├── .gitignore ├── Cargo.toml ├── examples └── cart-pole │ ├── network.bin │ ├── Cargo.toml │ └── src │ ├── main.rs │ └── gui.rs ├── environments ├── cart-pole │ ├── src │ │ ├── utils.rs │ │ └── lib.rs │ └── Cargo.toml └── tictactoe │ ├── Cargo.toml │ └── src │ └── main.rs ├── environment ├── Cargo.toml └── src │ └── lib.rs ├── export ├── Cargo.toml └── src │ └── lib.rs ├── core ├── src │ ├── lib.rs │ ├── connection.rs │ ├── node.rs │ ├── genome │ │ ├── connection.rs │ │ ├── node.rs │ │ ├── crossover.rs │ │ └── mod.rs │ ├── neat │ │ ├── reporter.rs │ │ ├── speciation.rs │ │ ├── configuration.rs │ │ └── mod.rs │ ├── activation.rs │ ├── reporting.rs │ ├── graph.rs │ ├── aggregations.rs │ ├── network.rs │ ├── speciation │ │ ├── distance.rs │ │ └── mod.rs │ └── mutations.rs ├── Cargo.toml └── LICENSE ├── LICENSE └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | /**/target 2 | /**/Cargo.lock 3 | 4 | .DS_Store 5 | /**/.DS_Store 6 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | members = ["core", "export", "environment", "environments/*", "examples/*"] 3 | -------------------------------------------------------------------------------- /examples/cart-pole/network.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stjepangolemac/neat-rs/HEAD/examples/cart-pole/network.bin -------------------------------------------------------------------------------- /environments/cart-pole/src/utils.rs: -------------------------------------------------------------------------------- 1 | use std::f64::consts::PI; 2 | 3 | pub fn to_radians(degrees: f64) -> f64 { 4 | degrees * PI / 180. 5 | } 6 | -------------------------------------------------------------------------------- /environment/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "neat-environment" 3 | version = "0.1.0" 4 | authors = ["Stjepan Golemac "] 5 | edition = "2018" 6 | -------------------------------------------------------------------------------- /export/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "neat-export" 3 | version = "0.1.0" 4 | authors = ["Stjepan Golemac "] 5 | edition = "2018" 6 | 7 | [dependencies] 8 | neat-core = { path ="../core", features= ["network-serde"] } 9 | bincode = "1.3.1" 10 | -------------------------------------------------------------------------------- /core/src/lib.rs: -------------------------------------------------------------------------------- 1 | mod activation; 2 | mod aggregations; 3 | mod connection; 4 | mod genome; 5 | mod mutations; 6 | mod neat; 7 | mod network; 8 | mod node; 9 | pub mod reporting; 10 | mod speciation; 11 | 12 | pub use genome::*; 13 | pub use neat::*; 14 | pub use network::*; 15 | -------------------------------------------------------------------------------- /environment/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub trait Environment { 2 | type State; 3 | type Input; 4 | 5 | fn state(&self) -> Self::State; 6 | fn step(&mut self, input: Self::Input) -> Result<(), ()>; 7 | 8 | fn done(&self) -> bool; 9 | fn reset(&mut self); 10 | 11 | fn render(&self); 12 | 13 | fn fitness(&self) -> f64; 14 | } 15 | -------------------------------------------------------------------------------- /environments/cart-pole/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "neat-environment-cart-pole" 3 | version = "0.1.0" 4 | authors = ["Stjepan Golemac "] 5 | edition = "2018" 6 | 7 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 8 | 9 | [dependencies] 10 | neat-environment = { path = "../../environment" } 11 | rand = "0.8.0" 12 | -------------------------------------------------------------------------------- /environments/tictactoe/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "tictactoe" 3 | version = "0.1.0" 4 | authors = ["Stjepan Golemac "] 5 | edition = "2018" 6 | 7 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 8 | 9 | [dependencies] 10 | rand = "0.8.0" 11 | neat-environment = { path = "../../environment" } 12 | neat-core = { path = "../../core" } 13 | -------------------------------------------------------------------------------- /core/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "neat-core" 3 | version = "0.1.0" 4 | authors = ["Stjepan Golemac "] 5 | edition = "2018" 6 | 7 | [dependencies] 8 | rand = "0.7.3" 9 | rand_distr = "0.3.0" 10 | rayon = "1.5.0" 11 | serde = { version = "1.0.118", features=["derive"], optional = true } 12 | uuid = { version = "0.8.1", features = ["v4"] } 13 | 14 | [features] 15 | network-serde = ["serde"] 16 | -------------------------------------------------------------------------------- /examples/cart-pole/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "cart-pole" 3 | version = "0.1.0" 4 | authors = ["Stjepan Golemac "] 5 | edition = "2018" 6 | 7 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 8 | 9 | [dependencies] 10 | neat-core = { path = "../../core" } 11 | neat-environment-cart-pole = { path = "../../environments/cart-pole" } 12 | neat-export = { path = "../../export" } 13 | nannou = "0.15.0" 14 | -------------------------------------------------------------------------------- /core/src/connection.rs: -------------------------------------------------------------------------------- 1 | use crate::genome::connection::ConnectionGene; 2 | 3 | #[derive(Debug)] 4 | #[cfg_attr( 5 | feature = "network-serde", 6 | derive(serde::Serialize, serde::Deserialize) 7 | )] 8 | pub struct Connection { 9 | pub from: usize, 10 | pub to: usize, 11 | pub weight: f64, 12 | } 13 | 14 | impl From<&ConnectionGene> for Connection { 15 | fn from(g: &ConnectionGene) -> Self { 16 | Connection { 17 | from: g.from, 18 | to: g.to, 19 | weight: g.weight, 20 | } 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /core/src/node.rs: -------------------------------------------------------------------------------- 1 | use crate::activation::ActivationKind; 2 | use crate::aggregations::Aggregation; 3 | use crate::genome::node::NodeGene; 4 | 5 | #[derive(Debug, Clone, PartialEq, Hash)] 6 | #[cfg_attr( 7 | feature = "network-serde", 8 | derive(serde::Serialize, serde::Deserialize) 9 | )] 10 | pub enum NodeKind { 11 | Input, 12 | Hidden, 13 | Output, 14 | Constant, 15 | } 16 | 17 | #[derive(Debug)] 18 | #[cfg_attr( 19 | feature = "network-serde", 20 | derive(serde::Serialize, serde::Deserialize) 21 | )] 22 | pub struct Node { 23 | pub kind: NodeKind, 24 | pub aggregation: Aggregation, 25 | pub activation: ActivationKind, 26 | pub bias: f64, 27 | pub value: Option, 28 | } 29 | 30 | impl From<&NodeGene> for Node { 31 | fn from(g: &NodeGene) -> Self { 32 | Node { 33 | kind: g.kind.clone(), 34 | activation: g.activation.clone(), 35 | bias: g.bias, 36 | value: None, 37 | aggregation: g.aggregation.clone(), 38 | } 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Stjepan Golemac 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /core/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [year] [fullname] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /core/src/genome/connection.rs: -------------------------------------------------------------------------------- 1 | use rand::random; 2 | use std::hash::{Hash, Hasher}; 3 | 4 | #[derive(Debug, Clone)] 5 | pub struct ConnectionGene { 6 | pub from: usize, 7 | pub to: usize, 8 | pub weight: f64, 9 | pub disabled: bool, 10 | } 11 | 12 | impl ConnectionGene { 13 | pub fn new(from: usize, to: usize) -> Self { 14 | ConnectionGene { 15 | from, 16 | to, 17 | weight: random::() * 2. - 1., 18 | disabled: false, 19 | } 20 | } 21 | 22 | pub fn innovation_number(&self) -> usize { 23 | let a = self.from; 24 | let b = self.to; 25 | 26 | let first_part = (a + b) * (a + b + 1); 27 | let second_part = b; 28 | 29 | first_part.checked_div(2).unwrap() + second_part 30 | } 31 | } 32 | 33 | impl PartialEq for ConnectionGene { 34 | fn eq(&self, other: &Self) -> bool { 35 | self.from == other.from 36 | && self.to == other.to 37 | && self.disabled == other.disabled 38 | && (self.weight - other.weight).abs() < f64::EPSILON 39 | } 40 | } 41 | 42 | impl Eq for ConnectionGene {} 43 | 44 | impl Hash for ConnectionGene { 45 | fn hash(&self, state: &mut H) { 46 | self.from.hash(state); 47 | self.to.hash(state); 48 | self.disabled.hash(state); 49 | self.weight.to_bits().hash(state); 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /core/src/genome/node.rs: -------------------------------------------------------------------------------- 1 | use crate::activation::ActivationKind; 2 | use crate::aggregations::Aggregation; 3 | use crate::node::NodeKind; 4 | use rand::random; 5 | use std::hash::{Hash, Hasher}; 6 | 7 | #[derive(Debug, Clone)] 8 | pub struct NodeGene { 9 | pub kind: NodeKind, 10 | pub aggregation: Aggregation, 11 | pub activation: ActivationKind, 12 | pub bias: f64, 13 | } 14 | 15 | impl NodeGene { 16 | pub fn new(kind: NodeKind) -> Self { 17 | let aggregation = random(); 18 | let activation = match kind { 19 | NodeKind::Input => ActivationKind::Input, 20 | _ => random(), 21 | }; 22 | let bias: f64 = match kind { 23 | NodeKind::Input => 0., 24 | _ => random::() * 2. - 1., 25 | }; 26 | 27 | NodeGene { 28 | aggregation, 29 | kind, 30 | activation, 31 | bias, 32 | } 33 | } 34 | } 35 | 36 | impl PartialEq for NodeGene { 37 | fn eq(&self, other: &Self) -> bool { 38 | self.kind == other.kind 39 | && self.aggregation == other.aggregation 40 | && self.activation == other.activation 41 | && (self.bias - other.bias).abs() < f64::EPSILON 42 | } 43 | } 44 | 45 | impl Eq for NodeGene {} 46 | 47 | impl Hash for NodeGene { 48 | fn hash(&self, state: &mut H) { 49 | self.kind.hash(state); 50 | self.aggregation.hash(state); 51 | self.activation.hash(state); 52 | self.bias.to_bits().hash(state); 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /export/src/lib.rs: -------------------------------------------------------------------------------- 1 | use neat_core::Network; 2 | use std::fs::{read, write}; 3 | use std::path::Path; 4 | 5 | pub fn to_bytes(network: &Network) -> Vec { 6 | bincode::serialize(network).unwrap() 7 | } 8 | 9 | pub fn from_bytes(bytes: &[u8]) -> Network { 10 | bincode::deserialize(bytes).unwrap() 11 | } 12 | 13 | pub fn to_file>(path: S, network: &Network) { 14 | write(path, to_bytes(&network)).unwrap(); 15 | } 16 | 17 | pub fn from_file>(path: S) -> Network { 18 | from_bytes(&read(path).unwrap()) 19 | } 20 | 21 | #[cfg(test)] 22 | mod tests { 23 | use super::*; 24 | use neat_core::Genome; 25 | 26 | #[test] 27 | fn to_bytes_works() { 28 | let network: Network = (&Genome::new(3, 1)).into(); 29 | 30 | to_bytes(&network); 31 | } 32 | 33 | #[test] 34 | fn from_bytes_works() { 35 | let mut network: Network = (&Genome::new(3, 1)).into(); 36 | let output_before = network.forward_pass(vec![1., 2., 3.]); 37 | 38 | let bytes = to_bytes(&network); 39 | let mut imported_network = from_bytes(&bytes); 40 | 41 | let output_after = imported_network.forward_pass(vec![1., 2., 3.]); 42 | 43 | assert_eq!(output_before, output_after); 44 | } 45 | 46 | #[test] 47 | fn file_import_export_works() { 48 | let filename = "network.bin"; 49 | 50 | let mut network: Network = (&Genome::new(3, 1)).into(); 51 | let output_before = network.forward_pass(vec![1., 2., 3.]); 52 | 53 | to_file(filename, &network); 54 | let mut imported_network = from_file(filename); 55 | 56 | let output_after = imported_network.forward_pass(vec![1., 2., 3.]); 57 | 58 | assert_eq!(output_before, output_after); 59 | 60 | std::fs::remove_file(filename).unwrap(); 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /core/src/neat/reporter.rs: -------------------------------------------------------------------------------- 1 | use super::NEAT; 2 | 3 | pub type Hook = fn(i: usize, &NEAT) -> (); 4 | 5 | pub struct Reporter { 6 | hooks: Vec<(usize, Hook)>, 7 | } 8 | 9 | impl Reporter { 10 | pub fn new() -> Self { 11 | Reporter { hooks: vec![] } 12 | } 13 | 14 | pub fn register(&mut self, every: usize, hook: Hook) { 15 | self.hooks.push((every, hook)); 16 | } 17 | 18 | pub fn report(&self, i: usize, system: &NEAT) { 19 | self.hooks 20 | .iter() 21 | .filter(|(every, _)| i % every == 0) 22 | .for_each(|(_, hook)| hook(i, system)); 23 | } 24 | } 25 | 26 | #[cfg(test)] 27 | mod tests { 28 | use super::*; 29 | 30 | #[test] 31 | fn print_every() { 32 | use crate::neat::NEAT; 33 | 34 | let mut reporter = Reporter::new(); 35 | 36 | reporter.register(1, |i, _| { 37 | dbg!(i); 38 | }); 39 | 40 | let system = NEAT::new(1, 1, |_| 0.); 41 | 42 | for i in 1..=10 { 43 | reporter.report(i, &system); 44 | } 45 | } 46 | 47 | #[test] 48 | fn print_every_3() { 49 | use crate::neat::NEAT; 50 | 51 | let mut reporter = Reporter::new(); 52 | 53 | reporter.register(3, |i, _| { 54 | dbg!(i); 55 | }); 56 | 57 | let system = NEAT::new(1, 1, |_| 0.); 58 | 59 | for i in 1..=10 { 60 | reporter.report(i, &system); 61 | } 62 | } 63 | 64 | #[test] 65 | fn access_system() { 66 | use crate::neat::NEAT; 67 | 68 | let mut reporter = Reporter::new(); 69 | 70 | reporter.register(4, |i, system| { 71 | println!("At generation {} input count is {}", i, system.inputs); 72 | }); 73 | 74 | let system = NEAT::new(1, 1, |_| 0.); 75 | 76 | for i in 1..=10 { 77 | reporter.report(i, &system); 78 | } 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /examples/cart-pole/src/main.rs: -------------------------------------------------------------------------------- 1 | use neat_core::{Configuration, NEAT}; 2 | use neat_environment_cart_pole::{CartPole, Environment}; 3 | use neat_export::to_file; 4 | 5 | mod gui; 6 | 7 | fn train() { 8 | let mut system = NEAT::new(4, 1, |network| { 9 | let num_simulations = 10; 10 | let max_steps = 1000; 11 | let mut env = CartPole::new(); 12 | 13 | let mut steps_done = 0; 14 | let mut fitness = 0.; 15 | 16 | for _ in 0..num_simulations { 17 | env.reset(); 18 | 19 | for _ in 0..max_steps { 20 | if env.done() { 21 | break; 22 | } 23 | 24 | let state = env.state(); 25 | let network_output = network.forward_pass(state.to_vec()); 26 | let env_input = f64::max(-1., f64::min(1., *network_output.first().unwrap())); 27 | 28 | env.step(env_input).unwrap(); 29 | steps_done += 1; 30 | } 31 | 32 | fitness += env.fitness(); 33 | } 34 | 35 | fitness / num_simulations as f64 36 | }); 37 | 38 | system.set_configuration(Configuration { 39 | population_size: 100, 40 | max_generations: 500, 41 | stagnation_after: 50, 42 | node_cost: 1., 43 | connection_cost: 1., 44 | compatibility_threshold: 2., 45 | ..Default::default() 46 | }); 47 | 48 | system.add_hook(10, |generation, system| { 49 | println!( 50 | "Generation {}, best fitness is {}, {} species alive", 51 | generation, 52 | system.get_best().2, 53 | system.species_set.species().len() 54 | ); 55 | }); 56 | 57 | let (network, fitness) = system.start(); 58 | 59 | // println!( 60 | // "Found network with {} nodes and {} connections, fitness is {}", 61 | // network.nodes.len(), 62 | // network.connections.len(), 63 | // fitness 64 | // ); 65 | 66 | to_file("network.bin", &network); 67 | } 68 | 69 | fn main() { 70 | let param: String = std::env::args().skip(1).take(1).collect(); 71 | 72 | if param == "train" { 73 | train(); 74 | }; 75 | 76 | if param == "visualize" { 77 | gui::visualize(); 78 | }; 79 | } 80 | -------------------------------------------------------------------------------- /core/src/neat/speciation.rs: -------------------------------------------------------------------------------- 1 | use std::cell::RefCell; 2 | use std::collections::HashMap; 3 | use std::rc::Rc; 4 | 5 | use super::configuration::Configuration; 6 | use crate::genome::{Genome, GenomeId}; 7 | 8 | /// Holds all genomes and species, does the process of speciation 9 | #[derive(Debug)] 10 | pub struct GenomeBank { 11 | configuration: Rc>, 12 | genomes: HashMap, 13 | previous_genomes: HashMap, 14 | fitnesses: HashMap, 15 | } 16 | 17 | impl GenomeBank { 18 | pub fn new(configuration: Rc>) -> Self { 19 | GenomeBank { 20 | configuration, 21 | genomes: HashMap::new(), 22 | previous_genomes: HashMap::new(), 23 | fitnesses: HashMap::new(), 24 | } 25 | } 26 | 27 | /// Adds a new genome 28 | pub fn add_genome(&mut self, genome: Genome) { 29 | self.genomes.insert(genome.id(), genome); 30 | } 31 | 32 | /// Clear genomes 33 | pub fn clear(&mut self) { 34 | let mut new_bank = GenomeBank::new(self.configuration.clone()); 35 | new_bank.previous_genomes = self.genomes.clone(); 36 | 37 | *self = new_bank; 38 | } 39 | 40 | /// Returns a reference to the genomes 41 | pub fn genomes(&self) -> &HashMap { 42 | &self.genomes 43 | } 44 | 45 | pub fn previous_genomes(&self) -> &HashMap { 46 | &self.previous_genomes 47 | } 48 | 49 | /// Tracks the fitness of a particular genome 50 | pub fn mark_fitness(&mut self, genome_id: GenomeId, fitness: f64) { 51 | self.fitnesses.insert(genome_id, fitness); 52 | } 53 | 54 | /// Returns a reference to the fitnesses 55 | pub fn fitnesses(&self) -> &HashMap { 56 | &self.fitnesses 57 | } 58 | } 59 | 60 | #[cfg(test)] 61 | mod tests { 62 | use super::*; 63 | 64 | #[test] 65 | fn can_add_genome() { 66 | let configuration: Rc> = Default::default(); 67 | let mut bank = GenomeBank::new(configuration); 68 | 69 | let genome = Genome::new(1, 1); 70 | bank.add_genome(genome); 71 | } 72 | 73 | #[test] 74 | fn can_mark_fitness() { 75 | let configuration: Rc> = Default::default(); 76 | let mut bank = GenomeBank::new(configuration); 77 | 78 | let genome = Genome::new(1, 1); 79 | bank.add_genome(genome.clone()); 80 | 81 | bank.mark_fitness(genome.id(), 1337.); 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /examples/cart-pole/src/gui.rs: -------------------------------------------------------------------------------- 1 | use nannou::prelude::*; 2 | use nannou::ui::prelude::*; 3 | 4 | use neat_core::Network; 5 | use neat_environment_cart_pole::{CartPole, CartPoleConfiguration, Environment}; 6 | use neat_export::from_file; 7 | 8 | pub fn visualize() { 9 | nannou::app(model).update(update).view(view).run(); 10 | } 11 | 12 | struct Model { 13 | network: Option, 14 | env: CartPole, 15 | } 16 | 17 | fn model(app: &App) -> Model { 18 | app.new_window() 19 | .size(480, 320) 20 | .dropped_file(dropped_file) 21 | .key_released(key_released) 22 | .build() 23 | .unwrap(); 24 | 25 | Model { 26 | network: None, 27 | env: CartPole::new(), 28 | } 29 | } 30 | 31 | fn update(_app: &App, model: &mut Model, update: Update) { 32 | if let Some(ref mut network) = model.network { 33 | let state = model.env.state(); 34 | let network_output = network.forward_pass(state.to_vec()); 35 | let env_input = f64::max(-1., f64::min(1., *network_output.first().unwrap())); 36 | 37 | if model.env.step(env_input).is_err() { 38 | model.env.reset(); 39 | } 40 | } 41 | } 42 | 43 | fn view(app: &App, model: &Model, frame: Frame) { 44 | let CartPoleConfiguration { length_pole, .. } = model.env.configuration; 45 | let [x, _, theta, _] = model.env.state(); 46 | 47 | let cart_x = 0. + x as f32 * 100.; 48 | let cart_width = 20.; 49 | let cart_height = 10.; 50 | let pole_height = length_pole as f32 * 100. * 2.; 51 | let pole_rotational_dx = theta.sin() as f32 * pole_height / 2.; 52 | let pole_rotational_dy = theta.cos() as f32 * pole_height / 2.; 53 | 54 | let draw = app.draw(); 55 | 56 | draw.background().rgb(0., 0., 0.); 57 | 58 | draw.rect() 59 | .x(cart_x) 60 | .y(0. - cart_height / 2.) 61 | .w_h(cart_width, cart_height); 62 | draw.rect() 63 | .x(cart_x + pole_rotational_dx) 64 | .y(pole_rotational_dy) 65 | .w_h(1., pole_height) 66 | .rotate(-theta as f32); 67 | 68 | // Write the result of our drawing to the window's frame. 69 | draw.to_frame(app, &frame).unwrap(); 70 | } 71 | 72 | fn dropped_file(_app: &App, model: &mut Model, path: std::path::PathBuf) { 73 | model.network = Some(from_file(path)); 74 | } 75 | 76 | fn key_released(_app: &App, model: &mut Model, key: Key) { 77 | let force = match key { 78 | Key::Left => -1., 79 | Key::Right => 1., 80 | _ => 0., 81 | }; 82 | 83 | model.env.apply_force_to_pole(force); 84 | } 85 | -------------------------------------------------------------------------------- /core/src/activation.rs: -------------------------------------------------------------------------------- 1 | use rand::distributions::{Distribution, Standard}; 2 | use rand::Rng; 3 | 4 | #[derive(Debug, Clone, PartialEq, Hash)] 5 | #[cfg_attr( 6 | feature = "network-serde", 7 | derive(serde::Serialize, serde::Deserialize) 8 | )] 9 | pub enum ActivationKind { 10 | Input, 11 | Tanh, 12 | Relu, 13 | Step, 14 | Logistic, 15 | Identity, 16 | Softsign, 17 | Sinusoid, 18 | Gaussian, 19 | BentIdentity, 20 | Bipolar, 21 | Inverse, 22 | SELU, 23 | } 24 | 25 | impl Distribution for Standard { 26 | fn sample(&self, rng: &mut R) -> ActivationKind { 27 | match rng.gen_range(0, 12) { 28 | 0 => ActivationKind::Tanh, 29 | 1 => ActivationKind::Relu, 30 | 2 => ActivationKind::Step, 31 | 3 => ActivationKind::Logistic, 32 | 4 => ActivationKind::Identity, 33 | 5 => ActivationKind::Softsign, 34 | 6 => ActivationKind::Sinusoid, 35 | 7 => ActivationKind::Gaussian, 36 | 8 => ActivationKind::BentIdentity, 37 | 9 => ActivationKind::Bipolar, 38 | 10 => ActivationKind::SELU, 39 | _ => ActivationKind::Inverse, 40 | } 41 | } 42 | } 43 | 44 | pub fn activate(x: f64, kind: &ActivationKind) -> f64 { 45 | match kind { 46 | ActivationKind::Tanh => x.tanh(), 47 | ActivationKind::Relu => { 48 | if x > 0. { 49 | x 50 | } else { 51 | 0.01 * x 52 | } 53 | } 54 | ActivationKind::Step => { 55 | if x > 0. { 56 | 1. 57 | } else { 58 | 0. 59 | } 60 | } 61 | ActivationKind::Logistic => 1. / (1. + (-x).exp()), 62 | ActivationKind::Identity => x, 63 | ActivationKind::Softsign => x / (1. + x.abs()), 64 | ActivationKind::Sinusoid => x.sin(), 65 | ActivationKind::Gaussian => (-x.powi(2)).exp(), 66 | ActivationKind::BentIdentity => (((x.powi(2) + 1.).sqrt() - 1.) / 2.) + x, 67 | ActivationKind::Bipolar => { 68 | if x > 0. { 69 | 1. 70 | } else { 71 | -1. 72 | } 73 | } 74 | ActivationKind::Inverse => 1. - x, 75 | ActivationKind::SELU => { 76 | let alpha = 1.6732632423543772; 77 | let scale = 1.05070098735548; 78 | 79 | let fx = if x > 0. { x } else { alpha * x.exp() - alpha }; 80 | 81 | fx * scale 82 | } 83 | _ => panic!("Unknown activation function"), 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /core/src/reporting.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | use std::time::Instant; 3 | 4 | use crate::genome::Genome; 5 | 6 | // TODO 7 | pub type Species = usize; 8 | pub struct Generation { 9 | number: usize, 10 | started: Instant, 11 | } 12 | pub struct Population<'system> { 13 | genomes: Vec, 14 | fitnesses: HashMap<&'system Genome, f64>, 15 | } 16 | 17 | pub trait Reporter { 18 | fn on_generation_start(generation: Generation) {} 19 | fn on_generation_end(generation: Generation, population: Population, species: &[Species]) {} 20 | fn on_evaluation_end(population: Population, species: &[Species], best_genome: &Genome) {} 21 | fn on_reproduction_end(population: Population, species: &[Species]) {} 22 | fn on_extinction() {} 23 | fn on_solution_found(generation: Generation, population: Population, best_genome: &Genome) {} 24 | fn on_species_stagnant(species_id: usize, species: &[Species]) {} 25 | } 26 | 27 | struct StdoutReporter; 28 | 29 | impl Reporter for StdoutReporter { 30 | fn on_generation_start(generation: Generation) { 31 | println!("Running generation {}", generation.number); 32 | } 33 | 34 | fn on_generation_end(generation: Generation, population: Population, species: &[Species]) { 35 | println!( 36 | "Generation {} done in {} seconds with {} members in {} species", 37 | generation.number, 38 | generation.started.elapsed().as_secs(), 39 | population.genomes.len(), 40 | species.len() 41 | ); 42 | } 43 | 44 | fn on_evaluation_end(population: Population, species: &[Species], best_genome: &Genome) { 45 | let average_fitness = population 46 | .fitnesses 47 | .iter() 48 | .map(|(_, fitness)| fitness) 49 | .sum::() 50 | / population.fitnesses.len() as f64; 51 | 52 | println!( 53 | "Evaluated members have an average fitness of {}, best genome has {}", 54 | average_fitness, 55 | population.fitnesses.get(best_genome).unwrap() 56 | ); 57 | } 58 | 59 | fn on_extinction() { 60 | println!("All species are extinct"); 61 | } 62 | 63 | fn on_solution_found(generation: Generation, population: Population, best_genome: &Genome) { 64 | println!( 65 | "Best genome found in generation {} and has fitness {}", 66 | generation.number, 67 | population.fitnesses.get(best_genome).unwrap() 68 | ); 69 | } 70 | 71 | fn on_species_stagnant(species_id: usize, species: &[Species]) { 72 | println!("Removing stagnant species {}", species_id); 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # neat-rs 2 | 3 | A neuroevolution framework written in rust. 4 | 5 | ## How to use 6 | 7 | Here is how to train a cart pole balancing neural network, available in the 8 | `examples/` dir. 9 | 10 | The `system` only has 3 parameters: 11 | 12 | - Number of input neurons 13 | - Number of output neurons 14 | - The fitness function that returns a `f64` 15 | 16 | ```rust 17 | let mut system = NEAT::new(4, 1, |network| { 18 | let num_simulations = 10; 19 | let max_steps = 1000; 20 | let mut env = CartPole::new(); 21 | 22 | let mut steps_done = 0; 23 | let mut fitness = 0.; 24 | 25 | for _ in 0..num_simulations { 26 | env.reset(); 27 | 28 | for _ in 0..max_steps { 29 | if env.done() { 30 | break; 31 | } 32 | 33 | let state = env.state(); 34 | let network_output = network.forward_pass(state.to_vec()); 35 | let env_input = f64::max(-1., f64::min(1., *network_output.first().unwrap())); 36 | 37 | env.step(env_input).unwrap(); 38 | steps_done += 1; 39 | } 40 | 41 | fitness += env.fitness(); 42 | } 43 | 44 | fitness / num_simulations as f64 45 | }); 46 | 47 | system.set_configuration(Configuration { 48 | population_size: 100, 49 | max_generations: 500, 50 | stagnation_after: 50, 51 | node_cost: 1., 52 | connection_cost: 1., 53 | compatibility_threshold: 2., 54 | ..Default::default() 55 | }); 56 | 57 | system.add_hook(10, |generation, system| { 58 | println!( 59 | "Generation {}, best fitness is {}, {} species alive", 60 | generation, 61 | system.get_best().2, 62 | system.species_set.species().len() 63 | ); 64 | }); 65 | 66 | let (network, fitness) = system.start(); 67 | ``` 68 | 69 | To start the training go to the `examples/cart-pole/` dir and run the 70 | command below. It will produce a `network.bin` that contains the neural network 71 | "recipe". 72 | 73 | ```bash 74 | cargo run --release -- train 75 | ``` 76 | 77 | After training, you can see the neural network balancing the pole by running 78 | the command below, then dragging the `network.bin` into the window. That will 79 | load and instantiate the neural network. You can apply "wind" with arrow keys. 80 | 81 | ```bash 82 | cargo run --release -- visualize 83 | ``` 84 | 85 | ## Things I'd like to add (but probably won't due to the lack of time) 86 | 87 | - Two pole balancing task (started it in a different branch) 88 | - Recurrent connections 89 | - Extend the `system` so it works with both `f32` and `f64` (might improve performance) 90 | - HyperNEAT 91 | - FS NEAT (feature selection) 92 | 93 | ## Is this useful? 94 | 95 | If somebody finds this code useful, or is even willing to fund further 96 | development, I'd be happy to talk to you. You can reach me by [sending me a message 97 | on Twitter](https://twitter.com/SGolemac). 98 | -------------------------------------------------------------------------------- /core/src/graph.rs: -------------------------------------------------------------------------------- 1 | trait Node {} 2 | 3 | trait Edge { 4 | fn from() -> usize; 5 | fn to() -> usize; 6 | } 7 | 8 | trait Graph { 9 | type NodeType: Node; 10 | type EdgeType: Edge; 11 | 12 | fn nodes(&mut self) -> &mut [Self::NodeType]; 13 | fn edges(&mut self) -> &mut [Self::EdgeType]; 14 | 15 | fn add_node(&mut self, node: Self::NodeType); 16 | fn add_edge(&mut self, edge: Self::EdgeType); 17 | 18 | fn filter_modify_nodes bool, F: Fn(&mut Self::NodeType)>( 19 | &mut self, 20 | p: P, 21 | f: F, 22 | ); 23 | } 24 | 25 | #[cfg(test)] 26 | mod tests { 27 | use super::*; 28 | 29 | #[derive(Debug, PartialEq)] 30 | struct ExampleNode(usize); 31 | 32 | impl Node for ExampleNode {} 33 | 34 | #[derive(Debug)] 35 | struct ExampleEdge; 36 | 37 | impl Edge for ExampleEdge { 38 | fn from() -> usize { 39 | 0 40 | } 41 | 42 | fn to() -> usize { 43 | 0 44 | } 45 | } 46 | 47 | #[derive(Debug)] 48 | struct ExampleGraph { 49 | nodes: Vec, 50 | edges: Vec, 51 | } 52 | 53 | impl ExampleGraph { 54 | pub fn new() -> Self { 55 | ExampleGraph { 56 | nodes: vec![], 57 | edges: vec![], 58 | } 59 | } 60 | } 61 | 62 | impl Graph for ExampleGraph { 63 | type NodeType = ExampleNode; 64 | type EdgeType = ExampleEdge; 65 | 66 | fn nodes(&mut self) -> &mut [ExampleNode] { 67 | &mut self.nodes[..] 68 | } 69 | 70 | fn edges(&mut self) -> &mut [ExampleEdge] { 71 | &mut self.edges[..] 72 | } 73 | 74 | fn add_node(&mut self, node: ExampleNode) { 75 | self.nodes.push(node); 76 | } 77 | 78 | fn add_edge(&mut self, edge: ExampleEdge) { 79 | self.edges.push(edge); 80 | } 81 | 82 | fn filter_modify_nodes bool, F: Fn(&mut Self::NodeType)>( 83 | &mut self, 84 | p: P, 85 | f: F, 86 | ) { 87 | self.nodes().iter_mut().filter(|n| p(n)).for_each(|n| f(n)); 88 | } 89 | } 90 | 91 | #[test] 92 | fn init_graph() { 93 | ExampleGraph::new(); 94 | } 95 | 96 | #[test] 97 | fn add_node() { 98 | let mut graph = ExampleGraph::new(); 99 | 100 | graph.add_node(ExampleNode(1)); 101 | } 102 | 103 | #[test] 104 | fn add_edge() { 105 | let mut graph = ExampleGraph::new(); 106 | 107 | graph.add_edge(ExampleEdge); 108 | } 109 | 110 | #[test] 111 | fn filter_and_modify_nodes() { 112 | let mut graph = ExampleGraph::new(); 113 | 114 | graph.add_node(ExampleNode(1)); 115 | graph.add_node(ExampleNode(2)); 116 | graph.add_node(ExampleNode(3)); 117 | 118 | graph.filter_modify_nodes(|n| n.0 > 1, |n| n.0 = 10); 119 | 120 | let nodes = graph.nodes(); 121 | 122 | assert_eq!(nodes[0].0, 1); 123 | assert_eq!(nodes[1].0, 10); 124 | assert_eq!(nodes[2].0, 10); 125 | } 126 | } 127 | -------------------------------------------------------------------------------- /core/src/neat/configuration.rs: -------------------------------------------------------------------------------- 1 | use std::default::Default; 2 | 3 | use crate::mutations::MutationKind; 4 | 5 | /// Holds configuration options of the whole NEAT process 6 | #[derive(Debug)] 7 | pub struct Configuration { 8 | /// The generations limit of for the evolution process 9 | pub max_generations: usize, 10 | 11 | /// The maximum number of genomes in each generation 12 | pub population_size: usize, 13 | 14 | /// The ratio of champion individuals that are copied to the next generation 15 | pub elitism: f64, 16 | 17 | /// The minimum amount of species that need to exist after the removal of stagnated ones 18 | pub elitism_species: usize, 19 | 20 | /// How many generations of not making progress is considered stagnation 21 | pub stagnation_after: usize, 22 | 23 | /// The fitness cost of every node in the gene 24 | pub node_cost: f64, 25 | 26 | /// The fitness cost of every connection in the gene 27 | pub connection_cost: f64, 28 | 29 | /// The mutation rate of offspring 30 | pub mutation_rate: f64, 31 | 32 | /// The ratio of genomes that will survive to the next generation 33 | pub survival_ratio: f64, 34 | 35 | /// The types of mutations available and their sampling weights 36 | pub mutation_kinds: Vec<(MutationKind, usize)>, 37 | 38 | /// The process will stop if the fitness goal is reached 39 | pub fitness_goal: Option, 40 | 41 | /* 42 | * Genomic distance during speciation 43 | */ 44 | /// Controls how much connections can affect distance 45 | pub distance_connection_disjoint_coefficient: f64, 46 | pub distance_connection_weight_coeficcient: f64, 47 | pub distance_connection_disabled_coefficient: f64, 48 | 49 | /// Controls how much nodes can affect distance 50 | pub distance_node_bias_coefficient: f64, 51 | pub distance_node_activation_coefficient: f64, 52 | pub distance_node_aggregation_coefficient: f64, 53 | 54 | /// A limit on how distant two genomes can be to belong to the same species 55 | pub compatibility_threshold: f64, 56 | } 57 | 58 | impl Default for Configuration { 59 | fn default() -> Self { 60 | Configuration { 61 | max_generations: 1000, 62 | population_size: 150, 63 | elitism: 0.1, 64 | elitism_species: 3, 65 | stagnation_after: 50, 66 | node_cost: 0., 67 | connection_cost: 0., 68 | mutation_rate: 0.5, 69 | survival_ratio: 0.5, 70 | mutation_kinds: default_mutation_kinds(), 71 | fitness_goal: None, 72 | distance_connection_disjoint_coefficient: 1., 73 | distance_connection_weight_coeficcient: 0.5, 74 | distance_connection_disabled_coefficient: 0.5, 75 | distance_node_bias_coefficient: 0.33, 76 | distance_node_activation_coefficient: 0.33, 77 | distance_node_aggregation_coefficient: 0.33, 78 | compatibility_threshold: 3., 79 | } 80 | } 81 | } 82 | 83 | pub fn default_mutation_kinds() -> Vec<(MutationKind, usize)> { 84 | use MutationKind::*; 85 | 86 | vec![ 87 | (AddConnection, 10), 88 | (RemoveConnection, 10), 89 | (AddNode, 10), 90 | (RemoveNode, 10), 91 | (ModifyWeight, 10), 92 | (ModifyBias, 10), 93 | (ModifyActivation, 10), 94 | (ModifyAggregation, 10), 95 | ] 96 | } 97 | -------------------------------------------------------------------------------- /core/src/aggregations.rs: -------------------------------------------------------------------------------- 1 | use rand::distributions::{Distribution, Standard}; 2 | use rand::Rng; 3 | 4 | pub fn aggregate(kind: &Aggregation, components: &[f64]) -> f64 { 5 | use Aggregation::*; 6 | 7 | let func: fn(components: &[f64]) -> f64 = match kind { 8 | Product => product, 9 | Sum => sum, 10 | Max => max, 11 | Min => min, 12 | MaxAbs => maxabs, 13 | Median => median, 14 | Mean => mean, 15 | }; 16 | 17 | func(components) 18 | } 19 | 20 | #[derive(Debug, Clone, PartialEq, Hash)] 21 | #[cfg_attr( 22 | feature = "network-serde", 23 | derive(serde::Serialize, serde::Deserialize) 24 | )] 25 | pub enum Aggregation { 26 | Product, 27 | Sum, 28 | Max, 29 | Min, 30 | MaxAbs, 31 | Median, 32 | Mean, 33 | } 34 | 35 | impl Distribution for Standard { 36 | fn sample(&self, rng: &mut R) -> Aggregation { 37 | use Aggregation::*; 38 | 39 | match rng.gen_range(0, 7) { 40 | 0 => Product, 41 | 1 => Sum, 42 | 2 => Max, 43 | 3 => Min, 44 | 4 => MaxAbs, 45 | 5 => Median, 46 | _ => Mean, 47 | } 48 | } 49 | } 50 | 51 | fn product(components: &[f64]) -> f64 { 52 | components 53 | .iter() 54 | .fold(1., |result, current| result * current) 55 | } 56 | 57 | fn sum(components: &[f64]) -> f64 { 58 | components.iter().sum() 59 | } 60 | 61 | fn max(components: &[f64]) -> f64 { 62 | components.iter().fold( 63 | f64::MIN, 64 | |max, current| if *current > max { *current } else { max }, 65 | ) 66 | } 67 | 68 | fn min(components: &[f64]) -> f64 { 69 | components.iter().fold( 70 | f64::MAX, 71 | |min, current| if *current < min { *current } else { min }, 72 | ) 73 | } 74 | 75 | fn maxabs(components: &[f64]) -> f64 { 76 | let abs_components: Vec = components.iter().map(|component| component.abs()).collect(); 77 | max(&abs_components) 78 | } 79 | 80 | fn median(components: &[f64]) -> f64 { 81 | use std::cmp::Ordering; 82 | 83 | if components.is_empty() { 84 | return 0.; 85 | } 86 | 87 | let mut sorted = components.to_vec(); 88 | sorted.sort_by(|a, b| { 89 | if a < b { 90 | Ordering::Less 91 | } else { 92 | Ordering::Greater 93 | } 94 | }); 95 | 96 | let length = sorted.len(); 97 | let is_length_even = length % 2 == 0; 98 | let median_index = if is_length_even { 99 | length / 2 - 1 100 | } else { 101 | length / 2 102 | }; 103 | 104 | *sorted.get(median_index).unwrap() 105 | } 106 | 107 | fn mean(components: &[f64]) -> f64 { 108 | let sum: f64 = components.iter().sum(); 109 | sum / components.len() as f64 110 | } 111 | 112 | #[cfg(test)] 113 | mod tests { 114 | use super::*; 115 | 116 | #[test] 117 | fn product_works() { 118 | let components = vec![1., 2., 3., 4.]; 119 | 120 | assert!((product(&components) - 24.).abs() < f64::EPSILON); 121 | } 122 | 123 | #[test] 124 | fn sum_works() { 125 | let components = vec![1., 2., 3., 4.]; 126 | 127 | assert!((sum(&components) - 10.).abs() < f64::EPSILON); 128 | } 129 | 130 | #[test] 131 | fn max_works() { 132 | let components = vec![1., 2., 3., 4.]; 133 | 134 | assert!((max(&components) - 4.).abs() < f64::EPSILON); 135 | } 136 | 137 | #[test] 138 | fn min_works() { 139 | let components = vec![1., 2., 3., 4.]; 140 | 141 | assert!((min(&components) - 1.).abs() < f64::EPSILON); 142 | } 143 | 144 | #[test] 145 | fn maxabs_works() { 146 | let components = vec![-5., -3., 1., 2., 3., 4.]; 147 | 148 | assert!((maxabs(&components) - 5.).abs() < f64::EPSILON); 149 | } 150 | 151 | #[test] 152 | fn median_works() { 153 | let components = vec![3., -3., 4., -5., 1., 2.]; 154 | 155 | assert!((median(&components) - 1.).abs() < f64::EPSILON); 156 | } 157 | 158 | #[test] 159 | fn mean_works() { 160 | let components = vec![1., 2., 3., 4.]; 161 | 162 | assert!((mean(&components) - 2.5).abs() < f64::EPSILON); 163 | } 164 | } 165 | -------------------------------------------------------------------------------- /core/src/genome/crossover.rs: -------------------------------------------------------------------------------- 1 | use rand::random; 2 | 3 | use super::{ConnectionGene, Genome, NodeGene}; 4 | 5 | pub fn crossover(a: (&Genome, f64), b: (&Genome, f64)) -> Option { 6 | if (a.0.inputs != b.0.inputs) || (a.0.outputs != b.0.outputs) { 7 | return None; 8 | } 9 | 10 | let mut parent_a = a.0.clone(); 11 | let mut fitness_a = a.1; 12 | 13 | let mut parent_b = b.0.clone(); 14 | let mut fitness_b = b.1; 15 | 16 | // Parent A will always be the fitter one 17 | if fitness_a < fitness_b { 18 | std::mem::swap(&mut parent_a, &mut parent_b); 19 | std::mem::swap(&mut fitness_a, &mut fitness_b); 20 | } 21 | 22 | let mut child = Genome::empty(parent_a.inputs, parent_a.outputs); 23 | 24 | let child_connection_genes: Vec = parent_a 25 | .connection_genes 26 | .iter() 27 | .map(|connection| { 28 | let maybe_counterpart_connection = parent_b 29 | .connection_genes 30 | .iter() 31 | .find(|cb| cb.innovation_number() == connection.innovation_number()); 32 | 33 | // Chooses connection from one of the parents 34 | let chosen_connection = 35 | if let Some(counterpart_connection) = maybe_counterpart_connection { 36 | if random::() < 0.5 { 37 | connection 38 | } else { 39 | counterpart_connection 40 | } 41 | } else { 42 | connection 43 | }; 44 | 45 | /* 46 | * Chooses will the new connection be disabled 47 | * - disabled in both parents, 75% chance it will be disabled 48 | * - enabled in both parents, it will be enabled 49 | * - disabled in one parent, 50% chance it will stay disabled 50 | */ 51 | let new_disabled = if let Some(counterpart_connection) = maybe_counterpart_connection { 52 | match (connection.disabled, counterpart_connection.disabled) { 53 | (true, true) => random::() < 0.75, 54 | (false, false) => false, 55 | _ => random::() < 0.5, 56 | } 57 | } else { 58 | connection.disabled 59 | }; 60 | 61 | let mut new_connection = chosen_connection.clone(); 62 | new_connection.disabled = new_disabled; 63 | 64 | new_connection 65 | }) 66 | .collect(); 67 | 68 | let required_node_count = 1 + child_connection_genes 69 | .iter() 70 | .fold(0, |max, c| usize::max(usize::max(max, c.from), c.to)); 71 | 72 | let child_node_genes: Vec = (0..required_node_count) 73 | .map( 74 | |i| match (parent_a.node_genes.get(i), parent_b.node_genes.get(i)) { 75 | (Some(a), Some(b)) => { 76 | if random::() < 0.5 { 77 | a 78 | } else { 79 | b 80 | } 81 | } 82 | (Some(a), None) => a, 83 | (None, Some(b)) => b, 84 | _ => panic!("Node selection out of bounds"), 85 | }, 86 | ) 87 | .cloned() 88 | .collect(); 89 | 90 | child.connection_genes = child_connection_genes; 91 | child.node_genes = child_node_genes; 92 | 93 | child.node_order().and(Some(child)) 94 | } 95 | 96 | #[cfg(test)] 97 | mod tests { 98 | use super::*; 99 | 100 | #[test] 101 | fn crossover_success() { 102 | let a = Genome::new(2, 2); 103 | let b = Genome::new(2, 2); 104 | 105 | let maybe_child = crossover((&a, 1.), (&b, 2.)); 106 | assert!(maybe_child.is_some()); 107 | } 108 | 109 | #[test] 110 | fn crossover_outputs_wrong() { 111 | let a = Genome::new(2, 3); 112 | let b = Genome::new(2, 2); 113 | 114 | let maybe_child = crossover((&a, 1.), (&b, 2.)); 115 | assert!(maybe_child.is_none()); 116 | } 117 | 118 | #[test] 119 | fn crossover_inputs_wrong() { 120 | let a = Genome::new(3, 2); 121 | let b = Genome::new(2, 2); 122 | 123 | let maybe_child = crossover((&a, 1.), (&b, 2.)); 124 | assert!(maybe_child.is_none()); 125 | } 126 | } 127 | -------------------------------------------------------------------------------- /environments/cart-pole/src/lib.rs: -------------------------------------------------------------------------------- 1 | use rand::{random, thread_rng, Rng}; 2 | 3 | pub use neat_environment::Environment; 4 | use utils::*; 5 | 6 | mod utils; 7 | 8 | pub struct CartPoleConfiguration { 9 | pub gravity: f64, 10 | pub mass_cart: f64, 11 | pub mass_pole: f64, 12 | pub length_pole: f64, 13 | pub time_step: f64, 14 | 15 | pub limit_position: f64, 16 | pub limit_angle_radians: f64, 17 | } 18 | 19 | impl Default for CartPoleConfiguration { 20 | fn default() -> Self { 21 | CartPoleConfiguration { 22 | gravity: 9.8, 23 | mass_cart: 1.0, 24 | mass_pole: 0.1, 25 | length_pole: 0.5, 26 | time_step: 1. / 60., 27 | 28 | limit_position: 2.4, 29 | limit_angle_radians: to_radians(45.), 30 | } 31 | } 32 | } 33 | 34 | pub struct CartPole { 35 | pub configuration: CartPoleConfiguration, 36 | 37 | x: f64, 38 | theta: f64, 39 | dx: f64, 40 | dtheta: f64, 41 | t: f64, 42 | xacc: f64, 43 | tacc: f64, 44 | fitness: f64, 45 | 46 | finished: bool, 47 | } 48 | 49 | impl CartPole { 50 | pub fn new() -> Self { 51 | let configuration: CartPoleConfiguration = Default::default(); 52 | let mut rng = thread_rng(); 53 | 54 | let x = 55 | rng.gen_range(-0.5 * configuration.limit_position..0.5 * configuration.limit_position); 56 | let theta = rng.gen_range( 57 | -0.5 * configuration.limit_angle_radians..0.5 * configuration.limit_angle_radians, 58 | ); 59 | let dx = rng.gen_range(-1f64..1f64); 60 | let dtheta = rng.gen_range(-1f64..1f64); 61 | 62 | CartPole { 63 | configuration, 64 | 65 | x, 66 | theta, 67 | dx, 68 | dtheta, 69 | t: 0., 70 | xacc: 0., 71 | tacc: 0., 72 | fitness: 0., 73 | 74 | finished: false, 75 | } 76 | } 77 | 78 | fn continuous_actuator_force(input: f64) -> f64 { 79 | input * 10. 80 | } 81 | 82 | fn continuous_noisy_actuator_force(input: f64) -> f64 { 83 | (input + random::() * 0.75) * 10. 84 | } 85 | 86 | fn measure_fitness(&mut self) { 87 | let x_component = f64::max(0., self.configuration.limit_position - self.x.abs()); 88 | let theta_component = f64::max( 89 | 0., 90 | self.configuration.limit_angle_radians - self.theta.abs(), 91 | ); 92 | 93 | let step_fitness = 1. - x_component * theta_component; 94 | 95 | self.fitness += step_fitness.powi(2); 96 | } 97 | 98 | fn check_finished(&mut self) { 99 | if self.x.abs() > self.configuration.limit_position 100 | || self.theta.abs() > self.configuration.limit_angle_radians 101 | { 102 | self.finished = true; 103 | } 104 | } 105 | 106 | pub fn apply_force_to_pole(&mut self, force: f64) { 107 | self.dtheta += force; 108 | } 109 | } 110 | 111 | impl Environment for CartPole { 112 | type State = [f64; 4]; 113 | type Input = f64; 114 | 115 | fn state(&self) -> Self::State { 116 | [self.x, self.dx, self.theta, self.dtheta] 117 | } 118 | 119 | fn step(&mut self, input: Self::Input) -> Result<(), ()> { 120 | if input > 1. || input < -1. { 121 | panic!("Input must be between 1 and -1"); 122 | } 123 | if self.done() { 124 | return Err(()); 125 | } 126 | 127 | let force = CartPole::continuous_actuator_force(input); 128 | let xacc_current = self.xacc; 129 | let tacc_current = self.tacc; 130 | let mass_all = self.configuration.mass_pole + self.configuration.mass_cart; 131 | 132 | self.x += self.configuration.time_step * self.dx 133 | + 0.5 * xacc_current * self.configuration.time_step.powi(2); 134 | self.theta += self.configuration.time_step * self.dtheta 135 | + 0.5 * tacc_current * self.configuration.time_step.powi(2); 136 | 137 | let theta_sin = self.theta.sin(); 138 | let theta_cos = self.theta.cos(); 139 | 140 | self.tacc = (self.configuration.gravity * theta_sin 141 | + theta_cos 142 | * (-force 143 | - self.configuration.mass_pole 144 | * self.configuration.length_pole 145 | * self.dtheta.powi(2) 146 | * theta_sin) 147 | / mass_all) 148 | / (self.configuration.length_pole 149 | * (4. / 3. - self.configuration.mass_pole * theta_cos.powi(2) / mass_all)); 150 | self.xacc = (force 151 | + self.configuration.mass_pole 152 | * self.configuration.length_pole 153 | * (self.dtheta.powi(2) * theta_sin - self.tacc * theta_cos)) 154 | / mass_all; 155 | 156 | self.dx += 0.5 * (xacc_current + self.xacc) * self.configuration.time_step; 157 | self.dtheta += 0.5 * (tacc_current + self.tacc) * self.configuration.time_step; 158 | 159 | self.t += self.configuration.time_step; 160 | 161 | self.measure_fitness(); 162 | self.check_finished(); 163 | 164 | Ok(()) 165 | } 166 | 167 | fn done(&self) -> bool { 168 | self.finished 169 | } 170 | 171 | fn fitness(&self) -> f64 { 172 | self.fitness 173 | } 174 | 175 | fn reset(&mut self) { 176 | *self = CartPole::new(); 177 | } 178 | 179 | fn render(&self) { 180 | unimplemented!(); 181 | } 182 | } 183 | 184 | #[cfg(test)] 185 | mod tests { 186 | use super::*; 187 | 188 | #[test] 189 | fn misc() { 190 | let mut env = CartPole::new(); 191 | 192 | for _ in 0..5 { 193 | env.step(1.).unwrap(); 194 | 195 | let state = env.state(); 196 | dbg!(state); 197 | } 198 | 199 | let fitness = env.fitness(); 200 | 201 | dbg!(fitness); 202 | } 203 | } 204 | -------------------------------------------------------------------------------- /core/src/network.rs: -------------------------------------------------------------------------------- 1 | use crate::activation::*; 2 | use crate::aggregations::aggregate; 3 | use crate::connection::*; 4 | use crate::genome::Genome; 5 | use crate::node::*; 6 | 7 | #[derive(Debug)] 8 | #[cfg_attr( 9 | feature = "network-serde", 10 | derive(serde::Serialize, serde::Deserialize) 11 | )] 12 | pub struct Network { 13 | pub input_count: usize, 14 | pub output_count: usize, 15 | pub nodes: Vec, 16 | pub connections: Vec, 17 | node_calculation_order: Vec, 18 | } 19 | 20 | impl Network { 21 | fn is_node_ready(&self, index: usize) -> bool { 22 | let node = self.nodes.get(index).unwrap(); 23 | 24 | let requirements_fullfilled = self.connections.iter().filter(|c| c.to == index).all(|c| { 25 | let from_index = c.from; 26 | let from_node = &self.nodes[from_index]; 27 | 28 | from_node.value.is_some() 29 | }); 30 | let has_no_value = node.value.is_none(); 31 | 32 | requirements_fullfilled && has_no_value 33 | } 34 | 35 | pub fn forward_pass(&mut self, inputs: Vec) -> Vec { 36 | for i in &self.node_calculation_order { 37 | let node = self.nodes.get(*i).unwrap(); 38 | 39 | if matches!(node.kind, NodeKind::Input) { 40 | self.nodes.get_mut(*i).unwrap().value = Some(*inputs.get(*i).unwrap()); 41 | } else { 42 | let components: Vec = self 43 | .connections 44 | .iter() 45 | .filter(|c| c.to == *i) 46 | .map(|c| { 47 | let incoming_value = self.nodes.get(c.from).unwrap().value.unwrap(); 48 | incoming_value * c.weight 49 | }) 50 | .collect(); 51 | 52 | let aggregated = aggregate(&node.aggregation, &components); 53 | let aggregated_with_bias = aggregated + node.bias; 54 | 55 | self.nodes.get_mut(*i).unwrap().value = 56 | Some(activate(aggregated_with_bias, &node.activation)); 57 | } 58 | } 59 | 60 | self.nodes 61 | .iter() 62 | .filter(|n| matches!(n.kind, NodeKind::Output)) 63 | .map(|n| n.value.unwrap()) 64 | .collect() 65 | 66 | // let mut inputs_updated = false; 67 | // let mut nodes_changed = -1; 68 | // let mut nodes_changed_sum = 0; 69 | 70 | // while nodes_changed != 0 { 71 | // nodes_changed = 0; 72 | 73 | // // First pass, update inputs 74 | // if !inputs_updated { 75 | // self.nodes 76 | // .iter_mut() 77 | // .enumerate() 78 | // .filter(|(_, n)| matches!(n.kind, NodeKind::Input)) 79 | // .for_each(|(i, n)| { 80 | // let input_value = *inputs.get(i).expect( 81 | // "Inputs need to be of the same length as the number of input nodes", 82 | // ); 83 | 84 | // n.value = Some(input_value); 85 | // nodes_changed += 1; 86 | // }); 87 | 88 | // inputs_updated = true; 89 | // } 90 | 91 | // // Other passes, update non input nodes 92 | // let mut node_updates: Vec<(usize, f64)> = vec![]; 93 | // self.nodes 94 | // .iter() 95 | // .enumerate() 96 | // .filter(|(i, n)| { 97 | // let is_not_input = !matches!(n.kind, NodeKind::Input); 98 | // let is_ready = self.is_node_ready(*i); 99 | 100 | // is_not_input && is_ready 101 | // }) 102 | // .for_each(|(i, n)| { 103 | // let incoming_connections: Vec<&Connection> = 104 | // self.connections.iter().filter(|c| c.to == i).collect(); 105 | 106 | // let mut value = 0.; 107 | 108 | // for c in incoming_connections { 109 | // let from_node = self.nodes.get(c.from).unwrap(); 110 | // value += from_node.value.unwrap() * c.weight; 111 | // } 112 | 113 | // value += n.bias; 114 | 115 | // node_updates.push((i, value)); 116 | // }); 117 | 118 | // node_updates.iter().for_each(|(i, v)| { 119 | // let n = self.nodes.get_mut(*i).unwrap(); 120 | 121 | // n.value = Some(activate(*v, &n.activation)); 122 | 123 | // nodes_changed += 1; 124 | // }); 125 | 126 | // nodes_changed_sum += nodes_changed; 127 | // } 128 | 129 | // let outputs = self 130 | // .nodes 131 | // .iter() 132 | // .filter(|n| matches!(n.kind, NodeKind::Output)) 133 | // .map(|n| n.value.unwrap()) 134 | // .collect(); 135 | 136 | // // Very important, I forgot this initially :facepalm: 137 | // self.clear_values(); 138 | 139 | // outputs 140 | } 141 | 142 | fn clear_values(&mut self) { 143 | self.nodes.iter_mut().for_each(|n| n.value = None); 144 | } 145 | } 146 | 147 | impl From<&Genome> for Network { 148 | fn from(g: &Genome) -> Self { 149 | let nodes: Vec = g.nodes().iter().map(From::from).collect(); 150 | let connections: Vec = g 151 | .connections() 152 | .iter() 153 | .filter(|c| !c.disabled) 154 | .map(From::from) 155 | .collect(); 156 | 157 | Network { 158 | input_count: g.input_count(), 159 | output_count: g.output_count(), 160 | nodes, 161 | connections, 162 | node_calculation_order: g.node_order().unwrap(), 163 | } 164 | } 165 | } 166 | 167 | #[cfg(test)] 168 | mod tests { 169 | use super::*; 170 | 171 | #[test] 172 | fn init_network() { 173 | let g = Genome::new(1, 1); 174 | Network::from(&g); 175 | } 176 | 177 | #[test] 178 | fn forward_pass() { 179 | let g = Genome::new(2, 1); 180 | let mut n = Network::from(&g); 181 | 182 | let inputs: Vec> = vec![vec![0., 0.], vec![0., 1.], vec![1., 0.], vec![1., 1.]]; 183 | 184 | for i in inputs { 185 | let o = n.forward_pass(i.clone()); 186 | 187 | dbg!(i, o); 188 | } 189 | } 190 | } 191 | -------------------------------------------------------------------------------- /core/src/speciation/distance.rs: -------------------------------------------------------------------------------- 1 | use std::cell::RefCell; 2 | use std::collections::HashMap; 3 | use std::rc::Rc; 4 | 5 | use crate::Configuration; 6 | use crate::{ConnectionGene, Genome}; 7 | 8 | type DistanceKey = String; 9 | pub struct GenomicDistanceCache { 10 | configuration: Rc>, 11 | cache: HashMap, 12 | } 13 | 14 | impl GenomicDistanceCache { 15 | pub fn new(configuration: Rc>) -> Self { 16 | GenomicDistanceCache { 17 | configuration, 18 | cache: HashMap::new(), 19 | } 20 | } 21 | 22 | pub fn get(&mut self, a: &Genome, b: &Genome) -> f64 { 23 | let distance_key = GenomicDistanceCache::make_key(a, b); 24 | 25 | if let Some(distance) = self.cache.get(&distance_key) { 26 | *distance 27 | } else { 28 | let distance = self.distance(a, b); 29 | self.cache.insert(distance_key, distance); 30 | 31 | distance 32 | } 33 | } 34 | 35 | fn distance(&self, a: &Genome, b: &Genome) -> f64 { 36 | let ( 37 | distance_connection_disjoint_coefficient, 38 | distance_connection_weight_coeficcient, 39 | distance_connection_disabled_coefficient, 40 | distance_node_bias_coefficient, 41 | distance_node_activation_coefficient, 42 | distance_node_aggregation_coefficient, 43 | ) = { 44 | let conf = self.configuration.borrow(); 45 | 46 | ( 47 | conf.distance_connection_disjoint_coefficient, 48 | conf.distance_connection_weight_coeficcient, 49 | conf.distance_connection_disabled_coefficient, 50 | conf.distance_node_bias_coefficient, 51 | conf.distance_node_activation_coefficient, 52 | conf.distance_node_aggregation_coefficient, 53 | ) 54 | }; 55 | 56 | let mut distance = 0.; 57 | 58 | let max_connection_genes = usize::max(a.connections().len(), b.connections().len()); 59 | let max_node_genes = usize::max(a.nodes().len(), b.nodes().len()); 60 | 61 | let mut disjoint_connections: Vec<&ConnectionGene> = vec![]; 62 | let mut common_connections: Vec<(&ConnectionGene, &ConnectionGene)> = vec![]; 63 | 64 | let mut disjoint_map: HashMap = HashMap::new(); 65 | a.connections() 66 | .iter() 67 | .chain(b.connections().iter()) 68 | .map(|connection| connection.innovation_number()) 69 | .for_each(|innovation_number| { 70 | if let Some(is_disjoint) = disjoint_map.get_mut(&innovation_number) { 71 | *is_disjoint = false; 72 | } else { 73 | disjoint_map.insert(innovation_number, true); 74 | } 75 | }); 76 | 77 | disjoint_map 78 | .into_iter() 79 | .for_each(|(innovation_number, is_disjoint)| { 80 | if is_disjoint { 81 | let disjoint_connection = a 82 | .connections() 83 | .iter() 84 | .chain(b.connections().iter()) 85 | .find(|connection| connection.innovation_number() == innovation_number) 86 | .unwrap(); 87 | 88 | disjoint_connections.push(disjoint_connection); 89 | } else { 90 | let common_connection_a = a 91 | .connections() 92 | .iter() 93 | .find(|connection| connection.innovation_number() == innovation_number) 94 | .unwrap(); 95 | let common_connection_b = b 96 | .connections() 97 | .iter() 98 | .find(|connection| connection.innovation_number() == innovation_number) 99 | .unwrap(); 100 | 101 | common_connections.push((common_connection_a, common_connection_b)); 102 | } 103 | }); 104 | 105 | let disjoint_factor = 106 | disjoint_connections.len() as f64 * distance_connection_disjoint_coefficient; 107 | 108 | let connections_difference_factor: f64 = common_connections 109 | .iter() 110 | .map(|(connection_a, connection_b)| { 111 | let mut connection_distance = 0.; 112 | 113 | if connection_a.disabled != connection_b.disabled { 114 | connection_distance += 1. * distance_connection_disabled_coefficient; 115 | } 116 | 117 | connection_distance += (connection_a.weight - connection_b.weight).abs() 118 | * distance_connection_weight_coeficcient; 119 | 120 | connection_distance 121 | }) 122 | .sum::(); 123 | 124 | let nodes_difference_factor: f64 = a 125 | .nodes() 126 | .iter() 127 | .zip(b.nodes()) 128 | .map(|(node_a, node_b)| { 129 | let mut node_distance = 0.; 130 | 131 | if node_a.activation != node_b.activation { 132 | node_distance += 1. * distance_node_activation_coefficient; 133 | } 134 | 135 | if node_a.aggregation != node_b.aggregation { 136 | node_distance += 1. * distance_node_aggregation_coefficient; 137 | } 138 | 139 | node_distance += (node_a.bias - node_b.bias).abs() * distance_node_bias_coefficient; 140 | 141 | node_distance 142 | }) 143 | .sum(); 144 | 145 | distance += nodes_difference_factor; 146 | distance += (connections_difference_factor + disjoint_factor) / max_connection_genes as f64; 147 | 148 | distance 149 | } 150 | 151 | pub fn mean(&self) -> f64 { 152 | self.cache.values().sum::() / self.cache.len() as f64 153 | } 154 | 155 | fn make_key<'o>(a: &'o Genome, b: &'o Genome) -> String { 156 | use std::collections::hash_map::DefaultHasher; 157 | use std::hash::{Hash, Hasher}; 158 | 159 | let hash_a = { 160 | let mut hasher = DefaultHasher::new(); 161 | a.hash(&mut hasher); 162 | hasher.finish() 163 | }; 164 | 165 | let hash_b = { 166 | let mut hasher = DefaultHasher::new(); 167 | b.hash(&mut hasher); 168 | hasher.finish() 169 | }; 170 | 171 | hash_a.to_string(); 172 | 173 | if hash_a > hash_b { 174 | hash_a.to_string() + &hash_b.to_string() 175 | } else { 176 | hash_b.to_string() + &hash_a.to_string() 177 | } 178 | } 179 | } 180 | -------------------------------------------------------------------------------- /core/src/speciation/mod.rs: -------------------------------------------------------------------------------- 1 | use std::cell::RefCell; 2 | use std::collections::HashMap; 3 | use std::collections::HashSet; 4 | use std::rc::Rc; 5 | 6 | use crate::Configuration; 7 | use crate::{Genome, GenomeId}; 8 | 9 | use distance::GenomicDistanceCache; 10 | 11 | mod distance; 12 | 13 | pub struct SpeciesSet { 14 | configuration: Rc>, 15 | last_index: Option, 16 | species: HashMap, 17 | } 18 | 19 | impl SpeciesSet { 20 | pub fn new(configuration: Rc>) -> Self { 21 | SpeciesSet { 22 | configuration, 23 | last_index: None, 24 | species: HashMap::new(), 25 | } 26 | } 27 | 28 | pub fn species(&self) -> &HashMap { 29 | &self.species 30 | } 31 | 32 | pub fn speciate( 33 | &mut self, 34 | generation: usize, 35 | current_genomes: &[GenomeId], 36 | all_genomes: &HashMap, 37 | fitnesses: &HashMap, 38 | ) { 39 | let (compatibility_threshold, stagnation_after, elitism_species) = { 40 | let config = self.configuration.borrow(); 41 | 42 | ( 43 | config.compatibility_threshold, 44 | config.stagnation_after, 45 | config.elitism_species, 46 | ) 47 | }; 48 | 49 | let mut distances = GenomicDistanceCache::new(self.configuration.clone()); 50 | 51 | let mut unspeciated_genomes: HashSet = current_genomes.iter().cloned().collect(); 52 | let mut new_species: HashMap = self.species.clone(); 53 | 54 | // Find new representatives for existing species 55 | self.species.iter().for_each(|(species_id, species)| { 56 | let genome_representative = all_genomes.get(&species.representative).unwrap(); 57 | 58 | let (maybe_new_representative_id, _) = current_genomes 59 | .iter() 60 | .map(|genome_id| { 61 | let genome = all_genomes.get(genome_id).unwrap(); 62 | (genome_id, distances.get(genome, genome_representative)) 63 | }) 64 | .filter(|(_, distance)| *distance < compatibility_threshold) 65 | .fold( 66 | (None, f64::MAX), 67 | |(maybe_closest_genome_id, closest_genome_distance), 68 | (genome_id, genome_distance)| { 69 | if maybe_closest_genome_id.is_some() { 70 | if genome_distance < closest_genome_distance { 71 | return (Some(genome_id), genome_distance); 72 | } 73 | } else { 74 | return (Some(genome_id), genome_distance); 75 | } 76 | 77 | (maybe_closest_genome_id, closest_genome_distance) 78 | }, 79 | ); 80 | 81 | if let Some(new_representative_id) = maybe_new_representative_id { 82 | let species = new_species.get_mut(species_id).unwrap(); 83 | species.representative = *new_representative_id; 84 | species.members = vec![*new_representative_id]; 85 | 86 | unspeciated_genomes.remove(&new_representative_id); 87 | } else { 88 | new_species.remove(species_id); 89 | } 90 | }); 91 | 92 | // Put unspeciated genomes into species 93 | unspeciated_genomes.iter().for_each(|genome_id| { 94 | let genome = all_genomes.get(genome_id).unwrap(); 95 | 96 | let (maybe_closest_species_id, _) = { 97 | new_species 98 | .iter() 99 | .map(|(species_id, species)| { 100 | let species_representative_genome = 101 | all_genomes.get(&species.representative).unwrap(); 102 | 103 | ( 104 | species_id, 105 | distances.get(genome, species_representative_genome), 106 | ) 107 | }) 108 | .filter(|(_, distance)| *distance < compatibility_threshold) 109 | .fold( 110 | (None, f64::MAX), 111 | |(maybe_closest_species_id, closest_representative_distance), 112 | (species_id, representative_distance)| { 113 | if maybe_closest_species_id.is_some() { 114 | if representative_distance < closest_representative_distance { 115 | return (Some(*species_id), representative_distance); 116 | } 117 | } else { 118 | return (Some(*species_id), representative_distance); 119 | } 120 | 121 | (maybe_closest_species_id, closest_representative_distance) 122 | }, 123 | ) 124 | }; 125 | 126 | if let Some(closest_species_id) = maybe_closest_species_id { 127 | // Fits into an existing species 128 | new_species 129 | .get_mut(&closest_species_id) 130 | .unwrap() 131 | .members 132 | .push(*genome_id); 133 | } else { 134 | // Needs to go in a brand new species 135 | let species = Species::new(generation, *genome_id, vec![*genome_id]); 136 | let next_species_id = new_species.keys().max().or(Some(&0)).cloned().unwrap(); 137 | 138 | new_species.insert(next_species_id + 1, species); 139 | } 140 | }); 141 | 142 | // Calculate fitness for every species 143 | new_species.iter_mut().for_each(|(_, mut species)| { 144 | let member_fitnesses: Vec = species 145 | .members 146 | .iter() 147 | .map(|member_genome_id| *fitnesses.get(member_genome_id).unwrap()) 148 | .collect(); 149 | 150 | let species_mean_fitness = 151 | member_fitnesses.iter().sum::() / member_fitnesses.len() as f64; 152 | let best_previous_fitness = species 153 | .fitness_history 154 | .iter() 155 | .cloned() 156 | .fold(f64::MIN, f64::max); 157 | 158 | if species_mean_fitness > best_previous_fitness { 159 | species.last_improved = generation; 160 | } 161 | 162 | species.fitness = Some(species_mean_fitness); 163 | species.fitness_history.push(species_mean_fitness); 164 | }); 165 | 166 | // Calculate adjusted fitness for every species 167 | let species_fitnesses: Vec = new_species 168 | .iter() 169 | .map(|(_, species)| species.fitness.unwrap()) 170 | .collect(); 171 | 172 | new_species.iter_mut().for_each(|(_, mut species)| { 173 | let own_exp = species.fitness.unwrap().exp(); 174 | let exp_sum: f64 = species_fitnesses.iter().map(|fitness| fitness.exp()).sum(); 175 | 176 | let adjusted_fitness = own_exp / exp_sum; 177 | 178 | species.adjusted_fitness = Some(adjusted_fitness); 179 | }); 180 | 181 | // Remove stagnated species 182 | let mut stagnated_ids_and_adjusted_fitnesses: Vec<(usize, f64)> = new_species 183 | .iter() 184 | .filter(|(_, species)| generation - species.last_improved >= stagnation_after) 185 | .map(|(id, species)| (*id, species.adjusted_fitness.unwrap())) 186 | .collect(); 187 | 188 | stagnated_ids_and_adjusted_fitnesses.sort_by(|a, b| { 189 | use std::cmp::Ordering::*; 190 | 191 | if a.1 > b.1 { 192 | Less 193 | } else { 194 | Greater 195 | } 196 | }); 197 | 198 | stagnated_ids_and_adjusted_fitnesses 199 | .iter() 200 | .take(usize::max(new_species.len() - elitism_species, 0)) 201 | .for_each(|(id, _)| { 202 | new_species.remove(id).unwrap(); 203 | }); 204 | 205 | // Finally replace old species 206 | self.species = new_species; 207 | } 208 | } 209 | 210 | #[derive(Debug, Clone)] 211 | pub struct Species { 212 | created: usize, 213 | 214 | last_improved: usize, 215 | representative: GenomeId, 216 | pub members: Vec, 217 | 218 | fitness: Option, 219 | pub adjusted_fitness: Option, 220 | fitness_history: Vec, 221 | } 222 | 223 | impl Species { 224 | pub fn new(generation: usize, representative: GenomeId, members: Vec) -> Self { 225 | Species { 226 | created: generation, 227 | last_improved: generation, 228 | representative, 229 | members, 230 | fitness: None, 231 | adjusted_fitness: None, 232 | fitness_history: vec![], 233 | } 234 | } 235 | } 236 | 237 | #[cfg(test)] 238 | mod tests { 239 | use super::*; 240 | 241 | #[test] 242 | fn test_hash() { 243 | use std::collections::hash_map::DefaultHasher; 244 | use std::hash::{Hash, Hasher}; 245 | 246 | let genome = Genome::new(5, 3); 247 | 248 | let first_hash = { 249 | let mut hasher = DefaultHasher::new(); 250 | 251 | genome.hash(&mut hasher); 252 | hasher.finish().to_string() 253 | }; 254 | 255 | let second_hash = { 256 | let genome_clone = genome.clone(); 257 | let mut hasher = DefaultHasher::new(); 258 | 259 | genome_clone.hash(&mut hasher); 260 | hasher.finish().to_string() 261 | }; 262 | 263 | dbg!(&first_hash, &second_hash); 264 | 265 | assert_eq!(first_hash, second_hash); 266 | } 267 | } 268 | -------------------------------------------------------------------------------- /environments/tictactoe/src/main.rs: -------------------------------------------------------------------------------- 1 | use rand::random; 2 | 3 | use neat_core::{Configuration, Network, NEAT}; 4 | use neat_environment::Environment; 5 | 6 | #[derive(Clone, Copy, Debug)] 7 | enum Mark { 8 | X, 9 | O, 10 | Empty, 11 | } 12 | 13 | #[derive(Clone, Debug)] 14 | enum Player { 15 | External, 16 | Internal, 17 | } 18 | 19 | type Field = [Mark; 9]; 20 | 21 | #[derive(Debug)] 22 | struct TicTacToe { 23 | field: Field, 24 | first_player: Player, 25 | turn: Player, 26 | } 27 | 28 | impl TicTacToe { 29 | pub fn new() -> Self { 30 | let first_player: Player = if random::() < 0.5 { 31 | Player::External 32 | } else { 33 | Player::Internal 34 | }; 35 | 36 | let mut ttt = TicTacToe { 37 | field: [Mark::Empty; 9], 38 | first_player: first_player.clone(), 39 | turn: first_player.clone(), 40 | }; 41 | 42 | if let Player::Internal = first_player { 43 | ttt.step_internal(); 44 | } 45 | 46 | ttt 47 | } 48 | 49 | fn step_internal(&mut self) { 50 | if self.game_over() || self.is_external_turn() { 51 | return; 52 | } 53 | 54 | let empty_indexes: Vec = self 55 | .field 56 | .iter() 57 | .enumerate() 58 | .filter(|(_, mark)| matches!(mark, Mark::Empty)) 59 | .map(|(index, _)| index) 60 | .collect(); 61 | 62 | let random_index = empty_indexes 63 | .get(random::() % empty_indexes.len()) 64 | .unwrap(); 65 | 66 | let mark_to_place = if matches!(self.first_player, Player::Internal) { 67 | Mark::X 68 | } else { 69 | Mark::O 70 | }; 71 | 72 | *self.field.get_mut(*random_index).unwrap() = mark_to_place; 73 | self.turn = Player::External; 74 | } 75 | 76 | pub fn is_external_first(&self) -> bool { 77 | matches!(self.first_player, Player::External) 78 | } 79 | 80 | pub fn is_external_turn(&self) -> bool { 81 | matches!(self.turn, Player::External) 82 | } 83 | 84 | pub fn external_mark(&self) -> Mark { 85 | match self.first_player { 86 | Player::External => Mark::X, 87 | Player::Internal => Mark::O, 88 | } 89 | } 90 | 91 | fn game_over(&self) -> bool { 92 | let fields_full = self.field.iter().all(|mark| !matches!(mark, Mark::Empty)); 93 | 94 | fields_full || self.did_external_win() || self.did_internal_win() 95 | } 96 | 97 | fn did_external_win(&self) -> bool { 98 | self.did_mark_win(self.external_mark()) 99 | } 100 | 101 | fn did_internal_win(&self) -> bool { 102 | let internal_mark = { 103 | match self.first_player { 104 | Player::Internal => Mark::X, 105 | Player::External => Mark::O, 106 | } 107 | }; 108 | 109 | self.did_mark_win(internal_mark) 110 | } 111 | 112 | fn is_draw(&self) -> bool { 113 | self.game_over() && !self.did_external_win() && !self.did_internal_win() 114 | } 115 | 116 | fn did_mark_win(&self, check_mark: Mark) -> bool { 117 | let winning_lines = [ 118 | [0, 1, 2], 119 | [3, 4, 5], 120 | [6, 7, 8], 121 | [0, 3, 6], 122 | [1, 4, 7], 123 | [2, 5, 8], 124 | [0, 4, 8], 125 | [2, 4, 6], 126 | ]; 127 | 128 | let external_won = winning_lines.iter().any(|line| { 129 | line.iter() 130 | .map(|mark_index| self.field.get(*mark_index).unwrap()) 131 | .all(|mark| match (mark, check_mark) { 132 | (Mark::X, Mark::X) => true, 133 | (Mark::O, Mark::O) => true, 134 | _ => false, 135 | }) 136 | }); 137 | 138 | external_won 139 | } 140 | } 141 | 142 | impl Environment for TicTacToe { 143 | type State = Field; 144 | type Input = usize; 145 | 146 | fn state(&self) -> Self::State { 147 | self.field 148 | } 149 | 150 | fn step(&mut self, input: Self::Input) -> Result<(), ()> { 151 | if input >= 9 { 152 | panic!("Field index out of bounds"); 153 | } 154 | 155 | if self.game_over() || !self.is_external_turn() { 156 | return Err(()); 157 | } 158 | 159 | let mark_to_place = if matches!(self.first_player, Player::External) { 160 | Mark::X 161 | } else { 162 | Mark::O 163 | }; 164 | let mark = self.field.get_mut(input).unwrap(); 165 | 166 | if matches!(mark, Mark::Empty) { 167 | *mark = mark_to_place; 168 | } else { 169 | return Err(()); 170 | } 171 | 172 | self.turn = Player::Internal; 173 | self.step_internal(); 174 | 175 | Ok(()) 176 | } 177 | 178 | fn done(&self) -> bool { 179 | self.game_over() 180 | } 181 | 182 | fn reset(&mut self) { 183 | *self = TicTacToe::new(); 184 | } 185 | 186 | fn render(&self) { 187 | self.field.iter().enumerate().for_each(|(index, mark)| { 188 | let character: String = match mark { 189 | Mark::X => "X".to_owned(), 190 | Mark::O => "O".to_owned(), 191 | Mark::Empty => "_".to_owned(), 192 | }; 193 | 194 | if index % 3 == 0 { 195 | print!("\n"); 196 | } 197 | print!("{} ", character); 198 | }); 199 | 200 | print!("\n\n"); 201 | } 202 | 203 | fn fitness(&self) -> f64 { 204 | if self.did_external_win() { 205 | 1. 206 | } else { 207 | 0. 208 | } 209 | } 210 | } 211 | 212 | fn state_to_inputs(env: &TicTacToe) -> Vec { 213 | let player_mark = env.external_mark(); 214 | 215 | env.state() 216 | .iter() 217 | .map(|mark| match (player_mark, *mark) { 218 | (Mark::X, Mark::X) => 1., 219 | (Mark::O, Mark::O) => 1., 220 | (Mark::X, Mark::O) => -1., 221 | (Mark::O, Mark::X) => -1., 222 | _ => 0., 223 | }) 224 | .collect() 225 | } 226 | 227 | fn move_from_outputs(outputs: &[f64]) -> usize { 228 | outputs 229 | .iter() 230 | .enumerate() 231 | .fold((0, -999.), |(max_index, max_output), (index, output)| { 232 | if output > &max_output { 233 | (index, *output) 234 | } else { 235 | (max_index, max_output) 236 | } 237 | }) 238 | .0 239 | } 240 | 241 | fn play_network(network: &mut Network) { 242 | println!("Playing..."); 243 | 244 | let mut env = TicTacToe::new(); 245 | 246 | let player_mark = env.external_mark(); 247 | println!("Player mark is {:?}", player_mark); 248 | 249 | loop { 250 | env.render(); 251 | 252 | if env.game_over() { 253 | break; 254 | } 255 | 256 | let inputs = state_to_inputs(&env); 257 | let outputs: Vec = network.forward_pass(inputs.clone()); 258 | let max_output_index: usize = move_from_outputs(&outputs); 259 | 260 | if env.step(max_output_index).is_err() { 261 | break; 262 | } 263 | } 264 | 265 | println!("Game over, last state"); 266 | env.render(); 267 | } 268 | 269 | fn main() { 270 | let mut system = NEAT::new(9, 9, |network| { 271 | let games = 100; 272 | let mut turns = 0; 273 | let mut games_won = 0; 274 | let mut games_draw = 0; 275 | 276 | let mut env = TicTacToe::new(); 277 | 278 | for _ in 0..games { 279 | env.reset(); 280 | 281 | loop { 282 | if env.game_over() { 283 | break; 284 | } 285 | 286 | let inputs = state_to_inputs(&env); 287 | let outputs: Vec = network.forward_pass(inputs.clone()); 288 | let max_output_index: usize = move_from_outputs(&outputs); 289 | 290 | if env.step(max_output_index).is_ok() { 291 | turns += 1; 292 | } else { 293 | break; 294 | } 295 | } 296 | 297 | games_won += if env.did_external_win() { 1 } else { 0 }; 298 | games_draw += if env.is_draw() { 1 } else { 0 }; 299 | } 300 | 301 | // games as f64 / (games_won as f64 + games_draw as f64) //+ turns as f64 * 0.01 302 | // turns as f64 / games as f64 303 | (games_won as f64 + games_draw as f64) / games as f64 304 | }); 305 | 306 | system.set_configuration(Configuration { 307 | population_size: 50, 308 | max_generations: 500, 309 | node_cost: 0.001, 310 | connection_cost: 0.0005, 311 | compatibility_threshold: 3., 312 | ..Default::default() 313 | }); 314 | system.add_hook(1, |i, system| { 315 | let (_, _, fitness) = system.get_best(); 316 | 317 | println!("Generation {}, best fitness is {}", i, fitness); 318 | }); 319 | 320 | let (mut network, fitness) = system.start(); 321 | 322 | println!( 323 | "Found network with {} nodes and {} connections, of fitness {}", 324 | network.nodes.len(), 325 | network.connections.len(), 326 | fitness 327 | ); 328 | 329 | for _ in 0..5 { 330 | play_network(&mut network); 331 | } 332 | } 333 | 334 | #[cfg(test)] 335 | mod tests { 336 | use super::*; 337 | 338 | #[test] 339 | fn can_run() { 340 | let mut env = TicTacToe::new(); 341 | 342 | if env.is_external_first() { 343 | println!("I am X"); 344 | } else { 345 | println!("I am O"); 346 | } 347 | 348 | loop { 349 | if env.game_over() { 350 | break; 351 | } 352 | 353 | while env.step(random::() % 9).is_err() {} 354 | } 355 | 356 | println!("I WON: {}", env.did_external_win()); 357 | env.render(); 358 | env.reset(); 359 | } 360 | } 361 | -------------------------------------------------------------------------------- /core/src/neat/mod.rs: -------------------------------------------------------------------------------- 1 | use rand::random; 2 | use rayon::prelude::*; 3 | use std::cell::RefCell; 4 | use std::rc::Rc; 5 | use uuid::Uuid; 6 | 7 | use crate::genome::{crossover, Genome, GenomeId}; 8 | use crate::mutations::MutationKind; 9 | use crate::network::Network; 10 | use crate::speciation::SpeciesSet; 11 | pub use configuration::Configuration; 12 | use reporter::Reporter; 13 | use speciation::GenomeBank; 14 | 15 | mod configuration; 16 | mod reporter; 17 | mod speciation; 18 | 19 | pub struct NEAT { 20 | inputs: usize, 21 | outputs: usize, 22 | fitness_fn: fn(&mut Network) -> f64, 23 | pub genomes: GenomeBank, 24 | pub species_set: SpeciesSet, 25 | configuration: Rc>, 26 | reporter: Reporter, 27 | } 28 | 29 | impl NEAT { 30 | pub fn new(inputs: usize, outputs: usize, fitness_fn: fn(&mut Network) -> f64) -> Self { 31 | let configuration: Rc> = Default::default(); 32 | 33 | NEAT { 34 | inputs, 35 | outputs, 36 | fitness_fn, 37 | genomes: GenomeBank::new(configuration.clone()), 38 | species_set: SpeciesSet::new(configuration.clone()), 39 | configuration, 40 | reporter: Reporter::new(), 41 | } 42 | } 43 | 44 | pub fn set_configuration(&mut self, config: Configuration) { 45 | *self.configuration.borrow_mut() = config; 46 | } 47 | 48 | pub fn start(&mut self) -> (Network, f64) { 49 | let (population_size, max_generations) = { 50 | let config = self.configuration.borrow(); 51 | 52 | (config.population_size, config.max_generations) 53 | }; 54 | 55 | // Create initial genomes 56 | (0..population_size).for_each(|_| { 57 | self.genomes 58 | .add_genome(Genome::new(self.inputs, self.outputs)) 59 | }); 60 | 61 | self.test_fitness(); 62 | 63 | for i in 1..=max_generations { 64 | let current_genome_ids: Vec = 65 | self.genomes.genomes().keys().cloned().collect(); 66 | let previous_and_current_genomes = self 67 | .genomes 68 | .genomes() 69 | .iter() 70 | .chain(self.genomes.previous_genomes()) 71 | .map(|(genome_id, genome)| (genome_id.clone(), genome.clone())) 72 | .collect(); 73 | 74 | self.species_set.speciate( 75 | i, 76 | ¤t_genome_ids, 77 | &previous_and_current_genomes, 78 | self.genomes.fitnesses(), 79 | ); 80 | 81 | let (elitism, population_size, mutation_rate, survival_ratio) = { 82 | let config = self.configuration.borrow(); 83 | 84 | ( 85 | config.elitism, 86 | config.population_size, 87 | config.mutation_rate, 88 | config.survival_ratio, 89 | ) 90 | }; 91 | 92 | let offspring: Vec = self 93 | .species_set 94 | .species() 95 | .values() 96 | .flat_map(|species| { 97 | let offspring_count: usize = (species.adjusted_fitness.unwrap() 98 | * population_size as f64) 99 | .ceil() as usize; 100 | let elites_count: usize = (offspring_count as f64 * elitism).ceil() as usize; 101 | let nonelites_count: usize = offspring_count - elites_count; 102 | 103 | let mut member_ids_and_fitnesses: Vec<(GenomeId, f64)> = species 104 | .members 105 | .iter() 106 | .map(|member_id| { 107 | ( 108 | *member_id, 109 | *self.genomes.fitnesses().get(member_id).unwrap(), 110 | ) 111 | }) 112 | .collect(); 113 | 114 | member_ids_and_fitnesses.sort_by(|a, b| { 115 | use std::cmp::Ordering::*; 116 | 117 | let fitness_a = a.1; 118 | let fitness_b = b.1; 119 | 120 | if fitness_a > fitness_b { 121 | Less 122 | } else { 123 | Greater 124 | } 125 | }); 126 | 127 | // Pick survivors 128 | let surviving_count: usize = 129 | (member_ids_and_fitnesses.len() as f64 * survival_ratio).ceil() as usize; 130 | member_ids_and_fitnesses.truncate(surviving_count); 131 | 132 | let elite_children: Vec = 133 | (0..usize::min(elites_count, member_ids_and_fitnesses.len())) 134 | .map(|elite_index| { 135 | let (elite_genome_id, _) = 136 | member_ids_and_fitnesses.get(elite_index).unwrap(); 137 | let elite_genome = 138 | self.genomes.genomes().get(elite_genome_id).unwrap(); 139 | 140 | elite_genome.clone() 141 | }) 142 | .collect(); 143 | 144 | let crossover_data: Vec<(&Genome, f64, &Genome, f64)> = (0..nonelites_count) 145 | .map(|_| { 146 | let parent_a_index = random::() % member_ids_and_fitnesses.len(); 147 | let parent_b_index = random::() % member_ids_and_fitnesses.len(); 148 | 149 | let (parent_a_id, parent_a_fitness) = 150 | member_ids_and_fitnesses.get(parent_a_index).unwrap(); 151 | let (parent_b_id, parent_b_fitness) = 152 | member_ids_and_fitnesses.get(parent_b_index).unwrap(); 153 | 154 | let parent_a_genome = self.genomes.genomes().get(parent_a_id).unwrap(); 155 | let parent_b_genome = self.genomes.genomes().get(parent_b_id).unwrap(); 156 | 157 | ( 158 | parent_a_genome, 159 | *parent_a_fitness, 160 | parent_b_genome, 161 | *parent_b_fitness, 162 | ) 163 | }) 164 | .collect(); 165 | 166 | let mut crossover_children: Vec = crossover_data 167 | .par_iter() 168 | .map(|(parent_a, fitness_a, parent_b, fitness_b)| { 169 | crossover((parent_a, *fitness_a), (parent_b, *fitness_b)) 170 | }) 171 | .filter(|maybe_genome| maybe_genome.is_some()) 172 | .map(|maybe_genome| maybe_genome.unwrap()) 173 | .collect(); 174 | 175 | let mutations_for_children: Vec> = crossover_children 176 | .iter() 177 | .map(|_| { 178 | if random::() < mutation_rate { 179 | Some(self.pick_mutation()) 180 | } else { 181 | None 182 | } 183 | }) 184 | .collect(); 185 | 186 | crossover_children 187 | .par_iter_mut() 188 | .zip(mutations_for_children) 189 | .for_each(|(child, maybe_mutation)| { 190 | if let Some(mutation) = maybe_mutation { 191 | child.mutate(&mutation); 192 | } 193 | }); 194 | 195 | elite_children 196 | .into_iter() 197 | .chain(crossover_children) 198 | .collect::>() 199 | }) 200 | .collect(); 201 | 202 | self.genomes.clear(); 203 | offspring 204 | .into_iter() 205 | .for_each(|genome| self.genomes.add_genome(genome)); 206 | 207 | self.test_fitness(); 208 | 209 | self.reporter.report(i, &self); 210 | 211 | let goal_reached = { 212 | if let Some(goal) = self.configuration.borrow().fitness_goal { 213 | let (_, _, best_fitness) = self.get_best(); 214 | 215 | best_fitness >= goal 216 | } else { 217 | false 218 | } 219 | }; 220 | 221 | if goal_reached { 222 | break; 223 | } 224 | } 225 | 226 | let (_, best_genome, best_fitness) = self.get_best(); 227 | (Network::from(best_genome), best_fitness) 228 | } 229 | 230 | fn test_fitness(&mut self) { 231 | let ids_and_networks: Vec<(GenomeId, Network)> = self 232 | .genomes 233 | .genomes() 234 | .iter() 235 | .map(|(genome_id, genome)| (*genome_id, Network::from(genome))) 236 | .collect(); 237 | 238 | let node_cost = self.configuration.borrow().node_cost; 239 | let connection_cost = self.configuration.borrow().connection_cost; 240 | let fitness_fn = self.fitness_fn; 241 | 242 | let ids_and_fitnesses: Vec<(GenomeId, f64)> = ids_and_networks 243 | .into_par_iter() 244 | .map(|(genome_id, mut network)| { 245 | let mut fitness: f64 = (fitness_fn)(&mut network); 246 | fitness -= node_cost * network.nodes.len() as f64; 247 | fitness -= connection_cost * network.connections.len() as f64; 248 | 249 | (genome_id, fitness) 250 | }) 251 | .collect(); 252 | 253 | ids_and_fitnesses 254 | .into_iter() 255 | .for_each(|(genome_id, genome_fitness)| { 256 | self.genomes.mark_fitness(genome_id, genome_fitness) 257 | }); 258 | } 259 | 260 | pub fn get_best(&self) -> (GenomeId, &Genome, f64) { 261 | let (best_genome_id, best_fitness) = self.genomes.fitnesses().iter().fold( 262 | (Uuid::new_v4(), f64::MIN), 263 | |(best_id, best_fitness), (genome_id, genome_fitness)| { 264 | if *genome_fitness > best_fitness { 265 | (*genome_id, *genome_fitness) 266 | } else { 267 | (best_id, best_fitness) 268 | } 269 | }, 270 | ); 271 | 272 | let best_genome = self.genomes.genomes().get(&best_genome_id).unwrap(); 273 | 274 | (best_genome_id, best_genome, best_fitness) 275 | } 276 | 277 | fn pick_mutation(&self) -> MutationKind { 278 | use rand::{distributions::Distribution, thread_rng}; 279 | use rand_distr::weighted_alias::WeightedAliasIndex; 280 | 281 | let dist = WeightedAliasIndex::new( 282 | self.configuration 283 | .borrow() 284 | .mutation_kinds 285 | .iter() 286 | .map(|k| k.1) 287 | .collect(), 288 | ) 289 | .unwrap(); 290 | 291 | let mut rng = thread_rng(); 292 | 293 | self.configuration 294 | .borrow() 295 | .mutation_kinds 296 | .get(dist.sample(&mut rng)) 297 | .cloned() 298 | .unwrap() 299 | .0 300 | } 301 | 302 | pub fn add_hook(&mut self, every: usize, hook: reporter::Hook) { 303 | self.reporter.register(every, hook); 304 | } 305 | } 306 | 307 | #[cfg(test)] 308 | mod tests { 309 | use super::*; 310 | 311 | #[test] 312 | fn xor() { 313 | let mut system = NEAT::new(2, 1, |n| { 314 | let inputs: Vec> = 315 | vec![vec![0., 0.], vec![0., 1.], vec![1., 0.], vec![1., 1.]]; 316 | let outputs: Vec = vec![0., 1., 1., 0.]; 317 | 318 | let mut error = 0.; 319 | 320 | for (i, o) in inputs.iter().zip(outputs) { 321 | let results = n.forward_pass(i.clone()); 322 | let result = results.first().unwrap(); 323 | 324 | error += (o - *result).powi(2); 325 | } 326 | 327 | 1. / (1. + error) 328 | }); 329 | 330 | system.set_configuration(Configuration { 331 | population_size: 150, 332 | max_generations: 100, 333 | mutation_rate: 0.75, 334 | fitness_goal: Some(0.9099), 335 | node_cost: 0.01, 336 | connection_cost: 0.01, 337 | compatibility_threshold: 3., 338 | ..Default::default() 339 | }); 340 | system.add_hook(1, |i, system| { 341 | let (_, _, fitness) = system.get_best(); 342 | println!("Generation {}, best fitness is {}", i, fitness); 343 | }); 344 | 345 | let (mut network, fitness) = system.start(); 346 | 347 | let inputs: Vec> = vec![vec![0., 0.], vec![0., 1.], vec![1., 0.], vec![1., 1.]]; 348 | for i in inputs { 349 | let o = network.forward_pass(i.clone()); 350 | dbg!(i, o); 351 | } 352 | 353 | dbg!(&network, &fitness); 354 | 355 | println!( 356 | "Found network with {} nodes and {} connections, of fitness {}", 357 | network.nodes.len(), 358 | network.connections.len(), 359 | fitness 360 | ); 361 | } 362 | } 363 | -------------------------------------------------------------------------------- /core/src/mutations.rs: -------------------------------------------------------------------------------- 1 | use rand::distributions::{Distribution, Standard}; 2 | use rand::random; 3 | use rand::thread_rng; 4 | use rand::Rng; 5 | use rand_distr::StandardNormal; 6 | 7 | use crate::activation::ActivationKind; 8 | use crate::genome::Genome; 9 | use crate::node::NodeKind; 10 | 11 | pub fn mutate(kind: &MutationKind, g: &mut Genome) { 12 | use MutationKind::*; 13 | 14 | match kind { 15 | AddConnection => add_connection(g), 16 | RemoveConnection => disable_connection(g), 17 | AddNode => add_node(g), 18 | RemoveNode => remove_node(g), 19 | ModifyWeight => change_weight(g), 20 | ModifyBias => change_bias(g), 21 | ModifyActivation => change_activation(g), 22 | ModifyAggregation => change_aggregation(g), 23 | }; 24 | } 25 | 26 | #[derive(Debug, Clone, Eq, PartialEq, Hash)] 27 | pub enum MutationKind { 28 | AddConnection, 29 | RemoveConnection, 30 | AddNode, 31 | RemoveNode, 32 | ModifyWeight, 33 | ModifyBias, 34 | ModifyActivation, 35 | ModifyAggregation, 36 | } 37 | 38 | impl Distribution for Standard { 39 | fn sample(&self, rng: &mut R) -> MutationKind { 40 | use MutationKind::*; 41 | 42 | match rng.gen_range(0, 7) { 43 | 0 => AddConnection, 44 | 1 => RemoveConnection, 45 | 2 => AddNode, 46 | 3 => RemoveNode, 47 | 4 => ModifyWeight, 48 | 5 => ModifyBias, 49 | _ => ModifyActivation, 50 | } 51 | } 52 | } 53 | 54 | /// Adds a new random connection 55 | pub fn add_connection(g: &mut Genome) { 56 | let existing_connections: Vec<(usize, usize, bool)> = g 57 | .connections() 58 | .iter() 59 | .map(|c| (c.from, c.to, c.disabled)) 60 | .collect(); 61 | 62 | let mut possible_connections: Vec<(usize, usize)> = (0..g.nodes().len()) 63 | .flat_map(|i| { 64 | let mut inner = vec![]; 65 | 66 | (0..g.nodes().len()).for_each(|j| { 67 | if i != j { 68 | if !existing_connections.contains(&(i, j, false)) { 69 | inner.push((i, j)); 70 | }; 71 | if !existing_connections.contains(&(j, i, false)) { 72 | inner.push((j, i)); 73 | }; 74 | } 75 | }); 76 | 77 | inner 78 | }) 79 | .collect(); 80 | 81 | possible_connections.sort_unstable(); 82 | possible_connections.dedup(); 83 | 84 | possible_connections = possible_connections 85 | .into_iter() 86 | .filter(|(i, j)| g.can_connect(*i, *j)) 87 | .collect(); 88 | 89 | if possible_connections.is_empty() { 90 | return; 91 | } 92 | 93 | let picked_connection = possible_connections 94 | .get(random::() % possible_connections.len()) 95 | .unwrap(); 96 | 97 | g.add_connection(picked_connection.0, picked_connection.1) 98 | .unwrap(); 99 | } 100 | 101 | /// Removes a random connection if it's not the only one 102 | fn disable_connection(g: &mut Genome) { 103 | let eligible_indexes: Vec = g 104 | .connections() 105 | .iter() 106 | .enumerate() 107 | .filter(|(_, c)| { 108 | if c.disabled { 109 | return false; 110 | } 111 | 112 | let from_index = c.from; 113 | let to_index = c.to; 114 | 115 | // Number of outgoing connections for the `from` node 116 | let from_connections_count = g 117 | .connections() 118 | .iter() 119 | .filter(|c| c.from == from_index && !c.disabled) 120 | .count(); 121 | // Number of incoming connections for the `to` node 122 | let to_connections_count = g 123 | .connections() 124 | .iter() 125 | .filter(|c| c.to == to_index && !c.disabled) 126 | .count(); 127 | 128 | from_connections_count > 1 && to_connections_count > 1 129 | }) 130 | .map(|(i, _)| i) 131 | .collect(); 132 | 133 | if eligible_indexes.is_empty() { 134 | return; 135 | } 136 | 137 | let index = eligible_indexes 138 | .get(random::() % eligible_indexes.len()) 139 | .unwrap(); 140 | 141 | g.disable_connection(*index); 142 | } 143 | 144 | /// Adds a random hidden node to the genome and its connections 145 | pub fn add_node(g: &mut Genome) { 146 | let new_node_index = g.add_node(); 147 | 148 | // Only enabled connections can be disabled 149 | let enabled_connections: Vec = g 150 | .connections() 151 | .iter() 152 | .enumerate() 153 | .filter(|(_, c)| !c.disabled) 154 | .map(|(i, _)| i) 155 | .collect(); 156 | 157 | let (picked_index, picked_from, picked_to, picked_weight) = { 158 | let random_enabled_connection_index = random::() % enabled_connections.len(); 159 | let picked_index = enabled_connections 160 | .get(random_enabled_connection_index) 161 | .unwrap(); 162 | let picked_connection = g.connections().get(*picked_index).unwrap(); 163 | 164 | ( 165 | picked_index, 166 | picked_connection.from, 167 | picked_connection.to, 168 | picked_connection.weight, 169 | ) 170 | }; 171 | 172 | g.disable_connection(*picked_index); 173 | 174 | let connection_index = g.add_connection(picked_from, new_node_index).unwrap(); 175 | g.add_connection(new_node_index, picked_to).unwrap(); 176 | 177 | // Reuse the weight from the removed connection 178 | g.connection_mut(connection_index).unwrap().weight = picked_weight; 179 | } 180 | 181 | /// Removes a random hidden node from the genome and rewires connected nodes 182 | fn remove_node(g: &mut Genome) { 183 | let hidden_nodes: Vec = g 184 | .nodes() 185 | .iter() 186 | .enumerate() 187 | .filter(|(i, n)| { 188 | let incoming_count = g 189 | .connections() 190 | .iter() 191 | .filter(|c| c.to == *i && !c.disabled) 192 | .count(); 193 | let outgoing_count = g 194 | .connections() 195 | .iter() 196 | .filter(|c| c.from == *i && !c.disabled) 197 | .count(); 198 | 199 | matches!(n.kind, NodeKind::Hidden) && incoming_count > 0 && outgoing_count > 0 200 | }) 201 | .map(|(i, _)| i) 202 | .collect(); 203 | 204 | if hidden_nodes.is_empty() { 205 | return; 206 | } 207 | 208 | let picked_node_index = hidden_nodes 209 | .get(random::() % hidden_nodes.len()) 210 | .unwrap(); 211 | 212 | let incoming_connections_and_from_indexes: Vec<(usize, usize)> = g 213 | .connections() 214 | .iter() 215 | .enumerate() 216 | .filter(|(_, c)| c.to == *picked_node_index && !c.disabled) 217 | .map(|(i, c)| (i, c.from)) 218 | .collect(); 219 | let outgoing_connections_and_to_indexes: Vec<(usize, usize)> = g 220 | .connections() 221 | .iter() 222 | .enumerate() 223 | .filter(|(_, c)| c.from == *picked_node_index && !c.disabled) 224 | .map(|(i, c)| (i, c.to)) 225 | .collect(); 226 | 227 | let new_from_to_pairs: Vec<(usize, usize)> = incoming_connections_and_from_indexes 228 | .iter() 229 | .flat_map(|(_, from)| { 230 | outgoing_connections_and_to_indexes 231 | .iter() 232 | .map(|(_, to)| (*from, *to)) 233 | .collect::>() 234 | }) 235 | .filter(|(from, to)| { 236 | g.connections() 237 | .iter() 238 | .find(|c| c.from == *from && c.to == *to && !c.disabled) 239 | .is_none() 240 | }) 241 | .collect(); 242 | 243 | g.add_many_connections(&new_from_to_pairs); 244 | 245 | let connection_indexes_to_delete: Vec = g 246 | .connections() 247 | .iter() 248 | .enumerate() 249 | .filter(|(_, c)| c.from == *picked_node_index || c.to == *picked_node_index) 250 | .map(|(i, _)| i) 251 | .collect(); 252 | 253 | g.disable_many_connections(&connection_indexes_to_delete); 254 | } 255 | 256 | /// Changes the weight of a random connection 257 | fn change_weight(g: &mut Genome) { 258 | let index = random::() % g.connections().len(); 259 | let picked_connection = g.connection_mut(index).unwrap(); 260 | 261 | let new_weight = if random::() < 0.1 { 262 | picked_connection.weight + thread_rng().sample::(StandardNormal) 263 | } else { 264 | random::() * 2. - 1. 265 | }; 266 | 267 | picked_connection.weight = new_weight.max(-1.).min(1.); 268 | } 269 | 270 | /// Changes the bias of a random non input node 271 | fn change_bias(g: &mut Genome) { 272 | let eligible_indexes: Vec = g 273 | .nodes() 274 | .iter() 275 | .enumerate() 276 | .filter(|(_, n)| !matches!(n.kind, NodeKind::Input)) 277 | .map(|(i, _)| i) 278 | .collect(); 279 | 280 | let index = eligible_indexes 281 | .get(random::() % eligible_indexes.len()) 282 | .unwrap(); 283 | let picked_node = g.node_mut(*index).unwrap(); 284 | 285 | let new_bias = if random::() < 0.1 { 286 | picked_node.bias + thread_rng().sample::(StandardNormal) 287 | } else { 288 | random::() * 2. - 1. 289 | }; 290 | 291 | picked_node.bias = new_bias.max(-1.).min(1.); 292 | } 293 | 294 | /// Changes the activation function of a random non input node 295 | fn change_activation(g: &mut Genome) { 296 | let eligible_indexes: Vec = g 297 | .nodes() 298 | .iter() 299 | .enumerate() 300 | .filter(|(_, n)| !matches!(n.kind, NodeKind::Input)) 301 | .map(|(i, _)| i) 302 | .collect(); 303 | 304 | let index = eligible_indexes 305 | .get(random::() % eligible_indexes.len()) 306 | .unwrap(); 307 | let picked_node = g.node_mut(*index).unwrap(); 308 | 309 | picked_node.activation = random::(); 310 | } 311 | 312 | fn change_aggregation(g: &mut Genome) { 313 | let eligible_indexes: Vec = g 314 | .nodes() 315 | .iter() 316 | .enumerate() 317 | .filter(|(_, n)| !matches!(n.kind, NodeKind::Input)) 318 | .map(|(i, _)| i) 319 | .collect(); 320 | 321 | let index = eligible_indexes 322 | .get(random::() % eligible_indexes.len()) 323 | .unwrap(); 324 | let picked_node = g.node_mut(*index).unwrap(); 325 | 326 | picked_node.aggregation = random(); 327 | } 328 | 329 | #[cfg(test)] 330 | mod tests { 331 | use super::*; 332 | 333 | #[test] 334 | fn add_connection_adds_missing_connection() { 335 | let mut g = Genome::new(1, 2); 336 | 337 | g.add_node(); 338 | g.add_connection(0, 3).unwrap(); 339 | g.add_connection(3, 2).unwrap(); 340 | 341 | assert!(!g.connections().iter().any(|c| c.from == 3 && c.to == 1)); 342 | add_connection(&mut g); 343 | assert!(g.connections().iter().any(|c| c.from == 3 && c.to == 1)); 344 | } 345 | 346 | #[test] 347 | fn add_connection_doesnt_add_unecessary_connections() { 348 | let mut g = Genome::new(1, 2); 349 | 350 | g.add_node(); 351 | g.add_connection(0, 3).unwrap(); 352 | g.add_connection(3, 2).unwrap(); 353 | 354 | // This will add the last missing connection 355 | assert_eq!(g.connections().len(), 4); 356 | add_connection(&mut g); 357 | assert_eq!(g.connections().len(), 5); 358 | 359 | // There should be no new connections 360 | add_connection(&mut g); 361 | assert_eq!(g.connections().len(), 5); 362 | } 363 | 364 | #[test] 365 | fn remove_connection_doesnt_remove_last_connection_of_a_node() { 366 | let mut g = Genome::new(1, 2); 367 | assert_eq!(g.connections().iter().filter(|c| !c.disabled).count(), 2); 368 | 369 | disable_connection(&mut g); 370 | assert_eq!(g.connections().iter().filter(|c| !c.disabled).count(), 2); 371 | } 372 | 373 | #[test] 374 | fn add_node_doesnt_change_existing_connections() { 375 | let mut g = Genome::new(1, 1); 376 | let original_connections = g.connections().to_vec(); 377 | 378 | add_node(&mut g); 379 | 380 | let original_connections_not_modified = original_connections 381 | .iter() 382 | .filter(|oc| { 383 | g.connections() 384 | .iter() 385 | .any(|c| c.from == oc.from && c.to == oc.to && c.disabled == oc.disabled) 386 | }) 387 | .count(); 388 | 389 | // When adding a node, a connection is selected to be disabled and replaced with a new node and two new 390 | // connections 391 | assert_eq!( 392 | original_connections_not_modified, 393 | original_connections.len() - 1 394 | ); 395 | } 396 | 397 | #[test] 398 | fn remove_node_doesnt_mess_up_the_connections() { 399 | let mut g = Genome::new(1, 1); 400 | let connection_enabled_initially = !g.connections().first().unwrap().disabled; 401 | 402 | add_node(&mut g); 403 | let connection_disabled_after_add = g.connections().first().unwrap().disabled; 404 | 405 | remove_node(&mut g); 406 | let connection_enabled_after_remove = !g.connections().first().unwrap().disabled; 407 | 408 | assert!(connection_enabled_initially); 409 | assert!(connection_disabled_after_add); 410 | assert!(connection_enabled_after_remove); 411 | } 412 | 413 | #[test] 414 | fn change_bias_doesnt_change_input_nodes() { 415 | let mut g = Genome::new(1, 1); 416 | 417 | let input_bias = g.nodes().get(0).unwrap().bias; 418 | let output_bias = g.nodes().get(1).unwrap().bias; 419 | 420 | for _ in 0..10 { 421 | change_bias(&mut g); 422 | } 423 | 424 | let new_input_bias = g.nodes().get(0).unwrap().bias; 425 | let new_output_bias = g.nodes().get(1).unwrap().bias; 426 | 427 | assert!((input_bias - new_input_bias).abs() < f64::EPSILON); 428 | assert!((output_bias - new_output_bias).abs() > f64::EPSILON); 429 | } 430 | 431 | #[test] 432 | fn change_activation_doesnt_change_input_nodes() { 433 | let mut g = Genome::new(1, 1); 434 | 435 | let i_activation = g.nodes().get(0).unwrap().activation.clone(); 436 | let o_activation = g.nodes().get(1).unwrap().activation.clone(); 437 | 438 | let mut new_i_activations = vec![]; 439 | let mut new_o_activations = vec![]; 440 | 441 | for _ in 0..10 { 442 | change_activation(&mut g); 443 | 444 | new_i_activations.push(g.nodes().get(0).unwrap().activation.clone()); 445 | new_o_activations.push(g.nodes().get(1).unwrap().activation.clone()); 446 | } 447 | 448 | assert!(new_i_activations.iter().all(|a| *a == i_activation)); 449 | assert!(new_o_activations.iter().any(|a| *a != o_activation)); 450 | } 451 | 452 | #[test] 453 | fn mutate_genome() { 454 | use std::collections::HashMap; 455 | use std::convert::TryFrom; 456 | use std::time; 457 | 458 | let mut times: HashMap> = HashMap::new(); 459 | let mut g = Genome::new(1, 1); 460 | 461 | let limit = 50; 462 | for i in 1..=limit { 463 | let kind: MutationKind = random(); 464 | 465 | let before = std::time::Instant::now(); 466 | mutate(&kind, &mut g); 467 | let after = std::time::Instant::now(); 468 | let duration = after.duration_since(before); 469 | 470 | if times.get(&kind).is_none() { 471 | times.insert(kind.clone(), vec![]); 472 | } 473 | 474 | times.get_mut(&kind).unwrap().push(duration); 475 | 476 | if g.connections().iter().all(|c| c.disabled) { 477 | panic!("All connections disabled, happened after {:?}", kind); 478 | } 479 | 480 | println!("mutation {}/{}", i, limit); 481 | } 482 | 483 | let mut kind_average_times: Vec<(MutationKind, time::Duration)> = times 484 | .iter() 485 | .map(|(k, t)| { 486 | let sum: u128 = t.iter().map(|d| d.as_micros()).sum(); 487 | let avg: u128 = sum.div_euclid(u128::try_from(t.len()).unwrap()); 488 | 489 | let duration = time::Duration::from_micros(u64::try_from(avg).unwrap()); 490 | 491 | (k.clone(), duration) 492 | }) 493 | .collect(); 494 | 495 | kind_average_times.sort_by(|(_, duration1), (_, duration2)| duration1.cmp(duration2)); 496 | 497 | kind_average_times.iter().for_each(|(k, duration)| { 498 | println!("{:?} on avg took {:?}", k, duration); 499 | }); 500 | 501 | println!( 502 | "Genome had {} nodes and {} connections, of which {} were active", 503 | g.nodes().len(), 504 | g.connections().len(), 505 | g.connections().iter().filter(|c| !c.disabled).count(), 506 | ); 507 | } 508 | } 509 | -------------------------------------------------------------------------------- /core/src/genome/mod.rs: -------------------------------------------------------------------------------- 1 | use std::collections::{HashMap, HashSet, VecDeque}; 2 | use uuid::Uuid; 3 | 4 | use crate::mutations::MutationKind; 5 | use crate::node::NodeKind; 6 | pub use connection::ConnectionGene; 7 | pub use crossover::*; 8 | pub use node::NodeGene; 9 | 10 | pub mod connection; 11 | pub mod crossover; 12 | pub mod node; 13 | 14 | pub type GenomeId = Uuid; 15 | 16 | #[derive(Debug, Clone, PartialEq, Eq, Hash)] 17 | pub struct Genome { 18 | id: Uuid, 19 | inputs: usize, 20 | outputs: usize, 21 | connection_genes: Vec, 22 | node_genes: Vec, 23 | } 24 | 25 | impl Genome { 26 | pub fn new(inputs: usize, outputs: usize) -> Self { 27 | let mut node_genes = vec![]; 28 | 29 | (0..inputs).for_each(|_| node_genes.push(NodeGene::new(NodeKind::Input))); 30 | (0..outputs).for_each(|_| node_genes.push(NodeGene::new(NodeKind::Output))); 31 | 32 | let connection_genes: Vec = (0..inputs) 33 | .flat_map(|i| { 34 | (inputs..inputs + outputs) 35 | .map(|o| ConnectionGene::new(i, o)) 36 | .collect::>() 37 | }) 38 | .collect(); 39 | 40 | Genome { 41 | id: Uuid::new_v4(), 42 | inputs, 43 | outputs, 44 | connection_genes, 45 | node_genes, 46 | } 47 | } 48 | 49 | fn empty(inputs: usize, outputs: usize) -> Self { 50 | Genome { 51 | id: Uuid::new_v4(), 52 | inputs, 53 | outputs, 54 | connection_genes: vec![], 55 | node_genes: vec![], 56 | } 57 | } 58 | 59 | pub fn id(&self) -> GenomeId { 60 | // use std::collections::hash_map::DefaultHasher; 61 | // use std::hash::{Hash, Hasher}; 62 | 63 | // let mut hasher = DefaultHasher::new(); 64 | // self.hash(&mut hasher); 65 | 66 | // hasher.finish() 67 | self.id 68 | } 69 | 70 | pub fn input_count(&self) -> usize { 71 | self.inputs 72 | } 73 | 74 | pub fn output_count(&self) -> usize { 75 | self.outputs 76 | } 77 | 78 | pub fn nodes(&self) -> &[NodeGene] { 79 | &self.node_genes 80 | } 81 | 82 | pub fn node_mut(&mut self, index: usize) -> Option<&mut NodeGene> { 83 | self.node_genes.get_mut(index) 84 | } 85 | 86 | pub fn connections(&self) -> &[ConnectionGene] { 87 | &self.connection_genes 88 | } 89 | 90 | pub fn connection_mut(&mut self, index: usize) -> Option<&mut ConnectionGene> { 91 | self.connection_genes.get_mut(index) 92 | } 93 | 94 | fn calculate_node_order( 95 | &self, 96 | additional_connections: Option>, 97 | ) -> Option> { 98 | let mut connections: Vec = self 99 | .connection_genes 100 | .iter() 101 | .filter(|c| !c.disabled) 102 | .cloned() 103 | .collect(); 104 | 105 | if let Some(mut conns) = additional_connections { 106 | connections.append(&mut conns); 107 | } 108 | 109 | if connections.is_empty() { 110 | return None; 111 | } 112 | 113 | let mut visited: Vec = vec![]; 114 | 115 | // Input nodes are automatically visited as they get their values from inputs 116 | self.node_genes 117 | .iter() 118 | .enumerate() 119 | .filter(|(_, n)| matches!(n.kind, NodeKind::Input)) 120 | .for_each(|(i, _)| { 121 | visited.push(i); 122 | }); 123 | 124 | let mut newly_visited = 1; 125 | while newly_visited != 0 { 126 | newly_visited = 0; 127 | 128 | let mut nodes_to_visit: Vec = self 129 | .node_genes 130 | .iter() 131 | .enumerate() 132 | .filter(|(i, _)| { 133 | // The node is not visited but all prerequisite nodes are visited 134 | !visited.contains(i) 135 | && connections 136 | .iter() 137 | .filter(|c| c.to == *i) 138 | .map(|c| c.from) 139 | .all(|node_index| visited.contains(&node_index)) 140 | }) 141 | .map(|(i, _)| i) 142 | .collect(); 143 | 144 | newly_visited += nodes_to_visit.len(); 145 | visited.append(&mut nodes_to_visit); 146 | } 147 | 148 | if visited.len() != self.node_genes.len() { 149 | return None; 150 | } 151 | 152 | Some(visited) 153 | } 154 | 155 | pub fn node_order(&self) -> Option> { 156 | self.calculate_node_order(None) 157 | } 158 | 159 | pub fn node_order_with( 160 | &self, 161 | additional_connections: Vec, 162 | ) -> Option> { 163 | self.calculate_node_order(Some(additional_connections)) 164 | } 165 | 166 | fn calculate_node_distance_from_inputs(&self) -> HashMap { 167 | // Inputs are immediately added with distance of 0 168 | let mut distances: HashMap = self 169 | .nodes() 170 | .iter() 171 | .enumerate() 172 | .filter(|(_, n)| matches!(n.kind, NodeKind::Input)) 173 | .map(|(i, _)| (i, 0)) 174 | .collect(); 175 | 176 | // Inputs need to be visited first 177 | let mut to_visit: VecDeque = self 178 | .nodes() 179 | .iter() 180 | .enumerate() 181 | .filter(|(_, n)| matches!(n.kind, NodeKind::Input)) 182 | .map(|(i, _)| i) 183 | .collect(); 184 | 185 | while let Some(i) = to_visit.pop_front() { 186 | let source_distance = *distances.get(&i).unwrap_or(&0); 187 | 188 | self.connections() 189 | .iter() 190 | .filter(|c| c.from == i) 191 | .for_each(|c| { 192 | let node_index = c.to; 193 | let potential_distance = source_distance + 1; 194 | 195 | let maybe_change = if let Some(distance) = distances.get(&node_index) { 196 | if potential_distance > *distance { 197 | to_visit.push_back(node_index); 198 | Some(potential_distance) 199 | } else { 200 | None 201 | } 202 | } else { 203 | to_visit.push_back(node_index); 204 | Some(potential_distance) 205 | }; 206 | 207 | if let Some(new_distance) = maybe_change { 208 | distances.insert(node_index, new_distance); 209 | } 210 | }); 211 | } 212 | 213 | distances 214 | } 215 | 216 | fn is_projecting_directly(&self, source: usize, target: usize) -> bool { 217 | self.connection_genes 218 | .iter() 219 | .filter(|c| !c.disabled) 220 | .any(|c| c.from == source && c.to == target) 221 | } 222 | 223 | fn is_projected_directly(&self, target: usize, source: usize) -> bool { 224 | self.is_projecting_directly(source, target) 225 | } 226 | 227 | fn is_projecting(&self, source: usize, target: usize) -> bool { 228 | let mut visited_nodes: HashSet = HashSet::new(); 229 | let mut nodes_to_visit: VecDeque = VecDeque::new(); 230 | 231 | nodes_to_visit.push_back(source); 232 | 233 | let mut projecting = false; 234 | while let Some(i) = nodes_to_visit.pop_front() { 235 | visited_nodes.insert(i); 236 | if self.is_projecting_directly(i, target) { 237 | projecting = true; 238 | break; 239 | } else { 240 | self.connection_genes 241 | .iter() 242 | .filter(|c| c.from == i && !c.disabled && !visited_nodes.contains(&i)) 243 | .for_each(|c| nodes_to_visit.push_back(c.to)); 244 | } 245 | } 246 | 247 | projecting 248 | } 249 | 250 | fn is_projected(&self, target: usize, source: usize) -> bool { 251 | self.is_projecting(source, target) 252 | } 253 | 254 | pub fn can_connect(&self, from: usize, to: usize) -> bool { 255 | let from_node = self.node_genes.get(from).unwrap(); 256 | let to_node = self.node_genes.get(to).unwrap(); 257 | 258 | let is_from_output = matches!(from_node.kind, NodeKind::Output); 259 | let is_to_input = matches!(to_node.kind, NodeKind::Input); 260 | 261 | let distances = self.calculate_node_distance_from_inputs(); 262 | let from_distance = distances.get(&from).unwrap(); 263 | let to_distance = distances.get(&to).unwrap_or(&usize::MAX); 264 | let is_recurrent = from_distance > to_distance; 265 | 266 | if is_from_output || is_to_input || is_recurrent { 267 | false 268 | } else { 269 | !self.is_projecting(from, to) 270 | } 271 | } 272 | 273 | pub fn add_connection(&mut self, from: usize, to: usize) -> Result { 274 | if !self.can_connect(from, to) { 275 | return Err(()); 276 | } 277 | 278 | let maybe_connection = self 279 | .connection_genes 280 | .iter_mut() 281 | .find(|c| c.from == from && c.to == to); 282 | 283 | if let Some(mut conn) = maybe_connection { 284 | conn.disabled = false; 285 | } else { 286 | self.connection_genes.push(ConnectionGene::new(from, to)); 287 | } 288 | 289 | Ok(self.connection_genes.len() - 1) 290 | } 291 | 292 | pub fn add_many_connections(&mut self, params: &[(usize, usize)]) -> Vec> { 293 | let results = params 294 | .iter() 295 | .map(|(from, to)| self.add_connection(*from, *to)) 296 | .collect(); 297 | 298 | results 299 | } 300 | 301 | pub fn disable_connection(&mut self, index: usize) { 302 | self.connection_genes.get_mut(index).unwrap().disabled = true; 303 | } 304 | 305 | pub fn disable_many_connections(&mut self, indexes: &[usize]) { 306 | indexes.iter().for_each(|i| self.disable_connection(*i)); 307 | } 308 | 309 | /// Add a new hidden node to the genome 310 | pub fn add_node(&mut self) -> usize { 311 | let index = self.node_genes.len(); 312 | self.node_genes.push(NodeGene::new(NodeKind::Hidden)); 313 | 314 | index 315 | } 316 | 317 | pub fn mutate(&mut self, kind: &MutationKind) { 318 | crate::mutations::mutate(kind, self); 319 | } 320 | } 321 | 322 | #[cfg(test)] 323 | mod tests { 324 | use super::*; 325 | 326 | #[test] 327 | fn initialize() { 328 | Genome::new(2, 2); 329 | } 330 | 331 | #[test] 332 | fn add_node_does_not_change_connections() { 333 | let mut g = Genome::new(1, 2); 334 | 335 | g.add_node(); 336 | 337 | let first_connection = g.connection_genes.get(0).unwrap(); 338 | assert_eq!(first_connection.from, 0); 339 | assert_eq!(first_connection.to, 1); 340 | 341 | let second_connection = g.connection_genes.get(1).unwrap(); 342 | assert_eq!(second_connection.from, 0); 343 | assert_eq!(second_connection.to, 2); 344 | } 345 | 346 | #[test] 347 | fn is_projecting_directly() { 348 | let g = Genome::new(2, 2); 349 | 350 | assert!(g.is_projecting_directly(0, 2)); 351 | assert!(g.is_projecting_directly(0, 3)); 352 | assert!(g.is_projecting_directly(1, 2)); 353 | assert!(g.is_projecting_directly(1, 3)); 354 | 355 | assert!(!g.is_projecting_directly(2, 0)); 356 | assert!(!g.is_projecting_directly(3, 0)); 357 | assert!(!g.is_projecting_directly(2, 1)); 358 | assert!(!g.is_projecting_directly(3, 1)); 359 | } 360 | 361 | #[test] 362 | fn is_projected_directly() { 363 | let g = Genome::new(2, 2); 364 | 365 | assert!(g.is_projected_directly(2, 0)); 366 | assert!(g.is_projected_directly(3, 0)); 367 | assert!(g.is_projected_directly(2, 1)); 368 | assert!(g.is_projected_directly(3, 1)); 369 | 370 | assert!(!g.is_projected_directly(0, 2)); 371 | assert!(!g.is_projected_directly(0, 3)); 372 | assert!(!g.is_projected_directly(1, 2)); 373 | assert!(!g.is_projected_directly(1, 3)); 374 | } 375 | 376 | // TODO rewrite the tests or both the implementation and tests 377 | 378 | // #[test] 379 | // fn is_projecting() { 380 | // let mut g = Genome::empty(1, 1); 381 | 382 | // g.node_genes.push(NodeGene::new(NodeKind::Input)); 383 | // g.node_genes.push(NodeGene::new(NodeKind::Hidden)); 384 | // g.node_genes.push(NodeGene::new(NodeKind::Hidden)); 385 | // g.node_genes.push(NodeGene::new(NodeKind::Output)); 386 | 387 | // g.connection_genes.push(ConnectionGene::new(0, 1)); 388 | // g.connection_genes.push(ConnectionGene::new(1, 2)); 389 | // g.connection_genes.push(ConnectionGene::new(2, 3)); 390 | 391 | // assert!(g.is_projecting(0, 3)); 392 | // assert!(g.is_projecting(1, 3)); 393 | // assert!(g.is_projecting(2, 3)); 394 | 395 | // assert!(!g.is_projecting(3, 0)); 396 | // assert!(!g.is_projecting(3, 1)); 397 | // assert!(!g.is_projecting(3, 2)); 398 | // } 399 | 400 | // #[test] 401 | // fn is_projected() { 402 | // let mut g = Genome::empty(1, 1); 403 | 404 | // g.node_genes.push(NodeGene::new(NodeKind::Input)); 405 | // g.node_genes.push(NodeGene::new(NodeKind::Hidden)); 406 | // g.node_genes.push(NodeGene::new(NodeKind::Hidden)); 407 | // g.node_genes.push(NodeGene::new(NodeKind::Output)); 408 | 409 | // g.connection_genes.push(ConnectionGene::new(0, 1)); 410 | // g.connection_genes.push(ConnectionGene::new(1, 2)); 411 | // g.connection_genes.push(ConnectionGene::new(2, 3)); 412 | 413 | // assert!(g.is_projected(3, 0)); 414 | // assert!(g.is_projected(3, 1)); 415 | // assert!(g.is_projected(3, 2)); 416 | 417 | // assert!(!g.is_projected(0, 3)); 418 | // assert!(!g.is_projected(1, 3)); 419 | // assert!(!g.is_projected(2, 3)); 420 | // } 421 | 422 | #[test] 423 | fn can_connect() { 424 | let mut g = Genome::empty(1, 1); 425 | 426 | g.node_genes.push(NodeGene::new(NodeKind::Input)); 427 | g.node_genes.push(NodeGene::new(NodeKind::Hidden)); 428 | g.node_genes.push(NodeGene::new(NodeKind::Hidden)); 429 | g.node_genes.push(NodeGene::new(NodeKind::Hidden)); 430 | g.node_genes.push(NodeGene::new(NodeKind::Output)); 431 | 432 | g.connection_genes.push(ConnectionGene::new(0, 1)); 433 | g.connection_genes.push(ConnectionGene::new(0, 2)); 434 | g.connection_genes.push(ConnectionGene::new(1, 3)); 435 | g.connection_genes.push(ConnectionGene::new(2, 3)); 436 | g.connection_genes.push(ConnectionGene::new(3, 4)); 437 | 438 | assert!(g.can_connect(1, 2)); 439 | assert!(g.can_connect(2, 1)); 440 | 441 | assert!(!g.can_connect(3, 1)); 442 | assert!(!g.can_connect(3, 2)); 443 | assert!(!g.can_connect(4, 1)); 444 | assert!(!g.can_connect(4, 2)); 445 | } 446 | 447 | #[test] 448 | fn get_node_order() { 449 | let mut g = Genome::empty(2, 1); 450 | 451 | g.node_genes.push(NodeGene::new(NodeKind::Input)); 452 | g.node_genes.push(NodeGene::new(NodeKind::Input)); 453 | g.node_genes.push(NodeGene::new(NodeKind::Output)); 454 | g.node_genes.push(NodeGene::new(NodeKind::Hidden)); 455 | g.node_genes.push(NodeGene::new(NodeKind::Hidden)); 456 | g.node_genes.push(NodeGene::new(NodeKind::Hidden)); 457 | 458 | g.add_connection(0, 2).unwrap(); 459 | g.add_connection(1, 3).unwrap(); 460 | g.add_connection(1, 4).unwrap(); 461 | g.add_connection(1, 5).unwrap(); 462 | g.add_connection(3, 2).unwrap(); 463 | g.add_connection(4, 3).unwrap(); 464 | g.add_connection(5, 4).unwrap(); 465 | 466 | assert!(g.node_order().is_some()); 467 | assert!(g.node_order_with(vec![ConnectionGene::new(3, 5)]).is_none()); 468 | } 469 | 470 | #[test] 471 | fn no_recurrent_connections() { 472 | let mut g = Genome::empty(2, 1); 473 | 474 | g.node_genes.push(NodeGene::new(NodeKind::Input)); 475 | g.node_genes.push(NodeGene::new(NodeKind::Input)); 476 | g.node_genes.push(NodeGene::new(NodeKind::Output)); 477 | g.node_genes.push(NodeGene::new(NodeKind::Hidden)); 478 | g.node_genes.push(NodeGene::new(NodeKind::Hidden)); 479 | g.node_genes.push(NodeGene::new(NodeKind::Hidden)); 480 | 481 | g.add_connection(0, 2).unwrap(); 482 | g.add_connection(1, 3).unwrap(); 483 | g.add_connection(1, 4).unwrap(); 484 | g.add_connection(1, 5).unwrap(); 485 | g.add_connection(3, 2).unwrap(); 486 | g.add_connection(4, 3).unwrap(); 487 | g.add_connection(5, 4).unwrap(); 488 | 489 | assert!(g.add_connection(3, 5).is_err()); 490 | } 491 | 492 | #[test] 493 | fn node_distances_simple() { 494 | let g = Genome::new(2, 1); 495 | 496 | dbg!(g.calculate_node_distance_from_inputs()); 497 | } 498 | 499 | #[test] 500 | fn node_distances_block_recurrent_connections() { 501 | let mut g = Genome::empty(2, 1); 502 | 503 | g.node_genes.push(NodeGene::new(NodeKind::Input)); 504 | g.node_genes.push(NodeGene::new(NodeKind::Input)); 505 | g.node_genes.push(NodeGene::new(NodeKind::Output)); 506 | g.node_genes.push(NodeGene::new(NodeKind::Hidden)); 507 | g.node_genes.push(NodeGene::new(NodeKind::Hidden)); 508 | g.node_genes.push(NodeGene::new(NodeKind::Hidden)); 509 | 510 | g.add_connection(0, 3).unwrap(); 511 | g.add_connection(1, 3).unwrap(); 512 | g.add_connection(3, 4).unwrap(); 513 | g.add_connection(4, 5).unwrap(); 514 | g.add_connection(4, 2).unwrap(); 515 | g.add_connection(5, 2).unwrap(); 516 | 517 | assert!(g.add_connection(5, 3).is_err()); 518 | } 519 | } 520 | --------------------------------------------------------------------------------