├── .gitignore ├── src ├── genes │ ├── id.rs │ ├── nodes.rs │ ├── nodes │ │ └── activations.rs │ └── connections.rs ├── mutations │ ├── error.rs │ ├── change_activation.rs │ ├── change_weights.rs │ ├── add_node.rs │ ├── add_connection.rs │ ├── add_recurrent_connection.rs │ ├── remove_recurrent_connection.rs │ ├── remove_connection.rs │ ├── remove_node.rs │ └── duplicate_node.rs ├── genes.rs ├── mutations.rs ├── favannat_impl.rs ├── lib.rs ├── parameters.rs ├── genome │ └── compatibility_distance.rs └── genome.rs ├── Cargo.toml ├── LICENSE ├── README.md └── benches └── bench_set_genome.rs /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | Cargo.lock 3 | -------------------------------------------------------------------------------- /src/genes/id.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | /// Identity of ANN structure elements. 4 | #[derive(PartialEq, Eq, Debug, Clone, Copy, Hash, Serialize, Deserialize, PartialOrd, Ord)] 5 | pub struct Id(pub u64); 6 | -------------------------------------------------------------------------------- /src/mutations/error.rs: -------------------------------------------------------------------------------- 1 | use thiserror::Error; 2 | 3 | #[derive(Error, Debug, PartialEq, Eq)] 4 | pub enum MutationError { 5 | #[error("No two nodes could be connected by a new feed-forward connection.")] 6 | CouldNotAddFeedForwardConnection, 7 | #[error("No two nodes could be connected by a new recurrent connection.")] 8 | CouldNotAddRecurrentConnection, 9 | #[error("No removable node present in the genome.")] 10 | CouldNotRemoveNode, 11 | #[error("No removable feed-forward connection present in the genome.")] 12 | CouldNotRemoveFeedForwardConnection, 13 | #[error("No removable recurrent connection present in the genome.")] 14 | CouldNotRemoveRecurrentConnection, 15 | #[error("No hidden node to duplicate was present in the genome.")] 16 | CouldNotDuplicateNode, 17 | } 18 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "set_genome" 3 | version = "0.6.7" 4 | authors = ["Silvan Buedenbender "] 5 | edition = "2018" 6 | license = "MIT" 7 | description = "A genetic data structure for neuroevolution algorithms." 8 | homepage = "https://github.com/SilvanCodes/set-genome" 9 | documentation = "https://docs.rs/set_genome" 10 | repository = "https://github.com/SilvanCodes/set-genome" 11 | readme = "README.md" 12 | keywords = ["ann", "evolution"] 13 | categories = ["data-structures", "science", "mathematics"] 14 | 15 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 16 | 17 | [dependencies] 18 | rand = { version = "0.8", features = ["small_rng"] } 19 | rand_distr = "0.4" 20 | serde = { version = "1.0", features = ["derive"] } 21 | config = "0.11" 22 | favannat = { version = "0.6.4", optional = true } 23 | thiserror = "1.0.30" 24 | seahash = "4.1.0" 25 | 26 | [dev-dependencies] 27 | criterion = "0.3" 28 | nalgebra = "0.28.0" 29 | 30 | [[bench]] 31 | name = "bench_set_genome" 32 | harness = false 33 | 34 | [features] 35 | default = ["favannat"] 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Silvan Büdenbender 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # set_genome 2 | 3 | This crate is supposed to act as the representation/reproduction aspect in neuroevolution algorithms and may be combined with arbitrary selection mechanisms. 4 | 5 | SET stands for **S**et **E**ncoded **T**opology and this crate implements a genetic data structure, the `Genome`, 6 | using this set encoding to describe artificial neural networks (ANNs). 7 | Further this crate defines operations on this genome, namely `Mutations` and `Crossover`. 8 | Mutations alter a genome by adding or removing genes, crossover recombines two genomes. 9 | To have an intuitive definition of crossover for network structures the [NEAT algorithm] defined a procedure and has to be understood as a mental predecessor to this SET encoding, 10 | which very much is a formalization and progression of the ideas NEAT introduced regarding the genome. 11 | The thesis describing this genome and other ideas can be found [here], a paper focusing just on the SET encoding will follow soon. 12 | 13 | [neat algorithm]: http://nn.cs.utexas.edu/downloads/papers/stanley.ec02.pdf 14 | [here]: https://www.silvan.codes/SET-NEAT_Thesis.pdf 15 | 16 | ## Usage 17 | 18 | ```toml 19 | [dependencies] 20 | set_genome = "0.1" 21 | ``` 22 | 23 | See the [documentation] more information. 24 | 25 | [documentation]: https://docs.rs/set_genome 26 | -------------------------------------------------------------------------------- /src/mutations/change_activation.rs: -------------------------------------------------------------------------------- 1 | use rand::{prelude::IteratorRandom, Rng}; 2 | 3 | use crate::{ 4 | genes::{Activation, Node}, 5 | genome::Genome, 6 | }; 7 | 8 | use super::Mutations; 9 | 10 | impl Mutations { 11 | /// This mutation changes the activation function of one random hidden node to any other choosen from `activation_pool`. 12 | /// If the pool is empty (the current activation function is excluded) nothing is changed. 13 | pub fn change_activation( 14 | activation_pool: &[Activation], 15 | genome: &mut Genome, 16 | rng: &mut impl Rng, 17 | ) { 18 | if let Some(node) = genome.hidden.random(rng) { 19 | let updated = Node::hidden( 20 | node.id, 21 | activation_pool 22 | .iter() 23 | .filter(|&&activation| activation != node.activation) 24 | .choose(rng) 25 | .cloned() 26 | .unwrap_or(node.activation), 27 | ); 28 | 29 | genome.hidden.replace(updated); 30 | } 31 | } 32 | } 33 | 34 | #[cfg(test)] 35 | mod tests { 36 | use rand::thread_rng; 37 | 38 | use crate::{activations::Activation, Genome, Mutations, Parameters}; 39 | 40 | #[test] 41 | fn change_activation() { 42 | let mut genome = Genome::initialized(&Parameters::default()); 43 | let activation_pool = Activation::all(); 44 | 45 | Mutations::add_node(&activation_pool, &mut genome, &mut thread_rng()); 46 | 47 | let old_activation = genome.hidden.iter().next().unwrap().activation; 48 | 49 | Mutations::change_activation(&activation_pool, &mut genome, &mut thread_rng()); 50 | 51 | assert_ne!( 52 | genome.hidden.iter().next().unwrap().activation, 53 | old_activation 54 | ); 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /src/mutations/change_weights.rs: -------------------------------------------------------------------------------- 1 | use rand::Rng; 2 | 3 | use super::Mutations; 4 | use crate::genome::Genome; 5 | 6 | impl Mutations { 7 | /// This mutation alters `percent_perturbed` connection weights sampled from a gaussian distribution with given `standard_deviation`. 8 | pub fn change_weights( 9 | percent_perturbed: f64, 10 | standard_deviation: f64, 11 | genome: &mut Genome, 12 | rng: &mut impl Rng, 13 | ) { 14 | let change_feed_forward_amount = 15 | (percent_perturbed * genome.feed_forward.len() as f64).ceil() as usize; 16 | let change_recurrent_amount = 17 | (percent_perturbed * genome.recurrent.len() as f64).ceil() as usize; 18 | 19 | genome.feed_forward = genome 20 | .feed_forward 21 | .drain_into_random(rng) 22 | .enumerate() 23 | .map(|(index, mut connection)| { 24 | if index < change_feed_forward_amount { 25 | connection.perturb_weight(standard_deviation, rng); 26 | } 27 | connection 28 | }) 29 | .collect(); 30 | 31 | genome.recurrent = genome 32 | .recurrent 33 | .drain_into_random(rng) 34 | .enumerate() 35 | .map(|(index, mut connection)| { 36 | if index < change_recurrent_amount { 37 | connection.perturb_weight(standard_deviation, rng); 38 | } 39 | connection 40 | }) 41 | .collect(); 42 | } 43 | } 44 | 45 | #[cfg(test)] 46 | mod tests { 47 | use rand::thread_rng; 48 | 49 | use crate::{Genome, Mutations, Parameters}; 50 | 51 | #[test] 52 | fn change_weights() { 53 | let mut genome = Genome::initialized(&Parameters::default()); 54 | 55 | let old_weight = genome.feed_forward.iter().next().unwrap().weight; 56 | 57 | Mutations::change_weights(1.0, 1.0, &mut genome, &mut thread_rng()); 58 | 59 | assert!( 60 | (old_weight - genome.feed_forward.iter().next().unwrap().weight).abs() > f64::EPSILON 61 | ); 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /src/genes/nodes.rs: -------------------------------------------------------------------------------- 1 | use seahash::SeaHasher; 2 | use serde::{Deserialize, Serialize}; 3 | use std::{ 4 | cmp::Ordering, 5 | hash::{Hash, Hasher}, 6 | }; 7 | 8 | use self::activations::Activation; 9 | 10 | use super::{Gene, Id}; 11 | 12 | pub mod activations; 13 | 14 | /// Struct describing a ANN node. 15 | /// 16 | /// A node is made up of an identifier and activation function. 17 | /// See [`Activations`] for more information. 18 | #[derive(Debug, Clone, Deserialize, Serialize)] 19 | pub struct Node { 20 | pub id: Id, 21 | pub order: usize, 22 | pub activation: Activation, 23 | pub id_counter: u64, 24 | } 25 | 26 | impl Node { 27 | pub fn input(id: Id, order: usize) -> Self { 28 | Node { 29 | id, 30 | order, 31 | activation: Activation::Linear, 32 | id_counter: 0, 33 | } 34 | } 35 | 36 | pub fn output(id: Id, order: usize, activation: Activation) -> Self { 37 | Node { 38 | id, 39 | order, 40 | activation, 41 | id_counter: 0, 42 | } 43 | } 44 | 45 | pub fn hidden(id: Id, activation: Activation) -> Self { 46 | Node { 47 | id, 48 | order: 0, 49 | activation, 50 | id_counter: 0, 51 | } 52 | } 53 | 54 | pub fn next_id(&mut self) -> Id { 55 | let mut id_hasher = SeaHasher::new(); 56 | self.id.hash(&mut id_hasher); 57 | self.id_counter.hash(&mut id_hasher); 58 | self.id_counter += 1; 59 | Id(id_hasher.finish()) 60 | } 61 | } 62 | 63 | impl Gene for Node { 64 | fn recombine(&self, other: &Self) -> Self { 65 | Self { 66 | activation: other.activation, 67 | ..*self 68 | } 69 | } 70 | } 71 | 72 | impl Hash for Node { 73 | fn hash(&self, state: &mut H) { 74 | self.id.hash(state) 75 | } 76 | } 77 | 78 | impl PartialEq for Node { 79 | fn eq(&self, other: &Self) -> bool { 80 | self.id == other.id 81 | } 82 | } 83 | 84 | impl Eq for Node {} 85 | 86 | impl PartialOrd for Node { 87 | fn partial_cmp(&self, other: &Self) -> Option { 88 | Some(self.cmp(other)) 89 | } 90 | } 91 | 92 | impl Ord for Node { 93 | fn cmp(&self, other: &Self) -> Ordering { 94 | self.order.cmp(&other.order) 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /src/genes/nodes/activations.rs: -------------------------------------------------------------------------------- 1 | //! Lists constant functions matching the [`Activation`] enum variants. 2 | //! 3 | //! The pool of activation functions is the same as in [this paper](https://weightagnostic.github.io/). 4 | 5 | use serde::{Deserialize, Serialize}; 6 | 7 | /// Possible activation functions for ANN nodes. 8 | /// 9 | /// See the [actual functions listed here] under **Constants**. 10 | /// 11 | /// [actual functions listed here]: ../activations/index.html#constants 12 | #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] 13 | pub enum Activation { 14 | Linear, 15 | Sigmoid, 16 | Tanh, 17 | Gaussian, 18 | Step, 19 | Sine, 20 | Cosine, 21 | Inverse, 22 | Absolute, 23 | Relu, 24 | Squared, 25 | } 26 | 27 | impl Activation { 28 | pub fn all() -> Vec { 29 | vec![ 30 | Self::Linear, 31 | Self::Sigmoid, 32 | Self::Tanh, 33 | Self::Gaussian, 34 | Self::Step, 35 | Self::Sine, 36 | Self::Cosine, 37 | Self::Inverse, 38 | Self::Absolute, 39 | Self::Relu, 40 | Self::Squared, 41 | ] 42 | } 43 | } 44 | 45 | /// Returns the argument unchanged. 46 | pub const LINEAR: fn(f64) -> f64 = |val| val; 47 | 48 | /// Steepened sigmoid function, the same use in the original NEAT paper. 49 | pub const SIGMOID: fn(f64) -> f64 = |val| 1.0 / (1.0 + (-4.9 * val).exp()); 50 | 51 | /// It is a [rescaled sigmoid function]. 52 | /// 53 | /// [rescaled sigmoid function]: https://brenocon.com/blog/2013/10/tanh-is-a-rescaled-logistic-sigmoid-function/ 54 | pub const TANH: fn(f64) -> f64 = |val| 2.0 * SIGMOID(2.0 * val) - 1.0; 55 | 56 | /// [Gaussian function] with parameters a = 1, b = 0, c = 1 a.k.a. standard normal distribution. 57 | /// 58 | /// [Gaussian function]: https://en.wikipedia.org/wiki/Gaussian_function 59 | pub const GAUSSIAN: fn(f64) -> f64 = |val| (val * val / -2.0).exp(); 60 | 61 | /// Returns one if argument greater than zero, else zero. 62 | pub const STEP: fn(f64) -> f64 = |val| if val > 0.0 { 1.0 } else { 0.0 }; 63 | 64 | /// Returns sine of argument. 65 | pub const SINE: fn(f64) -> f64 = |val| (val * std::f64::consts::PI).sin(); 66 | 67 | /// Returns cosine of argument. 68 | pub const COSINE: fn(f64) -> f64 = |val| (val * std::f64::consts::PI).cos(); 69 | 70 | /// Returns negative argument. 71 | pub const INVERSE: fn(f64) -> f64 = |val| -val; 72 | 73 | /// Returns absolute value of argument. 74 | pub const ABSOLUTE: fn(f64) -> f64 = |val| val.abs(); 75 | 76 | /// Returns argument if it is greater than zero, else zero. 77 | pub const RELU: fn(f64) -> f64 = |val| 0f64.max(val); 78 | 79 | /// Returns square of argument. 80 | pub const SQUARED: fn(f64) -> f64 = |val| val * val; 81 | -------------------------------------------------------------------------------- /src/genes/connections.rs: -------------------------------------------------------------------------------- 1 | use rand::Rng; 2 | use seahash::SeaHasher; 3 | use serde::{Deserialize, Serialize}; 4 | use std::{ 5 | cmp::Ordering, 6 | hash::{Hash, Hasher}, 7 | }; 8 | 9 | use super::{Gene, Id}; 10 | 11 | /// Struct describing a ANN connection. 12 | /// 13 | /// A connection is characterised by its input/origin/start, its output/destination/end and its weight. 14 | #[derive(Debug, Clone, Serialize, Deserialize)] 15 | pub struct Connection { 16 | pub input: Id, 17 | pub output: Id, 18 | pub weight: f64, 19 | pub id_counter: u64, 20 | } 21 | 22 | impl Connection { 23 | pub fn new(input: Id, weight: f64, output: Id) -> Self { 24 | Self { 25 | input, 26 | output, 27 | weight, 28 | id_counter: 0, 29 | } 30 | } 31 | 32 | pub fn id(&self) -> (Id, Id) { 33 | (self.input, self.output) 34 | } 35 | 36 | pub fn next_id(&mut self) -> Id { 37 | let mut id_hasher = SeaHasher::new(); 38 | self.input.hash(&mut id_hasher); 39 | self.output.hash(&mut id_hasher); 40 | self.id_counter.hash(&mut id_hasher); 41 | self.id_counter += 1; 42 | Id(id_hasher.finish()) 43 | } 44 | 45 | pub fn perturb_weight(&mut self, standard_deviation: f64, rng: &mut impl Rng) { 46 | self.weight = Self::weight_perturbation(self.weight, standard_deviation, rng); 47 | } 48 | 49 | pub fn weight_perturbation(weight: f64, standard_deviation: f64, rng: &mut impl Rng) -> f64 { 50 | // approximatly normal distributed sample, see: https://en.wikipedia.org/wiki/Irwin%E2%80%93Hall_distribution#Approximating_a_Normal_distribution 51 | let mut perturbation = 52 | ((0..12).map(|_| rng.gen::()).sum::() - 6.0) * standard_deviation; 53 | 54 | while (weight + perturbation) > 1.0 || (weight + perturbation) < -1.0 { 55 | perturbation = -perturbation / 2.0; 56 | } 57 | weight + perturbation 58 | } 59 | } 60 | 61 | impl Gene for Connection { 62 | fn recombine(&self, other: &Self) -> Self { 63 | Self { 64 | weight: other.weight, 65 | ..*self 66 | } 67 | } 68 | } 69 | 70 | impl PartialEq for Connection { 71 | fn eq(&self, other: &Self) -> bool { 72 | self.id() == other.id() 73 | } 74 | } 75 | 76 | impl Eq for Connection {} 77 | 78 | impl Hash for Connection { 79 | fn hash(&self, state: &mut H) { 80 | self.id().hash(state); 81 | } 82 | } 83 | 84 | impl PartialOrd for Connection { 85 | fn partial_cmp(&self, other: &Self) -> Option { 86 | Some(self.cmp(other)) 87 | } 88 | } 89 | 90 | impl Ord for Connection { 91 | fn cmp(&self, other: &Self) -> Ordering { 92 | self.id().cmp(&other.id()) 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /src/mutations/add_node.rs: -------------------------------------------------------------------------------- 1 | use rand::{prelude::SliceRandom, Rng}; 2 | 3 | use crate::{ 4 | genes::{Activation, Connection, Node}, 5 | genome::Genome, 6 | }; 7 | 8 | use super::Mutations; 9 | 10 | impl Mutations { 11 | /// This mutation adds a new node to the genome by "splitting" an existing connection, i.e. the existing connection gets "re-routed" via the new node and the weight of the split connection is set to zero. 12 | /// The connection leading into the new node is of weight 1.0 and the connection originating from the new node has the same weight as the split connection (before it is zeroed). 13 | pub fn add_node(activation_pool: &[Activation], genome: &mut Genome, rng: &mut impl Rng) { 14 | // select an connection gene and split 15 | let mut random_connection = genome.feed_forward.random(rng).cloned().unwrap(); 16 | 17 | let mut id = random_connection.next_id(); 18 | 19 | // avoid id collisions, will cause some kind of "divergent evolution" eventually 20 | while genome.contains(id) { 21 | id = random_connection.next_id() 22 | } 23 | 24 | // construct new node gene 25 | let new_node = Node::hidden(id, activation_pool.choose(rng).cloned().unwrap()); 26 | 27 | // insert new connection pointing to new node 28 | assert!(genome.feed_forward.insert(Connection::new( 29 | random_connection.input, 30 | 1.0, 31 | new_node.id, 32 | ))); 33 | // insert new connection pointing from new node 34 | assert!(genome.feed_forward.insert(Connection::new( 35 | new_node.id, 36 | random_connection.weight, 37 | random_connection.output, 38 | ))); 39 | // insert new node into genome 40 | assert!(genome.hidden.insert(new_node)); 41 | 42 | // update weight to zero to 'deactivate' connnection 43 | random_connection.weight = 0.0; 44 | genome.feed_forward.replace(random_connection); 45 | } 46 | } 47 | 48 | #[cfg(test)] 49 | mod tests { 50 | use rand::thread_rng; 51 | 52 | use crate::{activations::Activation, Genome, Mutations, Parameters}; 53 | 54 | #[test] 55 | fn add_random_node() { 56 | let mut genome = Genome::initialized(&Parameters::default()); 57 | 58 | Mutations::add_node(&Activation::all(), &mut genome, &mut thread_rng()); 59 | 60 | assert_eq!(genome.feed_forward.len(), 3); 61 | } 62 | 63 | #[test] 64 | fn same_structure_same_id() { 65 | let mut genome1 = Genome::initialized(&Parameters::default()); 66 | let mut genome2 = Genome::initialized(&Parameters::default()); 67 | 68 | Mutations::add_node(&Activation::all(), &mut genome1, &mut thread_rng()); 69 | Mutations::add_node(&Activation::all(), &mut genome2, &mut thread_rng()); 70 | 71 | assert_eq!(genome1.hidden, genome2.hidden); 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /src/mutations/add_connection.rs: -------------------------------------------------------------------------------- 1 | use rand::{seq::SliceRandom, Rng}; 2 | 3 | use crate::{genes::Connection, genome::Genome}; 4 | 5 | use super::{MutationError, MutationResult, Mutations}; 6 | 7 | impl Mutations { 8 | /// This mutation adds a new feed-forward connection to the genome, should it be possible. 9 | /// It is possible when any two nodes[^details] are not yet connected with a feed-forward connection. 10 | /// 11 | /// [^details]: "any two nodes" is technically not correct as the start node for the connection has to come from the intersection of input and hidden nodes and the end node has to come from the intersection of the hidden and output nodes. 12 | pub fn add_connection(genome: &mut Genome, rng: &mut impl Rng) -> MutationResult { 13 | let mut possible_start_nodes = genome 14 | .inputs 15 | .iter() 16 | .chain(genome.hidden.iter()) 17 | .collect::>(); 18 | possible_start_nodes.shuffle(rng); 19 | 20 | let mut possible_end_nodes = genome 21 | .hidden 22 | .iter() 23 | .chain(genome.outputs.iter()) 24 | .collect::>(); 25 | possible_end_nodes.shuffle(rng); 26 | 27 | for start_node in possible_start_nodes { 28 | if let Some(end_node) = possible_end_nodes.iter().cloned().find(|&end_node| { 29 | end_node != start_node 30 | && !genome.feed_forward.contains(&Connection::new( 31 | start_node.id, 32 | 0.0, 33 | end_node.id, 34 | )) 35 | && !genome.would_form_cycle(start_node, end_node) 36 | }) { 37 | // add new feed-forward connection 38 | assert!(genome.feed_forward.insert(Connection::new( 39 | start_node.id, 40 | Connection::weight_perturbation(0.0, 0.1, rng), 41 | end_node.id, 42 | ))); 43 | return Ok(()); 44 | } 45 | } 46 | // no possible connection end present 47 | Err(MutationError::CouldNotAddFeedForwardConnection) 48 | } 49 | } 50 | 51 | #[cfg(test)] 52 | mod tests { 53 | use rand::thread_rng; 54 | 55 | use crate::{Genome, MutationError, Mutations, Parameters}; 56 | 57 | #[test] 58 | fn add_random_connection() { 59 | let mut genome = Genome::uninitialized(&Parameters::default()); 60 | 61 | assert!(Mutations::add_connection(&mut genome, &mut thread_rng()).is_ok()); 62 | assert_eq!(genome.feed_forward.len(), 1); 63 | } 64 | 65 | #[test] 66 | fn dont_add_same_connection_twice() { 67 | let mut genome = Genome::uninitialized(&Parameters::default()); 68 | 69 | Mutations::add_connection(&mut genome, &mut thread_rng()).expect("add_connection"); 70 | 71 | if let Err(error) = Mutations::add_connection(&mut genome, &mut thread_rng()) { 72 | assert_eq!(error, MutationError::CouldNotAddFeedForwardConnection); 73 | } else { 74 | unreachable!() 75 | } 76 | 77 | assert_eq!(genome.feed_forward.len(), 1); 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /src/mutations/add_recurrent_connection.rs: -------------------------------------------------------------------------------- 1 | use rand::{seq::SliceRandom, Rng}; 2 | 3 | use crate::{genes::Connection, genome::Genome}; 4 | 5 | use super::{MutationError, MutationResult, Mutations}; 6 | 7 | impl Mutations { 8 | /// This mutation adds a recurrent connection to the `genome` when possible. 9 | /// It is possible when any two nodes [^details] are not yet connected with a recurrent connection. 10 | /// 11 | /// [^details]: "any two nodes" is technically not correct as the end node has to come from the intersection of the hidden and output nodes. 12 | pub fn add_recurrent_connection(genome: &mut Genome, rng: &mut impl Rng) -> MutationResult { 13 | let mut possible_start_nodes = genome 14 | .inputs 15 | .iter() 16 | .chain(genome.hidden.iter()) 17 | .chain(genome.outputs.iter()) 18 | .collect::>(); 19 | possible_start_nodes.shuffle(rng); 20 | 21 | let mut possible_end_nodes = genome 22 | .hidden 23 | .iter() 24 | .chain(genome.outputs.iter()) 25 | .collect::>(); 26 | possible_end_nodes.shuffle(rng); 27 | 28 | for start_node in possible_start_nodes { 29 | if let Some(end_node) = possible_end_nodes.iter().cloned().find(|&end_node| { 30 | !genome 31 | .recurrent 32 | .contains(&Connection::new(start_node.id, 0.0, end_node.id)) 33 | }) { 34 | assert!(genome.recurrent.insert(Connection::new( 35 | start_node.id, 36 | Connection::weight_perturbation(0.0, 0.1, rng), 37 | end_node.id, 38 | ))); 39 | return Ok(()); 40 | } 41 | } 42 | // no possible connection end present 43 | Err(MutationError::CouldNotAddRecurrentConnection) 44 | } 45 | } 46 | 47 | #[cfg(test)] 48 | mod tests { 49 | use rand::thread_rng; 50 | 51 | use crate::{Genome, MutationError, Mutations, Parameters}; 52 | 53 | #[test] 54 | fn add_random_connection() { 55 | let mut genome = Genome::initialized(&Parameters::default()); 56 | 57 | Mutations::add_recurrent_connection(&mut genome, &mut thread_rng()) 58 | .expect("y no add recurrent connection"); 59 | 60 | assert_eq!(genome.recurrent.len(), 1); 61 | } 62 | 63 | #[test] 64 | fn dont_add_same_connection_twice() { 65 | let mut genome = Genome::initialized(&Parameters::default()); 66 | 67 | // create all possible recurrent connections 68 | Mutations::add_recurrent_connection(&mut genome, &mut thread_rng()) 69 | .expect("y no add recurrent connection"); 70 | 71 | Mutations::add_recurrent_connection(&mut genome, &mut thread_rng()) 72 | .expect("y no add recurrent connection"); 73 | 74 | if let Err(error) = Mutations::add_recurrent_connection(&mut genome, &mut thread_rng()) { 75 | assert_eq!(error, MutationError::CouldNotAddRecurrentConnection); 76 | } else { 77 | unreachable!() 78 | } 79 | 80 | assert_eq!(genome.recurrent.len(), 2); 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /src/mutations/remove_recurrent_connection.rs: -------------------------------------------------------------------------------- 1 | use rand::Rng; 2 | 3 | use crate::Genome; 4 | 5 | use super::{MutationError, MutationResult, Mutations}; 6 | 7 | impl Mutations { 8 | /// Removes a recurrent connection if at least one is present in the genome. 9 | /// Does nothing when no recurrent connections exist. 10 | pub fn remove_recurrent_connection(genome: &mut Genome, rng: &mut impl Rng) -> MutationResult { 11 | if let Some(removable_connection) = &genome 12 | .recurrent 13 | .iter() 14 | // make iterator wrap 15 | .cycle() 16 | // randomly offset into the iterator to choose any node 17 | .skip((rng.gen::() * (genome.recurrent.len()) as f64).floor() as usize) 18 | .cloned() 19 | .next() 20 | { 21 | assert!(genome.recurrent.remove(removable_connection)); 22 | Ok(()) 23 | } else { 24 | Err(MutationError::CouldNotRemoveRecurrentConnection) 25 | } 26 | } 27 | } 28 | 29 | #[cfg(test)] 30 | mod tests { 31 | use rand::thread_rng; 32 | 33 | use crate::{ 34 | activations::Activation, 35 | genes::{Connection, Genes, Id, Node}, 36 | mutations::MutationError, 37 | Genome, Mutations, 38 | }; 39 | 40 | #[test] 41 | fn can_remove_recurrent_connection() { 42 | let mut genome = Genome { 43 | inputs: Genes(vec![Node::input(Id(0), 0)].iter().cloned().collect()), 44 | outputs: Genes( 45 | vec![Node::output(Id(1), 0, Activation::Linear)] 46 | .iter() 47 | .cloned() 48 | .collect(), 49 | ), 50 | feed_forward: Genes( 51 | vec![Connection::new(Id(0), 1.0, Id(1))] 52 | .iter() 53 | .cloned() 54 | .collect(), 55 | ), 56 | recurrent: Genes( 57 | vec![Connection::new(Id(0), 1.0, Id(1))] 58 | .iter() 59 | .cloned() 60 | .collect(), 61 | ), 62 | ..Default::default() 63 | }; 64 | 65 | assert!(Mutations::remove_recurrent_connection(&mut genome, &mut thread_rng()).is_ok()) 66 | } 67 | 68 | #[test] 69 | fn can_not_remove_recurrent_connection() { 70 | let mut genome = Genome { 71 | inputs: Genes(vec![Node::input(Id(0), 0)].iter().cloned().collect()), 72 | outputs: Genes( 73 | vec![Node::output(Id(1), 0, Activation::Linear)] 74 | .iter() 75 | .cloned() 76 | .collect(), 77 | ), 78 | feed_forward: Genes( 79 | vec![Connection::new(Id(0), 1.0, Id(1))] 80 | .iter() 81 | .cloned() 82 | .collect(), 83 | ), 84 | ..Default::default() 85 | }; 86 | 87 | if let Err(error) = Mutations::remove_recurrent_connection(&mut genome, &mut thread_rng()) { 88 | assert_eq!(error, MutationError::CouldNotRemoveRecurrentConnection); 89 | } else { 90 | unreachable!() 91 | } 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /benches/bench_set_genome.rs: -------------------------------------------------------------------------------- 1 | use criterion::{criterion_group, criterion_main, Criterion}; 2 | use rand::{rngs::SmallRng, thread_rng, SeedableRng}; 3 | use set_genome::{activations::Activation, Genome, Mutations, Parameters}; 4 | 5 | pub fn crossover_same_genome_benchmark(c: &mut Criterion) { 6 | let parameters = Parameters::default(); 7 | 8 | let genome_0 = Genome::initialized(¶meters); 9 | let genome_1 = Genome::initialized(¶meters); 10 | 11 | c.bench_function("crossover same genome", |b| { 12 | b.iter(|| genome_0.cross_in(&genome_1)) 13 | }); 14 | } 15 | 16 | pub fn crossover_highly_mutated_genomes_benchmark(c: &mut Criterion) { 17 | let parameters = Parameters { 18 | structure: Default::default(), 19 | mutations: vec![ 20 | Mutations::AddNode { 21 | chance: 1.0, 22 | activation_pool: vec![ 23 | Activation::Linear, 24 | Activation::Sigmoid, 25 | Activation::Tanh, 26 | Activation::Gaussian, 27 | Activation::Step, 28 | Activation::Sine, 29 | Activation::Cosine, 30 | Activation::Inverse, 31 | Activation::Absolute, 32 | Activation::Relu, 33 | ], 34 | }, 35 | Mutations::AddConnection { chance: 1.0 }, 36 | ], 37 | }; 38 | 39 | let mut genome_0 = Genome::initialized(¶meters); 40 | let mut genome_1 = Genome::initialized(¶meters); 41 | 42 | for _ in 0..100 { 43 | genome_0.mutate(¶meters).expect("mutation"); 44 | genome_1.mutate(¶meters).expect("mutation"); 45 | } 46 | 47 | c.bench_function("crossover highly mutated genomes", |b| { 48 | b.iter(|| genome_0.cross_in(&genome_1)) 49 | }); 50 | } 51 | 52 | pub fn mutate_genome_benchmark(c: &mut Criterion) { 53 | let parameters = Parameters { 54 | structure: Default::default(), 55 | mutations: vec![ 56 | Mutations::AddNode { 57 | chance: 1.0, 58 | activation_pool: vec![ 59 | Activation::Linear, 60 | Activation::Sigmoid, 61 | Activation::Tanh, 62 | Activation::Gaussian, 63 | Activation::Step, 64 | Activation::Sine, 65 | Activation::Cosine, 66 | Activation::Inverse, 67 | Activation::Absolute, 68 | Activation::Relu, 69 | ], 70 | }, 71 | Mutations::AddConnection { chance: 1.0 }, 72 | ], 73 | }; 74 | 75 | let mut genome = Genome::initialized(¶meters); 76 | 77 | c.bench_function("mutate genome", |b| b.iter(|| genome.mutate(¶meters))); 78 | } 79 | 80 | pub fn add_node_to_genome_benchmark(c: &mut Criterion) { 81 | let genome = &mut Genome::initialized(&Parameters::default()); 82 | let rng = &mut SmallRng::from_rng(thread_rng()).unwrap(); 83 | 84 | c.bench_function("add node to genome", |b| { 85 | b.iter(|| Mutations::add_node(&Activation::all(), genome, rng)) 86 | }); 87 | } 88 | 89 | criterion_group!( 90 | benches, 91 | mutate_genome_benchmark, 92 | crossover_same_genome_benchmark, 93 | crossover_highly_mutated_genomes_benchmark, 94 | add_node_to_genome_benchmark 95 | ); 96 | criterion_main!(benches); 97 | -------------------------------------------------------------------------------- /src/mutations/remove_connection.rs: -------------------------------------------------------------------------------- 1 | use rand::Rng; 2 | 3 | use crate::Genome; 4 | 5 | use super::{MutationError, MutationResult, Mutations}; 6 | 7 | impl Mutations { 8 | /// Removes a connection, should this be possible without introducing dangling structure. 9 | /// Dangling means the in- or out-degree of any hidden node is zero, i.e. it neither can receive nor propagate a signal. 10 | /// If it is not possible, no connection will be removed. 11 | pub fn remove_connection(genome: &mut Genome, rng: &mut impl Rng) -> MutationResult { 12 | if let Some(removable_connection) = &genome 13 | .feed_forward 14 | .iter() 15 | // make iterator wrap 16 | .cycle() 17 | // randomly offset into the iterator to choose any node 18 | .skip((rng.gen::() * (genome.feed_forward.len()) as f64).floor() as usize) 19 | // just loop every value once 20 | .take(genome.feed_forward.len()) 21 | .find(|removal_candidate| { 22 | genome.has_alternative_input(removal_candidate.output, removal_candidate.input) 23 | && genome 24 | .has_alternative_output(removal_candidate.input, removal_candidate.output) 25 | }) 26 | .cloned() 27 | { 28 | assert!(genome.feed_forward.remove(removable_connection)); 29 | Ok(()) 30 | } else { 31 | Err(MutationError::CouldNotRemoveFeedForwardConnection) 32 | } 33 | } 34 | } 35 | 36 | #[cfg(test)] 37 | mod tests { 38 | use rand::thread_rng; 39 | 40 | use crate::{ 41 | activations::Activation, 42 | genes::{Connection, Genes, Id, Node}, 43 | mutations::MutationError, 44 | Genome, Mutations, 45 | }; 46 | 47 | #[test] 48 | fn can_remove_connection() { 49 | let mut genome = Genome { 50 | inputs: Genes(vec![Node::input(Id(0), 0)].iter().cloned().collect()), 51 | hidden: Genes( 52 | vec![Node::hidden(Id(2), Activation::Linear)] 53 | .iter() 54 | .cloned() 55 | .collect(), 56 | ), 57 | outputs: Genes( 58 | vec![Node::output(Id(1), 0, Activation::Linear)] 59 | .iter() 60 | .cloned() 61 | .collect(), 62 | ), 63 | feed_forward: Genes( 64 | vec![ 65 | Connection::new(Id(0), 1.0, Id(1)), 66 | Connection::new(Id(0), 1.0, Id(2)), 67 | Connection::new(Id(2), 1.0, Id(1)), 68 | ] 69 | .iter() 70 | .cloned() 71 | .collect(), 72 | ), 73 | ..Default::default() 74 | }; 75 | 76 | assert!(Mutations::remove_connection(&mut genome, &mut thread_rng()).is_ok()); 77 | } 78 | 79 | #[test] 80 | fn can_not_remove_connection() { 81 | let mut genome = Genome { 82 | inputs: Genes(vec![Node::input(Id(0), 0)].iter().cloned().collect()), 83 | outputs: Genes( 84 | vec![Node::output(Id(1), 0, Activation::Linear)] 85 | .iter() 86 | .cloned() 87 | .collect(), 88 | ), 89 | feed_forward: Genes( 90 | vec![Connection::new(Id(0), 1.0, Id(1))] 91 | .iter() 92 | .cloned() 93 | .collect(), 94 | ), 95 | ..Default::default() 96 | }; 97 | 98 | if let Err(error) = Mutations::remove_connection(&mut genome, &mut thread_rng()) { 99 | assert_eq!(error, MutationError::CouldNotRemoveFeedForwardConnection); 100 | } else { 101 | unreachable!() 102 | } 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /src/genes.rs: -------------------------------------------------------------------------------- 1 | //! The `Gene` trait is a marker and in combination with the `Genes` struct describes common operations on collections (sets) of genes. 2 | //! 3 | //! The genome holds several fields with `Genes` of different types. 4 | 5 | use rand::{prelude::IteratorRandom, prelude::SliceRandom, Rng}; 6 | use seahash::SeaHasher; 7 | use serde::{Deserialize, Serialize}; 8 | use std::{ 9 | collections::HashSet, 10 | hash::{BuildHasher, Hash, Hasher}, 11 | iter::FromIterator, 12 | ops::Deref, 13 | ops::DerefMut, 14 | }; 15 | 16 | mod connections; 17 | mod id; 18 | mod nodes; 19 | 20 | pub use connections::Connection; 21 | pub use id::Id; 22 | pub use nodes::{ 23 | activations::{self, Activation}, 24 | Node, 25 | }; 26 | 27 | pub trait Gene: Eq + Hash { 28 | fn recombine(&self, other: &Self) -> Self; 29 | } 30 | 31 | #[derive(Clone, Default)] 32 | pub struct GeneHasher; 33 | 34 | impl BuildHasher for GeneHasher { 35 | type Hasher = SeaHasher; 36 | 37 | fn build_hasher(&self) -> Self::Hasher { 38 | Self::Hasher::new() 39 | } 40 | } 41 | 42 | #[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] 43 | pub struct Genes(pub HashSet); 44 | 45 | // see here: https://stackoverflow.com/questions/60882381/what-is-the-fastest-correct-way-to-detect-that-there-are-no-duplicates-in-a-json/60884343#60884343 46 | impl Hash for Genes { 47 | fn hash(&self, state: &mut H) { 48 | let mut hash = 0; 49 | for gene in &self.0 { 50 | let mut gene_hasher = SeaHasher::new(); 51 | gene.hash(&mut gene_hasher); 52 | hash ^= gene_hasher.finish(); 53 | } 54 | state.write_u64(hash); 55 | } 56 | } 57 | 58 | impl Default for Genes { 59 | fn default() -> Self { 60 | Genes(Default::default()) 61 | } 62 | } 63 | 64 | impl Deref for Genes { 65 | type Target = HashSet; 66 | 67 | fn deref(&self) -> &Self::Target { 68 | &self.0 69 | } 70 | } 71 | 72 | impl DerefMut for Genes { 73 | fn deref_mut(&mut self) -> &mut Self::Target { 74 | &mut self.0 75 | } 76 | } 77 | 78 | impl Genes { 79 | pub fn random(&self, rng: &mut impl Rng) -> Option<&T> { 80 | self.iter().choose(rng) 81 | } 82 | 83 | pub fn drain_into_random(&mut self, rng: &mut impl Rng) -> impl Iterator { 84 | let mut random_vec = self.drain().collect::>(); 85 | random_vec.shuffle(rng); 86 | random_vec.into_iter() 87 | } 88 | 89 | pub fn iterate_matching_genes<'a>( 90 | &'a self, 91 | other: &'a Genes, 92 | ) -> impl Iterator { 93 | self.intersection(other) 94 | // we know item exists in other as we are iterating the intersection 95 | .map(move |item_self| (item_self, other.get(item_self).unwrap())) 96 | } 97 | 98 | pub fn iterate_unique_genes<'a>(&'a self, other: &'a Genes) -> impl Iterator { 99 | self.symmetric_difference(other) 100 | } 101 | } 102 | 103 | impl FromIterator for Genes { 104 | fn from_iter>(iter: I) -> Self { 105 | Genes(iter.into_iter().collect()) 106 | } 107 | } 108 | 109 | impl Genes { 110 | pub fn as_sorted_vec(&self) -> Vec<&T> { 111 | let mut vec: Vec<&T> = self.iter().collect(); 112 | vec.sort_unstable(); 113 | vec 114 | } 115 | } 116 | 117 | impl Genes { 118 | pub fn cross_in(&self, other: &Self, rng: &mut impl Rng) -> Self { 119 | self.iterate_matching_genes(other) 120 | .map(|(gene_self, gene_other)| { 121 | if rng.gen::() < 0.5 { 122 | gene_self.clone() 123 | } else { 124 | gene_self.recombine(gene_other) 125 | } 126 | }) 127 | .chain(self.difference(other).cloned()) 128 | .collect() 129 | } 130 | } 131 | -------------------------------------------------------------------------------- /src/mutations.rs: -------------------------------------------------------------------------------- 1 | use rand::Rng; 2 | use serde::{Deserialize, Serialize}; 3 | 4 | use crate::{genes::Activation, genome::Genome}; 5 | 6 | pub use self::error::MutationError; 7 | 8 | pub type MutationResult = Result<(), MutationError>; 9 | 10 | mod add_connection; 11 | mod add_node; 12 | mod add_recurrent_connection; 13 | mod change_activation; 14 | mod change_weights; 15 | mod duplicate_node; 16 | mod error; 17 | mod remove_connection; 18 | mod remove_node; 19 | mod remove_recurrent_connection; 20 | 21 | /// Lists all possible mutations with their corresponding parameters. 22 | /// 23 | /// Each mutation acts as a self-contained unit and has to be listed in the [`crate::Parameters::mutations`] field in order to take effect when calling [`crate::Genome::mutate_with`]. 24 | #[derive(Debug, Clone, Deserialize, Serialize)] 25 | #[serde(tag = "type")] 26 | #[serde(rename_all = "snake_case")] 27 | pub enum Mutations { 28 | /// See [`Mutations::change_weights`]. 29 | ChangeWeights { 30 | chance: f64, 31 | percent_perturbed: f64, 32 | standard_deviation: f64, 33 | }, 34 | /// See [`Mutations::change_activation`]. 35 | ChangeActivation { 36 | chance: f64, 37 | activation_pool: Vec, 38 | }, 39 | /// See [`Mutations::add_node`]. 40 | AddNode { 41 | chance: f64, 42 | activation_pool: Vec, 43 | }, 44 | /// See [`Mutations::add_connection`]. 45 | AddConnection { chance: f64 }, 46 | /// See [`Mutations::add_recurrent_connection`]. 47 | AddRecurrentConnection { chance: f64 }, 48 | /// See [`Mutations::remove_node`]. 49 | RemoveNode { chance: f64 }, 50 | /// See [`Mutations::remove_connection`]. 51 | RemoveConnection { chance: f64 }, 52 | /// See [`Mutations::remove_recurrent_connection`]. 53 | RemoveRecurrentConnection { chance: f64 }, 54 | /// See [`Mutations::duplicate_node`]. 55 | DuplicateNode { chance: f64 }, 56 | } 57 | 58 | impl Mutations { 59 | /// Mutate a [`Genome`] but respects the associate `chance` field of the [`Mutations`] enum variants. 60 | /// The user needs to supply some RNG manually when using this method directly. 61 | /// Use [`crate::Genome::mutate`] as the default API. 62 | pub fn mutate(&self, genome: &mut Genome, rng: &mut impl Rng) -> MutationResult { 63 | match self { 64 | &Mutations::ChangeWeights { 65 | chance, 66 | percent_perturbed, 67 | standard_deviation, 68 | } => { 69 | if rng.gen::() < chance { 70 | Self::change_weights(percent_perturbed, standard_deviation, genome, rng); 71 | } 72 | } 73 | Mutations::AddNode { 74 | chance, 75 | activation_pool, 76 | } => { 77 | if rng.gen::() < *chance { 78 | Self::add_node(activation_pool, genome, rng) 79 | } 80 | } 81 | &Mutations::AddConnection { chance } => { 82 | if rng.gen::() < chance { 83 | return Self::add_connection(genome, rng); 84 | } 85 | } 86 | &Mutations::AddRecurrentConnection { chance } => { 87 | if rng.gen::() < chance { 88 | return Self::add_recurrent_connection(genome, rng); 89 | } 90 | } 91 | Mutations::ChangeActivation { 92 | chance, 93 | activation_pool, 94 | } => { 95 | if rng.gen::() < *chance { 96 | Self::change_activation(activation_pool, genome, rng) 97 | } 98 | } 99 | &Mutations::RemoveNode { chance } => { 100 | if rng.gen::() < chance { 101 | return Self::remove_node(genome, rng); 102 | } 103 | } 104 | &Mutations::RemoveConnection { chance } => { 105 | if rng.gen::() < chance { 106 | return Self::remove_connection(genome, rng); 107 | } 108 | } 109 | &Mutations::RemoveRecurrentConnection { chance } => { 110 | if rng.gen::() < chance { 111 | return Self::remove_recurrent_connection(genome, rng); 112 | } 113 | } 114 | &Mutations::DuplicateNode { chance } => { 115 | if rng.gen::() < chance { 116 | return Self::duplicate_node(genome, rng); 117 | } 118 | } 119 | } 120 | Ok(()) 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /src/mutations/remove_node.rs: -------------------------------------------------------------------------------- 1 | use rand::Rng; 2 | 3 | use crate::Genome; 4 | 5 | use super::{MutationError, MutationResult, Mutations}; 6 | 7 | impl Mutations { 8 | /// Removes a node and all incoming and outgoing connections, should this be possible without introducing dangling structure. 9 | /// Dangling means the in- or out-degree of any hidden node is zero, i.e. it neither can receive nor propagate a signal. 10 | /// If it is not possible, no node will be removed. 11 | pub fn remove_node(genome: &mut Genome, rng: &mut impl Rng) -> MutationResult { 12 | if let Some(removable_node) = &genome 13 | .hidden 14 | .iter() 15 | // make iterator wrap 16 | .cycle() 17 | // randomly offset into the iterator to choose any node 18 | .skip((rng.gen::() * (genome.hidden.len()) as f64).floor() as usize) 19 | // just loop every value once 20 | .take(genome.hidden.len()) 21 | .find(|removal_candidate| { 22 | genome 23 | .connections() 24 | // find all input nodes of removal candidate 25 | .filter_map(|connection| { 26 | if connection.output == removal_candidate.id { 27 | Some(connection.input) 28 | } else { 29 | None 30 | } 31 | }) 32 | // make sure they have an alternative output 33 | .all(|id| genome.has_alternative_output(id, removal_candidate.id)) 34 | && genome 35 | .connections() 36 | // find all output nodes of removal candidate 37 | .filter_map(|connection| { 38 | if connection.input == removal_candidate.id { 39 | Some(connection.output) 40 | } else { 41 | None 42 | } 43 | }) 44 | // make sure they have an alternative input 45 | .all(|id| genome.has_alternative_input(id, removal_candidate.id)) 46 | }) 47 | .cloned() 48 | { 49 | // remove all feed-forward connections involving the node to be removed 50 | genome.feed_forward.retain(|connection| { 51 | connection.input != removable_node.id && connection.output != removable_node.id 52 | }); 53 | // remove all recurrent connections involving the node to be removed 54 | genome.recurrent.retain(|connection| { 55 | connection.input != removable_node.id && connection.output != removable_node.id 56 | }); 57 | // remove the node to be removed 58 | assert!(genome.hidden.remove(removable_node)); 59 | Ok(()) 60 | } else { 61 | Err(MutationError::CouldNotRemoveNode) 62 | } 63 | } 64 | } 65 | 66 | #[cfg(test)] 67 | mod tests { 68 | use rand::thread_rng; 69 | 70 | use crate::{ 71 | activations::Activation, 72 | genes::{Connection, Genes, Id, Node}, 73 | mutations::MutationError, 74 | Genome, Mutations, 75 | }; 76 | 77 | #[test] 78 | fn can_remove_node() { 79 | let mut genome = Genome { 80 | inputs: Genes(vec![Node::input(Id(0), 0)].iter().cloned().collect()), 81 | hidden: Genes( 82 | vec![ 83 | Node::hidden(Id(2), Activation::Linear), 84 | Node::hidden(Id(3), Activation::Linear), 85 | ] 86 | .iter() 87 | .cloned() 88 | .collect(), 89 | ), 90 | outputs: Genes( 91 | vec![Node::output(Id(1), 0, Activation::Linear)] 92 | .iter() 93 | .cloned() 94 | .collect(), 95 | ), 96 | feed_forward: Genes( 97 | vec![ 98 | Connection::new(Id(0), 1.0, Id(2)), 99 | Connection::new(Id(0), 1.0, Id(3)), 100 | Connection::new(Id(2), 1.0, Id(1)), 101 | Connection::new(Id(3), 1.0, Id(1)), 102 | ] 103 | .iter() 104 | .cloned() 105 | .collect(), 106 | ), 107 | ..Default::default() 108 | }; 109 | 110 | assert!(Mutations::remove_node(&mut genome, &mut thread_rng()).is_ok()) 111 | } 112 | 113 | #[test] 114 | fn can_not_remove_node() { 115 | let mut genome = Genome { 116 | inputs: Genes(vec![Node::input(Id(0), 0)].iter().cloned().collect()), 117 | hidden: Genes( 118 | vec![Node::hidden(Id(2), Activation::Linear)] 119 | .iter() 120 | .cloned() 121 | .collect(), 122 | ), 123 | outputs: Genes( 124 | vec![Node::output(Id(1), 0, Activation::Linear)] 125 | .iter() 126 | .cloned() 127 | .collect(), 128 | ), 129 | feed_forward: Genes( 130 | vec![ 131 | Connection::new(Id(0), 1.0, Id(2)), 132 | Connection::new(Id(2), 1.0, Id(1)), 133 | ] 134 | .iter() 135 | .cloned() 136 | .collect(), 137 | ), 138 | ..Default::default() 139 | }; 140 | 141 | if let Err(error) = Mutations::remove_node(&mut genome, &mut thread_rng()) { 142 | assert_eq!(error, MutationError::CouldNotRemoveNode); 143 | } else { 144 | unreachable!() 145 | } 146 | } 147 | } 148 | -------------------------------------------------------------------------------- /src/favannat_impl.rs: -------------------------------------------------------------------------------- 1 | use favannat::network::{EdgeLike, NetworkLike, NodeLike, Recurrent}; 2 | 3 | use crate::{ 4 | genes::{activations, Activation, Connection, Node}, 5 | genome::Genome, 6 | }; 7 | 8 | impl NodeLike for Node { 9 | fn id(&self) -> usize { 10 | self.id.0 as usize 11 | } 12 | fn activation(&self) -> fn(f64) -> f64 { 13 | match self.activation { 14 | Activation::Linear => activations::LINEAR, 15 | Activation::Sigmoid => activations::SIGMOID, 16 | Activation::Gaussian => activations::GAUSSIAN, 17 | Activation::Tanh => activations::TANH, 18 | Activation::Step => activations::STEP, 19 | Activation::Sine => activations::SINE, 20 | Activation::Cosine => activations::COSINE, 21 | Activation::Inverse => activations::INVERSE, 22 | Activation::Absolute => activations::ABSOLUTE, 23 | Activation::Relu => activations::RELU, 24 | Activation::Squared => activations::SQUARED, 25 | } 26 | } 27 | } 28 | 29 | impl EdgeLike for Connection { 30 | fn start(&self) -> usize { 31 | self.input.0 as usize 32 | } 33 | fn end(&self) -> usize { 34 | self.output.0 as usize 35 | } 36 | fn weight(&self) -> f64 { 37 | self.weight 38 | } 39 | } 40 | 41 | impl NetworkLike for Genome { 42 | fn nodes(&self) -> Vec<&Node> { 43 | self.nodes().collect() 44 | } 45 | fn edges(&self) -> Vec<&Connection> { 46 | self.feed_forward.as_sorted_vec() 47 | } 48 | fn inputs(&self) -> Vec<&Node> { 49 | self.inputs.as_sorted_vec() 50 | } 51 | fn outputs(&self) -> Vec<&Node> { 52 | self.outputs.as_sorted_vec() 53 | } 54 | fn hidden(&self) -> Vec<&Node> { 55 | self.hidden.as_sorted_vec() 56 | } 57 | } 58 | 59 | impl Recurrent for Genome { 60 | fn recurrent_edges(&self) -> Vec<&Connection> { 61 | self.recurrent.as_sorted_vec() 62 | } 63 | } 64 | 65 | #[cfg(test)] 66 | mod tests { 67 | use favannat::{MatrixRecurrentFabricator, StatefulEvaluator, StatefulFabricator}; 68 | use rand_distr::{Distribution, Uniform}; 69 | 70 | use crate::{activations::Activation, Genome, Mutations, Parameters, Structure}; 71 | 72 | // This test brakes with favannat version 0.6.1 due to a bug there. Now with favannat 0.6.2 it is fine. 73 | #[test] 74 | fn verify_output_does_not_occasionally_leak_internal_state() { 75 | let parameters = Parameters { 76 | structure: Structure { 77 | number_of_inputs: 13, 78 | number_of_outputs: 3, 79 | percent_of_connected_inputs: 1.0, 80 | outputs_activation: Activation::Sigmoid, 81 | seed: 42, 82 | }, 83 | mutations: vec![ 84 | Mutations::ChangeWeights { 85 | chance: 0.8, 86 | percent_perturbed: 0.5, 87 | standard_deviation: 0.2, 88 | }, 89 | Mutations::AddNode { 90 | chance: 0.1, 91 | activation_pool: vec![ 92 | Activation::Sigmoid, 93 | Activation::Tanh, 94 | Activation::Gaussian, 95 | Activation::Step, 96 | // Activation::Sine, 97 | // Activation::Cosine, 98 | Activation::Inverse, 99 | Activation::Absolute, 100 | Activation::Relu, 101 | ], 102 | }, 103 | Mutations::AddConnection { chance: 0.2 }, 104 | Mutations::AddConnection { chance: 0.02 }, 105 | Mutations::AddRecurrentConnection { chance: 0.1 }, 106 | Mutations::RemoveConnection { chance: 0.05 }, 107 | Mutations::RemoveConnection { chance: 0.01 }, 108 | Mutations::RemoveNode { chance: 0.05 }, 109 | ], 110 | }; 111 | 112 | let mut genome = Genome::initialized(¶meters); 113 | 114 | for _ in 0..100 { 115 | genome.mutate(¶meters); 116 | } 117 | 118 | let mut evaluator = MatrixRecurrentFabricator::fabricate(&genome).expect("not okay"); 119 | 120 | let between = Uniform::from(-10000.0..10000.0); 121 | let mut rng = rand::thread_rng(); 122 | 123 | for _ in 0..1000 { 124 | let input = (0..13) 125 | .map(|_| between.sample(&mut rng)) 126 | .collect::>(); 127 | let output = evaluator.evaluate(input); 128 | assert!( 129 | output[0] <= 1.0, 130 | "got {} which is bigger than 1.0, genome: {:?}, evaluator: {:?}", 131 | output[0], 132 | &genome, 133 | &evaluator 134 | ); 135 | assert!( 136 | output[0] >= 0.0, 137 | "got {} which is smaller than 0.0, genome: {:?}, evaluator: {:?}", 138 | output[0], 139 | &genome, 140 | &evaluator 141 | ); 142 | assert!( 143 | output[1] <= 1.0, 144 | "got {} which is bigger than 1.0, genome: {:?}, evaluator: {:?}", 145 | output[1], 146 | &genome, 147 | &evaluator 148 | ); 149 | assert!( 150 | output[1] >= 0.0, 151 | "got {} which is smaller than 0.0, genome: {:?}, evaluator: {:?}", 152 | output[1], 153 | &genome, 154 | &evaluator 155 | ); 156 | assert!( 157 | output[2] <= 1.0, 158 | "got {} which is bigger than 1.0, genome: {:?}, evaluator: {:?}", 159 | output[2], 160 | &genome, 161 | &evaluator 162 | ); 163 | assert!( 164 | output[2] >= 0.0, 165 | "got {} which is smaller than 0.0, genome: {:?}, evaluator: {:?}", 166 | output[2], 167 | &genome, 168 | &evaluator 169 | ); 170 | } 171 | } 172 | } 173 | -------------------------------------------------------------------------------- /src/mutations/duplicate_node.rs: -------------------------------------------------------------------------------- 1 | use rand::Rng; 2 | 3 | use crate::{genes::Node, genome::Genome, MutationError}; 4 | 5 | use super::Mutations; 6 | 7 | impl Mutations { 8 | /// This mutation adds a new node to the genome by "duplicating" an existing node, i.e. all incoming and outgoing connections to that node are duplicated as well. 9 | /// 10 | /// The weight of all outgoing connections is half the initial weight so the mutation starts out functionally equivalent to the genome without the mutation. 11 | /// By duplicating a node and its connections small local symetry can develop. 12 | /// It draws inspiration from the concept of gene duplication and cell division. 13 | pub fn duplicate_node(genome: &mut Genome, rng: &mut impl Rng) -> Result<(), MutationError> { 14 | // select an hiddden node gene to duplicate 15 | if let Some(mut random_hidden_node) = genome.hidden.random(rng).cloned() { 16 | let mut id = random_hidden_node.next_id(); 17 | 18 | // avoid id collisions, will cause some kind of "divergent evolution" eventually 19 | while genome.contains(id) { 20 | id = random_hidden_node.next_id() 21 | } 22 | 23 | // construct new node gene 24 | let new_node = Node::hidden(id, random_hidden_node.activation); 25 | 26 | // duplicate outgoing feedforward connections 27 | let mut outgoing_feedforward_connections = genome 28 | .feed_forward 29 | .iter() 30 | .filter(|c| c.input == random_hidden_node.id) 31 | .cloned() 32 | .collect::>(); 33 | 34 | // duplicate incoming feedforward connections 35 | let incoming_feedforward_connections = genome 36 | .feed_forward 37 | .iter() 38 | .filter(|c| c.output == random_hidden_node.id) 39 | .cloned() 40 | .collect::>(); 41 | 42 | let mut new_feedworward_connections = Vec::with_capacity( 43 | outgoing_feedforward_connections.len() + incoming_feedforward_connections.len(), 44 | ); 45 | 46 | // update weights 47 | for connection in outgoing_feedforward_connections.iter_mut() { 48 | connection.weight = connection.weight / 2.0; 49 | let mut new_connection = connection.clone(); 50 | new_connection.input = new_node.id; 51 | new_feedworward_connections.push(new_connection); 52 | } 53 | 54 | // replace updated 55 | for connection in outgoing_feedforward_connections { 56 | assert!(genome.feed_forward.replace(connection).is_some()) 57 | } 58 | 59 | // update ouputs 60 | for mut connection in incoming_feedforward_connections { 61 | connection.output = new_node.id; 62 | new_feedworward_connections.push(connection); 63 | } 64 | 65 | // insert all new connections 66 | for connection in new_feedworward_connections { 67 | assert!(genome.feed_forward.insert(connection)) 68 | } 69 | 70 | // duplicate outgoing recurrent connections 71 | let mut outgoing_recurrent_connections = genome 72 | .recurrent 73 | .iter() 74 | .filter(|c| c.input == random_hidden_node.id && c.output != random_hidden_node.id) 75 | .cloned() 76 | .collect::>(); 77 | 78 | // duplicate incoming recurrent connections 79 | let incoming_recurrent_connections = genome 80 | .recurrent 81 | .iter() 82 | .filter(|c| c.output == random_hidden_node.id && c.input != random_hidden_node.id) 83 | .cloned() 84 | .collect::>(); 85 | 86 | let mut new_recurrent_connections = Vec::with_capacity( 87 | outgoing_recurrent_connections.len() + incoming_recurrent_connections.len(), 88 | ); 89 | 90 | // update weights 91 | for connection in outgoing_recurrent_connections.iter_mut() { 92 | connection.weight = connection.weight / 2.0; 93 | let mut new_connection = connection.clone(); 94 | new_connection.input = new_node.id; 95 | new_recurrent_connections.push(new_connection); 96 | } 97 | 98 | // replace updated 99 | for connection in outgoing_recurrent_connections { 100 | assert!(genome.recurrent.replace(connection).is_some()) 101 | } 102 | 103 | // update ouputs 104 | for mut connection in incoming_recurrent_connections { 105 | connection.output = new_node.id; 106 | new_recurrent_connections.push(connection); 107 | } 108 | 109 | // insert all new connections 110 | for connection in new_recurrent_connections { 111 | assert!(genome.recurrent.insert(connection)) 112 | } 113 | 114 | if let Some(self_loop) = genome 115 | .recurrent 116 | .iter() 117 | .find(|c| c.input == random_hidden_node.id && c.output == random_hidden_node.id) 118 | { 119 | let mut new_self_loop = self_loop.clone(); 120 | new_self_loop.input = new_node.id; 121 | new_self_loop.output = new_node.id; 122 | assert!(genome.recurrent.insert(new_self_loop)) 123 | } 124 | 125 | // replace selected node with updated id_counter 126 | assert!(genome.hidden.replace(random_hidden_node).is_some()); 127 | 128 | // insert duplicated node 129 | assert!(genome.hidden.insert(new_node)); 130 | Ok(()) 131 | } else { 132 | Err(MutationError::CouldNotDuplicateNode) 133 | } 134 | } 135 | } 136 | 137 | #[cfg(test)] 138 | mod tests { 139 | use rand::thread_rng; 140 | 141 | use crate::{activations::Activation, Genome, Mutations, Parameters}; 142 | 143 | #[test] 144 | fn duplicate_random_node() { 145 | let mut genome = Genome::initialized(&Parameters::default()); 146 | assert_eq!(genome.feed_forward.len(), 1); 147 | 148 | Mutations::add_node(&Activation::all(), &mut genome, &mut thread_rng()); 149 | assert_eq!(genome.hidden.len(), 1); 150 | assert_eq!(genome.feed_forward.len(), 3); 151 | 152 | // create all possible recurrent connections 153 | assert!(Mutations::add_recurrent_connection(&mut genome, &mut thread_rng()).is_ok()); 154 | assert!(Mutations::add_recurrent_connection(&mut genome, &mut thread_rng()).is_ok()); 155 | assert!(Mutations::add_recurrent_connection(&mut genome, &mut thread_rng()).is_ok()); 156 | assert!(Mutations::add_recurrent_connection(&mut genome, &mut thread_rng()).is_ok()); 157 | assert!(Mutations::add_recurrent_connection(&mut genome, &mut thread_rng()).is_ok()); 158 | assert!(Mutations::add_recurrent_connection(&mut genome, &mut thread_rng()).is_ok()); 159 | assert_eq!(genome.recurrent.len(), 6); 160 | 161 | assert!(Mutations::duplicate_node(&mut genome, &mut thread_rng()).is_ok()); 162 | 163 | println!("{}", Genome::dot(&genome)); 164 | 165 | assert_eq!(genome.feed_forward.len(), 5); 166 | assert_eq!(genome.recurrent.len(), 10); 167 | assert_eq!(genome.hidden.len(), 2); 168 | } 169 | 170 | #[test] 171 | fn same_structure_same_id() { 172 | let mut genome1 = Genome::initialized(&Parameters::default()); 173 | let mut genome2 = Genome::initialized(&Parameters::default()); 174 | 175 | Mutations::add_node(&Activation::all(), &mut genome1, &mut thread_rng()); 176 | assert!(Mutations::duplicate_node(&mut genome1, &mut thread_rng()).is_ok()); 177 | 178 | Mutations::add_node(&Activation::all(), &mut genome2, &mut thread_rng()); 179 | assert!(Mutations::duplicate_node(&mut genome2, &mut thread_rng()).is_ok()); 180 | 181 | assert_eq!(genome1.hidden, genome2.hidden); 182 | } 183 | } 184 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! This crate is supposed to act as the representation/reproduction aspect in neuroevolution algorithms and may be combined with arbitrary selection mechanisms. 2 | //! 3 | //! # What you can do with this crate 4 | //! ``` 5 | //! # use set_genome::{Genome, Parameters}; 6 | //! # use favannat::{ 7 | //! # MatrixFeedforwardFabricator, Evaluator, Fabricator, 8 | //! # }; 9 | //! # use nalgebra::dmatrix; 10 | //! // Setup a genome context for networks with 10 inputs and 10 outputs. 11 | //! let parameters = Parameters::basic(10, 10); 12 | //! 13 | //! // Initialize a genome. 14 | //! let mut genome = Genome::initialized(¶meters); 15 | //! 16 | //! // Mutate a genome. 17 | //! genome.mutate(¶meters); 18 | //! 19 | //! // Get a phenotype of the genome. 20 | //! let network = MatrixFeedforwardFabricator::fabricate(&genome).expect("Cool network."); 21 | //! 22 | //! // Evaluate a network on an input. 23 | //! let output = network.evaluate(dmatrix![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]); 24 | //! ``` 25 | //! 26 | //! # SET genome 27 | //! 28 | //! SET stands for **S**et **E**ncoded **T**opology and this crate implements a genetic data structure, the [`Genome`], using this set encoding to describe artificial neural networks (ANNs). 29 | //! Further this crate defines operations on this genome, namely [`Mutations`] and [crossover]. Mutations alter a genome by adding or removing genes, crossover recombines two genomes. 30 | //! To have an intuitive definition of crossover for network structures the [NEAT algorithm] defined a procedure and has to be understood as a mental predecessor to this SET encoding, 31 | //! which very much is a formalization and progression of the ideas NEAT introduced regarding the genome. 32 | //! The thesis describing this genome and other ideas can be found [here], a paper focusing just on the SET encoding will follow soon. 33 | //! 34 | //! # Getting started 35 | //! 36 | //! We start by defining our parameters: 37 | //! 38 | //! Suppose we know our task has ten inputs and two outputs, which translate to the input and output layer of our ANN. 39 | //! Further we want 100% of our inputs nodes to be initially connected to the outputs and the outputs shall use the [`activations::Activation::Tanh`] function. 40 | //! Also the weights of our connections are supposed to be capped between \[-1, 1\] and change by deltas sampled from a normal distribution with 0.1 standard deviation. 41 | //! 42 | //! ``` 43 | //! use set_genome::{activations::Activation, Parameters, Structure}; 44 | //! 45 | //! let parameters = Parameters { 46 | //! structure: Structure { 47 | //! // ten inputs 48 | //! number_of_inputs: 10, 49 | //! // two outputs 50 | //! number_of_outputs: 2, 51 | //! // 100% connected 52 | //! percent_of_connected_inputs: 1.0, 53 | //! // specified output activation 54 | //! outputs_activation: Activation::Tanh, 55 | //! // seed for initial genome construction 56 | //! seed: 42 57 | //! }, 58 | //! mutations: vec![], 59 | //! }; 60 | //! ``` 61 | //! This allows us to create an initialized genome which conforms to our description above: 62 | //! 63 | //! ``` 64 | //! # use set_genome::{Genome, activations::Activation, Parameters, Structure}; 65 | //! # 66 | //! # let parameters = Parameters { 67 | //! # structure: Structure { 68 | //! # // ten inputs 69 | //! # number_of_inputs: 10, 70 | //! # // two outputs 71 | //! # number_of_outputs: 2, 72 | //! # // 100% connected 73 | //! # percent_of_connected_inputs: 1.0, 74 | //! # // specified output activation 75 | //! # outputs_activation: Activation::Tanh, 76 | //! // seed for initial genome construction 77 | //! seed: 42 78 | //! # }, 79 | //! # mutations: vec![], 80 | //! # }; 81 | //! # 82 | //! let genome_with_connections = Genome::initialized(¶meters); 83 | //! ``` 84 | //! "Initialized" here means the configured percent of connections have been constructed with random weights. 85 | //! "Uninitialized" thereby implys no connections have been constructed, such a genome is also available: 86 | //! 87 | //! ``` 88 | //! # use set_genome::{Genome, activations::Activation, Parameters, Structure}; 89 | //! # 90 | //! # let parameters = Parameters { 91 | //! # structure: Structure { 92 | //! # // ten inputs 93 | //! # number_of_inputs: 10, 94 | //! # // two outputs 95 | //! # number_of_outputs: 2, 96 | //! # // 100% connected 97 | //! # percent_of_connected_inputs: 1.0, 98 | //! # // specified output activation 99 | //! # outputs_activation: Activation::Tanh, 100 | //! // seed for initial genome construction 101 | //! seed: 42 102 | //! 103 | //! # }, 104 | //! # mutations: vec![], 105 | //! # }; 106 | //! # 107 | //! let genome_without_connections = Genome::uninitialized(¶meters); 108 | //! ``` 109 | //! Setting the `percent_of_connected_inputs` field in the [`parameters::Structure`] parameter to zero makes the 110 | //! "initialized" and "uninitialized" genome look the same. 111 | //! 112 | //! So we got ourselves a genome, let's mutate it: [`Genome::mutate`]. 113 | //! 114 | //! The possible mutations: 115 | //! 116 | //! - [`Mutations::add_connection`] 117 | //! - [`Mutations::add_node`] 118 | //! - [`Mutations::add_recurrent_connection`] 119 | //! - [`Mutations::change_activation`] 120 | //! - [`Mutations::change_weights`] 121 | //! - [`Mutations::remove_node`] 122 | //! - [`Mutations::remove_connection`] 123 | //! - [`Mutations::remove_recurrent_connection`] 124 | //! 125 | //! //! # Features 126 | //! 127 | //! This crate exposes the 'favannat' feature. [favannat] is a library to translate the genome into an executable form and also to execute it. 128 | //! It can be seen as a phenotype of the genome. 129 | //! The feature is enabled by default as probably you want to evaluate your evolved genomes, but disabling it is as easy as this: 130 | //! 131 | //! ```toml 132 | //! [dependencies] 133 | //! set-genome = { version = "x.x.x", default-features = false } 134 | //! ``` 135 | //! 136 | //! If you are interested how they connect, [see here]. 137 | //! favannat can be used to evaluate other data structures of yours, too, if they are [`favannat::network::NetworkLike`]. ;) 138 | //! 139 | //! [thesis]: https://www.silvan.codes/SET-NEAT_Thesis.pdf 140 | //! [this crate]: https://crates.io/crates/favannat 141 | //! [crossover]: `Genome::cross_in` 142 | //! [NEAT algorithm]: http://nn.cs.utexas.edu/downloads/papers/stanley.ec02.pdf 143 | //! [here]: https://www.silvan.codes/SET-NEAT_Thesis.pdf 144 | //! [favannat]: https://docs.rs/favannat 145 | //! [see here]: https://github.com/SilvanCodes/set-genome/blob/main/src/favannat_impl.rs 146 | 147 | pub use genes::{activations, Connection, Id, Node}; 148 | pub use genome::{CompatibilityDistance, Genome}; 149 | pub use mutations::{MutationError, MutationResult, Mutations}; 150 | pub use parameters::{Parameters, Structure}; 151 | use rand::{rngs::SmallRng, thread_rng, SeedableRng}; 152 | 153 | #[cfg(feature = "favannat")] 154 | mod favannat_impl; 155 | mod genes; 156 | mod genome; 157 | mod mutations; 158 | mod parameters; 159 | 160 | impl Genome { 161 | /// Initialization connects the configured percent of inputs nodes to output nodes, i.e. it creates connection genes with random weights. 162 | pub fn uninitialized(parameters: &Parameters) -> Self { 163 | Self::new(¶meters.structure) 164 | } 165 | 166 | pub fn initialized(parameters: &Parameters) -> Self { 167 | let mut genome = Genome::new(¶meters.structure); 168 | genome.init(¶meters.structure); 169 | genome 170 | } 171 | 172 | /// Apply all mutations listed in the [`Parameters`] with respect to their chance of happening. 173 | /// If a mutation is listed multiple times it is applied multiple times. 174 | /// 175 | /// This will probably be the most common way to apply mutations to a genome. 176 | /// 177 | /// # Examples 178 | /// 179 | /// ``` 180 | /// use set_genome::{Genome, Parameters}; 181 | /// 182 | /// // Create parameters, usually read from a configuration file. 183 | /// let parameters = Parameters::default(); 184 | /// 185 | /// // Create an initialized `Genome`. 186 | /// let mut genome = Genome::initialized(¶meters); 187 | /// 188 | /// // Randomly mutate the genome according to the available mutations listed in the parameters of the context and their corresponding chances . 189 | /// genome.mutate(¶meters); 190 | /// ``` 191 | /// 192 | pub fn mutate(&mut self, parameters: &Parameters) -> MutationResult { 193 | let rng = &mut SmallRng::from_rng(thread_rng()).unwrap(); 194 | 195 | for mutation in ¶meters.mutations { 196 | // gamble for application of mutation right here instead of in mutate() ?? 197 | mutation.mutate(self, rng)? 198 | } 199 | Ok(()) 200 | } 201 | } 202 | -------------------------------------------------------------------------------- /src/parameters.rs: -------------------------------------------------------------------------------- 1 | use crate::{genes::Activation, mutations::Mutations}; 2 | use config::{Config, ConfigError, File}; 3 | use serde::{Deserialize, Serialize}; 4 | 5 | /// This struct captures configuration about the basic ANN structure and [available mutations]. 6 | /// 7 | /// It can be constructed manually or from a `.toml` file. 8 | /// 9 | /// # Examples 10 | /// 11 | /// ## Code 12 | /// 13 | /// The following lists everything that is possible to specify: 14 | /// ``` 15 | /// use set_genome::{Parameters, Structure, Mutations, activations::Activation}; 16 | /// 17 | /// let parameters = Parameters { 18 | /// structure: Structure { 19 | /// number_of_inputs: 25, 20 | /// number_of_outputs: 3, 21 | /// percent_of_connected_inputs: 1.0, 22 | /// outputs_activation: Activation::Tanh, 23 | /// seed: 42 24 | /// }, 25 | /// mutations: vec![ 26 | /// Mutations::ChangeWeights { 27 | /// chance: 1.0, 28 | /// percent_perturbed: 0.5, 29 | /// standard_deviation: 0.1, 30 | /// }, 31 | /// Mutations::ChangeActivation { 32 | /// chance: 0.05, 33 | /// activation_pool: vec![ 34 | /// Activation::Linear, 35 | /// Activation::Sigmoid, 36 | /// Activation::Tanh, 37 | /// Activation::Gaussian, 38 | /// Activation::Step, 39 | /// Activation::Sine, 40 | /// Activation::Cosine, 41 | /// Activation::Inverse, 42 | /// Activation::Absolute, 43 | /// Activation::Relu, 44 | /// ], 45 | /// }, 46 | /// Mutations::AddNode { 47 | /// chance: 0.005, 48 | /// activation_pool: vec![ 49 | /// Activation::Linear, 50 | /// Activation::Sigmoid, 51 | /// Activation::Tanh, 52 | /// Activation::Gaussian, 53 | /// Activation::Step, 54 | /// Activation::Sine, 55 | /// Activation::Cosine, 56 | /// Activation::Inverse, 57 | /// Activation::Absolute, 58 | /// Activation::Relu, 59 | /// ], 60 | /// }, 61 | /// Mutations::RemoveNode { chance: 0.001 }, 62 | /// Mutations::AddConnection { chance: 0.1 }, 63 | /// Mutations::RemoveConnection { chance: 0.001 }, 64 | /// Mutations::AddRecurrentConnection { chance: 0.01 }, 65 | /// Mutations::RemoveRecurrentConnection { chance: 0.001 }, 66 | /// ], 67 | /// }; 68 | /// ``` 69 | /// 70 | /// ## Configuration 71 | /// 72 | /// Write a config file like so: 73 | /// ```toml 74 | /// [structure] 75 | /// number_of_inputs = 9 76 | /// number_of_outputs = 2 77 | /// percent_of_connected_inputs = 1.0 78 | /// outputs_activation = "Tanh" 79 | /// 80 | /// [[mutations]] 81 | /// type = "add_connection" 82 | /// chance = 0.1 83 | /// 84 | /// [[mutations]] 85 | /// type = "add_recurrent_connection" 86 | /// chance = 0.01 87 | /// 88 | /// [[mutations]] 89 | /// type = "add_node" 90 | /// chance = 0.005 91 | /// activation_pool = [ 92 | /// "Sigmoid", 93 | /// "Tanh", 94 | /// "Relu", 95 | /// "Linear", 96 | /// "Gaussian", 97 | /// "Step", 98 | /// "Sine", 99 | /// "Cosine", 100 | /// "Inverse", 101 | /// "Absolute", 102 | /// ] 103 | /// 104 | /// [[mutations]] 105 | /// type = "remove_node" 106 | /// chance = 0.001 107 | /// 108 | /// [[mutations]] 109 | /// type = "change_weights" 110 | /// chance = 1.0 111 | /// percent_perturbed = 0.5 112 | /// standard_deviation = 0.1, 113 | /// 114 | /// [[mutations]] 115 | /// type = "change_activation" 116 | /// chance = 0.05 117 | /// activation_pool = [ 118 | /// "Sigmoid", 119 | /// "Tanh", 120 | /// "Relu", 121 | /// "Linear", 122 | /// "Gaussian", 123 | /// "Step", 124 | /// "Sine", 125 | /// "Cosine", 126 | /// "Inverse", 127 | /// "Absolute", 128 | /// ] 129 | /// 130 | /// [[mutations]] 131 | /// type = "remove_connection" 132 | /// chance = 0.001 133 | /// 134 | /// [[mutations]] 135 | /// type = "remove_recurrent_connection" 136 | /// chance = 0.001 137 | /// ``` 138 | /// 139 | /// And then read the file: 140 | /// 141 | /// ```text 142 | /// // let parameters = Parameters::new("path/to/file"); 143 | /// ``` 144 | /// 145 | /// [available mutations]: `Mutations` 146 | /// 147 | #[derive(Deserialize, Serialize, Debug, Clone)] 148 | pub struct Parameters { 149 | /// Describes basic structure of the ANN. 150 | pub structure: Structure, 151 | /// List of mutations that execute on [`crate::Genome::mutate_with`] 152 | pub mutations: Vec, 153 | } 154 | 155 | impl Default for Parameters { 156 | fn default() -> Self { 157 | Self { 158 | structure: Structure::default(), 159 | mutations: vec![ 160 | Mutations::ChangeWeights { 161 | chance: 1.0, 162 | percent_perturbed: 0.5, 163 | standard_deviation: 0.1, 164 | }, 165 | Mutations::ChangeActivation { 166 | chance: 0.05, 167 | activation_pool: vec![ 168 | Activation::Linear, 169 | Activation::Sigmoid, 170 | Activation::Tanh, 171 | Activation::Gaussian, 172 | Activation::Step, 173 | Activation::Sine, 174 | Activation::Cosine, 175 | Activation::Inverse, 176 | Activation::Absolute, 177 | Activation::Relu, 178 | ], 179 | }, 180 | Mutations::AddNode { 181 | chance: 0.005, 182 | activation_pool: vec![ 183 | Activation::Linear, 184 | Activation::Sigmoid, 185 | Activation::Tanh, 186 | Activation::Gaussian, 187 | Activation::Step, 188 | Activation::Sine, 189 | Activation::Cosine, 190 | Activation::Inverse, 191 | Activation::Absolute, 192 | Activation::Relu, 193 | ], 194 | }, 195 | Mutations::AddConnection { chance: 0.1 }, 196 | Mutations::AddRecurrentConnection { chance: 0.01 }, 197 | ], 198 | } 199 | } 200 | } 201 | 202 | impl Parameters { 203 | /// The basic parameters allow for the mutations of weights (100% of the time 50% of the weights are mutated), new nodes (1% chance) and new connections (10% chance) and are meant to quickly get a general feel for how this crate works. 204 | /// All nodes use the [`Activation::Tanh`] function. 205 | pub fn basic(number_of_inputs: usize, number_of_outputs: usize) -> Self { 206 | Self { 207 | structure: Structure::basic(number_of_inputs, number_of_outputs), 208 | mutations: vec![ 209 | Mutations::ChangeWeights { 210 | chance: 1.0, 211 | percent_perturbed: 0.5, 212 | standard_deviation: 0.1, 213 | }, 214 | Mutations::AddNode { 215 | chance: 0.01, 216 | activation_pool: vec![Activation::Tanh], 217 | }, 218 | Mutations::AddConnection { chance: 0.1 }, 219 | ], 220 | } 221 | } 222 | } 223 | 224 | /// This struct describes the invariants of the ANN structure. 225 | #[derive(Debug, Clone, Deserialize, Serialize)] 226 | pub struct Structure { 227 | /// Number of input nodes. 228 | pub number_of_inputs: usize, 229 | /// Number of output nodes. 230 | pub number_of_outputs: usize, 231 | /// Percent of input nodes initially connected to all poutput nodes. 232 | pub percent_of_connected_inputs: f64, 233 | /// Activation function for all output nodes. 234 | pub outputs_activation: Activation, 235 | /// Seed to generate the initial node ids. 236 | pub seed: u64, 237 | } 238 | 239 | impl Default for Structure { 240 | fn default() -> Self { 241 | Self { 242 | number_of_inputs: 1, 243 | number_of_outputs: 1, 244 | percent_of_connected_inputs: 1.0, 245 | outputs_activation: Activation::Tanh, 246 | seed: 42, 247 | } 248 | } 249 | } 250 | 251 | impl Structure { 252 | /// The basic structure connects every input to every output, uses a standard deviation of 0.1 for sampling weight mutations and caps weights between [-1, 1]. 253 | pub fn basic(number_of_inputs: usize, number_of_outputs: usize) -> Self { 254 | Self { 255 | number_of_inputs, 256 | number_of_outputs, 257 | ..Default::default() 258 | } 259 | } 260 | } 261 | 262 | impl Parameters { 263 | pub fn new(path: &str) -> Result { 264 | let mut s = Config::new(); 265 | 266 | // Start off by merging in the "default" configuration file 267 | s.merge(File::with_name(path))?; 268 | 269 | // You can deserialize (and thus freeze) the entire configuration as 270 | s.try_into() 271 | } 272 | } 273 | -------------------------------------------------------------------------------- /src/genome/compatibility_distance.rs: -------------------------------------------------------------------------------- 1 | use crate::Genome; 2 | 3 | /// Mechanism to compute distances between genomes. 4 | /// 5 | /// Compatibility distance is a concept introduced in [NEAT] and defines a distance metric between genomes. 6 | /// It can be useful for other evolutionary mechanisms such as speciation. 7 | /// 8 | /// Three aspects amount to the resulting difference: 9 | /// - the amount of identical a.k.a shared connections between the genomes 10 | /// - the total weight difference between shared connections 11 | /// - the number of different activations in identical nodes 12 | /// 13 | /// Each aspect gives a normalized value between 0 and 1 and is then weighted by the corresponding factor. 14 | /// The computed difference is the normalized combination of the weighted aspects. 15 | /// 16 | /// For details read [here] part 2.5.1. 17 | /// 18 | /// [NEAT]: http://nn.cs.utexas.edu/downloads/papers/stanley.ec02.pdf 19 | /// [here]: https://silvan.codes/assets/pdf/SET-NEAT_Thesis.pdf 20 | /// 21 | /// # Example 22 | /// ``` 23 | /// # use set_genome::{Genome, Parameters, CompatibilityDistance}; 24 | /// let distance = CompatibilityDistance::with_factors(1.0, 1.0, 1.0); 25 | /// 26 | /// let parameters = Parameters::basic(10, 10); 27 | /// 28 | /// // Randomly initialize two genomes. 29 | /// let genome_one = Genome::initialized(¶meters); 30 | /// let genome_two = Genome::initialized(¶meters); 31 | /// 32 | /// // Both calls are equivalent. 33 | /// assert!(distance.between(&genome_one, &genome_two) > 0.0); 34 | /// assert!(CompatibilityDistance::compatability_distance(&genome_one, &genome_two, 1.0, 1.0, 1.0).0 > 0.0); 35 | /// ``` 36 | pub struct CompatibilityDistance { 37 | factor_connections: f64, 38 | factor_weights: f64, 39 | factor_activations: f64, 40 | } 41 | 42 | impl CompatibilityDistance { 43 | pub fn with_factors( 44 | factor_connections: f64, 45 | factor_weights: f64, 46 | factor_activations: f64, 47 | ) -> Self { 48 | Self { 49 | factor_connections, 50 | factor_weights, 51 | factor_activations, 52 | } 53 | } 54 | 55 | pub fn between(&self, genome_0: &Genome, genome_1: &Genome) -> f64 { 56 | let Self { 57 | factor_connections, 58 | factor_weights, 59 | factor_activations, 60 | } = *self; 61 | 62 | CompatibilityDistance::compatability_distance( 63 | genome_0, 64 | genome_1, 65 | factor_connections, 66 | factor_weights, 67 | factor_activations, 68 | ) 69 | .0 70 | } 71 | 72 | /// Directly compute the compatability distance. 73 | /// 74 | /// The result is a 4-tuple of: 75 | /// - the overall difference 76 | /// - the scaled connection difference 77 | /// - the scaled weight difference 78 | /// - the scaled activation difference 79 | /// 80 | /// # Example 81 | /// ``` 82 | /// # use set_genome::{Genome, Parameters, CompatibilityDistance}; 83 | /// let parameters = Parameters::basic(10, 10); 84 | /// 85 | /// // Randomly initialize two genomes. 86 | /// let genome_one = Genome::initialized(¶meters); 87 | /// let genome_two = Genome::initialized(¶meters); 88 | /// 89 | /// assert!(CompatibilityDistance::compatability_distance(&genome_one, &genome_two, 1.0, 1.0, 1.0).0 > 0.0); 90 | /// 91 | /// assert_eq!(CompatibilityDistance::compatability_distance(&genome_one, &genome_one, 1.0, 1.0, 1.0).1, 0.0); 92 | /// assert_eq!(CompatibilityDistance::compatability_distance(&genome_one, &genome_one, 1.0, 1.0, 1.0).2, 0.0); 93 | /// assert_eq!(CompatibilityDistance::compatability_distance(&genome_one, &genome_one, 1.0, 1.0, 1.0).3, 0.0); 94 | /// ``` 95 | pub fn compatability_distance( 96 | genome_0: &Genome, 97 | genome_1: &Genome, 98 | factor_connections: f64, 99 | factor_weights: f64, 100 | factor_activations: f64, 101 | ) -> (f64, f64, f64, f64) { 102 | let mut weight_difference = 0.0; 103 | let mut activation_difference = 0.0; 104 | 105 | let matching_connections_count = (genome_0 106 | .feed_forward 107 | .iterate_matching_genes(&genome_1.feed_forward) 108 | .inspect(|(connection_0, connection_1)| { 109 | weight_difference += (connection_0.weight - connection_1.weight).abs(); 110 | }) 111 | .count() 112 | + genome_0 113 | .recurrent 114 | .iterate_matching_genes(&genome_1.recurrent) 115 | .inspect(|(connection_0, connection_1)| { 116 | weight_difference += (connection_0.weight - connection_1.weight).abs(); 117 | }) 118 | .count()) as f64; 119 | 120 | let different_connections_count = (genome_0 121 | .feed_forward 122 | .iterate_unique_genes(&genome_1.feed_forward) 123 | .count() 124 | + genome_0 125 | .recurrent 126 | .iterate_unique_genes(&genome_1.recurrent) 127 | .count()) as f64; 128 | 129 | let matching_nodes_count = genome_0 130 | .hidden 131 | .iterate_matching_genes(&genome_1.hidden) 132 | .inspect(|(node_0, node_1)| { 133 | if node_0.activation != node_1.activation { 134 | activation_difference += 1.0; 135 | } 136 | }) 137 | .count() as f64; 138 | 139 | // Connection weights are capped between 1.0 and -1.0. So the maximum difference per matching connection is 2.0. 140 | let maximum_weight_difference = matching_connections_count * 2.0; 141 | 142 | // percent of different genes, considering all unique genes from both genomes 143 | let scaled_connection_difference = factor_connections * different_connections_count 144 | / (matching_connections_count + different_connections_count); 145 | 146 | // average weight differences , considering matching connection genes 147 | let scaled_weight_difference = factor_weights 148 | * if maximum_weight_difference > 0.0 { 149 | weight_difference / maximum_weight_difference 150 | } else { 151 | 0.0 152 | }; 153 | 154 | // percent of different activation functions, considering matching nodes genes 155 | let scaled_activation_difference = factor_activations 156 | * if matching_nodes_count > 0.0 { 157 | activation_difference / matching_nodes_count 158 | } else { 159 | 0.0 160 | }; 161 | 162 | let overall_scaled_distance = (scaled_connection_difference 163 | + scaled_weight_difference 164 | + scaled_activation_difference) 165 | / (factor_connections + factor_weights + factor_activations); 166 | 167 | ( 168 | overall_scaled_distance, 169 | scaled_connection_difference, 170 | scaled_weight_difference, 171 | scaled_activation_difference, 172 | ) 173 | } 174 | } 175 | 176 | #[cfg(test)] 177 | mod tests { 178 | use crate::{ 179 | activations::Activation, genes::Genes, 180 | genome::compatibility_distance::CompatibilityDistance, Connection, Genome, Id, Node, 181 | }; 182 | 183 | #[test] 184 | fn compatability_distance_same_genome() { 185 | let genome_0 = Genome { 186 | inputs: Genes(vec![Node::input(Id(0), 0)].iter().cloned().collect()), 187 | outputs: Genes( 188 | vec![Node::output(Id(1), 0, Activation::Linear)] 189 | .iter() 190 | .cloned() 191 | .collect(), 192 | ), 193 | 194 | feed_forward: Genes( 195 | vec![Connection::new(Id(0), 1.0, Id(1))] 196 | .iter() 197 | .cloned() 198 | .collect(), 199 | ), 200 | ..Default::default() 201 | }; 202 | 203 | let genome_1 = genome_0.clone(); 204 | 205 | let delta = 206 | CompatibilityDistance::compatability_distance(&genome_0, &genome_1, 1.0, 0.4, 0.0).0; 207 | 208 | assert!(delta.abs() < f64::EPSILON); 209 | } 210 | 211 | #[test] 212 | fn compatability_distance_different_weight_genome() { 213 | let genome_0 = Genome { 214 | inputs: Genes(vec![Node::input(Id(0), 0)].iter().cloned().collect()), 215 | outputs: Genes( 216 | vec![Node::output(Id(1), 0, Activation::Linear)] 217 | .iter() 218 | .cloned() 219 | .collect(), 220 | ), 221 | 222 | feed_forward: Genes( 223 | vec![Connection::new(Id(0), 0.0, Id(1))] 224 | .iter() 225 | .cloned() 226 | .collect(), 227 | ), 228 | ..Default::default() 229 | }; 230 | 231 | let mut genome_1 = genome_0.clone(); 232 | 233 | genome_1 234 | .feed_forward 235 | .replace(Connection::new(Id(0), 1.0, Id(1))); 236 | 237 | let factor_weight = 2.0; 238 | 239 | let delta = CompatibilityDistance::compatability_distance( 240 | &genome_0, 241 | &genome_1, 242 | 0.0, 243 | factor_weight, 244 | 0.0, 245 | ) 246 | .0; 247 | 248 | // factor 2 times 2 expressed difference over 2 possible difference over factor 2 249 | assert!((delta - factor_weight * 1.0 / 2.0 / factor_weight).abs() < f64::EPSILON); 250 | } 251 | 252 | #[test] 253 | fn compatability_distance_different_connection_genome() { 254 | let genome_0 = Genome { 255 | inputs: Genes(vec![Node::input(Id(0), 0)].iter().cloned().collect()), 256 | outputs: Genes( 257 | vec![Node::output(Id(1), 0, Activation::Linear)] 258 | .iter() 259 | .cloned() 260 | .collect(), 261 | ), 262 | 263 | feed_forward: Genes( 264 | vec![Connection::new(Id(0), 1.0, Id(1))] 265 | .iter() 266 | .cloned() 267 | .collect(), 268 | ), 269 | ..Default::default() 270 | }; 271 | 272 | let mut genome_1 = genome_0.clone(); 273 | 274 | genome_1 275 | .feed_forward 276 | .insert(Connection::new(Id(0), 1.0, Id(2))); 277 | genome_1 278 | .feed_forward 279 | .insert(Connection::new(Id(2), 2.0, Id(1))); 280 | 281 | let delta = 282 | CompatibilityDistance::compatability_distance(&genome_0, &genome_1, 2.0, 0.0, 0.0).0; 283 | 284 | // factor 2 times 2 different genes over 3 total genes over factor 2 285 | assert!((delta - 2.0 * 2.0 / 3.0 / 2.0).abs() < f64::EPSILON); 286 | } 287 | } 288 | -------------------------------------------------------------------------------- /src/genome.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | collections::HashSet, 3 | hash::{Hash, Hasher}, 4 | }; 5 | 6 | use crate::{ 7 | genes::{Connection, Genes, Id, Node}, 8 | parameters::Structure, 9 | }; 10 | 11 | use rand::{rngs::SmallRng, seq::SliceRandom, thread_rng, Rng, SeedableRng}; 12 | use seahash::SeaHasher; 13 | use serde::{Deserialize, Serialize}; 14 | 15 | mod compatibility_distance; 16 | 17 | pub use compatibility_distance::CompatibilityDistance; 18 | 19 | /// This is the core data structure this crate revoles around. 20 | /// 21 | /// A genome can be changed by mutation (a random alteration of its structure) or by crossing in another genome (recombining their matching parts). 22 | /// A lot of additional information explaining details of the structure can be found in the [thesis] that developed this idea. 23 | /// More and more knowledge from there will find its way into this documentaion over time. 24 | /// 25 | /// [thesis]: https://www.silvan.codes/SET-NEAT_Thesis.pdf 26 | #[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] 27 | pub struct Genome { 28 | pub inputs: Genes, 29 | pub hidden: Genes, 30 | pub outputs: Genes, 31 | pub feed_forward: Genes, 32 | pub recurrent: Genes, 33 | } 34 | 35 | impl Genome { 36 | /// Creates a new genome according to the [`Structure`] it is given. 37 | /// It generates all necessary identities based on an RNG seeded from a hash of the I/O configuration of the structure. 38 | /// This allows genomes of identical I/O configuration to be crossed over in a meaningful way. 39 | pub fn new(structure: &Structure) -> Self { 40 | let mut seed_hasher = SeaHasher::new(); 41 | structure.number_of_inputs.hash(&mut seed_hasher); 42 | structure.number_of_outputs.hash(&mut seed_hasher); 43 | structure.seed.hash(&mut seed_hasher); 44 | 45 | let mut rng = SmallRng::seed_from_u64(seed_hasher.finish()); 46 | 47 | Genome { 48 | inputs: (0..structure.number_of_inputs) 49 | .map(|order| Node::input(Id(rng.gen::()), order)) 50 | .collect(), 51 | outputs: (0..structure.number_of_outputs) 52 | .map(|order| { 53 | Node::output(Id(rng.gen::()), order, structure.outputs_activation) 54 | }) 55 | .collect(), 56 | ..Default::default() 57 | } 58 | } 59 | 60 | /// Returns an iterator over references to all node genes (input + hidden + output) in the genome. 61 | pub fn nodes(&self) -> impl Iterator { 62 | self.inputs 63 | .iter() 64 | .chain(self.hidden.iter()) 65 | .chain(self.outputs.iter()) 66 | } 67 | 68 | pub fn contains(&self, id: Id) -> bool { 69 | let fake_node = &Node::input(id, 0); 70 | self.inputs.contains(fake_node) 71 | || self.hidden.contains(fake_node) 72 | || self.outputs.contains(fake_node) 73 | } 74 | 75 | /// Returns an iterator over references to all connection genes (feed-forward + recurrent) in the genome. 76 | pub fn connections(&self) -> impl Iterator { 77 | self.feed_forward.iter().chain(self.recurrent.iter()) 78 | } 79 | 80 | /// Initializes a genome, i.e. connects the in the [`Structure`] configured percent of inputs to all outputs by creating connection genes with random weights. 81 | pub fn init(&mut self, structure: &Structure) { 82 | let rng = &mut SmallRng::from_rng(thread_rng()).unwrap(); 83 | 84 | let mut possible_inputs = self.inputs.iter().collect::>(); 85 | possible_inputs.shuffle(rng); 86 | 87 | for input in possible_inputs.iter().take( 88 | (structure.percent_of_connected_inputs * structure.number_of_inputs as f64).ceil() 89 | as usize, 90 | ) { 91 | // connect to every output 92 | for output in self.outputs.iter() { 93 | assert!(self.feed_forward.insert(Connection::new( 94 | input.id, 95 | Connection::weight_perturbation(0.0, 0.1, rng), 96 | output.id 97 | ))); 98 | } 99 | } 100 | } 101 | 102 | /// Returns the sum of connection genes inside the genome (feed-forward + recurrent). 103 | pub fn len(&self) -> usize { 104 | self.feed_forward.len() + self.recurrent.len() 105 | } 106 | 107 | /// Is true when no connection genes are present in the genome. 108 | pub fn is_empty(&self) -> bool { 109 | self.feed_forward.is_empty() && self.recurrent.is_empty() 110 | } 111 | 112 | /// Cross-in another genome. 113 | /// For connection genes present in both genomes flip a coin to determine the weight inside the new genome. 114 | /// For node genes present in both genomes flip a coin to determine the activation function inside the new genome. 115 | /// Any structure not present in other is taken over unchanged from `self`. 116 | pub fn cross_in(&self, other: &Self) -> Self { 117 | // Instantiating an RNG for every call might slow things down. 118 | let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); 119 | 120 | let feed_forward = self.feed_forward.cross_in(&other.feed_forward, &mut rng); 121 | let recurrent = self.recurrent.cross_in(&other.recurrent, &mut rng); 122 | let hidden = self.hidden.cross_in(&other.hidden, &mut rng); 123 | 124 | Genome { 125 | feed_forward, 126 | recurrent, 127 | hidden, 128 | // use input and outputs from fitter, but they should be identical with weaker 129 | inputs: self.inputs.clone(), 130 | outputs: self.outputs.clone(), 131 | ..Default::default() 132 | } 133 | } 134 | 135 | /// Check if connecting `start_node` and `end_node` would introduce a circle into the ANN structure. 136 | /// Think about the ANN as a graph for this, if you follow the connection arrows, can you reach `start_node` from `end_node`? 137 | pub fn would_form_cycle(&self, start_node: &Node, end_node: &Node) -> bool { 138 | let mut to_visit = vec![end_node.id]; 139 | let mut visited = HashSet::new(); 140 | 141 | while let Some(node) = to_visit.pop() { 142 | if !visited.contains(&node) { 143 | visited.insert(node); 144 | for connection in self 145 | .feed_forward 146 | .iter() 147 | .filter(|connection| connection.input == node) 148 | { 149 | if connection.output == start_node.id { 150 | return true; 151 | } else { 152 | to_visit.push(connection.output) 153 | } 154 | } 155 | } 156 | } 157 | false 158 | } 159 | 160 | /// Check if a node gene has more than one connection gene pointing to it. 161 | pub fn has_alternative_input(&self, node: Id, exclude: Id) -> bool { 162 | self.connections() 163 | .filter(|connection| connection.output == node) 164 | .any(|connection| connection.input != exclude) 165 | } 166 | 167 | /// Check if a node gene has more than one connection gene leaving it. 168 | pub fn has_alternative_output(&self, node: Id, exclude: Id) -> bool { 169 | self.connections() 170 | .filter(|connection| connection.input == node) 171 | .any(|connection| connection.output != exclude) 172 | } 173 | 174 | /// Get the encoded neural network as a string in [DOT][1] format. 175 | /// 176 | /// The output can be [visualized here][2] for example. 177 | /// 178 | /// [1]: https://www.graphviz.org/doc/info/lang.html 179 | /// [2]: https://dreampuf.github.io/GraphvizOnline 180 | pub fn dot(genome: &Self) -> String { 181 | let mut dot = "digraph {\n".to_owned(); 182 | dot.push_str("\tgraph [splines=curved ranksep=8]\n"); 183 | 184 | dot.push_str("\tsubgraph cluster_inputs {\n"); 185 | dot.push_str("\t\tgraph [label=\"Inputs\"]\n"); 186 | dot.push_str("\t\tnode [color=\"#D6B656\", fillcolor=\"#FFF2CC\", style=\"filled\"]\n"); 187 | dot.push_str("\n"); 188 | for node in genome.inputs.iter() { 189 | // fill color: FFF2CC 190 | // line color: D6B656 191 | 192 | dot.push_str(&format!( 193 | "\t\t{} [label={:?}];\n", 194 | node.id.0, node.activation 195 | )); 196 | } 197 | dot.push_str("\t}\n"); 198 | 199 | dot.push_str("\tsubgraph hidden {\n"); 200 | dot.push_str("\t\tgraph [label=\"Hidden\" rank=\"same\"]\n"); 201 | dot.push_str("\t\tnode [color=\"#6C8EBF\", fillcolor=\"#DAE8FC\", style=\"filled\"]\n"); 202 | dot.push_str("\n"); 203 | for node in genome.hidden.iter() { 204 | // fill color: DAE8FC 205 | // line color: 6C8EBF 206 | 207 | dot.push_str(&format!( 208 | "\t\t{} [label={:?}];\n", 209 | node.id.0, node.activation 210 | )); 211 | } 212 | dot.push_str("\t}\n"); 213 | 214 | dot.push_str("\tsubgraph cluster_outputs {\n"); 215 | dot.push_str("\t\tgraph [label=\"Outputs\" labelloc=\"b\"]\n"); 216 | dot.push_str("\t\tnode [color=\"#9673A6\", fillcolor=\"#E1D5E7\", style=\"filled\"]\n"); 217 | dot.push_str("\n"); 218 | for node in genome.outputs.iter() { 219 | // fill color: E1D5E7 220 | // line color: 9673A6 221 | 222 | dot.push_str(&format!( 223 | "\t\t{} [label={:?}];\n", 224 | node.id.0, node.activation 225 | )); 226 | } 227 | dot.push_str("\t}\n"); 228 | 229 | dot.push_str("\n"); 230 | 231 | dot.push_str("\tsubgraph feedforward_connections {\n"); 232 | dot.push_str("\n"); 233 | for connection in genome.feed_forward.iter() { 234 | dot.push_str(&format!( 235 | "\t\t{0} -> {1} [label=\"\" arrowsize={3:?} penwidth={3:?} tooltip={2:?} labeltooltip={2:?}];\n", 236 | connection.input.0, 237 | connection.output.0, 238 | connection.weight, 239 | connection.weight.abs() * 0.95 + 0.05 240 | )); 241 | } 242 | dot.push_str("\t}\n"); 243 | 244 | dot.push_str("\tsubgraph recurrent_connections {\n"); 245 | dot.push_str("\t\tedge [color=\"#FF8000\"]\n"); 246 | dot.push_str("\n"); 247 | for connection in genome.recurrent.iter() { 248 | // color: FF8000 249 | 250 | dot.push_str(&format!( 251 | "\t\t{0} -> {1} [label=\"\" arrowsize={3:?} penwidth={3:?} tooltip={2:?} labeltooltip={2:?}];\n", 252 | connection.input.0, 253 | connection.output.0, 254 | connection.weight, 255 | connection.weight.abs() * 0.95 + 0.05 256 | )); 257 | } 258 | dot.push_str("\t}\n"); 259 | 260 | dot.push_str("}\n"); 261 | dot 262 | } 263 | } 264 | 265 | #[cfg(test)] 266 | mod tests { 267 | use std::hash::{Hash, Hasher}; 268 | 269 | use rand::thread_rng; 270 | use seahash::SeaHasher; 271 | 272 | use super::Genome; 273 | use crate::{ 274 | genes::{Activation, Connection, Genes, Id, Node}, 275 | Mutations, Parameters, Structure, 276 | }; 277 | 278 | #[test] 279 | fn find_alternative_input() { 280 | let genome = Genome { 281 | inputs: Genes( 282 | vec![Node::input(Id(0), 0), Node::input(Id(1), 1)] 283 | .iter() 284 | .cloned() 285 | .collect(), 286 | ), 287 | outputs: Genes( 288 | vec![Node::output(Id(2), 0, Activation::Linear)] 289 | .iter() 290 | .cloned() 291 | .collect(), 292 | ), 293 | feed_forward: Genes( 294 | vec![ 295 | Connection::new(Id(0), 1.0, Id(2)), 296 | Connection::new(Id(1), 1.0, Id(2)), 297 | ] 298 | .iter() 299 | .cloned() 300 | .collect(), 301 | ), 302 | ..Default::default() 303 | }; 304 | 305 | assert!(genome.has_alternative_input(Id(2), Id(1))) 306 | } 307 | 308 | #[test] 309 | fn find_no_alternative_input() { 310 | let genome = Genome { 311 | inputs: Genes(vec![Node::input(Id(0), 0)].iter().cloned().collect()), 312 | outputs: Genes( 313 | vec![Node::output(Id(1), 0, Activation::Linear)] 314 | .iter() 315 | .cloned() 316 | .collect(), 317 | ), 318 | feed_forward: Genes( 319 | vec![Connection::new(Id(0), 1.0, Id(1))] 320 | .iter() 321 | .cloned() 322 | .collect(), 323 | ), 324 | ..Default::default() 325 | }; 326 | 327 | assert!(!genome.has_alternative_input(Id(1), Id(0))) 328 | } 329 | 330 | #[test] 331 | fn find_alternative_output() { 332 | let genome = Genome { 333 | inputs: Genes(vec![Node::input(Id(0), 0)].iter().cloned().collect()), 334 | outputs: Genes( 335 | vec![ 336 | Node::output(Id(2), 0, Activation::Linear), 337 | Node::output(Id(1), 0, Activation::Linear), 338 | ] 339 | .iter() 340 | .cloned() 341 | .collect(), 342 | ), 343 | feed_forward: Genes( 344 | vec![ 345 | Connection::new(Id(0), 1.0, Id(1)), 346 | Connection::new(Id(0), 1.0, Id(2)), 347 | ] 348 | .iter() 349 | .cloned() 350 | .collect(), 351 | ), 352 | ..Default::default() 353 | }; 354 | 355 | assert!(genome.has_alternative_output(Id(0), Id(1))) 356 | } 357 | 358 | #[test] 359 | fn find_no_alternative_output() { 360 | let genome = Genome { 361 | inputs: Genes(vec![Node::input(Id(0), 0)].iter().cloned().collect()), 362 | outputs: Genes( 363 | vec![Node::output(Id(1), 0, Activation::Linear)] 364 | .iter() 365 | .cloned() 366 | .collect(), 367 | ), 368 | feed_forward: Genes( 369 | vec![Connection::new(Id(0), 1.0, Id(1))] 370 | .iter() 371 | .cloned() 372 | .collect(), 373 | ), 374 | ..Default::default() 375 | }; 376 | 377 | assert!(!genome.has_alternative_output(Id(0), Id(1))) 378 | } 379 | 380 | #[test] 381 | fn crossover() { 382 | let parameters = Parameters::default(); 383 | 384 | let mut genome_0 = Genome::initialized(¶meters); 385 | let mut genome_1 = Genome::initialized(¶meters); 386 | 387 | let rng = &mut thread_rng(); 388 | 389 | // mutate genome_0 390 | Mutations::add_node(&Activation::all(), &mut genome_0, rng); 391 | 392 | // mutate genome_1 393 | Mutations::add_node(&Activation::all(), &mut genome_1, rng); 394 | Mutations::add_node(&Activation::all(), &mut genome_1, rng); 395 | 396 | // shorter genome is fitter genome 397 | let offspring = genome_0.cross_in(&genome_1); 398 | 399 | assert_eq!(offspring.hidden.len(), 1); 400 | assert_eq!(offspring.feed_forward.len(), 3); 401 | } 402 | 403 | #[test] 404 | fn detect_no_cycle() { 405 | let parameters = Parameters::default(); 406 | 407 | let genome = Genome::initialized(¶meters); 408 | 409 | let input = genome.inputs.iter().next().unwrap(); 410 | let output = genome.outputs.iter().next().unwrap(); 411 | 412 | assert!(!genome.would_form_cycle(&input, &output)); 413 | } 414 | 415 | #[test] 416 | fn detect_cycle() { 417 | let parameters = Parameters::default(); 418 | 419 | let genome = Genome::initialized(¶meters); 420 | 421 | let input = genome.inputs.iter().next().unwrap(); 422 | let output = genome.outputs.iter().next().unwrap(); 423 | 424 | assert!(genome.would_form_cycle(&output, &input)); 425 | } 426 | 427 | #[test] 428 | fn crossover_no_cycle() { 429 | // assumption: 430 | // crossover of equal fitness genomes should not produce cycles 431 | // prerequisits: 432 | // genomes with equal fitness (0.0 in this case) 433 | // "mirrored" structure as simplest example 434 | 435 | let mut genome_0 = Genome { 436 | inputs: Genes(vec![Node::input(Id(0), 0)].iter().cloned().collect()), 437 | outputs: Genes( 438 | vec![Node::output(Id(1), 0, Activation::Linear)] 439 | .iter() 440 | .cloned() 441 | .collect(), 442 | ), 443 | hidden: Genes( 444 | vec![ 445 | Node::hidden(Id(2), Activation::Tanh), 446 | Node::hidden(Id(3), Activation::Tanh), 447 | ] 448 | .iter() 449 | .cloned() 450 | .collect(), 451 | ), 452 | feed_forward: Genes( 453 | vec![ 454 | Connection::new(Id(0), 1.0, Id(2)), 455 | Connection::new(Id(2), 1.0, Id(1)), 456 | Connection::new(Id(0), 1.0, Id(3)), 457 | Connection::new(Id(3), 1.0, Id(1)), 458 | ] 459 | .iter() 460 | .cloned() 461 | .collect(), 462 | ), 463 | ..Default::default() 464 | }; 465 | 466 | let mut genome_1 = genome_0.clone(); 467 | 468 | // insert connectio one way in genome0 469 | genome_0 470 | .feed_forward 471 | .insert(Connection::new(Id(2), 1.0, Id(3))); 472 | 473 | // insert connection the other way in genome1 474 | genome_1 475 | .feed_forward 476 | .insert(Connection::new(Id(3), 1.0, Id(2))); 477 | 478 | let offspring = genome_0.cross_in(&genome_1); 479 | 480 | for connection0 in offspring.feed_forward.iter() { 481 | for connection1 in offspring.feed_forward.iter() { 482 | assert!( 483 | !(connection0.input == connection1.output 484 | && connection0.output == connection1.input) 485 | ) 486 | } 487 | } 488 | } 489 | 490 | #[test] 491 | fn hash_genome() { 492 | let genome_0 = Genome { 493 | inputs: Genes( 494 | vec![Node::input(Id(1), 0), Node::input(Id(0), 0)] 495 | .iter() 496 | .cloned() 497 | .collect(), 498 | ), 499 | outputs: Genes( 500 | vec![Node::output(Id(2), 0, Activation::Linear)] 501 | .iter() 502 | .cloned() 503 | .collect(), 504 | ), 505 | 506 | feed_forward: Genes( 507 | vec![Connection::new(Id(0), 1.0, Id(1))] 508 | .iter() 509 | .cloned() 510 | .collect(), 511 | ), 512 | ..Default::default() 513 | }; 514 | 515 | let genome_1 = Genome { 516 | inputs: Genes( 517 | vec![Node::input(Id(0), 0), Node::input(Id(1), 0)] 518 | .iter() 519 | .cloned() 520 | .collect(), 521 | ), 522 | outputs: Genes( 523 | vec![Node::output(Id(2), 0, Activation::Linear)] 524 | .iter() 525 | .cloned() 526 | .collect(), 527 | ), 528 | 529 | feed_forward: Genes( 530 | vec![Connection::new(Id(0), 1.0, Id(1))] 531 | .iter() 532 | .cloned() 533 | .collect(), 534 | ), 535 | ..Default::default() 536 | }; 537 | 538 | assert_eq!(genome_0, genome_1); 539 | 540 | let mut hasher = SeaHasher::new(); 541 | genome_0.hash(&mut hasher); 542 | let genome_0_hash = hasher.finish(); 543 | 544 | let mut hasher = SeaHasher::new(); 545 | genome_1.hash(&mut hasher); 546 | let genome_1_hash = hasher.finish(); 547 | 548 | assert_eq!(genome_0_hash, genome_1_hash); 549 | } 550 | 551 | #[test] 552 | fn create_dot_from_genome() { 553 | let genome = Genome { 554 | inputs: Genes(vec![Node::input(Id(0), 0)].iter().cloned().collect()), 555 | outputs: Genes( 556 | vec![Node::output(Id(1), 0, Activation::Linear)] 557 | .iter() 558 | .cloned() 559 | .collect(), 560 | ), 561 | hidden: Genes( 562 | vec![Node::hidden(Id(2), Activation::Tanh)] 563 | .iter() 564 | .cloned() 565 | .collect(), 566 | ), 567 | feed_forward: Genes( 568 | vec![ 569 | Connection::new(Id(0), 0.25795942718883524, Id(2)), 570 | Connection::new(Id(2), -0.09736946507786626, Id(1)), 571 | ] 572 | .iter() 573 | .cloned() 574 | .collect(), 575 | ), 576 | recurrent: Genes( 577 | vec![Connection::new(Id(1), 0.19777863112749228, Id(2))] 578 | .iter() 579 | .cloned() 580 | .collect(), 581 | ), 582 | ..Default::default() 583 | }; 584 | 585 | // let dot = "digraph {\n\t0 [label=Linear color=\"#D6B656\" fillcolor=\"#FFF2CC\" style=\"filled\"];\n\t2 [label=Tanh color=\"#6C8EBF\" fillcolor=\"#DAE8FC\" style=\"filled\"];\n\t1 [label=Linear color=\"#9673A6\" fillcolor=\"#E1D5E7\" style=\"filled\"];\n\t0 -> 2 [label=0.25795942718883524];\n\t2 -> 1 [label=0.09736946507786626];\n\t1 -> 2 [label=0.19777863112749228 color=\"#FF8000\"];\n}\n"; 586 | 587 | let dot = "digraph { 588 | \tgraph [splines=curved ranksep=8] 589 | \tsubgraph cluster_inputs { 590 | \t\tgraph [label=\"Inputs\"] 591 | \t\tnode [color=\"#D6B656\", fillcolor=\"#FFF2CC\", style=\"filled\"] 592 | 593 | \t\t0 [label=Linear]; 594 | \t} 595 | \tsubgraph hidden { 596 | \t\tgraph [label=\"Hidden\" rank=\"same\"] 597 | \t\tnode [color=\"#6C8EBF\", fillcolor=\"#DAE8FC\", style=\"filled\"] 598 | 599 | \t\t2 [label=Tanh]; 600 | \t} 601 | \tsubgraph cluster_outputs { 602 | \t\tgraph [label=\"Outputs\" labelloc=\"b\"] 603 | \t\tnode [color=\"#9673A6\", fillcolor=\"#E1D5E7\", style=\"filled\"] 604 | 605 | \t\t1 [label=Linear]; 606 | \t} 607 | 608 | \tsubgraph feedforward_connections { 609 | 610 | \t\t0 -> 2 [label=\"\" arrowsize=0.29506145582939347 penwidth=0.29506145582939347 tooltip=0.25795942718883524 labeltooltip=0.25795942718883524]; 611 | \t\t2 -> 1 [label=\"\" arrowsize=0.14250099182397294 penwidth=0.14250099182397294 tooltip=-0.09736946507786626 labeltooltip=-0.09736946507786626]; 612 | \t} 613 | \tsubgraph recurrent_connections { 614 | \t\tedge [color=\"#FF8000\"] 615 | 616 | \t\t1 -> 2 [label=\"\" arrowsize=0.23788969957111766 penwidth=0.23788969957111766 tooltip=0.19777863112749228 labeltooltip=0.19777863112749228]; 617 | \t} 618 | } 619 | "; 620 | assert_eq!(&Genome::dot(&genome), dot) 621 | } 622 | 623 | #[test] 624 | fn print_big_dot() { 625 | let parameters = Parameters { 626 | structure: Structure { 627 | number_of_inputs: 10, 628 | number_of_outputs: 10, 629 | percent_of_connected_inputs: 0.2, 630 | ..Default::default() 631 | }, 632 | mutations: vec![ 633 | Mutations::ChangeWeights { 634 | chance: 1.0, 635 | percent_perturbed: 0.5, 636 | standard_deviation: 0.1, 637 | }, 638 | Mutations::ChangeActivation { 639 | chance: 0.05, 640 | activation_pool: vec![ 641 | Activation::Linear, 642 | Activation::Sigmoid, 643 | Activation::Tanh, 644 | Activation::Gaussian, 645 | Activation::Step, 646 | Activation::Sine, 647 | Activation::Cosine, 648 | Activation::Inverse, 649 | Activation::Absolute, 650 | Activation::Relu, 651 | ], 652 | }, 653 | Mutations::AddNode { 654 | chance: 0.005, 655 | activation_pool: vec![ 656 | Activation::Linear, 657 | Activation::Sigmoid, 658 | Activation::Tanh, 659 | Activation::Gaussian, 660 | Activation::Step, 661 | Activation::Sine, 662 | Activation::Cosine, 663 | Activation::Inverse, 664 | Activation::Absolute, 665 | Activation::Relu, 666 | ], 667 | }, 668 | Mutations::AddConnection { chance: 0.01 }, 669 | Mutations::AddRecurrentConnection { chance: 0.01 }, 670 | ], 671 | }; 672 | let mut genome = Genome::initialized(¶meters); 673 | 674 | for _ in 0..1000 { 675 | genome.mutate(¶meters); 676 | } 677 | 678 | print!("{}", Genome::dot(&genome)); 679 | } 680 | } 681 | --------------------------------------------------------------------------------