├── .gitignore ├── Cargo.lock ├── Cargo.toml ├── README.md ├── consumer_binary ├── Cargo.toml └── src │ └── main.rs ├── matrix ├── Cargo.toml └── src │ ├── lib.rs │ ├── macros.rs │ └── matrix.rs └── neural-network ├── Cargo.toml └── src ├── activations.rs ├── lib.rs └── network.rs /.gitignore: -------------------------------------------------------------------------------- 1 | target/ 2 | -------------------------------------------------------------------------------- /Cargo.lock: -------------------------------------------------------------------------------- 1 | # This file is automatically @generated by Cargo. 2 | # It is not intended for manual editing. 3 | version = 3 4 | 5 | [[package]] 6 | name = "cfg-if" 7 | version = "1.0.0" 8 | source = "registry+https://github.com/rust-lang/crates.io-index" 9 | checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" 10 | 11 | [[package]] 12 | name = "consumer_binary" 13 | version = "0.1.0" 14 | dependencies = [ 15 | "neural-network", 16 | ] 17 | 18 | [[package]] 19 | name = "darling" 20 | version = "0.14.4" 21 | source = "registry+https://github.com/rust-lang/crates.io-index" 22 | checksum = "7b750cb3417fd1b327431a470f388520309479ab0bf5e323505daf0290cd3850" 23 | dependencies = [ 24 | "darling_core", 25 | "darling_macro", 26 | ] 27 | 28 | [[package]] 29 | name = "darling_core" 30 | version = "0.14.4" 31 | source = "registry+https://github.com/rust-lang/crates.io-index" 32 | checksum = "109c1ca6e6b7f82cc233a97004ea8ed7ca123a9af07a8230878fcfda9b158bf0" 33 | dependencies = [ 34 | "fnv", 35 | "ident_case", 36 | "proc-macro2", 37 | "quote", 38 | "strsim", 39 | "syn", 40 | ] 41 | 42 | [[package]] 43 | name = "darling_macro" 44 | version = "0.14.4" 45 | source = "registry+https://github.com/rust-lang/crates.io-index" 46 | checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e" 47 | dependencies = [ 48 | "darling_core", 49 | "quote", 50 | "syn", 51 | ] 52 | 53 | [[package]] 54 | name = "derive_builder" 55 | version = "0.12.0" 56 | source = "registry+https://github.com/rust-lang/crates.io-index" 57 | checksum = "8d67778784b508018359cbc8696edb3db78160bab2c2a28ba7f56ef6932997f8" 58 | dependencies = [ 59 | "derive_builder_macro", 60 | ] 61 | 62 | [[package]] 63 | name = "derive_builder_core" 64 | version = "0.12.0" 65 | source = "registry+https://github.com/rust-lang/crates.io-index" 66 | checksum = "c11bdc11a0c47bc7d37d582b5285da6849c96681023680b906673c5707af7b0f" 67 | dependencies = [ 68 | "darling", 69 | "proc-macro2", 70 | "quote", 71 | "syn", 72 | ] 73 | 74 | [[package]] 75 | name = "derive_builder_macro" 76 | version = "0.12.0" 77 | source = "registry+https://github.com/rust-lang/crates.io-index" 78 | checksum = "ebcda35c7a396850a55ffeac740804b40ffec779b98fffbb1738f4033f0ee79e" 79 | dependencies = [ 80 | "derive_builder_core", 81 | "syn", 82 | ] 83 | 84 | [[package]] 85 | name = "fnv" 86 | version = "1.0.7" 87 | source = "registry+https://github.com/rust-lang/crates.io-index" 88 | checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" 89 | 90 | [[package]] 91 | name = "getrandom" 92 | version = "0.2.9" 93 | source = "registry+https://github.com/rust-lang/crates.io-index" 94 | checksum = "c85e1d9ab2eadba7e5040d4e09cbd6d072b76a557ad64e797c2cb9d4da21d7e4" 95 | dependencies = [ 96 | "cfg-if", 97 | "libc", 98 | "wasi", 99 | ] 100 | 101 | [[package]] 102 | name = "ident_case" 103 | version = "1.0.1" 104 | source = "registry+https://github.com/rust-lang/crates.io-index" 105 | checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" 106 | 107 | [[package]] 108 | name = "libc" 109 | version = "0.2.144" 110 | source = "registry+https://github.com/rust-lang/crates.io-index" 111 | checksum = "2b00cc1c228a6782d0f076e7b232802e0c5689d41bb5df366f2a6b6621cfdfe1" 112 | 113 | [[package]] 114 | name = "matrix" 115 | version = "0.1.0" 116 | dependencies = [ 117 | "rand", 118 | ] 119 | 120 | [[package]] 121 | name = "neural-network" 122 | version = "0.1.0" 123 | dependencies = [ 124 | "derive_builder", 125 | "matrix", 126 | ] 127 | 128 | [[package]] 129 | name = "ppv-lite86" 130 | version = "0.2.17" 131 | source = "registry+https://github.com/rust-lang/crates.io-index" 132 | checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" 133 | 134 | [[package]] 135 | name = "proc-macro2" 136 | version = "1.0.59" 137 | source = "registry+https://github.com/rust-lang/crates.io-index" 138 | checksum = "6aeca18b86b413c660b781aa319e4e2648a3e6f9eadc9b47e9038e6fe9f3451b" 139 | dependencies = [ 140 | "unicode-ident", 141 | ] 142 | 143 | [[package]] 144 | name = "quote" 145 | version = "1.0.28" 146 | source = "registry+https://github.com/rust-lang/crates.io-index" 147 | checksum = "1b9ab9c7eadfd8df19006f1cf1a4aed13540ed5cbc047010ece5826e10825488" 148 | dependencies = [ 149 | "proc-macro2", 150 | ] 151 | 152 | [[package]] 153 | name = "rand" 154 | version = "0.8.5" 155 | source = "registry+https://github.com/rust-lang/crates.io-index" 156 | checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" 157 | dependencies = [ 158 | "libc", 159 | "rand_chacha", 160 | "rand_core", 161 | ] 162 | 163 | [[package]] 164 | name = "rand_chacha" 165 | version = "0.3.1" 166 | source = "registry+https://github.com/rust-lang/crates.io-index" 167 | checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" 168 | dependencies = [ 169 | "ppv-lite86", 170 | "rand_core", 171 | ] 172 | 173 | [[package]] 174 | name = "rand_core" 175 | version = "0.6.4" 176 | source = "registry+https://github.com/rust-lang/crates.io-index" 177 | checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" 178 | dependencies = [ 179 | "getrandom", 180 | ] 181 | 182 | [[package]] 183 | name = "strsim" 184 | version = "0.10.0" 185 | source = "registry+https://github.com/rust-lang/crates.io-index" 186 | checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" 187 | 188 | [[package]] 189 | name = "syn" 190 | version = "1.0.109" 191 | source = "registry+https://github.com/rust-lang/crates.io-index" 192 | checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" 193 | dependencies = [ 194 | "proc-macro2", 195 | "quote", 196 | "unicode-ident", 197 | ] 198 | 199 | [[package]] 200 | name = "unicode-ident" 201 | version = "1.0.9" 202 | source = "registry+https://github.com/rust-lang/crates.io-index" 203 | checksum = "b15811caf2415fb889178633e7724bad2509101cde276048e013b9def5e51fa0" 204 | 205 | [[package]] 206 | name = "wasi" 207 | version = "0.11.0+wasi-snapshot-preview1" 208 | source = "registry+https://github.com/rust-lang/crates.io-index" 209 | checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" 210 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | 3 | members = [ 4 | "matrix", 5 | "neural-network", 6 | "consumer_binary" 7 | ] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # neural-net-rs 2 | 3 | neural-net-rs is a Rust-based neural network framework designed for educational purposes. This project aims to provide a simple yet informative implementation of neural networks in the Rust programming language. 4 | 5 | 6 | [![Neural Net Rust](https://img.youtube.com/vi/DKbz9pNXVdE/0.jpg)](https://www.youtube.com/watch?v=DKbz9pNXVdE) 7 | ## Features 8 | 9 | - **Educational Focus:** neural-net-rs is created with the primary goal of helping users understand the fundamentals of neural networks in Rust. 10 | - **Simplicity:** The framework prioritizes simplicity to facilitate a smooth learning experience for beginners in deep learning. 11 | - **Flexibility:** While keeping things simple, neural-net-rs is designed to be flexible, allowing users to experiment with different neural network architectures. 12 | 13 | ## Getting Started 14 | 15 | ### Prerequisites 16 | [Install Rust](https://www.rust-lang.org/learn/get-started) 17 | 18 | ### Installation 19 | 20 | ```bash 21 | git clone https://github.com/your-username/neural-net-rs.git 22 | cd neural-net-rs 23 | cargo build 24 | -------------------------------------------------------------------------------- /consumer_binary/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "consumer_binary" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 7 | 8 | [dependencies] 9 | neural-network = {path = "../neural-network"} -------------------------------------------------------------------------------- /consumer_binary/src/main.rs: -------------------------------------------------------------------------------- 1 | use neural_network::network::Network; 2 | use neural_network::activations::SIGMOID; 3 | use neural_network::matrix::Matrix; 4 | use std::env; 5 | fn main() { 6 | env::set_var("RUST_BACKTRACE", "1"); 7 | let inputs = vec![ 8 | vec![0.0, 0.0], 9 | vec![0.0, 1.0], 10 | vec![1.0, 0.0], 11 | vec![1.0, 1.0], 12 | ]; 13 | let targets = vec![vec![0.0], vec![1.0], vec![0.0], vec![1.0]]; 14 | 15 | let mut network = Network::new(vec![2,3,1],SIGMOID,0.5); 16 | 17 | 18 | network.train(inputs, targets, 100000); 19 | 20 | println!("{:?}", network.feed_forward(Matrix::from(vec![0.0, 0.0]))); 21 | println!("{:?}", network.feed_forward(Matrix::from(vec![0.0, 1.0]))); 22 | println!("{:?}", network.feed_forward(Matrix::from(vec![1.0, 0.0]))); 23 | println!("{:?}", network.feed_forward(Matrix::from(vec![1.0, 1.0]))); 24 | 25 | 26 | 27 | } 28 | -------------------------------------------------------------------------------- /matrix/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "matrix" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 7 | 8 | [dependencies] 9 | rand = "0.8.5" 10 | -------------------------------------------------------------------------------- /matrix/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod macros; 2 | 3 | pub mod matrix; -------------------------------------------------------------------------------- /matrix/src/macros.rs: -------------------------------------------------------------------------------- 1 | 2 | use crate::matrix::Matrix; 3 | 4 | 5 | #[macro_export] 6 | macro_rules! matrix { 7 | ( $( $($val:expr),+ );* $(;)? ) => { 8 | { 9 | let mut data = Vec::::new(); 10 | let mut rows = 0; 11 | let mut cols = 0; 12 | $( 13 | let row_data = vec![$($val),+]; 14 | data.extend(row_data); 15 | rows += 1; 16 | let row_len = vec![$($val),+].len(); 17 | if cols == 0 { 18 | cols = row_len; 19 | } else if cols != row_len { 20 | panic!("Inconsistent number of elements in the matrix rows"); 21 | } 22 | )* 23 | 24 | Matrix { 25 | rows, 26 | cols, 27 | data, 28 | } 29 | } 30 | }; 31 | } 32 | 33 | #[cfg(test)] 34 | mod tests { 35 | use super::Matrix; 36 | 37 | #[test] 38 | fn test_matrix_macro() { 39 | let m = matrix![ 40 | 1.0, 2.0, 3.0; 41 | 4.0, 5.0, 6.0; 42 | 7.0, 8.0, 9.0 43 | ]; 44 | 45 | 46 | 47 | assert_eq!(m.rows, 3); 48 | assert_eq!(m.cols, 3); 49 | assert_eq!( 50 | m.data, 51 | vec![ 52 | 1.0, 2.0, 3.0, 53 | 4.0, 5.0, 6.0, 54 | 7.0, 8.0, 9.0, 55 | ] 56 | ); 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /matrix/src/matrix.rs: -------------------------------------------------------------------------------- 1 | use std::fmt; 2 | use std::ops::{Add, Mul}; 3 | use rand::Rng; 4 | use crate::macros::*; 5 | 6 | #[derive(Debug,Clone)] 7 | pub struct Matrix { 8 | pub rows: usize, 9 | pub cols: usize, 10 | pub data: Vec 11 | } 12 | 13 | 14 | // access through i* numofcols + j 15 | 16 | 17 | impl Matrix { 18 | 19 | pub fn elementwise_multiply(&self, other: &Matrix) -> Matrix { 20 | 21 | if self.rows != other.rows || self.cols != other.cols { 22 | panic!("Attempted to multiply by matrix of incorrect dimensions"); 23 | } 24 | 25 | let mut result_data = vec![0.0; self.cols * self.rows]; 26 | for i in 0..self.data.len() { // double check this 27 | result_data[i] = self.data[i] * other.data[i] 28 | } 29 | 30 | Matrix { 31 | rows: self.rows, 32 | cols: self.cols, 33 | data: result_data, 34 | } 35 | } 36 | pub fn random(rows: usize, cols: usize) -> Matrix { 37 | let mut buffer = Vec::::with_capacity(rows * cols); 38 | 39 | for _ in 0..rows*cols { 40 | let num = rand::thread_rng().gen_range(0.0..1.0); 41 | 42 | buffer.push(num); 43 | } 44 | 45 | Matrix{rows,cols,data:buffer} 46 | 47 | } 48 | 49 | pub fn new(rows: usize, cols: usize, data: Vec) -> Matrix { 50 | 51 | assert!(data.len()-1 != rows * cols, "Invalid Size"); 52 | Matrix { rows, cols, data } 53 | 54 | } 55 | 56 | pub fn zeros(rows:usize, cols:usize) -> Matrix { 57 | 58 | Matrix { rows, cols, data: vec![0.0; cols * rows] } 59 | } 60 | 61 | pub fn add(&self, other: &Matrix) -> Matrix { 62 | if self.rows != other.rows || self.cols != other.cols { 63 | panic!("Attempted to add matrix of incorrect dimensions"); 64 | } 65 | 66 | let mut buffer = Vec::::with_capacity(self.rows * self.cols); 67 | 68 | for i in 0..self.data.len() { 69 | 70 | let result = self.data[i] + other.data[i]; 71 | 72 | buffer.push(result); 73 | 74 | } 75 | 76 | Matrix { 77 | rows:self.rows, 78 | cols: self.cols, 79 | data: buffer 80 | } 81 | 82 | } 83 | 84 | pub fn subtract(&self, other: &Matrix) -> Matrix { 85 | 86 | assert!( 87 | self.rows == other.rows && self.cols == other.cols, 88 | "Cannot subtract matrices with different dimensions" 89 | ); 90 | 91 | let mut buffer = Vec::::with_capacity(self.rows * self.cols); 92 | 93 | for i in 0..self.data.len() { 94 | 95 | let result = self.data[i] - other.data[i]; 96 | 97 | buffer.push(result); 98 | 99 | } 100 | 101 | Matrix { 102 | rows:self.rows, 103 | cols: self.cols, 104 | data: buffer 105 | } 106 | 107 | } 108 | 109 | 110 | pub fn dot_multiply(&self, other: &Matrix) -> Matrix { 111 | 112 | 113 | if self.cols != other.rows { 114 | panic!("Attempted to multiply by matrix of incorrect dimensions"); 115 | } 116 | 117 | 118 | let mut result_data = vec![0.0; self.rows * other.cols]; 119 | 120 | for i in 0..self.rows { 121 | for j in 0..other.cols { 122 | let mut sum = 0.0; 123 | for k in 0..self.cols { 124 | sum += self.data[i * self.cols + k] * other.data[k * other.cols + j]; 125 | } 126 | result_data[i * other.cols + j] = sum; 127 | } 128 | } 129 | 130 | Matrix { 131 | rows: self.rows, 132 | cols: other.cols, 133 | data: result_data, 134 | } 135 | 136 | } 137 | 138 | pub fn transpose(&self) -> Matrix { 139 | let mut buffer = vec![0.0; self.cols * self.rows]; 140 | 141 | for i in 0..self.rows { 142 | for j in 0..self.cols { 143 | buffer[j * self.rows + i] = self.data[i * self.cols + j]; 144 | } 145 | } 146 | 147 | Matrix { 148 | rows: self.cols, 149 | cols: self.rows, 150 | data: buffer, 151 | } 152 | } 153 | 154 | pub fn map(&mut self, func: fn(&f64) -> f64) -> Matrix 155 | { 156 | let mut result = Matrix { 157 | rows: self.rows, 158 | cols: self.cols, 159 | data: Vec::with_capacity(self.data.len()), 160 | }; 161 | 162 | result.data.extend(self.data.iter().map(|&val| func(&val))); 163 | 164 | result 165 | } 166 | 167 | 168 | } 169 | impl From> for Matrix { 170 | fn from(vec: Vec) -> Self { 171 | let rows = vec.len(); 172 | let cols = 1; 173 | Matrix { 174 | rows, 175 | cols, 176 | data: vec, 177 | } 178 | } 179 | } 180 | 181 | impl PartialEq for Matrix { 182 | fn eq(&self, other: &Self) -> bool { 183 | self.rows == other.rows && self.cols == other.cols && self.data == other.data 184 | } 185 | } 186 | 187 | impl fmt::Display for Matrix { 188 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 189 | for row in 0..self.rows { 190 | for col in 0..self.cols { 191 | write!(f, "{}", self.data[row * self.cols + col])?; 192 | if col < self.cols - 1 { 193 | write!(f, "\t")?; // Separate columns with a tab 194 | } 195 | } 196 | writeln!(f)?; // Move to the next line after each row 197 | } 198 | Ok(()) 199 | } 200 | } 201 | 202 | 203 | 204 | 205 | #[cfg(test)] 206 | mod tests { 207 | use super::*; 208 | use crate::matrix; 209 | 210 | #[test] 211 | fn test_random_matrix() { 212 | let rows = 3; 213 | let cols = 4; 214 | let matrix = Matrix::random(rows, cols); 215 | 216 | assert_eq!(matrix.rows, rows); 217 | assert_eq!(matrix.cols, cols); 218 | assert_eq!(matrix.data.len(), rows * cols); 219 | 220 | for &num in &matrix.data { 221 | assert!(num >= 0.0 && num < 1.0); 222 | } 223 | } 224 | 225 | #[test] 226 | fn test_elementwise_multiply() { 227 | // Create two matrices for testing 228 | let matrix1 = Matrix::new(2, 2, vec![1.0, 2.0, 3.0, 4.0]); 229 | let matrix2 = Matrix::new(2, 2, vec![5.0, 6.0, 7.0, 8.0]); 230 | 231 | // Perform element-wise multiplication 232 | let result = matrix1.elementwise_multiply(&matrix2); 233 | 234 | // Define the expected result 235 | let expected_result = Matrix::new(2, 2, vec![5.0, 12.0, 21.0, 32.0]); 236 | 237 | // Check if the actual result matches the expected result 238 | assert_eq!(result, expected_result); 239 | } 240 | 241 | 242 | #[test] 243 | fn test_subtract_same_dimensions() { 244 | let matrix1 = matrix![ 245 | 1.0, 2.0; 246 | 3.0, 4.0 247 | ]; 248 | 249 | let matrix2 = matrix![ 250 | 5.0, 6.0; 251 | 7.0, 8.0 252 | ]; 253 | 254 | let result = matrix1.subtract(&matrix2); 255 | 256 | let expected = matrix![ 257 | -4.0, -4.0; 258 | -4.0, -4.0 259 | ]; 260 | 261 | assert_eq!(result, expected); 262 | } 263 | 264 | #[test] 265 | fn test_dot_multiply() { 266 | let a = matrix![ 267 | 1.0, 2.0, 3.0; 268 | 4.0, 5.0, 6.0 269 | ]; 270 | let b = matrix![ 271 | 7.0, 8.0; 272 | 9.0, 10.0; 273 | 11.0, 12.0 274 | ]; 275 | 276 | let result = a.dot_multiply(&b); 277 | 278 | let expected_result = matrix![ 279 | 58.0, 64.0; 280 | 139.0, 154.0 281 | ]; 282 | 283 | assert_eq!(result, expected_result); 284 | } 285 | 286 | #[test] 287 | #[should_panic(expected = "Cannot subtract matrices with different dimensions")] 288 | fn test_subtract_different_dimensions() { 289 | let matrix1 = matrix![ 290 | 1.0, 2.0; 291 | 3.0, 4.0 292 | ]; 293 | 294 | let matrix2 = matrix![ 295 | 5.0, 6.0, 7.0; 296 | 8.0, 9.0, 10.0 297 | ]; 298 | 299 | let _ = matrix1.subtract(&matrix2); 300 | } 301 | 302 | #[test] 303 | fn test_matrix_addition() { 304 | let a = matrix![ 305 | 1.0, 2.0, 3.0; 306 | 4.0, 5.0, 6.0; 307 | 7.0, 8.0, 9.0 308 | ]; 309 | 310 | let b = matrix![ 311 | 5.0, 6.0, 7.0; 312 | 8.0, 9.0, 10.0; 313 | 11.0, 12.0, 13.0 314 | ]; 315 | 316 | let expected_result = matrix![ 317 | 6.0, 8.0, 10.0; 318 | 12.0, 14.0, 16.0; 319 | 18.0, 20.0, 22.0 320 | ]; 321 | 322 | let result = a.add(&b); 323 | 324 | assert_eq!(result, expected_result); 325 | } 326 | 327 | #[test] 328 | fn test_transpose_2x2() { 329 | let matrix = matrix![ 330 | 1.0, 2.0; 331 | 3.0, 4.0 332 | ]; 333 | let transposed = matrix.transpose(); 334 | 335 | let expected = matrix![ 336 | 1.0, 3.0; 337 | 2.0, 4.0 338 | ]; 339 | assert_eq!(transposed, expected); 340 | } 341 | 342 | #[test] 343 | fn test_transpose_3x3() { 344 | let matrix = matrix![ 345 | 1.0, 2.0, 3.0; 346 | 4.0, 5.0, 6.0; 347 | 7.0, 8.0, 9.0 348 | ]; 349 | let transposed = matrix.transpose(); 350 | 351 | let expected = matrix![ 352 | 1.0, 4.0, 7.0; 353 | 2.0, 5.0, 8.0; 354 | 3.0, 6.0, 9.0 355 | ]; 356 | assert_eq!(transposed, expected); 357 | } 358 | 359 | #[test] 360 | fn test_transpose_4x3() { 361 | let matrix = matrix![ 362 | 1.0, 2.0, 3.0; 363 | 4.0, 5.0, 6.0; 364 | 7.0, 8.0, 9.0; 365 | 10.0, 11.0, 12.0 366 | ]; 367 | let transposed = matrix.transpose(); 368 | 369 | let expected = matrix![ 370 | 1.0, 4.0, 7.0, 10.0; 371 | 2.0, 5.0, 8.0, 11.0; 372 | 3.0, 6.0, 9.0, 12.0 373 | ]; 374 | assert_eq!(transposed, expected); 375 | } 376 | 377 | #[test] 378 | fn test_map_add_one() { 379 | let mut matrix = Matrix { 380 | rows: 2, 381 | cols: 2, 382 | data: vec![1.0, 2.0, 3.0, 4.0], 383 | }; 384 | 385 | let transformed = matrix.map(|x| x + 1.0); 386 | 387 | let expected = Matrix { 388 | rows: 2, 389 | cols: 2, 390 | data: vec![2.0, 3.0, 4.0, 5.0], 391 | }; 392 | 393 | assert_eq!(transformed, expected); 394 | } 395 | 396 | #[test] 397 | fn test_map_square() { 398 | let mut matrix = Matrix { 399 | rows: 2, 400 | cols: 2, 401 | data: vec![1.0, 2.0, 3.0, 4.0], 402 | }; 403 | 404 | let transformed = matrix.map(|x| x * x); 405 | 406 | let expected = Matrix { 407 | rows: 2, 408 | cols: 2, 409 | data: vec![1.0, 4.0, 9.0, 16.0], 410 | }; 411 | 412 | assert_eq!(transformed, expected); 413 | } 414 | } 415 | -------------------------------------------------------------------------------- /neural-network/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "neural-network" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 7 | 8 | [dependencies] 9 | derive_builder = "0.12.0" 10 | matrix = {path = "../matrix"} -------------------------------------------------------------------------------- /neural-network/src/activations.rs: -------------------------------------------------------------------------------- 1 | use std::f64::consts::E; 2 | 3 | 4 | 5 | #[derive(Clone,Copy,Debug)] 6 | pub struct Activation { 7 | pub function: fn(&f64) -> f64, 8 | pub derivative: fn(&f64) -> f64, 9 | } 10 | 11 | pub const SIGMOID: Activation = Activation { 12 | function: |x| 1.0 / (1.0 + E.powf(-x)), 13 | derivative: |x| x * (1.0 - x), 14 | }; 15 | -------------------------------------------------------------------------------- /neural-network/src/lib.rs: -------------------------------------------------------------------------------- 1 | #[macro_use] 2 | extern crate derive_builder; 3 | pub mod network; 4 | pub mod activations; 5 | 6 | pub mod matrix { 7 | 8 | pub use matrix::matrix::Matrix; 9 | } -------------------------------------------------------------------------------- /neural-network/src/network.rs: -------------------------------------------------------------------------------- 1 | use matrix::matrix::Matrix; 2 | 3 | use crate::activations::Activation; 4 | 5 | 6 | #[derive( Builder)] 7 | pub struct Network { 8 | layers:Vec, // amount of neurons in each layer, [72,16,10] 9 | weights: Vec, 10 | biases: Vec, 11 | data: Vec, 12 | activation: Activation, 13 | learning_rate: f64, 14 | } 15 | 16 | impl Network { 17 | 18 | pub fn new(layers: Vec,activation:Activation,learning_rate:f64 ) -> Self { 19 | 20 | let mut weights = vec![]; 21 | 22 | let mut biases = vec![]; 23 | 24 | for i in 0..layers.len() - 1 { 25 | weights.push(Matrix::random(layers[i+1], layers[i])); 26 | biases.push(Matrix::random(layers[i+1], 1)); 27 | } 28 | 29 | 30 | Network { 31 | layers, 32 | weights, 33 | biases, 34 | data: vec![], 35 | activation, 36 | learning_rate 37 | } 38 | 39 | 40 | } 41 | 42 | pub fn feed_forward(&mut self, inputs: Matrix) -> Matrix { 43 | 44 | assert!(self.layers[0] == inputs.data.len(), "Invalid Number of Inputs"); 45 | // println!("{:?} {:?}",self.weights[0],inputs); 46 | // println!("{:?}",self.weights[0].dot_multiply(&inputs).add(&self.biases[0])); 47 | 48 | let mut current = inputs; 49 | 50 | self.data = vec![current.clone()]; 51 | 52 | 53 | for i in 0..self.layers.len() -1 { 54 | current = self.weights[i] 55 | .dot_multiply(¤t) 56 | .add(&self.biases[i]).map(self.activation.function); 57 | 58 | self.data.push(current.clone()); 59 | } 60 | 61 | 62 | current 63 | 64 | } 65 | 66 | pub fn back_propogate(&mut self, inputs:Matrix, targets:Matrix) { 67 | 68 | let mut errors = targets.subtract(&inputs); 69 | 70 | let mut gradients = inputs.clone().map(self.activation.derivative); 71 | 72 | 73 | 74 | 75 | for i in (0..self.layers.len() -1).rev(){ 76 | 77 | gradients = gradients.elementwise_multiply(&errors).map(|x| x * 0.5); // learning rate 78 | 79 | 80 | 81 | 82 | self.weights[i] = self.weights[i].add(&gradients.dot_multiply(&self.data[i].transpose())); 83 | 84 | 85 | 86 | 87 | self.biases[i] = self.biases[i].add(&gradients); 88 | 89 | errors = self.weights[i].transpose().dot_multiply(&errors); 90 | gradients = self.data[i].map(self.activation.derivative); 91 | 92 | } 93 | } 94 | 95 | pub fn train(&mut self, inputs: Vec>, targets: Vec>, epochs: u32) { 96 | for i in 1..=epochs { 97 | if epochs < 100 || i % (epochs / 100) == 0 { 98 | println!("Epoch {} of {}", i, epochs); 99 | } 100 | for j in 0..inputs.len() { 101 | let outputs = self.feed_forward(Matrix::from(inputs[j].clone())); 102 | self.back_propogate(outputs,Matrix::from( targets[j].clone())); 103 | } 104 | } 105 | } 106 | 107 | 108 | 109 | 110 | } --------------------------------------------------------------------------------