├── train ├── value │ ├── src │ │ ├── lib.rs │ │ └── bin │ │ │ ├── shatranj.rs │ │ │ ├── frompolicy.rs │ │ │ ├── convert.rs │ │ │ ├── expand.rs │ │ │ ├── ataxx.rs │ │ │ └── chess.rs │ └── Cargo.toml └── policy │ ├── Cargo.toml │ └── src │ ├── bin │ ├── ataxx.rs │ ├── shatranj.rs │ ├── chess.rs │ ├── shuffle.rs │ ├── filter.rs │ └── interleave.rs │ ├── ataxx.rs │ ├── shatranj.rs │ ├── chess.rs │ └── lib.rs ├── datagen ├── src │ ├── impls.rs │ ├── impls │ │ ├── ataxx.rs │ │ ├── shatranj.rs │ │ └── chess.rs │ ├── bin │ │ ├── montyxx.rs │ │ ├── montyj.rs │ │ └── monty.rs │ ├── rng.rs │ ├── lib.rs │ └── thread.rs └── Cargo.toml ├── .gitignore ├── Cargo.toml ├── src ├── bin │ ├── monty.rs │ ├── montyj.rs │ └── montyxx.rs ├── mcts │ ├── helpers.rs │ └── params.rs ├── games │ ├── ataxx │ │ ├── policy.rs │ │ ├── util.rs │ │ ├── moves.rs │ │ └── board.rs │ ├── shatranj │ │ ├── policy.rs │ │ ├── consts.rs │ │ ├── moves.rs │ │ └── attacks.rs │ ├── chess │ │ ├── value.rs │ │ ├── policy.rs │ │ ├── moves.rs │ │ ├── frc.rs │ │ ├── consts.rs │ │ └── attacks.rs │ ├── ataxx.rs │ ├── shatranj.rs │ └── chess.rs ├── lib.rs ├── value.rs ├── tree │ ├── hash.rs │ ├── edge.rs │ └── node.rs ├── games.rs ├── mcts.rs ├── comm.rs └── tree.rs ├── Makefile ├── LICENSE ├── resources ├── ataxx-fens.txt └── chess-fens.txt ├── README.md └── Cargo.lock /train/value/src/lib.rs: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /datagen/src/impls.rs: -------------------------------------------------------------------------------- 1 | pub mod ataxx; 2 | pub mod chess; 3 | pub mod shatranj; 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /.cargo 3 | *-policy-* 4 | *.data 5 | *.binpack 6 | *.exe 7 | *.network 8 | *.epd 9 | checkpoints 10 | data 11 | nets 12 | monty 13 | -------------------------------------------------------------------------------- /datagen/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "datagen" 3 | version = "0.1.0" 4 | edition = "2021" 5 | authors = ["Jamie Whiting"] 6 | 7 | [dependencies] 8 | monty = { path = "../" } 9 | bulletformat = "1.3.0" 10 | -------------------------------------------------------------------------------- /train/policy/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "policy" 3 | version = "0.1.0" 4 | edition = "2021" 5 | authors = ["Jamie Whiting"] 6 | 7 | [dependencies] 8 | datagen = { path = "../../datagen" } 9 | monty = { path = "../../", features = ["nonet"] } 10 | goober = { git = 'https://github.com/jw1912/goober' } 11 | -------------------------------------------------------------------------------- /train/policy/src/bin/ataxx.rs: -------------------------------------------------------------------------------- 1 | use monty::ataxx::PolicyNetwork; 2 | 3 | fn main() { 4 | let mut args = std::env::args(); 5 | args.next(); 6 | let threads = args.next().unwrap().parse().unwrap(); 7 | 8 | policy::train::(threads, "data/ataxx/blah.data".to_string(), 30, 20); 9 | } 10 | -------------------------------------------------------------------------------- /train/policy/src/bin/shatranj.rs: -------------------------------------------------------------------------------- 1 | use monty::shatranj::PolicyNetwork; 2 | 3 | fn main() { 4 | let mut args = std::env::args(); 5 | args.next(); 6 | let threads = args.next().unwrap().parse().unwrap(); 7 | 8 | policy::train::(threads, "data/shatranj/whatever.data".to_string(), 10, 7); 9 | } 10 | -------------------------------------------------------------------------------- /train/policy/src/bin/chess.rs: -------------------------------------------------------------------------------- 1 | use monty::chess::PolicyNetwork; 2 | 3 | fn main() { 4 | let mut args = std::env::args(); 5 | args.next(); 6 | let threads = args.next().unwrap().parse().unwrap(); 7 | 8 | policy::train::( 9 | threads, 10 | "data/chess/policy-with-frc.data".to_string(), 11 | 60, 12 | 25, 13 | ); 14 | } 15 | -------------------------------------------------------------------------------- /train/value/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "value" 3 | version = "0.1.0" 4 | edition = "2021" 5 | authors = ["Jamie Whiting"] 6 | 7 | [features] 8 | shatranj = [] 9 | 10 | [dependencies] 11 | bullet = { package = "bullet_lib", git = 'https://github.com/jw1912/bullet', features = ["cuda"] } 12 | datagen = { path = "../../datagen" } 13 | monty = { path = "../../", features = ["nonet"]} 14 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "monty" 3 | version = "1.0.0" 4 | edition = "2021" 5 | authors = ["Jamie Whiting"] 6 | 7 | [profile.release] 8 | panic = 'abort' 9 | strip = true 10 | lto = true 11 | codegen-units = 1 12 | 13 | [dependencies] 14 | goober = { git = 'https://github.com/jw1912/goober.git' } 15 | 16 | [features] 17 | nonet = [] 18 | 19 | [workspace] 20 | members = ["datagen", "train/policy", "train/value"] 21 | resolver = "2" 22 | -------------------------------------------------------------------------------- /src/bin/monty.rs: -------------------------------------------------------------------------------- 1 | use monty::{ 2 | chess::{PolicyNetwork, Uci, ValueNetwork}, 3 | UciLike, 4 | }; 5 | 6 | #[repr(C)] 7 | struct Nets(ValueNetwork, PolicyNetwork); 8 | 9 | const NETS: Nets = unsafe { std::mem::transmute(*include_bytes!("../../resources/net.network")) }; 10 | 11 | static VALUE: ValueNetwork = NETS.0; 12 | static POLICY: PolicyNetwork = NETS.1; 13 | 14 | fn main() { 15 | let mut args = std::env::args(); 16 | let arg1 = args.nth(1); 17 | 18 | if let Some("bench") = arg1.as_deref() { 19 | monty::chess::Uci::bench(5, &POLICY, &VALUE); 20 | return; 21 | } 22 | 23 | Uci::run(&POLICY, &VALUE); 24 | } 25 | -------------------------------------------------------------------------------- /src/bin/montyj.rs: -------------------------------------------------------------------------------- 1 | use monty::{ 2 | shatranj::{PolicyNetwork, Uci}, 3 | UciLike, ValueNetwork, 4 | }; 5 | 6 | #[repr(C)] 7 | struct Nets(ValueNetwork<768, 8>, PolicyNetwork); 8 | 9 | const NETS: Nets = unsafe { std::mem::transmute(*include_bytes!("../../resources/net.network")) }; 10 | 11 | static VALUE: ValueNetwork<768, 8> = NETS.0; 12 | static POLICY: PolicyNetwork = NETS.1; 13 | 14 | fn main() { 15 | let mut args = std::env::args(); 16 | let arg1 = args.nth(1); 17 | 18 | if let Some("bench") = arg1.as_deref() { 19 | Uci::bench(6, &POLICY, &VALUE); 20 | return; 21 | } 22 | 23 | Uci::run(&POLICY, &VALUE); 24 | } 25 | -------------------------------------------------------------------------------- /src/bin/montyxx.rs: -------------------------------------------------------------------------------- 1 | use monty::{ 2 | ataxx::{PolicyNetwork, Uai}, 3 | UciLike, ValueNetwork, 4 | }; 5 | 6 | #[repr(C)] 7 | struct Nets(ValueNetwork<2916, 256>, PolicyNetwork); 8 | 9 | const NETS: Nets = unsafe { std::mem::transmute(*include_bytes!("../../resources/net.network")) }; 10 | 11 | static VALUE: ValueNetwork<2916, 256> = NETS.0; 12 | static POLICY: PolicyNetwork = NETS.1; 13 | 14 | fn main() { 15 | let mut args = std::env::args(); 16 | let arg1 = args.nth(1); 17 | 18 | if let Some("bench") = arg1.as_deref() { 19 | Uai::bench(5, &POLICY, &VALUE); 20 | return; 21 | } 22 | 23 | Uai::run(&POLICY, &VALUE); 24 | } 25 | -------------------------------------------------------------------------------- /datagen/src/impls/ataxx.rs: -------------------------------------------------------------------------------- 1 | use monty::ataxx::{Ataxx, Board, Move}; 2 | 3 | use crate::{BinpackType, DatagenSupport}; 4 | 5 | impl DatagenSupport for Ataxx { 6 | type CompressedBoard = Board; 7 | type Binpack = (); 8 | } 9 | 10 | impl BinpackType for ::Binpack { 11 | fn new(_: Ataxx) -> Self {} 12 | 13 | fn push(&mut self, _: usize, _: Move, _: f32) {} 14 | 15 | fn deserialise_from( 16 | _: &mut impl std::io::BufRead, 17 | _: Vec<(u16, i16)>, 18 | ) -> std::io::Result { 19 | Ok(()) 20 | } 21 | 22 | fn serialise_into(&self, _: &mut impl std::io::Write) -> std::io::Result<()> { 23 | Ok(()) 24 | } 25 | 26 | fn set_result(&mut self, _: f32) {} 27 | } 28 | -------------------------------------------------------------------------------- /datagen/src/bin/montyxx.rs: -------------------------------------------------------------------------------- 1 | use datagen::{parse_args, run_datagen}; 2 | use monty::{ 3 | ataxx::{Ataxx, PolicyNetwork}, 4 | GameRep, ValueNetwork, 5 | }; 6 | 7 | #[repr(C)] 8 | struct Nets(ValueNetwork<2916, 256>, PolicyNetwork); 9 | 10 | const NETS: Nets = 11 | unsafe { std::mem::transmute(*include_bytes!("../../../resources/net.network")) }; 12 | 13 | static VALUE: ValueNetwork<2916, 256> = NETS.0; 14 | static POLICY: PolicyNetwork = NETS.1; 15 | 16 | fn main() { 17 | let args = std::env::args(); 18 | let (threads, book, policy) = parse_args(args); 19 | 20 | run_datagen::( 21 | Ataxx::default_mcts_params(), 22 | 1_000, 23 | threads, 24 | policy, 25 | "Ataxx", 26 | &POLICY, 27 | &VALUE, 28 | book, 29 | ); 30 | } 31 | -------------------------------------------------------------------------------- /datagen/src/bin/montyj.rs: -------------------------------------------------------------------------------- 1 | use datagen::{parse_args, run_datagen}; 2 | use monty::{ 3 | shatranj::{PolicyNetwork, Shatranj}, 4 | GameRep, ValueNetwork, 5 | }; 6 | 7 | #[repr(C)] 8 | struct Nets(ValueNetwork<768, 8>, PolicyNetwork); 9 | 10 | const NETS: Nets = 11 | unsafe { std::mem::transmute(*include_bytes!("../../../resources/net.network")) }; 12 | 13 | static VALUE: ValueNetwork<768, 8> = NETS.0; 14 | static POLICY: PolicyNetwork = NETS.1; 15 | 16 | fn main() { 17 | let args = std::env::args(); 18 | let (threads, book, policy) = parse_args(args); 19 | 20 | run_datagen::( 21 | Shatranj::default_mcts_params(), 22 | 1_000, 23 | threads, 24 | policy, 25 | "Shatranj", 26 | &POLICY, 27 | &VALUE, 28 | book, 29 | ); 30 | } 31 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | EXE = monty 2 | 3 | ifeq ($(OS),Windows_NT) 4 | NAME := $(EXE).exe 5 | OLD := monty-$(VER).exe 6 | AVX2 := monty-$(VER)-avx2.exe 7 | else 8 | NAME := $(EXE) 9 | OLD := monty-$(VER) 10 | AVX2 := monty-$(VER)-avx2 11 | endif 12 | 13 | chess: 14 | cargo rustc --release --package monty --bin monty -- -C target-cpu=native --emit link=$(NAME) 15 | 16 | ataxx: 17 | cargo rustc --release --package monty --bin montyxx -- -C target-cpu=native --emit link=$(NAME) 18 | 19 | shatranj: 20 | cargo rustc --release --package monty --bin montyj -- -C target-cpu=native --emit link=$(NAME) 21 | 22 | gen: 23 | cargo rustc --release --package datagen --bin monty -- -C target-cpu=native --emit link=$(NAME) 24 | 25 | release: 26 | cargo rustc --release --bin monty -- --emit link=$(OLD) 27 | cargo rustc --release --bin monty -- -C target-cpu=x86-64-v2 -C target-feature=+avx2 --emit link=$(AVX2) -------------------------------------------------------------------------------- /src/mcts/helpers.rs: -------------------------------------------------------------------------------- 1 | use crate::{mcts::MctsParams, tree::Edge}; 2 | 3 | pub struct SearchHelpers; 4 | 5 | impl SearchHelpers { 6 | pub fn get_cpuct(params: &MctsParams, parent: &Edge) -> f32 { 7 | // baseline CPUCT value 8 | let mut cpuct = params.cpuct(); 9 | 10 | // scale CPUCT as visits increase 11 | cpuct *= 1.0 + (((parent.visits() + 8192) / 8192) as f32).ln(); 12 | 13 | // scale CPUCT with variance of Q 14 | if parent.visits() > 1 { 15 | let frac = parent.var().sqrt() / params.cpuct_var_scale(); 16 | cpuct *= 1.0 + params.cpuct_var_weight() * (frac - 1.0); 17 | } 18 | 19 | cpuct 20 | } 21 | 22 | pub fn get_fpu(parent: &Edge) -> f32 { 23 | 1.0 - parent.q() 24 | } 25 | 26 | pub fn get_action_value(action: &Edge, fpu: f32) -> f32 { 27 | if action.visits() == 0 { 28 | fpu 29 | } else { 30 | action.q() 31 | } 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /datagen/src/bin/monty.rs: -------------------------------------------------------------------------------- 1 | use datagen::{parse_args, run_datagen}; 2 | use monty::{ 3 | chess::{Chess, PolicyNetwork, ValueNetwork}, 4 | GameRep, UciLike, 5 | }; 6 | 7 | #[repr(C)] 8 | struct Nets(ValueNetwork, PolicyNetwork); 9 | 10 | const NETS: Nets = 11 | unsafe { std::mem::transmute(*include_bytes!("../../../resources/net.network")) }; 12 | 13 | static VALUE: ValueNetwork = NETS.0; 14 | static POLICY: PolicyNetwork = NETS.1; 15 | 16 | fn main() { 17 | let args = std::env::args(); 18 | let (threads, book, policy) = parse_args(args); 19 | 20 | monty::chess::Uci::bench(4, &POLICY, &VALUE); 21 | 22 | if let Some(path) = &book { 23 | println!("Using book: {path}") 24 | } else { 25 | println!("Not using a book.") 26 | } 27 | 28 | let mut params = Chess::default_mcts_params(); 29 | 30 | // value data params 31 | params.set("root_pst", 2.62); 32 | params.set("cpuct", 1.08); 33 | 34 | run_datagen::( 35 | params, 5_000, threads, policy, "Chess", &POLICY, &VALUE, book, 36 | ); 37 | } 38 | -------------------------------------------------------------------------------- /datagen/src/rng.rs: -------------------------------------------------------------------------------- 1 | use std::time::{SystemTime, UNIX_EPOCH}; 2 | 3 | pub struct Rand(u32); 4 | 5 | impl Default for Rand { 6 | fn default() -> Self { 7 | Self( 8 | (std::time::SystemTime::now() 9 | .duration_since(std::time::UNIX_EPOCH) 10 | .expect("valid") 11 | .as_nanos() 12 | & 0xFFFF_FFFF) as u32, 13 | ) 14 | } 15 | } 16 | 17 | impl Rand { 18 | pub fn rand_int(&mut self) -> u32 { 19 | self.0 ^= self.0 << 13; 20 | self.0 ^= self.0 >> 17; 21 | self.0 ^= self.0 << 5; 22 | self.0 23 | } 24 | 25 | pub fn rand_f32(&mut self, abs_max: f32) -> f32 { 26 | let rand_int = self.rand_int(); 27 | let float = f64::from(rand_int) / f64::from(u32::MAX); 28 | (2.0 * float - 1.0) as f32 * abs_max 29 | } 30 | 31 | pub fn with_seed() -> Self { 32 | let seed = SystemTime::now() 33 | .duration_since(UNIX_EPOCH) 34 | .expect("Guaranteed increasing.") 35 | .as_micros() as u32; 36 | 37 | Self(seed) 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Jamie Whiting 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/games/ataxx/policy.rs: -------------------------------------------------------------------------------- 1 | use super::moves::Move; 2 | 3 | use goober::{activation, layer, FeedForwardNetwork, Matrix, SparseVector, Vector}; 4 | 5 | #[repr(C)] 6 | #[derive(Clone, Copy, FeedForwardNetwork)] 7 | pub struct SubNet { 8 | ft: layer::SparseConnected, 9 | } 10 | 11 | impl SubNet { 12 | pub const fn zeroed() -> Self { 13 | Self { 14 | ft: layer::SparseConnected::zeroed(), 15 | } 16 | } 17 | 18 | pub fn from_fn f32>(mut f: F) -> Self { 19 | let matrix = Matrix::from_fn(|_, _| f()); 20 | let vector = Vector::from_fn(|_| f()); 21 | 22 | Self { 23 | ft: layer::SparseConnected::from_raw(matrix, vector), 24 | } 25 | } 26 | } 27 | 28 | #[repr(C)] 29 | #[derive(Clone, Copy)] 30 | pub struct PolicyNetwork { 31 | pub subnets: [SubNet; 99], 32 | } 33 | 34 | impl PolicyNetwork { 35 | pub fn get(&self, mov: &Move, feats: &SparseVector) -> f32 { 36 | let from_subnet = &self.subnets[mov.from().min(49)]; 37 | let from_vec = from_subnet.out(feats); 38 | 39 | let to_subnet = &self.subnets[50 + mov.to().min(48)]; 40 | let to_vec = to_subnet.out(feats); 41 | 42 | from_vec.dot(&to_vec) 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | mod comm; 2 | mod games; 3 | mod mcts; 4 | mod tree; 5 | mod value; 6 | 7 | pub use comm::UciLike; 8 | pub use games::{ataxx, chess, shatranj, GameRep, GameState}; 9 | pub use mcts::{Limits, MctsParams, Searcher}; 10 | pub use tree::Tree; 11 | pub use value::ValueNetwork; 12 | 13 | // Macro for calculating tables (until const fn pointers are stable). 14 | #[macro_export] 15 | macro_rules! init { 16 | (|$sq:ident, $size:literal | $($rest:tt)+) => {{ 17 | let mut $sq = 0; 18 | let mut res = [{$($rest)+}; $size]; 19 | while $sq < $size { 20 | res[$sq] = {$($rest)+}; 21 | $sq += 1; 22 | } 23 | res 24 | }}; 25 | } 26 | 27 | #[macro_export] 28 | macro_rules! pop_lsb { 29 | ($idx:ident, $x:expr) => { 30 | let $idx = $x.trailing_zeros() as u16; 31 | $x &= $x - 1 32 | }; 33 | } 34 | 35 | /// # Safety 36 | /// Object must be valid if fully zeroed. 37 | pub unsafe fn boxed_and_zeroed() -> Box { 38 | unsafe { 39 | let layout = std::alloc::Layout::new::(); 40 | let ptr = std::alloc::alloc_zeroed(layout); 41 | if ptr.is_null() { 42 | std::alloc::handle_alloc_error(layout); 43 | } 44 | Box::from_raw(ptr.cast()) 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /src/value.rs: -------------------------------------------------------------------------------- 1 | const SCALE: i32 = 400; 2 | const QA: i32 = 255; 3 | const QB: i32 = 64; 4 | const QAB: i32 = QA * QB; 5 | 6 | #[repr(C, align(64))] 7 | pub struct ValueNetwork { 8 | l1_weights: [Accumulator; INPUT], 9 | l1_bias: Accumulator, 10 | l2_weights: Accumulator, 11 | l2_bias: i16, 12 | } 13 | 14 | pub trait ValueFeatureMap { 15 | fn value_feature_map(&self, f: F); 16 | } 17 | 18 | #[derive(Clone, Copy)] 19 | #[repr(C)] 20 | struct Accumulator { 21 | vals: [i16; HIDDEN], 22 | } 23 | 24 | #[inline] 25 | fn screlu(x: i16) -> i32 { 26 | i32::from(x).clamp(0, QA).pow(2) 27 | } 28 | 29 | impl ValueNetwork