├── .gitignore ├── Cargo.toml ├── LICENSE-APACHE ├── LICENSE-MIT ├── README.md ├── base65536 ├── Cargo.toml └── src │ └── lib.rs ├── export ├── Cargo.toml └── src │ └── main.rs ├── plot_ratings.py ├── slimnn ├── Cargo.toml └── src │ ├── activations.rs │ ├── conv.rs │ ├── lib.rs │ ├── linear.rs │ └── loading.rs ├── study-connect4 ├── Cargo.toml └── src │ ├── connect4.rs │ ├── main.rs │ └── policies.rs └── synthesis ├── Cargo.toml └── src ├── alpha_zero.rs ├── config.rs ├── data.rs ├── evaluator.rs ├── game.rs ├── lib.rs ├── mcts.rs ├── policies ├── cache.rs ├── mod.rs ├── rollout.rs └── traits.rs ├── prelude.rs └── utils.rs /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /_logs 3 | /.vscode 4 | /_models 5 | .cargo 6 | .idea 7 | Cargo.lock 8 | bayeselo.exe -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | 3 | members = [ 4 | "base65536", 5 | "slimnn", 6 | "synthesis", 7 | "export", 8 | "study-connect4", 9 | ] 10 | -------------------------------------------------------------------------------- /LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | Permission is hereby granted, free of charge, to any 2 | person obtaining a copy of this software and associated 3 | documentation files (the "Software"), to deal in the 4 | Software without restriction, including without 5 | limitation the rights to use, copy, modify, merge, 6 | publish, distribute, sublicense, and/or sell copies of 7 | the Software, and to permit persons to whom the Software 8 | is furnished to do so, subject to the following 9 | conditions: 10 | 11 | The above copyright notice and this permission notice 12 | shall be included in all copies or substantial portions 13 | of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF 16 | ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED 17 | TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A 18 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT 19 | SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 20 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 21 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR 22 | IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 23 | DEALINGS IN THE SOFTWARE. 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Synthesis: A Rust implementation of AlphaZero 2 | 3 | This repo is a cargo workspace made up of multiple crates & binaries: 4 | 5 | - `synthesis`: The main library crate with all the main training & MCTS logic in it 6 | - `study-connect4`: A binary crate that uses the synthesis library to train a network to play Connect4 7 | - `base65536`: A small crate to encode/decode u8's into valid utf-8 strings 8 | - `slimnn`: A small neural network crate in pure rust 9 | - `export`: A binary crate that saves pytorch weights into a format slimnn can understand 10 | 11 | ```cargo run --release --bin study-connect4``` 12 | 13 | ## What's implemented 14 | 15 | - Integration with the tch-rs [1] package to support pytorch in rust 16 | - 💪 General MCTS implementation that supports the standard rollout method as well as using a NN in place of rollouts 17 | - Includes MCTS Solver [2] 18 | - Includes FPU [3] 19 | - 💡 An AlphaZero [4] learner that collects experience using MCTS+NN and trains a policy and value function 20 | - Supports multiple value targets 21 | - All hyperparameters exposed 22 | - Multi threaded support! 👩‍👩‍👧‍👧 23 | - 📈 Lightweight evaluation against standard rollout mcts with various number of explores 24 | - Saves game outcomes to a pgn file 25 | - Runs bayeselo [5] executable to produce elo ratings 26 | - Plots ratings 🎉 27 | - 🎲 9x7 Connect4 as a playground to test things 28 | - 😎 Support for running without torch 29 | - `slimnn` for simple NN layer implementations 30 | - `export` & `base65536` for converting torch weights to utf-8 strings 31 | 32 | 1. https://github.com/LaurentMazare/tch-rs 33 | 2. Winands, Mark HM, Yngvi Björnsson, and Jahn-Takeshi Saito. "Monte-Carlo tree search solver." International Conference on Computers and Games. Springer, Berlin, Heidelberg, 2008. 34 | 3. Gelly, Sylvain, and Yizao Wang. "Exploration exploitation in go: UCT for Monte-Carlo go." NIPS: Neural Information Processing Systems Conference On-line trading of Exploration and Exploitation Workshop. 2006. 35 | 4. https://deepmind.com/blog/article/alphazero-shedding-new-light-grand-games-chess-shogi-and-go 36 | 5. https://www.remi-coulom.fr/Bayesian-Elo/ 37 | 38 | ## Improvements 39 | 40 | ###### General 41 | 42 | - [ ] Evaluation metrics in addition to elo: 43 | - [ ] Depth reached 44 | - [ ] Something for how quickly positions are solved 45 | - [ ] Search policy accuracy 46 | - [ ] value accuracy against Q 47 | - [ ] value accuracy against 2-ply minimax value 48 | - [ ] mix mcst tree and minimax tree (of solved nodes) using p(correct) 49 | - [ ] Support transpositions (and backprop to multiple parents) while training... does this improve strength? 50 | - [ ] Score Bounded solver https://www.lamsade.dauphine.fr/~cazenave/papers/mcsolver.pdf 51 | - [ ] Ordinal MCTS https://arxiv.org/pdf/1901.04274.pdf 52 | - [ ] Regularized Policy Optimization https://arxiv.org/abs/2007.12509 53 | - [ ] Schedules for various parameters 54 | - [ ] sample_actions_until 55 | - [ ] value target 56 | - [ ] noise_weight 57 | - [ ] New algorithm for separate exploration/exploitation 58 | - [ ] Is this ExIt? https://arxiv.org/pdf/1705.08439.pdf 59 | - [ ] exploration process that builds off of exploit play line by sampling other states backward 60 | - [ ] exploit process that samples a state from ^ and exploits all the way down 61 | 62 | ###### Performance 63 | - [x] compiler flags (LTO=fat, codegen-units=1, target=native) 64 | - [x] multi threaded gather_experience 65 | - [ ] Reduce allocations (pre allocated buffer for MCTS nodes?) 66 | - [ ] speed up conv2d with im2col https://leonardoaraujosantos.gitbook.io/artificial-inteligence/machine_learning/deep_learning/convolution_layer/making_faster 67 | - [ ] https://sahnimanas.github.io/post/anatomy-of-a-high-performance-convolution/ 68 | - [ ] reverse linear weight dimensions for speed up 69 | - [ ] support outputting 16 bit floats instead of 32 bit floats https://github.com/starkat99/half-rs/blob/master/src/bfloat/convert.rs 70 | 71 | ## Resources for learning more about AlphaZero 72 | 73 | - https://medium.com/@sleepsonthefloor/azfour-a-connect-four-webapp-powered-by-the-alphazero-algorithm-d0c82d6f3ae9 74 | - https://deepmind.com/blog/article/alphazero-shedding-new-light-grand-games-chess-shogi-and-go 75 | - https://www.nature.com/articles/nature24270.epdf 76 | - https://dselsam.github.io/posts/2018-06-06-issues-with-alpha-zero.html 77 | - https://github.com/deepmind/open_spiel/blob/master/open_spiel/algorithms/alpha_zero_torch/alpha_zero.cc 78 | - https://lczero.org/blog/2018/12/alphazero-paper-and-lc0-v0191/ 79 | - http://proceedings.mlr.press/v97/tian19a/tian19a.pdf 80 | - https://link.springer.com/content/pdf/10.1007/s00521-021-05928-5.pdf 81 | 82 | # License 83 | 84 | Dual-licensed to be compatible with the Rust project. 85 | 86 | Licensed under the Apache License, Version 2.0 http://www.apache.org/licenses/LICENSE-2.0 or the MIT license http://opensource.org/licenses/MIT, at your option. This file may not be copied, modified, or distributed except according to those terms. 87 | -------------------------------------------------------------------------------- /base65536/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "base65536" 3 | version = "0.1.0" 4 | authors = ["Corey Lowman "] 5 | edition = "2018" 6 | license = "MIT OR Apache-2.0" 7 | 8 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 9 | 10 | [dependencies] 11 | -------------------------------------------------------------------------------- /base65536/src/lib.rs: -------------------------------------------------------------------------------- 1 | // https://github.com/Parkayun/base65536 2 | const BLOCK_START: [u32; 256] = [ 3 | 13312, 13568, 13824, 14080, 14336, 14592, 14848, 15104, 15360, 15616, 15872, 16128, 16384, 4 | 16640, 16896, 17152, 17408, 17664, 17920, 18176, 18432, 18688, 18944, 19200, 19456, 19968, 5 | 20224, 20480, 20736, 20992, 21248, 21504, 21760, 22016, 22272, 22528, 22784, 23040, 23296, 6 | 23552, 23808, 24064, 24320, 24576, 24832, 25088, 25344, 25600, 25856, 26112, 26368, 26624, 7 | 26880, 27136, 27392, 27648, 27904, 28160, 28416, 28672, 28928, 29184, 29440, 29696, 29952, 8 | 30208, 30464, 30720, 30976, 31232, 31488, 31744, 32000, 32256, 32512, 32768, 33024, 33280, 9 | 33536, 33792, 34048, 34304, 34560, 34816, 35072, 35328, 35584, 35840, 36096, 36352, 36608, 10 | 36864, 37120, 37376, 37632, 37888, 38144, 38400, 38656, 38912, 39168, 39424, 39680, 39936, 11 | 40192, 40448, 41216, 41472, 41728, 42240, 67072, 73728, 73984, 74240, 77824, 78080, 78336, 12 | 78592, 82944, 83200, 92160, 92416, 131072, 131328, 131584, 131840, 132096, 132352, 132608, 13 | 132864, 133120, 133376, 133632, 133888, 134144, 134400, 134656, 134912, 135168, 135424, 135680, 14 | 135936, 136192, 136448, 136704, 136960, 137216, 137472, 137728, 137984, 138240, 138496, 138752, 15 | 139008, 139264, 139520, 139776, 140032, 140288, 140544, 140800, 141056, 141312, 141568, 141824, 16 | 142080, 142336, 142592, 142848, 143104, 143360, 143616, 143872, 144128, 144384, 144640, 144896, 17 | 145152, 145408, 145664, 145920, 146176, 146432, 146688, 146944, 147200, 147456, 147712, 147968, 18 | 148224, 148480, 148736, 148992, 149248, 149504, 149760, 150016, 150272, 150528, 150784, 151040, 19 | 151296, 151552, 151808, 152064, 152320, 152576, 152832, 153088, 153344, 153600, 153856, 154112, 20 | 154368, 154624, 154880, 155136, 155392, 155648, 155904, 156160, 156416, 156672, 156928, 157184, 21 | 157440, 157696, 157952, 158208, 158464, 158720, 158976, 159232, 159488, 159744, 160000, 160256, 22 | 160512, 160768, 161024, 161280, 161536, 161792, 162048, 162304, 162560, 162816, 163072, 163328, 23 | 163584, 163840, 164096, 164352, 164608, 164864, 165120, 24 | ]; 25 | 26 | pub fn encode(bytes: &[u8]) -> String { 27 | let mut s = String::new(); 28 | for i in (0..bytes.len()).step_by(2) { 29 | let b1 = bytes[i] as u32; 30 | let b2 = if i + 1 < bytes.len() { 31 | BLOCK_START[bytes[i + 1] as usize] 32 | } else { 33 | 5376 34 | }; 35 | s.push(char::from_u32(b2 + b1).unwrap()); 36 | } 37 | s 38 | } 39 | 40 | pub fn decode(s: String) -> Vec { 41 | let mut bytes = Vec::new(); 42 | for ch in s.chars() { 43 | let code_point = ch as u32; 44 | let b1 = code_point & ((1 << 8) - 1); 45 | bytes.push(b1 as u8); 46 | if code_point - b1 != 5376 { 47 | let b2 = BLOCK_START 48 | .iter() 49 | .position(|&v| v == code_point - b1) 50 | .unwrap(); 51 | assert!(b2 < 256); 52 | bytes.push(b2 as u8); 53 | } 54 | } 55 | bytes 56 | } 57 | 58 | #[cfg(test)] 59 | mod tests { 60 | use super::*; 61 | 62 | #[test] 63 | fn test_encode() { 64 | let bytes = b"Hello World"; 65 | let s = encode(bytes); 66 | let decoded = decode(s); 67 | assert_eq!(&decoded, bytes); 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /export/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "export" 3 | version = "0.1.0" 4 | authors = ["Corey Lowman "] 5 | edition = "2018" 6 | license = "MIT OR Apache-2.0" 7 | 8 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 9 | 10 | [dependencies] 11 | base65536 = { path = "../base65536" } 12 | synthesis = { path = "../synthesis" } 13 | tch = "0.4.1" -------------------------------------------------------------------------------- /export/src/main.rs: -------------------------------------------------------------------------------- 1 | use base65536; 2 | use std::collections::HashMap; 3 | use std::env; 4 | use std::io::Write; 5 | use std::path::Path; 6 | use tch; 7 | 8 | pub fn f32_to_bf16(value: f32) -> u16 { 9 | // Convert to raw bytes 10 | let x = value.to_bits(); 11 | 12 | // check for NaN 13 | if x & 0x7FFF_FFFFu32 > 0x7F80_0000u32 { 14 | // Keep high part of current mantissa but also set most significiant mantissa bit 15 | return ((x >> 16) | 0x0040u32) as u16; 16 | } 17 | 18 | // round and shift 19 | let round_bit = 0x0000_8000u32; 20 | if (x & round_bit) != 0 && (x & (3 * round_bit - 1)) != 0 { 21 | (x >> 16) as u16 + 1 22 | } else { 23 | (x >> 16) as u16 24 | } 25 | } 26 | 27 | fn serialize_tensor(t: &tch::Tensor) -> String { 28 | let f32s: Vec = t.into(); 29 | let u8s: Vec = f32s 30 | .iter() 31 | .map(|f| { 32 | f32_to_bf16(*f) 33 | .to_be_bytes() 34 | .iter() 35 | .cloned() 36 | .collect::>() 37 | }) 38 | .flatten() 39 | .collect(); 40 | base65536::encode(&u8s) 41 | } 42 | 43 | fn serialize_tensors>( 44 | variables: &HashMap, 45 | path: P, 46 | ) -> Result<(), std::io::Error> { 47 | let mut names: Vec = variables 48 | .iter() 49 | .map(|(k, _)| String::from(k.split(".").next().unwrap())) 50 | .collect(); 51 | names.sort(); 52 | names.dedup(); 53 | 54 | let mut f = std::fs::File::create(path)?; 55 | let mut i = 0; 56 | for name in names.iter() { 57 | let weight_key = format!("{}.weight", name); 58 | let weight_num_dims = variables.get(&weight_key).unwrap().size().len(); 59 | let bias_key = format!("{}.bias", name); 60 | f.write_fmt(format_args!( 61 | "load_{}d(&mut policy.{}, String::from(PARAMETERS[{}]));\n", 62 | weight_num_dims, weight_key, i, 63 | ))?; 64 | i += 1; 65 | f.write_fmt(format_args!( 66 | "load_1d(&mut policy.{}, String::from(PARAMETERS[{}]));\n", 67 | bias_key, i, 68 | ))?; 69 | i += 1; 70 | } 71 | 72 | f.write_fmt(format_args!( 73 | "const PARAMETERS: [&'static str; {}] = [\n", 74 | variables.len() 75 | ))?; 76 | let mut i = 0; 77 | for name in names.iter() { 78 | let weight_key = format!("{}.weight", name); 79 | let bias_key = format!("{}.bias", name); 80 | let str_weight = serialize_tensor(&variables.get(&weight_key).unwrap()); 81 | let str_bias = serialize_tensor(&variables.get(&bias_key).unwrap()); 82 | println!("{} - {} {}", name, str_weight.len(), str_bias.len()); 83 | f.write_fmt(format_args!( 84 | "// {} - {}\n\"{}\",\n\"{}\",\n", 85 | name, i, str_weight, str_bias, 86 | ))?; 87 | i += 2; 88 | } 89 | f.write(b"];\n")?; 90 | 91 | Ok(()) 92 | } 93 | 94 | fn main() -> Result<(), Box> { 95 | let args: Vec = env::args().collect(); 96 | assert!(args.len() == 3); 97 | 98 | let src_varstore_path = &args[1]; 99 | let dst_params_path = &args[2]; 100 | 101 | let ts = tch::Tensor::load_multi(src_varstore_path)?; 102 | let variables: HashMap = ts.into_iter().collect(); 103 | serialize_tensors(&variables, dst_params_path)?; 104 | Ok(()) 105 | } 106 | -------------------------------------------------------------------------------- /plot_ratings.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import argparse 3 | import os 4 | 5 | 6 | def main(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("path") 9 | args = parser.parse_args() 10 | 11 | scores = {} 12 | random_score = None 13 | mcts_scores = {} 14 | with open(args.path) as fp: 15 | # skip header 16 | fp.readline() 17 | 18 | for line in fp: 19 | while " " in line: 20 | line = line.replace(" ", " ", 1) 21 | parts = line.strip().split(" ") 22 | elo = int(parts[2]) 23 | if parts[1] == "Random": 24 | random_score = elo 25 | elif "VanillaMCTS" in parts[1]: 26 | mcts_scores[parts[1]] = elo 27 | else: 28 | num = int(parts[1].split("_")[1].split(".")[0]) 29 | scores[num] = elo 30 | 31 | if len(scores) > 0: 32 | names = sorted(scores) 33 | elos = [scores[name] - scores[0] for name in names] 34 | plt.plot(names, elos, label="Learner") 35 | plt.scatter(names, elos) 36 | for name, elo in mcts_scores.items(): 37 | if elo - scores[0] < 0: 38 | continue 39 | plt.plot( 40 | [names[0], names[-1]], 41 | [elo - scores[0], elo - scores[0]], 42 | linestyle="dashed", 43 | label=name, 44 | ) 45 | plt.text(names[-1], elo - scores[0], name.replace("VanillaMCTS", "")) 46 | plt.title("Strength through training") 47 | plt.xlabel("Iteration") 48 | plt.ylim(bottom=-20) 49 | plt.ylabel("BayesianELO") 50 | plt.savefig(f"{os.path.dirname(args.path)}/ratings.png") 51 | 52 | 53 | if __name__ == "__main__": 54 | main() 55 | -------------------------------------------------------------------------------- /slimnn/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "slimnn" 3 | version = "0.1.0" 4 | authors = ["Corey Lowman "] 5 | edition = "2018" 6 | license = "MIT OR Apache-2.0" 7 | 8 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 9 | 10 | [dependencies] 11 | base65536 = { path = "../base65536" } -------------------------------------------------------------------------------- /slimnn/src/activations.rs: -------------------------------------------------------------------------------- 1 | pub trait Activation { 2 | fn apply(&self, x: f32) -> f32; 3 | 4 | fn apply_1d(&self, x: &[f32; N]) -> [f32; N] { 5 | let mut y = [0.0; N]; 6 | for i in 0..N { 7 | y[i] = self.apply(x[i]); 8 | } 9 | y 10 | } 11 | 12 | fn apply_2d(&self, x: &[[f32; W]; H]) -> [[f32; W]; H] { 13 | let mut y = [[0.0; W]; H]; 14 | for i in 0..H { 15 | y[i] = self.apply_1d(&x[i]); 16 | } 17 | y 18 | } 19 | 20 | fn apply_3d( 21 | &self, 22 | x: &[[[f32; I]; W]; H], 23 | ) -> [[[f32; I]; W]; H] { 24 | let mut y = [[[0.0; I]; W]; H]; 25 | for i in 0..H { 26 | y[i] = self.apply_2d(&x[i]); 27 | } 28 | y 29 | } 30 | } 31 | 32 | pub struct ReLU; 33 | impl Activation for ReLU { 34 | fn apply(&self, x: f32) -> f32 { 35 | x.max(0.0) 36 | } 37 | } 38 | 39 | pub struct Tanh; 40 | impl Activation for Tanh { 41 | fn apply(&self, x: f32) -> f32 { 42 | x.tanh() 43 | } 44 | } 45 | 46 | pub struct Softmax; 47 | impl Activation for Softmax { 48 | fn apply(&self, _x: f32) -> f32 { 49 | panic!("Can't call softmax on 1d values") 50 | } 51 | 52 | fn apply_1d(&self, x: &[f32; N]) -> [f32; N] { 53 | let mut y = [0.0; N]; 54 | let mut total = 0.0; 55 | for i in 0..N { 56 | y[i] = x[i].exp(); 57 | total += y[i]; 58 | } 59 | for i in 0..N { 60 | y[i] /= total; 61 | } 62 | y 63 | } 64 | } 65 | 66 | #[cfg(test)] 67 | mod tests { 68 | use super::*; 69 | 70 | #[test] 71 | fn test_relu_1d() { 72 | let x = [-2., -1., -0.5, 0., 0.5, 1., 2.]; 73 | let y = ReLU.apply_1d(&x); 74 | assert_eq!(y, [0., 0., 0., 0.0, 0.5, 1., 2.]) 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /slimnn/src/conv.rs: -------------------------------------------------------------------------------- 1 | pub struct Conv2d< 2 | const NUM_CHAN_IN: usize, 3 | const NUM_CHAN_OUT: usize, 4 | const KERNEL_SIZE: usize, 5 | const ROW_PADDING: usize, 6 | const COL_PADDING: usize, 7 | const STRIDE: usize, 8 | > { 9 | pub weight: [[[[f32; KERNEL_SIZE]; KERNEL_SIZE]; NUM_CHAN_IN]; NUM_CHAN_OUT], 10 | pub bias: [f32; NUM_CHAN_OUT], 11 | } 12 | 13 | pub type DefaultConv2d< 14 | const NUM_CHAN_IN: usize, 15 | const NUM_CHAN_OUT: usize, 16 | const KERNEL_SIZE: usize, 17 | > = Conv2d; 18 | 19 | impl< 20 | const NUM_CHAN_IN: usize, 21 | const NUM_CHAN_OUT: usize, 22 | const KERNEL_SIZE: usize, 23 | const ROW_PADDING: usize, 24 | const COL_PADDING: usize, 25 | const STRIDE: usize, 26 | > Default for Conv2d 27 | { 28 | fn default() -> Self { 29 | Self { 30 | weight: [[[[0.0; KERNEL_SIZE]; KERNEL_SIZE]; NUM_CHAN_IN]; NUM_CHAN_OUT], 31 | bias: [0.0; NUM_CHAN_OUT], 32 | } 33 | } 34 | } 35 | 36 | impl< 37 | const NUM_CHAN_IN: usize, 38 | const NUM_CHAN_OUT: usize, 39 | const KERNEL_SIZE: usize, 40 | const ROW_PADDING: usize, 41 | const COL_PADDING: usize, 42 | const STRIDE: usize, 43 | > Conv2d 44 | { 45 | pub fn forward( 46 | &self, 47 | x: &[[[f32; W_IN]; H_IN]; NUM_CHAN_IN], 48 | ) -> [[[f32; W_OUT]; H_OUT]; NUM_CHAN_OUT] { 49 | // TODO convert these to compile time calculations when that feature gets added 50 | assert_eq!(W_OUT, ((W_IN + 2 * COL_PADDING - KERNEL_SIZE) / STRIDE) + 1); 51 | assert_eq!(H_OUT, ((H_IN + 2 * ROW_PADDING - KERNEL_SIZE) / STRIDE) + 1); 52 | 53 | let mut y = [[[0.0; W_OUT]; H_OUT]; NUM_CHAN_OUT]; 54 | for i_cout in 0..NUM_CHAN_OUT { 55 | for i_out_row in 0..H_OUT { 56 | for i_out_col in 0..W_OUT { 57 | y[i_cout][i_out_row][i_out_col] = self.bias[i_cout]; 58 | } 59 | } 60 | } 61 | 62 | for i_cout in 0..NUM_CHAN_OUT { 63 | for i_cin in 0..NUM_CHAN_IN { 64 | for i_out_row in 0..H_OUT { 65 | for i_out_col in 0..W_OUT { 66 | for i_k1 in 0..KERNEL_SIZE { 67 | let i_in_row = i_out_row * STRIDE + i_k1; 68 | if ROW_PADDING <= i_in_row && i_in_row < H_IN + ROW_PADDING { 69 | for i_k2 in 0..KERNEL_SIZE { 70 | let i_in_col = i_out_col * STRIDE + i_k2; 71 | if COL_PADDING <= i_in_col && i_in_col < W_IN + COL_PADDING { 72 | let w = self.weight[i_cout][i_cin][i_k1][i_k2]; 73 | let v = x[i_cin][i_in_row - ROW_PADDING] 74 | [i_in_col - COL_PADDING]; 75 | y[i_cout][i_out_row][i_out_col] += w * v; 76 | } 77 | } 78 | } 79 | } 80 | } 81 | } 82 | } 83 | } 84 | y 85 | } 86 | } 87 | 88 | #[cfg(test)] 89 | mod tests { 90 | use super::*; 91 | 92 | #[test] 93 | fn test_channels() { 94 | let mut conv: Conv2d<2, 3, 3, 0, 0, 1> = Default::default(); 95 | conv.weight = [ 96 | [ 97 | [ 98 | [-0.15755141, 0.22708158, -0.15798858], 99 | [0.02651758, 0.02949663, 0.00168143], 100 | [0.17578505, -0.00790119, -0.19205979], 101 | ], 102 | [ 103 | [-0.14549221, 0.18005536, -0.20870999], 104 | [-0.02696078, -0.23109876, 0.08636613], 105 | [-0.19711518, -0.22332281, 0.10554548], 106 | ], 107 | ], 108 | [ 109 | [ 110 | [0.16746391, 0.15539487, 0.17933075], 111 | [0.10245387, 0.17808135, 0.00389603], 112 | [-0.05004337, 0.1836464, -0.10865615], 113 | ], 114 | [ 115 | [-0.10701959, 0.22484492, 0.02793057], 116 | [-0.02405447, -0.19043699, -0.06315055], 117 | [0.06873955, 0.07716306, -0.04570898], 118 | ], 119 | ], 120 | [ 121 | [ 122 | [0.17305957, 0.04872765, 0.00382756], 123 | [0.23236315, -0.09844781, 0.02772947], 124 | [0.15478937, 0.1431862, 0.17355753], 125 | ], 126 | [ 127 | [-0.08297984, -0.06017058, -0.12197169], 128 | [0.16294028, 0.10859056, -0.06415635], 129 | [-0.02141872, -0.07581845, 0.21996914], 130 | ], 131 | ], 132 | ]; 133 | conv.bias = [0.10084419, -0.09816569, 0.08800296]; 134 | let x = [ 135 | [ 136 | [0.36798823, 0.32735145, 0.98843247], 137 | [0.6964252, 0.77441615, 0.2872073], 138 | [0.44724697, 0.34764975, 0.43664432], 139 | ], 140 | [ 141 | [0.3974272, 0.56625223, 0.29246235], 142 | [0.42091078, 0.92644644, 0.7363354], 143 | [0.8082989, 0.3173269, 0.56485856], 144 | ], 145 | ]; 146 | let y = conv.forward::<3, 3, 1, 1>(&x); 147 | let t = [[[-0.35449058]], [[0.31011194]], [[0.5618632]]]; 148 | for i_chan in 0..3 { 149 | for i_row in 0..1 { 150 | for i_col in 0..1 { 151 | assert!((y[i_chan][i_row][i_col] - t[i_chan][i_row][i_col]) < 1e-6); 152 | } 153 | } 154 | } 155 | } 156 | 157 | #[test] 158 | fn test_pad0_stride1() { 159 | let mut conv: Conv2d<1, 2, 2, 0, 0, 1> = Default::default(); 160 | conv.weight = [ 161 | [[[-0.32677823, -0.05076015], [-0.05067509, -0.4386055]]], 162 | [[[0.03580105, -0.22968405], [-0.0093717, -0.02110344]]], 163 | ]; 164 | conv.bias = [-0.27842504, -0.13038313]; 165 | let x = [[ 166 | [0.31053752, 0.7199225, 0.43577182], 167 | [0.5588689, 0.87854236, 0.03064412], 168 | [0.6246787, 0.7104719, 0.34061056], 169 | ]]; 170 | let y = conv.forward::<3, 3, 2, 2>(&x); 171 | let t = [ 172 | [[-0.8300995, -0.5937608], [-0.84891874, -0.75246596]], 173 | [[-0.3083981, -0.21357912], [-0.3330099, -0.11981525]], 174 | ]; 175 | for i_chan in 0..2 { 176 | for i_row in 0..2 { 177 | for i_col in 0..2 { 178 | assert!((y[i_chan][i_row][i_col] - t[i_chan][i_row][i_col]) < 1e-6); 179 | } 180 | } 181 | } 182 | } 183 | 184 | #[test] 185 | fn test_pad1_stride1() { 186 | let mut conv: Conv2d<1, 2, 3, 1, 1, 1> = Default::default(); 187 | conv.weight = [ 188 | [[ 189 | [0.1022217, -0.2990429, -0.00128791], 190 | [0.19702038, 0.2377766, -0.26631758], 191 | [0.13468763, 0.15147987, 0.17658153], 192 | ]], 193 | [[ 194 | [-0.28261143, 0.22986522, -0.23906478], 195 | [0.2649006, 0.2315791, -0.1839219], 196 | [-0.23470804, 0.25393584, 0.10693645], 197 | ]], 198 | ]; 199 | conv.bias = [-0.24373719, 0.20181063]; 200 | let x = [[ 201 | [0.31053752, 0.7199225, 0.43577182], 202 | [0.5588689, 0.87854236, 0.03064412], 203 | [0.6246787, 0.7104719, 0.34061056], 204 | ]]; 205 | let y = conv.forward::<3, 3, 3, 3>(&x); 206 | let t = [ 207 | [ 208 | [-0.12183491, 0.08633745, 0.12468931], 209 | [-0.21853128, 0.13490605, 0.02720466], 210 | [-0.45267165, -0.24807255, 0.05787167], 211 | ], 212 | [ 213 | [0.37718016, 0.46584255, 0.29501486], 214 | [0.30352712, 0.5914381, 0.2580838], 215 | [0.13423777, 0.5058508, 0.22765125], 216 | ], 217 | ]; 218 | for i_chan in 0..2 { 219 | for i_row in 0..3 { 220 | for i_col in 0..3 { 221 | assert!( 222 | (y[i_chan][i_row][i_col] - t[i_chan][i_row][i_col]) < 1e-6, 223 | "y={:?}\nt={:?}\n", 224 | y, 225 | t 226 | ); 227 | } 228 | } 229 | } 230 | } 231 | 232 | #[test] 233 | fn test_pad1_stride1_v2() { 234 | let mut conv: Conv2d<1, 2, 2, 1, 1, 1> = Default::default(); 235 | conv.weight = [ 236 | [[[-0.05767131, 0.29223222], [0.12310421, 0.4551205]]], 237 | [[[-0.42312527, -0.10522532], [-0.44393647, -0.34403402]]], 238 | ]; 239 | conv.bias = [-0.44340342, -0.04786187]; 240 | let x = [[ 241 | [0.31053752, 0.7199225, 0.43577182], 242 | [0.5588689, 0.87854236, 0.03064412], 243 | [0.6246787, 0.7104719, 0.34061056], 244 | ]]; 245 | let y = conv.forward::<3, 3, 4, 4>(&x); 246 | let t = [ 247 | [ 248 | [-0.30207145, -0.07752347, -0.15644926, -0.38975808], 249 | [-0.09830165, 0.21771383, -0.23547669, -0.46476254], 250 | [0.00422016, 0.18135518, -0.24263397, -0.4032401], 251 | [-0.2608522, -0.2718067, -0.3848399, -0.46304688], 252 | ], 253 | [ 254 | [-0.15469734, -0.4333986, -0.517382, -0.24131688], 255 | [-0.2728082, -0.80536294, -0.79889315, -0.245852], 256 | [-0.32157975, -0.89852244, -0.85540587, -0.21203762], 257 | [-0.11359389, -0.38693884, -0.38432136, -0.1919828], 258 | ], 259 | ]; 260 | for i_chan in 0..2 { 261 | for i_row in 0..4 { 262 | for i_col in 0..4 { 263 | assert!( 264 | (y[i_chan][i_row][i_col] - t[i_chan][i_row][i_col]) < 1e-6, 265 | "y={:?}\nt={:?}\n", 266 | y, 267 | t 268 | ); 269 | } 270 | } 271 | } 272 | } 273 | 274 | #[test] 275 | fn test_pad1_stride3() { 276 | let mut conv: Conv2d<2, 4, 3, 1, 1, 3> = Default::default(); 277 | conv.weight = [ 278 | [ 279 | [ 280 | [0.14617981, 0.11528827, -0.09853071], 281 | [-0.07204299, 0.00336616, 0.10673852], 282 | [0.00742291, 0.11231156, 0.2337733], 283 | ], 284 | [ 285 | [-0.12326456, 0.09475489, 0.00286931], 286 | [-0.11880011, -0.02660868, 0.07172234], 287 | [0.2074004, 0.03104086, 0.04103799], 288 | ], 289 | ], 290 | [ 291 | [ 292 | [0.21327908, 0.1421523, -0.05588683], 293 | [-0.11007029, 0.21513708, -0.00373882], 294 | [-0.1405331, -0.21004112, 0.14830004], 295 | ], 296 | [ 297 | [0.20064591, 0.01574595, 0.20489053], 298 | [0.22750203, 0.0800796, 0.04829051], 299 | [-0.14500698, -0.05896148, 0.00155589], 300 | ], 301 | ], 302 | [ 303 | [ 304 | [-0.15640959, 0.06537981, -0.04286686], 305 | [-0.18440281, 0.19307987, -0.17548376], 306 | [-0.16365337, -0.19908702, -0.11890511], 307 | ], 308 | [ 309 | [0.1236528, -0.02123187, 0.09672363], 310 | [0.12256287, -0.1111846, -0.07555458], 311 | [0.00952002, -0.20669132, 0.07111661], 312 | ], 313 | ], 314 | [ 315 | [ 316 | [0.2330633, -0.15059352, -0.15611872], 317 | [-0.17001018, 0.14806725, 0.16246979], 318 | [-0.05100261, 0.03655182, 0.03957109], 319 | ], 320 | [ 321 | [-0.17500071, -0.16877688, -0.12116098], 322 | [0.08597867, 0.2115535, -0.06175958], 323 | [0.17978726, -0.1459909, -0.03851977], 324 | ], 325 | ], 326 | ]; 327 | conv.bias = [0.20518447, 0.19566856, -0.00938936, 0.03015544]; 328 | let x = [ 329 | [ 330 | [ 331 | 0.21057081, 0.7208752, 0.45118594, 0.02713799, 0.68024296, 0.3910358, 332 | 0.23223233, 0.41766483, 0.91346294, 333 | ], 334 | [ 335 | 0.2500931, 0.12637055, 0.1272502, 0.51244897, 0.4964708, 0.8277718, 0.59300315, 336 | 0.20171815, 0.36267585, 337 | ], 338 | [ 339 | 0.38347948, 0.9041415, 0.899777, 0.13333935, 0.36263585, 0.25092953, 340 | 0.01967365, 0.8455198, 0.7374019, 341 | ], 342 | [ 343 | 0.3127641, 0.16012222, 0.8803009, 0.42519283, 0.03213924, 0.28457332, 344 | 0.11092412, 0.7152574, 0.412224, 345 | ], 346 | [ 347 | 0.02153009, 0.4150746, 0.3453967, 0.29103887, 0.77954763, 0.14238816, 348 | 0.89460665, 0.40737927, 0.25643557, 349 | ], 350 | [ 351 | 0.8341871, 0.15114868, 0.503467, 0.05420381, 0.90553546, 0.76879644, 0.9369356, 352 | 0.39110446, 0.314408, 353 | ], 354 | [ 355 | 0.46065217, 0.14885396, 0.03641135, 0.658757, 0.14085609, 0.30918586, 356 | 0.4080497, 0.37957066, 0.4732889, 357 | ], 358 | [ 359 | 0.07997447, 0.3393948, 0.45755583, 0.4620967, 0.05732316, 0.6345551, 0.8626297, 360 | 0.3647226, 0.33206326, 361 | ], 362 | [ 363 | 0.9571505, 0.5728258, 0.41579878, 0.55097336, 0.26232815, 0.91751957, 364 | 0.0674867, 0.36741936, 0.81730753, 365 | ], 366 | ], 367 | [ 368 | [ 369 | 0.7533369, 0.3351096, 0.20795238, 0.64170164, 0.08985782, 0.96100605, 370 | 0.55665845, 0.5729717, 0.6419061, 371 | ], 372 | [ 373 | 0.0928914, 0.5336605, 0.6168807, 0.98969394, 0.3537085, 0.13544858, 0.15680116, 374 | 0.09755313, 0.84776443, 375 | ], 376 | [ 377 | 0.7369501, 0.23981953, 0.6761897, 0.904716, 0.47997457, 0.71871966, 0.7341236, 378 | 0.08496976, 0.71287054, 379 | ], 380 | [ 381 | 0.5947141, 0.35914594, 0.8863333, 0.19789106, 0.8543512, 0.46718287, 0.2856838, 382 | 0.3525315, 0.04213792, 383 | ], 384 | [ 385 | 0.33818644, 0.41329008, 0.49908864, 0.76561147, 0.14832914, 0.64022946, 386 | 0.84428936, 0.3323446, 0.6859826, 387 | ], 388 | [ 389 | 0.4561953, 0.8572295, 0.9829912, 0.2284044, 0.1273588, 0.45776683, 0.02060169, 390 | 0.5639244, 0.21596754, 391 | ], 392 | [ 393 | 0.9491337, 0.02133191, 0.596267, 0.74133545, 0.05190426, 0.10490572, 0.5913261, 394 | 0.73797053, 0.4090029, 395 | ], 396 | [ 397 | 0.10168785, 0.28168315, 0.5618286, 0.50776577, 0.3526731, 0.5083848, 398 | 0.88234985, 0.36654603, 0.6237142, 399 | ], 400 | [ 401 | 0.45444447, 0.18045723, 0.86170214, 0.2762791, 0.6174765, 0.30022788, 402 | 0.19892776, 0.07195282, 0.22581935, 403 | ], 404 | ], 405 | ]; 406 | let y = conv.forward::<9, 9, 3, 3>(&x); 407 | // assert_eq!(W_OUT, ((W_IN + 2 * PADDING - KERNEL_SIZE + 1) / STRIDE)); 408 | // 7 + 2 - 3 + 1 409 | // 7 410 | let t = [ 411 | [ 412 | [0.3692422, 0.5577823, 0.29135567], 413 | [0.38581508, 0.5631167, 0.5298315], 414 | [0.42890364, 0.2398901, 0.71326447], 415 | ], 416 | [ 417 | [0.27634868, 0.05319048, 0.25220025], 418 | [0.4297849, 0.776384, 0.2243587], 419 | [0.6921845, 0.5580633, 0.4964255], 420 | ], 421 | [ 422 | [-0.25037688, -0.61478406, -0.3990448], 423 | [-0.17064369, -0.4600212, -0.55443037], 424 | [0.01031597, -0.16011265, -0.59833616], 425 | ], 426 | [ 427 | [0.29715353, 0.20075348, 0.21606356], 428 | [-0.19430616, -0.18607897, -0.1226138], 429 | [-0.0173984, 0.10020237, -0.02460488], 430 | ], 431 | ]; 432 | for i_chan in 0..4 { 433 | for i_row in 0..3 { 434 | for i_col in 0..3 { 435 | assert!( 436 | (y[i_chan][i_row][i_col] - t[i_chan][i_row][i_col]) < 1e-6, 437 | "y={:?}\nt={:?}\n", 438 | y, 439 | t 440 | ); 441 | } 442 | } 443 | } 444 | } 445 | 446 | #[test] 447 | fn test_diff_paddings() { 448 | let mut conv: Conv2d<2, 4, 3, 1, 0, 3> = Default::default(); 449 | 450 | conv.weight = [ 451 | [ 452 | [ 453 | [-0.16951475, 0.18148716, 0.11452983], 454 | [-0.05653553, -0.12347288, 0.1437112], 455 | [0.11910407, 0.09197839, -0.19396508], 456 | ], 457 | [ 458 | [0.2284107, 0.0551682, -0.21032758], 459 | [0.16945575, 0.16806488, 0.11360888], 460 | [0.17211585, -0.23254138, -0.14311942], 461 | ], 462 | ], 463 | [ 464 | [ 465 | [0.15108018, 0.2310238, -0.07898572], 466 | [0.09268294, 0.16065855, -0.02468468], 467 | [0.04045774, 0.11338453, -0.13995044], 468 | ], 469 | [ 470 | [-0.23541123, 0.09708233, -0.1154344], 471 | [-0.07918043, 0.14908992, -0.23269427], 472 | [0.1471413, -0.18606442, -0.07318471], 473 | ], 474 | ], 475 | [ 476 | [ 477 | [-0.05816181, -0.07883959, -0.1241257], 478 | [-0.17616543, -0.19186923, 0.0532635], 479 | [0.08541082, 0.0948873, 0.1106682], 480 | ], 481 | [ 482 | [0.1407726, -0.18328246, -0.20389993], 483 | [-0.18750295, -0.00620356, 0.1526462], 484 | [0.01116495, -0.19061024, 0.17295755], 485 | ], 486 | ], 487 | [ 488 | [ 489 | [-0.15338722, 0.11304025, -0.13875616], 490 | [0.170949, 0.05909871, -0.09674671], 491 | [-0.00352511, 0.10647057, -0.17724186], 492 | ], 493 | [ 494 | [-0.18874522, -0.16570199, -0.00518005], 495 | [0.05784269, -0.0508645, 0.08142553], 496 | [-0.12230667, 0.14570008, 0.05119579], 497 | ], 498 | ], 499 | ]; 500 | conv.bias = [0.04838754, 0.07195146, -0.03947061, 0.17069755]; 501 | 502 | let x = [ 503 | [ 504 | [ 505 | 0.8835113, 0.13882023, 0.7315728, 0.6014649, 0.29594523, 0.46894914, 0.8524046, 506 | 0.45359534, 0.43056685, 507 | ], 508 | [ 509 | 0.1012792, 0.05108845, 0.25242734, 0.8690395, 0.32797933, 0.05993927, 510 | 0.7875964, 0.00939852, 0.23371905, 511 | ], 512 | [ 513 | 0.7273682, 0.39719957, 0.9026962, 0.18514287, 0.9099473, 0.6830876, 0.8096039, 514 | 0.5167066, 0.83929837, 515 | ], 516 | [ 517 | 0.6798897, 0.6469037, 0.37039953, 0.880702, 0.49907035, 0.05726093, 0.7798417, 518 | 0.9619096, 0.5043428, 519 | ], 520 | [ 521 | 0.2669297, 0.8015939, 0.95740455, 0.6020198, 0.87445277, 0.7183642, 0.54205835, 522 | 0.68437976, 0.82951576, 523 | ], 524 | [ 525 | 0.5435678, 0.4369346, 0.4174738, 0.8460932, 0.13418722, 0.55912274, 0.1456281, 526 | 0.800143, 0.07224685, 527 | ], 528 | [ 529 | 0.95401746, 0.2754792, 0.03196198, 0.5022102, 0.1791898, 0.08201677, 0.706302, 530 | 0.57362723, 0.3891986, 531 | ], 532 | ], 533 | [ 534 | [ 535 | 0.25238568, 0.6502806, 0.5196041, 0.3642316, 0.6576511, 0.12723362, 0.7020561, 536 | 0.18168187, 0.5960324, 537 | ], 538 | [ 539 | 0.23093349, 0.14747256, 0.6867326, 0.85302174, 0.39524788, 0.4955296, 540 | 0.8853273, 0.60563225, 0.9854223, 541 | ], 542 | [ 543 | 0.08876747, 0.84838796, 0.7440208, 0.6717199, 0.9986871, 0.16444528, 0.8884909, 544 | 0.11773419, 0.16389328, 545 | ], 546 | [ 547 | 0.12747067, 0.62013346, 0.31473154, 0.5395161, 0.07182807, 0.17038721, 548 | 0.9332434, 0.3214336, 0.91524035, 549 | ], 550 | [ 551 | 0.82844126, 0.5781982, 0.13706946, 0.73159313, 0.50984377, 0.62877065, 552 | 0.24265271, 0.08741719, 0.5145434, 553 | ], 554 | [ 555 | 0.50140876, 0.99232125, 0.6318608, 0.7825543, 0.9962608, 0.27447134, 556 | 0.82021457, 0.7088084, 0.789743, 557 | ], 558 | [ 559 | 0.22269273, 0.14444828, 0.810514, 0.17407775, 0.48010093, 0.48212987, 560 | 0.5515393, 0.01245189, 0.29415256, 561 | ], 562 | ], 563 | ]; 564 | 565 | let y = conv.forward::<9, 7, 3, 3>(&x); 566 | 567 | let t = [ 568 | [ 569 | [0.17249039, 0.33797365, 0.14313158], 570 | [0.01598295, 0.38415283, 0.4161861], 571 | [0.19046798, 0.29555833, 0.31168252], 572 | ], 573 | [ 574 | [0.04498426, 0.28293645, -0.00831042], 575 | [0.31679663, 0.25450188, -0.02866032], 576 | [0.07459396, 0.08634683, 0.09541086], 577 | ], 578 | [ 579 | [-0.02014766, -0.09831435, -0.13660741], 580 | [-0.59615356, -0.41860208, -0.21566784], 581 | [-0.5356276, -0.37756905, -0.5678168], 582 | ], 583 | [ 584 | [0.27173287, 0.24347454, 0.36867434], 585 | [-0.15768422, 0.05384711, -0.01273608], 586 | [0.06425115, -0.22228895, 0.12463954], 587 | ], 588 | ]; 589 | 590 | for i_chan in 0..4 { 591 | for i_row in 0..3 { 592 | for i_col in 0..3 { 593 | assert!( 594 | (y[i_chan][i_row][i_col] - t[i_chan][i_row][i_col]) < 1e-6, 595 | "y={:?}\nt={:?}\n", 596 | y, 597 | t 598 | ); 599 | } 600 | } 601 | } 602 | } 603 | } 604 | -------------------------------------------------------------------------------- /slimnn/src/lib.rs: -------------------------------------------------------------------------------- 1 | mod activations; 2 | mod conv; 3 | mod linear; 4 | mod loading; 5 | 6 | pub use activations::{Activation, ReLU, Softmax, Tanh}; 7 | pub use conv::{Conv2d, DefaultConv2d}; 8 | pub use linear::Linear; 9 | pub use loading::{load_1d, load_2d, load_4d}; 10 | -------------------------------------------------------------------------------- /slimnn/src/linear.rs: -------------------------------------------------------------------------------- 1 | #[derive(Debug)] 2 | pub struct Linear { 3 | pub weight: [[f32; I]; O], 4 | pub bias: [f32; O], 5 | } 6 | 7 | impl Default for Linear { 8 | fn default() -> Self { 9 | Self { 10 | weight: [[0.0; I]; O], 11 | bias: [0.0; O], 12 | } 13 | } 14 | } 15 | 16 | impl Linear { 17 | pub fn forward(&self, x: &[f32; I]) -> [f32; O] { 18 | let mut output = self.bias; 19 | for i_input in 0..I { 20 | for i_output in 0..O { 21 | output[i_output] += x[i_input] * self.weight[i_output][i_input]; 22 | } 23 | } 24 | output 25 | } 26 | } 27 | 28 | /* 29 | 30 | use std::arch::x86_64::*; 31 | 32 | #[derive(Debug)] 33 | pub struct SIMDLinear { 34 | pub weight: [[__m256; O]; Ix8], 35 | pub bias: [__m256; O], 36 | } 37 | 38 | impl From<&Linear> for SIMDLinear { 39 | fn from(linear: &Linear) -> Self { 40 | let mut weight = [[unsafe { _mm256_setzero_ps() }; O]; Ix8]; 41 | let mut bias = [unsafe { _mm256_setzero_ps() }; O]; 42 | for i_input in 0..Ix8 { 43 | for i_output in 0..O { 44 | weight[i_input][i_output] = unsafe { _mm256_loadu_ps(&linear.weight[i_input][8 * i_output]) }; 45 | } 46 | } 47 | for i_output in 0..O { 48 | bias[i_output] = unsafe { _mm256_loadu_ps(&linear.bias[8 * i_output]) }; 49 | } 50 | 51 | Self { weight, bias } 52 | } 53 | } 54 | 55 | impl SIMDLinear { 56 | pub fn set1(&self, x: &[f32; Ix8]) -> [__m256; Ix8] { 57 | let mut c = [unsafe { _mm256_setzero_ps() }; Ix8]; 58 | for i in 0..Ix8 { 59 | c[i] = unsafe { _mm256_set1_ps(x[i]) }; 60 | } 61 | c 62 | } 63 | 64 | pub fn decompress(&self, x: &[__m256; O]) -> [f32; Ox8] { 65 | let mut o = [0.0; Ox8]; 66 | for i in 0..O { 67 | unsafe { 68 | _mm256_store_ps(&mut o[8 * i], x[i]); 69 | } 70 | } 71 | o 72 | } 73 | 74 | pub fn forward(&self, x: &[__m256; Ix8]) -> [__m256; O] { 75 | let mut output = self.bias; 76 | for i_input in 0..Ix8 { 77 | let w = &self.weight[i_input]; 78 | let v = x[i_input]; 79 | for i_output in 0..O { 80 | output[i_output] = unsafe { 81 | _mm256_fmadd_ps(v, w[i_output], output[i_output]) 82 | }; 83 | } 84 | } 85 | output 86 | } 87 | 88 | pub fn forward_relu(&self, x: &[__m256; Ix8]) -> [__m256; O] { 89 | let mut o = self.forward(x); 90 | let zeros = unsafe { _mm256_setzero_ps() }; 91 | for i_output in 0..O { 92 | o[i_output] = unsafe { 93 | _mm256_max_ps(o[i_output], zeros) 94 | }; 95 | } 96 | o 97 | } 98 | } 99 | */ 100 | 101 | #[cfg(test)] 102 | mod tests { 103 | use super::*; 104 | 105 | #[test] 106 | fn test_linear() { 107 | let mut q: Linear<3, 2> = Default::default(); 108 | q.weight = [[1., 3., 5.], [2., 4., 6.]]; 109 | q.bias = [-1., 1.]; 110 | assert_eq!(q.forward(&[3., 2., 1.]), [13., 21.]); 111 | assert_eq!(q.forward(&[1., 3., 2.]), [19., 27.]); 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /slimnn/src/loading.rs: -------------------------------------------------------------------------------- 1 | use base65536; 2 | 3 | fn bytes_to_floats(bytes: Vec) -> Vec { 4 | let mut floats = Vec::with_capacity(bytes.len() / 4); 5 | assert!(bytes.len() % 4 == 0); 6 | for i in (0..bytes.len()).step_by(4) { 7 | floats.push(f32::from_be_bytes([ 8 | bytes[i], 9 | bytes[i + 1], 10 | bytes[i + 2], 11 | bytes[i + 3], 12 | ])); 13 | } 14 | floats 15 | } 16 | 17 | pub fn load_1d(data: &mut [f32; N], params: String) { 18 | let bytes = base65536::decode(params); 19 | let floats = bytes_to_floats(bytes); 20 | assert_eq!(floats.len(), N); 21 | unsafe { std::ptr::copy(floats.as_ptr(), data.as_mut_ptr(), floats.len()) }; 22 | } 23 | 24 | pub fn load_2d(data: &mut [[f32; I]; O], params: String) { 25 | let bytes = base65536::decode(params); 26 | let floats = bytes_to_floats(bytes); 27 | assert_eq!(floats.len(), I * O); 28 | unsafe { std::ptr::copy(floats.as_ptr(), data[0].as_mut_ptr(), floats.len()) }; 29 | } 30 | 31 | pub fn load_4d( 32 | data: &mut [[[[f32; I]; J]; K]; L], 33 | params: String, 34 | ) { 35 | let bytes = base65536::decode(params); 36 | let floats = bytes_to_floats(bytes); 37 | assert_eq!(floats.len(), I * J * K * L); 38 | unsafe { std::ptr::copy(floats.as_ptr(), data[0][0][0].as_mut_ptr(), floats.len()) }; 39 | } 40 | -------------------------------------------------------------------------------- /study-connect4/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "study-connect4" 3 | version = "0.1.0" 4 | authors = ["Corey Lowman "] 5 | edition = "2018" 6 | license = "MIT OR Apache-2.0" 7 | 8 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 9 | 10 | [dependencies] 11 | synthesis = { path = "../synthesis" } 12 | slimnn = { path = "../slimnn" } 13 | tch = "0.4.1" 14 | rand = "0.8.3" 15 | rand_distr = "0.4.0" -------------------------------------------------------------------------------- /study-connect4/src/connect4.rs: -------------------------------------------------------------------------------- 1 | use synthesis::game::*; 2 | 3 | /* 4 | +----------------------------+ 5 | | 6 13 20 27 34 41 48 55 62 | 6 | | 5 12 19 26 33 40 47 54 61 | 7 | | 4 11 18 25 32 39 46 53 60 | 8 | | 3 10 17 24 31 38 45 52 59 | 9 | | 2 9 16 23 30 37 44 51 58 | 10 | | 1 8 15 22 29 36 43 50 57 | 11 | | 0 7 14 21 28 35 42 49 56 | 63 12 | +----------------------------+ 13 | */ 14 | 15 | const WIDTH: usize = 9; 16 | const HEIGHT: usize = 7; 17 | 18 | #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] 19 | pub enum PlayerId { 20 | Red, 21 | Black, 22 | } 23 | 24 | impl HasTurnOrder for PlayerId { 25 | fn prev(&self) -> Self { 26 | self.next() 27 | } 28 | 29 | fn next(&self) -> Self { 30 | match self { 31 | PlayerId::Black => PlayerId::Red, 32 | PlayerId::Red => PlayerId::Black, 33 | } 34 | } 35 | } 36 | 37 | const FAB_COL: u64 = 0b1111111; 38 | const FAB_ROW: u64 = (1 << (7 * 0)) 39 | | (1 << (7 * 1)) 40 | | (1 << (7 * 2)) 41 | | (1 << (7 * 3)) 42 | | (1 << (7 * 4)) 43 | | (1 << (7 * 5)) 44 | | (1 << (7 * 6)) 45 | | (1 << (7 * 7)) 46 | | (1 << (7 * 8)); 47 | 48 | const COLS: [u64; WIDTH] = [ 49 | FAB_COL << (7 * 0), 50 | FAB_COL << (7 * 1), 51 | FAB_COL << (7 * 2), 52 | FAB_COL << (7 * 3), 53 | FAB_COL << (7 * 4), 54 | FAB_COL << (7 * 5), 55 | FAB_COL << (7 * 6), 56 | FAB_COL << (7 * 7), 57 | FAB_COL << (7 * 8), 58 | ]; 59 | 60 | const ROWS: [u64; HEIGHT] = [ 61 | FAB_ROW << 0, 62 | FAB_ROW << 1, 63 | FAB_ROW << 2, 64 | FAB_ROW << 3, 65 | FAB_ROW << 4, 66 | FAB_ROW << 5, 67 | FAB_ROW << 6, 68 | ]; 69 | 70 | const D1_MASK: u64 = (COLS[0] | COLS[1] | COLS[2] | COLS[3] | COLS[4] | COLS[5]) 71 | & (ROWS[3] | ROWS[4] | ROWS[5] | ROWS[6]); 72 | const D2_MASK: u64 = (COLS[0] | COLS[1] | COLS[2] | COLS[3] | COLS[4] | COLS[5]) 73 | & (ROWS[0] | ROWS[1] | ROWS[2] | ROWS[3]); 74 | const H_MASK: u64 = COLS[0] | COLS[1] | COLS[2] | COLS[3] | COLS[4] | COLS[5]; 75 | const V_MASK: u64 = ROWS[0] | ROWS[1] | ROWS[2] | ROWS[3]; 76 | 77 | const fn won(bb: u64) -> bool { 78 | let d1 = bb & (bb >> 6) & (bb >> 12) & (bb >> 18) & D1_MASK; 79 | let d2 = bb & (bb >> 8) & (bb >> 16) & (bb >> 24) & D2_MASK; 80 | let h = bb & (bb >> 7) & (bb >> 14) & (bb >> 21) & H_MASK; 81 | let v = bb & (bb >> 1) & (bb >> 2) & (bb >> 3) & V_MASK; 82 | v + h + d1 + d2 > 0 83 | } 84 | 85 | /* 86 | 87 | use std::arch::x86_64::*; 88 | 89 | fn fast_won(bb: u64) -> bool { 90 | unsafe { 91 | let bbx4 = _mm256_set1_epi64x(bb as i64); 92 | let maskx4 = _mm256_set_epi64x(D1_MASK as i64, D2_MASK as i64, H_MASK as i64, V_MASK as i64); 93 | let shift1 = _mm256_set_epi64x(6, 8, 7, 1); 94 | let shift2 = _mm256_set_epi64x(12, 16, 14, 2); 95 | let shift3 = _mm256_set_epi64x(18, 24, 21, 3); 96 | let a = _mm256_and_si256(bbx4, maskx4); 97 | let b = _mm256_srlv_epi64(bbx4, shift1); 98 | let c = _mm256_and_si256(a, b); 99 | let d = _mm256_srlv_epi64(bbx4, shift2); 100 | let e = _mm256_and_si256(c, d); 101 | let f = _mm256_srlv_epi64(bbx4, shift3); 102 | let res = _mm256_testz_si256(e, f); 103 | res == 0 104 | } 105 | } 106 | */ 107 | 108 | #[derive(Debug, Eq, PartialEq, Clone)] 109 | pub struct Connect4 { 110 | my_bb: u64, 111 | op_bb: u64, 112 | height: [u8; WIDTH], 113 | player: PlayerId, 114 | } 115 | 116 | impl std::hash::Hash for Connect4 { 117 | fn hash(&self, state: &mut H) { 118 | state.write_u64(self.my_bb); 119 | state.write_u64(self.op_bb); 120 | } 121 | } 122 | 123 | #[derive(Debug, PartialEq, Eq, Clone, Copy)] 124 | pub struct Column(u8); 125 | 126 | impl From for Column { 127 | fn from(x: usize) -> Self { 128 | Column(x as u8) 129 | } 130 | } 131 | 132 | impl Into for Column { 133 | fn into(self) -> usize { 134 | self.0 as usize 135 | } 136 | } 137 | 138 | pub struct FreeColumns { 139 | height: [u8; WIDTH], 140 | col: u8, 141 | } 142 | 143 | impl Iterator for FreeColumns { 144 | type Item = Column; 145 | fn next(&mut self) -> Option { 146 | if self.col == WIDTH as u8 { 147 | return None; 148 | } 149 | 150 | while self.col < WIDTH as u8 { 151 | if self.height[self.col as usize] < HEIGHT as u8 { 152 | let item = Some(Column(self.col)); 153 | self.col += 1; 154 | return item; 155 | } 156 | self.col += 1; 157 | } 158 | 159 | None 160 | } 161 | } 162 | 163 | impl Connect4 { 164 | fn winner(&self) -> Option { 165 | if won(self.op_bb) { 166 | Some(self.player.next()) 167 | } else { 168 | None 169 | } 170 | } 171 | } 172 | 173 | impl Game for Connect4 { 174 | const NAME: &'static str = "Connect4"; 175 | const NUM_PLAYERS: usize = 2; 176 | const MAX_TURNS: usize = 63; 177 | 178 | type PlayerId = PlayerId; 179 | type Action = Column; 180 | type ActionIterator = FreeColumns; 181 | 182 | fn new() -> Self { 183 | Self { 184 | my_bb: 0, 185 | op_bb: 0, 186 | height: [0; WIDTH], 187 | player: PlayerId::Red, 188 | } 189 | } 190 | 191 | fn player(&self) -> Self::PlayerId { 192 | self.player 193 | } 194 | 195 | fn is_over(&self) -> bool { 196 | self.winner().is_some() || (0..WIDTH).all(|col| self.height[col] == HEIGHT as u8) 197 | } 198 | 199 | fn reward(&self, player_id: Self::PlayerId) -> f32 { 200 | // assert!(self.is_over()); 201 | 202 | match self.winner() { 203 | Some(winner) => { 204 | if winner == player_id { 205 | 1.0 206 | } else { 207 | -1.0 208 | } 209 | } 210 | None => 0.0, 211 | } 212 | } 213 | 214 | fn iter_actions(&self) -> Self::ActionIterator { 215 | FreeColumns { 216 | height: self.height, 217 | col: 0, 218 | } 219 | } 220 | 221 | fn step(&mut self, action: &Self::Action) -> bool { 222 | let col: usize = (*action).into(); 223 | 224 | // assert!(self.height[col] < HEIGHT as u8); 225 | 226 | self.my_bb ^= 1 << (self.height[col] + (HEIGHT as u8) * (col as u8)); 227 | self.height[col] += 1; 228 | 229 | std::mem::swap(&mut self.my_bb, &mut self.op_bb); 230 | self.player = self.player.next(); 231 | 232 | self.is_over() 233 | } 234 | 235 | const DIMS: &'static [i64] = &[1, 1, HEIGHT as i64, WIDTH as i64]; 236 | type Features = [[[f32; WIDTH]; HEIGHT]; 1]; 237 | fn features(&self) -> Self::Features { 238 | let mut s = Self::Features::default(); 239 | for row in 0..HEIGHT { 240 | for col in 0..WIDTH { 241 | let index = 1 << (row + HEIGHT * col); 242 | if self.my_bb & index != 0 { 243 | s[0][row][col] = 1.0; 244 | } else if self.op_bb & index != 0 { 245 | s[0][row][col] = -1.0; 246 | } else { 247 | s[0][row][col] = -0.1; 248 | } 249 | } 250 | } 251 | for col in 0..WIDTH { 252 | let h = self.height[col] as usize; 253 | if h < HEIGHT { 254 | s[0][h][col] = 0.1; 255 | } 256 | } 257 | s 258 | } 259 | 260 | fn print(&self) { 261 | if self.is_over() { 262 | println!("{:?} won", self.winner()); 263 | } else { 264 | println!("{:?} to play", self.player); 265 | println!( 266 | "Available Actions: {:?}", 267 | self.iter_actions().collect::>() 268 | ); 269 | } 270 | 271 | let (my_char, op_char) = match self.player { 272 | PlayerId::Black => ("B", "r"), 273 | PlayerId::Red => ("r", "B"), 274 | }; 275 | 276 | for row in (0..HEIGHT).rev() { 277 | for col in 0..WIDTH { 278 | let index = 1 << (row + HEIGHT * col); 279 | print!( 280 | "{} ", 281 | if self.my_bb & index != 0 { 282 | my_char 283 | } else if self.op_bb & index != 0 { 284 | op_char 285 | } else { 286 | "." 287 | } 288 | ); 289 | } 290 | println!(); 291 | } 292 | } 293 | } 294 | 295 | #[cfg(test)] 296 | mod tests { 297 | use super::*; 298 | 299 | #[test] 300 | fn test_first_wins() { 301 | let mut game = Connect4::new(); 302 | assert!(!game.step(&Column(0))); 303 | assert!(!game.step(&Column(1))); 304 | assert!(!game.step(&Column(0))); 305 | assert!(!game.step(&Column(1))); 306 | assert!(!game.step(&Column(0))); 307 | assert!(!game.step(&Column(1))); 308 | assert!(game.step(&Column(0))); 309 | assert!(game.is_over()); 310 | assert_eq!(game.winner(), Some(PlayerId::Red)); 311 | assert_eq!(game.reward(game.player()), -1.0); 312 | assert_eq!(game.player(), PlayerId::Black); 313 | assert_eq!(game.reward(PlayerId::Black), -1.0); 314 | assert_eq!(game.reward(PlayerId::Red), 1.0); 315 | } 316 | 317 | #[test] 318 | fn test_second_wins() { 319 | let mut game = Connect4::new(); 320 | assert!(!game.step(&Column(0))); 321 | assert!(!game.step(&Column(1))); 322 | assert!(!game.step(&Column(2))); 323 | assert!(!game.step(&Column(1))); 324 | assert!(!game.step(&Column(2))); 325 | assert!(!game.step(&Column(1))); 326 | assert!(!game.step(&Column(2))); 327 | assert!(game.step(&Column(1))); 328 | assert!(game.is_over()); 329 | assert_eq!(game.winner(), Some(PlayerId::Black)); 330 | assert_eq!(game.reward(game.player()), -1.0); 331 | assert_eq!(game.player(), PlayerId::Red); 332 | assert_eq!(game.reward(PlayerId::Black), 1.0); 333 | assert_eq!(game.reward(PlayerId::Red), -1.0); 334 | } 335 | 336 | #[test] 337 | fn test_draw() { 338 | /* 339 | +-------------------+ 340 | | r b r b r b r b r | 341 | | r b r b r b r b b | 342 | | r b r b r b r b r | 343 | | b r b r b r b r b | 344 | | b r b r b r b r r | 345 | | r b r b r b r b b | 346 | | r b r b r b r b r | 347 | +-------------------+ 348 | */ 349 | 350 | let mut game = Connect4::new(); 351 | assert!(!game.step(&Column(0))); 352 | assert!(!game.step(&Column(1))); 353 | assert!(!game.step(&Column(0))); 354 | assert!(!game.step(&Column(1))); 355 | assert!(!game.step(&Column(1))); 356 | assert!(!game.step(&Column(0))); 357 | assert!(!game.step(&Column(1))); 358 | assert!(!game.step(&Column(0))); 359 | assert!(!game.step(&Column(0))); 360 | assert!(!game.step(&Column(1))); 361 | assert!(!game.step(&Column(0))); 362 | assert!(!game.step(&Column(1))); 363 | 364 | assert!(game.iter_actions().position(|c| c == Column(0)).is_some()); 365 | assert!(!game.step(&Column(0))); 366 | assert!(game.iter_actions().position(|c| c == Column(0)).is_none()); 367 | 368 | assert!(game.iter_actions().position(|c| c == Column(1)).is_some()); 369 | assert!(!game.step(&Column(1))); 370 | assert!(game.iter_actions().position(|c| c == Column(1)).is_none()); 371 | 372 | assert!(!game.step(&Column(2))); 373 | assert!(!game.step(&Column(3))); 374 | assert!(!game.step(&Column(2))); 375 | assert!(!game.step(&Column(3))); 376 | assert!(!game.step(&Column(3))); 377 | assert!(!game.step(&Column(2))); 378 | assert!(!game.step(&Column(3))); 379 | assert!(!game.step(&Column(2))); 380 | assert!(!game.step(&Column(2))); 381 | assert!(!game.step(&Column(3))); 382 | assert!(!game.step(&Column(2))); 383 | assert!(!game.step(&Column(3))); 384 | 385 | assert!(game.iter_actions().position(|c| c == Column(2)).is_some()); 386 | assert!(!game.step(&Column(2))); 387 | assert!(game.iter_actions().position(|c| c == Column(2)).is_none()); 388 | 389 | assert!(game.iter_actions().position(|c| c == Column(3)).is_some()); 390 | assert!(!game.step(&Column(3))); 391 | assert!(game.iter_actions().position(|c| c == Column(3)).is_none()); 392 | 393 | assert!(!game.step(&Column(4))); 394 | assert!(!game.step(&Column(5))); 395 | assert!(!game.step(&Column(4))); 396 | assert!(!game.step(&Column(5))); 397 | assert!(!game.step(&Column(5))); 398 | assert!(!game.step(&Column(4))); 399 | assert!(!game.step(&Column(5))); 400 | assert!(!game.step(&Column(4))); 401 | assert!(!game.step(&Column(4))); 402 | assert!(!game.step(&Column(5))); 403 | assert!(!game.step(&Column(4))); 404 | assert!(!game.step(&Column(5))); 405 | 406 | assert!(game.iter_actions().position(|c| c == Column(4)).is_some()); 407 | assert!(!game.step(&Column(4))); 408 | assert!(game.iter_actions().position(|c| c == Column(4)).is_none()); 409 | 410 | assert!(game.iter_actions().position(|c| c == Column(5)).is_some()); 411 | assert!(!game.step(&Column(5))); 412 | assert!(game.iter_actions().position(|c| c == Column(5)).is_none()); 413 | 414 | assert!(!game.step(&Column(6))); 415 | assert!(!game.step(&Column(7))); 416 | assert!(!game.step(&Column(6))); 417 | assert!(!game.step(&Column(7))); 418 | assert!(!game.step(&Column(7))); 419 | assert!(!game.step(&Column(6))); 420 | assert!(!game.step(&Column(7))); 421 | assert!(!game.step(&Column(6))); 422 | assert!(!game.step(&Column(6))); 423 | assert!(!game.step(&Column(7))); 424 | assert!(!game.step(&Column(6))); 425 | assert!(!game.step(&Column(7))); 426 | 427 | assert!(game.iter_actions().position(|c| c == Column(6)).is_some()); 428 | assert!(!game.step(&Column(6))); 429 | assert!(game.iter_actions().position(|c| c == Column(6)).is_none()); 430 | 431 | assert!(game.iter_actions().position(|c| c == Column(7)).is_some()); 432 | assert!(!game.step(&Column(7))); 433 | assert!(game.iter_actions().position(|c| c == Column(7)).is_none()); 434 | 435 | assert!(!game.step(&Column(8))); 436 | assert!(!game.step(&Column(8))); 437 | assert!(!game.step(&Column(8))); 438 | assert!(!game.step(&Column(8))); 439 | assert!(!game.step(&Column(8))); 440 | assert!(!game.step(&Column(8))); 441 | assert!(game.iter_actions().position(|c| c == Column(8)).is_some()); 442 | assert!(game.step(&Column(8))); 443 | assert!(game.is_over()); 444 | assert_eq!(game.winner(), None); 445 | assert_eq!(game.reward(PlayerId::Red), 0.0); 446 | assert_eq!(game.reward(PlayerId::Black), 0.0); 447 | } 448 | 449 | #[test] 450 | fn test_horz_wins() { 451 | for row in 0..HEIGHT { 452 | let mut bb = 453 | (1 << (row + 0)) | (1 << (row + 7)) | (1 << (row + 14)) | (1 << (row + 21)); 454 | for _i in 0..6 { 455 | assert!(won(bb)); 456 | bb <<= 7; 457 | } 458 | } 459 | } 460 | 461 | #[test] 462 | fn test_vert_wins() { 463 | for col in 0..WIDTH { 464 | let mut bb = (1 << (7 * col + 0)) 465 | | (1 << (7 * col + 1)) 466 | | (1 << (7 * col + 2)) 467 | | (1 << (7 * col + 3)); 468 | for _i in 0..4 { 469 | assert!(won(bb)); 470 | bb <<= 1; 471 | } 472 | } 473 | } 474 | 475 | #[test] 476 | fn test_d1_wins() { 477 | for row in 3..HEIGHT { 478 | let mut bb = (1 << row) | (1 << (row + 6)) | (1 << (row + 12)) | (1 << (row + 18)); 479 | for _i in 0..6 { 480 | assert!(won(bb)); 481 | bb <<= 7; 482 | } 483 | } 484 | } 485 | 486 | #[test] 487 | fn test_d2_wins() { 488 | for col in 0..6 { 489 | let mut bb = (1 << (7 * col + 0)) 490 | | (1 << (7 * (col + 1) + 1)) 491 | | (1 << (7 * (col + 2) + 2)) 492 | | (1 << (7 * (col + 3) + 3)); 493 | for _i in 0..4 { 494 | assert!(won(bb)); 495 | bb <<= 1; 496 | } 497 | } 498 | } 499 | } 500 | -------------------------------------------------------------------------------- /study-connect4/src/main.rs: -------------------------------------------------------------------------------- 1 | mod connect4; 2 | mod policies; 3 | 4 | use rand::{distributions::Distribution, thread_rng}; 5 | use rand_distr::Normal; 6 | 7 | use crate::connect4::Connect4; 8 | use crate::policies::*; 9 | use synthesis::prelude::*; 10 | 11 | fn learn, P: Policy + NNPolicy, const N: usize>( 12 | ) -> Result<(), Box> { 13 | let cfg = LearningConfig { 14 | seed: 0, // seed for rng & torch 15 | logs: train_dir("./_logs", G::NAME)?, // log directory 16 | num_iterations: 200, // number of training iterations to run 17 | 18 | lr_schedule: vec![(1, 1e-3), (20, 5e-4), (40, 1e-4), (60, 5e-5), (80, 1e-5)], // schedule for lr - first item in tuple is iteration # 19 | weight_decay: 1e-6, // L2 regularization for Adam optimizer 20 | num_epochs: 20, // number of full passes over training data per iteration 21 | batch_size: 32, // size of batches that epochs are split into 22 | policy_weight: 1.0, // scalar for policy loss 23 | value_weight: 1.0, // scalar for value loss 24 | 25 | games_to_keep: 20000, // number of games to keep in replay buffer 26 | games_per_train: 1000, // number of new games to add to replay buffer per training iteration 27 | 28 | rollout_cfg: RolloutConfig { 29 | num_workers: 6, // number of processes to use for running games 30 | num_explores: 1600, // number of MCTS explores per turn 31 | random_actions_until: 1, // last turn number to select random actions 32 | sample_actions_until: 30, // last turn number to sample actions 33 | stop_games_when_solved: false, // end games early if they are solved by MCTS 34 | value_target: ValueTarget::Q, // the target for NN value function 35 | action: ActionSelection::NumVisits, // the value to use for best action 36 | 37 | mcts_cfg: MCTSConfig { 38 | exploration: Exploration::PolynomialUct { c: 3.0 }, // type of exploration to use (e.g. PUCT or UCT) 39 | solve: true, // use MCTS Solver extension to solve nodes 40 | correct_values_on_solve: true, // if node is solved, adjust previously backprop'd values 41 | select_solved_nodes: true, // select nodes that are solved 42 | auto_extend: true, // visit nodes until a node with > 1 child is reached 43 | fpu: Fpu::Func(|| { 44 | // exploit value of un-evaluated nodes 45 | let dist = Normal::new(1.0, 0.1).unwrap(); 46 | dist.sample(&mut thread_rng()) 47 | }), 48 | root_policy_noise: PolicyNoise::None, 49 | }, 50 | }, 51 | }; 52 | 53 | let eval_cfg = EvaluationConfig { 54 | logs: cfg.logs.clone(), 55 | 56 | policy_num_explores: cfg.rollout_cfg.num_explores, 57 | policy_action: ActionSelection::NumVisits, 58 | policy_mcts_cfg: MCTSConfig { 59 | exploration: Exploration::PolynomialUct { c: 3.0 }, 60 | solve: true, 61 | correct_values_on_solve: true, 62 | select_solved_nodes: true, 63 | auto_extend: true, 64 | fpu: Fpu::Const(1.0), 65 | root_policy_noise: PolicyNoise::None, 66 | }, 67 | 68 | num_games_against_best_policies: 1, 69 | num_best_policies: 10, 70 | 71 | num_games_against_rollout: 5, 72 | rollout_num_explores: vec![800, 1600, 3200, 6400, 12800, 25600, 51200, 102400, 204800], 73 | rollout_action: ActionSelection::Q, 74 | rollout_mcts_cfg: MCTSConfig { 75 | exploration: Exploration::Uct { c: 2.0 }, 76 | solve: true, 77 | correct_values_on_solve: true, 78 | select_solved_nodes: true, 79 | auto_extend: false, 80 | fpu: Fpu::Const(f32::INFINITY), 81 | root_policy_noise: PolicyNoise::None, 82 | }, 83 | }; 84 | 85 | tch::set_num_threads(1); 86 | tch::set_num_interop_threads(1); 87 | 88 | let eval_handle = std::thread::spawn(move || evaluator::(&eval_cfg).unwrap()); 89 | alpha_zero::(&cfg)?; 90 | eval_handle.join().unwrap(); 91 | Ok(()) 92 | } 93 | 94 | fn main() { 95 | learn::().unwrap() 96 | } 97 | -------------------------------------------------------------------------------- /study-connect4/src/policies.rs: -------------------------------------------------------------------------------- 1 | use crate::connect4::Connect4; 2 | use synthesis::prelude::*; 3 | use tch::{self, nn, Tensor}; 4 | 5 | pub struct Connect4Net { 6 | l_1: nn::Linear, 7 | l_2: nn::Linear, 8 | l_3: nn::Linear, 9 | l_4: nn::Linear, 10 | l_5: nn::Linear, 11 | } 12 | 13 | impl NNPolicy for Connect4Net { 14 | fn new(vs: &nn::VarStore) -> Self { 15 | let root = &vs.root(); 16 | let state_dims = Connect4::DIMS; 17 | assert!(state_dims.len() == 4); 18 | assert!(&state_dims == &[1, 1, 7, 9]); 19 | Self { 20 | l_1: nn::linear(root / "l_1", 63, 128, Default::default()), 21 | l_2: nn::linear(root / "l_2", 128, 96, Default::default()), 22 | l_3: nn::linear(root / "l_3", 96, 64, Default::default()), 23 | l_4: nn::linear(root / "l_4", 64, 48, Default::default()), 24 | l_5: nn::linear(root / "l_5", 48, 12, Default::default()), 25 | } 26 | } 27 | 28 | fn forward(&self, xs: &Tensor) -> (Tensor, Tensor) { 29 | let xs = xs 30 | .flat_view() 31 | .apply(&self.l_1) 32 | .relu() 33 | .apply(&self.l_2) 34 | .relu() 35 | .apply(&self.l_3) 36 | .relu() 37 | .apply(&self.l_4) 38 | .relu() 39 | .apply(&self.l_5); 40 | let mut ts = xs.split_with_sizes(&[9, 3], -1); 41 | let outcome_logits = ts.pop().unwrap(); 42 | let policy_logits = ts.pop().unwrap(); 43 | (policy_logits, outcome_logits) 44 | } 45 | } 46 | 47 | impl Policy for Connect4Net { 48 | fn eval(&mut self, env: &Connect4) -> ([f32; Connect4::MAX_NUM_ACTIONS], [f32; 3]) { 49 | let xs = env.features(); 50 | let t = tensor(&xs, Connect4::DIMS, tch::Kind::Float); 51 | let (logits, value) = self.forward(&t); 52 | let mut policy = [0.0f32; Connect4::MAX_NUM_ACTIONS]; 53 | logits.copy_data(&mut policy, Connect4::MAX_NUM_ACTIONS); 54 | let mut outcomes = [0.0f32; 3]; 55 | value 56 | .softmax(-1, tch::Kind::Float) 57 | .copy_data(&mut outcomes, 3); 58 | (policy, outcomes) 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /synthesis/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "synthesis" 3 | version = "0.1.0" 4 | authors = ["Corey Lowman "] 5 | edition = "2018" 6 | license = "MIT OR Apache-2.0" 7 | 8 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 9 | 10 | [dependencies] 11 | tch = "0.4.1" 12 | rand = "0.8.3" 13 | ordered-float = "2.5.0" 14 | serde_json = "1.0.64" 15 | serde = { version = "1.0.126", features = ["derive"] } 16 | chrono = "0.4.3" 17 | log = "0.4.14" 18 | env_logger = "0.8.3" 19 | torch-sys = "0.4.1" 20 | rand_distr = "0.4.0" 21 | indicatif = "0.16.2" 22 | slimnn = { path = "../slimnn" } -------------------------------------------------------------------------------- /synthesis/src/alpha_zero.rs: -------------------------------------------------------------------------------- 1 | use crate::config::{LearningConfig, RolloutConfig, ValueTarget}; 2 | use crate::data::*; 3 | use crate::game::{Game, Outcome}; 4 | use crate::mcts::MCTS; 5 | use crate::policies::{NNPolicy, Policy, PolicyWithCache}; 6 | use crate::utils::*; 7 | use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; 8 | use rand::prelude::*; 9 | use rand::{distributions::Distribution, distributions::WeightedIndex}; 10 | use std::default::Default; 11 | use tch::{ 12 | kind::Kind, 13 | nn::{Adam, OptimizerConfig, VarStore}, 14 | }; 15 | 16 | pub fn alpha_zero, P: Policy + NNPolicy, const N: usize>( 17 | cfg: &LearningConfig, 18 | ) -> Result<(), Box> { 19 | // set up directory structure 20 | std::fs::create_dir_all(&cfg.logs)?; 21 | let models_dir = cfg.logs.join("models"); 22 | std::fs::create_dir(&models_dir)?; 23 | save_str(&cfg.logs, "env_name", &G::NAME.into())?; 24 | save_str(&cfg.logs, "git_hash", &git_hash()?)?; 25 | save_str(&cfg.logs, "git_diff.patch", &git_diff()?)?; 26 | 27 | // seed rngs 28 | tch::manual_seed(cfg.seed as i64); 29 | 30 | // init policy 31 | let vs = VarStore::new(tch::Device::Cpu); 32 | let policy = P::new(&vs); 33 | let mut opt = Adam::default().build(&vs, cfg.lr_schedule[0].1)?; 34 | if cfg.weight_decay > 0.0 { 35 | opt.set_weight_decay(cfg.weight_decay); 36 | } 37 | vs.save(models_dir.join(String::from("model_0.ot")))?; 38 | 39 | // init replay buffer 40 | let mut buffer = ReplayBuffer::new(256_000); 41 | 42 | // start learning! 43 | let mut dims = G::DIMS.to_owned(); 44 | let batch_mean = 1.0 / (cfg.batch_size as f32); 45 | for i_iter in 0..cfg.num_iterations { 46 | // gather data 47 | { 48 | let _guard = tch::no_grad_guard(); 49 | gather_experience::(cfg, format!("model_{}.ot", i_iter), &mut buffer, i_iter); 50 | } 51 | 52 | // convert buffer data to tensors 53 | let dedup = buffer.deduplicate(); 54 | println!("Dedup {} -> {} steps", buffer.vs.len(), dedup.vs.len()); 55 | dims[0] = dedup.vs.len() as i64; 56 | let states = tensor(&dedup.states, &dims, Kind::Float); 57 | let target_pis = tensor(&dedup.pis, &[dims[0], N as i64], Kind::Float); 58 | let target_vs = tensor(&dedup.vs, &[dims[0], 3], Kind::Float); 59 | 60 | // calculate lr from schedule 61 | let lr = cfg 62 | .lr_schedule 63 | .iter() 64 | .filter(|(i, _lr)| *i <= i_iter + 1) 65 | .last() 66 | .unwrap() 67 | .1; 68 | opt.set_lr(lr); 69 | println!("Using lr={}", lr); 70 | 71 | // train 72 | for _i_epoch in 0..cfg.num_epochs { 73 | let sampler = 74 | BatchRandSampler::new(&states, &target_pis, &target_vs, cfg.batch_size, true); 75 | 76 | let mut epoch_loss = [0.0, 0.0]; 77 | for (state, target_pi, target_v) in sampler { 78 | let (pi_logits, v_logits) = policy.forward(&state); 79 | 80 | let log_pi = pi_logits.log_softmax(-1, Kind::Float); 81 | let log_v = v_logits.log_softmax(-1, Kind::Float); 82 | let pi_loss = batch_mean * log_pi.kl_div(&target_pi, tch::Reduction::Sum, false); 83 | let v_loss = batch_mean * log_v.kl_div(&target_v, tch::Reduction::Sum, false); 84 | 85 | let loss = cfg.policy_weight * &pi_loss + cfg.value_weight * &v_loss; 86 | opt.backward_step(&loss); 87 | 88 | epoch_loss[0] += f32::from(&pi_loss); 89 | epoch_loss[1] += f32::from(&v_loss); 90 | } 91 | epoch_loss[0] *= (cfg.batch_size as f32) / (dims[0] as f32); 92 | epoch_loss[1] *= (cfg.batch_size as f32) / (dims[0] as f32); 93 | println!("{} {:?}", _i_epoch, epoch_loss); 94 | } 95 | 96 | // save latest weights 97 | vs.save(models_dir.join(format!("model_{}.ot", i_iter + 1)))?; 98 | states.write_npy(cfg.logs.join("latest_states.npy"))?; 99 | target_pis.write_npy(cfg.logs.join("latest_pis.npy"))?; 100 | target_vs.write_npy(cfg.logs.join("latest_vs.npy"))?; 101 | 102 | println!("Finished iteration {}", i_iter + 1); 103 | println!( 104 | "lifetime: {} steps / {} games ({:.3} steps/game)", 105 | buffer.total_steps(), 106 | buffer.total_games_played(), 107 | buffer.total_steps() as f32 / buffer.total_games_played() as f32, 108 | ); 109 | println!( 110 | "buffer: {} steps / {} games ({:.3} steps/game)", 111 | buffer.curr_steps(), 112 | buffer.curr_games(), 113 | buffer.curr_steps() as f32 / buffer.curr_games() as f32, 114 | ); 115 | } 116 | 117 | Ok(()) 118 | } 119 | 120 | fn gather_experience, P: Policy + NNPolicy, const N: usize>( 121 | cfg: &LearningConfig, 122 | policy_name: String, 123 | buffer: &mut ReplayBuffer, 124 | seed: usize, 125 | ) { 126 | let mut games_to_schedule = cfg.games_per_train; 127 | let mut workers_left = cfg.rollout_cfg.num_workers + 1; 128 | let mut handles = Vec::with_capacity(workers_left); 129 | let multi_bar = MultiProgress::new(); 130 | 131 | // create workers 132 | for i_worker in 0..cfg.rollout_cfg.num_workers + 1 { 133 | // create copies of data for this worker 134 | let worker_policy_name = policy_name.clone(); 135 | let worker_cfg = cfg.clone(); 136 | 137 | // calculate number of games this worker will run. this allows uneven number of games across workers 138 | let num_games = games_to_schedule / workers_left; 139 | let worker_bar = multi_bar.add(styled_progress_bar(num_games)); 140 | let worker_seed = seed * (cfg.rollout_cfg.num_workers + 1) + i_worker; 141 | // spawn a worker 142 | handles.push(std::thread::spawn(move || { 143 | run_n_games::( 144 | worker_cfg, 145 | worker_policy_name, 146 | num_games, 147 | worker_bar, 148 | worker_seed, 149 | ) 150 | })); 151 | 152 | games_to_schedule -= num_games; 153 | workers_left -= 1; 154 | } 155 | 156 | // sanity check that all games are scheduled 157 | assert!(games_to_schedule == 0); 158 | assert!(workers_left == 0); 159 | 160 | // wait for workers to complete 161 | multi_bar.join().unwrap(); 162 | 163 | // collect experience gathered into main buffer 164 | buffer.keep_last_n_games(cfg.games_to_keep - cfg.games_per_train); 165 | for handle in handles.drain(..) { 166 | let mut worker_buffer = handle.join().unwrap(); 167 | buffer.extend(&mut worker_buffer); 168 | } 169 | } 170 | 171 | fn styled_progress_bar(n: usize) -> ProgressBar { 172 | let bar = ProgressBar::new(n as u64); 173 | bar.set_style( 174 | ProgressStyle::default_bar() 175 | .template("[{bar:40}] {pos}/{len} ({percent}%) | {eta} remaining | {elapsed_precise}") 176 | .progress_chars("|| "), 177 | ); 178 | bar 179 | } 180 | 181 | fn run_n_games, P: Policy + NNPolicy, const N: usize>( 182 | cfg: LearningConfig, 183 | policy_name: String, 184 | num_games: usize, 185 | progress_bar: ProgressBar, 186 | seed: usize, 187 | ) -> ReplayBuffer { 188 | let mut buffer = ReplayBuffer::new(G::MAX_TURNS * num_games); 189 | let mut rng = StdRng::seed_from_u64(seed as u64); 190 | 191 | // load the policy weights 192 | let mut vs = VarStore::new(tch::Device::Cpu); 193 | let mut policy = P::new(&vs); 194 | vs.load(cfg.logs.join("models").join(&policy_name)).unwrap(); 195 | 196 | // create a cache for this policy, this speeds things up a lot, but takes memory 197 | let mut cached_policy = 198 | PolicyWithCache::with_capacity(G::MAX_TURNS * cfg.games_per_train, &mut policy); 199 | 200 | // run all the games 201 | for _ in 0..num_games { 202 | buffer.new_game(); 203 | run_game(&cfg.rollout_cfg, &mut cached_policy, &mut rng, &mut buffer); 204 | progress_bar.inc(1); 205 | } 206 | progress_bar.finish(); 207 | 208 | buffer 209 | } 210 | 211 | struct StateInfo { 212 | turn: usize, 213 | t: f32, 214 | q: [f32; 3], 215 | z: [f32; 3], 216 | } 217 | 218 | impl StateInfo { 219 | fn q(turn: usize, q: [f32; 3]) -> Self { 220 | Self { 221 | turn, 222 | t: 0.0, 223 | q, 224 | z: [0.0; 3], 225 | } 226 | } 227 | } 228 | 229 | fn run_game, P: Policy, R: Rng, const N: usize>( 230 | cfg: &RolloutConfig, 231 | policy: &mut P, 232 | rng: &mut R, 233 | buffer: &mut ReplayBuffer, 234 | ) { 235 | let mut game = G::new(); 236 | let mut solution = None; 237 | let mut search_policy = [0.0; N]; 238 | let mut num_turns = 0; 239 | let mut state_infos = Vec::with_capacity(G::MAX_TURNS); 240 | 241 | while solution.is_none() { 242 | let mut mcts = 243 | MCTS::with_capacity(cfg.num_explores + 1, cfg.mcts_cfg, policy, game.clone()); 244 | 245 | // explore 246 | mcts.explore_n(cfg.num_explores); 247 | 248 | // store in buffer 249 | mcts.target_policy(&mut search_policy); 250 | buffer.add(&game, &search_policy, [0.0; 3]); 251 | state_infos.push(StateInfo::q(num_turns + 1, mcts.target_q())); 252 | 253 | // pick action 254 | let action = sample_action(&cfg, &mut mcts, &game, &search_policy, rng, num_turns); 255 | solution = mcts.solution(&action); 256 | 257 | let is_over = game.step(&action); 258 | if is_over { 259 | solution = Some(game.reward(game.player()).into()); 260 | } else if !cfg.stop_games_when_solved { 261 | solution = None; 262 | } 263 | num_turns += 1; 264 | } 265 | 266 | fill_state_info(&mut state_infos, solution.unwrap().reversed()); 267 | store_rewards(&cfg, buffer, &state_infos); 268 | } 269 | 270 | fn sample_action, P: Policy, R: Rng, const N: usize>( 271 | cfg: &RolloutConfig, 272 | mcts: &mut MCTS, 273 | game: &G, 274 | search_policy: &[f32], 275 | rng: &mut R, 276 | num_turns: usize, 277 | ) -> G::Action { 278 | let best = mcts.best_action(cfg.action); 279 | let solution = mcts.solution(&best); 280 | let action = if num_turns < cfg.random_actions_until { 281 | let n = rng.gen_range(0..game.iter_actions().count() as u8) as usize; 282 | game.iter_actions().nth(n).unwrap() 283 | } else if num_turns < cfg.sample_actions_until 284 | && (solution.is_none() || !cfg.stop_games_when_solved) 285 | { 286 | let dist = WeightedIndex::new(search_policy).unwrap(); 287 | let choice = dist.sample(rng); 288 | // assert!(search_policy[choice] > 0.0); 289 | G::Action::from(choice) 290 | } else { 291 | best 292 | }; 293 | action 294 | } 295 | 296 | fn fill_state_info(state_infos: &mut Vec, mut outcome: Outcome) { 297 | let num_turns = state_infos.len(); 298 | for state_value in state_infos.iter_mut().rev() { 299 | state_value.z[match outcome { 300 | Outcome::Win(_) => 2, 301 | Outcome::Draw(_) => 1, 302 | Outcome::Lose(_) => 0, 303 | }] = 1.0; 304 | state_value.t = state_value.turn as f32 / num_turns as f32; 305 | outcome = outcome.reversed(); 306 | } 307 | } 308 | 309 | fn store_rewards, const N: usize>( 310 | cfg: &RolloutConfig, 311 | buffer: &mut ReplayBuffer, 312 | state_infos: &Vec, 313 | ) { 314 | let num_turns = state_infos.len(); 315 | let start_i = buffer.curr_steps() - num_turns; 316 | let end_i = buffer.curr_steps(); 317 | for (buffer_value, state) in buffer.vs[start_i..end_i].iter_mut().zip(state_infos) { 318 | *buffer_value = match cfg.value_target { 319 | ValueTarget::Q => state.q, 320 | ValueTarget::Z => state.z, 321 | ValueTarget::QZaverage { p } => { 322 | let mut value = [0.0; 3]; 323 | for i in 0..3 { 324 | value[i] = state.q[i] * p + state.z[i] * (1.0 - p); 325 | } 326 | value 327 | } 328 | ValueTarget::QtoZ { from, to } => { 329 | let p = (1.0 - state.t) * from + state.t * to; 330 | let mut value = [0.0; 3]; 331 | for i in 0..3 { 332 | value[i] = state.q[i] * (1.0 - p) + state.z[i] * p; 333 | } 334 | value 335 | } 336 | }; 337 | } 338 | } 339 | -------------------------------------------------------------------------------- /synthesis/src/config.rs: -------------------------------------------------------------------------------- 1 | #[derive(Debug, Clone, Copy)] 2 | pub enum ValueTarget { 3 | Z, // Outcome of game {-1, 0, 1} 4 | Q, // Avg Value found while searching 5 | QZaverage { p: f32 }, // Q * p + Z * (1 - p) 6 | QtoZ { from: f32, to: f32 }, // interpolate from Q to Z based on turns 7 | } 8 | 9 | #[derive(Debug, Clone, Copy)] 10 | pub enum Exploration { 11 | Uct { c: f32 }, 12 | PolynomialUct { c: f32 }, 13 | } 14 | 15 | #[derive(Debug, Clone, Copy)] 16 | pub enum ActionSelection { 17 | Q, // avg value 18 | NumVisits, // num visits 19 | } 20 | 21 | #[derive(Debug, Clone, Copy)] 22 | pub enum Fpu { 23 | Const(f32), 24 | ParentQ, 25 | Func(fn() -> f32), 26 | } 27 | 28 | #[derive(Debug, Clone, Copy)] 29 | pub struct MCTSConfig { 30 | pub exploration: Exploration, 31 | pub solve: bool, 32 | pub correct_values_on_solve: bool, 33 | pub select_solved_nodes: bool, 34 | pub auto_extend: bool, 35 | pub fpu: Fpu, 36 | pub root_policy_noise: PolicyNoise, 37 | } 38 | 39 | #[derive(Debug, Clone, Copy)] 40 | pub enum PolicyNoise { 41 | None, 42 | Equal { weight: f32 }, 43 | Dirichlet { alpha: f32, weight: f32 }, 44 | } 45 | 46 | #[derive(Debug, Clone, Copy)] 47 | pub struct RolloutConfig { 48 | pub num_workers: usize, 49 | pub num_explores: usize, 50 | pub random_actions_until: usize, 51 | pub sample_actions_until: usize, 52 | pub stop_games_when_solved: bool, 53 | pub value_target: ValueTarget, 54 | pub action: ActionSelection, 55 | pub mcts_cfg: MCTSConfig, 56 | } 57 | 58 | #[derive(Debug, Clone)] 59 | pub struct EvaluationConfig { 60 | pub logs: std::path::PathBuf, 61 | 62 | pub policy_num_explores: usize, 63 | pub policy_action: ActionSelection, 64 | pub policy_mcts_cfg: MCTSConfig, 65 | 66 | pub num_best_policies: usize, 67 | pub num_games_against_best_policies: usize, 68 | 69 | pub rollout_action: ActionSelection, 70 | pub rollout_num_explores: Vec, 71 | pub rollout_mcts_cfg: MCTSConfig, 72 | pub num_games_against_rollout: usize, 73 | } 74 | 75 | #[derive(Debug, Clone)] 76 | pub struct LearningConfig { 77 | // general params 78 | pub seed: u64, 79 | pub logs: std::path::PathBuf, 80 | 81 | // training params 82 | pub lr_schedule: Vec<(usize, f64)>, 83 | pub weight_decay: f64, 84 | pub num_iterations: usize, 85 | pub num_epochs: usize, 86 | pub batch_size: i64, 87 | pub policy_weight: f32, 88 | pub value_weight: f32, 89 | 90 | pub games_to_keep: usize, 91 | pub games_per_train: usize, 92 | 93 | pub rollout_cfg: RolloutConfig, 94 | } 95 | -------------------------------------------------------------------------------- /synthesis/src/data.rs: -------------------------------------------------------------------------------- 1 | use crate::game::Game; 2 | use std::{collections::HashMap, ffi::c_void}; 3 | use tch::{Kind, Tensor}; 4 | use torch_sys::at_tensor_of_data; 5 | 6 | pub struct BatchRandSampler<'a> { 7 | inds: Tensor, 8 | 9 | x: &'a Tensor, 10 | y: &'a Tensor, 11 | z: &'a Tensor, 12 | 13 | size: i64, 14 | batch_size: i64, 15 | index: i64, 16 | drop_last: bool, 17 | } 18 | 19 | impl<'a> BatchRandSampler<'a> { 20 | pub fn new( 21 | x: &'a Tensor, 22 | y: &'a Tensor, 23 | z: &'a Tensor, 24 | batch_size: i64, 25 | drop_last: bool, 26 | ) -> Self { 27 | let n = x.size()[0]; 28 | Self { 29 | inds: Tensor::randperm(n, tch::kind::INT64_CPU), 30 | x, 31 | y, 32 | z, 33 | size: n, 34 | batch_size, 35 | index: 0, 36 | drop_last, 37 | } 38 | } 39 | } 40 | 41 | impl<'a> Iterator for BatchRandSampler<'a> { 42 | type Item = (Tensor, Tensor, Tensor); 43 | 44 | fn next(&mut self) -> Option { 45 | let next_index = (self.index + self.batch_size).min(self.size); 46 | if self.index >= self.size 47 | || (self.drop_last && (next_index - self.index) < self.batch_size) 48 | { 49 | return None; 50 | } 51 | 52 | let batch_inds = self 53 | .inds 54 | .narrow(0, self.index as i64, (next_index - self.index) as i64); 55 | self.index = next_index; 56 | 57 | let item = ( 58 | self.x.index_select(0, &batch_inds), 59 | self.y.index_select(0, &batch_inds), 60 | self.z.index_select(0, &batch_inds), 61 | ); 62 | Some(item) 63 | } 64 | } 65 | 66 | pub fn tensor(data: &[T], dims: &[i64], kind: tch::Kind) -> Tensor { 67 | let dsize = kind.elt_size_in_bytes(); 68 | let dtype = match kind { 69 | Kind::Uint8 => 0, 70 | Kind::Int8 => 1, 71 | Kind::Int16 => 2, 72 | Kind::Int => 3, 73 | Kind::Int64 => 4, 74 | Kind::Half => 5, 75 | Kind::Float => 6, 76 | Kind::Double => 7, 77 | Kind::ComplexHalf => 8, 78 | Kind::ComplexFloat => 9, 79 | Kind::ComplexDouble => 10, 80 | Kind::Bool => 11, 81 | Kind::QInt8 => 12, 82 | Kind::QUInt8 => 13, 83 | Kind::QInt32 => 14, 84 | Kind::BFloat16 => 15, 85 | }; 86 | let data = data.as_ptr() as *const c_void; 87 | let ndims = dims.len(); 88 | let dims = dims.as_ptr(); 89 | unsafe { Tensor::from_ptr(at_tensor_of_data(data, dims, ndims, dsize, dtype)) } 90 | } 91 | 92 | pub struct FlatBatch, const N: usize> { 93 | pub states: Vec, 94 | pub pis: Vec<[f32; N]>, 95 | pub vs: Vec<[f32; 3]>, 96 | } 97 | 98 | #[derive(Debug)] 99 | struct StateStatistics, const N: usize> { 100 | state: G::Features, 101 | sum_pi: [f32; N], 102 | sum_v: [f32; 3], 103 | num: u32, 104 | } 105 | 106 | pub struct ReplayBuffer, const N: usize> { 107 | game_id: usize, 108 | steps: usize, 109 | game_ids: Vec, 110 | pub games: Vec, 111 | pub states: Vec, 112 | pub pis: Vec<[f32; N]>, 113 | pub vs: Vec<[f32; 3]>, 114 | } 115 | 116 | impl, const N: usize> ReplayBuffer { 117 | pub fn new(n: usize) -> Self { 118 | Self { 119 | game_id: 0, 120 | steps: 0, 121 | game_ids: Vec::with_capacity(n), 122 | games: Vec::with_capacity(n), 123 | states: Vec::with_capacity(n), 124 | pis: Vec::with_capacity(n), 125 | vs: Vec::with_capacity(n), 126 | } 127 | } 128 | 129 | pub fn new_game(&mut self) { 130 | self.game_id += 1; 131 | } 132 | 133 | pub fn total_games_played(&self) -> usize { 134 | self.game_id 135 | } 136 | 137 | pub fn curr_games(&self) -> usize { 138 | let mut unique = self.game_ids.clone(); 139 | unique.dedup(); 140 | unique.len() 141 | } 142 | 143 | pub fn total_steps(&self) -> usize { 144 | self.steps 145 | } 146 | 147 | pub fn curr_steps(&self) -> usize { 148 | self.vs.len() 149 | } 150 | 151 | pub fn add(&mut self, game: &G, pi: &[f32; N], v: [f32; 3]) { 152 | self.game_ids.push(self.game_id); 153 | self.steps += 1; 154 | self.games.push(game.clone()); 155 | self.states.push(game.features()); 156 | self.pis.push(*pi); 157 | self.vs.push(v); 158 | } 159 | 160 | pub fn extend(&mut self, other: &mut Self) { 161 | self.steps += other.steps; 162 | let start = self.game_id; 163 | self.game_ids 164 | .extend(other.game_ids.iter().map(|&g| g + start)); 165 | self.game_id += other.game_id; 166 | self.games.extend(other.games.drain(..)); 167 | self.states.extend(other.states.drain(..)); 168 | self.pis.extend(other.pis.drain(..)); 169 | self.vs.extend(other.vs.drain(..)); 170 | } 171 | 172 | pub fn keep_last_n_games(&mut self, n: usize) { 173 | if self.game_id <= n { 174 | return; 175 | } 176 | 177 | let min_game_id = self.game_id - n; 178 | 179 | let mut max_ind_to_remove = None; 180 | for (i, &game_id) in self.game_ids.iter().enumerate() { 181 | if game_id >= min_game_id { 182 | break; 183 | } 184 | max_ind_to_remove = Some(i); 185 | } 186 | if let Some(max_ind) = max_ind_to_remove { 187 | drop(self.game_ids.drain(0..=max_ind)); 188 | drop(self.games.drain(0..=max_ind)); 189 | drop(self.states.drain(0..=max_ind)); 190 | drop(self.pis.drain(0..=max_ind)); 191 | drop(self.vs.drain(0..=max_ind)); 192 | assert!(self.game_ids[0] >= min_game_id); 193 | } 194 | } 195 | 196 | pub fn deduplicate(&self) -> FlatBatch { 197 | let mut statistics: HashMap> = 198 | HashMap::with_capacity(self.game_ids.len()); 199 | for i in 0..self.game_ids.len() { 200 | let stats = statistics 201 | .entry(self.games[i].clone()) 202 | .or_insert(StateStatistics { 203 | state: self.states[i].clone(), 204 | sum_pi: [0.0; N], 205 | sum_v: [0.0; 3], 206 | num: 0, 207 | }); 208 | for j in 0..N { 209 | stats.sum_pi[j] += self.pis[i][j]; 210 | } 211 | for j in 0..3 { 212 | stats.sum_v[j] += self.vs[i][j]; 213 | } 214 | stats.num += 1; 215 | } 216 | 217 | let mut states = Vec::with_capacity(statistics.len()); 218 | let mut pis = Vec::with_capacity(statistics.len()); 219 | let mut vs = Vec::with_capacity(statistics.len()); 220 | for (_, stats) in statistics.iter() { 221 | let mut avg_pi = [0.0; N]; 222 | for i in 0..N { 223 | avg_pi[i] = stats.sum_pi[i] / stats.num as f32; 224 | } 225 | let mut avg_v = [0.0; 3]; 226 | for i in 0..3 { 227 | avg_v[i] = stats.sum_v[i] / stats.num as f32; 228 | } 229 | states.push(stats.state.clone()); 230 | pis.push(avg_pi); 231 | vs.push(avg_v); 232 | } 233 | 234 | FlatBatch { states, pis, vs } 235 | } 236 | } 237 | -------------------------------------------------------------------------------- /synthesis/src/evaluator.rs: -------------------------------------------------------------------------------- 1 | use crate::config::*; 2 | use crate::game::*; 3 | use crate::mcts::MCTS; 4 | use crate::policies::*; 5 | use crate::utils::*; 6 | use rand::prelude::{Rng, SeedableRng, StdRng}; 7 | use tch::nn::VarStore; 8 | 9 | pub fn evaluator, P: Policy + NNPolicy, const N: usize>( 10 | cfg: &EvaluationConfig, 11 | ) -> Result<(), Box> { 12 | std::thread::sleep(std::time::Duration::from_secs(1)); 13 | 14 | let models_dir = cfg.logs.join("models"); 15 | let pgn_path = cfg.logs.join("results.pgn"); 16 | let mut pgn = std::fs::File::create(&pgn_path)?; 17 | let _guard = tch::no_grad_guard(); 18 | let first_player = G::new().player(); 19 | 20 | let mut best_k = Vec::with_capacity(cfg.num_best_policies); 21 | 22 | for i_iter in 0.. { 23 | // add new games for baselines so they don't fall behind 24 | { 25 | let i = i_iter % cfg.rollout_num_explores.len(); 26 | for j in 0..cfg.rollout_num_explores.len() { 27 | if i == j { 28 | continue; 29 | } 30 | let seed = i_iter; 31 | add_pgn_result( 32 | &mut pgn, 33 | &format!("VanillaMCTS{}", cfg.rollout_num_explores[i]), 34 | &format!("VanillaMCTS{}", cfg.rollout_num_explores[j]), 35 | mcts_vs_mcts::( 36 | &cfg, 37 | first_player, 38 | cfg.rollout_num_explores[i], 39 | cfg.rollout_num_explores[j], 40 | seed as u64, 41 | ), 42 | )?; 43 | } 44 | calculate_ratings(&cfg.logs)?; 45 | plot_ratings(&cfg.logs)?; 46 | } 47 | 48 | // wait for model to exist; 49 | let name = format!("model_{}.ot", i_iter); 50 | while !models_dir.join(&name).exists() { 51 | std::thread::sleep(std::time::Duration::from_secs(1)); 52 | } 53 | 54 | // wait an extra second to be sure data is there 55 | std::thread::sleep(std::time::Duration::from_secs(1)); 56 | 57 | // load model 58 | let mut vs = VarStore::new(tch::Device::Cpu); 59 | let mut policy = P::new(&vs); 60 | vs.load(models_dir.join(&name))?; 61 | 62 | // evaluate against rollout mcts 63 | for &explores in cfg.rollout_num_explores.iter() { 64 | let op_name = format!("VanillaMCTS{}", explores); 65 | for seed in 0..cfg.num_games_against_rollout { 66 | let result = eval_against_rollout_mcts( 67 | &cfg, 68 | &mut policy, 69 | first_player, 70 | explores, 71 | seed as u64, 72 | ); 73 | add_pgn_result(&mut pgn, &name, &op_name, result)?; 74 | let result = eval_against_rollout_mcts( 75 | &cfg, 76 | &mut policy, 77 | first_player.next(), 78 | explores, 79 | seed as u64, 80 | ); 81 | add_pgn_result(&mut pgn, &op_name, &name, result)?; 82 | } 83 | calculate_ratings(&cfg.logs)?; 84 | plot_ratings(&cfg.logs)?; 85 | } 86 | 87 | // evaluate against best old policies 88 | for (prev_name, prev_p) in best_k.iter_mut() { 89 | let result = eval_against_old(&cfg, &mut policy, prev_p); 90 | add_pgn_result(&mut pgn, &name, &prev_name, result)?; 91 | 92 | let result = eval_against_old(&cfg, prev_p, &mut policy); 93 | add_pgn_result(&mut pgn, &prev_name, &name, result)?; 94 | } 95 | 96 | // update results 97 | calculate_ratings(&cfg.logs)?; 98 | plot_ratings(&cfg.logs)?; 99 | 100 | // update top k 101 | if best_k.len() < cfg.num_best_policies { 102 | best_k.push((name, policy)); 103 | } else { 104 | let ranks = rankings(&cfg.logs)?; 105 | if ranks 106 | .iter() 107 | .take(cfg.num_best_policies) 108 | .position(|n| n == &name) 109 | .is_some() 110 | { 111 | best_k.push((name, policy)); 112 | match best_k.iter().position(|(n, _p)| { 113 | ranks 114 | .iter() 115 | .take(cfg.num_best_policies) 116 | .position(|n1| n1 == n) 117 | .is_none() 118 | }) { 119 | Some(i) => { 120 | best_k.remove(i); 121 | } 122 | None => panic!("Didn't find policy to evict"), 123 | } 124 | } 125 | } 126 | } 127 | 128 | Ok(()) 129 | } 130 | 131 | fn eval_against_old, P: Policy, const N: usize>( 132 | cfg: &EvaluationConfig, 133 | p1: &mut P, 134 | p2: &mut P, 135 | ) -> f32 { 136 | let mut game = G::new(); 137 | let first_player = game.player(); 138 | loop { 139 | let action = if game.player() == first_player { 140 | MCTS::exploit( 141 | cfg.policy_num_explores, 142 | cfg.policy_mcts_cfg, 143 | p1, 144 | game.clone(), 145 | cfg.policy_action, 146 | ) 147 | } else { 148 | MCTS::exploit( 149 | cfg.policy_num_explores, 150 | cfg.policy_mcts_cfg, 151 | p2, 152 | game.clone(), 153 | cfg.policy_action, 154 | ) 155 | }; 156 | if game.step(&action) { 157 | break; 158 | } 159 | } 160 | game.reward(first_player) 161 | } 162 | 163 | fn eval_against_rollout_mcts, P: Policy, const N: usize>( 164 | cfg: &EvaluationConfig, 165 | policy: &mut P, 166 | player: G::PlayerId, 167 | opponent_explores: usize, 168 | seed: u64, 169 | ) -> f32 { 170 | let mut game = G::new(); 171 | let first_player = game.player(); 172 | let mut rng = StdRng::seed_from_u64(seed); 173 | let mut rollout_policy = RolloutPolicy { rng: &mut rng }; 174 | loop { 175 | let action = if game.player() == player { 176 | MCTS::exploit( 177 | cfg.policy_num_explores, 178 | cfg.policy_mcts_cfg, 179 | policy, 180 | game.clone(), 181 | cfg.policy_action, 182 | ) 183 | } else { 184 | FrozenMCTS::exploit( 185 | opponent_explores, 186 | cfg.rollout_mcts_cfg, 187 | &mut rollout_policy, 188 | game.clone(), 189 | cfg.rollout_action, 190 | ) 191 | }; 192 | 193 | if game.step(&action) { 194 | break; 195 | } 196 | } 197 | game.reward(first_player) 198 | } 199 | 200 | fn mcts_vs_mcts, const N: usize>( 201 | cfg: &EvaluationConfig, 202 | player: G::PlayerId, 203 | p1_explores: usize, 204 | p2_explores: usize, 205 | seed: u64, 206 | ) -> f32 { 207 | let mut rng = StdRng::seed_from_u64(seed); 208 | let mut rollout_policy = RolloutPolicy { rng: &mut rng }; 209 | let mut game = G::new(); 210 | let first_player = game.player(); 211 | loop { 212 | let action = FrozenMCTS::exploit( 213 | if game.player() == player { 214 | p1_explores 215 | } else { 216 | p2_explores 217 | }, 218 | cfg.rollout_mcts_cfg, 219 | &mut rollout_policy, 220 | game.clone(), 221 | cfg.rollout_action, 222 | ); 223 | if game.step(&action) { 224 | break; 225 | } 226 | } 227 | game.reward(first_player) 228 | } 229 | 230 | type NodeId = u32; 231 | type ActionId = u8; 232 | 233 | #[derive(Debug)] 234 | struct Node, const N: usize> { 235 | parent: NodeId, // 4 bytes 236 | first_child: NodeId, // 4 bytes 237 | num_children: u8, // 1 byte 238 | game: G, // ? bytes 239 | solution: Option, // 1 byte 240 | action: ActionId, // 1 byte 241 | action_prob: f32, // 4 bytes 242 | cum_value: f32, // 4 bytes 243 | num_visits: f32, // 4 bytes 244 | } 245 | 246 | impl, const N: usize> Node { 247 | fn unvisited( 248 | parent: NodeId, 249 | game: G, 250 | solution: Option, 251 | action: u8, 252 | action_prob: f32, 253 | ) -> Self { 254 | Self { 255 | parent, 256 | first_child: 0, 257 | num_children: 0, 258 | game, 259 | action, 260 | solution, 261 | action_prob, 262 | cum_value: 0.0, 263 | num_visits: 0.0, 264 | } 265 | } 266 | 267 | #[inline] 268 | fn is_unvisited(&self) -> bool { 269 | self.num_children == 0 && self.solution.is_none() 270 | } 271 | 272 | #[inline] 273 | fn is_visited(&self) -> bool { 274 | self.num_children != 0 275 | } 276 | 277 | #[inline] 278 | fn is_unsolved(&self) -> bool { 279 | self.solution.is_none() 280 | } 281 | 282 | #[inline] 283 | fn last_child(&self) -> NodeId { 284 | self.first_child + self.num_children as u32 285 | } 286 | 287 | #[inline] 288 | fn mark_visited(&mut self, first_child: NodeId, num_children: u8) { 289 | self.first_child = first_child; 290 | self.num_children = num_children; 291 | } 292 | 293 | #[inline] 294 | fn mark_solved(&mut self, outcome: Outcome) { 295 | self.solution = Some(outcome); 296 | } 297 | } 298 | 299 | pub struct FrozenMCTS<'a, G: Game, P: Policy, const N: usize> { 300 | root: NodeId, 301 | offset: NodeId, 302 | nodes: Vec>, 303 | policy: &'a mut P, 304 | cfg: MCTSConfig, 305 | } 306 | 307 | impl<'a, G: Game, P: Policy, const N: usize> FrozenMCTS<'a, G, P, N> { 308 | pub fn exploit( 309 | explores: usize, 310 | cfg: MCTSConfig, 311 | policy: &'a mut P, 312 | game: G, 313 | action_selection: ActionSelection, 314 | ) -> G::Action { 315 | let mut mcts = Self::with_capacity(explores + 1, cfg, policy, game); 316 | mcts.explore_n(explores); 317 | mcts.best_action(action_selection) 318 | } 319 | 320 | pub fn with_capacity(capacity: usize, cfg: MCTSConfig, policy: &'a mut P, game: G) -> Self { 321 | let mut nodes = Vec::with_capacity(capacity); 322 | nodes.push(Node::unvisited(0, game, None, 0, 0.0)); 323 | let mut mcts = Self { 324 | root: 0, 325 | offset: 0, 326 | nodes, 327 | policy, 328 | cfg, 329 | }; 330 | let (value, any_solved) = mcts.visit(mcts.root); 331 | mcts.backprop(mcts.root, value, any_solved); 332 | mcts 333 | } 334 | 335 | fn next_node_id(&self) -> NodeId { 336 | self.nodes.len() as NodeId + self.offset 337 | } 338 | 339 | fn node(&self, node_id: NodeId) -> &Node { 340 | &self.nodes[(node_id - self.offset) as usize] 341 | } 342 | 343 | fn mut_node(&mut self, node_id: NodeId) -> &mut Node { 344 | &mut self.nodes[(node_id - self.offset) as usize] 345 | } 346 | 347 | fn children_of(&self, node: &Node) -> &[Node] { 348 | &self.nodes 349 | [(node.first_child - self.offset) as usize..(node.last_child() - self.offset) as usize] 350 | } 351 | 352 | fn mut_nodes(&mut self, first_child: NodeId, last_child: NodeId) -> &mut [Node] { 353 | &mut self.nodes[(first_child - self.offset) as usize..(last_child - self.offset) as usize] 354 | } 355 | 356 | pub fn best_action(&self, action_selection: ActionSelection) -> G::Action { 357 | let root = self.node(self.root); 358 | 359 | let mut best_action = None; 360 | let mut best_value = f32::NEG_INFINITY; 361 | for child in self.children_of(root) { 362 | if child.is_unvisited() { 363 | continue; 364 | } 365 | let value = match child.solution { 366 | Some(Outcome::Win(_)) => f32::NEG_INFINITY, 367 | Some(Outcome::Draw(_)) => 1e6, 368 | Some(Outcome::Lose(_)) => f32::INFINITY, 369 | None => match action_selection { 370 | ActionSelection::Q => -child.cum_value / child.num_visits, 371 | ActionSelection::NumVisits => child.num_visits, 372 | }, 373 | }; 374 | if best_action.is_none() || value > best_value { 375 | best_value = value; 376 | best_action = Some((child.action as usize).into()); 377 | } 378 | } 379 | best_action.unwrap() 380 | } 381 | 382 | fn explore(&mut self) { 383 | let mut node_id = self.root; 384 | loop { 385 | let node = self.node(node_id); 386 | if let Some(outcome) = node.solution { 387 | self.backprop(node_id, outcome.value(), true); 388 | return; 389 | } else if node.is_unvisited() { 390 | let (value, any_solved) = self.visit(node_id); 391 | self.backprop(node_id, value, any_solved); 392 | return; 393 | } else { 394 | node_id = self.select_best_child(node_id); 395 | } 396 | } 397 | } 398 | 399 | fn select_best_child(&mut self, node_id: NodeId) -> NodeId { 400 | let node = self.node(node_id); 401 | 402 | let mut best_child_id = None; 403 | let mut best_value = f32::NEG_INFINITY; 404 | for child_ind in 0..node.num_children { 405 | let child_id = node.first_child + child_ind as u32; 406 | let child = self.node(child_id); 407 | let value = if child.is_unvisited() { 408 | let f = match self.cfg.fpu { 409 | Fpu::Const(value) => value, 410 | _ => panic!("Unsupported fpu in baseline"), 411 | }; 412 | f + child.action_prob 413 | } else { 414 | let q = match child.solution { 415 | Some(outcome) => outcome.reversed().value(), 416 | None => -child.cum_value / child.num_visits, 417 | }; 418 | let u = match self.cfg.exploration { 419 | Exploration::Uct { c } => { 420 | let visits = (c * node.num_visits.ln()).sqrt(); 421 | visits / child.num_visits.sqrt() 422 | } 423 | _ => { 424 | panic!("Not supported in frozen mcts"); 425 | } 426 | }; 427 | q + u 428 | }; 429 | if best_child_id.is_none() || value > best_value { 430 | best_child_id = Some(child_id); 431 | best_value = value; 432 | } 433 | } 434 | best_child_id.unwrap() 435 | } 436 | 437 | fn visit(&mut self, node_id: NodeId) -> (f32, bool) { 438 | let first_child = self.next_node_id(); 439 | let node = self.node(node_id); 440 | let game = node.game.clone(); 441 | let (logits, dist) = self.policy.eval(&game); 442 | let mut num_children = 0; 443 | let mut any_solved = false; 444 | let mut max_logit = f32::NEG_INFINITY; 445 | for action in game.iter_actions() { 446 | let mut child_game = game.clone(); 447 | let is_over = child_game.step(&action); 448 | let solution = if is_over { 449 | any_solved = true; 450 | Some(child_game.reward(child_game.player()).into()) 451 | } else { 452 | None 453 | }; 454 | let action: usize = action.into(); 455 | let logit = logits[action]; 456 | max_logit = max_logit.max(logit); 457 | let child = Node::unvisited(node_id, child_game, solution, action as u8, logit); 458 | self.nodes.push(child); 459 | num_children += 1; 460 | } 461 | 462 | let node = self.mut_node(node_id); 463 | node.mark_visited(first_child, num_children); 464 | let first_child = node.first_child; 465 | let last_child = node.last_child(); 466 | 467 | // stable softmax 468 | let mut total = 0.0; 469 | for child in self.mut_nodes(first_child, last_child) { 470 | child.action_prob = (child.action_prob - max_logit).exp(); 471 | total += child.action_prob; 472 | } 473 | for child in self.mut_nodes(first_child, last_child) { 474 | child.action_prob /= total; 475 | } 476 | 477 | let value = dist[2] - dist[0]; 478 | 479 | (value, any_solved) 480 | } 481 | 482 | fn backprop(&mut self, leaf_node_id: NodeId, mut value: f32, mut solved: bool) { 483 | let mut node_id = leaf_node_id; 484 | loop { 485 | let node = self.node(node_id); 486 | let parent = node.parent; 487 | 488 | if self.cfg.solve && solved && node.is_unsolved() { 489 | let mut all_solved = true; 490 | let mut worst_solution = None; 491 | for child in self.children_of(node) { 492 | if child.is_unvisited() || child.is_unsolved() { 493 | all_solved = false; 494 | } else if worst_solution.is_none() || child.solution < worst_solution { 495 | worst_solution = child.solution; 496 | } 497 | } 498 | 499 | let node = self.mut_node(node_id); 500 | if let Some(Outcome::Lose(_)) = worst_solution { 501 | // at least 1 is a win, so mark this node as a win 502 | node.mark_solved(Outcome::Win(0)); 503 | value = -node.cum_value + (node.num_visits + 1.0); 504 | } else if node.is_visited() && all_solved { 505 | // all children node's are proven losses or draws 506 | let best_for_me = worst_solution.unwrap().reversed(); 507 | node.mark_solved(best_for_me); 508 | if let Outcome::Draw(_) = best_for_me { 509 | value = -node.cum_value; 510 | } else { 511 | value = -node.cum_value - (node.num_visits + 1.0); 512 | } 513 | } else { 514 | solved = false; 515 | } 516 | } 517 | 518 | let node = self.mut_node(node_id); 519 | node.cum_value += value; 520 | node.num_visits += 1.0; 521 | value = -value; 522 | if node_id == self.root { 523 | break; 524 | } 525 | node_id = parent; 526 | } 527 | } 528 | 529 | pub fn explore_n(&mut self, n: usize) { 530 | for _ in 0..n { 531 | self.explore(); 532 | } 533 | } 534 | } 535 | -------------------------------------------------------------------------------- /synthesis/src/game.rs: -------------------------------------------------------------------------------- 1 | use std::cmp::Ordering; 2 | use std::hash::Hash; 3 | 4 | pub trait HasTurnOrder: Eq + Clone + Copy + std::fmt::Debug { 5 | fn prev(&self) -> Self; 6 | fn next(&self) -> Self; 7 | } 8 | 9 | #[derive(Debug, PartialEq, Eq, Clone, Copy)] 10 | pub enum Outcome { 11 | Win(usize), 12 | Lose(usize), 13 | Draw(usize), 14 | } 15 | 16 | impl From for Outcome { 17 | fn from(value: f32) -> Self { 18 | if value > 0.0 { 19 | Self::Win(0) 20 | } else if value < 0.0 { 21 | Self::Lose(0) 22 | } else { 23 | Self::Draw(0) 24 | } 25 | } 26 | } 27 | 28 | impl Outcome { 29 | pub fn reversed(&self) -> Self { 30 | match self { 31 | Self::Win(u) => Self::Lose(*u + 1), 32 | Self::Lose(u) => Self::Win(*u + 1), 33 | Self::Draw(u) => Self::Draw(*u + 1), 34 | } 35 | } 36 | 37 | pub fn value(&self) -> f32 { 38 | match self { 39 | Self::Win(_) => 1.0, 40 | Self::Draw(_) => 0.0, 41 | Self::Lose(_) => -1.0, 42 | } 43 | } 44 | } 45 | 46 | impl Ord for Outcome { 47 | fn cmp(&self, other: &Self) -> Ordering { 48 | match (self, other) { 49 | (Self::Win(a), Self::Win(b)) => b.cmp(a), // NOTE: reversed, want to win in least number of terms 50 | (Self::Win(_a), Self::Draw(_b)) => Ordering::Greater, 51 | (Self::Win(_a), Self::Lose(_b)) => Ordering::Greater, 52 | (Self::Draw(_a), Self::Win(_b)) => Ordering::Less, 53 | (Self::Draw(a), Self::Draw(b)) => a.cmp(b), 54 | (Self::Draw(_a), Self::Lose(_b)) => Ordering::Greater, 55 | (Self::Lose(_a), Self::Win(_b)) => Ordering::Less, 56 | (Self::Lose(_a), Self::Draw(_b)) => Ordering::Less, 57 | (Self::Lose(a), Self::Lose(b)) => a.cmp(b), 58 | } 59 | } 60 | } 61 | 62 | impl PartialOrd for Outcome { 63 | fn partial_cmp(&self, other: &Self) -> Option { 64 | Some(self.cmp(other)) 65 | } 66 | } 67 | 68 | pub trait Game: Eq + Hash + Clone + std::fmt::Debug + Send { 69 | type PlayerId: HasTurnOrder; 70 | type Action: Eq + Clone + Copy + std::fmt::Debug + Into + From; 71 | type ActionIterator: Iterator; 72 | type Features: PartialEq + Clone + std::fmt::Debug + Send; 73 | 74 | const MAX_NUM_ACTIONS: usize = N; 75 | const MAX_TURNS: usize; 76 | const NAME: &'static str; 77 | const NUM_PLAYERS: usize; 78 | const DIMS: &'static [i64]; 79 | 80 | fn new() -> Self; 81 | fn player(&self) -> Self::PlayerId; 82 | fn is_over(&self) -> bool; 83 | fn reward(&self, player_id: Self::PlayerId) -> f32; 84 | fn iter_actions(&self) -> Self::ActionIterator; 85 | fn step(&mut self, action: &Self::Action) -> bool; 86 | fn features(&self) -> Self::Features; 87 | fn print(&self); 88 | } 89 | 90 | #[cfg(test)] 91 | mod tests { 92 | use super::*; 93 | 94 | #[test] 95 | fn test_cmp_outcome() { 96 | assert_eq!(Outcome::Win(0).cmp(&Outcome::Win(0)), Ordering::Equal); 97 | assert_eq!(Outcome::Win(0).cmp(&Outcome::Draw(0)), Ordering::Greater); 98 | assert_eq!(Outcome::Win(0).cmp(&Outcome::Lose(0)), Ordering::Greater); 99 | 100 | assert_eq!(Outcome::Draw(0).cmp(&Outcome::Win(0)), Ordering::Less); 101 | assert_eq!(Outcome::Draw(0).cmp(&Outcome::Draw(0)), Ordering::Equal); 102 | assert_eq!(Outcome::Draw(0).cmp(&Outcome::Lose(0)), Ordering::Greater); 103 | 104 | assert_eq!(Outcome::Lose(0).cmp(&Outcome::Win(0)), Ordering::Less); 105 | assert_eq!(Outcome::Lose(0).cmp(&Outcome::Draw(0)), Ordering::Less); 106 | assert_eq!(Outcome::Lose(0).cmp(&Outcome::Lose(0)), Ordering::Equal); 107 | } 108 | 109 | #[test] 110 | fn test_ord_outcome() { 111 | assert!(Outcome::Win(0) == Outcome::Win(0)); 112 | assert!(Outcome::Win(0) > Outcome::Draw(0)); 113 | assert!(Outcome::Win(0) > Outcome::Lose(0)); 114 | 115 | assert!(Outcome::Draw(0) < Outcome::Win(0)); 116 | assert!(Outcome::Draw(0) == Outcome::Draw(0)); 117 | assert!(Outcome::Draw(0) > Outcome::Lose(0)); 118 | 119 | assert!(Outcome::Lose(0) < Outcome::Win(0)); 120 | assert!(Outcome::Lose(0) < Outcome::Draw(0)); 121 | assert!(Outcome::Lose(0) == Outcome::Lose(0)); 122 | } 123 | 124 | #[test] 125 | fn test_partial_ord_outcome() { 126 | assert!(Some(Outcome::Win(0)) > None); 127 | assert!(Some(Outcome::Draw(0)) > None); 128 | assert!(Some(Outcome::Lose(0)) > None); 129 | 130 | assert!(Some(Outcome::Win(0)) == Some(Outcome::Win(0))); 131 | assert!(Some(Outcome::Win(0)) > Some(Outcome::Draw(0))); 132 | assert!(Some(Outcome::Win(0)) > Some(Outcome::Lose(0))); 133 | 134 | assert!(Some(Outcome::Draw(0)) < Some(Outcome::Win(0))); 135 | assert!(Some(Outcome::Draw(0)) == Some(Outcome::Draw(0))); 136 | assert!(Some(Outcome::Draw(0)) > Some(Outcome::Lose(0))); 137 | 138 | assert!(Some(Outcome::Lose(0)) < Some(Outcome::Win(0))); 139 | assert!(Some(Outcome::Lose(0)) < Some(Outcome::Draw(0))); 140 | assert!(Some(Outcome::Lose(0)) == Some(Outcome::Lose(0))); 141 | } 142 | } 143 | -------------------------------------------------------------------------------- /synthesis/src/lib.rs: -------------------------------------------------------------------------------- 1 | mod alpha_zero; 2 | pub mod config; 3 | mod data; 4 | mod evaluator; 5 | pub mod game; 6 | mod mcts; 7 | pub mod policies; 8 | pub mod prelude; 9 | mod utils; 10 | -------------------------------------------------------------------------------- /synthesis/src/mcts.rs: -------------------------------------------------------------------------------- 1 | use crate::config::{ActionSelection, Exploration, Fpu, MCTSConfig, PolicyNoise}; 2 | use crate::game::{Game, Outcome}; 3 | use crate::policies::Policy; 4 | use rand::{distributions::Distribution, thread_rng, Rng}; 5 | use rand_distr::Dirichlet; 6 | 7 | type NodeId = u32; 8 | type ActionId = u8; 9 | 10 | impl Into for Outcome { 11 | fn into(self) -> usize { 12 | match self { 13 | Outcome::Lose(_) => 0, 14 | Outcome::Draw(_) => 1, 15 | Outcome::Win(_) => 2, 16 | } 17 | } 18 | } 19 | 20 | impl Into<[f32; 3]> for Outcome { 21 | fn into(self) -> [f32; 3] { 22 | let mut dist = [0.0; 3]; 23 | dist[Into::::into(self)] = 1.0; 24 | dist 25 | } 26 | } 27 | 28 | #[derive(Debug)] 29 | struct Node, const N: usize> { 30 | parent: NodeId, // 4 bytes 31 | first_child: NodeId, // 4 bytes 32 | num_children: u8, // 1 byte 33 | game: G, // ? bytes 34 | solution: Option, // 1 byte 35 | action: ActionId, // 1 byte 36 | action_prob: f32, // 4 bytes 37 | outcome_probs: [f32; 3], 38 | num_visits: f32, // 4 bytes 39 | } 40 | 41 | impl, const N: usize> Node { 42 | fn q(&self) -> f32 { 43 | (self.outcome_probs[2] - self.outcome_probs[0]) / self.num_visits 44 | } 45 | 46 | fn unvisited( 47 | parent: NodeId, 48 | game: G, 49 | solution: Option, 50 | action: u8, 51 | action_prob: f32, 52 | ) -> Self { 53 | Self { 54 | parent, 55 | first_child: 0, 56 | num_children: 0, 57 | game, 58 | action, 59 | solution, 60 | action_prob, 61 | outcome_probs: [0.0; 3], 62 | num_visits: 0.0, 63 | } 64 | } 65 | 66 | fn action(&self) -> G::Action { 67 | (self.action as usize).into() 68 | } 69 | 70 | #[inline] 71 | fn is_unvisited(&self) -> bool { 72 | self.num_children == 0 && self.solution.is_none() 73 | } 74 | 75 | #[inline] 76 | fn is_visited(&self) -> bool { 77 | self.num_children != 0 78 | } 79 | 80 | #[inline] 81 | fn is_unsolved(&self) -> bool { 82 | self.solution.is_none() 83 | } 84 | 85 | #[inline] 86 | fn last_child(&self) -> NodeId { 87 | self.first_child + self.num_children as u32 88 | } 89 | 90 | #[inline] 91 | fn mark_visited(&mut self, first_child: NodeId, num_children: u8) { 92 | self.first_child = first_child; 93 | self.num_children = num_children; 94 | } 95 | 96 | #[inline] 97 | fn mark_solved(&mut self, outcome: Outcome) { 98 | self.solution = Some(outcome); 99 | } 100 | } 101 | 102 | pub struct MCTS<'a, G: Game, P: Policy, const N: usize> { 103 | root: NodeId, 104 | offset: NodeId, 105 | nodes: Vec>, 106 | policy: &'a mut P, 107 | cfg: MCTSConfig, 108 | } 109 | 110 | impl<'a, G: Game, P: Policy, const N: usize> MCTS<'a, G, P, N> { 111 | pub fn exploit( 112 | explores: usize, 113 | cfg: MCTSConfig, 114 | policy: &'a mut P, 115 | game: G, 116 | action_selection: ActionSelection, 117 | ) -> G::Action { 118 | let mut mcts = Self::with_capacity(explores + 1, cfg, policy, game); 119 | mcts.explore_n(explores); 120 | mcts.best_action(action_selection) 121 | } 122 | 123 | pub fn with_capacity(capacity: usize, cfg: MCTSConfig, policy: &'a mut P, game: G) -> Self { 124 | let mut nodes = Vec::with_capacity(capacity); 125 | nodes.push(Node::unvisited(0, game, None, 0, 0.0)); 126 | let mut mcts = Self { 127 | root: 0, 128 | offset: 0, 129 | nodes, 130 | policy, 131 | cfg, 132 | }; 133 | let (node_id, outcome_probs, any_solved) = mcts.visit(mcts.root); 134 | mcts.backprop(node_id, outcome_probs, any_solved); 135 | mcts.add_root_noise(); 136 | mcts 137 | } 138 | 139 | pub fn explore_n(&mut self, n: usize) { 140 | for _ in 0..n { 141 | // NOTE this is important for value extraction because if root is solved then children might not have any visits 142 | if self.node(self.root).solution.is_some() { 143 | break; 144 | } 145 | self.explore(); 146 | } 147 | } 148 | } 149 | 150 | impl<'a, G: Game, P: Policy, const N: usize> MCTS<'a, G, P, N> { 151 | fn next_node_id(&self) -> NodeId { 152 | self.nodes.len() as NodeId + self.offset 153 | } 154 | 155 | fn node(&self, node_id: NodeId) -> &Node { 156 | &self.nodes[(node_id - self.offset) as usize] 157 | } 158 | 159 | fn mut_node(&mut self, node_id: NodeId) -> &mut Node { 160 | &mut self.nodes[(node_id - self.offset) as usize] 161 | } 162 | 163 | fn children_of(&self, node: &Node) -> &[Node] { 164 | &self.nodes 165 | [(node.first_child - self.offset) as usize..(node.last_child() - self.offset) as usize] 166 | } 167 | 168 | fn mut_nodes(&mut self, first_child: NodeId, last_child: NodeId) -> &mut [Node] { 169 | &mut self.nodes[(first_child - self.offset) as usize..(last_child - self.offset) as usize] 170 | } 171 | } 172 | 173 | impl<'a, G: Game, P: Policy, const N: usize> MCTS<'a, G, P, N> { 174 | pub fn target_policy(&self, search_policy: &mut [f32; N]) { 175 | search_policy.fill(0.0); 176 | let mut total = 0.0; 177 | let root = self.node(self.root); 178 | if root.num_visits == 1.0 { 179 | // assert!(root.solution.is_some()); 180 | match root.solution { 181 | Some(Outcome::Win(_)) => { 182 | for child in self.children_of(root) { 183 | let v = if let Some(Outcome::Lose(_)) = child.solution { 184 | 1.0 185 | } else { 186 | 0.0 187 | }; 188 | search_policy[child.action as usize] = v; 189 | total += v; 190 | } 191 | } 192 | _ => { 193 | for child in self.children_of(root) { 194 | search_policy[child.action as usize] = 1.0; 195 | total += 1.0; 196 | } 197 | } 198 | } 199 | } else { 200 | // assert!(root.num_visits > 1.0); 201 | for child in self.children_of(root) { 202 | let v = child.num_visits; 203 | search_policy[child.action as usize] = v; 204 | total += v; 205 | } 206 | } 207 | // assert!(total > 0.0, "{:?} {:?}", root.solution, root.num_visits); 208 | for i in 0..N { 209 | search_policy[i] /= total; 210 | } 211 | } 212 | 213 | pub fn target_q(&self) -> [f32; 3] { 214 | let root = self.node(self.root); 215 | match root.solution { 216 | Some(outcome) => outcome.into(), 217 | None => { 218 | let mut outcome_probs = [0.0; 3]; 219 | for i in 0..3 { 220 | outcome_probs[i] = root.outcome_probs[i] / root.num_visits; 221 | } 222 | outcome_probs 223 | } 224 | } 225 | } 226 | } 227 | 228 | impl<'a, G: Game, P: Policy, const N: usize> MCTS<'a, G, P, N> { 229 | fn add_root_noise(&mut self) { 230 | match self.cfg.root_policy_noise { 231 | PolicyNoise::None => {} 232 | PolicyNoise::Equal { weight } => { 233 | self.add_equalizing_noise(weight); 234 | } 235 | PolicyNoise::Dirichlet { alpha, weight } => { 236 | self.add_dirichlet_noise(&mut thread_rng(), alpha, weight); 237 | } 238 | } 239 | } 240 | 241 | fn add_dirichlet_noise(&mut self, rng: &mut R, alpha: f32, noise_weight: f32) { 242 | let root = self.node(self.root); 243 | if root.num_children < 2 { 244 | return; 245 | } 246 | let first_child = root.first_child; 247 | let last_child = root.last_child(); 248 | let dirichlet = Dirichlet::new_with_size(alpha, root.num_children as usize).unwrap(); 249 | let noise_probs = dirichlet.sample(rng); 250 | for (noise, child) in noise_probs 251 | .iter() 252 | .zip(self.mut_nodes(first_child, last_child)) 253 | { 254 | child.action_prob = child.action_prob * (1.0 - noise_weight) + noise_weight * noise; 255 | } 256 | } 257 | 258 | fn add_equalizing_noise(&mut self, noise_weight: f32) { 259 | let root = self.node(self.root); 260 | if root.num_children < 2 { 261 | return; 262 | } 263 | let first_child = root.first_child; 264 | let last_child = root.last_child(); 265 | let noise = 1.0 / root.num_children as f32; 266 | for child in self.mut_nodes(first_child, last_child) { 267 | child.action_prob = child.action_prob * (1.0 - noise_weight) + noise_weight * noise; 268 | } 269 | } 270 | } 271 | 272 | impl<'a, G: Game, P: Policy, const N: usize> MCTS<'a, G, P, N> { 273 | pub fn best_action(&self, action_selection: ActionSelection) -> G::Action { 274 | let root = self.node(self.root); 275 | 276 | let mut best_action = None; 277 | let mut best_value = None; 278 | for child in self.children_of(root) { 279 | let value = match child.solution { 280 | Some(Outcome::Win(turns)) => Some((0.0, turns as f32)), 281 | None => match action_selection { 282 | ActionSelection::Q => Some((1.0, -child.q())), 283 | ActionSelection::NumVisits => Some((1.0, child.num_visits)), 284 | }, 285 | Some(Outcome::Draw(turns)) => Some((2.0, -(turns as f32))), 286 | Some(Outcome::Lose(turns)) => Some((3.0, -(turns as f32))), 287 | }; 288 | if value > best_value { 289 | best_value = value; 290 | best_action = Some(child.action()); 291 | } 292 | } 293 | best_action.unwrap() 294 | } 295 | 296 | pub fn solution(&self, action: &G::Action) -> Option { 297 | let action: usize = (*action).into(); 298 | let action = action as u8; 299 | let root = self.node(self.root); 300 | for child in self.children_of(root) { 301 | if child.action == action { 302 | return child.solution; 303 | } 304 | } 305 | None 306 | } 307 | } 308 | 309 | impl<'a, G: Game, P: Policy, const N: usize> MCTS<'a, G, P, N> { 310 | fn explore(&mut self) { 311 | let mut node_id = self.root; 312 | loop { 313 | let node = self.node(node_id); 314 | if let Some(outcome) = node.solution { 315 | self.backprop(node_id, outcome.into(), true); 316 | return; 317 | } else if node.is_unvisited() { 318 | let (node_id, outcome_probs, any_solved) = self.visit(node_id); 319 | self.backprop(node_id, outcome_probs, any_solved); 320 | return; 321 | } else { 322 | node_id = self.select_best_child(node); 323 | } 324 | } 325 | } 326 | 327 | fn select_best_child(&self, parent: &Node) -> NodeId { 328 | let mut best_child_id = None; 329 | let mut best_value = None; 330 | for child_id in parent.first_child..parent.last_child() { 331 | let child = self.node(child_id); 332 | let q = self.exploit_value(parent, child); 333 | let u = self.explore_value(parent, child); 334 | let value = Some(q + u); 335 | if value > best_value { 336 | best_child_id = Some(child_id); 337 | best_value = value; 338 | } 339 | } 340 | best_child_id.unwrap() 341 | } 342 | 343 | fn exploit_value(&self, parent: &Node, child: &Node) -> f32 { 344 | if let Some(outcome) = child.solution { 345 | if self.cfg.select_solved_nodes { 346 | outcome.reversed().value() 347 | } else { 348 | f32::NEG_INFINITY 349 | } 350 | } else if child.num_children == 0 { 351 | match self.cfg.fpu { 352 | Fpu::Const(value) => value, 353 | Fpu::ParentQ => parent.q(), 354 | Fpu::Func(fpu_fn) => (fpu_fn)(), 355 | } 356 | } else { 357 | -child.q() 358 | } 359 | } 360 | 361 | fn explore_value(&self, parent: &Node, child: &Node) -> f32 { 362 | match self.cfg.exploration { 363 | Exploration::Uct { c } => { 364 | let visits = (c * parent.num_visits.ln()).sqrt(); 365 | visits / child.num_visits.sqrt() 366 | } 367 | Exploration::PolynomialUct { c } => { 368 | let visits = parent.num_visits.sqrt(); 369 | c * child.action_prob * visits / (1.0 + child.num_visits) 370 | } 371 | } 372 | } 373 | 374 | fn visit(&mut self, node_id: NodeId) -> (NodeId, [f32; 3], bool) { 375 | let first_child = self.next_node_id(); 376 | let node = self.node(node_id); 377 | if let Some(outcome) = node.solution { 378 | return (node_id, outcome.into(), true); 379 | } 380 | 381 | let game = node.game.clone(); 382 | let mut num_children = 0; 383 | let mut any_solved = false; 384 | for action in game.iter_actions() { 385 | let mut child_game = game.clone(); 386 | let is_over = child_game.step(&action); 387 | let solution = if is_over { 388 | any_solved = true; 389 | Some(child_game.reward(child_game.player()).into()) 390 | } else { 391 | None 392 | }; 393 | let action: usize = action.into(); 394 | let child = Node::unvisited(node_id, child_game, solution, action as u8, 1.0); 395 | self.nodes.push(child); 396 | num_children += 1; 397 | } 398 | 399 | let node = self.mut_node(node_id); 400 | node.mark_visited(first_child, num_children); 401 | let first_child = node.first_child; 402 | let last_child = node.last_child(); 403 | 404 | if self.cfg.auto_extend && num_children == 1 { 405 | return self.visit(first_child); 406 | } else { 407 | let (logits, outcome_probs) = self.policy.eval(&game); 408 | 409 | // stable softmax 410 | let mut max_logit = f32::NEG_INFINITY; 411 | for child in self.mut_nodes(first_child, last_child) { 412 | let logit = logits[child.action as usize]; 413 | max_logit = max_logit.max(logit); 414 | child.action_prob = logit; 415 | } 416 | let mut total = 0.0; 417 | for child in self.mut_nodes(first_child, last_child) { 418 | child.action_prob = (child.action_prob - max_logit).exp(); 419 | total += child.action_prob; 420 | } 421 | for child in self.mut_nodes(first_child, last_child) { 422 | child.action_prob /= total; 423 | } 424 | 425 | (node_id, outcome_probs, any_solved) 426 | } 427 | } 428 | 429 | fn backprop(&mut self, leaf_node_id: NodeId, mut outcome_probs: [f32; 3], mut solved: bool) { 430 | let mut node_id = leaf_node_id; 431 | loop { 432 | let node = self.node(node_id); 433 | let parent = node.parent; 434 | 435 | if self.cfg.solve && solved { 436 | // compute whether all children are solved & best solution so far 437 | let mut all_solved = true; 438 | let mut best_solution = node.solution; 439 | for child in self.children_of(node) { 440 | let soln = child.solution.map(|o| o.reversed()); 441 | all_solved &= soln.is_some(); 442 | best_solution = best_solution.max(soln); 443 | } 444 | 445 | let correct_values = self.cfg.correct_values_on_solve; 446 | let node = self.mut_node(node_id); 447 | if let Some(Outcome::Win(in_turns)) = best_solution { 448 | // at least 1 is a win, so mark this node as a win 449 | node.mark_solved(Outcome::Win(in_turns)); 450 | if correct_values { 451 | for i in 0..3 { 452 | outcome_probs[i] = -node.outcome_probs[i]; 453 | } 454 | outcome_probs[2] += node.num_visits + 1.0; 455 | } 456 | } else if best_solution.is_some() && all_solved { 457 | // all children node's are proven losses or draws 458 | let best_outcome = best_solution.unwrap(); 459 | node.mark_solved(best_outcome); 460 | if correct_values { 461 | for i in 0..3 { 462 | outcome_probs[i] = -node.outcome_probs[i]; 463 | } 464 | if let Outcome::Draw(_) = best_outcome { 465 | outcome_probs[1] += node.num_visits + 1.0; 466 | } else { 467 | outcome_probs[0] += node.num_visits + 1.0; 468 | } 469 | } 470 | } else { 471 | solved = false; 472 | } 473 | } 474 | 475 | let node = self.mut_node(node_id); 476 | for i in 0..3 { 477 | node.outcome_probs[i] += outcome_probs[i]; 478 | } 479 | node.num_visits += 1.0; 480 | if node_id == self.root { 481 | break; 482 | } 483 | let t = outcome_probs[0]; 484 | outcome_probs[0] = outcome_probs[2]; 485 | outcome_probs[2] = t; 486 | node_id = parent; 487 | } 488 | } 489 | } 490 | 491 | #[cfg(test)] 492 | mod tests { 493 | use rand::prelude::{SeedableRng, StdRng}; 494 | 495 | use super::*; 496 | use crate::game::HasTurnOrder; 497 | use crate::policies::RolloutPolicy; 498 | 499 | #[derive(Clone, Copy, Debug, PartialEq, Eq, std::hash::Hash, PartialOrd, Ord)] 500 | pub enum PlayerId { 501 | X, 502 | O, 503 | } 504 | 505 | impl HasTurnOrder for PlayerId { 506 | fn prev(&self) -> Self { 507 | self.next() 508 | } 509 | 510 | fn next(&self) -> Self { 511 | match self { 512 | PlayerId::O => PlayerId::X, 513 | PlayerId::X => PlayerId::O, 514 | } 515 | } 516 | } 517 | 518 | #[derive(Debug, PartialEq, Eq, Clone, Copy)] 519 | struct Action { 520 | row: usize, 521 | col: usize, 522 | } 523 | 524 | impl From for Action { 525 | fn from(i: usize) -> Self { 526 | let row = i / 3; 527 | let col = i % 3; 528 | Self { row, col } 529 | } 530 | } 531 | 532 | impl Into for Action { 533 | fn into(self) -> usize { 534 | self.row * 3 + self.col 535 | } 536 | } 537 | 538 | #[derive(Debug, PartialEq, Eq, std::hash::Hash, Clone)] 539 | struct TicTacToe { 540 | board: [[Option; 3]; 3], 541 | player: PlayerId, 542 | turn: usize, 543 | } 544 | 545 | struct ActionIterator { 546 | game: TicTacToe, 547 | i: usize, 548 | } 549 | 550 | impl Iterator for ActionIterator { 551 | type Item = Action; 552 | 553 | fn next(&mut self) -> Option { 554 | while self.i < 9 { 555 | let action: Action = self.i.into(); 556 | self.i += 1; 557 | if self.game.board[action.row][action.col].is_none() { 558 | return Some(action); 559 | } 560 | } 561 | 562 | None 563 | } 564 | } 565 | 566 | impl TicTacToe { 567 | fn won(&self, player: PlayerId) -> bool { 568 | let p = Some(player); 569 | if self.board[0][0] == p && self.board[0][1] == p && self.board[0][2] == p { 570 | return true; 571 | } 572 | if self.board[1][0] == p && self.board[1][1] == p && self.board[1][2] == p { 573 | return true; 574 | } 575 | if self.board[2][0] == p && self.board[2][1] == p && self.board[2][2] == p { 576 | return true; 577 | } 578 | if self.board[0][0] == p && self.board[1][0] == p && self.board[2][0] == p { 579 | return true; 580 | } 581 | if self.board[0][1] == p && self.board[1][1] == p && self.board[2][1] == p { 582 | return true; 583 | } 584 | if self.board[0][2] == p && self.board[1][2] == p && self.board[2][2] == p { 585 | return true; 586 | } 587 | if self.board[0][0] == p && self.board[1][1] == p && self.board[2][2] == p { 588 | return true; 589 | } 590 | if self.board[0][2] == p && self.board[1][1] == p && self.board[2][0] == p { 591 | return true; 592 | } 593 | 594 | false 595 | } 596 | } 597 | 598 | impl Game<9> for TicTacToe { 599 | type PlayerId = PlayerId; 600 | type Action = Action; 601 | type ActionIterator = ActionIterator; 602 | type Features = [[[f32; 3]; 3]; 3]; 603 | 604 | const MAX_NUM_ACTIONS: usize = 9; 605 | const MAX_TURNS: usize = 9; 606 | const NAME: &'static str = "TicTacToe"; 607 | const NUM_PLAYERS: usize = 2; 608 | const DIMS: &'static [i64] = &[3, 3, 3]; 609 | 610 | fn new() -> Self { 611 | Self { 612 | board: [[None; 3]; 3], 613 | player: PlayerId::X, 614 | turn: 0, 615 | } 616 | } 617 | 618 | fn player(&self) -> Self::PlayerId { 619 | self.player 620 | } 621 | 622 | fn is_over(&self) -> bool { 623 | self.won(self.player) || self.won(self.player.prev()) || self.turn == 9 624 | } 625 | 626 | fn reward(&self, player_id: Self::PlayerId) -> f32 { 627 | if self.won(player_id) { 628 | 1.0 629 | } else if self.won(player_id.next()) { 630 | -1.0 631 | } else { 632 | 0.0 633 | } 634 | } 635 | 636 | fn iter_actions(&self) -> Self::ActionIterator { 637 | ActionIterator { 638 | game: self.clone(), 639 | i: 0, 640 | } 641 | } 642 | fn step(&mut self, action: &Self::Action) -> bool { 643 | assert!(action.row < 3); 644 | assert!(action.col < 3); 645 | assert!(self.board[action.row][action.col].is_none()); 646 | self.board[action.row][action.col] = Some(self.player); 647 | self.player = self.player.next(); 648 | self.turn += 1; 649 | self.is_over() 650 | } 651 | 652 | fn features(&self) -> Self::Features { 653 | let mut s = [[[0.0; 3]; 3]; 3]; 654 | for row in 0..3 { 655 | for col in 0..3 { 656 | if let Some(p) = self.board[row][col] { 657 | if p == self.player { 658 | s[0][row][col] = 1.0; 659 | } else { 660 | s[1][row][col] = 1.0; 661 | } 662 | } else { 663 | s[2][row][col] = 1.0; 664 | } 665 | } 666 | } 667 | s 668 | } 669 | 670 | fn print(&self) { 671 | for row in 0..3 { 672 | for col in 0..3 { 673 | print!( 674 | "{}", 675 | match self.board[row][col] { 676 | Some(PlayerId::X) => "x", 677 | Some(PlayerId::O) => "o", 678 | None => ".", 679 | } 680 | ); 681 | } 682 | println!(); 683 | } 684 | println!(); 685 | } 686 | } 687 | 688 | // https://en.wikipedia.org/wiki/Tic-tac-toe 689 | 690 | #[test] 691 | fn test_solve_win() { 692 | let mut rng = StdRng::seed_from_u64(0); 693 | let mut policy = RolloutPolicy { rng: &mut rng }; 694 | let mut game = TicTacToe::new(); 695 | game.step(&Action { row: 0, col: 0 }); 696 | game.step(&Action { row: 0, col: 2 }); 697 | let mut mcts = MCTS::with_capacity( 698 | 1601, 699 | MCTSConfig { 700 | exploration: Exploration::PolynomialUct { c: 2.0 }, 701 | solve: true, 702 | fpu: Fpu::Const(f32::INFINITY), 703 | select_solved_nodes: true, 704 | correct_values_on_solve: true, 705 | auto_extend: true, 706 | root_policy_noise: PolicyNoise::None, 707 | }, 708 | &mut policy, 709 | game.clone(), 710 | ); 711 | while mcts.node(mcts.root).solution.is_none() { 712 | mcts.explore(); 713 | } 714 | let mut search_policy = [0.0; 9]; 715 | mcts.target_policy(&mut search_policy); 716 | // assert_eq!(mcts.node(mcts.root).solution, Some(Outcome::Win)); 717 | // assert_eq!( 718 | // &search_policy, 719 | // &[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0] 720 | // ); 721 | assert_eq!(mcts.solution(&0.into()), None); 722 | assert_eq!(mcts.solution(&1.into()), None); 723 | assert_eq!(mcts.solution(&2.into()), None); 724 | assert_eq!(mcts.solution(&3.into()), None); 725 | assert_eq!(mcts.solution(&4.into()), None); 726 | assert_eq!(mcts.solution(&5.into()), None); 727 | // assert_eq!(mcts.solution(&6.into()), Some(Outcome::Lose)); 728 | assert_eq!(mcts.solution(&7.into()), None); 729 | assert_eq!(mcts.solution(&8.into()), None); 730 | // assert_eq!(mcts.target_q(), 1.0); 731 | assert_eq!(mcts.best_action(ActionSelection::Q), 6.into()); 732 | assert_eq!(mcts.nodes.len(), 311); 733 | } 734 | 735 | #[test] 736 | fn test_solve_loss() { 737 | let mut rng = StdRng::seed_from_u64(0); 738 | let mut policy = RolloutPolicy { rng: &mut rng }; 739 | let mut game = TicTacToe::new(); 740 | game.step(&Action { row: 0, col: 0 }); 741 | game.step(&Action { row: 0, col: 2 }); 742 | game.step(&Action { row: 2, col: 0 }); 743 | let mut mcts = MCTS::with_capacity( 744 | 1601, 745 | MCTSConfig { 746 | exploration: Exploration::PolynomialUct { c: 2.0 }, 747 | solve: true, 748 | correct_values_on_solve: true, 749 | fpu: Fpu::Const(f32::INFINITY), 750 | select_solved_nodes: true, 751 | auto_extend: true, 752 | root_policy_noise: PolicyNoise::None, 753 | }, 754 | &mut policy, 755 | game.clone(), 756 | ); 757 | while mcts.node(mcts.root).solution.is_none() { 758 | mcts.explore(); 759 | } 760 | // assert_eq!(mcts.node(mcts.root).solution, Some(Outcome::Lose)); 761 | let mut search_policy = [0.0; 9]; 762 | mcts.target_policy(&mut search_policy); 763 | // assert_eq!( 764 | // &search_policy, 765 | // &[ 766 | // 0.0, 0.16666667, 0.0, 0.16666667, 0.16666667, 0.16666667, 0.0, 0.16666667, 767 | // 0.16666667 768 | // ] 769 | // ); 770 | // assert_eq!(mcts.solution(&0.into()), None); 771 | // assert_eq!(mcts.solution(&1.into()), Some(Outcome::Win)); 772 | // assert_eq!(mcts.solution(&2.into()), None); 773 | // assert_eq!(mcts.solution(&3.into()), Some(Outcome::Win)); 774 | // assert_eq!(mcts.solution(&4.into()), Some(Outcome::Win)); 775 | // assert_eq!(mcts.solution(&5.into()), Some(Outcome::Win)); 776 | // assert_eq!(mcts.solution(&6.into()), None); 777 | // assert_eq!(mcts.solution(&7.into()), Some(Outcome::Win)); 778 | // assert_eq!(mcts.solution(&8.into()), Some(Outcome::Win)); 779 | // assert_eq!(mcts.target_q(), -1.0); 780 | // assert_eq!(mcts.best_action(ActionSelection::Q), 1.into()); 781 | assert_eq!(mcts.nodes.len(), 69); 782 | } 783 | 784 | #[test] 785 | fn test_solve_draw() { 786 | let mut rng = StdRng::seed_from_u64(0); 787 | let mut policy = RolloutPolicy { rng: &mut rng }; 788 | let mut game = TicTacToe::new(); 789 | game.step(&Action { row: 0, col: 0 }); 790 | game.step(&Action { row: 1, col: 1 }); 791 | let mut mcts = MCTS::with_capacity( 792 | 1601, 793 | MCTSConfig { 794 | exploration: Exploration::PolynomialUct { c: 2.0 }, 795 | solve: true, 796 | correct_values_on_solve: true, 797 | fpu: Fpu::Const(f32::INFINITY), 798 | select_solved_nodes: true, 799 | auto_extend: true, 800 | root_policy_noise: PolicyNoise::None, 801 | }, 802 | &mut policy, 803 | game.clone(), 804 | ); 805 | while mcts.node(mcts.root).solution.is_none() { 806 | mcts.explore(); 807 | } 808 | 809 | // assert_eq!(mcts.node(mcts.root).solution, Some(Outcome::Draw)); 810 | let mut search_policy = [0.0; 9]; 811 | mcts.target_policy(&mut search_policy); 812 | // assert_eq!( 813 | // &search_policy, 814 | // &[ 815 | // 0.0, 0.14285715, 0.14285715, 0.14285715, 0.0, 0.14285715, 0.14285715, 0.14285715, 816 | // 0.14285715 817 | // ] 818 | // ); 819 | assert_eq!(mcts.solution(&0.into()), None); 820 | // assert_eq!(mcts.solution(&1.into()), Some(Outcome::Draw)); 821 | // assert_eq!(mcts.solution(&2.into()), Some(Outcome::Draw)); 822 | // assert_eq!(mcts.solution(&3.into()), Some(Outcome::Draw)); 823 | // assert_eq!(mcts.solution(&4.into()), None); 824 | // assert_eq!(mcts.solution(&5.into()), Some(Outcome::Draw)); 825 | // assert_eq!(mcts.solution(&6.into()), Some(Outcome::Draw)); 826 | // assert_eq!(mcts.solution(&7.into()), Some(Outcome::Draw)); 827 | // assert_eq!(mcts.solution(&8.into()), Some(Outcome::Draw)); 828 | // assert_eq!(mcts.target_q(), 0.0); 829 | assert_eq!(mcts.best_action(ActionSelection::Q), 1.into()); 830 | assert_eq!(mcts.nodes.len(), 1533); 831 | } 832 | 833 | #[test] 834 | fn test_add_noise() { 835 | let mut rng = StdRng::seed_from_u64(0); 836 | let mut policy = RolloutPolicy { rng: &mut rng }; 837 | let game = TicTacToe::new(); 838 | let mut mcts = MCTS::with_capacity( 839 | 1601, 840 | MCTSConfig { 841 | exploration: Exploration::PolynomialUct { c: 2.0 }, 842 | solve: true, 843 | correct_values_on_solve: true, 844 | fpu: Fpu::Const(f32::INFINITY), 845 | select_solved_nodes: false, 846 | auto_extend: false, 847 | root_policy_noise: PolicyNoise::None, 848 | }, 849 | &mut policy, 850 | game.clone(), 851 | ); 852 | let mut rng2 = StdRng::seed_from_u64(0); 853 | 854 | let mut total = 0.0; 855 | for child in mcts.children_of(mcts.node(mcts.root)) { 856 | assert!(child.action_prob > 0.0); 857 | total += child.action_prob; 858 | } 859 | assert!((total - 1.0).abs() < 1e-6); 860 | 861 | mcts.add_dirichlet_noise(&mut rng2, 1.0, 0.25); 862 | let mut total = 0.0; 863 | for child in mcts.children_of(mcts.node(mcts.root)) { 864 | assert!(child.action_prob > 0.0); 865 | total += child.action_prob; 866 | } 867 | assert!((total - 1.0).abs() < 1e-6); 868 | } 869 | } 870 | -------------------------------------------------------------------------------- /synthesis/src/policies/cache.rs: -------------------------------------------------------------------------------- 1 | use crate::game::Game; 2 | use crate::policies::Policy; 3 | use std::collections::HashMap; 4 | 5 | pub struct PolicyWithCache<'a, G: Game, P: Policy, const N: usize> { 6 | pub policy: &'a mut P, 7 | pub cache: HashMap, 8 | } 9 | 10 | impl<'a, G: Game, P: Policy, const N: usize> PolicyWithCache<'a, G, P, N> { 11 | pub fn with_capacity(capacity: usize, policy: &'a mut P) -> Self { 12 | Self { 13 | policy, 14 | cache: HashMap::with_capacity(capacity), 15 | } 16 | } 17 | } 18 | 19 | impl<'a, G: Game, P: Policy, const N: usize> Policy 20 | for PolicyWithCache<'a, G, P, N> 21 | { 22 | fn eval(&mut self, game: &G) -> ([f32; N], [f32; 3]) { 23 | match self.cache.get(&game) { 24 | Some(pi_v) => *pi_v, 25 | None => { 26 | let pi_v = self.policy.eval(game); 27 | self.cache.insert(game.clone(), pi_v); 28 | pi_v 29 | } 30 | } 31 | } 32 | } 33 | 34 | pub struct OwnedPolicyWithCache, P: Policy, const N: usize> { 35 | pub policy: P, 36 | pub cache: HashMap, 37 | } 38 | 39 | impl, P: Policy, const N: usize> OwnedPolicyWithCache { 40 | pub fn with_capacity(capacity: usize, policy: P) -> Self { 41 | Self { 42 | policy, 43 | cache: HashMap::with_capacity(capacity), 44 | } 45 | } 46 | } 47 | 48 | impl, P: Policy, const N: usize> Policy for OwnedPolicyWithCache { 49 | fn eval(&mut self, game: &G) -> ([f32; N], [f32; 3]) { 50 | match self.cache.get(game) { 51 | Some(pi_v) => *pi_v, 52 | None => { 53 | let pi_v = self.policy.eval(game); 54 | self.cache.insert(game.clone(), pi_v); 55 | pi_v 56 | } 57 | } 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /synthesis/src/policies/mod.rs: -------------------------------------------------------------------------------- 1 | mod cache; 2 | mod rollout; 3 | mod traits; 4 | 5 | pub use cache::{OwnedPolicyWithCache, PolicyWithCache}; 6 | pub use rollout::RolloutPolicy; 7 | pub use traits::{NNPolicy, Policy}; 8 | -------------------------------------------------------------------------------- /synthesis/src/policies/rollout.rs: -------------------------------------------------------------------------------- 1 | use crate::game::Game; 2 | use crate::policies::Policy; 3 | use rand::Rng; 4 | 5 | pub struct RolloutPolicy<'a, R: Rng> { 6 | pub rng: &'a mut R, 7 | } 8 | impl<'a, G: Game, R: Rng, const N: usize> Policy for RolloutPolicy<'a, R> { 9 | fn eval(&mut self, game: &G) -> ([f32; N], [f32; 3]) { 10 | let player = game.player(); 11 | let mut rollout_game = game.clone(); 12 | let mut is_over = game.is_over(); 13 | while !is_over { 14 | let actions = rollout_game.iter_actions(); 15 | let num_actions = actions.count() as u8; 16 | let i = self.rng.gen_range(0..num_actions); 17 | let action = rollout_game.iter_actions().nth(i as usize).unwrap(); 18 | is_over = rollout_game.step(&action); 19 | } 20 | let r = rollout_game.reward(player); 21 | ( 22 | [0.0; N], 23 | if r == 0.0 { 24 | [0.0, 1.0, 0.0] 25 | } else if r < 0.0 { 26 | [1.0, 0.0, 0.0] 27 | } else { 28 | [0.0, 0.0, 1.0] 29 | }, 30 | ) 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /synthesis/src/policies/traits.rs: -------------------------------------------------------------------------------- 1 | use crate::game::Game; 2 | use tch::{nn::VarStore, Tensor}; 3 | 4 | pub trait Policy, const N: usize> { 5 | fn eval(&mut self, game: &G) -> ([f32; N], [f32; 3]); 6 | } 7 | 8 | pub trait NNPolicy, const N: usize> { 9 | fn new(vs: &VarStore) -> Self; 10 | fn forward(&self, xs: &Tensor) -> (Tensor, Tensor); 11 | } 12 | -------------------------------------------------------------------------------- /synthesis/src/prelude.rs: -------------------------------------------------------------------------------- 1 | pub use crate::alpha_zero::alpha_zero; 2 | pub use crate::config::{ 3 | ActionSelection, EvaluationConfig, Exploration, Fpu, LearningConfig, MCTSConfig, PolicyNoise, 4 | RolloutConfig, ValueTarget, 5 | }; 6 | pub use crate::data::tensor; 7 | pub use crate::evaluator::evaluator; 8 | pub use crate::game::{Game, HasTurnOrder}; 9 | pub use crate::policies::{NNPolicy, Policy, PolicyWithCache}; 10 | pub use crate::utils::train_dir; 11 | -------------------------------------------------------------------------------- /synthesis/src/utils.rs: -------------------------------------------------------------------------------- 1 | use chrono::prelude::*; 2 | use std::fs::File; 3 | use std::io::prelude::*; 4 | use std::path::{Path, PathBuf}; 5 | use std::process::{Command, Stdio}; 6 | 7 | pub fn train_dir(root: &'static str, tag: &'static str) -> std::io::Result { 8 | let time = Local::now().format("%m-%d-%YT%H-%M-%SZ").to_string(); 9 | Ok(Path::new(root).join(tag).join(time)) 10 | } 11 | 12 | pub fn save_str(path: &PathBuf, filename: &'static str, value: &String) -> std::io::Result<()> { 13 | std::fs::File::create(&path.join(filename)).and_then(|mut f| f.write_all(value.as_bytes())) 14 | } 15 | 16 | pub fn git_hash() -> std::io::Result { 17 | Command::new("git") 18 | .arg("rev-parse") 19 | .arg("HEAD") 20 | .output() 21 | .and_then(|output| { 22 | Ok(String::from_utf8(output.stdout).expect("Command didn't produce valid utf-8")) 23 | }) 24 | } 25 | 26 | pub fn git_diff() -> std::io::Result { 27 | Command::new("git").arg("diff").output().and_then(|output| { 28 | Ok(String::from_utf8(output.stdout).expect("Command didn't produce valid utf-8")) 29 | }) 30 | } 31 | 32 | pub fn add_pgn_result( 33 | pgn: &mut File, 34 | white_name: &String, 35 | black_name: &String, 36 | white_reward: f32, 37 | ) -> std::io::Result<()> { 38 | write!(pgn, "[White \"{}\"]\n", white_name)?; 39 | write!(pgn, "[Black \"{}\"]\n", black_name)?; 40 | let result = if white_reward == 1.0 { 41 | // white wins 42 | "1-0" 43 | } else if white_reward == -1.0 { 44 | // black wins 45 | "0-1" 46 | } else { 47 | assert_eq!(white_reward, 0.0); 48 | // draw 49 | "1/2-1/2" 50 | }; 51 | write!(pgn, "[Result \"{}\"]\n", result)?; 52 | write!(pgn, "{}\n", result) 53 | } 54 | 55 | pub fn calculate_ratings(dir: &PathBuf) -> std::io::Result<()> { 56 | let mut child = Command::new("bayeselo.exe") 57 | .current_dir(dir) 58 | .stdin(Stdio::piped()) 59 | .stdout(Stdio::piped()) 60 | .stderr(Stdio::piped()) 61 | .spawn()?; 62 | let mut stdin = child.stdin.take().unwrap(); 63 | write!(stdin, "readpgn results.pgn\n")?; 64 | write!(stdin, "elo\n")?; 65 | write!(stdin, "mm\n")?; 66 | write!(stdin, "exactdist\n")?; 67 | write!(stdin, "ratings >ratings\n")?; 68 | write!(stdin, "x\n")?; 69 | write!(stdin, "x\n")?; 70 | child.wait()?; 71 | Ok(()) 72 | } 73 | 74 | pub fn plot_ratings(dir: &PathBuf) -> std::io::Result<()> { 75 | let output = Command::new("python") 76 | .arg("plot_ratings.py") 77 | .arg(dir.join("ratings").to_str().unwrap()) 78 | .status()?; 79 | assert!(output.success()); 80 | Ok(()) 81 | } 82 | 83 | pub fn rankings(dir: &PathBuf) -> std::io::Result> { 84 | let file = File::open(dir.join("ratings"))?; 85 | let reader = std::io::BufReader::new(file); 86 | let mut names = Vec::new(); 87 | for line in reader.lines().skip(1) { 88 | let l = String::from(line?.trim()); 89 | match l.find("model_") { 90 | Some(start_i) => { 91 | let end_i = l.find(".ot").unwrap(); 92 | names.push(String::from(l[start_i..end_i + 3].trim())); 93 | } 94 | None => {} 95 | } 96 | } 97 | Ok(names) 98 | } 99 | --------------------------------------------------------------------------------