├── Cargo.lock ├── Cargo.toml ├── README.md ├── screenshot.png └── src ├── configs.rs ├── game.rs ├── lib.rs ├── main.rs ├── nn.rs ├── pop.rs ├── sim.rs ├── stream.rs ├── utils.rs └── viz.rs /Cargo.lock: -------------------------------------------------------------------------------- 1 | # This file is automatically @generated by Cargo. 2 | # It is not intended for manual editing. 3 | version = 3 4 | 5 | [[package]] 6 | name = "adler" 7 | version = "1.0.2" 8 | source = "registry+https://github.com/rust-lang/crates.io-index" 9 | checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" 10 | 11 | [[package]] 12 | name = "ahash" 13 | version = "0.8.11" 14 | source = "registry+https://github.com/rust-lang/crates.io-index" 15 | checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" 16 | dependencies = [ 17 | "cfg-if", 18 | "once_cell", 19 | "version_check", 20 | "zerocopy", 21 | ] 22 | 23 | [[package]] 24 | name = "autocfg" 25 | version = "1.2.0" 26 | source = "registry+https://github.com/rust-lang/crates.io-index" 27 | checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80" 28 | 29 | [[package]] 30 | name = "bitflags" 31 | version = "1.3.2" 32 | source = "registry+https://github.com/rust-lang/crates.io-index" 33 | checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" 34 | 35 | [[package]] 36 | name = "bumpalo" 37 | version = "3.15.4" 38 | source = "registry+https://github.com/rust-lang/crates.io-index" 39 | checksum = "7ff69b9dd49fd426c69a0db9fc04dd934cdb6645ff000864d98f7e2af8830eaa" 40 | 41 | [[package]] 42 | name = "bytemuck" 43 | version = "1.15.0" 44 | source = "registry+https://github.com/rust-lang/crates.io-index" 45 | checksum = "5d6d68c57235a3a081186990eca2867354726650f42f7516ca50c28d6281fd15" 46 | 47 | [[package]] 48 | name = "byteorder" 49 | version = "1.5.0" 50 | source = "registry+https://github.com/rust-lang/crates.io-index" 51 | checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" 52 | 53 | [[package]] 54 | name = "cfg-if" 55 | version = "1.0.0" 56 | source = "registry+https://github.com/rust-lang/crates.io-index" 57 | checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" 58 | 59 | [[package]] 60 | name = "color_quant" 61 | version = "1.1.0" 62 | source = "registry+https://github.com/rust-lang/crates.io-index" 63 | checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" 64 | 65 | [[package]] 66 | name = "crc32fast" 67 | version = "1.4.0" 68 | source = "registry+https://github.com/rust-lang/crates.io-index" 69 | checksum = "b3855a8a784b474f333699ef2bbca9db2c4a1f6d9088a90a2d25b1eb53111eaa" 70 | dependencies = [ 71 | "cfg-if", 72 | ] 73 | 74 | [[package]] 75 | name = "fdeflate" 76 | version = "0.3.4" 77 | source = "registry+https://github.com/rust-lang/crates.io-index" 78 | checksum = "4f9bfee30e4dedf0ab8b422f03af778d9612b63f502710fc500a334ebe2de645" 79 | dependencies = [ 80 | "simd-adler32", 81 | ] 82 | 83 | [[package]] 84 | name = "flate2" 85 | version = "1.0.28" 86 | source = "registry+https://github.com/rust-lang/crates.io-index" 87 | checksum = "46303f565772937ffe1d394a4fac6f411c6013172fadde9dcdb1e147a086940e" 88 | dependencies = [ 89 | "crc32fast", 90 | "miniz_oxide", 91 | ] 92 | 93 | [[package]] 94 | name = "fontdue" 95 | version = "0.7.3" 96 | source = "registry+https://github.com/rust-lang/crates.io-index" 97 | checksum = "0793f5137567643cf65ea42043a538804ff0fbf288649e2141442b602d81f9bc" 98 | dependencies = [ 99 | "hashbrown", 100 | "ttf-parser", 101 | ] 102 | 103 | [[package]] 104 | name = "getrandom" 105 | version = "0.2.12" 106 | source = "registry+https://github.com/rust-lang/crates.io-index" 107 | checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5" 108 | dependencies = [ 109 | "cfg-if", 110 | "libc", 111 | "wasi", 112 | ] 113 | 114 | [[package]] 115 | name = "glam" 116 | version = "0.21.3" 117 | source = "registry+https://github.com/rust-lang/crates.io-index" 118 | checksum = "518faa5064866338b013ff9b2350dc318e14cc4fcd6cb8206d7e7c9886c98815" 119 | 120 | [[package]] 121 | name = "hashbrown" 122 | version = "0.13.2" 123 | source = "registry+https://github.com/rust-lang/crates.io-index" 124 | checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e" 125 | dependencies = [ 126 | "ahash", 127 | ] 128 | 129 | [[package]] 130 | name = "image" 131 | version = "0.24.9" 132 | source = "registry+https://github.com/rust-lang/crates.io-index" 133 | checksum = "5690139d2f55868e080017335e4b94cb7414274c74f1669c84fb5feba2c9f69d" 134 | dependencies = [ 135 | "bytemuck", 136 | "byteorder", 137 | "color_quant", 138 | "num-traits", 139 | "png", 140 | ] 141 | 142 | [[package]] 143 | name = "libc" 144 | version = "0.2.153" 145 | source = "registry+https://github.com/rust-lang/crates.io-index" 146 | checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" 147 | 148 | [[package]] 149 | name = "macroquad" 150 | version = "0.4.5" 151 | source = "registry+https://github.com/rust-lang/crates.io-index" 152 | checksum = "002647f9301eb8827145a6ae71f1fc0b0441d3a3e0648146e3c7de926e78c37d" 153 | dependencies = [ 154 | "bumpalo", 155 | "fontdue", 156 | "glam", 157 | "image", 158 | "macroquad_macro", 159 | "miniquad", 160 | "quad-rand", 161 | "slotmap", 162 | ] 163 | 164 | [[package]] 165 | name = "macroquad_macro" 166 | version = "0.1.7" 167 | source = "registry+https://github.com/rust-lang/crates.io-index" 168 | checksum = "f5cecfede1e530599c8686f7f2d609489101d3d63741a6dc423afc997ce3fcc8" 169 | 170 | [[package]] 171 | name = "malloc_buf" 172 | version = "0.0.6" 173 | source = "registry+https://github.com/rust-lang/crates.io-index" 174 | checksum = "62bb907fe88d54d8d9ce32a3cceab4218ed2f6b7d35617cafe9adf84e43919cb" 175 | dependencies = [ 176 | "libc", 177 | ] 178 | 179 | [[package]] 180 | name = "miniquad" 181 | version = "0.4.0" 182 | source = "registry+https://github.com/rust-lang/crates.io-index" 183 | checksum = "91e9c578ad261f84512751bfdd9919c762ecd6103d06051fbdaede35136e1988" 184 | dependencies = [ 185 | "libc", 186 | "ndk-sys", 187 | "objc", 188 | "winapi", 189 | ] 190 | 191 | [[package]] 192 | name = "miniz_oxide" 193 | version = "0.7.2" 194 | source = "registry+https://github.com/rust-lang/crates.io-index" 195 | checksum = "9d811f3e15f28568be3407c8e7fdb6514c1cda3cb30683f15b6a1a1dc4ea14a7" 196 | dependencies = [ 197 | "adler", 198 | "simd-adler32", 199 | ] 200 | 201 | [[package]] 202 | name = "ndk-sys" 203 | version = "0.2.2" 204 | source = "registry+https://github.com/rust-lang/crates.io-index" 205 | checksum = "e1bcdd74c20ad5d95aacd60ef9ba40fdf77f767051040541df557b7a9b2a2121" 206 | 207 | [[package]] 208 | name = "num-traits" 209 | version = "0.2.18" 210 | source = "registry+https://github.com/rust-lang/crates.io-index" 211 | checksum = "da0df0e5185db44f69b44f26786fe401b6c293d1907744beaa7fa62b2e5a517a" 212 | dependencies = [ 213 | "autocfg", 214 | ] 215 | 216 | [[package]] 217 | name = "objc" 218 | version = "0.2.7" 219 | source = "registry+https://github.com/rust-lang/crates.io-index" 220 | checksum = "915b1b472bc21c53464d6c8461c9d3af805ba1ef837e1cac254428f4a77177b1" 221 | dependencies = [ 222 | "malloc_buf", 223 | ] 224 | 225 | [[package]] 226 | name = "once_cell" 227 | version = "1.19.0" 228 | source = "registry+https://github.com/rust-lang/crates.io-index" 229 | checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" 230 | 231 | [[package]] 232 | name = "png" 233 | version = "0.17.13" 234 | source = "registry+https://github.com/rust-lang/crates.io-index" 235 | checksum = "06e4b0d3d1312775e782c86c91a111aa1f910cbb65e1337f9975b5f9a554b5e1" 236 | dependencies = [ 237 | "bitflags", 238 | "crc32fast", 239 | "fdeflate", 240 | "flate2", 241 | "miniz_oxide", 242 | ] 243 | 244 | [[package]] 245 | name = "ppv-lite86" 246 | version = "0.2.17" 247 | source = "registry+https://github.com/rust-lang/crates.io-index" 248 | checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" 249 | 250 | [[package]] 251 | name = "proc-macro2" 252 | version = "1.0.79" 253 | source = "registry+https://github.com/rust-lang/crates.io-index" 254 | checksum = "e835ff2298f5721608eb1a980ecaee1aef2c132bf95ecc026a11b7bf3c01c02e" 255 | dependencies = [ 256 | "unicode-ident", 257 | ] 258 | 259 | [[package]] 260 | name = "quad-rand" 261 | version = "0.2.1" 262 | source = "registry+https://github.com/rust-lang/crates.io-index" 263 | checksum = "658fa1faf7a4cc5f057c9ee5ef560f717ad9d8dc66d975267f709624d6e1ab88" 264 | 265 | [[package]] 266 | name = "quote" 267 | version = "1.0.35" 268 | source = "registry+https://github.com/rust-lang/crates.io-index" 269 | checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" 270 | dependencies = [ 271 | "proc-macro2", 272 | ] 273 | 274 | [[package]] 275 | name = "rand" 276 | version = "0.8.5" 277 | source = "registry+https://github.com/rust-lang/crates.io-index" 278 | checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" 279 | dependencies = [ 280 | "libc", 281 | "rand_chacha", 282 | "rand_core", 283 | ] 284 | 285 | [[package]] 286 | name = "rand_chacha" 287 | version = "0.3.1" 288 | source = "registry+https://github.com/rust-lang/crates.io-index" 289 | checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" 290 | dependencies = [ 291 | "ppv-lite86", 292 | "rand_core", 293 | ] 294 | 295 | [[package]] 296 | name = "rand_core" 297 | version = "0.6.4" 298 | source = "registry+https://github.com/rust-lang/crates.io-index" 299 | checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" 300 | dependencies = [ 301 | "getrandom", 302 | ] 303 | 304 | [[package]] 305 | name = "simd-adler32" 306 | version = "0.3.7" 307 | source = "registry+https://github.com/rust-lang/crates.io-index" 308 | checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" 309 | 310 | [[package]] 311 | name = "slotmap" 312 | version = "1.0.7" 313 | source = "registry+https://github.com/rust-lang/crates.io-index" 314 | checksum = "dbff4acf519f630b3a3ddcfaea6c06b42174d9a44bc70c620e9ed1649d58b82a" 315 | dependencies = [ 316 | "version_check", 317 | ] 318 | 319 | [[package]] 320 | name = "snake" 321 | version = "0.1.0" 322 | dependencies = [ 323 | "macroquad", 324 | "rand", 325 | ] 326 | 327 | [[package]] 328 | name = "syn" 329 | version = "2.0.58" 330 | source = "registry+https://github.com/rust-lang/crates.io-index" 331 | checksum = "44cfb93f38070beee36b3fef7d4f5a16f27751d94b187b666a5cc5e9b0d30687" 332 | dependencies = [ 333 | "proc-macro2", 334 | "quote", 335 | "unicode-ident", 336 | ] 337 | 338 | [[package]] 339 | name = "ttf-parser" 340 | version = "0.15.2" 341 | source = "registry+https://github.com/rust-lang/crates.io-index" 342 | checksum = "7b3e06c9b9d80ed6b745c7159c40b311ad2916abb34a49e9be2653b90db0d8dd" 343 | 344 | [[package]] 345 | name = "unicode-ident" 346 | version = "1.0.12" 347 | source = "registry+https://github.com/rust-lang/crates.io-index" 348 | checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" 349 | 350 | [[package]] 351 | name = "version_check" 352 | version = "0.9.4" 353 | source = "registry+https://github.com/rust-lang/crates.io-index" 354 | checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" 355 | 356 | [[package]] 357 | name = "wasi" 358 | version = "0.11.0+wasi-snapshot-preview1" 359 | source = "registry+https://github.com/rust-lang/crates.io-index" 360 | checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" 361 | 362 | [[package]] 363 | name = "winapi" 364 | version = "0.3.9" 365 | source = "registry+https://github.com/rust-lang/crates.io-index" 366 | checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" 367 | dependencies = [ 368 | "winapi-i686-pc-windows-gnu", 369 | "winapi-x86_64-pc-windows-gnu", 370 | ] 371 | 372 | [[package]] 373 | name = "winapi-i686-pc-windows-gnu" 374 | version = "0.4.0" 375 | source = "registry+https://github.com/rust-lang/crates.io-index" 376 | checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" 377 | 378 | [[package]] 379 | name = "winapi-x86_64-pc-windows-gnu" 380 | version = "0.4.0" 381 | source = "registry+https://github.com/rust-lang/crates.io-index" 382 | checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" 383 | 384 | [[package]] 385 | name = "zerocopy" 386 | version = "0.7.32" 387 | source = "registry+https://github.com/rust-lang/crates.io-index" 388 | checksum = "74d4d3961e53fa4c9a25a8637fc2bfaf2595b3d3ae34875568a5cf64787716be" 389 | dependencies = [ 390 | "zerocopy-derive", 391 | ] 392 | 393 | [[package]] 394 | name = "zerocopy-derive" 395 | version = "0.7.32" 396 | source = "registry+https://github.com/rust-lang/crates.io-index" 397 | checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" 398 | dependencies = [ 399 | "proc-macro2", 400 | "quote", 401 | "syn", 402 | ] 403 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "snake" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 7 | 8 | [dependencies] 9 | macroquad = "0.4.5" 10 | rand = "0.8.5" 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AI learns to play Snake! 2 | A neural network learns to play snakes 3 | 4 | Built with [Rust](https://www.rust-lang.org/) and [Macroquad](https://github.com/not-fl3/macroquad) game engine 5 | 6 | ![screenshot](/screenshot.png) 7 | 8 | # Explaination & Timelapse Video 9 | [![youtube](https://img.youtube.com/vi/YPWy-3CTB-I/0.jpg)](https://youtu.be/YPWy-3CTB-I) 10 | 11 | # Controls 12 | - `Tab` - Enable/Disable visualization 13 | - `Space` - Slow down the simulation 14 | 15 | # Glossary 16 | - **Neuro-evolution**: A subfield of artificial intelligence and evolutionary computation that uses evolutionary algorithms to evolve artificial neural networks. 17 | - **Genetic Algorithm (GA)**: An optimization algorithm inspired by the process of natural selection and genetics, used to evolve populations of solutions to optimization problems 18 | - **Population**: A collection of individuals (neural networks) that are subject to evolution in a neuro-evolutionary algorithm. 19 | - **Island Model**: A technique in neuro-evolution where multiple separate populations (islands, in code this is called streams) of individuals evolve independently, periodically exchanging individuals between islands to maintain diversity and promote exploration. 20 | - **Mutation**: The process of introducing random changes to the genetic material (weights, biases, or network structure) of individuals in the population. Mutation helps introduce new variations and explore the solution space 21 | 22 | # Overview 23 | ### Setup 24 | - Every snake has a neural network that acts as its brain. 25 | - The snake can see in 4 direction. It can detect food, wall and itself in these 4 directions. Total number of inputs = `4 * 3 = 12` 26 | - These 12 values are fed as an input to the neural network. The neural network then generates 4 values that indicate the threshold for actions - left, right, bottom and top. 27 | - Every generation has 5 streams (islands) of 1000 snakes each. The snakes in each stream evolve independently of the snakes from other streams 28 | - Occasionally best performing snakes from one stream are injected into another. This technique is called "Island Rejuvenation". 29 | ### Algorithm 30 | - The simulation begins at `Generation 0` with 5 streams of games, the individuals in each of these streams have randomly generated neural networks. 31 | - Each step, we update every game i.e pass the vision inputs to the neural network and have it decide on an action to take. 32 | - When the snake collides the walls or when it collides with itself, the game is flagged as completed. 33 | - Aditionally, games are marked complete when the snake isn't able to eat food for a certain number of steps. This is to prevent snakes from performing looping actions indefinitely. 34 | - We update each game in a generation until all the games are complete. 35 | - At the end of each generation, each snake in a stream is ranked based on how it performed. 36 | - Based on this ranking, parents are chosen for the next batch of snakes. Snake at rank 1 is more probable to be a parent compared to snake at rank 10. 37 | - Here's an example of how the population is distributed throughout every generation: 38 | 1. Top 1% of the snakes are retained for next generation 39 | 2. 50% of the population is newly generated using 2 snakes from the previous generation as their parents 40 | 3. 20% of the snakes are freshly generated with no connection to the past generations 41 | 4. The rest 29% of the population are all mutations of the current best performing snakes 42 | - Once we have a new population, we start a new generation. And the above steps are performed until the simulation is closed manually. 43 | - The above steps result in the snakes fine tuning their strategies which inturn lead to longer snakes. 44 | 45 | # Usage 46 | - Clone the repo 47 | ```bash 48 | git clone git@github.com:bones-ai/rust-snake-ai.git 49 | cd rust-snake-ai 50 | ``` 51 | - Run the simulation 52 | ```bash 53 | cargo run --release 54 | ``` 55 | 56 | ## Configurations 57 | - The project config file is located at `src/configs.rs` 58 | - Disable `VIZ_DARK_THEME` changes the theming 59 | - The streams feature is still experimental. A single stream with 1000 snakes will yield quick results. 60 | -------------------------------------------------------------------------------- /screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bones-ai/rust-snake-ai/1d097b88fe3da51c7fcc3b9a2ab08278824fc456/screenshot.png -------------------------------------------------------------------------------- /src/configs.rs: -------------------------------------------------------------------------------- 1 | use macroquad::prelude::*; 2 | 3 | // Game 4 | pub const GRID_W: i32 = 25; 5 | pub const GRID_H: i32 = 25; 6 | 7 | // Sim 8 | pub const NUM_GAMES_PER_STREAM: usize = 1000; 9 | pub const NUM_STREAMS: usize = 1; 10 | pub const NUM_SIM_STEPS: usize = 100; 11 | pub const STREAM_REJUVENATION_PERCENT: f32 = 0.1; 12 | pub const STREAM_LOCAL_MAX_WAIT_SECS: f32 = 90.0; 13 | pub const SIM_SLEEP_MILLIS: u64 = 50; 14 | 15 | // Pop 16 | pub const POP_NUM_RETAINED: f32 = 0.01; 17 | pub const POP_NUM_CHILDREN: f32 = 0.5; 18 | pub const POP_NUM_RANDOM: f32 = 0.2; 19 | pub const POP_NUM_RETAINED_MUTATED: f32 = 0.29; 20 | 21 | // Viz 22 | pub const VIZ_GRID_W: i32 = 5; 23 | pub const VIZ_GRID_H: i32 = 4; 24 | pub const VIZ_DARK_THEME: bool = true; 25 | 26 | // NN 27 | pub const BRAIN_MUTATION_RATE: f32 = 0.1; 28 | pub const BRAIN_MUTATION_VARIATION: f32 = 0.1; 29 | pub const INP_LAYER_SIZE: usize = 12; 30 | pub const HIDDEN_LAYER_SIZE: usize = 8; 31 | pub const OUTPUT_LAYER_SIZE: usize = 4; 32 | -------------------------------------------------------------------------------- /src/game.rs: -------------------------------------------------------------------------------- 1 | //! Snake Game Logic 2 | //! Snake Actions from a Neural Network 3 | 4 | use crate::nn::Net; 5 | use crate::*; 6 | 7 | #[derive(Clone)] 8 | pub struct Game { 9 | pub head: Point, 10 | pub body: Vec, 11 | pub food: Point, 12 | pub dir: FourDirs, 13 | pub brain: Net, 14 | 15 | pub is_complete: bool, 16 | no_food_steps: usize, 17 | num_steps: usize, 18 | } 19 | 20 | impl Game { 21 | pub fn new() -> Self { 22 | let mut body = Vec::new(); 23 | let head = Point::new(GRID_W / 2, GRID_H / 2); 24 | body.push(head.clone()); 25 | 26 | Self { 27 | body, 28 | head, 29 | food: Point::rand(), 30 | dir: FourDirs::get_rand_dir(), 31 | brain: Net::new(), 32 | is_complete: false, 33 | no_food_steps: 0, 34 | num_steps: 0, 35 | } 36 | } 37 | 38 | pub fn update(&mut self) { 39 | if self.is_complete { 40 | return; 41 | } 42 | 43 | self.num_steps += 1; 44 | self.dir = self.get_brain_output(); 45 | self.handle_food_collision(); 46 | self.update_snake_positions(); 47 | self.handle_step_limit(); 48 | if self.is_wall(self.head) || self.is_snake_body(self.head) { 49 | self.is_complete = true; 50 | } 51 | } 52 | 53 | pub fn get_net_output(&self) -> Vec> { 54 | let vision = self.get_snake_vision(); 55 | self.brain.predict(&vision) 56 | } 57 | 58 | fn get_brain_output(&self) -> FourDirs { 59 | let vision = self.get_snake_vision(); 60 | let nn_out = self.brain.predict(&vision).pop().unwrap(); 61 | let max_index = nn_out 62 | .iter() 63 | .enumerate() 64 | .max_by(|(_, &a), (_, &b)| a.partial_cmp(&b).unwrap_or(std::cmp::Ordering::Equal)) 65 | .map(|(i, _)| i) 66 | .unwrap(); 67 | let mut dir = match max_index { 68 | 0 => FourDirs::Left, 69 | 1 => FourDirs::Right, 70 | 2 => FourDirs::Bottom, 71 | _ => FourDirs::Top, 72 | }; 73 | 74 | if self.dir.is_horizontal() { 75 | if dir.is_horizontal() && self.dir != dir { 76 | dir = self.dir; 77 | } 78 | } 79 | if self.dir.is_vertical() { 80 | if dir.is_vertical() && self.dir != dir { 81 | dir = self.dir; 82 | } 83 | } 84 | 85 | dir 86 | } 87 | 88 | fn get_snake_vision(&self) -> Vec { 89 | // self.get_11_vision() 90 | // self.get_custom_vision() 91 | // self.get_eight_dir_vision() 92 | self.get_four_dir_vision() 93 | } 94 | 95 | fn get_four_dir_vision(&self) -> Vec { 96 | let mut vision = Vec::new(); 97 | let dirs = FourDirs::get_all_dirs(); 98 | 99 | for d in dirs { 100 | let (wall, food, body) = self.look_in_dir(self.head, d); 101 | vision.push(wall as f64); 102 | vision.push(if food { 1.0 } else { 0.0 }); 103 | vision.push(body as f64); 104 | } 105 | 106 | vision 107 | } 108 | 109 | pub fn fitness(&self) -> f32 { 110 | let score = self.body.len() as f32; 111 | if score <= 1.0 { 112 | return 1.0; 113 | } 114 | 115 | if score < 5.0 { 116 | return (self.num_steps as f32 * 0.1) * (2.0 as f32).powf(score) * score; 117 | } 118 | 119 | let mut fitness = 1.0; 120 | fitness *= (2.0 as f32).powf(score) * score; 121 | fitness *= self.num_steps as f32; 122 | 123 | // TODO f32 shouldn't work as it can't hold such a big value 124 | // This is broken 125 | fitness 126 | } 127 | 128 | pub fn score(&self) -> usize { 129 | self.body.len() 130 | } 131 | 132 | pub fn is_wall(&self, pt: Point) -> bool { 133 | pt.x >= GRID_W || pt.x <= 0 || pt.y >= GRID_H || pt.y <= 0 134 | } 135 | 136 | pub fn is_snake_body(&self, pt: Point) -> bool { 137 | for p in self.body.iter().skip(1) { 138 | if pt == *p { 139 | return true; 140 | } 141 | } 142 | 143 | false 144 | } 145 | 146 | fn update_snake_positions(&mut self) { 147 | self.head.x += self.dir.value().0; 148 | self.head.y += self.dir.value().1; 149 | 150 | let mut prev_pos = self.head.clone(); 151 | for p in self.body.iter_mut() { 152 | let new_pos = *p; 153 | *p = prev_pos; 154 | prev_pos = new_pos; 155 | } 156 | } 157 | 158 | pub fn with_brain(new_brain: &Net) -> Self { 159 | let mut new_game = Self::new(); 160 | new_game.brain = new_brain.clone(); 161 | 162 | new_game 163 | } 164 | 165 | fn handle_food_collision(&mut self) { 166 | if self.head != self.food { 167 | self.no_food_steps += 1; 168 | return; 169 | } 170 | 171 | self.body.push(Point::new(self.head.x, self.head.y)); 172 | self.food = self.get_random_empty_pos(); 173 | self.no_food_steps = 0; 174 | } 175 | 176 | fn handle_step_limit(&mut self) { 177 | let limit = match self.score() { 178 | score if score > 10 => NUM_SIM_STEPS * 2, 179 | score if score > 20 => NUM_SIM_STEPS * 3, 180 | score if score > 30 => NUM_SIM_STEPS * 5, 181 | score if score > 80 => NUM_SIM_STEPS * 8, 182 | _ => NUM_SIM_STEPS, 183 | }; 184 | 185 | if self.no_food_steps >= limit { 186 | self.is_complete = true; 187 | } 188 | } 189 | 190 | fn get_random_empty_pos(&self) -> Point { 191 | let mut pt = Point::rand(); 192 | 193 | let mut num_tries = 0; 194 | while num_tries < 5 { 195 | num_tries += 1; 196 | pt = Point::rand(); 197 | 198 | if !self.body.contains(&pt) { 199 | break; 200 | } 201 | } 202 | 203 | pt 204 | } 205 | 206 | fn look_in_dir(&self, st: Point, dir: (i32, i32)) -> (f32, bool, f32) { 207 | let mut food = false; 208 | // let mut body = false; 209 | let mut temp_pt: Point = st; 210 | let mut dist = 0; 211 | 212 | loop { 213 | if self.is_wall(temp_pt) { 214 | break; 215 | } 216 | 217 | if self.food == temp_pt { 218 | food = true; 219 | } 220 | 221 | if self.is_snake_body(temp_pt) { 222 | // body = true; 223 | break; 224 | } 225 | 226 | temp_pt = Point::new(temp_pt.x + dir.0, temp_pt.y + dir.1); 227 | 228 | dist += 1; 229 | if dist > 1000 { 230 | break; 231 | } 232 | } 233 | 234 | (1.0 / dist as f32, food, 1.0 / dist as f32) 235 | } 236 | 237 | pub fn render(&self) { 238 | for x in 0..=GRID_W { 239 | for y in 0..=GRID_H { 240 | let pt = (x, y).into(); 241 | if self.is_wall(pt) { 242 | print!("□"); 243 | continue; 244 | } 245 | if self.is_snake_body(pt) { 246 | print!("■"); 247 | continue; 248 | } 249 | if self.head == pt { 250 | print!("■"); 251 | } 252 | if self.food == pt { 253 | print!("●"); 254 | } 255 | print!("."); 256 | } 257 | println!(); 258 | } 259 | println!(); 260 | } 261 | } 262 | 263 | impl PartialEq for Game { 264 | fn eq(&self, other: &Self) -> bool { 265 | self.fitness() == other.fitness() 266 | } 267 | } 268 | 269 | impl PartialOrd for Game { 270 | fn partial_cmp(&self, other: &Self) -> Option { 271 | self.fitness().partial_cmp(&other.fitness()) 272 | } 273 | } 274 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod configs; 2 | pub mod game; 3 | pub mod nn; 4 | pub mod pop; 5 | pub mod sim; 6 | pub mod stream; 7 | pub mod utils; 8 | pub mod viz; 9 | 10 | pub use configs::*; 11 | pub use utils::*; 12 | -------------------------------------------------------------------------------- /src/main.rs: -------------------------------------------------------------------------------- 1 | use std::thread; 2 | use std::time::Duration; 3 | 4 | use macroquad::prelude::*; 5 | 6 | use snake::sim::Simulation; 7 | use snake::*; 8 | 9 | fn window_conf() -> Conf { 10 | Conf { 11 | window_title: "snake-ai".to_owned(), 12 | high_dpi: true, 13 | sample_count: 1, 14 | fullscreen: true, 15 | ..Default::default() 16 | } 17 | } 18 | 19 | #[macroquad::main(window_conf)] 20 | async fn main() { 21 | let mut sim = Simulation::new(); 22 | let mut is_viz_enabled = true; 23 | let mut is_slow_mode = true; 24 | 25 | loop { 26 | let mut iterations = 0; 27 | 28 | loop { 29 | sim.update(is_viz_enabled, is_slow_mode); 30 | iterations += 1; 31 | 32 | if is_slow_mode { 33 | break; 34 | } 35 | if iterations >= 50 { 36 | break; 37 | } 38 | } 39 | 40 | if is_key_pressed(KeyCode::Escape) { 41 | break; 42 | } 43 | if is_key_pressed(KeyCode::Tab) { 44 | is_viz_enabled = !is_viz_enabled; 45 | if !is_viz_enabled { 46 | is_slow_mode = false; 47 | } 48 | } 49 | if is_key_released(KeyCode::Space) { 50 | is_slow_mode = !is_slow_mode; 51 | } 52 | 53 | if is_slow_mode { 54 | thread::sleep(Duration::from_millis(SIM_SLEEP_MILLIS)); 55 | } 56 | next_frame().await 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /src/nn.rs: -------------------------------------------------------------------------------- 1 | //! A simple Neural Network 2 | //! There is no way to train this network 3 | //! It can only be used for neuro-evolution 4 | 5 | use rand::Rng; 6 | 7 | use crate::*; 8 | 9 | #[derive(Clone)] 10 | pub struct Net { 11 | n_inputs: usize, 12 | layers: Vec, 13 | } 14 | 15 | #[derive(Clone)] 16 | struct Layer { 17 | nodes: Vec>, 18 | } 19 | 20 | impl Net { 21 | pub fn new() -> Self { 22 | let layer_sizes = vec![ 23 | INP_LAYER_SIZE, 24 | HIDDEN_LAYER_SIZE, 25 | // HIDDEN_LAYER_SIZE, 26 | OUTPUT_LAYER_SIZE, 27 | ]; 28 | 29 | if layer_sizes.len() < 2 { 30 | panic!("Need at least 2 layers"); 31 | } 32 | for &size in layer_sizes.iter() { 33 | if size < 1 { 34 | panic!("Empty layers not allowed"); 35 | } 36 | } 37 | 38 | let mut layers = Vec::new(); 39 | let first_layer_size = *layer_sizes.first().unwrap(); 40 | let mut prev_layer_size = first_layer_size; 41 | 42 | for &layer_size in layer_sizes[1..].iter() { 43 | layers.push(Layer::new(layer_size, prev_layer_size)); 44 | prev_layer_size = layer_size; 45 | } 46 | 47 | Self { 48 | layers, 49 | n_inputs: first_layer_size, 50 | } 51 | } 52 | 53 | pub fn merge(&self, other: &Net) -> Self { 54 | assert_eq!(self.layers.len(), other.layers.len()); 55 | 56 | let mut merged_layers = Vec::new(); 57 | for i in 0..self.layers.len() { 58 | let merged_layer = &self.layers[i].merge(&other.layers[i]); 59 | merged_layers.push(merged_layer.clone()); 60 | } 61 | 62 | Net { 63 | layers: merged_layers, 64 | n_inputs: self.n_inputs, 65 | } 66 | } 67 | 68 | pub fn predict(&self, inputs: &Vec) -> Vec> { 69 | if inputs.len() != self.n_inputs { 70 | panic!( 71 | "Bad input size, expected {:?} but got {:?}", 72 | self.n_inputs, 73 | inputs.len() 74 | ); 75 | } 76 | 77 | let mut outputs = Vec::new(); 78 | outputs.push(inputs.clone()); 79 | for (layer_index, layer) in self.layers.iter().enumerate() { 80 | let layer_results = layer.predict(&outputs[layer_index]); 81 | outputs.push(layer_results); 82 | } 83 | 84 | outputs 85 | } 86 | 87 | pub fn mutate(&mut self) { 88 | self.layers.iter_mut().for_each(|l| l.mutate()); 89 | } 90 | } 91 | 92 | impl Layer { 93 | fn new(layer_size: usize, prev_layer_size: usize) -> Self { 94 | let mut nodes: Vec> = Vec::new(); 95 | let mut rng = rand::thread_rng(); 96 | 97 | for _ in 0..layer_size { 98 | let mut node: Vec = Vec::new(); 99 | for _ in 0..prev_layer_size + 1 { 100 | let random_weight: f64 = rng.gen_range(-1.0f64..1.0f64); 101 | node.push(random_weight); 102 | } 103 | nodes.push(node); 104 | } 105 | 106 | Self { nodes } 107 | } 108 | 109 | fn merge(&self, other: &Layer) -> Self { 110 | assert_eq!(self.nodes.len(), other.nodes.len()); 111 | let mut rng = rand::thread_rng(); 112 | let mut nodes: Vec> = Vec::new(); 113 | 114 | for (node1, node2) in self.nodes.iter().zip(other.nodes.iter()) { 115 | let mut merged_node = Vec::new(); 116 | for (&weight1, &weight2) in node1.iter().zip(node2.iter()) { 117 | let selected_weight = if rng.gen::() { weight1 } else { weight2 }; 118 | merged_node.push(selected_weight); 119 | } 120 | nodes.push(merged_node); 121 | } 122 | 123 | Self { nodes } 124 | } 125 | 126 | fn predict(&self, inputs: &Vec) -> Vec { 127 | let mut layer_results = Vec::new(); 128 | for node in self.nodes.iter() { 129 | layer_results.push(self.sigmoid(self.dot_prod(&node, &inputs))); 130 | } 131 | 132 | layer_results 133 | } 134 | 135 | fn mutate(&mut self) { 136 | let mut rng = rand::thread_rng(); 137 | 138 | for n in self.nodes.iter_mut() { 139 | for val in n.iter_mut() { 140 | if rng.gen_range(0.0..1.0) >= BRAIN_MUTATION_RATE { 141 | continue; 142 | } 143 | 144 | *val += rng.gen_range(-BRAIN_MUTATION_VARIATION..BRAIN_MUTATION_VARIATION) as f64; 145 | if *val > 1.0 || *val < -1.0 { 146 | let random_weight = rng.gen_range(-1.0f64..1.0f64); 147 | *val = random_weight; 148 | } 149 | } 150 | } 151 | } 152 | 153 | fn dot_prod(&self, node: &Vec, values: &Vec) -> f64 { 154 | let mut it = node.iter(); 155 | let mut total = *it.next().unwrap(); 156 | for (weight, value) in it.zip(values.iter()) { 157 | total += weight * value; 158 | } 159 | 160 | total 161 | } 162 | 163 | fn sigmoid(&self, y: f64) -> f64 { 164 | 1f64 / (1f64 + (-y).exp()) 165 | } 166 | } 167 | -------------------------------------------------------------------------------- /src/pop.rs: -------------------------------------------------------------------------------- 1 | //! Population 2 | //! Handles multiples streams (islands) of neuro-evoloving agents 3 | //! Also responsible for Island Rejuvenation 4 | 5 | use std::time::Instant; 6 | 7 | use rand::Rng; 8 | 9 | use crate::stream::Stream; 10 | use crate::*; 11 | 12 | use self::nn::Net; 13 | 14 | pub struct Population { 15 | gen_start_ts: Instant, 16 | streams: Vec, 17 | } 18 | 19 | pub struct GenerationSummary { 20 | pub time_elapsed_secs: f32, 21 | pub max_score: usize, 22 | pub best_net: Option, 23 | } 24 | 25 | impl Population { 26 | pub fn new() -> Self { 27 | let mut streams = Vec::new(); 28 | for _ in 0..NUM_STREAMS { 29 | streams.push(Stream::new()); 30 | } 31 | 32 | Self { 33 | streams, 34 | gen_start_ts: Instant::now(), 35 | } 36 | } 37 | 38 | pub fn update(&mut self) -> usize { 39 | let mut games_alive = NUM_GAMES_PER_STREAM * NUM_STREAMS; 40 | 41 | for stream in self.streams.iter_mut() { 42 | games_alive -= stream.update(); 43 | } 44 | 45 | games_alive 46 | } 47 | 48 | pub fn reset(&mut self) { 49 | self.gen_start_ts = Instant::now(); 50 | let mut nets = Vec::new(); 51 | 52 | // Streams reset 53 | for stream in self.streams.iter_mut() { 54 | let best_net = stream.reset(); 55 | nets.push(best_net); 56 | } 57 | 58 | // No Streams to cross 59 | if self.streams.len() <= 1 { 60 | return; 61 | } 62 | 63 | // Streams crossing 64 | let mut rng = rand::thread_rng(); 65 | for stream in self.streams.iter_mut() { 66 | if !stream.is_local_maximum() { 67 | continue; 68 | } 69 | 70 | stream.inject(&nets[rng.gen_range(0..nets.len())]); 71 | } 72 | } 73 | 74 | pub fn get_gen_summary(&self) -> GenerationSummary { 75 | let mut max_score = 0; 76 | let mut best_net = None; 77 | 78 | for stream in self.streams.iter() { 79 | let (stream_score, stream_net) = stream.get_stream_summary(); 80 | if stream_score > max_score { 81 | max_score = stream_score; 82 | best_net = stream_net; 83 | } 84 | } 85 | 86 | GenerationSummary { 87 | max_score, 88 | time_elapsed_secs: self.gen_start_ts.elapsed().as_secs_f32(), 89 | best_net, 90 | } 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /src/sim.rs: -------------------------------------------------------------------------------- 1 | //! Simulation 2 | //! Responsible for updating the population and viz 3 | //! Handles generations 4 | 5 | use macroquad::prelude::*; 6 | 7 | use crate::pop::Population; 8 | use crate::viz::Viz; 9 | 10 | pub struct Simulation { 11 | gen_count: usize, 12 | pop: Population, 13 | viz: Viz, 14 | } 15 | 16 | impl Simulation { 17 | pub fn new() -> Self { 18 | Self { 19 | gen_count: 0, 20 | pop: Population::new(), 21 | viz: Viz::new(), 22 | } 23 | } 24 | 25 | pub fn update(&mut self, is_viz_enabled: bool, is_slow_mode: bool) { 26 | let games_alive = self.pop.update(); 27 | if games_alive <= 0 { 28 | self.end_current_genration(); 29 | self.start_new_generation(); 30 | } 31 | 32 | self.viz.update_settings(is_viz_enabled, is_slow_mode); 33 | self.viz.update(); 34 | self.viz.draw(); 35 | } 36 | 37 | pub fn start_new_generation(&mut self) { 38 | self.gen_count += 1; 39 | self.pop.reset(); 40 | } 41 | 42 | pub fn end_current_genration(&mut self) { 43 | let stats = self.pop.get_gen_summary(); 44 | self.viz.reset(stats, self.gen_count); 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /src/stream.rs: -------------------------------------------------------------------------------- 1 | //! Stream 2 | //! Island of neuro-evolving agents 3 | 4 | use std::time::Instant; 5 | 6 | use rand::distributions::{Distribution, WeightedIndex}; 7 | 8 | use crate::game::Game; 9 | use crate::nn::Net; 10 | use crate::*; 11 | 12 | pub struct Stream { 13 | games: Vec, 14 | max_score: usize, 15 | max_score_ts: Instant, 16 | } 17 | 18 | impl Stream { 19 | pub fn new() -> Self { 20 | let mut games = Vec::new(); 21 | for _ in 0..NUM_GAMES_PER_STREAM { 22 | games.push(Game::new()); 23 | } 24 | 25 | Self { 26 | games, 27 | max_score: 0, 28 | max_score_ts: Instant::now(), 29 | } 30 | } 31 | 32 | pub fn update(&mut self) -> usize { 33 | let mut games_alive = NUM_GAMES_PER_STREAM; 34 | 35 | for g in self.games.iter_mut() { 36 | g.update(); 37 | 38 | let score = g.score(); 39 | if score > self.max_score { 40 | self.max_score = score; 41 | self.max_score_ts = Instant::now(); 42 | } 43 | 44 | if g.is_complete { 45 | games_alive -= 1; 46 | } 47 | } 48 | 49 | NUM_GAMES_PER_STREAM - games_alive 50 | } 51 | 52 | pub fn is_local_maximum(&self) -> bool { 53 | self.max_score_ts.elapsed().as_secs_f32() > STREAM_LOCAL_MAX_WAIT_SECS 54 | } 55 | 56 | pub fn inject(&mut self, net: &Net) { 57 | let new_game = Game::with_brain(net); 58 | let num_games = (NUM_GAMES_PER_STREAM as f32 * STREAM_REJUVENATION_PERCENT) as usize; 59 | 60 | self.games.drain(0..num_games); 61 | for _ in 0..num_games { 62 | self.games.push(new_game.clone()); 63 | } 64 | 65 | self.max_score = 0; 66 | self.max_score_ts = Instant::now(); 67 | } 68 | 69 | pub fn get_stream_summary(&self) -> (usize, Option) { 70 | let mut max_score = 0; 71 | let mut best_net = None; 72 | 73 | for g in self.games.iter() { 74 | let score = g.score(); 75 | if score > max_score { 76 | max_score = score; 77 | best_net = Some(g.brain.clone()); 78 | } 79 | } 80 | 81 | (max_score, best_net) 82 | } 83 | 84 | pub fn reset(&mut self) -> Net { 85 | let mut rng = rand::thread_rng(); 86 | let gene_pool = self.generate_gene_pool(); 87 | let mut new_games = Vec::new(); 88 | 89 | // Population Distribution 90 | let num_retained = NUM_GAMES_PER_STREAM as f32 * POP_NUM_RETAINED; 91 | let num_children = NUM_GAMES_PER_STREAM as f32 * POP_NUM_CHILDREN; 92 | let num_random = NUM_GAMES_PER_STREAM as f32 * POP_NUM_RANDOM; 93 | let mut num_retained_mutated = NUM_GAMES_PER_STREAM as f32 * POP_NUM_RETAINED_MUTATED; 94 | 95 | // Retained no mutation 96 | let mut games_sorted = self.games.clone(); 97 | games_sorted.sort_by(|a, b| a.partial_cmp(b).unwrap()); 98 | games_sorted.reverse(); 99 | for i in 0..num_retained as usize { 100 | let old_brain = games_sorted[i].brain.clone(); 101 | let mut new_game = Game::new(); 102 | new_game.brain = old_brain; 103 | 104 | new_games.push(new_game); 105 | } 106 | 107 | // Children 108 | if let Some(pool) = gene_pool { 109 | for _ in 0..num_children as i32 { 110 | let rand_parent_1 = self.games[pool.sample(&mut rng)].clone(); 111 | let rand_parent_2 = self.games[pool.sample(&mut rng)].clone(); 112 | let mut new_brain = rand_parent_1.brain.merge(&rand_parent_2.brain); 113 | new_brain.mutate(); 114 | 115 | let new_game = Game::with_brain(&new_brain); 116 | new_games.push(new_game); 117 | } 118 | } else { 119 | // TODO: Error, failed to create a gene pool 120 | num_retained_mutated += num_children; 121 | } 122 | 123 | // Retained with mutation 124 | for i in 0..num_retained_mutated as usize { 125 | let mut old_brain = games_sorted[i].brain.clone(); 126 | let mut new_game = Game::new(); 127 | old_brain.mutate(); 128 | new_game.brain = old_brain; 129 | 130 | new_games.push(new_game); 131 | } 132 | 133 | // Full random 134 | for _ in 0..num_random as i32 { 135 | new_games.push(Game::new()); 136 | } 137 | 138 | self.games = new_games; 139 | games_sorted[0].brain.clone() 140 | } 141 | 142 | fn generate_gene_pool(&self) -> Option> { 143 | let mut max_fitness = 0.0; 144 | let mut weights = Vec::new(); 145 | 146 | for game in self.games.iter() { 147 | let fitness = game.fitness(); 148 | if fitness > max_fitness { 149 | max_fitness = fitness; 150 | } 151 | 152 | if fitness.is_finite() { 153 | weights.push(fitness); 154 | } 155 | } 156 | weights 157 | .iter_mut() 158 | .for_each(|i| *i = (*i / max_fitness) * 100.0); 159 | 160 | WeightedIndex::new(&weights).ok() 161 | } 162 | } 163 | -------------------------------------------------------------------------------- /src/utils.rs: -------------------------------------------------------------------------------- 1 | use macroquad::color::Color; 2 | use rand::Rng; 3 | 4 | use crate::*; 5 | 6 | #[derive(Default, PartialEq, Eq, Hash, Clone, Copy, Debug)] 7 | pub struct Point { 8 | pub x: i32, 9 | pub y: i32, 10 | } 11 | 12 | #[derive(Debug, Clone, Copy, Default, PartialEq)] 13 | pub enum FourDirs { 14 | #[default] 15 | Left, 16 | Right, 17 | Bottom, 18 | Top, 19 | } 20 | 21 | pub fn map_to_unit_interval(value: f32, range: f32) -> f32 { 22 | let x_abs = range.abs(); 23 | let clamped_value = value.clamp(-x_abs, x_abs); 24 | (clamped_value + x_abs) / (2.0 * x_abs) 25 | } 26 | 27 | pub fn grid_to_world(x: i32, y: i32, tile_size: f32, scale: f32) -> (f32, f32) { 28 | (x as f32 * tile_size * scale, y as f32 * tile_size * scale) 29 | } 30 | 31 | pub fn color_with_a(color: Color, a: f32) -> Color { 32 | Color::new(color.r, color.g, color.b, a) 33 | } 34 | 35 | pub fn are_colors_equal(c1: Color, c2: Color) -> bool { 36 | c1.r == c2.r && c1.g == c2.g && c1.b == c2.b 37 | } 38 | 39 | impl FourDirs { 40 | pub fn get_rand_dir() -> Self { 41 | let mut rng = rand::thread_rng(); 42 | match rng.gen_range(0..4) { 43 | 0 => Self::Left, 44 | 1 => Self::Right, 45 | 2 => Self::Bottom, 46 | _ => Self::Top, 47 | } 48 | } 49 | 50 | pub fn get_all_dirs() -> [(i32, i32); 4] { 51 | [ 52 | Self::Left.value(), 53 | Self::Right.value(), 54 | Self::Bottom.value(), 55 | Self::Top.value(), 56 | ] 57 | } 58 | 59 | pub fn get_rand_horizontal() -> Self { 60 | let mut rng = rand::thread_rng(); 61 | match rng.gen_range(0..2) { 62 | 0 => Self::Left, 63 | _ => Self::Right, 64 | } 65 | } 66 | 67 | pub fn get_rand_vertical() -> Self { 68 | let mut rng = rand::thread_rng(); 69 | match rng.gen_range(0..2) { 70 | 0 => Self::Top, 71 | _ => Self::Bottom, 72 | } 73 | } 74 | 75 | pub fn value(&self) -> (i32, i32) { 76 | match self { 77 | Self::Left => (-1, 0), 78 | Self::Right => (1, 0), 79 | Self::Bottom => (0, 1), 80 | Self::Top => (0, -1), 81 | } 82 | } 83 | 84 | pub fn is_horizontal(&self) -> bool { 85 | match self { 86 | FourDirs::Left => true, 87 | FourDirs::Right => true, 88 | _ => false, 89 | } 90 | } 91 | 92 | pub fn is_vertical(&self) -> bool { 93 | match self { 94 | FourDirs::Top => true, 95 | FourDirs::Bottom => true, 96 | _ => false, 97 | } 98 | } 99 | } 100 | 101 | impl Point { 102 | pub fn new(x: i32, y: i32) -> Self { 103 | Self { x, y } 104 | } 105 | 106 | pub fn rand() -> Self { 107 | let mut rng = rand::thread_rng(); 108 | Self { 109 | x: rng.gen_range(1..GRID_W - 1), 110 | y: rng.gen_range(1..GRID_H - 1), 111 | } 112 | } 113 | } 114 | 115 | impl Into for (i32, i32) { 116 | fn into(self) -> Point { 117 | Point { 118 | x: self.0, 119 | y: self.1, 120 | } 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /src/viz.rs: -------------------------------------------------------------------------------- 1 | //! Visualization 2 | //! Handles everything drawn on screen 3 | 4 | use macroquad::prelude::*; 5 | 6 | use std::time::Instant; 7 | 8 | use crate::game::Game; 9 | use crate::nn::Net; 10 | use crate::pop::GenerationSummary; 11 | use crate::*; 12 | 13 | pub struct Viz { 14 | games: Vec, 15 | sim_start_ts: Instant, 16 | max_score: usize, 17 | gen_count: usize, 18 | best_brain: Option, 19 | 20 | is_slow_mode: bool, 21 | is_show_viz: bool, 22 | colors: Colors, 23 | } 24 | 25 | struct Colors { 26 | bg: Color, 27 | snake_head: Color, 28 | snake_body: Color, 29 | food: Color, 30 | wall: Color, 31 | text: Color, 32 | node_enabled: Color, 33 | node_disabled: Color, 34 | node_hidden: Color, 35 | positive: Color, 36 | negative: Color, 37 | disabled: Color, 38 | opacity: f32, 39 | } 40 | 41 | impl Viz { 42 | pub fn new() -> Self { 43 | Self { 44 | games: Vec::new(), 45 | sim_start_ts: Instant::now(), 46 | max_score: 0, 47 | gen_count: 0, 48 | best_brain: None, 49 | is_slow_mode: false, 50 | is_show_viz: false, 51 | colors: if VIZ_DARK_THEME { 52 | Colors::dark() 53 | } else { 54 | Colors::light() 55 | }, 56 | } 57 | } 58 | 59 | pub fn update(&mut self) { 60 | if !self.is_show_viz { 61 | return; 62 | } 63 | 64 | let mut num_completed = 0; 65 | for g in self.games.iter_mut() { 66 | if g.is_complete { 67 | num_completed += 1; 68 | } 69 | 70 | g.update(); 71 | } 72 | 73 | if num_completed >= self.games.len() { 74 | self.init_games(); 75 | } 76 | } 77 | 78 | fn init_games(&mut self) { 79 | let new_brain = Net::new(); 80 | let brain = match &self.best_brain { 81 | Some(brain) => brain, 82 | None => &new_brain, 83 | }; 84 | 85 | let num_games = 100; 86 | let mut games = Vec::new(); 87 | for _ in 0..num_games { 88 | games.push(Game::with_brain(brain)); 89 | } 90 | self.games = games; 91 | } 92 | 93 | pub fn reset(&mut self, summary: GenerationSummary, gen_count: usize) { 94 | if summary.max_score > self.max_score { 95 | self.max_score = summary.max_score; 96 | self.best_brain = summary.best_net; 97 | // self.init_games(); 98 | } 99 | 100 | self.gen_count = gen_count; 101 | self.print_gen_info(summary.max_score); 102 | } 103 | 104 | pub fn draw(&self) { 105 | clear_background(self.colors.bg); 106 | 107 | self.draw_stats(); 108 | self.draw_best_games(); 109 | self.draw_net(); 110 | } 111 | 112 | fn draw_best_games(&self) { 113 | let mut pos_x = 0; 114 | let mut pos_y = 0; 115 | 116 | if self.games.len() <= 0 || !self.is_show_viz { 117 | return; 118 | } 119 | 120 | let grid_zero = [0, 1, VIZ_GRID_W, VIZ_GRID_W + 1]; 121 | let mut best_games = self.games.clone(); 122 | best_games.sort_by(|a, b| a.partial_cmp(b).unwrap()); 123 | best_games.reverse(); 124 | 125 | for index in 0..(VIZ_GRID_H * VIZ_GRID_W) { 126 | if !grid_zero.contains(&(index as i32)) { 127 | let game = &best_games[index as usize]; 128 | self.draw_game(game, pos_x, pos_y, 1.0); 129 | } 130 | 131 | pos_x += 1; 132 | if pos_x >= VIZ_GRID_W { 133 | pos_x = 0; 134 | pos_y += 1; 135 | } 136 | if pos_y >= VIZ_GRID_H { 137 | break; 138 | } 139 | } 140 | 141 | // Render 0th game 142 | let game = &best_games[0]; 143 | self.draw_game(game, 0, 0, 1.96); 144 | } 145 | 146 | fn draw_game(&self, game: &Game, pos_x: i32, pos_y: i32, scale: f32) { 147 | let padding = 10.0; 148 | let w = (screen_width() - padding * 2.0) * 0.7; 149 | let h = (screen_height() - padding * 2.0) * 0.99; 150 | let sq = w.min(h); 151 | let tile_size = ((sq / 4.0) / GRID_W as f32) * scale as f32; 152 | 153 | for x in 0..=GRID_W { 154 | for y in 0..=GRID_H { 155 | let mut color = self.colors.bg; 156 | let pt = (x, y).into(); 157 | 158 | if game.is_wall(pt) { 159 | color = self.colors.wall; 160 | } 161 | if game.is_snake_body(pt) { 162 | color = self.colors.snake_body; 163 | } 164 | if game.head == pt { 165 | color = self.colors.snake_head; 166 | } 167 | if game.food == pt { 168 | color = self.colors.food; 169 | } 170 | 171 | if game.is_complete { 172 | color = color_with_a(color, self.colors.opacity); 173 | } 174 | 175 | let (tx, ty) = 176 | grid_to_world((pos_x * GRID_W) + x, (pos_y * GRID_H) + y, tile_size, 1.0); 177 | draw_rectangle(tx + padding, ty + padding, tile_size, tile_size, color); 178 | } 179 | } 180 | 181 | let (tx, ty) = grid_to_world( 182 | (pos_x * GRID_W) + 3, 183 | (pos_y * GRID_H) + GRID_H, 184 | tile_size, 185 | 1.0, 186 | ); 187 | draw_text( 188 | format!("{:?}", game.score()).as_str(), 189 | tx, 190 | ty, 191 | 30.0, 192 | if game.is_complete { 193 | self.colors.disabled 194 | } else { 195 | self.colors.text 196 | }, 197 | ); 198 | } 199 | 200 | fn draw_net(&self) { 201 | if self.games.is_empty() || !self.is_show_viz { 202 | return; 203 | } 204 | 205 | let padding = 10.0; 206 | let w = (screen_width() - padding * 2.0) * 0.75 + 50.0; 207 | let h = screen_height() * 1.00; 208 | 209 | let node_border_color = color_with_a(GRAY, 0.0); 210 | let node_radius = 25.0; 211 | let node_border_thickness = 2.0; 212 | let line_thickness = 3.0; 213 | let y_padding = 120.0; 214 | let layer_1_x_padding = 0.0; 215 | let layer_2_x_padding = 150.0; 216 | let layer_3_x_padding = 300.0; 217 | 218 | let layer_1_y = self.calculate_circle_positions(INP_LAYER_SIZE, node_radius, h, 15.0); 219 | let layer_2_y = self.calculate_circle_positions(HIDDEN_LAYER_SIZE, node_radius, h, 15.0); 220 | let layer_3_y = self.calculate_circle_positions(OUTPUT_LAYER_SIZE, node_radius, h, 15.0); 221 | let (colors1, colors2, colors3) = self.get_node_colors(); 222 | 223 | // Bottom Text 224 | let bt_x = screen_width() * 0.75; 225 | draw_text( 226 | "Input", 227 | bt_x + 5.0, 228 | screen_height() * 0.97, 229 | 30.0, 230 | self.colors.text, 231 | ); 232 | draw_text( 233 | "Hidden", 234 | bt_x + layer_2_x_padding, 235 | screen_height() * 0.85, 236 | 30.0, 237 | self.colors.text, 238 | ); 239 | draw_text( 240 | "Output", 241 | bt_x + layer_3_x_padding, 242 | screen_height() * 0.725, 243 | 30.0, 244 | self.colors.text, 245 | ); 246 | 247 | // Lines 248 | for (y1, c1) in layer_1_y.iter().zip(colors1.iter()) { 249 | for y2 in layer_2_y.iter() { 250 | let color = self.get_line_color(*c1); 251 | draw_line( 252 | w + layer_1_x_padding, 253 | *y1 + y_padding, 254 | w + layer_2_x_padding, 255 | *y2 + y_padding, 256 | line_thickness, 257 | color, 258 | ); 259 | } 260 | } 261 | for y2 in layer_2_y.iter() { 262 | for (y3, c3) in layer_3_y.iter().zip(colors3.iter()) { 263 | let color = self.get_line_color(*c3); 264 | draw_line( 265 | w + layer_2_x_padding, 266 | *y2 + y_padding, 267 | w + layer_3_x_padding, 268 | *y3 + y_padding, 269 | line_thickness, 270 | color, 271 | ); 272 | } 273 | } 274 | 275 | // Nodes 276 | for (y, c) in layer_1_y.iter().zip(colors1.iter()) { 277 | draw_circle(w + layer_1_x_padding, *y + y_padding, node_radius, *c); 278 | draw_circle_lines( 279 | w + layer_1_x_padding, 280 | *y + y_padding, 281 | node_radius, 282 | node_border_thickness, 283 | node_border_color, 284 | ); 285 | } 286 | for (y, c) in layer_2_y.iter().zip(colors2.iter()) { 287 | draw_circle(w + layer_2_x_padding, *y + y_padding, node_radius, *c); 288 | draw_circle_lines( 289 | w + layer_2_x_padding, 290 | *y + y_padding, 291 | node_radius, 292 | node_border_thickness, 293 | node_border_color, 294 | ); 295 | } 296 | for ((idx, y), c) in layer_3_y.iter().enumerate().zip(colors3.iter()) { 297 | let (px, py) = (w + layer_3_x_padding, *y + y_padding); 298 | draw_circle(px, py, node_radius, *c); 299 | draw_circle_lines( 300 | px, 301 | py, 302 | node_radius, 303 | node_border_thickness, 304 | node_border_color, 305 | ); 306 | let text = match idx { 307 | 0 => "Left", 308 | 1 => "Right", 309 | 2 => "Bottom", 310 | _ => "Top", 311 | }; 312 | let color = if are_colors_equal(*c, self.colors.node_enabled) { 313 | self.colors.text 314 | } else { 315 | self.colors.disabled 316 | }; 317 | draw_text(text, px + 50.0, py + 5.0, 30.0, color); 318 | } 319 | } 320 | 321 | fn draw_stats(&self) { 322 | let w = screen_width() * 0.78; 323 | let h = screen_height() * 0.07; 324 | 325 | draw_text( 326 | format!("Gen: {:?}", self.gen_count).as_str(), 327 | w, 328 | h, 329 | 50.0, 330 | self.colors.text, 331 | ); 332 | draw_text( 333 | format!("Score: {:?}", self.max_score).as_str(), 334 | w, 335 | h + 40.0, 336 | 50.0, 337 | self.colors.text, 338 | ); 339 | draw_text( 340 | format!("Slow: {:?}", self.is_slow_mode).as_str(), 341 | w, 342 | h + 80.0, 343 | 50.0, 344 | if self.is_slow_mode { 345 | self.colors.positive 346 | } else { 347 | self.colors.negative 348 | }, 349 | ); 350 | draw_text( 351 | format!("Viz: {:?}", self.is_show_viz).as_str(), 352 | w, 353 | h + 120.0, 354 | 50.0, 355 | if self.is_show_viz { 356 | self.colors.positive 357 | } else { 358 | self.colors.negative 359 | }, 360 | ); 361 | 362 | if !self.is_show_viz { 363 | draw_text( 364 | "[Space] - Slow motion", 365 | w, 366 | h + 250.0, 367 | 30.0, 368 | self.colors.text, 369 | ); 370 | draw_text("[Tab] - Show Viz", w, h + 280.0, 30.0, self.colors.text); 371 | } 372 | } 373 | 374 | fn get_line_color(&self, c1: Color) -> Color { 375 | let mut output_color; 376 | if are_colors_equal(c1, self.colors.node_enabled) { 377 | output_color = self.colors.node_enabled; 378 | output_color.a = 0.3; 379 | } else { 380 | output_color = self.colors.node_disabled; 381 | output_color.a = 0.1; 382 | } 383 | 384 | output_color 385 | } 386 | 387 | pub fn update_settings(&mut self, is_viz_enabled: bool, is_slow_mode: bool) { 388 | self.is_show_viz = is_viz_enabled; 389 | self.is_slow_mode = is_slow_mode; 390 | } 391 | 392 | fn get_node_colors(&self) -> (Vec, Vec, Vec) { 393 | let mut color_enabled = self.colors.node_enabled; 394 | let mut color_disabled = self.colors.node_disabled; 395 | let mut color_hidden = self.colors.node_hidden; 396 | 397 | // TODO remove the resorting 398 | let mut best_games = self.games.clone(); 399 | best_games.sort_by(|a, b| a.partial_cmp(b).unwrap()); 400 | best_games.reverse(); 401 | let game = &best_games[0]; 402 | 403 | if game.is_complete { 404 | color_enabled = self.colors.disabled; 405 | color_disabled = self.colors.disabled; 406 | color_hidden = self.colors.disabled; 407 | } 408 | 409 | let net_out = game.get_net_output(); 410 | let inputs = net_out[0].clone(); 411 | let hidden = net_out[1].clone(); 412 | let output = net_out[2].clone(); 413 | 414 | let mut input_colors = Vec::new(); 415 | for i in inputs.iter() { 416 | if *i == 0.0 { 417 | input_colors.push(color_disabled); 418 | } else if *i > 0.2 { 419 | input_colors.push(color_enabled); 420 | } else { 421 | input_colors.push(color_disabled); 422 | } 423 | } 424 | 425 | let mut hidden_colors = Vec::new(); 426 | for i in hidden.iter() { 427 | let opacity = map_to_unit_interval(*i as f32, 0.5); 428 | if game.is_complete { 429 | hidden_colors.push(color_with_a(color_hidden, 1.0)); 430 | } else if opacity.is_finite() { 431 | hidden_colors.push(color_with_a(color_hidden, opacity)); 432 | } else { 433 | hidden_colors.push(color_with_a(color_hidden, 0.8)); 434 | } 435 | } 436 | 437 | let max_index = output 438 | .iter() 439 | .enumerate() 440 | .max_by(|(_, &a), (_, &b)| a.partial_cmp(&b).unwrap_or(std::cmp::Ordering::Equal)) 441 | .map(|(i, _)| i) 442 | .unwrap(); 443 | let mut dir = match max_index { 444 | 0 => FourDirs::Left, 445 | 1 => FourDirs::Right, 446 | 2 => FourDirs::Bottom, 447 | _ => FourDirs::Top, 448 | }; 449 | 450 | if game.dir.is_horizontal() { 451 | if dir.is_horizontal() && game.dir != dir { 452 | dir = game.dir; 453 | } 454 | } 455 | if game.dir.is_vertical() { 456 | if dir.is_vertical() && game.dir != dir { 457 | dir = game.dir; 458 | } 459 | } 460 | 461 | let mut output_colors = vec![ 462 | color_disabled, 463 | color_disabled, 464 | color_disabled, 465 | color_disabled, 466 | ]; 467 | if dir == FourDirs::Left { 468 | output_colors[0] = color_enabled; 469 | } 470 | if dir == FourDirs::Right { 471 | output_colors[1] = color_enabled; 472 | } 473 | if dir == FourDirs::Bottom { 474 | output_colors[2] = color_enabled; 475 | } 476 | if dir == FourDirs::Top { 477 | output_colors[3] = color_enabled; 478 | } 479 | 480 | // colors 481 | (input_colors, hidden_colors, output_colors) 482 | } 483 | 484 | fn calculate_circle_positions(&self, n: usize, r: f32, h: f32, y: f32) -> Vec { 485 | let total_height = n as f32 * (2.0 * r) + (n as f32 - 1.0) * y; 486 | let top_y = (h - total_height) / 2.0; 487 | 488 | let mut positions = Vec::new(); 489 | for i in 0..n { 490 | let circle_y = top_y + (2.0 * r + y) * i as f32; 491 | positions.push(circle_y); 492 | } 493 | 494 | positions 495 | } 496 | 497 | fn print_gen_info(&self, gen_max_score: usize) { 498 | let message = format!( 499 | "Gen: {}, Max Score: {}, Gen Max: {}, Sim Ts: {:.2?}m", 500 | self.gen_count, 501 | self.max_score, 502 | gen_max_score, 503 | self.sim_start_ts.elapsed().as_secs_f32() / 60.0, 504 | ); 505 | println!("{}", message); 506 | } 507 | } 508 | 509 | impl Colors { 510 | fn dark() -> Self { 511 | Self { 512 | bg: Color::from_hex(0x28334f), 513 | snake_head: Color::from_hex(0xe982f4), 514 | snake_body: Color::from_hex(0x67dbf8), 515 | food: Color::from_hex(0x7aed86), 516 | wall: Color::from_hex(0xadb4bf), 517 | text: WHITE, 518 | node_enabled: Color::from_hex(0x7aed86), 519 | node_disabled: Color::from_hex(0xfb7171), 520 | node_hidden: SKYBLUE, 521 | positive: Color::from_hex(0x51de88), 522 | negative: Color::from_hex(0xfb7171), 523 | disabled: GRAY, 524 | opacity: 0.3, 525 | } 526 | } 527 | 528 | fn light() -> Self { 529 | Self { 530 | bg: WHITE, 531 | snake_head: BLUE, 532 | snake_body: GREEN, 533 | food: RED, 534 | wall: BROWN, 535 | text: BLACK, 536 | node_enabled: GREEN, 537 | node_disabled: RED, 538 | node_hidden: SKYBLUE, 539 | positive: GREEN, 540 | negative: RED, 541 | disabled: GRAY, 542 | opacity: 0.3, 543 | } 544 | } 545 | } 546 | --------------------------------------------------------------------------------