├── .gitignore ├── src ├── matrix │ ├── mod.rs │ ├── feedforward │ │ ├── mod.rs │ │ ├── evaluator.rs │ │ └── fabricator.rs │ └── recurrent │ │ ├── mod.rs │ │ ├── evaluator.rs │ │ └── fabricator.rs ├── neat_original │ ├── mod.rs │ ├── fabricator.rs │ └── evaluator.rs ├── sparse_matrix │ ├── mod.rs │ ├── feedforward │ │ ├── mod.rs │ │ ├── evaluator.rs │ │ └── fabricator.rs │ └── recurrent │ │ ├── mod.rs │ │ ├── evaluator.rs │ │ └── fabricator.rs ├── lib.rs └── network │ ├── io.rs │ └── mod.rs ├── Cargo.toml ├── README.md └── LICENCE /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | Cargo.lock 3 | -------------------------------------------------------------------------------- /src/matrix/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod feedforward; 2 | pub mod recurrent; 3 | -------------------------------------------------------------------------------- /src/neat_original/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod evaluator; 2 | pub mod fabricator; 3 | -------------------------------------------------------------------------------- /src/sparse_matrix/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod feedforward; 2 | pub mod recurrent; 3 | -------------------------------------------------------------------------------- /src/matrix/feedforward/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod evaluator; 2 | pub mod fabricator; 3 | -------------------------------------------------------------------------------- /src/matrix/recurrent/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod evaluator; 2 | pub mod fabricator; 3 | -------------------------------------------------------------------------------- /src/sparse_matrix/feedforward/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod evaluator; 2 | pub mod fabricator; 3 | -------------------------------------------------------------------------------- /src/sparse_matrix/recurrent/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod evaluator; 2 | pub mod fabricator; 3 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "favannat" 3 | version = "0.6.4" 4 | authors = ["Silvan Buedenbender "] 5 | edition = "2018" 6 | license = "MIT" 7 | description = "Algorithms to evaluate the function encoded in ANN-like structures." 8 | homepage = "https://github.com/SilvanCodes/favannat" 9 | documentation = "https://docs.rs/favannat" 10 | repository = "https://github.com/SilvanCodes/favannat" 11 | readme = "README.md" 12 | keywords = ["ann", "evolution"] 13 | categories = ["algorithms", "science", "mathematics"] 14 | 15 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 16 | 17 | [dependencies] 18 | nalgebra = "0.32.3" 19 | nalgebra-sparse = "0.9.0" 20 | ndarray = { version = "0.15", optional = true } 21 | -------------------------------------------------------------------------------- /src/matrix/feedforward/evaluator.rs: -------------------------------------------------------------------------------- 1 | use nalgebra::DMatrix; 2 | 3 | use crate::network::{Evaluator, NetworkIO}; 4 | 5 | #[derive(Debug)] 6 | pub struct MatrixFeedforwardEvaluator { 7 | pub stages: Vec>, 8 | pub transformations: Vec, 9 | } 10 | 11 | impl Evaluator for MatrixFeedforwardEvaluator { 12 | fn evaluate(&self, state: T) -> T { 13 | let mut state = NetworkIO::input(state); 14 | // performs evaluation by sequentially matrix multiplying and transforming the state with every stage 15 | for (stage_matrix, transformations) in self.stages.iter().zip(&self.transformations) { 16 | state *= stage_matrix; 17 | for (value, activation) in state.iter_mut().zip(transformations) { 18 | *value = activation(*value); 19 | } 20 | } 21 | NetworkIO::output(state) 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # favannat (FAbricate and eVAluate Neural Networks of Arbitrary Topology) 2 | 3 | Crate is functional but still in early development. 4 | 5 | ## Introduction 6 | 7 | This crates aims to provide some semantics and data structures that allow to turn a somewhat generic description of a neural net into some executable function. 8 | 9 | Therefore it provides the "network" termes like "node" and "edge" and a roughly sketched interface to execute nets; 10 | namely the "Fabricator" trait and the "Evaluator" trait. 11 | 12 | Further it provides one implementation of those traits. 13 | 14 | 15 | ## Limitations 16 | 17 | Only DAGs (directed, acyclic graphs) can be evaluated, which is by design. It is planned to implement logic to unroll recurrent networks into DAGs. 18 | 19 | ## Contribution 20 | 21 | Any thoughts on style and correctness/usefulness are very welcome. 22 | Different implementations of the "Fabricate/Evaluate" traits are appreciated. 23 | -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Silvan Büdenbender 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! This crate allows to evaluate anything that implements the [`network::NetworkLike`] trait. 2 | //! 3 | //! See [`network::net`] for an examplatory implementation. 4 | //! 5 | //! Networks accept any value that implements the [`network::NetworkIO`] trait. 6 | //! 7 | //! The feature `ndarray` implements `NetworkIO` from `ndarray::Array1` when enabled. 8 | 9 | pub mod matrix; 10 | pub mod neat_original; 11 | pub mod network; 12 | pub mod sparse_matrix; 13 | 14 | pub use matrix::{ 15 | feedforward::{evaluator::MatrixFeedforwardEvaluator, fabricator::MatrixFeedforwardFabricator}, 16 | recurrent::{evaluator::MatrixRecurrentEvaluator, fabricator::MatrixRecurrentFabricator}, 17 | }; 18 | 19 | pub use sparse_matrix::{ 20 | feedforward::{ 21 | evaluator::SparseMatrixFeedforwardEvaluator, fabricator::SparseMatrixFeedforwardFabricator, 22 | }, 23 | recurrent::{ 24 | evaluator::SparseMatrixRecurrentEvaluator, fabricator::SparseMatrixRecurrentFabricator, 25 | }, 26 | }; 27 | 28 | pub use network::{Evaluator, Fabricator, StatefulEvaluator, StatefulFabricator}; 29 | 30 | type Matrix = Vec>; 31 | type Transformations = Vec f64>; 32 | -------------------------------------------------------------------------------- /src/matrix/recurrent/evaluator.rs: -------------------------------------------------------------------------------- 1 | use nalgebra::DMatrix; 2 | 3 | use crate::{ 4 | matrix::feedforward::evaluator::MatrixFeedforwardEvaluator, 5 | network::{Evaluator, NetworkIO, StatefulEvaluator}, 6 | }; 7 | 8 | #[derive(Debug)] 9 | pub struct MatrixRecurrentEvaluator { 10 | pub internal: DMatrix, 11 | pub evaluator: MatrixFeedforwardEvaluator, 12 | pub outputs: usize, 13 | } 14 | 15 | impl StatefulEvaluator for MatrixRecurrentEvaluator { 16 | fn evaluate(&mut self, input: T) -> T { 17 | let mut input = NetworkIO::input(input); 18 | input = DMatrix::from_iterator( 19 | 1, 20 | input.len() + self.internal.len(), 21 | input.iter().chain(self.internal.iter()).cloned(), 22 | ); 23 | 24 | self.internal = self.evaluator.evaluate(input); 25 | 26 | NetworkIO::output(DMatrix::from_iterator( 27 | 1, 28 | self.outputs, 29 | self.internal 30 | .view((0, 0), (1, self.outputs)) 31 | .iter() 32 | .cloned(), 33 | )) 34 | } 35 | 36 | fn reset_internal_state(&mut self) { 37 | self.internal = DMatrix::from_element(1, self.internal.len(), 0.0); 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /src/sparse_matrix/recurrent/evaluator.rs: -------------------------------------------------------------------------------- 1 | use nalgebra::DMatrix; 2 | 3 | use crate::{ 4 | network::{Evaluator, NetworkIO, StatefulEvaluator}, 5 | sparse_matrix::feedforward::evaluator::SparseMatrixFeedforwardEvaluator, 6 | }; 7 | 8 | #[derive(Debug)] 9 | pub struct SparseMatrixRecurrentEvaluator { 10 | pub internal: DMatrix, 11 | pub evaluator: SparseMatrixFeedforwardEvaluator, 12 | pub outputs: usize, 13 | } 14 | 15 | impl StatefulEvaluator for SparseMatrixRecurrentEvaluator { 16 | fn evaluate(&mut self, input: T) -> T { 17 | let mut input = NetworkIO::input(input); 18 | input = DMatrix::from_iterator( 19 | 1, 20 | input.len() + self.internal.len(), 21 | input.iter().chain(self.internal.iter()).cloned(), 22 | ); 23 | 24 | self.internal = self.evaluator.evaluate(input); 25 | 26 | NetworkIO::output(DMatrix::from_iterator( 27 | 1, 28 | self.outputs, 29 | self.internal 30 | .view((0, 0), (1, self.outputs)) 31 | .iter() 32 | .cloned(), 33 | )) 34 | } 35 | 36 | fn reset_internal_state(&mut self) { 37 | self.internal = DMatrix::from_element(1, self.internal.len(), 0.0); 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /src/network/io.rs: -------------------------------------------------------------------------------- 1 | use nalgebra::{DMatrix, DVector}; 2 | 3 | /// Data structures implementing this trait can be used as input and output of networks. 4 | pub trait NetworkIO { 5 | fn input(input: Self) -> DMatrix; 6 | fn output(output: DMatrix) -> Self; 7 | } 8 | 9 | impl NetworkIO for DMatrix { 10 | fn input(input: Self) -> DMatrix { 11 | input 12 | } 13 | fn output(output: DMatrix) -> Self { 14 | output 15 | } 16 | } 17 | 18 | impl NetworkIO for DVector { 19 | fn input(input: Self) -> DMatrix { 20 | DMatrix::from_iterator(1, input.len(), input.into_iter().cloned()) 21 | } 22 | fn output(output: DMatrix) -> Self { 23 | DVector::from(output.into_iter().cloned().collect::>()) 24 | } 25 | } 26 | 27 | impl NetworkIO for Vec { 28 | fn input(input: Self) -> DMatrix { 29 | DMatrix::from_iterator(1, input.len(), input.into_iter()) 30 | } 31 | fn output(output: DMatrix) -> Self { 32 | output.into_iter().cloned().collect::>() 33 | } 34 | } 35 | 36 | #[cfg(feature = "ndarray")] 37 | use ndarray::Array1; 38 | 39 | #[cfg(feature = "ndarray")] 40 | impl NetworkIO for Array1 { 41 | fn input(input: Self) -> DMatrix { 42 | DMatrix::from_iterator(1, input.len(), input.into_iter()) 43 | } 44 | fn output(output: DMatrix) -> Self { 45 | Array1::from_iter(output.into_iter().cloned()) 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /src/sparse_matrix/feedforward/evaluator.rs: -------------------------------------------------------------------------------- 1 | use nalgebra::DMatrix; 2 | use nalgebra_sparse::{CscMatrix, SparseEntry, SparseEntryMut}; 3 | 4 | use crate::network::{Evaluator, NetworkIO}; 5 | 6 | #[derive(Debug)] 7 | pub struct SparseMatrixFeedforwardEvaluator { 8 | pub stages: Vec>, 9 | pub transformations: Vec, 10 | } 11 | 12 | impl Evaluator for SparseMatrixFeedforwardEvaluator { 13 | fn evaluate(&self, state: T) -> T { 14 | let state = NetworkIO::input(state); 15 | let mut len = 0; 16 | let mut state: CscMatrix = (&state).into(); 17 | // performs evaluation by sequentially matrix multiplying and transforming the state with every stage 18 | for (stage_matrix, transformations) in self.stages.iter().zip(&self.transformations) { 19 | len = transformations.len(); 20 | state = state * stage_matrix; 21 | for (index, activation) in transformations.iter().enumerate() { 22 | if let SparseEntryMut::NonZero(value) = state.index_entry_mut(0, index) { 23 | *value = activation(*value); 24 | } 25 | } 26 | } 27 | NetworkIO::output(DMatrix::from_iterator( 28 | 1, 29 | len, 30 | (0..len).map(|index| { 31 | if let SparseEntry::NonZero(value) = state.index_entry(0, index) { 32 | *value 33 | } else { 34 | 0.0 35 | } 36 | }), 37 | )) 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /src/neat_original/fabricator.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use crate::network::{EdgeLike, NodeLike, StatefulFabricator}; 4 | 5 | use super::evaluator::{DependentNode, NeatOriginalEvaluator}; 6 | 7 | #[derive(Debug)] 8 | pub struct NeatOriginalFabricator {} 9 | 10 | impl StatefulFabricator for NeatOriginalFabricator 11 | where 12 | N: NodeLike, 13 | E: EdgeLike, 14 | { 15 | type Output = super::evaluator::NeatOriginalEvaluator; 16 | 17 | fn fabricate(net: &impl crate::network::Recurrent) -> Result { 18 | let mut nodes: Vec = Vec::new(); 19 | 20 | let node_input_sum: Vec = vec![0.0; net.nodes().len()]; 21 | let node_active_output: Vec<[f64; 2]> = vec![[0.0; 2]; net.nodes().len()]; 22 | 23 | let mut id_gen = 0_usize..; 24 | let mut id_map: HashMap = HashMap::new(); 25 | 26 | for node in net.nodes() { 27 | id_map.insert(node.id(), id_gen.next().unwrap()); 28 | 29 | nodes.push(DependentNode { 30 | activation_function: node.activation(), 31 | inputs: Vec::new(), 32 | is_active: false, 33 | }); 34 | } 35 | 36 | for edge in net.edges() { 37 | nodes[*id_map.get(&edge.end()).unwrap()].inputs.push(( 38 | *id_map.get(&edge.start()).unwrap(), 39 | edge.weight(), 40 | false, 41 | )) 42 | } 43 | 44 | for edge in net.recurrent_edges() { 45 | nodes[*id_map.get(&edge.end()).unwrap()].inputs.push(( 46 | *id_map.get(&edge.start()).unwrap(), 47 | edge.weight(), 48 | true, 49 | )) 50 | } 51 | 52 | Ok(NeatOriginalEvaluator { 53 | input_ids: net 54 | .inputs() 55 | .iter() 56 | .map(|i| *id_map.get(&i.id()).unwrap()) 57 | .collect(), 58 | output_ids: net 59 | .outputs() 60 | .iter() 61 | .map(|i| *id_map.get(&i.id()).unwrap()) 62 | .collect(), 63 | nodes, 64 | node_input_sum, 65 | node_active_output, 66 | }) 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /src/matrix/recurrent/fabricator.rs: -------------------------------------------------------------------------------- 1 | use nalgebra::DMatrix; 2 | 3 | use crate::{ 4 | matrix::feedforward::fabricator::MatrixFeedforwardFabricator, 5 | network::{ 6 | net::unroll, EdgeLike, Fabricator, NetworkLike, NodeLike, Recurrent, StatefulFabricator, 7 | }, 8 | }; 9 | 10 | use super::evaluator::MatrixRecurrentEvaluator; 11 | 12 | pub struct MatrixRecurrentFabricator; 13 | 14 | impl StatefulFabricator for MatrixRecurrentFabricator 15 | where 16 | N: NodeLike, 17 | E: EdgeLike, 18 | { 19 | type Output = MatrixRecurrentEvaluator; 20 | 21 | fn fabricate(net: &impl Recurrent) -> Result { 22 | let unrolled = unroll(net); 23 | let evaluator = MatrixFeedforwardFabricator::fabricate(&unrolled)?; 24 | let memory = unrolled.outputs().len(); 25 | 26 | assert!(unrolled.inputs().len() - net.inputs().len() == memory); 27 | 28 | Ok(MatrixRecurrentEvaluator { 29 | internal: DMatrix::from_element(1, memory, 0.0), 30 | evaluator, 31 | outputs: net.outputs().len(), 32 | }) 33 | } 34 | } 35 | 36 | #[cfg(test)] 37 | mod tests { 38 | use nalgebra::dmatrix; 39 | 40 | use crate::{ 41 | edges, 42 | matrix::recurrent::fabricator::MatrixRecurrentFabricator, 43 | network::{net::Net, StatefulEvaluator, StatefulFabricator}, 44 | nodes, 45 | }; 46 | 47 | #[test] 48 | fn computes_without_recurrent_edges() { 49 | let some_net = Net::new( 50 | 1, 51 | 1, 52 | nodes!('l', 'l'), 53 | edges!( 54 | 0--1.0->1 55 | ), 56 | ); 57 | 58 | let mut evaluator = MatrixRecurrentFabricator::fabricate(&some_net).unwrap(); 59 | println!("stages {:?}", evaluator); 60 | 61 | let result = evaluator.evaluate(dmatrix![5.0]); 62 | assert_eq!(result, dmatrix![5.0]); 63 | } 64 | 65 | #[test] 66 | fn stateful_net_evaluator_0() { 67 | let mut some_net = Net::new( 68 | 2, 69 | 2, 70 | nodes!('l', 'l', 'l', 'l'), 71 | edges!( 72 | 0--1.0->2, 73 | 1--1.0->3 74 | ), 75 | ); 76 | 77 | some_net.set_recurrent_edges(edges!( 78 | 0--1.0->2, 79 | 1--1.0->3 80 | )); 81 | let mut evaluator = MatrixRecurrentFabricator::fabricate(&some_net).unwrap(); 82 | 83 | let result = evaluator.evaluate(dmatrix![5.0, 0.0]); 84 | assert_eq!(result, dmatrix![5.0, 0.0]); 85 | 86 | let result = evaluator.evaluate(dmatrix![5.0, 5.0]); 87 | assert_eq!(result, dmatrix![10.0, 5.0]); 88 | 89 | let result = evaluator.evaluate(dmatrix![0.0, 5.0]); 90 | assert_eq!(result, dmatrix![5.0, 10.0]); 91 | 92 | let result = evaluator.evaluate(dmatrix![0.0, 0.0]); 93 | assert_eq!(result, dmatrix![0.0, 5.0]); 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /src/neat_original/evaluator.rs: -------------------------------------------------------------------------------- 1 | use nalgebra::DMatrix; 2 | 3 | use crate::network::{NetworkIO, StatefulEvaluator}; 4 | 5 | #[derive(Debug)] 6 | pub struct DependentNode { 7 | pub activation_function: fn(f64) -> f64, 8 | pub inputs: Vec<(usize, f64, bool)>, 9 | pub is_active: bool, 10 | } 11 | 12 | #[derive(Debug)] 13 | pub struct NeatOriginalEvaluator { 14 | pub input_ids: Vec, 15 | pub output_ids: Vec, 16 | pub nodes: Vec, 17 | pub node_input_sum: Vec, 18 | // [0] is current output, [1] it output before that 19 | pub node_active_output: Vec<[f64; 2]>, 20 | } 21 | 22 | impl NeatOriginalEvaluator { 23 | fn outputs_off(&self) -> bool { 24 | for &id in self.output_ids.iter() { 25 | if !self.nodes[id].is_active { 26 | return true; 27 | } 28 | } 29 | false 30 | } 31 | } 32 | 33 | impl StatefulEvaluator for NeatOriginalEvaluator { 34 | fn evaluate(&mut self, input: T) -> T { 35 | let input = NetworkIO::input(input); 36 | 37 | for (&id, &value) in self.input_ids.iter().zip(input.iter()) { 38 | self.node_active_output[id][0] = value; 39 | self.nodes[id].is_active = true; 40 | } 41 | 42 | let mut onetime = false; 43 | 44 | while self.outputs_off() || !onetime { 45 | for id in 0..self.nodes.len() { 46 | if !self.input_ids.contains(&id) { 47 | self.node_input_sum[id] = 0.0; 48 | self.nodes[id].is_active = false; 49 | 50 | let inputs = self.nodes[id].inputs.clone(); 51 | for &(dep_id, weight, recurrent) in inputs.iter() { 52 | if !recurrent { 53 | if self.nodes[dep_id].is_active { 54 | self.nodes[id].is_active = true; 55 | } 56 | self.node_input_sum[id] += self.node_active_output[dep_id][0] * weight; 57 | } else { 58 | self.node_input_sum[id] += self.node_active_output[dep_id][1] * weight; 59 | } 60 | } 61 | } 62 | } 63 | 64 | for id in 0..self.nodes.len() { 65 | if !self.input_ids.contains(&id) && self.nodes[id].is_active { 66 | // shift last output in time 67 | self.node_active_output[id][1] = self.node_active_output[id][0]; 68 | // compute new output when possible 69 | self.node_active_output[id][0] = 70 | (self.nodes[id].activation_function)(self.node_input_sum[id]); 71 | } 72 | } 73 | 74 | onetime = true; 75 | } 76 | 77 | NetworkIO::output(DMatrix::from_iterator( 78 | 1, 79 | self.output_ids.len(), 80 | self.output_ids 81 | .iter() 82 | .map(|&id| self.node_active_output[id][0]), // .collect::>(), 83 | )) 84 | } 85 | 86 | fn reset_internal_state(&mut self) { 87 | for value in self.node_input_sum.iter_mut() { 88 | *value = 0.0; 89 | } 90 | for value in self.node_active_output.iter_mut() { 91 | *value = [0.0; 2]; 92 | } 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /src/sparse_matrix/recurrent/fabricator.rs: -------------------------------------------------------------------------------- 1 | use nalgebra::DMatrix; 2 | 3 | use crate::{ 4 | network::{ 5 | net::unroll, EdgeLike, Fabricator, NetworkLike, NodeLike, Recurrent, StatefulFabricator, 6 | }, 7 | sparse_matrix::feedforward::fabricator::SparseMatrixFeedforwardFabricator, 8 | }; 9 | 10 | pub struct SparseMatrixRecurrentFabricator; 11 | 12 | impl StatefulFabricator for SparseMatrixRecurrentFabricator 13 | where 14 | N: NodeLike, 15 | E: EdgeLike, 16 | { 17 | type Output = super::evaluator::SparseMatrixRecurrentEvaluator; 18 | 19 | fn fabricate(net: &impl Recurrent) -> Result { 20 | let unrolled = unroll(net); 21 | let evaluator = SparseMatrixFeedforwardFabricator::fabricate(&unrolled)?; 22 | let memory = unrolled.outputs().len(); 23 | 24 | assert!(unrolled.inputs().len() - net.inputs().len() == memory); 25 | 26 | Ok(super::evaluator::SparseMatrixRecurrentEvaluator { 27 | internal: DMatrix::from_element(1, memory, 0.0), 28 | evaluator, 29 | outputs: net.outputs().len(), 30 | }) 31 | } 32 | } 33 | 34 | #[cfg(test)] 35 | mod tests { 36 | use nalgebra::dmatrix; 37 | 38 | use crate::{ 39 | edges, 40 | network::{net::Net, StatefulEvaluator, StatefulFabricator}, 41 | nodes, 42 | sparse_matrix::recurrent::fabricator::SparseMatrixRecurrentFabricator, 43 | }; 44 | 45 | #[test] 46 | fn computes_without_recurrent_edges() { 47 | let some_net = Net::new( 48 | 1, 49 | 1, 50 | nodes!('l', 'l'), 51 | edges!( 52 | 0--1.0->1 53 | ), 54 | ); 55 | 56 | let mut evaluator = SparseMatrixRecurrentFabricator::fabricate(&some_net).unwrap(); 57 | // println!("stages {:?}", evaluator); 58 | 59 | let result = evaluator.evaluate(dmatrix![5.0]); 60 | assert_eq!(result, dmatrix![5.0]); 61 | } 62 | 63 | #[test] 64 | fn stateful_net_evaluator_0() { 65 | let mut some_net = Net::new( 66 | 2, 67 | 2, 68 | nodes!('l', 'l', 'l', 'l'), 69 | edges!( 70 | 0--1.0->2, 71 | 1--1.0->3 72 | ), 73 | ); 74 | 75 | some_net.set_recurrent_edges(edges!( 76 | 0--1.0->2, 77 | 1--1.0->3 78 | )); 79 | let mut evaluator = SparseMatrixRecurrentFabricator::fabricate(&some_net).unwrap(); 80 | // println!("stages {:?}", evaluator); 81 | 82 | let result = evaluator.evaluate(dmatrix![5.0, 0.0]); 83 | assert_eq!(result, dmatrix![5.0, 0.0]); 84 | 85 | let result = evaluator.evaluate(dmatrix![5.0, 5.0]); 86 | assert_eq!(result, dmatrix![10.0, 5.0]); 87 | 88 | let result = evaluator.evaluate(dmatrix![0.0, 5.0]); 89 | assert_eq!(result, dmatrix![5.0, 10.0]); 90 | 91 | let result = evaluator.evaluate(dmatrix![0.0, 0.0]); 92 | assert_eq!(result, dmatrix![0.0, 5.0]); 93 | } 94 | 95 | #[test] 96 | fn stateful_net_evaluator_with_hidden_node() { 97 | let mut some_net = Net::new( 98 | 1, 99 | 1, 100 | nodes!('l', 'l', 'l'), 101 | edges!( 102 | 0--1.0->1, 103 | 1--1.0->2 104 | ), 105 | ); 106 | 107 | some_net.set_recurrent_edges(edges!( 108 | 1--1.0->1 109 | )); 110 | let mut evaluator = SparseMatrixRecurrentFabricator::fabricate(&some_net).unwrap(); 111 | // println!("stages {:?}", evaluator); 112 | 113 | let result = evaluator.evaluate(dmatrix![1.0]); 114 | assert_eq!(result, dmatrix![1.0]); 115 | 116 | let result = evaluator.evaluate(dmatrix![1.0]); 117 | assert_eq!(result, dmatrix![2.0]); 118 | 119 | let result = evaluator.evaluate(dmatrix![1.0]); 120 | assert_eq!(result, dmatrix![3.0]); 121 | } 122 | 123 | #[test] 124 | fn stateful_net_evaluator_not_all_inputs_connected() { 125 | let mut some_net = Net::new( 126 | 2, 127 | 1, 128 | nodes!('l', 'l', 'l'), 129 | edges!( 130 | 0--1.0->2 131 | ), 132 | ); 133 | 134 | some_net.set_recurrent_edges(edges!( 135 | 0--1.0->2 136 | )); 137 | let mut evaluator = SparseMatrixRecurrentFabricator::fabricate(&some_net).unwrap(); 138 | // println!("stages {:?}", evaluator); 139 | 140 | let result = evaluator.evaluate(dmatrix![1.0, 1.0]); 141 | assert_eq!(result, dmatrix![1.0]); 142 | 143 | let result = evaluator.evaluate(dmatrix![1.0, 1.0]); 144 | assert_eq!(result, dmatrix![2.0]); 145 | 146 | let result = evaluator.evaluate(dmatrix![1.0, 1.0]); 147 | assert_eq!(result, dmatrix![2.0]); 148 | } 149 | 150 | #[test] 151 | fn stateful_net_evaluator_not_all_inputs_connected_with_hidden_node() { 152 | let mut some_net = Net::new( 153 | 2, 154 | 1, 155 | nodes!('l', 'l', 'l', 'l'), 156 | edges!( 157 | 0--1.0->2, 158 | 2--1.0->3 159 | ), 160 | ); 161 | 162 | some_net.set_recurrent_edges(edges!( 163 | 0--1.0->2, 164 | 2--1.0->2 165 | )); 166 | let mut evaluator = SparseMatrixRecurrentFabricator::fabricate(&some_net).unwrap(); 167 | // println!("stages {:?}", evaluator); 168 | 169 | let result = evaluator.evaluate(dmatrix![1.0, 1.0]); 170 | assert_eq!(result, dmatrix![1.0]); 171 | 172 | let result = evaluator.evaluate(dmatrix![1.0, 1.0]); 173 | assert_eq!(result, dmatrix![3.0]); 174 | 175 | let result = evaluator.evaluate(dmatrix![1.0, 1.0]); 176 | assert_eq!(result, dmatrix![5.0]); 177 | } 178 | 179 | #[test] 180 | fn stateful_net_evaluator_two_hidden_nodes() { 181 | let mut some_net = Net::new( 182 | 1, 183 | 1, 184 | nodes!('l', 'l', 'l', 'l'), 185 | edges!( 186 | 0--1.0->1, 187 | 0--1.0->2, 188 | 1--1.0->3, 189 | 2--1.0->3 190 | ), 191 | ); 192 | 193 | some_net.set_recurrent_edges(edges!( 194 | 3--1.0->1, 195 | 3--1.0->2 196 | )); 197 | let mut evaluator = SparseMatrixRecurrentFabricator::fabricate(&some_net).unwrap(); 198 | // println!("stages {:?}", evaluator); 199 | 200 | let result = evaluator.evaluate(dmatrix![1.0]); 201 | assert_eq!(result, dmatrix![2.0]); 202 | 203 | let result = evaluator.evaluate(dmatrix![1.0]); 204 | assert_eq!(result, dmatrix![6.0]); 205 | 206 | let result = evaluator.evaluate(dmatrix![1.0]); 207 | assert_eq!(result, dmatrix![14.0]); 208 | } 209 | 210 | #[test] 211 | fn stateful_net_evaluator_self_recurrence() { 212 | let mut some_net = Net::new( 213 | 1, 214 | 1, 215 | nodes!('l', 'l'), 216 | edges!( 217 | 0--1.0->1 218 | ), 219 | ); 220 | 221 | some_net.set_recurrent_edges(edges!( 222 | 1--1.0->1 223 | )); 224 | 225 | let mut evaluator = SparseMatrixRecurrentFabricator::fabricate(&some_net).unwrap(); 226 | // println!("stages {:?}", evaluator); 227 | 228 | let result = evaluator.evaluate(dmatrix![1.0]); 229 | assert_eq!(result, dmatrix![1.0]); 230 | 231 | let result = evaluator.evaluate(dmatrix![1.0]); 232 | assert_eq!(result, dmatrix![2.0]); 233 | 234 | let result = evaluator.evaluate(dmatrix![0.0]); 235 | assert_eq!(result, dmatrix![2.0]); 236 | 237 | let result = evaluator.evaluate(dmatrix![0.0]); 238 | assert_eq!(result, dmatrix![2.0]); 239 | 240 | let result = evaluator.evaluate(dmatrix![1.0]); 241 | assert_eq!(result, dmatrix![3.0]); 242 | } 243 | } 244 | -------------------------------------------------------------------------------- /src/matrix/feedforward/fabricator.rs: -------------------------------------------------------------------------------- 1 | use crate::network::{EdgeLike, Fabricator, NetworkLike, NodeLike}; 2 | use nalgebra::{DMatrix, DVector}; 3 | use std::collections::HashMap; 4 | 5 | pub struct MatrixFeedforwardFabricator; 6 | 7 | impl MatrixFeedforwardFabricator { 8 | fn get_matrix(dynamic_matrix: Vec>) -> DMatrix { 9 | let columns = dynamic_matrix 10 | .into_iter() 11 | .map(DVector::from_vec) 12 | .collect::>(); 13 | 14 | DMatrix::from_columns(&columns) 15 | } 16 | } 17 | 18 | impl Fabricator for MatrixFeedforwardFabricator 19 | where 20 | N: NodeLike, 21 | E: EdgeLike, 22 | { 23 | type Output = super::evaluator::MatrixFeedforwardEvaluator; 24 | 25 | fn fabricate(net: &impl NetworkLike) -> Result { 26 | // build dependency graph by collecting incoming edges per node 27 | let mut dependency_graph: HashMap> = HashMap::new(); 28 | 29 | for edge in net.edges() { 30 | dependency_graph 31 | .entry(edge.end()) 32 | .and_modify(|dependencies| dependencies.push(edge)) 33 | .or_insert_with(|| vec![edge]); 34 | } 35 | 36 | if dependency_graph.is_empty() { 37 | return Err("no edges present, net invalid"); 38 | } 39 | 40 | // keep track of dependencies present 41 | let mut dependency_count = dependency_graph.len(); 42 | 43 | // println!("initial dependency_graph {:#?}", dependency_graph); 44 | 45 | // contains list of matrices (stages) that form the computable net 46 | let mut compute_stages: Vec = Vec::new(); 47 | // contains activation functions corresponding to each stage 48 | let mut stage_transformations: Vec = Vec::new(); 49 | // set available nodes a.k.a net input 50 | let mut available_nodes = net.inputs(); 51 | // sort via Ord implementation of provided nodes to guarantee each input will be processed by the same node every time 52 | available_nodes.sort_unstable(); 53 | // reduce nodes to ids 54 | let mut available_nodes: Vec = available_nodes.iter().map(|n| n.id()).collect(); 55 | 56 | // println!("available_nodes {:?}", available_nodes); 57 | 58 | // set wanted nodes a.k.a net output 59 | let mut wanted_nodes = net.outputs(); 60 | // sort via Ord implementation of provided nodes to guarantee each output will appear in the same order every time 61 | wanted_nodes.sort_unstable(); 62 | // reduce nodes to ids 63 | let wanted_nodes: Vec = wanted_nodes.iter().map(|n| n.id()).collect(); 64 | 65 | // println!("wanted_nodes {:?}", wanted_nodes); 66 | 67 | // gather compute stages by finding computable nodes and required carries until all dependencies are resolved 68 | while !dependency_graph.is_empty() { 69 | // setup new compute stage 70 | let mut stage_matrix: crate::Matrix = Vec::new(); 71 | // setup new transformations 72 | let mut transformations: crate::Transformations = Vec::new(); 73 | // list of nodes becoming available by compute stage 74 | let mut next_available_nodes: Vec = Vec::new(); 75 | 76 | for (&dependent_node, dependencies) in dependency_graph.iter() { 77 | // marker if all dependencies are available 78 | let mut computable = true; 79 | // eventual compute vector 80 | let mut compute_or_carry = vec![f64::NAN; available_nodes.len()]; 81 | // check every dependency 82 | for &dependency in dependencies { 83 | let mut found = false; 84 | for (index, &id) in available_nodes.iter().enumerate() { 85 | if dependency.start() == id { 86 | // add weight to compute vector at position of input 87 | compute_or_carry[index] = dependency.weight(); 88 | found = true; 89 | } 90 | } 91 | // if any dependency is not found the node is not computable yet 92 | if !found { 93 | computable = false; 94 | } 95 | } 96 | if computable { 97 | // replace NAN with 0.0 98 | for n in &mut compute_or_carry { 99 | if n.is_nan() { 100 | *n = 0.0 101 | } 102 | } 103 | // add vec to compute stage 104 | stage_matrix.push(compute_or_carry); 105 | // add activation function to stage transformations 106 | transformations.push( 107 | net.nodes() 108 | .iter() 109 | .find(|&node| node.id() == dependent_node) 110 | .unwrap() 111 | .activation(), 112 | ); 113 | // mark node as available in next iteration 114 | next_available_nodes.push(dependent_node); 115 | } else { 116 | // figure out carries 117 | for (index, &weight) in compute_or_carry.iter().enumerate() { 118 | // if there is some partial dependency that is not carried yet 119 | if !next_available_nodes.contains(&available_nodes[index]) 120 | && !weight.is_nan() 121 | { 122 | let mut carry = vec![0.0; available_nodes.len()]; 123 | carry[index] = 1.0; 124 | // add carry vector 125 | stage_matrix.push(carry); 126 | // add identity function for carried vector 127 | transformations.push(|val| val); 128 | // add node as available 129 | next_available_nodes.push(available_nodes[index]); 130 | } 131 | } 132 | } 133 | } 134 | 135 | // keep any wanted notes if available (output) 136 | for wanted_node in wanted_nodes.iter() { 137 | for (index, available_node) in available_nodes.iter().enumerate() { 138 | if available_node == wanted_node { 139 | // carry only if not carried already 140 | if !next_available_nodes.contains(available_node) { 141 | let mut carry = vec![0.0; available_nodes.len()]; 142 | carry[index] = 1.0; 143 | // add carry vector 144 | stage_matrix.push(carry); 145 | // add identity function for carried vector 146 | transformations.push(|val| val); 147 | // add node as available 148 | next_available_nodes.push(*available_node); 149 | } 150 | } 151 | } 152 | } 153 | 154 | // remove resolved dependencies from dependency graph 155 | for node in next_available_nodes.iter() { 156 | dependency_graph.remove(node); 157 | } 158 | 159 | // if no dependency was removed no progess was made 160 | if dependency_graph.len() == dependency_count { 161 | return Err("can't resolve dependencies, net invalid"); 162 | } else { 163 | dependency_count = dependency_graph.len(); 164 | } 165 | 166 | // println!("next_available_nodes {:?}", next_available_nodes); 167 | 168 | // reorder last stage according to net output order (invalidates next_available_nodes order which wont be used after this point) 169 | if dependency_graph.is_empty() { 170 | // println!("stage_matrix {:?}", stage_matrix); 171 | 172 | let mut reordered_matrix = stage_matrix.clone(); 173 | let mut reordered_transformations = transformations.clone(); 174 | 175 | let mut matched_wanted_count = 0; 176 | 177 | for ((available_node, column), transformation) in next_available_nodes 178 | .iter() 179 | .zip(stage_matrix.into_iter()) 180 | .zip(transformations.into_iter()) 181 | { 182 | for (index, wanted_node) in wanted_nodes.iter().enumerate() { 183 | if available_node == wanted_node { 184 | reordered_matrix[index] = column; 185 | reordered_transformations[index] = transformation; 186 | matched_wanted_count += 1; 187 | break; 188 | } 189 | } 190 | } 191 | 192 | if matched_wanted_count < wanted_nodes.len() { 193 | return Err( 194 | "dependencies resolved but not all outputs computable, net invalid", 195 | ); 196 | } 197 | 198 | // println!("reordered_matrix {:?}", reordered_matrix); 199 | 200 | stage_matrix = reordered_matrix; 201 | transformations = reordered_transformations; 202 | } 203 | 204 | // add resolved dependencies and transformations to compute stages 205 | compute_stages.push(stage_matrix); 206 | stage_transformations.push(transformations); 207 | 208 | // set available nodes for next iteration 209 | available_nodes = next_available_nodes; 210 | } 211 | 212 | Ok(super::evaluator::MatrixFeedforwardEvaluator { 213 | stages: compute_stages 214 | .into_iter() 215 | .map(MatrixFeedforwardFabricator::get_matrix) 216 | .collect(), 217 | transformations: stage_transformations, 218 | }) 219 | } 220 | } 221 | 222 | #[cfg(test)] 223 | mod tests { 224 | use nalgebra::dmatrix; 225 | 226 | use super::MatrixFeedforwardFabricator; 227 | use crate::{ 228 | edges, 229 | network::{net::Net, Evaluator, Fabricator}, 230 | nodes, 231 | }; 232 | 233 | #[test] 234 | fn reports_error_on_empty_edges() { 235 | let net = Net::new(1, 1, nodes!('l', 'l'), Vec::new()); 236 | 237 | assert_eq!( 238 | MatrixFeedforwardFabricator::fabricate(&net).err(), 239 | Some("no edges present, net invalid") 240 | ); 241 | } 242 | 243 | #[test] 244 | fn reports_error_on_missing_edges_to_output() { 245 | let net = Net::new(1, 2, nodes!('l', 'l', 'l'), edges!(0--1.0->1)); 246 | 247 | assert_eq!( 248 | MatrixFeedforwardFabricator::fabricate(&net).err(), 249 | Some("dependencies resolved but not all outputs computable, net invalid") 250 | ); 251 | } 252 | 253 | // test uncomputable output 254 | #[test] 255 | fn reports_error_on_unresolvable_dependency() { 256 | let net = Net::new(1, 1, nodes!('l', 'l', 'l'), edges!(1--0.5->2)); 257 | 258 | assert_eq!( 259 | MatrixFeedforwardFabricator::fabricate(&net).err(), 260 | Some("can't resolve dependencies, net invalid") 261 | ); 262 | } 263 | 264 | // tests construction and evaluation of simplest network 265 | #[test] 266 | fn simple_net_evaluator_0() { 267 | let some_net = Net::new(1, 1, nodes!('l', 'l'), edges!(0--0.5->1)); 268 | 269 | let evaluator = MatrixFeedforwardFabricator::fabricate(&some_net).unwrap(); 270 | // println!("stages {:?}", evaluator.stages); 271 | 272 | let result = evaluator.evaluate(dmatrix![5.0]); 273 | // println!("result {:?}", result); 274 | 275 | assert_eq!(result, dmatrix![2.5]); 276 | } 277 | 278 | // tests input dimension > 1 279 | #[test] 280 | fn simple_net_evaluator_1() { 281 | let some_net = Net::new( 282 | 2, 283 | 1, 284 | nodes!('l', 'l', 'l'), 285 | edges!( 286 | 0--0.5->2, 287 | 1--0.5->2 288 | ), 289 | ); 290 | 291 | let evaluator = MatrixFeedforwardFabricator::fabricate(&some_net).unwrap(); 292 | // println!("stages {:?}", evaluator.stages); 293 | 294 | let result = evaluator.evaluate(dmatrix![5.0, 5.0]); 295 | // println!("result {:?}", result); 296 | 297 | assert_eq!(result, dmatrix![5.0]); 298 | } 299 | 300 | // test linear chaining of edges 301 | #[test] 302 | fn simple_net_evaluator_2() { 303 | let some_net = Net::new( 304 | 1, 305 | 1, 306 | nodes!('l', 'l', 'l'), 307 | edges!( 308 | 0--0.5->1, 309 | 1--0.5->2 310 | ), 311 | ); 312 | 313 | let evaluator = MatrixFeedforwardFabricator::fabricate(&some_net).unwrap(); 314 | // println!("stages {:?}", evaluator.stages); 315 | 316 | let result = evaluator.evaluate(dmatrix![5.0]); 317 | // println!("result {:?}", result); 318 | 319 | assert_eq!(result, dmatrix![1.25]); 320 | } 321 | 322 | // test construction of carry for later needs 323 | #[test] 324 | fn simple_net_evaluator_3() { 325 | let some_net = Net::new( 326 | 1, 327 | 1, 328 | nodes!('l', 'l', 'l'), 329 | edges!( 330 | 0--0.5->1, 331 | 1--0.5->2, 332 | 0--0.5->2 333 | ), 334 | ); 335 | 336 | let evaluator = MatrixFeedforwardFabricator::fabricate(&some_net).unwrap(); 337 | // println!("stages {:?}", evaluator.stages); 338 | 339 | let result = evaluator.evaluate(dmatrix![5.0]); 340 | // println!("result {:?}", result); 341 | 342 | assert_eq!(result, dmatrix![3.75]); 343 | } 344 | 345 | // test construction of carry for early result with dedup carry 346 | #[test] 347 | fn simple_net_evaluator_4() { 348 | let some_net = Net::new( 349 | 1, 350 | 2, 351 | nodes!('l', 'l', 'l', 'l'), 352 | edges!( 353 | 0--0.5->1, 354 | 1--0.5->2, 355 | 0--0.5->3, 356 | 0--0.5->2 357 | ), 358 | ); 359 | 360 | let evaluator = MatrixFeedforwardFabricator::fabricate(&some_net).unwrap(); 361 | 362 | let result = evaluator.evaluate(dmatrix![5.0]); 363 | 364 | assert_eq!(result, dmatrix![3.75, 2.5]); 365 | } 366 | 367 | // test construction of carry for early result flipped order 368 | #[test] 369 | fn simple_net_evaluator_5() { 370 | let some_net = Net::new( 371 | 1, 372 | 2, 373 | nodes!('l', 'l', 'l', 'l'), 374 | edges!( 375 | 0--0.5->1, 376 | 1--0.5->3, 377 | 0--0.5->2 378 | ), 379 | ); 380 | 381 | let evaluator = MatrixFeedforwardFabricator::fabricate(&some_net).unwrap(); 382 | // println!("stages {:?}", evaluator.stages); 383 | 384 | let result = evaluator.evaluate(dmatrix![5.0]); 385 | // println!("result {:?}", result); 386 | 387 | assert_eq!(result, dmatrix![2.5, 1.25]); 388 | } 389 | 390 | #[test] 391 | fn simple_net_evaluator_9() { 392 | let some_net = Net::new( 393 | 2, 394 | 1, 395 | nodes!('l', 'l', 'l'), 396 | edges!( 397 | 0--0.5->2, 398 | 1--0.0->2 399 | ), 400 | ); 401 | 402 | let evaluator = MatrixFeedforwardFabricator::fabricate(&some_net).unwrap(); 403 | // println!("stages {:?}", evaluator.stages); 404 | 405 | let result = evaluator.evaluate(dmatrix![5.0, 5.0]); 406 | // println!("result {:?}", result); 407 | 408 | assert_eq!(result, dmatrix![2.5]); 409 | } 410 | } 411 | -------------------------------------------------------------------------------- /src/network/mod.rs: -------------------------------------------------------------------------------- 1 | //! Defines vocabulary and interfaces for this crate. 2 | 3 | pub use self::io::NetworkIO; 4 | 5 | mod io; 6 | 7 | /// Declares a structure to have [`NodeLike`] properties. 8 | /// 9 | /// [`NodeLike`] provides the plumbing to accept user-defined structures and use them as nodes in this crates context. 10 | /// The implementation of [`NodeLike::id`] needs to provide a unique identifier per node. 11 | pub trait NodeLike: Ord { 12 | fn id(&self) -> usize; 13 | fn activation(&self) -> fn(f64) -> f64; 14 | } 15 | 16 | /// Declares a structure to have [`EdgeLike`] properties. 17 | /// 18 | /// [`EdgeLike`] provides the plumbing to accept user-defined structures and use them as edges in this crates context. 19 | pub trait EdgeLike { 20 | fn start(&self) -> usize; 21 | fn end(&self) -> usize; 22 | fn weight(&self) -> f64; 23 | } 24 | 25 | /// Declares a structure to have network-like properties. 26 | /// 27 | /// `NetworkLike` sits at the core of this crate. 28 | /// Together with [`NodeLike`] and [`EdgeLike`] it provides the interface to start using this crate. 29 | /// Structures that are `NetworkLike` can be fabricated and evaluated by the different implementations of the 30 | /// [`Fabricator`], [`Evaluator`], [`StatefulFabricator`] and [`StatefulEvaluator`] traits. 31 | pub trait NetworkLike { 32 | fn edges(&self) -> Vec<&E>; 33 | fn inputs(&self) -> Vec<&N>; 34 | fn hidden(&self) -> Vec<&N>; 35 | fn outputs(&self) -> Vec<&N>; 36 | 37 | fn nodes(&self) -> Vec<&N> { 38 | self.inputs() 39 | .into_iter() 40 | .chain(self.hidden().into_iter()) 41 | .chain(self.outputs().into_iter()) 42 | .collect() 43 | } 44 | } 45 | 46 | /// Declares a [`NetworkLike`] structure to have recurrent edges. 47 | /// 48 | /// Recurrent edges act like memory cells in a network. 49 | /// They imply that internal state has to be preserved. 50 | pub trait Recurrent: NetworkLike { 51 | fn recurrent_edges(&self) -> Vec<&E>; 52 | } 53 | 54 | /// A facade behind which evaluation of a fabricated [`NetworkLike`] structure is implemented. 55 | pub trait Evaluator { 56 | fn evaluate(&self, input: T) -> T; 57 | } 58 | 59 | /// A facade behind which evaluation of a fabricated [`Recurrent`] [`NetworkLike`] structure is implemented. 60 | /// 61 | /// Due to its statefulness it needs mutable access and provides a way to reset the internal state. 62 | pub trait StatefulEvaluator { 63 | fn evaluate(&mut self, input: T) -> T; 64 | fn reset_internal_state(&mut self); 65 | } 66 | 67 | /// A facade behind which the fabrication of a [`NetworkLike`] structure is implemented. 68 | /// 69 | /// Fabrication means transforming a description of a network, the [`NetworkLike`] structure, into an executable form of its encoded function, an [`Evaluator`]. 70 | pub trait Fabricator { 71 | type Output: Evaluator; 72 | 73 | fn fabricate(net: &impl NetworkLike) -> Result; 74 | } 75 | 76 | /// A facade behind which the fabrication of a [`Recurrent`] [`NetworkLike`] structure is implemented. 77 | /// 78 | /// Fabrication means transforming a description of a network, the [`Recurrent`] [`NetworkLike`] structure, into an executable form of its encoded function, a [`StatefulEvaluator`]. 79 | pub trait StatefulFabricator { 80 | type Output: StatefulEvaluator; 81 | 82 | fn fabricate(net: &impl Recurrent) -> Result; 83 | } 84 | 85 | /// Contains an example of a [`Recurrent`] [`NetworkLike`] structure. 86 | pub mod net { 87 | use std::collections::HashMap; 88 | 89 | use super::{EdgeLike, NetworkLike, NodeLike, Recurrent}; 90 | 91 | #[derive(Debug)] 92 | pub struct Node { 93 | id: usize, 94 | activation: fn(f64) -> f64, 95 | } 96 | 97 | impl Node { 98 | pub fn new(id: usize, activation: fn(f64) -> f64) -> Self { 99 | Self { id, activation } 100 | } 101 | } 102 | 103 | impl NodeLike for Node { 104 | fn id(&self) -> usize { 105 | self.id 106 | } 107 | fn activation(&self) -> fn(f64) -> f64 { 108 | self.activation 109 | } 110 | } 111 | 112 | impl PartialEq for Node { 113 | fn eq(&self, other: &Self) -> bool { 114 | self.id() == other.id() 115 | } 116 | } 117 | 118 | impl Eq for Node {} 119 | 120 | impl PartialOrd for Node { 121 | fn partial_cmp(&self, other: &Self) -> Option { 122 | Some(self.cmp(other)) 123 | } 124 | } 125 | 126 | impl Ord for Node { 127 | fn cmp(&self, other: &Self) -> std::cmp::Ordering { 128 | self.id().cmp(&other.id()) 129 | } 130 | } 131 | 132 | #[derive(Debug)] 133 | pub struct Edge { 134 | start: usize, 135 | end: usize, 136 | weight: f64, 137 | } 138 | 139 | impl Edge { 140 | pub fn new(start: usize, end: usize, weight: f64) -> Self { 141 | Self { start, end, weight } 142 | } 143 | } 144 | 145 | impl EdgeLike for Edge { 146 | fn start(&self) -> usize { 147 | self.start 148 | } 149 | fn end(&self) -> usize { 150 | self.end 151 | } 152 | fn weight(&self) -> f64 { 153 | self.weight 154 | } 155 | } 156 | 157 | /// [`Net`] is an example of a [`Recurrent`] [`NetworkLike`] structure and also used as an intermediate representation to perform the [`unroll`] operation on [`Recurrent`] [`NetworkLike`] structures. 158 | #[derive(Debug)] 159 | pub struct Net { 160 | inputs: usize, 161 | outputs: usize, 162 | nodes: Vec, 163 | edges: Vec, 164 | recurrent_edges: Vec, 165 | } 166 | 167 | impl NetworkLike for Net { 168 | fn edges(&self) -> Vec<&Edge> { 169 | self.edges.iter().collect() 170 | } 171 | fn inputs(&self) -> Vec<&Node> { 172 | self.nodes.iter().take(self.inputs).collect() 173 | } 174 | fn hidden(&self) -> Vec<&Node> { 175 | self.nodes 176 | .iter() 177 | .skip(self.inputs) 178 | .take(self.nodes.len() - self.inputs - self.outputs) 179 | .collect() 180 | } 181 | 182 | fn outputs(&self) -> Vec<&Node> { 183 | self.nodes 184 | .iter() 185 | .skip(self.nodes().len() - self.outputs) 186 | .collect() 187 | } 188 | 189 | fn nodes(&self) -> Vec<&Node> { 190 | self.nodes.iter().collect() 191 | } 192 | } 193 | 194 | impl Recurrent for Net { 195 | fn recurrent_edges(&self) -> Vec<&Edge> { 196 | self.recurrent_edges.iter().collect() 197 | } 198 | } 199 | 200 | impl Net { 201 | pub fn new(inputs: usize, outputs: usize, nodes: Vec, edges: Vec) -> Self { 202 | Net { 203 | inputs, 204 | outputs, 205 | nodes, 206 | edges, 207 | recurrent_edges: Vec::new(), 208 | } 209 | } 210 | pub fn set_recurrent_edges(&mut self, edges: Vec) { 211 | self.recurrent_edges = edges 212 | } 213 | } 214 | 215 | /// unroll is an essential operation in order to evaluate [`Recurrent`] [`NetworkLike`] structures. 216 | /// 217 | /// It restructures the edges and nodes to be evaluatable in a feedforward manner. 218 | /// The evaluation further depends on the implementations in [`crate::matrix::recurrent::evaluator`] and [`crate::sparse_matrix::recurrent::evaluator`] which handle the internal state. 219 | pub fn unroll, N: NodeLike, E: EdgeLike>(recurrent: &R) -> Net { 220 | // remember known ids as they can not be reused as otherwise 221 | // during rewriting edge inputs/outputs stuff would be confused 222 | let known_ids = recurrent 223 | .nodes() 224 | .iter() 225 | .map(|node| node.id()) 226 | .collect::>(); 227 | 228 | let mut known_edges = recurrent 229 | .edges() 230 | .iter() 231 | .map(|e| Edge { 232 | start: e.start(), 233 | end: e.end(), 234 | weight: e.weight(), 235 | }) 236 | .collect::>(); 237 | 238 | let mut known_recurrent_edges = recurrent 239 | .recurrent_edges() 240 | .iter() 241 | .map(|e| Edge { 242 | start: e.start(), 243 | end: e.end(), 244 | weight: e.weight(), 245 | }) 246 | .collect::>(); 247 | 248 | let mut new_low_ids = (usize::MIN..usize::MAX).filter(|tmp_id| !known_ids.contains(tmp_id)); 249 | 250 | // give static input nodes the lowest possible ids to not fuck up output order by sorting in feedforward fabricator 251 | let mut known_inputs = recurrent.inputs(); 252 | known_inputs.sort_unstable(); 253 | let mut known_inputs = known_inputs 254 | .iter() 255 | .map(|n| { 256 | let new_id = new_low_ids.next().unwrap(); 257 | 258 | // patch all edges to new id 259 | for edge in &mut known_edges { 260 | if edge.start == n.id() { 261 | edge.start = new_id 262 | } 263 | if edge.end == n.id() { 264 | edge.end = new_id 265 | } 266 | } 267 | 268 | // patch all recurrent edges to new id 269 | for edge in &mut known_recurrent_edges { 270 | if edge.start == n.id() { 271 | edge.start = new_id 272 | } 273 | if edge.end == n.id() { 274 | edge.end = new_id 275 | } 276 | } 277 | 278 | Node { 279 | id: new_id, 280 | activation: n.activation(), 281 | } 282 | }) 283 | .collect::>(); 284 | 285 | // give static output nodes the lowest possible ids to not fuck up output order by sorting in feedforward fabricator 286 | let mut known_outputs = recurrent.outputs(); 287 | known_outputs.sort_unstable(); 288 | let mut known_outputs = known_outputs 289 | .iter() 290 | .map(|n| { 291 | let new_id = new_low_ids.next().unwrap(); 292 | 293 | // patch all edges to new id 294 | for edge in &mut known_edges { 295 | if edge.start == n.id() { 296 | edge.start = new_id 297 | } 298 | if edge.end == n.id() { 299 | edge.end = new_id 300 | } 301 | } 302 | 303 | // patch all recurrent edges to new id 304 | for edge in &mut known_recurrent_edges { 305 | if edge.start == n.id() { 306 | edge.start = new_id 307 | } 308 | if edge.end == n.id() { 309 | edge.end = new_id 310 | } 311 | } 312 | 313 | Node { 314 | id: new_id, 315 | activation: n.activation(), 316 | } 317 | }) 318 | .collect::>(); 319 | 320 | let mut unroll_map: HashMap = HashMap::new(); 321 | 322 | // create wrapping input for all original outputs, regardless of if they are used 323 | // this is to simplify the state transfer inside the stateful matrix evaluator 324 | for output in &known_outputs { 325 | let wrapper_input_id = new_low_ids.next().unwrap(); 326 | 327 | let wrapper_input_node = Node { 328 | id: wrapper_input_id, 329 | activation: |val| val, 330 | }; 331 | 332 | known_inputs.push(wrapper_input_node); 333 | 334 | unroll_map.insert(output.id(), wrapper_input_id); 335 | } 336 | 337 | // create all wrapping nodes and egdes for recurrent connections with patched ids 338 | for recurrent_edge in known_recurrent_edges { 339 | let recurrent_input = unroll_map.entry(recurrent_edge.start()).or_insert_with(|| { 340 | let wrapper_input_id = new_low_ids.next().unwrap(); 341 | 342 | let wrapper_input_node = Node { 343 | id: wrapper_input_id, 344 | activation: |val| val, 345 | }; 346 | let wrapper_output_node = Node { 347 | id: new_low_ids.next().unwrap(), 348 | activation: |val| val, 349 | }; 350 | 351 | // used to carry value into next evaluation 352 | let outward_wrapping_edge = Edge { 353 | start: recurrent_edge.start(), 354 | weight: 1.0, 355 | end: wrapper_output_node.id(), 356 | }; 357 | 358 | // add nodes for wrapping 359 | known_inputs.push(wrapper_input_node); 360 | known_outputs.push(wrapper_output_node); 361 | 362 | // add outward wrapping connection 363 | known_edges.push(outward_wrapping_edge); 364 | 365 | wrapper_input_id 366 | }); 367 | 368 | let inward_wrapping_connection = Edge { 369 | start: *recurrent_input, 370 | end: recurrent_edge.end(), 371 | weight: recurrent_edge.weight(), 372 | }; 373 | 374 | known_edges.push(inward_wrapping_connection); 375 | } 376 | 377 | let inputs_count = known_inputs.len(); 378 | let outputs_count = known_outputs.len(); 379 | let nodes = known_inputs 380 | .into_iter() 381 | .chain(recurrent.hidden().iter().map(|n| Node { 382 | id: n.id(), 383 | activation: n.activation(), 384 | })) 385 | .chain(known_outputs.into_iter()) 386 | .collect::>(); 387 | let edges = known_edges; 388 | 389 | Net::new(inputs_count, outputs_count, nodes, edges) 390 | } 391 | 392 | pub mod activations { 393 | pub const LINEAR: fn(f64) -> f64 = |val| val; 394 | // pub const SIGMOID: fn(f64) -> f64 = |val| 1.0 / (1.0 + (-1.0 * val).exp()); 395 | pub const SIGMOID: fn(f64) -> f64 = |val| 1.0 / (1.0 + (-4.9 * val).exp()); 396 | pub const TANH: fn(f64) -> f64 = |val| 2.0 * SIGMOID(2.0 * val) - 1.0; 397 | // a = 1, b = 0, c = 1 398 | pub const GAUSSIAN: fn(f64) -> f64 = |val| (val * val / -2.0).exp(); 399 | // pub const STEP: fn(f64) -> f64 = |val| if val > 0.0 { 1.0 } else { 0.0 }; 400 | // pub const SINE: fn(f64) -> f64 = |val| (val * std::f64::consts::PI).sin(); 401 | // pub const COSINE: fn(f64) -> f64 = |val| (val * std::f64::consts::PI).cos(); 402 | pub const INVERSE: fn(f64) -> f64 = |val| -val; 403 | // pub const ABSOLUTE: fn(f64) -> f64 = |val| val.abs(); 404 | pub const RELU: fn(f64) -> f64 = |val| 0f64.max(val); 405 | pub const SQUARED: fn(f64) -> f64 = |val| val * val; 406 | } 407 | 408 | #[macro_export] 409 | macro_rules! edges { 410 | ( $( $start:literal -- $weight:literal -> $end:literal ),* ) => { 411 | { 412 | vec![ 413 | $( 414 | crate::network::net::Edge::new($start, $end, $weight), 415 | )* 416 | ] 417 | } 418 | }; 419 | } 420 | 421 | #[macro_export] 422 | macro_rules! nodes { 423 | ( $( $activation:literal ),* ) => { 424 | { 425 | let mut nodes = Vec::new(); 426 | 427 | $( 428 | nodes.push( 429 | crate::network::net::Node::new(nodes.len(), match $activation { 430 | 'l' => crate::network::net::activations::LINEAR, 431 | 's' => crate::network::net::activations::SIGMOID, 432 | 't' => crate::network::net::activations::TANH, 433 | 'g' => crate::network::net::activations::GAUSSIAN, 434 | 'r' => crate::network::net::activations::RELU, 435 | 'q' => crate::network::net::activations::SQUARED, 436 | 'i' => crate::network::net::activations::INVERSE, 437 | _ => crate::network::net::activations::SIGMOID } 438 | ) 439 | ); 440 | )* 441 | 442 | nodes 443 | } 444 | }; 445 | } 446 | } 447 | -------------------------------------------------------------------------------- /src/sparse_matrix/feedforward/fabricator.rs: -------------------------------------------------------------------------------- 1 | use crate::network::{EdgeLike, Fabricator, NetworkLike, NodeLike}; 2 | use nalgebra_sparse::{CooMatrix, CscMatrix}; 3 | use std::collections::HashMap; 4 | 5 | pub struct SparseMatrixFeedforwardFabricator; 6 | 7 | impl SparseMatrixFeedforwardFabricator { 8 | fn get_sparse( 9 | (col_inds, row_inds, data, rows): (Vec, Vec, Vec, usize), 10 | ) -> CscMatrix { 11 | let colums = col_inds.iter().max().unwrap() + 1; 12 | 13 | CscMatrix::from( 14 | &CooMatrix::try_from_triplets(rows, colums, row_inds, col_inds, data).unwrap(), 15 | ) 16 | } 17 | } 18 | 19 | impl Fabricator for SparseMatrixFeedforwardFabricator 20 | where 21 | N: NodeLike, 22 | E: EdgeLike, 23 | { 24 | type Output = super::evaluator::SparseMatrixFeedforwardEvaluator; 25 | 26 | fn fabricate(net: &impl NetworkLike) -> Result { 27 | // build dependency graph by collecting incoming edges per node 28 | let mut dependency_graph: HashMap> = HashMap::new(); 29 | 30 | for edge in net.edges() { 31 | dependency_graph 32 | .entry(edge.end()) 33 | .and_modify(|dependencies| dependencies.push(edge)) 34 | .or_insert_with(|| vec![edge]); 35 | } 36 | 37 | if dependency_graph.is_empty() { 38 | return Err("no edges present, net invalid"); 39 | } 40 | 41 | // keep track of dependencies present 42 | let mut dependency_count = dependency_graph.len(); 43 | 44 | // println!("initial dependency_graph {:#?}", dependency_graph); 45 | 46 | // contains list of matrices (stages) that form the computable net 47 | let mut compute_stages: Vec<(Vec, Vec, Vec, usize)> = Vec::new(); 48 | // contains activation functions corresponding to each stage 49 | let mut stage_transformations: Vec = Vec::new(); 50 | // set available nodes a.k.a net input 51 | let mut available_nodes = net.inputs(); 52 | // sort via Ord implementation of provided nodes to guarantee each input will be processed by the same node every time 53 | available_nodes.sort_unstable(); 54 | // reduce nodes to ids 55 | let mut available_nodes: Vec = available_nodes.iter().map(|n| n.id()).collect(); 56 | 57 | // println!("available_nodes {:?}", available_nodes); 58 | 59 | // set wanted nodes a.k.a net output 60 | let mut wanted_nodes = net.outputs(); 61 | // sort via Ord implementation of provided nodes to guarantee each output will appear in the same order every time 62 | wanted_nodes.sort_unstable(); 63 | // reduce nodes to ids 64 | let wanted_nodes: Vec = wanted_nodes.iter().map(|n| n.id()).collect(); 65 | 66 | // println!("wanted_nodes {:?}", wanted_nodes); 67 | 68 | // gather compute stages by finding computable nodes and required carries until all dependencies are resolved 69 | while !dependency_graph.is_empty() { 70 | // setup new transformations 71 | let mut transformations: crate::Transformations = Vec::new(); 72 | // list of nodes becoming available by compute stage 73 | let mut next_available_nodes: Vec = Vec::new(); 74 | 75 | let mut column_index = 0; 76 | let mut stage_column_indices: Vec = Vec::new(); 77 | let mut stage_row_indices = Vec::new(); 78 | let mut stage_data = Vec::new(); 79 | 80 | for (&dependent_node, dependencies) in dependency_graph.iter() { 81 | let mut node_column_indices = Vec::new(); 82 | let mut node_row_indices = Vec::new(); 83 | let mut node_data = Vec::new(); 84 | // marker if all dependencies are available 85 | let mut computable = true; 86 | // check every dependency 87 | for &dependency in dependencies { 88 | let mut found = false; 89 | for (row_index, &id) in available_nodes.iter().enumerate() { 90 | // index here is row index 91 | if dependency.start() == id { 92 | node_column_indices.push(column_index); 93 | node_row_indices.push(row_index); 94 | node_data.push(dependency.weight()); 95 | found = true; 96 | } 97 | } 98 | // if any dependency is not found the node is not computable yet 99 | if !found { 100 | computable = false; 101 | } 102 | } 103 | if computable { 104 | stage_column_indices = [stage_column_indices, node_column_indices].concat(); 105 | stage_row_indices = [stage_row_indices, node_row_indices].concat(); 106 | stage_data = [stage_data, node_data].concat(); 107 | // add activation function to stage transformations 108 | transformations.push( 109 | net.nodes() 110 | .iter() 111 | .find(|&node| node.id() == dependent_node) 112 | .unwrap() 113 | .activation(), 114 | ); 115 | column_index += 1; 116 | // mark node as available in next iteration 117 | next_available_nodes.push(dependent_node); 118 | } else { 119 | let mut carry_column_indices = Vec::new(); 120 | let mut carry_row_indices = Vec::new(); 121 | let mut carry_data = Vec::new(); 122 | for row_index in node_row_indices { 123 | if !next_available_nodes.contains(&available_nodes[row_index]) { 124 | carry_row_indices.push(row_index); 125 | carry_column_indices.push(column_index); 126 | column_index += 1; 127 | carry_data.push(1.0); 128 | transformations.push(|val| val); 129 | next_available_nodes.push(available_nodes[row_index]); 130 | } 131 | } 132 | stage_column_indices = [stage_column_indices, carry_column_indices].concat(); 133 | stage_row_indices = [stage_row_indices, carry_row_indices].concat(); 134 | stage_data = [stage_data, carry_data].concat(); 135 | } 136 | } 137 | 138 | // keep any wanted notes if available (output) 139 | for wanted_node in wanted_nodes.iter() { 140 | for (row_index, available_node) in available_nodes.iter().enumerate() { 141 | if available_node == wanted_node { 142 | // carry only if not carried already 143 | if !next_available_nodes.contains(available_node) { 144 | stage_column_indices.push(column_index); 145 | column_index += 1; 146 | stage_row_indices.push(row_index); 147 | stage_data.push(1.0); 148 | 149 | // add identity function for carried vector 150 | transformations.push(|val| val); 151 | // add node as available 152 | next_available_nodes.push(*available_node); 153 | } 154 | } 155 | } 156 | } 157 | 158 | // remove resolved dependencies from dependency graph 159 | for node in next_available_nodes.iter() { 160 | dependency_graph.remove(node); 161 | } 162 | 163 | // if no dependency was removed no progess was made 164 | if dependency_graph.len() == dependency_count { 165 | return Err("can't resolve dependencies, net invalid"); 166 | } else { 167 | dependency_count = dependency_graph.len(); 168 | } 169 | 170 | // println!("next_available_nodes {:?}", next_available_nodes); 171 | 172 | // reorder last stage according to net output order (invalidates next_available_nodes order which wont be used after this point) 173 | if dependency_graph.is_empty() { 174 | // println!("stage_matrix {:?}", stage_matrix); 175 | 176 | // let mut reordered_matrix = stage_matrix.clone(); 177 | let mut reordered_stage_column_indices = 178 | vec![usize::MAX; stage_column_indices.len()]; 179 | let mut reordered_transformations = transformations.clone(); 180 | 181 | let mut matched_wanted_count = 0; 182 | 183 | for (old_column_index, available_node) in next_available_nodes.iter().enumerate() { 184 | for (new_column_index, wanted_node) in wanted_nodes.iter().enumerate() { 185 | if available_node == wanted_node { 186 | for (reordered_index, &old_index) in reordered_stage_column_indices 187 | .iter_mut() 188 | .zip(stage_column_indices.iter()) 189 | { 190 | if old_index == old_column_index { 191 | *reordered_index = new_column_index; 192 | } 193 | } 194 | 195 | reordered_transformations[new_column_index] = 196 | transformations[old_column_index]; 197 | matched_wanted_count += 1; 198 | break; 199 | } 200 | } 201 | } 202 | 203 | if matched_wanted_count < wanted_nodes.len() { 204 | return Err( 205 | "dependencies resolved but not all outputs computable, net invalid", 206 | ); 207 | } 208 | 209 | // println!("reordered_matrix {:?}", reordered_matrix); 210 | 211 | stage_column_indices = reordered_stage_column_indices; 212 | transformations = reordered_transformations; 213 | } 214 | 215 | // add resolved dependencies and transformations to compute stages 216 | compute_stages.push(( 217 | stage_column_indices, 218 | stage_row_indices, 219 | stage_data, 220 | available_nodes.len(), 221 | )); 222 | stage_transformations.push(transformations); 223 | 224 | // set available nodes for next iteration 225 | available_nodes = next_available_nodes; 226 | } 227 | 228 | Ok(super::evaluator::SparseMatrixFeedforwardEvaluator { 229 | stages: compute_stages 230 | .into_iter() 231 | .map(SparseMatrixFeedforwardFabricator::get_sparse) 232 | .collect(), 233 | transformations: stage_transformations, 234 | }) 235 | } 236 | } 237 | 238 | #[cfg(test)] 239 | mod tests { 240 | use nalgebra::dmatrix; 241 | 242 | use super::SparseMatrixFeedforwardFabricator; 243 | use crate::{ 244 | edges, 245 | network::{net::Net, Evaluator, Fabricator}, 246 | nodes, 247 | }; 248 | 249 | // tests construction and evaluation of simplest network 250 | #[test] 251 | fn simple_net_evaluator_0() { 252 | let some_net = Net::new(1, 1, nodes!('l', 'l'), edges!(0--0.5->1)); 253 | 254 | let evaluator = SparseMatrixFeedforwardFabricator::fabricate(&some_net).unwrap(); 255 | // println!("stages {:?}", evaluator.stages); 256 | 257 | let result = evaluator.evaluate(dmatrix![5.0]); 258 | // println!("result {:?}", result); 259 | 260 | assert_eq!(result, dmatrix![2.5]); 261 | } 262 | 263 | // tests input dimension > 1 264 | #[test] 265 | fn simple_net_evaluator_1() { 266 | let some_net = Net::new( 267 | 2, 268 | 1, 269 | nodes!('l', 'l', 'l'), 270 | edges!( 271 | 0--0.5->2, 272 | 1--0.5->2 273 | ), 274 | ); 275 | 276 | let evaluator = SparseMatrixFeedforwardFabricator::fabricate(&some_net).unwrap(); 277 | // println!("stages {:?}", evaluator.stages); 278 | 279 | let result = evaluator.evaluate(dmatrix![5.0, 5.0]); 280 | // println!("result {:?}", result); 281 | 282 | assert_eq!(result, dmatrix![5.0]); 283 | } 284 | 285 | // test linear chaining of edges 286 | #[test] 287 | fn simple_net_evaluator_2() { 288 | let some_net = Net::new( 289 | 1, 290 | 1, 291 | nodes!('l', 'l', 'l'), 292 | edges!( 293 | 0--0.5->1, 294 | 1--0.5->2 295 | ), 296 | ); 297 | 298 | let evaluator = SparseMatrixFeedforwardFabricator::fabricate(&some_net).unwrap(); 299 | // println!("stages {:?}", evaluator.stages); 300 | 301 | let result = evaluator.evaluate(dmatrix![5.0]); 302 | // println!("result {:?}", result); 303 | 304 | assert_eq!(result, dmatrix![1.25]); 305 | } 306 | 307 | // test construction of carry for later needs 308 | #[test] 309 | fn simple_net_evaluator_3() { 310 | let some_net = Net::new( 311 | 1, 312 | 1, 313 | nodes!('l', 'l', 'l'), 314 | edges!( 315 | 0--0.5->1, 316 | 1--0.5->2, 317 | 0--0.5->2 318 | ), 319 | ); 320 | 321 | let evaluator = SparseMatrixFeedforwardFabricator::fabricate(&some_net).unwrap(); 322 | // println!("stages {:?}", evaluator.stages); 323 | 324 | let result = evaluator.evaluate(dmatrix![5.0]); 325 | // println!("result {:?}", result); 326 | 327 | assert_eq!(result, dmatrix![3.75]); 328 | } 329 | 330 | // test construction of carry for early result with dedup carry 331 | #[test] 332 | fn simple_net_evaluator_4() { 333 | let some_net = Net::new( 334 | 1, 335 | 2, 336 | nodes!('l', 'l', 'l', 'l'), 337 | edges!( 338 | 0--0.5->1, 339 | 1--0.5->2, 340 | 0--0.5->3, 341 | 0--0.5->2 342 | ), 343 | ); 344 | 345 | let evaluator = SparseMatrixFeedforwardFabricator::fabricate(&some_net).unwrap(); 346 | // println!("stages {:?}", evaluator.stages); 347 | 348 | let result = evaluator.evaluate(dmatrix![5.0]); 349 | // println!("result {:?}", result); 350 | 351 | assert_eq!(result, dmatrix![3.75, 2.5]); 352 | } 353 | 354 | // test construction of carry for early result flipped order 355 | #[test] 356 | fn simple_net_evaluator_5() { 357 | let some_net = Net::new( 358 | 1, 359 | 2, 360 | nodes!('l', 'l', 'l', 'l'), 361 | edges!( 362 | 0--0.5->1, 363 | 1--0.5->3, 364 | 0--0.5->2 365 | ), 366 | ); 367 | 368 | let evaluator = SparseMatrixFeedforwardFabricator::fabricate(&some_net).unwrap(); 369 | 370 | let result = evaluator.evaluate(dmatrix![5.0]); 371 | 372 | assert_eq!(result, dmatrix![2.5, 1.25]); 373 | } 374 | 375 | // test unconnected net 376 | #[test] 377 | fn simple_net_evaluator_6() { 378 | let some_net = Net::new(1, 1, nodes!('l', 'l'), Vec::new()); 379 | 380 | if let Err(message) = SparseMatrixFeedforwardFabricator::fabricate(&some_net) { 381 | assert_eq!(message, "no edges present, net invalid"); 382 | } else { 383 | unreachable!(); 384 | } 385 | } 386 | 387 | // test uncomputable output 388 | #[test] 389 | fn simple_net_evaluator_7() { 390 | let some_net = Net::new(1, 1, nodes!('l', 'l', 'l'), edges!(0--0.5->1)); 391 | 392 | if let Err(message) = SparseMatrixFeedforwardFabricator::fabricate(&some_net) { 393 | assert_eq!( 394 | message, 395 | "dependencies resolved but not all outputs computable, net invalid" 396 | ); 397 | } else { 398 | unreachable!(); 399 | } 400 | } 401 | 402 | // test uncomputable output 403 | #[test] 404 | fn simple_net_evaluator_8() { 405 | let some_net = Net::new(1, 1, nodes!('l', 'l', 'l'), edges!(1--0.5->2)); 406 | 407 | if let Err(message) = SparseMatrixFeedforwardFabricator::fabricate(&some_net) { 408 | assert_eq!(message, "can't resolve dependencies, net invalid"); 409 | } else { 410 | unreachable!(); 411 | } 412 | } 413 | 414 | #[test] 415 | fn simple_net_evaluator_9() { 416 | let some_net = Net::new( 417 | 2, 418 | 1, 419 | nodes!('l', 'l', 'l'), 420 | edges!( 421 | 0--0.5->2, 422 | 1--0.0->2 423 | ), 424 | ); 425 | 426 | let evaluator = SparseMatrixFeedforwardFabricator::fabricate(&some_net).unwrap(); 427 | // println!("stages {:?}", evaluator.stages); 428 | 429 | let result = evaluator.evaluate(dmatrix![5.0, 5.0]); 430 | // println!("result {:?}", result); 431 | 432 | assert_eq!(result, dmatrix![2.5]); 433 | } 434 | 435 | // This test fails as currently it is necessary to run connections into all outputs. 436 | // 437 | // #[test] 438 | // fn simple_net_evaluator_not_fully_connected_outputs() { 439 | // let some_net = Net::new( 440 | // 1, 441 | // 2, 442 | // nodes!('l', 'l', 'l'), 443 | // edges!( 444 | // 0--1.0->1 445 | // ), 446 | // ); 447 | 448 | // let evaluator = SparseMatrixFeedforwardFabricator::fabricate(&some_net).unwrap(); 449 | // // println!("stages {:?}", evaluator.stages); 450 | 451 | // let result = evaluator.evaluate(dmatrix![5.0]); 452 | // // println!("result {:?}", result); 453 | 454 | // assert_eq!(result, dmatrix![5.0, 0.0]); 455 | // } 456 | } 457 | --------------------------------------------------------------------------------