├── .gitignore ├── .travis.yml ├── Cargo.toml ├── LICENSE-APACHE ├── LICENSE-MIT ├── README.md ├── k_means ├── Cargo.toml ├── README.md └── src │ ├── lib.rs │ └── main.rs ├── linear_regression ├── Cargo.toml ├── README.md └── src │ ├── lib.rs │ └── main.rs └── test.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated by Cargo 2 | # will have compiled files and executables 3 | /target/ 4 | 5 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries 6 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html 7 | Cargo.lock 8 | 9 | # These are backup files generated by rustfmt 10 | **/*.rs.bk 11 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: rust 2 | # use trusty for newer openblas 3 | sudo: required 4 | dist: trusty 5 | matrix: 6 | include: 7 | - rust: 1.32.0 8 | env: 9 | - FEATURES='openblas' 10 | - RUSTFLAGS='-D warnings' 11 | - rust: stable 12 | env: 13 | - FEATURES='openblas' 14 | - RUSTFLAGS='-D warnings' 15 | - rust: beta 16 | env: 17 | - FEATURES='openblas' 18 | - CHANNEL='beta' 19 | - RUSTFLAGS='-D warnings' 20 | - rust: nightly 21 | env: 22 | - FEATURES='openblas' 23 | - CHANNEL='nightly' 24 | env: 25 | global: 26 | - HOST=x86_64-unknown-linux-gnu 27 | - CARGO_INCREMENTAL=0 28 | addons: 29 | apt: 30 | packages: 31 | - libopenblas-dev 32 | - gfortran 33 | script: 34 | - ./test.sh "$FEATURES" "$CHANNEL" -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | members = [ 3 | "./", 4 | "linear_regression", 5 | "k_means", 6 | ] -------------------------------------------------------------------------------- /LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015 - 2018 Ulrik Sverdrup "bluss", 2 | Jim Turner, 3 | and ndarray developers 4 | 5 | Permission is hereby granted, free of charge, to any 6 | person obtaining a copy of this software and associated 7 | documentation files (the "Software"), to deal in the 8 | Software without restriction, including without 9 | limitation the rights to use, copy, modify, merge, 10 | publish, distribute, sublicense, and/or sell copies of 11 | the Software, and to permit persons to whom the Software 12 | is furnished to do so, subject to the following 13 | conditions: 14 | 15 | The above copyright notice and this permission notice 16 | shall be included in all copies or substantial portions 17 | of the Software. 18 | 19 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF 20 | ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED 21 | TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A 22 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT 23 | SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 24 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 25 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR 26 | IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 27 | DEALINGS IN THE SOFTWARE. 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ndarray-examples 2 | 3 | A collection of examples leveraging the `ndarray` ecosystem. 4 | 5 | Each example folder contains a description and instructions on how to run it. Do not run `cargo run` or `cargo build` from the top level folder! 6 | 7 | Table of contents: 8 | 9 | - [Linear regression](https://github.com/rust-ndarray/ndarray-examples/tree/master/linear_regression) 10 | - [K-Means clustering](https://github.com/rust-ndarray/ndarray-examples/tree/master/k_means) 11 | 12 | ## Contributing 13 | 14 | New examples are welcome! 15 | 16 | Please open an issue to discuss your example proposal and get involved! 17 | 18 | ## License 19 | 20 | Dual-licensed to be compatible with the Rust project. 21 | 22 | Licensed under the Apache License, Version 2.0 23 | http://www.apache.org/licenses/LICENSE-2.0 or the MIT license 24 | http://opensource.org/licenses/MIT, at your 25 | option. This file may not be copied, modified, or distributed 26 | except according to those terms. 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /k_means/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "k_means" 3 | version = "0.1.0" 4 | authors = ["LukeMathWalker"] 5 | edition = "2018" 6 | workspace = ".." 7 | 8 | [dependencies] 9 | ndarray = {version = "0.13"} 10 | ndarray-stats = "0.3" 11 | ndarray-rand = "0.11" 12 | rand = "0.7" 13 | -------------------------------------------------------------------------------- /k_means/README.md: -------------------------------------------------------------------------------- 1 | K-Means 2 | ======= 3 | 4 | An implementation of K-Means clustering using the [standard algorithm](https://en.wikipedia.org/wiki/K-means_clustering). 5 | 6 | You can run the example using 7 | ```sh 8 | cargo run 9 | ``` 10 | -------------------------------------------------------------------------------- /k_means/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![allow(non_snake_case)] 2 | use ndarray::{Array1, Array2, ArrayBase, Axis, Data, Ix1, Ix2}; 3 | use ndarray_stats::DeviationExt; 4 | use std::collections::HashMap; 5 | 6 | /// K-means clustering aims to partition a set of observations 7 | /// into `self.n_clusters` clusters, where each observation belongs 8 | /// to the cluster with the nearest mean. 9 | /// 10 | /// The mean of the points within a cluster is called *centroid*. 11 | /// 12 | /// Given the set of `centroids`, you can assign an observation to a cluster 13 | /// choosing the nearest centroid. 14 | /// 15 | /// Details on the algorithm can be found [here](https://en.wikipedia.org/wiki/K-means_clustering). 16 | /// 17 | /// We are implementing the _standard algorithm_. 18 | pub struct KMeans { 19 | /// Our set of centroids. 20 | /// 21 | /// Before `fit` is called, it's set to `None`. 22 | /// 23 | /// Once `fit` is called, we will have our set of centroids: the `centroids` matrix 24 | /// has shape `(n_clusters, n_features)`. 25 | pub centroids: Option>, 26 | /// The number of clusters we are trying to subdivide our observations into. 27 | /// It's set before-hand. 28 | n_clusters: u16, 29 | } 30 | 31 | impl KMeans { 32 | /// The number of clusters we are looking for has to be chosen before-hand, 33 | /// hence we take it as a constructor parameter of our `KMeans` struct. 34 | pub fn new(n_clusters: u16) -> KMeans { 35 | KMeans { 36 | centroids: None, 37 | n_clusters, 38 | } 39 | } 40 | 41 | /// Given an input matrix `X`, with shape `(n_samples, n_features)` 42 | /// `fit` determines `self.n_clusters` centroids based on the training data distribution. 43 | /// 44 | /// `self` is modified in place (`self.centroids` is mutated), nothing is returned. 45 | pub fn fit(&mut self, X: &ArrayBase) 46 | where 47 | A: Data, 48 | { 49 | let (n_samples, _) = X.dim(); 50 | assert!( 51 | n_samples >= self.n_clusters as usize, 52 | "We need more sample points than clusters!" 53 | ); 54 | 55 | let tolerance = 1e-3; 56 | 57 | // Initialisation: we use the Forgy method - we pick `self.n_clusters` random 58 | // observations and we use them as our first set of centroids. 59 | let mut centroids = KMeans::get_random_centroids(self.n_clusters, X); 60 | 61 | // Keep repeating the assignment-update steps until we have convergence 62 | loop { 63 | // Assignment step: associate each sample to the closest centroid 64 | let cluster_memberships = X.map_axis(Axis(1), |sample| { 65 | KMeans::find_closest_centroid(¢roids, &sample) 66 | }); 67 | 68 | // Update step: using the newly computed `cluster_memberships`, 69 | // compute the new centroids, the means of our clusters 70 | let new_centroids = 71 | KMeans::compute_centroids(&X, &cluster_memberships); 72 | 73 | // Check the convergence condition: if the new centroids, 74 | // after the assignment-update cycle, are closer to the old centroids 75 | // than a pre-established tolerance we are finished. 76 | let distance = centroids.sq_l2_dist(&new_centroids).unwrap(); 77 | let has_converged = distance < tolerance; 78 | 79 | centroids = new_centroids; 80 | 81 | if has_converged { 82 | break; 83 | } 84 | } 85 | 86 | // Set `self.centroids` to the outcome of our optimisation process. 87 | self.centroids = Some(centroids); 88 | } 89 | 90 | /// Given our observations, `X`, and the index to which each observation belongs, 91 | /// stored in `cluster_memberships`, 92 | /// we want to compute the mean of all observations in each cluster. 93 | fn compute_centroids( 94 | X: &ArrayBase, 95 | cluster_memberships: &ArrayBase, 96 | ) -> Array2 97 | where 98 | A: Data, 99 | B: Data, 100 | { 101 | let (_, n_features) = X.dim(); 102 | // `centroids` is a cluster index -> rolling mean mapping. 103 | // We will update it while we iterate over our observations. 104 | let mut centroids: HashMap = HashMap::new(); 105 | 106 | // We iterate over our observations and cluster memberships in lock-step. 107 | let iterator = X.genrows().into_iter().zip(cluster_memberships.iter()); 108 | for (row, cluster_index) in iterator { 109 | // If we have already encountered an observation that belongs to the 110 | // `cluster_index`th cluster, we retrieve the current rolling mean (our new centroid) 111 | // and we update it using the current observation. 112 | if let Some(rolling_mean) = centroids.get_mut(cluster_index) { 113 | rolling_mean.accumulate(&row); 114 | } else { 115 | // If we have not yet encountered an observation that belongs to the 116 | // `cluster_index`th cluster, we set its centroid to `row`, 117 | // initialising our `RollingMean` accumulator. 118 | let new_centroid = RollingMean::new(row.to_owned()); 119 | // .to_owned takes our `row` view as input and returns an owned array. 120 | centroids.insert(*cluster_index, new_centroid); 121 | } 122 | } 123 | 124 | // Convert our `HashMap` into a 2d array. 125 | let mut new_centroids: Array2 = Array2::zeros((centroids.len(), n_features)); 126 | for (cluster_index, centroid) in centroids.into_iter() { 127 | let mut new_centroid = new_centroids.index_axis_mut(Axis(0), cluster_index); 128 | // .assign sets each element of `new_centroid` 129 | // to the corresponding element in `centroid.current_mean`. 130 | new_centroid.assign(¢roid.current_mean); 131 | } 132 | 133 | new_centroids 134 | } 135 | 136 | fn get_random_centroids(n_clusters: u16, X: &ArrayBase) -> Array2 137 | where 138 | A: Data, 139 | { 140 | let (n_samples, _) = X.dim(); 141 | let mut rng = rand::thread_rng(); 142 | let indices = rand::seq::index::sample(&mut rng, n_samples, n_clusters as usize).into_vec(); 143 | X.select(Axis(0), &indices) 144 | } 145 | 146 | fn find_closest_centroid( 147 | centroids: &ArrayBase, 148 | sample: &ArrayBase, 149 | ) -> usize 150 | where 151 | A: Data, 152 | B: Data, 153 | { 154 | let mut iterator = centroids.genrows().into_iter(); 155 | 156 | let first_centroid = iterator.next().expect("No centroids - degenerate case!"); 157 | let mut closest_index = 0; 158 | let mut minimum_distance = sample.sq_l2_dist(&first_centroid).unwrap(); 159 | 160 | for (index, centroid) in iterator.enumerate() { 161 | let distance = sample.sq_l2_dist(¢roid).unwrap(); 162 | if distance < minimum_distance { 163 | minimum_distance = distance; 164 | // We skipped the first centroid in the for loop 165 | closest_index = index + 1; 166 | } 167 | } 168 | 169 | closest_index 170 | } 171 | } 172 | 173 | struct RollingMean { 174 | pub current_mean: Array1, 175 | n_samples: u64, 176 | } 177 | 178 | impl RollingMean { 179 | pub fn new(first_sample: Array1) -> Self { 180 | RollingMean { 181 | current_mean: first_sample, 182 | n_samples: 1, 183 | } 184 | } 185 | 186 | pub fn accumulate(&mut self, new_sample: &ArrayBase) 187 | where 188 | A: Data, 189 | { 190 | let mut increment: Array1 = &self.current_mean - new_sample; 191 | increment.mapv_inplace(|x| x / (self.n_samples + 1) as f64); 192 | self.current_mean -= &increment; 193 | self.n_samples += 1; 194 | } 195 | } 196 | -------------------------------------------------------------------------------- /k_means/src/main.rs: -------------------------------------------------------------------------------- 1 | #![allow(non_snake_case)] 2 | use ndarray::{stack, Array, Array2, Axis}; 3 | use ndarray_rand::RandomExt; 4 | use ndarray_rand::rand_distr::Normal; 5 | 6 | // Import KMeans from other file ("lib.rs") in this example 7 | use k_means::KMeans; 8 | 9 | /// It returns a data distribution. 10 | /// 11 | /// The data is clearly centered around two distinct points, 12 | /// to quickly spot if something is wrong with the KMeans algorithm 13 | /// looking at the output. 14 | fn get_data(n_samples: usize, n_features: usize) -> Array2 { 15 | let shape = (n_samples / 2, n_features); 16 | let X1: Array2 = Array::random(shape, Normal::new(1000., 0.1).unwrap()); 17 | let X2: Array2 = Array::random(shape, Normal::new(-1000., 0.1).unwrap()); 18 | stack(Axis(0), &[X1.view(), X2.view()]).unwrap().to_owned() 19 | } 20 | 21 | pub fn main() { 22 | let n_samples = 50000; 23 | let n_features = 3; 24 | let n_clusters = 2; 25 | 26 | let X = get_data(n_samples, n_features); 27 | 28 | let mut k_means = KMeans::new(n_clusters); 29 | k_means.fit(&X); 30 | 31 | println!("The centroids are {:.3}", k_means.centroids.unwrap()); 32 | } 33 | -------------------------------------------------------------------------------- /linear_regression/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "linear_regression" 3 | version = "0.1.0" 4 | authors = ["LukeMathWalker"] 5 | edition = "2018" 6 | workspace = ".." 7 | 8 | [features] 9 | default = [] 10 | openblas = ["ndarray-linalg/openblas"] 11 | intel-mkl = ["ndarray-linalg/intel-mkl"] 12 | netlib = ["ndarray-linalg/netlib"] 13 | 14 | [dependencies] 15 | ndarray = {version = "0.13", features = ["blas"]} 16 | ndarray-linalg = {version = "0.12", optional = true, default-features = false} 17 | ndarray-stats = "0.3" 18 | ndarray-rand = "0.11" 19 | rand = "0.7" 20 | -------------------------------------------------------------------------------- /linear_regression/README.md: -------------------------------------------------------------------------------- 1 | Linear Regression 2 | ================= 3 | 4 | An implementation of vanilla linear regression: it solves the normal equation to determine 5 | the optimal coefficients. 6 | 7 | You can run the example using 8 | ```sh 9 | cargo run --features= 10 | ``` 11 | where `` has to be either `openblas`, `netlib` or `intel-mkl`. 12 | 13 | If you want to use OpenBLAS: 14 | ```sh 15 | cargo run --features=openblas 16 | ``` 17 | 18 | See the following section for more details. 19 | 20 | BLAS/LAPACK Backend 21 | =================== 22 | 23 | This example uses `ndarray-linalg`: it thus requires a BLAS/LAPACK backend to be compiled and executed. 24 | 25 | Three BLAS/LAPACK implementations are supported: 26 | 27 | - [OpenBLAS](https://github.com/cmr/openblas-src) 28 | - requires `gfortran` (or another Fortran compiler) 29 | - [Netlib](https://github.com/cmr/netlib-src) 30 | - requires `cmake` and `gfortran` 31 | - [Intel MKL](https://github.com/termoshtt/rust-intel-mkl) (non-free license, see the linked page) 32 | -------------------------------------------------------------------------------- /linear_regression/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![allow(non_snake_case)] 2 | use ndarray::{stack, Array, Array1, ArrayBase, Axis, Data, Ix1, Ix2}; 3 | use ndarray_linalg::Solve; 4 | 5 | /// The simple linear regression model is 6 | /// y = bX + e where e ~ N(0, sigma^2 * I) 7 | /// In probabilistic terms this corresponds to 8 | /// y - bX ~ N(0, sigma^2 * I) 9 | /// y | X, b ~ N(bX, sigma^2 * I) 10 | /// The loss for the model is simply the squared error between the model 11 | /// predictions and the true values: 12 | /// Loss = ||y - bX||^2 13 | /// The maximum likelihood estimation for the model parameters `beta` can be computed 14 | /// in closed form via the normal equation: 15 | /// b = (X^T X)^{-1} X^T y 16 | /// where (X^T X)^{-1} X^T is known as the pseudoinverse or Moore-Penrose inverse. 17 | /// 18 | /// Adapted from: https://github.com/ddbourgin/numpy-ml 19 | pub struct LinearRegression { 20 | pub beta: Option>, 21 | fit_intercept: bool, 22 | } 23 | 24 | impl LinearRegression { 25 | pub fn new(fit_intercept: bool) -> LinearRegression { 26 | LinearRegression { 27 | beta: None, 28 | fit_intercept, 29 | } 30 | } 31 | 32 | /// Given: 33 | /// - an input matrix `X`, with shape `(n_samples, n_features)`; 34 | /// - a target variable `y`, with shape `(n_samples,)`; 35 | /// `fit` tunes the `beta` parameter of the linear regression model 36 | /// to match the training data distribution. 37 | /// 38 | /// `self` is modified in place, nothing is returned. 39 | pub fn fit(&mut self, X: &ArrayBase, y: &ArrayBase) 40 | where 41 | A: Data, 42 | B: Data, 43 | { 44 | let (n_samples, _) = X.dim(); 45 | 46 | // Check that our inputs have compatible shapes 47 | assert_eq!(y.dim(), n_samples); 48 | 49 | // If we are fitting the intercept, we need an additional column 50 | self.beta = if self.fit_intercept { 51 | let dummy_column: Array = Array::ones((n_samples, 1)); 52 | let X = stack(Axis(1), &[dummy_column.view(), X.view()]).unwrap(); 53 | Some(LinearRegression::solve_normal_equation(&X, y)) 54 | } else { 55 | Some(LinearRegression::solve_normal_equation(X, y)) 56 | }; 57 | } 58 | 59 | /// Given an input matrix `X`, with shape `(n_samples, n_features)`, 60 | /// `predict` returns the target variable according to linear model 61 | /// learned from the training data distribution. 62 | /// 63 | /// **Panics** if `self` has not be `fit`ted before calling `predict. 64 | pub fn predict(&self, X: &ArrayBase) -> Array1 65 | where 66 | A: Data, 67 | { 68 | let (n_samples, _) = X.dim(); 69 | 70 | // If we are fitting the intercept, we need an additional column 71 | if self.fit_intercept { 72 | let dummy_column: Array = Array::ones((n_samples, 1)); 73 | let X = stack(Axis(1), &[dummy_column.view(), X.view()]).unwrap(); 74 | self._predict(&X) 75 | } else { 76 | self._predict(X) 77 | } 78 | } 79 | 80 | fn solve_normal_equation(X: &ArrayBase, y: &ArrayBase) -> Array1 81 | where 82 | A: Data, 83 | B: Data, 84 | { 85 | let rhs = X.t().dot(y); 86 | let linear_operator = X.t().dot(X); 87 | linear_operator.solve_into(rhs).unwrap() 88 | } 89 | 90 | fn _predict(&self, X: &ArrayBase) -> Array1 91 | where 92 | A: Data, 93 | { 94 | match &self.beta { 95 | None => panic!("The linear regression estimator has to be fitted first!"), 96 | Some(beta) => X.dot(beta), 97 | } 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /linear_regression/src/main.rs: -------------------------------------------------------------------------------- 1 | #![allow(non_snake_case)] 2 | use ndarray::{Array, Array1, Array2, Axis}; 3 | use ndarray_linalg::random; 4 | use ndarray_rand::RandomExt; 5 | use ndarray_stats::DeviationExt; 6 | use ndarray_rand::rand_distr::StandardNormal; 7 | 8 | // Import LinearRegression from other file ("lib.rs") in this example 9 | use linear_regression::LinearRegression; 10 | 11 | /// It returns a tuple: input data and the associated target variable. 12 | /// 13 | /// The target variable is a linear function of the input, perturbed by gaussian noise. 14 | fn get_data(n_samples: usize, n_features: usize) -> (Array2, Array1) { 15 | let shape = (n_samples, n_features); 16 | let noise: Array1 = Array::random(n_samples, StandardNormal); 17 | 18 | let beta: Array1 = random(n_features) * 10.; 19 | println!("Beta used to generate target variable: {:.3}", beta); 20 | 21 | let X: Array2 = random(shape); 22 | let y: Array1 = X.dot(&beta) + noise; 23 | (X, y) 24 | } 25 | 26 | pub fn main() { 27 | let n_train_samples = 5000; 28 | let n_test_samples = 1000; 29 | let n_features = 3; 30 | 31 | let (X, y) = get_data(n_train_samples + n_test_samples, n_features); 32 | let (X_train, X_test) = X.view().split_at(Axis(0), n_train_samples); 33 | let (y_train, y_test) = y.view().split_at(Axis(0), n_train_samples); 34 | 35 | let mut linear_regressor = LinearRegression::new(false); 36 | linear_regressor.fit(&X_train, &y_train); 37 | 38 | let test_predictions = linear_regressor.predict(&X_test); 39 | let mean_squared_error = test_predictions.mean_sq_err(&y_test.to_owned()).unwrap(); 40 | println!( 41 | "Beta estimated from the training data: {:.3}", 42 | linear_regressor.beta.unwrap() 43 | ); 44 | println!( 45 | "The fitted regressor has a mean squared error of {:.3}", 46 | mean_squared_error 47 | ); 48 | } 49 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -x 4 | set -e 5 | 6 | FEATURES=$1 7 | CHANNEL=$2 8 | 9 | if [[ "$CHANNEL" != "beta" ]]; then 10 | rustup component add rustfmt 11 | cargo fmt --all -- --check 12 | rustup component add clippy 13 | fi 14 | 15 | # Loop over the directories in the project, skipping the target directory 16 | for f in *; do 17 | if [[ -d ${f} ]] && [[ ${f} != "target" ]]; then 18 | # Will not run if no directories are available 19 | echo "\n\nTesting '${f}' example.\n\n" 20 | cd ${f} 21 | cargo run --features "${FEATURES}" 22 | if [[ "$CHANNEL" != "beta" ]]; then 23 | cargo clippy -- -D warnings 24 | fi 25 | cd .. 26 | fi 27 | done 28 | 29 | --------------------------------------------------------------------------------