├── .github └── workflows │ └── ci.yaml ├── .gitignore ├── .python-version ├── .vscode ├── launch.json ├── settings.json └── tasks.json ├── LICENSE.md ├── README.md ├── images ├── learning.png └── tui.png ├── pyproject.toml ├── rust ├── .gitignore ├── Cargo.lock ├── Cargo.toml ├── proptest-regressions │ ├── c4r.txt │ ├── mcts.txt │ └── solver.txt └── src │ ├── c4r.rs │ ├── interactive_play.rs │ ├── lib.rs │ ├── mcts.rs │ ├── pybridge.rs │ ├── self_play.rs │ ├── solver.rs │ ├── tui.rs │ ├── types.rs │ └── utils.rs ├── src └── c4a0 │ ├── __init__.py │ ├── explore.ipynb │ ├── main.py │ ├── nn.py │ ├── sweep.py │ ├── tournament.py │ ├── training.py │ └── utils.py ├── tests └── c4a0_tests │ ├── __init__.py │ ├── conftest.py │ ├── nn_test.py │ └── tournament_test.py └── uv.lock /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: c4a0 CI 2 | 3 | on: 4 | # Triggers the workflow on push events for all branches 5 | push: 6 | 7 | # Allows you to run this workflow manually from the Actions tab 8 | workflow_dispatch: 9 | 10 | jobs: 11 | ci: 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - uses: actions/checkout@v4 16 | 17 | # Rust actions 18 | - name: install rust 19 | run: rustup update stable && rustup default stable 20 | - name: cargo test 21 | working-directory: rust 22 | run: cargo test 23 | 24 | # Python actions 25 | - uses: actions/setup-python@v4 26 | - uses: yezz123/setup-uv@v4 27 | - run: uv sync 28 | - run: uv run ruff check 29 | - run: uv run maturin build --release 30 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .aider* 2 | 3 | .pytest_cache/ 4 | __pycache__/ 5 | .ruff_cache/ 6 | 7 | lightning_logs/ 8 | training/ 9 | training-sweeps/ 10 | solver/ 11 | 12 | *.bak 13 | *.db 14 | *.log 15 | *.pkl 16 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.11.6 2 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | "configurations": [ 3 | { 4 | "name": "Python: src/main.py", 5 | "type": "debugpy", 6 | "request": "launch", 7 | "program": "rye run python src/main.py", 8 | "console": "integratedTerminal", 9 | "justMyCode": true 10 | } 11 | ] 12 | } 13 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.testing.pytestArgs": [ 3 | "tests" 4 | ], 5 | "python.testing.unittestEnabled": false, 6 | "python.testing.pytestEnabled": true, 7 | "python.analysis.extraPaths": [ 8 | "./src" 9 | ], 10 | "python.analysis.typeCheckingMode": "basic", 11 | "rust-analyzer.linkedProjects": [ 12 | "rust/Cargo.toml" 13 | ] 14 | } 15 | -------------------------------------------------------------------------------- /.vscode/tasks.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "2.0.0", 3 | "tasks": [ 4 | { 5 | "label": "Run Python Script", 6 | "type": "shell", 7 | "command": "rye", 8 | "args": ["run", "python", "src/c4a0/main.py", "train"], 9 | "group": { 10 | "kind": "build", 11 | "isDefault": true 12 | }, 13 | "problemMatcher": [], 14 | "options": { 15 | "env": { 16 | "PYTHONPATH": "${workspaceFolder}/src" 17 | } 18 | } 19 | } 20 | ] 21 | } 22 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright 2024 Advait Shinde 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and 4 | associated documentation files (the “Software”), to deal in the Software without restriction, 5 | including without limitation the rights to use, copy, modify, merge, publish, distribute, 6 | sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is 7 | furnished to do so, subject to the following conditions: 8 | 9 | The above copyright notice and this permission notice shall be included in all copies or substantial 10 | portions of the Software. 11 | 12 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT 13 | NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 14 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES 15 | OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 16 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 17 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # c4a0: Connect Four Alpha-Zero 2 | 3 | ![CI](https://github.com/advait/c4a0/actions/workflows/ci.yaml/badge.svg?ts=2) 4 | 5 | An Alpha-Zero-style Connect Four engine trained entirely via self play. 6 | 7 | The game logic, Monte Carlo Tree Search, and multi-threaded self play engine is written in rust 8 | [here](https://github.com/advait/c4a0/tree/master/rust). 9 | 10 | The NN is written in Python/PyTorch [here](https://github.com/advait/c4a0/tree/master/src/c4a0?ts=2) 11 | and interfaces with rust via [PyO3](https://pyo3.rs/v0.22.2/) 12 | 13 | ![Terminal UI](https://raw.githubusercontent.com/advait/c4a0/refs/heads/master/images/tui.png) 14 | 15 | ## Usage 16 | 17 | 1. Install clang 18 | ```sh 19 | # Instructions for Ubuntu/Debian (other OSs may vary) 20 | sudo apt install clang 21 | ``` 22 | 23 | 2. Install [uv](https://docs.astral.sh/uv/getting-started/installation/) for python dep/env management 24 | ```sh 25 | curl -LsSf https://astral.sh/uv/install.sh | sh 26 | ``` 27 | 28 | 3. Install deps and create virtual env: 29 | ```sh 30 | uv sync 31 | ``` 32 | 33 | 4. Compile rust code 34 | ```sh 35 | uv run maturin develop --release 36 | ``` 37 | 38 | 4. Train a network 39 | ```sh 40 | uv run src/c4a0/main.py train --max-gens=10 41 | ``` 42 | 43 | 5. Play against the network 44 | ```sh 45 | uv run src/c4a0/main.py play --model=best 46 | ``` 47 | 48 | 6. (Optional) Download a [connect four solver](https://github.com/PascalPons/connect4?ts=2) to 49 | objectively measure training progress: 50 | ``` 51 | git clone https://github.com/PascalPons/connect4.git solver 52 | cd solver 53 | make 54 | # Download opening book to speed up solutions 55 | wget https://github.com/PascalPons/connect4/releases/download/book/7x6.book 56 | ``` 57 | 58 | Now pass the solver paths to `train`, `score` and other commands: 59 | ``` 60 | uv run python src/c4a0/main.py score solver/c4solver solver/7x6.book 61 | ``` 62 | 63 | ## Results 64 | After 9 generations of training (approx ~15 min on an RTX 3090) we achieve the following results: 65 | 66 | ![Training Results](https://raw.githubusercontent.com/advait/c4a0/refs/heads/master/images/learning.png) 67 | 68 | ## Architecture 69 | 70 | ### PyTorch NN [`src/c4a0/nn.py`](https://github.com/advait/c4a0/blob/master/src/c4a0/nn.py?ts=2) 71 | 72 | A resnet-style CNN that takes in as input a baord position and outputs a Policy (probability 73 | distribution over moves weighted by promise) and Q Value (predicted win/loss value [-1, 1]). 74 | 75 | Various NN hyperparameters can are sweepable via the `nn-sweep` command. 76 | 77 | ### Connect Four Game Logic [`rust/src/c4r.rs`](https://github.com/advait/c4a0/blob/master/rust/src/c4r.rs?ts=2) 78 | 79 | Implements compact bitboard representation of board state (`Pos`) and all connect four rules 80 | and game logic. 81 | 82 | ### Monte Carlo Tree Search (MCTS) [`rust/src/mcts.rs`](https://github.com/advait/c4a0/blob/master/rust/src/mcts.rs?ts=2) 83 | 84 | Implements Monte Carlo Tree Search - the core algorithm behind Alpha-Zero. Probabalistically 85 | explores potential game pathways and optimally hones in on the optimal move to play from any 86 | position. 87 | 88 | MCTS relies on outputs from the NN. The output of MCTS helps train the next generation's NN. 89 | 90 | ### Self Play [`rust/src/self_play.rs`](https://github.com/advait/c4a0/blob/master/rust/src/self_play.rs?ts=2) 91 | 92 | Uses rust multi-threading to parallelize self play (training data generation). 93 | 94 | ### Solver [`rust/src/solver.rs`](https://github.com/advait/c4a0/blob/master/rust/src/solver.rs?ts=2) 95 | 96 | Connect Four is a perfectly solved game. See Pascal Pons's [great 97 | writeup](http://blog.gamesolver.org/) on how to build a perfect solver. We can use these solutions 98 | to objectively measure our NN's performance. Importantly we **never train on these solutions**, 99 | instead only using our self-play data to improve the NN's performance. 100 | 101 | `solver.rs` contains the stdin/out interface to learn the objective solutions to our training 102 | positions. Because solutions are expensive to compute, we cache them in a local 103 | [rocksdb](https://docs.rs/rocksdb/latest/rocksdb/) database (solutions.db). We then measure our 104 | training positions to see how often they recommend optimal moves as determined by the solver. 105 | -------------------------------------------------------------------------------- /images/learning.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/advait/c4a0/49cb6584fd4cc31e68f056f210faff9bac8823cf/images/learning.png -------------------------------------------------------------------------------- /images/tui.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/advait/c4a0/49cb6584fd4cc31e68f056f210faff9bac8823cf/images/tui.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "c4a0" 3 | version = "0.1.0" 4 | description = "" 5 | authors = [{ name = "Advait Shinde", email = "advait.shinde@gmail.com" }] 6 | dependencies = [ 7 | "pytorch-lightning>=2.5.0", 8 | "numpy>=2.2.3", 9 | "einops>=0.8.1", 10 | "torchmetrics>=1.6.2", 11 | "tensorboardx>=2.6.2.2", 12 | "tensorboard>=2.19.0", 13 | "tqdm>=4.67.1", 14 | "jupyterlab>=4.3.5", 15 | "pydantic>=2.10.6", 16 | "tabulate>=0.9.0", 17 | "matplotlib>=3.10.1", 18 | "maturin>=1.8.2", 19 | "loguru>=0.7.3", 20 | "typer>=0.15.2", 21 | "optuna>=4.2.1", 22 | "optuna-dashboard>=0.17.0", 23 | ] 24 | readme = "README.md" 25 | requires-python = ">= 3.11" 26 | 27 | [build-system] 28 | requires = ["maturin>=1,<2"] 29 | build-backend = "maturin" 30 | 31 | [tool.uv] 32 | managed = true 33 | dev-dependencies = [ 34 | "pytest>=8.3.5", 35 | "pytest-asyncio>=0.25.3", 36 | "pyright>=1.1.396", 37 | "jupyterlab>=4.0.10", 38 | "pandas>=2.1.4", 39 | "rankit>=0.3.3", 40 | "ruff>=0.9.9", 41 | ] 42 | 43 | [tool.pyright] 44 | typeCheckingMode = "basic" 45 | extraPaths = ["src", "tests"] 46 | 47 | [tool.pytest.ini_options] 48 | pythonpath = ["src"] 49 | testpaths = ["tests"] 50 | filterwarnings = [ 51 | # Disable warnings we get form pytorch lightning that clutter pytest output 52 | "ignore:pkg_resources is deprecated.*:DeprecationWarning", 53 | "ignore:Deprecated call to `pkg_resources.declare_namespace`.*:DeprecationWarning", 54 | "ignore:Deprecated call to `pkg_resources.declare_namespace.*:DeprecationWarning", 55 | "ignore:You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet.*", 56 | ] 57 | -------------------------------------------------------------------------------- /rust/.gitignore: -------------------------------------------------------------------------------- 1 | target/ 2 | -------------------------------------------------------------------------------- /rust/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "c4a0_rust" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [lib] 7 | name = "c4a0_rust" 8 | crate-type = ["cdylib"] 9 | 10 | [dependencies] 11 | pyo3 = { version = "0.21.2", features = ["extension-module"] } 12 | approx = "0.5.1" 13 | more-asserts = "0.3.1" 14 | crossbeam = "0.8.4" 15 | crossbeam-channel = "0.5.13" 16 | num_cpus = "1.16.0" 17 | rand = "0.8.5" 18 | indicatif = "0.17.8" 19 | numpy = "0.21.0" 20 | serde = { version = "1.0.204", features = ["derive"] } 21 | serde_cbor = "0.11.2" 22 | ratatui = "0.27.0" 23 | parking_lot = "0.12.3" 24 | log = "0.4.22" 25 | env_logger = "0.11.5" 26 | rocksdb = "0.22.0" 27 | 28 | [dev-dependencies] 29 | proptest = "1.5.0" 30 | -------------------------------------------------------------------------------- /rust/proptest-regressions/c4r.txt: -------------------------------------------------------------------------------- 1 | # Seeds for failure cases proptest has generated in the past. It is 2 | # automatically read and these particular cases re-run before any 3 | # novel cases are generated. 4 | # 5 | # It is recommended to check this file in to source control so that 6 | # everyone who runs the test benefits from these saved cases. 7 | cc f144b09895d55faeaa53836e51932532269c4177695e3cdcff819fed145760a1 # shrinks to pos = 🔵⚫⚫⚫⚫⚫⚫ 🔴⚫⚫⚫⚫⚫⚫ 🔵⚫⚫⚫⚫⚫⚫ 🔴⚫⚫⚫⚫⚫⚫ 🔵⚫⚫⚫⚫⚫⚫ 🔴⚫🔵⚫⚫⚫⚫ mask: 0000000000000000000000000000100000010000001000000100000010000101 value: 0000000000000000000000000000000000010000000000000100000000000001 8 | cc f2c9ed96feb1929fc6be7b6f4ceb7b4bb2945cb31f9413338ca96d4e791dafee # shrinks to pos = 🔵⚫⚫⚫⚫⚫⚫ 🔴⚫⚫⚫⚫⚫⚫ 🔵⚫⚫⚫⚫⚫⚫ 🔴⚫⚫⚫⚫⚫🔵 🔵🔴⚫⚫⚫⚫🔵 🔴🔵🔴⚫⚫⚫🔴 mask: 0000000000000000000000000000100000010000001100000110000111000111 value: 0000000000000000000000000000000000010000000000000100000101000101 9 | -------------------------------------------------------------------------------- /rust/proptest-regressions/mcts.txt: -------------------------------------------------------------------------------- 1 | # Seeds for failure cases proptest has generated in the past. It is 2 | # automatically read and these particular cases re-run before any 3 | # novel cases are generated. 4 | # 5 | # It is recommended to check this file in to source control so that 6 | # everyone who runs the test benefits from these saved cases. 7 | cc 40e16b9b688b505df9abaa47b5f4776d343e4b7b1e64b74d269df92d58f2cfb2 # shrinks to policy = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] 8 | cc 19bd609bfb408cf67e793e846d64293283bb382e095250405f5ddaecc76bd3bf # shrinks to policy_log = [0.0, 0.0, -6.872888e19, 0.0, 0.0, 0.0, 0.0] 9 | cc d17aeb3793210e7443418e6a8b852d22c01eddc1225c6d644bcf0c117c2e1c98 # shrinks to policy = [0.9286058, 0.0, 0.0033046294, 0.06687763, 0.0, 0.0, 0.001211846] 10 | cc b4b48a3dbc95e9b50dd37dcd6a6a89ddd75347d26bc888ac92e4086e0481a8e3 # shrinks to policy = [0.4780801, 2.5148089e-5, 2.5148089e-5, 0.52179414, 2.5148089e-5, 2.5148089e-5, 2.5148089e-5] 11 | cc db28139dac7e220d26e3b4a97b97792ff1efc2f86afd17a6a4d3ef0040701844 # shrinks to policy = [0.0, 0.933416, 0.00035163847, 0.0009350313, 0.0, 0.06520966, 8.7709726e-5] 12 | cc 19b63b05c1a9a12558899ca40e18023e4be6da5c0be577fc65589720bd59a4d9 # shrinks to policy = [0.0, 0.106206864, 0.0, 0.0, 0.148644, 0.7410872, 0.004062006] 13 | -------------------------------------------------------------------------------- /rust/proptest-regressions/solver.txt: -------------------------------------------------------------------------------- 1 | # Seeds for failure cases proptest has generated in the past. It is 2 | # automatically read and these particular cases re-run before any 3 | # novel cases are generated. 4 | # 5 | # It is recommended to check this file in to source control so that 6 | # everyone who runs the test benefits from these saved cases. 7 | cc 2540f4e7d164d44b26da0fb0d76df80bf87d8a43b10e8a5eb9dbbe5f78313814 # shrinks to pos = ⚫⚫⚫⚫⚫⚫⚫ ⚫⚫⚫⚫⚫⚫⚫ ⚫⚫⚫⚫⚫⚫⚫ ⚫⚫⚫⚫⚫⚫⚫ ⚫⚫⚫⚫⚫⚫⚫ ⚫⚫⚫⚫⚫⚫⚫ mask: 0000000000000000000000000000000000000000000000000000000000000000 value: 0000000000000000000000000000000000000000000000000000000000000000 8 | cc c580718d3fb4c6a50604f149054a9409a4d353082e9ad7147c4c1372662bc333 # shrinks to pos = ⚫⚫⚫⚫⚫⚫⚫ ⚫⚫⚫⚫⚫⚫⚫ ⚫⚫⚫⚫⚫⚫⚫ ⚫⚫⚫⚫⚫⚫⚫ ⚫⚫⚫⚫⚫⚫⚫ ⚫⚫⚫⚫⚫⚫⚫ mask: 0000000000000000000000000000000000000000000000000000000000000000 value: 0000000000000000000000000000000000000000000000000000000000000000 9 | -------------------------------------------------------------------------------- /rust/src/c4r.rs: -------------------------------------------------------------------------------- 1 | use core::fmt; 2 | use std::{array::from_fn, fmt::Display}; 3 | 4 | use more_asserts::debug_assert_gt; 5 | use serde::{Deserialize, Serialize}; 6 | 7 | use crate::types::{Policy, QValue}; 8 | 9 | /// Connect four position. 10 | /// Internally consists of a u64 mask (bitmask representing whether a piece exists at a given 11 | /// location) and a u64 value (bitmask representing the color of the given piece). 12 | /// Bit indexing is specified by [Pos::_idx_mask_unsafe]. 13 | #[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] 14 | pub struct Pos { 15 | mask: u64, 16 | value: u64, 17 | } 18 | 19 | /// The oponnent/player token within a cell. 20 | #[derive(Debug, PartialEq, Eq, Clone, Copy)] 21 | pub enum CellValue { 22 | Opponent = 0, 23 | Player = 1, 24 | } 25 | 26 | /// Possible terminal states of a connect four game. 27 | #[derive(Debug, PartialEq, Eq, Clone, Copy)] 28 | pub enum TerminalState { 29 | PlayerWin, 30 | OpponentWin, 31 | Draw, 32 | } 33 | 34 | /// The column for a given move (0..[Pos::N_COLS]) 35 | pub type Move = usize; 36 | 37 | impl Default for Pos { 38 | fn default() -> Self { 39 | Pos { mask: 0, value: 0 } 40 | } 41 | } 42 | 43 | impl Pos { 44 | pub const N_ROWS: usize = 6; 45 | pub const N_COLS: usize = 7; 46 | 47 | /// The number of channels in the numpy buffer (one per player) 48 | pub const BUF_N_CHANNELS: usize = 2; 49 | /// The length of a single channel (in # of f32s) of the numpy buffer 50 | pub const BUF_CHANNEL_LEN: usize = Self::N_ROWS * Self::N_COLS; 51 | /// The required length (in # of f32s) of the numpy buffer 52 | pub const BUF_LEN: usize = Self::BUF_N_CHANNELS * Self::BUF_CHANNEL_LEN; 53 | 54 | /// Plays a move in the given column from the perspective of the [CellValue::Player]. 55 | /// Returns a new position where the cell values are flipped. 56 | /// Performs bounds and collision checing. 57 | /// DOES NOT perform win checking. 58 | pub fn make_move(&self, col: Move) -> Option { 59 | if col > Self::N_COLS { 60 | return None; 61 | } 62 | 63 | for row in 0..Self::N_ROWS { 64 | let idx = Self::_idx_mask_unsafe(row, col); 65 | if (idx & self.mask) == 0 { 66 | let mut ret = self.clone(); 67 | ret._set_piece_unsafe(row, col, Some(CellValue::Player)); 68 | return Some(ret.invert()); 69 | } 70 | } 71 | None 72 | } 73 | 74 | /// Returns the value of cell at the given position. 75 | /// Performs bounds checking. 76 | pub fn get(&self, row: usize, col: usize) -> Option { 77 | if col > Self::N_COLS || row > Self::N_ROWS { 78 | return None; 79 | } 80 | let idx = Self::_idx_mask_unsafe(row, col); 81 | 82 | if (self.mask & idx) == 0 { 83 | return None; 84 | } 85 | 86 | if (self.value & idx) == 0 { 87 | Some(CellValue::Opponent) 88 | } else { 89 | Some(CellValue::Player) 90 | } 91 | } 92 | 93 | /// Returns the ply of the position or the number of moves that have been played. 94 | /// Ply of 0 is the starting position. 95 | pub fn ply(&self) -> usize { 96 | u64::count_ones(self.mask).try_into().unwrap() 97 | } 98 | 99 | /// Mutably sets the given piece without any bounds or collision checking. 100 | fn _set_piece_unsafe(&mut self, row: usize, col: usize, piece: Option) { 101 | let idx_mask = Self::_idx_mask_unsafe(row, col); 102 | match piece { 103 | Some(CellValue::Opponent) => { 104 | self.mask |= idx_mask; 105 | self.value &= !idx_mask; 106 | } 107 | Some(CellValue::Player) => { 108 | self.mask |= idx_mask; 109 | self.value |= idx_mask; 110 | } 111 | None => { 112 | self.mask &= !idx_mask; 113 | self.value &= !idx_mask; 114 | } 115 | }; 116 | } 117 | 118 | /// Returns a single bit for the given row and column. 119 | const fn _idx_mask_unsafe(row: usize, col: usize) -> u64 { 120 | let idx = row * Self::N_COLS + col; 121 | 0b1 << idx 122 | } 123 | 124 | /// Inverts the colors of this position. 125 | pub fn invert(mut self) -> Pos { 126 | self.value = !self.value; 127 | self.value &= self.mask; 128 | self 129 | } 130 | 131 | /// Generates a horizontal win mask starting from the given cell. 132 | const fn _gen_h_win_mask(row: usize, col: usize) -> u64 { 133 | Self::_idx_mask_unsafe(row, col) 134 | | Self::_idx_mask_unsafe(row, col + 1) 135 | | Self::_idx_mask_unsafe(row, col + 2) 136 | | Self::_idx_mask_unsafe(row, col + 3) 137 | } 138 | 139 | /// Generates a vertical win mask starting from the given cell. 140 | const fn _gen_v_win_mask(row: usize, col: usize) -> u64 { 141 | Self::_idx_mask_unsafe(row, col) 142 | | Self::_idx_mask_unsafe(row + 1, col) 143 | | Self::_idx_mask_unsafe(row + 2, col) 144 | | Self::_idx_mask_unsafe(row + 3, col) 145 | } 146 | 147 | /// Generates a diagonal (top-left to bottom-right) win mask starting from the given cell. 148 | const fn _gen_d1_win_mask(row: usize, col: usize) -> u64 { 149 | Self::_idx_mask_unsafe(row, col) 150 | | Self::_idx_mask_unsafe(row + 1, col + 1) 151 | | Self::_idx_mask_unsafe(row + 2, col + 2) 152 | | Self::_idx_mask_unsafe(row + 3, col + 3) 153 | } 154 | 155 | /// Generates a diagonal (bottom-left to top-right) win mask starting from the given cell. 156 | const fn _gen_d2_win_mask(row: usize, col: usize) -> u64 { 157 | Self::_idx_mask_unsafe(row, col) 158 | | Self::_idx_mask_unsafe(row - 1, col + 1) 159 | | Self::_idx_mask_unsafe(row - 2, col + 2) 160 | | Self::_idx_mask_unsafe(row - 3, col + 3) 161 | } 162 | 163 | /// Represents the set of all possible wins. 164 | /// Each item is a bitmask representing the required locations of consecutive pieces. 165 | const WIN_MASKS: [u64; 69] = { 166 | // Note rust doesn't support for loops in const functions so we have to resort to while: 167 | // See: https://github.com/rust-lang/rust/issues/87575 168 | 169 | let mut masks = [0u64; 69]; 170 | let mut index = 0; 171 | 172 | // Horizontal wins 173 | let mut row = 0; 174 | while row < Self::N_ROWS { 175 | let mut col = 0; 176 | while col <= Self::N_COLS - 4 { 177 | masks[index] = Self::_gen_h_win_mask(row, col); 178 | index += 1; 179 | col += 1; 180 | } 181 | row += 1; 182 | } 183 | 184 | // Vertical wins 185 | let mut col = 0; 186 | while col < Self::N_COLS { 187 | let mut row = 0; 188 | while row <= Self::N_ROWS - 4 { 189 | masks[index] = Self::_gen_v_win_mask(row, col); 190 | index += 1; 191 | row += 1; 192 | } 193 | col += 1; 194 | } 195 | 196 | // Diagonal (top-left to bottom-right) wins 197 | row = 0; 198 | while row <= Self::N_ROWS - 4 { 199 | let mut col = 0; 200 | while col <= Self::N_COLS - 4 { 201 | masks[index] = Self::_gen_d1_win_mask(row, col); 202 | index += 1; 203 | col += 1; 204 | } 205 | row += 1; 206 | } 207 | 208 | // Diagonal (bottom-left to top-right) wins 209 | row = 3; 210 | while row < Self::N_ROWS { 211 | let mut col = 0; 212 | while col <= Self::N_COLS - 4 { 213 | masks[index] = Self::_gen_d2_win_mask(row, col); 214 | index += 1; 215 | col += 1; 216 | } 217 | row += 1; 218 | } 219 | 220 | if index != 69 { 221 | panic!("expected 69 win masks"); 222 | } 223 | masks 224 | }; 225 | 226 | /// Determines if the game is over, and if so, who won. 227 | /// If the game is not over, returns None. 228 | pub fn is_terminal_state(&self) -> Option { 229 | if self._is_terminal_for_player() { 230 | Some(TerminalState::PlayerWin) 231 | } else if self.clone().invert()._is_terminal_for_player() { 232 | Some(TerminalState::OpponentWin) 233 | } else if self.ply() == Self::N_COLS * Self::N_ROWS { 234 | Some(TerminalState::Draw) 235 | } else { 236 | None 237 | } 238 | } 239 | 240 | /// Determines if the current player has won. 241 | fn _is_terminal_for_player(&self) -> bool { 242 | let player_tokens = self.mask & self.value; 243 | for win_mask in Self::WIN_MASKS { 244 | if u64::count_ones(player_tokens & win_mask) == 4 { 245 | return true; 246 | } 247 | } 248 | false 249 | } 250 | 251 | /// Returns the f32 terminal value of the position. The first value is with the ply penalty 252 | /// and the second value is wwithout the ply penalty. Returns None if the game is not over. 253 | pub fn terminal_value_with_ply_penalty(&self, c_ply_penalty: f32) -> Option<(QValue, QValue)> { 254 | let ply_penalty_magnitude = c_ply_penalty * self.ply() as f32; 255 | self.is_terminal_state().map(|t| match t { 256 | // If the player wins, we apply a penalty to encourage shorter wins 257 | TerminalState::PlayerWin => (1.0 - ply_penalty_magnitude, 1.0), 258 | // If the player loses, we apply a penalty to encourage more drawn out games 259 | TerminalState::OpponentWin => (-1.0 + ply_penalty_magnitude, -1.0), 260 | // Drawn games do not have any ply penalty 261 | TerminalState::Draw => (0.0, 0.0), 262 | }) 263 | } 264 | 265 | /// Indicates which moves (columns) are legal to play. 266 | pub fn legal_moves(&self) -> [bool; Self::N_COLS] { 267 | let top_row = Self::N_ROWS - 1; 268 | from_fn(|col| self.get(top_row, col).is_none()) 269 | } 270 | 271 | /// Mask the policy logprobs by setting illegal moves to [f32::NEG_INFINITY]. 272 | pub fn mask_policy(&self, policy_logprobs: &mut Policy) { 273 | let legal_moves = self.legal_moves(); 274 | debug_assert_gt!( 275 | { legal_moves.iter().filter(|&&m| m).count() }, 276 | 0, 277 | "no legal moves in leaf node" 278 | ); 279 | 280 | // Mask policy for illegal moves and softmax 281 | for mov in 0..Pos::N_COLS { 282 | if !legal_moves[mov] { 283 | policy_logprobs[mov] = f32::NEG_INFINITY; 284 | } 285 | } 286 | } 287 | 288 | /// Returns a new [Pos] that is horizonitally flipped. 289 | pub fn flip_h(&self) -> Pos { 290 | let mut ret = Pos::default(); 291 | (0..Pos::N_ROWS).for_each(|row| { 292 | (0..Pos::N_COLS).for_each(|col| { 293 | if let Some(piece) = self.get(row, col) { 294 | ret._set_piece_unsafe(row, Pos::N_COLS - 1 - col, Some(piece)); 295 | } 296 | }) 297 | }); 298 | ret 299 | } 300 | 301 | /// Returns a list of moves that can be played to reach the given position. 302 | /// Note this might not be the actual move sequence that was played. 303 | /// This move sequence can be used to pass our [Pos] states to external solvers for evaluation. 304 | pub fn to_moves(&self) -> Vec { 305 | self.to_moves_rec(self.clone(), Vec::new()) 306 | .expect(format!("failed to generate moves for pos:\n{}", self).as_str()) 307 | .into_iter() 308 | .rev() 309 | .collect() 310 | } 311 | 312 | /// Returns a [Pos] from a list of moves. Panics if the moves are invalid. 313 | pub fn from_moves(moves: &[Move]) -> Pos { 314 | let mut pos = Pos::default(); 315 | for &mov in moves { 316 | pos = pos.make_move(mov).unwrap(); 317 | } 318 | pos 319 | } 320 | 321 | /// Helper function for [Self::to_moves] that attempts to recursively remove pieces from the top 322 | /// of the `temp` board until it is empty, then returns the [Move]s representing the removals. 323 | /// 324 | /// We can't remove pieces in a greedy way as that might result in "trapped" pieces. As such, 325 | /// we have to recursively backtrack and try removing pieces in a different order until we find 326 | /// an order that results in an empty board. 327 | fn to_moves_rec(&self, temp: Pos, moves: Vec) -> Option> { 328 | if temp.ply() == 0 { 329 | return Some(moves); 330 | } 331 | 332 | // Whether we are remove player 0's piece or player 1's piece 333 | let removing_p0_piece = (self.ply() % 2 == 0) ^ (temp.ply() % 2 == 0); 334 | 335 | 'next_col: for col in 0..Self::N_COLS { 336 | 'next_row: for row in (0..Self::N_ROWS).rev() { 337 | let self_piece = self.get(row, col); 338 | let temp_piece = temp.get(row, col); 339 | let should_remove_piece = if removing_p0_piece { 340 | (self_piece, temp_piece) == (Some(CellValue::Player), Some(CellValue::Player)) 341 | } else { 342 | (self_piece, temp_piece) 343 | == (Some(CellValue::Opponent), Some(CellValue::Opponent)) 344 | }; 345 | 346 | if should_remove_piece { 347 | let mut temp = temp.clone(); 348 | temp._set_piece_unsafe(row, col, None); 349 | let mut moves = moves.clone(); 350 | moves.push(col); 351 | 352 | // Recursively try to continue removing pieces, or if that fails, 353 | // try the next column instead. 354 | if let Some(ret) = self.to_moves_rec(temp, moves) { 355 | return Some(ret); 356 | } else { 357 | continue 'next_col; 358 | } 359 | } else if temp_piece.is_none() { 360 | // Already removed this piece from temp, continue to next row down 361 | continue 'next_row; 362 | } else { 363 | // No more eligible pieces in this column, continue to the next column 364 | continue 'next_col; 365 | } 366 | } 367 | } 368 | 369 | // Failed to successfully remove all pieces (i.e. stuck pieces remain). 370 | // Return None to enable the caller to backtrack. 371 | None 372 | } 373 | 374 | /// Writes the position to a buffer intended to be interpreted as a [numpy] array. 375 | /// The final array is of shape (2, 6, 7) where the first dim represents player/opponent, 376 | /// the second dim represents rows, and the final dim represents columns. The data is written 377 | /// in row-major format. 378 | pub fn write_numpy_buffer(&self, buf: &mut [f32]) { 379 | assert_eq!(buf.len(), Self::BUF_LEN); 380 | (0..Self::BUF_N_CHANNELS).for_each(|player| { 381 | (0..Self::N_ROWS).for_each(|row| { 382 | (0..Self::N_COLS).for_each(|col| { 383 | let idx = player * Self::BUF_CHANNEL_LEN + row * Self::N_COLS + col; 384 | buf[idx] = match self.get(row, col) { 385 | Some(CellValue::Player) if player == 0 => 1.0, 386 | Some(CellValue::Opponent) if player == 1 => 1.0, 387 | _ => 0.0, 388 | }; 389 | }); 390 | }); 391 | }) 392 | } 393 | } 394 | 395 | impl Display for Pos { 396 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 397 | let mut ret: Vec = Vec::with_capacity(Self::N_ROWS); 398 | for row in (0..Self::N_ROWS).rev() { 399 | let mut s = String::with_capacity(Pos::N_COLS); 400 | for col in 0..Self::N_COLS { 401 | let p = match self.get(row, col) { 402 | Some(CellValue::Player) => "🔴", 403 | Some(CellValue::Opponent) => "🔵", 404 | None => "⚫", 405 | }; 406 | s.push_str(p); 407 | } 408 | ret.push(s); 409 | } 410 | let ret = ret.join("\n"); 411 | write!(f, "{}", ret) 412 | } 413 | } 414 | 415 | impl From<&str> for Pos { 416 | fn from(s: &str) -> Self { 417 | let mut pos = Pos::default(); 418 | for (row, line) in s.lines().rev().enumerate() { 419 | for (col, c) in line.chars().enumerate() { 420 | let cell_value = match c { 421 | '🔴' => CellValue::Player, 422 | '🔵' => CellValue::Opponent, 423 | _ => continue, 424 | }; 425 | pos._set_piece_unsafe(row, col, Some(cell_value)); 426 | } 427 | } 428 | pos 429 | } 430 | } 431 | 432 | impl fmt::Debug for Pos { 433 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 434 | write!( 435 | f, 436 | "{}\nmask: {:064b}\nvalue: {:064b}", 437 | self.to_string(), 438 | self.mask, 439 | self.value 440 | ) 441 | } 442 | } 443 | 444 | impl From<&Vec> for Pos { 445 | fn from(moves: &Vec) -> Self { 446 | let mut pos = Pos::default(); 447 | for &mov in moves { 448 | pos = pos.make_move(mov).unwrap(); 449 | } 450 | pos 451 | } 452 | } 453 | 454 | #[cfg(test)] 455 | pub mod tests { 456 | use super::*; 457 | use proptest::prelude::*; 458 | 459 | // Test helpers for Pos 460 | impl Pos { 461 | fn test_move(&self, col: usize) -> Pos { 462 | self.make_move(col).unwrap() 463 | } 464 | 465 | fn test_moves(self, cols: &[usize]) -> Pos { 466 | let mut pos = self; 467 | for col in cols { 468 | pos = pos.test_move(*col) 469 | } 470 | pos 471 | } 472 | } 473 | 474 | #[test] 475 | fn playing_moves_works() { 476 | let mut pos = Pos::default(); 477 | for col in 0..Pos::N_COLS { 478 | for row in 0..Pos::N_ROWS { 479 | pos = pos.test_move(col); 480 | assert_eq!(pos.get(row, col), Some(CellValue::Opponent)); 481 | } 482 | 483 | // Playing here should overflow column 484 | assert!(!pos.legal_moves()[col]); 485 | assert_eq!(pos.make_move(col), None); 486 | } 487 | } 488 | 489 | #[test] 490 | fn row_win() { 491 | let pos = Pos::from_moves(&[0, 0, 1, 1, 2, 2, 3]); 492 | 493 | // Because the board is inverted, the last move results in the opponent winning 494 | assert_eq!(pos.is_terminal_state(), Some(TerminalState::OpponentWin)); 495 | } 496 | 497 | #[test] 498 | fn col_win() { 499 | let pos = Pos::from_moves(&[6, 0, 6, 0, 6, 0, 6]); 500 | assert_eq!(pos.is_terminal_state(), Some(TerminalState::OpponentWin)); 501 | } 502 | 503 | #[test] 504 | fn draw() { 505 | let pos = Pos::from_moves(&[ 506 | // Fill first three rows with alternating moves 507 | 0, 1, 2, 3, 4, 5, // First row 508 | 0, 1, 2, 3, 4, 5, // Second row 509 | 0, 1, 2, 3, 4, 5, // Third row 510 | // Fill fourth and fifth rows in reverse order to continue pattern 511 | 5, 4, 3, 2, 1, 0, // Fourth row 512 | 5, 4, 3, 2, 1, 0, // Fifth row 513 | 5, 4, 3, 2, 1, 0, // Sixth row 514 | // Fill the last column (column 6) to complete all rows 515 | 6, 6, 6, 6, 6, 6, // Last column full 516 | ]); 517 | 518 | // Verify if the terminal state is a draw 519 | assert_eq!(pos.is_terminal_state(), Some(TerminalState::Draw)); 520 | } 521 | 522 | #[test] 523 | fn to_str() { 524 | let pos = Pos::from_moves(&[ 525 | 0, 1, 2, 3, 4, 5, // First row 526 | 0, 1, 2, 3, 4, 5, // Second row 527 | 0, 1, 2, 3, 4, 5, // Third row 528 | 5, 4, 3, 2, 1, 0, // Fourth row 529 | 5, 4, 3, 2, 1, 0, // Fifth row 530 | 5, 4, 3, 2, 1, 0, // Sixth row 531 | 6, 6, 6, 6, 6, 6, // Last column full 532 | ]); 533 | 534 | let expected = [ 535 | "🔵🔴🔵🔴🔵🔴🔵", 536 | "🔵🔴🔵🔴🔵🔴🔴", 537 | "🔵🔴🔵🔴🔵🔴🔵", 538 | "🔴🔵🔴🔵🔴🔵🔴", 539 | "🔴🔵🔴🔵🔴🔵🔵", 540 | "🔴🔵🔴🔵🔴🔵🔴", 541 | ] 542 | .join("\n"); 543 | 544 | assert_eq!(pos.to_string(), expected); 545 | assert_eq!(Pos::from(expected.as_str()), pos); 546 | } 547 | 548 | #[test] 549 | fn legal_moves() { 550 | let mut pos = Pos::default(); 551 | assert_legal_moves(&pos, "OOOOOOO"); 552 | 553 | pos = pos.test_moves(&[ 554 | 0, 1, 2, 3, 4, 5, // First row 555 | 0, 1, 2, 3, 4, 5, // Second row 556 | 0, 1, 2, 3, 4, 5, // Third row 557 | 5, 4, 3, 2, 1, 0, // Fourth row 558 | 5, 4, 3, 2, 1, 0, // Fifth row 559 | ]); 560 | 561 | // Fill up top row 562 | assert_legal_moves(&pos, "OOOOOOO"); 563 | pos = pos.test_move(5); 564 | assert_legal_moves(&pos, "OOOOOXO"); 565 | pos = pos.test_move(4); 566 | assert_legal_moves(&pos, "OOOOXXO"); 567 | pos = pos.test_move(3); 568 | assert_legal_moves(&pos, "OOOXXXO"); 569 | pos = pos.test_move(2); 570 | assert_legal_moves(&pos, "OOXXXXO"); 571 | pos = pos.test_move(1); 572 | assert_legal_moves(&pos, "OXXXXXO"); 573 | pos = pos.test_move(0); 574 | assert_legal_moves(&pos, "XXXXXXO"); 575 | 576 | // Fill up last column 577 | pos = pos.test_moves(&[6, 6, 6, 6, 6, 6]); 578 | assert_legal_moves(&pos, "XXXXXXX"); 579 | } 580 | 581 | fn assert_legal_moves(pos: &Pos, s: &str) { 582 | let legal_moves = pos.legal_moves(); 583 | for mov in 0..Pos::N_COLS { 584 | if legal_moves[mov] && s[mov..mov + 1] != *"O" { 585 | assert!( 586 | false, 587 | "expected col {} to be legal in game\n\n{}", 588 | mov, 589 | pos.to_string() 590 | ); 591 | } else if !legal_moves[mov] && s[mov..mov + 1] != *"X" { 592 | assert!( 593 | false, 594 | "expected col {} to be illegal in game\n\n{}", 595 | mov, 596 | pos.to_string() 597 | ); 598 | } 599 | } 600 | } 601 | 602 | #[test] 603 | fn flip_h_symmetrical() { 604 | let pos = Pos::from_moves(&[3, 3, 3]); 605 | let flipped = pos.flip_h(); 606 | assert_eq!(pos, flipped); 607 | assert_eq!(pos, flipped.flip_h()); 608 | } 609 | 610 | prop_compose! { 611 | /// Strategy to generate random connect four positions. We start with a Vec of random 612 | /// columns to play in and play them in order. If any moves are invalid, we ignore them. 613 | /// This allows proptest's shrinking to undo moves to find the smallest failing case. 614 | pub fn random_pos()(moves in prop::collection::vec(0..Pos::N_COLS, 0..500)) -> Pos { 615 | let mut pos = Pos::default(); 616 | 617 | for &mov in &moves { 618 | if pos.is_terminal_state().is_some() { 619 | break 620 | } 621 | 622 | if pos.legal_moves()[mov] { 623 | pos = pos.test_move(mov); 624 | } 625 | } 626 | 627 | pos 628 | } 629 | } 630 | 631 | proptest! { 632 | /// Double flipping the position should result in the same position. 633 | #[test] 634 | fn flip_h(pos in random_pos()) { 635 | let flipped = pos.flip_h(); 636 | assert_eq!(pos, flipped.flip_h()); 637 | } 638 | 639 | /// Converting a position to a string and back should result in the same position. 640 | #[test] 641 | fn to_from_string(pos in random_pos()) { 642 | let s = pos.to_string(); 643 | assert_eq!(Pos::from(s.as_str()), pos); 644 | } 645 | 646 | /// Generating moves from a position and converting them back should result in the same pos. 647 | #[test] 648 | fn to_moves(pos in random_pos()) { 649 | let moves = pos.to_moves(); 650 | let generated = Pos::from(&moves); 651 | assert_eq!(generated, pos); 652 | } 653 | } 654 | } 655 | -------------------------------------------------------------------------------- /rust/src/interactive_play.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use parking_lot::{Mutex, MutexGuard}; 4 | 5 | use crate::{ 6 | c4r::{Move, Pos}, 7 | mcts::MctsGame, 8 | types::{EvalPosT, GameMetadata, Policy, QValue}, 9 | }; 10 | 11 | /// Enables interactive play with a game using MCTS. 12 | #[derive(Clone, Debug)] 13 | pub struct InteractivePlay { 14 | state: Arc>>, 15 | } 16 | 17 | impl InteractivePlay { 18 | pub fn new( 19 | eval_pos: E, 20 | max_mcts_iterations: usize, 21 | c_exploration: f32, 22 | c_ply_penalty: f32, 23 | ) -> Self { 24 | Self::new_from_pos( 25 | Pos::default(), 26 | eval_pos, 27 | max_mcts_iterations, 28 | c_exploration, 29 | c_ply_penalty, 30 | ) 31 | } 32 | 33 | pub fn new_from_pos( 34 | pos: Pos, 35 | eval_pos: E, 36 | max_mcts_iterations: usize, 37 | c_exploration: f32, 38 | c_ply_penalty: f32, 39 | ) -> Self { 40 | let state = State { 41 | eval_pos, 42 | game: MctsGame::new_from_pos(pos, GameMetadata::default()), 43 | max_mcts_iterations, 44 | c_exploration, 45 | c_ply_penalty, 46 | bg_thread_running: false, 47 | }; 48 | 49 | let ret = Self { 50 | state: Arc::new(Mutex::new(state)), 51 | }; 52 | ret.lock_and_ensure_bg_thread(); 53 | ret 54 | } 55 | 56 | /// Returns a snapshot of the current state of the interactive play. 57 | pub fn snapshot(&self) -> Snapshot { 58 | let state_guard = self.state.lock(); 59 | state_guard.snapshot() 60 | } 61 | 62 | /// Increases the number of MCTS iterations by the given amount. 63 | pub fn increase_mcts_iters(&self, n: usize) { 64 | let mut state_guard = self.state.lock(); 65 | state_guard.max_mcts_iterations += n; 66 | self.ensure_bg_thread(state_guard); 67 | } 68 | 69 | /// Makes the given move. 70 | pub fn make_move(&self, mov: Move) { 71 | let mut state_guard = self.state.lock(); 72 | let move_successful = state_guard.make_move(mov); 73 | if move_successful { 74 | self.ensure_bg_thread(state_guard); 75 | } 76 | } 77 | 78 | /// Makes a random move using the given temperature. 79 | pub fn make_random_move(&self, temperature: f32) { 80 | let mut state_guard = self.state.lock(); 81 | let move_successful = state_guard.make_random_move(temperature); 82 | if move_successful { 83 | self.ensure_bg_thread(state_guard); 84 | } 85 | } 86 | 87 | /// Resets the game to the starting position. 88 | pub fn reset_game(&self) { 89 | let mut state_guard = self.state.lock(); 90 | state_guard.game.reset_game(); 91 | self.ensure_bg_thread(state_guard); 92 | } 93 | 94 | /// Undoes the last move if possible. 95 | pub fn undo_move(&self) { 96 | let mut state_guard = self.state.lock(); 97 | let move_successful = state_guard.game.undo_move(); 98 | if move_successful { 99 | self.ensure_bg_thread(state_guard); 100 | } 101 | } 102 | 103 | /// Locks the state and then ensures that the background thread is running. 104 | fn lock_and_ensure_bg_thread(&self) { 105 | let state_guard = self.state.lock(); 106 | self.ensure_bg_thread(state_guard); 107 | } 108 | 109 | /// Ensures that the background thread is running with the given lock. 110 | fn ensure_bg_thread(&self, mut state_guard: MutexGuard>) { 111 | if state_guard.bg_thread_should_stop() || state_guard.bg_thread_running { 112 | return; 113 | } 114 | 115 | state_guard.bg_thread_running = true; 116 | drop(state_guard); 117 | let state = Arc::clone(&self.state); 118 | std::thread::Builder::new() 119 | .name("mcts_bg_thread".into()) 120 | .spawn(move || loop { 121 | let mut state_guard = state.lock(); 122 | if state_guard.bg_thread_should_stop() { 123 | state_guard.bg_thread_running = false; 124 | return; 125 | } 126 | 127 | state_guard.bg_thread_tick(); 128 | }) 129 | .expect("failed to start mcts_bg_thread"); 130 | } 131 | } 132 | 133 | /// The state of the interactive play. 134 | #[derive(Debug)] 135 | struct State { 136 | eval_pos: E, 137 | game: MctsGame, 138 | max_mcts_iterations: usize, 139 | c_exploration: f32, 140 | c_ply_penalty: f32, 141 | bg_thread_running: bool, 142 | } 143 | 144 | impl State { 145 | fn snapshot(&self) -> Snapshot { 146 | let mut pos = self.game.root_pos(); 147 | let mut q_penalty = self.game.root_q_with_penalty(); 148 | let mut q_no_penalty = self.game.root_q_no_penalty(); 149 | if pos.ply() % 2 == 1 { 150 | pos = pos.invert(); 151 | q_penalty = -q_penalty; 152 | q_no_penalty = -q_no_penalty; 153 | } 154 | 155 | Snapshot { 156 | pos, 157 | policy: self.game.root_policy(), 158 | q_penalty, 159 | q_no_penalty, 160 | n_mcts_iterations: self.game.root_visit_count(), 161 | max_mcts_iterations: self.max_mcts_iterations, 162 | c_exploration: self.c_exploration, 163 | c_ply_penalty: self.c_ply_penalty, 164 | bg_thread_running: self.bg_thread_running, 165 | } 166 | } 167 | 168 | /// Makes the given move returning whether it was successfully played. 169 | pub fn make_move(&mut self, mov: Move) -> bool { 170 | let pos = &self.game.root_pos(); 171 | if pos.is_terminal_state().is_some() || !pos.legal_moves()[mov] { 172 | return false; 173 | } 174 | self.game.make_move(mov, self.c_exploration); 175 | true 176 | } 177 | 178 | /// Makes a random move using the given temperature. 179 | pub fn make_random_move(&mut self, temperature: f32) -> bool { 180 | if self.game.root_pos().is_terminal_state().is_some() { 181 | return false; 182 | } 183 | self.game.make_random_move(self.c_exploration, temperature); 184 | true 185 | } 186 | 187 | /// Returns true if the background thread should stop. 188 | fn bg_thread_should_stop(&self) -> bool { 189 | self.game.root_visit_count() >= self.max_mcts_iterations 190 | || self.game.root_pos().is_terminal_state().is_some() 191 | } 192 | 193 | /// A single tick of the background thread. 194 | /// Performs a single MCTS iteration and updates the game state accordingly. 195 | fn bg_thread_tick(&mut self) { 196 | // TODO: Preemptively forward pass additional pos leafs and store their results in cache 197 | // to maximize GPU parallelism instead of evaluating a single pos at a time. 198 | let leaf_pos = self.game.leaf_pos(); 199 | let eval = self 200 | .eval_pos 201 | .eval_pos(0, vec![leaf_pos]) 202 | .into_iter() 203 | .next() 204 | .unwrap(); 205 | 206 | self.game.on_received_policy( 207 | eval.policy, 208 | eval.q_penalty, 209 | eval.q_no_penalty, 210 | self.c_exploration, 211 | self.c_ply_penalty, 212 | ); 213 | 214 | let snapshot = self.snapshot(); 215 | log::debug!( 216 | "bg_thread_tick finished; root_policy: {:?}\nroot_value: {:.2}", 217 | snapshot.policy, 218 | snapshot.q_penalty 219 | ); 220 | } 221 | } 222 | 223 | /// A snapshot of the current state of the interactive play. The snapshot is always from the 224 | /// perspective of Player 0 (i.e. odd plys have inverted [Pos] and [QValue] to reflect 225 | /// Player 0's perspective). 226 | #[derive(Debug)] 227 | pub struct Snapshot { 228 | pub pos: Pos, 229 | pub policy: Policy, 230 | pub q_penalty: QValue, 231 | pub q_no_penalty: QValue, 232 | pub n_mcts_iterations: usize, 233 | pub max_mcts_iterations: usize, 234 | pub c_exploration: f32, 235 | pub c_ply_penalty: f32, 236 | pub bg_thread_running: bool, 237 | } 238 | 239 | #[cfg(test)] 240 | mod tests { 241 | use more_asserts::assert_ge; 242 | 243 | use crate::{c4r::Pos, self_play::tests::UniformEvalPos}; 244 | 245 | use super::{InteractivePlay, Snapshot}; 246 | 247 | const TEST_C_EXPLORATION: f32 = 4.0; 248 | const TEST_C_PLY_PENALTY: f32 = 0.01; 249 | 250 | impl InteractivePlay { 251 | fn new_test(pos: Pos, max_mcts_iters: usize) -> InteractivePlay { 252 | InteractivePlay::new_from_pos( 253 | pos, 254 | UniformEvalPos {}, 255 | max_mcts_iters, 256 | TEST_C_EXPLORATION, 257 | TEST_C_PLY_PENALTY, 258 | ) 259 | } 260 | 261 | fn block_then_snapshot(&self) -> Snapshot { 262 | loop { 263 | let state_guard = self.state.lock(); 264 | if !state_guard.bg_thread_running { 265 | return state_guard.snapshot(); 266 | } 267 | std::thread::yield_now(); 268 | } 269 | } 270 | } 271 | 272 | #[test] 273 | fn forcing_position() { 274 | let pos = Pos::from( 275 | [ 276 | "⚫⚫⚫⚫⚫⚫⚫", 277 | "⚫⚫⚫⚫⚫⚫⚫", 278 | "⚫⚫⚫⚫⚫⚫⚫", 279 | "⚫⚫⚫⚫⚫⚫⚫", 280 | "⚫⚫🔵🔵⚫⚫⚫", 281 | "⚫⚫🔴🔴⚫⚫⚫", 282 | ] 283 | .join("\n") 284 | .as_str(), 285 | ); 286 | let play = InteractivePlay::new_test(pos, 10_000); 287 | let snapshot = play.block_then_snapshot(); 288 | let winning_moves = snapshot.policy[1] + snapshot.policy[4]; 289 | assert_ge!(winning_moves, 0.98); 290 | assert_ge!(snapshot.q_penalty, 0.91); 291 | assert_ge!(snapshot.q_no_penalty, 0.98); 292 | 293 | play.make_move(1); 294 | let snapshot = play.block_then_snapshot(); 295 | assert_ge!(snapshot.q_penalty, 0.91); 296 | assert_ge!(snapshot.q_no_penalty, 0.98); 297 | 298 | play.make_move(0); 299 | let snapshot = play.block_then_snapshot(); 300 | assert_ge!(snapshot.policy[4], 0.99); 301 | assert_ge!(snapshot.q_penalty, 0.91); 302 | assert_ge!(snapshot.q_no_penalty, 0.98); 303 | } 304 | } 305 | -------------------------------------------------------------------------------- /rust/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![allow(dead_code)] 2 | mod c4r; 3 | mod interactive_play; 4 | mod mcts; 5 | mod pybridge; 6 | mod self_play; 7 | mod solver; 8 | mod tui; 9 | mod types; 10 | mod utils; 11 | 12 | use c4r::Pos; 13 | use env_logger::Env; 14 | use pybridge::PlayGamesResult; 15 | use pyo3::prelude::*; 16 | use types::{GameMetadata, GameResult, Sample}; 17 | 18 | /// A Python module implemented in Rust. The name of this function must match 19 | /// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to 20 | /// import the module. 21 | #[pymodule] 22 | fn c4a0_rust(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { 23 | env_logger::Builder::from_env(Env::default().default_filter_or("info")).init(); 24 | 25 | m.add("N_COLS", Pos::N_COLS)?; 26 | m.add("N_ROWS", Pos::N_ROWS)?; 27 | m.add("BUF_N_CHANNELS", Pos::BUF_N_CHANNELS)?; 28 | 29 | m.add_class::()?; 30 | m.add_class::()?; 31 | m.add_class::()?; 32 | m.add_class::()?; 33 | 34 | m.add_function(wrap_pyfunction!(pybridge::play_games, m)?)?; 35 | m.add_function(wrap_pyfunction!(pybridge::run_tui, m)?)?; 36 | 37 | Ok(()) 38 | } 39 | -------------------------------------------------------------------------------- /rust/src/mcts.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | array, 3 | cell::RefCell, 4 | rc::{Rc, Weak}, 5 | }; 6 | 7 | use rand::{ 8 | distributions::{Distribution, WeightedIndex}, 9 | rngs::StdRng, 10 | SeedableRng, 11 | }; 12 | 13 | use crate::{ 14 | c4r::{Move, Pos}, 15 | types::{policy_from_iter, GameMetadata, GameResult, ModelID, Policy, QValue, Sample}, 16 | utils::OrdF32, 17 | }; 18 | 19 | /// A single Monte Carlo Tree Search connect four game. 20 | /// We store the MCTS tree in Vec form where child pointers are indicated by NodeId (the index 21 | /// within the Vec where the given node is stored). 22 | /// The [Self::root_id] indicates the root and the [Self::leaf_id] indicates the leaf node that has 23 | /// yet to be expanded. 24 | /// [Self::make_move] allows us to play a move (updating the root node to the played child) so we 25 | /// can preserve any prior MCTS iterations that happened through that node. 26 | #[derive(Debug, Clone)] 27 | pub struct MctsGame { 28 | metadata: GameMetadata, 29 | root: Rc>, 30 | leaf: Rc>, 31 | moves: Vec, 32 | } 33 | 34 | impl Default for MctsGame { 35 | fn default() -> Self { 36 | MctsGame::new_from_pos(Pos::default(), GameMetadata::default()) 37 | } 38 | } 39 | 40 | /// SAFETY: MctsGame is Send because it doesn't have any public methods that expose the Rc/RefCell 41 | /// allowing for illegal cross-thread mutation. 42 | unsafe impl Send for MctsGame {} 43 | 44 | impl MctsGame { 45 | pub const UNIFORM_POLICY: Policy = [1.0 / Pos::N_COLS as f32; Pos::N_COLS]; 46 | 47 | /// New game with the given id and start position. 48 | pub fn new_from_pos(pos: Pos, metadata: GameMetadata) -> MctsGame { 49 | let root_node = Rc::new(RefCell::new(Node::new(pos, Weak::new(), 1.0))); 50 | MctsGame { 51 | metadata, 52 | root: Rc::clone(&root_node), 53 | leaf: root_node, 54 | moves: Vec::new(), 55 | } 56 | } 57 | 58 | /// Gets the root position - the last moved that was played. 59 | pub fn root_pos(&self) -> Pos { 60 | self.root.borrow().pos.clone() 61 | } 62 | 63 | /// Gets the leaf node position that needs to be evaluated by the NN. 64 | pub fn leaf_pos(&self) -> Pos { 65 | self.leaf.borrow().pos.clone() 66 | } 67 | 68 | /// Gets the [ModelID] that is to play in the leaf position. The [ModelID] corresponds to which 69 | /// NN we need to call to evaluate the position. 70 | pub fn leaf_model_id_to_play(&self) -> ModelID { 71 | if self.leaf.borrow().pos.ply() % 2 == 0 { 72 | self.metadata.player0_id 73 | } else { 74 | self.metadata.player1_id 75 | } 76 | } 77 | 78 | /// Called when we receive a new policy/value from the NN forward pass for this leaf node. 79 | /// This is the heart of the MCTS algorithm: 80 | /// 1. Expands the current leaf with the given policy (if it is non-terminal) 81 | /// 2. Backpropagates up the tree with the given value (or the objective terminal value) 82 | /// 3. selects a new leaf for the next MCTS iteration. 83 | pub fn on_received_policy( 84 | &mut self, 85 | mut policy_logprobs: Policy, 86 | q_penalty: QValue, 87 | q_no_penalty: QValue, 88 | c_exploration: f32, 89 | c_ply_penalty: f32, 90 | ) { 91 | let leaf_pos = self.leaf_pos(); 92 | if let Some((q_penalty, q_no_penalty)) = 93 | leaf_pos.terminal_value_with_ply_penalty(c_ply_penalty) 94 | { 95 | // If this is a terminal state, the received policy is irrelevant. We backpropagate 96 | // the objective terminal value and select a new leaf. 97 | self.backpropagate_value(q_penalty, q_no_penalty); 98 | self.select_new_leaf(c_exploration); 99 | } else { 100 | // If this is a non-terminal state, we use the received policy to expand the leaf, 101 | // backpropagate the received value, and select a new leaf. 102 | leaf_pos.mask_policy(&mut policy_logprobs); 103 | let policy_probs = softmax(policy_logprobs); 104 | self.expand_leaf(policy_probs); 105 | self.backpropagate_value(q_penalty, q_no_penalty); 106 | self.select_new_leaf(c_exploration); 107 | } 108 | } 109 | 110 | /// Expands the the leaf by adding child nodes to it which then be eligible for exploration via 111 | /// subsequent MCTS iterations. Each child node's [Node::initial_policy_value] is determined by 112 | /// the provided policy. 113 | /// Noop for terminal nodes. 114 | fn expand_leaf(&self, policy_probs: Policy) { 115 | let leaf_pos = self.leaf_pos(); 116 | if leaf_pos.is_terminal_state().is_some() { 117 | return; 118 | } 119 | let legal_moves = leaf_pos.legal_moves(); 120 | 121 | let children: [Option>>; Pos::N_COLS] = std::array::from_fn(|m| { 122 | if legal_moves[m] { 123 | let child_pos = leaf_pos.make_move(m).unwrap(); 124 | let child = Node::new(child_pos, Rc::downgrade(&self.leaf), policy_probs[m]); 125 | Some(Rc::new(RefCell::new(child))) 126 | } else { 127 | None 128 | } 129 | }); 130 | let mut leaf = self.leaf.borrow_mut(); 131 | leaf.children = Some(children); 132 | } 133 | 134 | /// Backpropagate value up the tree, alternating value signs for each step. 135 | /// If the leaf node is a non-terminal node, the value is taken from the NN forward pass. 136 | /// If the leaf node is a terminal node, the value is the objective value of the win/loss/draw. 137 | fn backpropagate_value(&self, mut q_penalty: QValue, mut q_no_penalty: QValue) { 138 | let mut node_ref = Rc::clone(&self.leaf); 139 | loop { 140 | let mut node = node_ref.borrow_mut(); 141 | node.visit_count += 1; 142 | node.q_sum_penalty += q_penalty; 143 | node.q_sum_no_penalty += q_no_penalty; 144 | 145 | q_penalty = -q_penalty; 146 | q_no_penalty = -q_no_penalty; 147 | 148 | if let Some(parent) = node.parent.upgrade() { 149 | drop(node); // Drop node_ref borrow so we can reassign node_ref 150 | node_ref = parent; 151 | } else { 152 | break; 153 | } 154 | } 155 | } 156 | 157 | /// Select the next leaf node by traversing from the root node, repeatedly selecting the child 158 | /// with the highest [Node::uct_value] until we reach a node with no expanded children (leaf 159 | /// node). 160 | fn select_new_leaf(&mut self, c_exploration: f32) { 161 | let mut node_ref = Rc::clone(&self.root); 162 | 163 | loop { 164 | let next = node_ref.borrow().children.as_ref().and_then(|children| { 165 | children 166 | .iter() 167 | .flatten() 168 | .max_by_key(|&child| { 169 | let score = child.borrow().uct_value(c_exploration); 170 | OrdF32(score) 171 | }) 172 | .cloned() 173 | }); 174 | 175 | if let Some(next) = next { 176 | node_ref = Rc::clone(&next) 177 | } else { 178 | break; 179 | } 180 | } 181 | 182 | self.leaf = node_ref; 183 | } 184 | 185 | /// Makes a move, updating the root node to be the child node corresponding to the move. 186 | /// Stores the previous position and policy in the [Self::moves] vector. 187 | pub fn make_move(&mut self, m: Move, c_exploration: f32) { 188 | self.moves.push(RecordedMove { 189 | pos: self.root_pos(), 190 | policy: self.root_policy(), 191 | mov: m, 192 | }); 193 | 194 | let child = { 195 | let root = self.root.borrow(); 196 | let children = root.children.as_ref().expect("root node has no children"); 197 | let child = children[m as usize] 198 | .as_ref() 199 | .expect("attempted to make an invalid move"); 200 | Rc::clone(&child) 201 | }; 202 | self.root = child; 203 | 204 | // We must select a new leaf as the old leaf might not be in the subtree of the new root 205 | self.select_new_leaf(c_exploration); 206 | } 207 | 208 | /// Makes a move probabalistically based on the root node's policy. 209 | /// Uses the game_id and ply as rng seeds for deterministic sampling. 210 | /// 211 | /// The temperature parameter scales the policy probabilities, with values > 1.0 making the 212 | /// sampled distribution more uniform and values < 1.0 making the sampled distribution favor 213 | /// the most lucrative moves. 214 | pub fn make_random_move(&mut self, c_exploration: f32, temperature: f32) { 215 | let seed = self.metadata.game_id * ((Pos::N_ROWS * Pos::N_COLS) + self.moves.len()) as u64; 216 | let mut rng = StdRng::seed_from_u64(seed); 217 | let policy = self.root_policy(); 218 | let policy = apply_temperature(&policy, temperature); 219 | let dist = WeightedIndex::new(policy).unwrap(); 220 | let mov = dist.sample(&mut rng); 221 | self.make_move(mov, c_exploration); 222 | } 223 | 224 | /// Resets the game to the starting position. 225 | pub fn reset_game(&mut self) { 226 | while self.undo_move() {} 227 | } 228 | 229 | /// Undo the last move. 230 | pub fn undo_move(&mut self) -> bool { 231 | if self.moves.is_empty() { 232 | return false; 233 | } 234 | 235 | let mut moves = self.moves.clone(); 236 | let last_move = moves.pop().unwrap(); 237 | 238 | // last_move.pos is the previous position 239 | let root = Node::new(last_move.pos, Weak::new(), 1.0); 240 | let root = Rc::new(RefCell::new(root)); 241 | self.root = Rc::clone(&root); 242 | self.leaf = root; 243 | self.moves = moves; 244 | true 245 | } 246 | 247 | /// The number of visits to the root node. 248 | pub fn root_visit_count(&self) -> usize { 249 | self.root.borrow().visit_count 250 | } 251 | 252 | /// After performing many MCTS iterations, the resulting policy is determined by the visit count 253 | /// of each child (more visits implies more lucrative). 254 | pub fn root_policy(&self) -> Policy { 255 | self.root.borrow().policy() 256 | } 257 | 258 | /// The average [QValue] of the root node as a consequence of performing MCTS iterations 259 | /// (with ply penalties applied). 260 | pub fn root_q_with_penalty(&self) -> QValue { 261 | self.root.borrow().q_with_penalty() 262 | } 263 | 264 | /// The average [QValue] of the root node as a consequence of performing MCTS iterations 265 | /// (without ply penalties applied). 266 | pub fn root_q_no_penalty(&self) -> QValue { 267 | self.root.borrow().q_no_penalty() 268 | } 269 | 270 | /// Converts a finished game into a Vec of [Sample] for future NN training. 271 | pub fn to_result(self, c_ply_penalty: f32) -> GameResult { 272 | let (q_penalty, q_no_penalty) = self 273 | .root 274 | .borrow() 275 | .pos 276 | .terminal_value_with_ply_penalty(c_ply_penalty) 277 | .expect("attempted to convert a non-terminal game to a training sample"); 278 | 279 | // Q values alternate for each ply as perspective alternates between players. 280 | let mut alternating_q = vec![(q_penalty, q_no_penalty), (-q_penalty, -q_no_penalty)] 281 | .into_iter() 282 | .cycle(); 283 | if self.moves.len() % 2 == 1 { 284 | // If we have an odd number of moves (even number of total positions), the first Q value 285 | // should be inverted so that the final Q value is based on the terminal state above. 286 | alternating_q.next(); 287 | } 288 | 289 | let mut samples: Vec<_> = self 290 | .moves 291 | .iter() 292 | .zip(alternating_q) 293 | .map(|(mov, (q_penalty, q_no_penalty))| Sample { 294 | pos: mov.pos.clone(), 295 | policy: mov.policy, 296 | q_penalty, 297 | q_no_penalty, 298 | }) 299 | .collect(); 300 | 301 | // Add the final (terminal) position with an arbitray uniform policy 302 | samples.push(Sample { 303 | pos: self.root.borrow().pos.clone(), 304 | policy: MctsGame::UNIFORM_POLICY, 305 | q_penalty, 306 | q_no_penalty, 307 | }); 308 | 309 | GameResult { 310 | metadata: self.metadata.clone(), 311 | samples: samples, 312 | } 313 | } 314 | } 315 | 316 | /// Recorded move during the MCTS process. 317 | #[derive(Debug, Clone)] 318 | struct RecordedMove { 319 | pos: Pos, 320 | policy: Policy, 321 | mov: Move, 322 | } 323 | 324 | /// A node within an MCTS tree. 325 | /// [Self::parent] is a weak reference to the parent node to avoid reference cycles. 326 | /// [Self::children] is an array of optional child nodes. If a child is None, it means that the 327 | /// move is illegal. Otherwise the child is a [Rc>] reference to the child node. 328 | /// We maintain two separate Q values: one with ply penalties applied ([Self::q_sum_penalty]) and 329 | /// one without ([Self::q_sum_no_penalty]). These are normalized with [Self::visit_count] to get the 330 | /// average [QValue]s in [Self::q_with_penalty()] and [Self::q_no_penalty()]. 331 | #[derive(Debug, Clone)] 332 | struct Node { 333 | pos: Pos, 334 | parent: Weak>, 335 | visit_count: usize, 336 | q_sum_penalty: f32, 337 | q_sum_no_penalty: f32, 338 | initial_policy_value: QValue, 339 | children: Option<[Option>>; Pos::N_COLS]>, 340 | } 341 | 342 | impl Node { 343 | const EPS: f32 = 1e-8; 344 | 345 | fn new(pos: Pos, parent: Weak>, initial_policy_value: QValue) -> Node { 346 | Node { 347 | pos, 348 | parent, 349 | visit_count: 0, 350 | q_sum_penalty: 0.0, 351 | q_sum_no_penalty: 0.0, 352 | initial_policy_value, 353 | children: None, 354 | } 355 | } 356 | 357 | /// The exploitation component of the UCT value (i.e. the average win rate) with a penalty 358 | /// applied for additional plys to discourage longer sequences. 359 | fn q_with_penalty(&self) -> QValue { 360 | self.q_sum_penalty / ((self.visit_count as f32) + 1.0) 361 | } 362 | 363 | /// The exploitation component of the UCT value (i.e. the average win rate) without any 364 | /// ply penalty. 365 | fn q_no_penalty(&self) -> QValue { 366 | self.q_sum_no_penalty / ((self.visit_count as f32) + 1.0) 367 | } 368 | 369 | /// The exploration component of the UCT value. Higher visit counts result in lower values. 370 | /// We also weight the exploration value by the initial policy value to allow the network 371 | /// to guide the search. 372 | fn exploration_value(&self) -> QValue { 373 | let parent_visit_count = self 374 | .parent 375 | .upgrade() 376 | .map_or(self.visit_count as f32, |parent| { 377 | parent.borrow().visit_count as f32 378 | }) as f32; 379 | let exploration_value = (parent_visit_count.ln() / (self.visit_count as f32 + 1.)).sqrt(); 380 | exploration_value * (self.initial_policy_value + Self::EPS) 381 | } 382 | 383 | /// The UCT value of this node. Represents the lucrativeness of this node according to MCTS. 384 | /// Because [Self::uct_value] is called from the perspective of the *parent* node, we negate 385 | /// the exploration value. 386 | fn uct_value(&self, c_exploration: f32) -> QValue { 387 | -self.q_with_penalty() + c_exploration * self.exploration_value() 388 | } 389 | 390 | /// Whether the game is over (won, los, draw) from this position. 391 | fn is_terminal(&self) -> bool { 392 | self.pos.is_terminal_state().is_some() 393 | } 394 | 395 | /// Uses the child counts as weights to determine the implied policy from this position. 396 | fn policy(&self) -> Policy { 397 | if let Some(children) = &self.children { 398 | let child_counts = policy_from_iter(children.iter().map(|maybe_child| { 399 | maybe_child 400 | .as_ref() 401 | .map_or(0., |child_ref| child_ref.borrow().visit_count as f32) 402 | })); 403 | let child_counts_sum = child_counts.iter().sum::(); 404 | if child_counts_sum == 0.0 { 405 | MctsGame::UNIFORM_POLICY 406 | } else { 407 | child_counts.map(|c| c / child_counts_sum) 408 | } 409 | } else { 410 | MctsGame::UNIFORM_POLICY 411 | } 412 | } 413 | } 414 | 415 | /// Softmax function for a policy. 416 | fn softmax(policy_logprobs: Policy) -> Policy { 417 | let max = policy_logprobs 418 | .iter() 419 | .cloned() 420 | .fold(f32::NEG_INFINITY, f32::max); 421 | if max.is_infinite() { 422 | // If the policy is all negative infinity, we fall back to uniform policy. 423 | // This can happen if the NN dramatically underflows. 424 | // We panic as this is an issue that should be fixed in the NN. 425 | panic!("softmax: policy is all negative infinity, debug NN on why this is happening."); 426 | } 427 | let exps = policy_logprobs 428 | .iter() 429 | // Subtract max value to avoid overflow 430 | .map(|p| (p - max).exp()) 431 | .collect::>(); 432 | let sum = exps.iter().sum::(); 433 | array::from_fn(|i| exps[i] / sum) 434 | } 435 | 436 | /// Applies temperature scaling to a policy. 437 | /// Expects the policy to be in [0-1] (non-log) space. 438 | /// Temperature=0.0 is argmax, temperature=1.0 is a noop. 439 | pub fn apply_temperature(policy: &Policy, temperature: f32) -> Policy { 440 | if temperature == 1.0 || policy.iter().all(|&p| p == policy[0]) { 441 | // Temp 1.0 or uniform policy is noop 442 | return policy.clone(); 443 | } else if temperature == 0.0 { 444 | // Temp 0.0 is argmax 445 | let max = policy.iter().cloned().fold(f32::NEG_INFINITY, f32::max); 446 | let ret = policy.map(|p| if p == max { 1.0 } else { 0.0 }); 447 | let sum = ret.iter().sum::(); 448 | return ret.map(|p| p / sum); // Potentially multiple argmaxes 449 | } 450 | 451 | let policy_log = policy.map(|p| p.ln() / temperature); 452 | let policy_log_sum_exp = policy_log.map(|p| p.exp()).iter().sum::().ln(); 453 | policy_log.map(|p| (p - policy_log_sum_exp).exp().clamp(0.0, 1.0)) 454 | } 455 | 456 | #[cfg(test)] 457 | mod tests { 458 | use super::*; 459 | use approx::assert_relative_eq; 460 | use more_asserts::{assert_gt, assert_lt}; 461 | use proptest::prelude::*; 462 | 463 | const CONST_COL_WEIGHT: f32 = 1.0 / (Pos::N_COLS as f32); 464 | const CONST_POLICY: Policy = [CONST_COL_WEIGHT; Pos::N_COLS]; 465 | const TEST_C_EXPLORATION: f32 = 4.0; 466 | const TEST_C_PLY_PENALTY: f32 = 0.01; 467 | 468 | /// Runs a batch with a single game and a constant evaluation function. 469 | fn run_mcts(pos: Pos, n_iterations: usize) -> (Policy, QValue, QValue) { 470 | let mut game = MctsGame::new_from_pos(pos, GameMetadata::default()); 471 | for _ in 0..n_iterations { 472 | game.on_received_policy( 473 | MctsGame::UNIFORM_POLICY, 474 | 0.0, 475 | 0.0, 476 | TEST_C_EXPLORATION, 477 | TEST_C_PLY_PENALTY, 478 | ) 479 | } 480 | ( 481 | game.root_policy(), 482 | game.root_q_with_penalty(), 483 | game.root_q_no_penalty(), 484 | ) 485 | } 486 | 487 | #[test] 488 | fn mcts_prefers_center_column() { 489 | let (policy, _q_penalty, _q_no_penalty) = run_mcts(Pos::default(), 1000); 490 | assert_policy_sum_1(&policy); 491 | assert_gt!(policy[3], CONST_COL_WEIGHT); 492 | } 493 | 494 | #[test] 495 | fn mcts_depth_one() { 496 | let (policy, _q_penalty, _q_no_penalty) = 497 | run_mcts(Pos::default(), 1 + Pos::N_COLS + Pos::N_COLS); 498 | assert_policy_eq(&policy, &CONST_POLICY, Node::EPS); 499 | } 500 | 501 | #[test] 502 | fn mcts_depth_two() { 503 | let (policy, _q_penalty, _q_no_penalty) = run_mcts( 504 | Pos::default(), 505 | 1 + Pos::N_COLS + (Pos::N_COLS * Pos::N_COLS) + (Pos::N_COLS * Pos::N_COLS), 506 | ); 507 | assert_policy_eq(&policy, &CONST_POLICY, Node::EPS); 508 | } 509 | 510 | #[test] 511 | fn mcts_depth_uneven() { 512 | let (policy, _q_penalty, _q_no_penalty) = run_mcts(Pos::default(), 47); 513 | assert_policy_ne(&policy, &CONST_POLICY, Node::EPS); 514 | } 515 | 516 | /// From an obviously winning position, mcts should end up with a policy that prefers the 517 | /// winning move. 518 | #[test] 519 | fn winning_position() { 520 | let pos = Pos::from( 521 | [ 522 | "⚫⚫⚫⚫⚫⚫⚫", 523 | "⚫⚫⚫⚫⚫⚫⚫", 524 | "⚫⚫⚫⚫⚫⚫⚫", 525 | "⚫⚫⚫⚫⚫⚫⚫", 526 | "⚫🔵🔵🔵⚫⚫⚫", 527 | "⚫🔴🔴🔴⚫⚫⚫", 528 | ] 529 | .join("\n") 530 | .as_str(), 531 | ); 532 | let (policy, q_penalty, q_no_penalty) = run_mcts(pos, 10_000); 533 | let winning_moves = policy[0] + policy[4]; 534 | assert_relative_eq!(policy.iter().sum::(), 1.0); 535 | assert_gt!(winning_moves, 0.99); 536 | assert_gt!(q_penalty, 0.92); 537 | assert_gt!(q_no_penalty, 0.99); 538 | } 539 | 540 | /// From a winning position, mcts should end up with a policy that prefers the winning move. 541 | #[test] 542 | fn winning_position2() { 543 | let pos = Pos::from( 544 | [ 545 | "⚫⚫⚫⚫⚫⚫⚫", 546 | "⚫⚫⚫⚫⚫⚫⚫", 547 | "⚫⚫⚫⚫⚫⚫⚫", 548 | "⚫⚫⚫⚫⚫⚫⚫", 549 | "⚫⚫🔵🔵⚫⚫⚫", 550 | "⚫⚫🔴🔴⚫⚫⚫", 551 | ] 552 | .join("\n") 553 | .as_str(), 554 | ); 555 | let (policy, q_penalty, q_no_penalty) = run_mcts(pos, 10_000); 556 | let winning_moves = policy[1] + policy[4]; 557 | assert_gt!(winning_moves, 0.98); 558 | assert_gt!(q_penalty, 0.90); 559 | assert_gt!(q_no_penalty, 0.98); 560 | assert_gt!(q_no_penalty, q_penalty); 561 | } 562 | 563 | /// From a winning position, mcts should end up with a policy that prefers the winning move. 564 | #[test] 565 | fn winning_position3() { 566 | let pos = Pos::from( 567 | [ 568 | "⚫⚫⚫⚫⚫⚫⚫", 569 | "⚫⚫⚫⚫⚫⚫⚫", 570 | "⚫⚫⚫⚫⚫⚫⚫", 571 | "⚫🔴🔵🔵⚫⚫⚫", 572 | "⚫🔵🔴🔴🔴⚫⚫", 573 | "⚫🔵🔵🔴🔵🔴⚫", 574 | ] 575 | .join("\n") 576 | .as_str(), 577 | ); 578 | let (policy, q_penalty, q_no_penalty) = run_mcts(pos, 10_000); 579 | assert_gt!(policy[5], 0.99); 580 | assert_gt!(q_penalty, 0.86); 581 | assert_gt!(q_no_penalty, 0.99); 582 | assert_gt!(q_no_penalty, q_penalty); 583 | } 584 | 585 | /// From a definitively losing position, mcts should end up with a uniform policy because it's 586 | /// desperately trying to find a non-losing move. 587 | #[test] 588 | fn losing_position() { 589 | let pos = Pos::from( 590 | [ 591 | "⚫⚫⚫⚫⚫⚫⚫", 592 | "⚫⚫⚫⚫⚫⚫⚫", 593 | "⚫⚫⚫⚫⚫⚫⚫", 594 | "⚫⚫⚫⚫⚫⚫⚫", 595 | "⚫🔴🔴⚫⚫⚫⚫", 596 | "⚫🔵🔵🔵⚫⚫⚫", 597 | ] 598 | .join("\n") 599 | .as_str(), 600 | ); 601 | let (policy, q_penalty, q_no_penalty) = run_mcts(pos, 300_000); 602 | assert_policy_sum_1(&policy); 603 | policy.iter().for_each(|&p| { 604 | assert_relative_eq!(p, CONST_COL_WEIGHT, epsilon = 0.01); 605 | }); 606 | assert_lt!(q_penalty, -0.93); 607 | assert_lt!(q_no_penalty, -0.99); 608 | assert_lt!(q_no_penalty, q_penalty); 609 | } 610 | 611 | /// From a position with two wins, prefer the shorter win. Here, playing 0 leads to a forced 612 | /// win, but playing 4 leads to an immediate win. 613 | #[test] 614 | fn prefer_shorter_wins() { 615 | let pos = Pos::from( 616 | [ 617 | "⚫⚫⚫🔵⚫⚫⚫", 618 | "⚫🔵🔵🔵⚫⚫⚫", 619 | "⚫🔴🔵🔵⚫⚫⚫", 620 | "⚫🔴🔴🔴⚫⚫⚫", 621 | "⚫🔴🔴🔴⚫⚫⚫", 622 | "⚫🔵🔴🔵⚫⚫⚫", 623 | ] 624 | .join("\n") 625 | .as_str(), 626 | ); 627 | let (policy, q_penalty, q_no_penalty) = run_mcts(pos, 10_000); 628 | assert_gt!(policy[4], 0.99); 629 | assert_gt!(q_penalty, 0.82); 630 | assert_gt!(q_no_penalty, 0.99); 631 | assert_gt!(q_no_penalty, q_penalty); 632 | } 633 | 634 | /// Strategy for generating a policy with at least one non-zero value. 635 | fn policy_strategy() -> impl Strategy { 636 | let min = 0.0f32; 637 | let max = 10.0f32; 638 | let positive_strategy = min..max; 639 | let neg_inf_strategy = Just(f32::NEG_INFINITY); 640 | prop::array::uniform7(prop_oneof![positive_strategy, neg_inf_strategy]) 641 | .prop_filter("all neg infinity not allowed", |policy_logits| { 642 | !policy_logits.iter().all(|&p| p == f32::NEG_INFINITY) 643 | }) 644 | .prop_map(|policy_log| softmax(policy_log)) 645 | } 646 | 647 | proptest! { 648 | /// Softmax policies should sum up to one. 649 | #[test] 650 | fn softmax_sum_1(policy in policy_strategy()) { 651 | assert_policy_sum_1(&policy); 652 | } 653 | 654 | /// Temperature of 1.0 should not affect the policy. 655 | #[test] 656 | fn temperature_1(policy in policy_strategy()) { 657 | let policy_with_temp = apply_temperature(&policy, 1.0); 658 | assert_policy_eq(&policy, &policy_with_temp, 1e-5); 659 | } 660 | 661 | /// Temperature of 2.0 should change the policy. 662 | #[test] 663 | fn temperature_2(policy in policy_strategy()) { 664 | let policy_with_temp = apply_temperature(&policy, 2.0); 665 | assert_policy_sum_1(&policy_with_temp); 666 | // If policy is nonuniform and there are at least two non-zero probabilities, the 667 | // policy with temperature should be different from the original policy 668 | if policy.iter().filter(|&&p| p != CONST_COL_WEIGHT && p > 0.0).count() >= 2 { 669 | assert_policy_ne(&policy, &policy_with_temp, Node::EPS); 670 | } 671 | } 672 | 673 | /// Temperature of 0.0 should be argmax. 674 | #[test] 675 | fn temperature_0(policy in policy_strategy()) { 676 | let policy_with_temp = apply_temperature(&policy, 0.0); 677 | let max = policy_with_temp.iter().fold(f32::NEG_INFINITY, |a, &b| f32::max(a, b)); 678 | let max_count = policy_with_temp.iter().filter(|&&p| p == max).count() as f32; 679 | assert_policy_sum_1(&policy_with_temp); 680 | for p in policy_with_temp { 681 | if p == max { 682 | assert_eq!(1.0 / max_count, p); 683 | } 684 | } 685 | } 686 | } 687 | 688 | fn assert_policy_sum_1(policy: &Policy) { 689 | let sum = policy.iter().sum::(); 690 | if (sum - 1.0).abs() > 1e-5 { 691 | panic!("policy sum {:?} is not 1.0: {:?}", sum, policy); 692 | } 693 | } 694 | 695 | fn assert_policy_eq(p1: &Policy, p2: &Policy, epsilon: f32) { 696 | let eq = p1 697 | .iter() 698 | .zip(p2.iter()) 699 | .all(|(a, b)| (a - b).abs() < epsilon); 700 | if !eq { 701 | panic!("policies are not equal: {:?} {:?}", p1, p2); 702 | } 703 | } 704 | 705 | fn assert_policy_ne(p1: &Policy, p2: &Policy, epsilon: f32) { 706 | let ne = p1 707 | .iter() 708 | .zip(p2.iter()) 709 | .any(|(a, b)| (a - b).abs() > epsilon); 710 | if !ne { 711 | panic!("policies are equal: {:?} {:?}", p1, p2); 712 | } 713 | } 714 | } 715 | -------------------------------------------------------------------------------- /rust/src/pybridge.rs: -------------------------------------------------------------------------------- 1 | use numpy::{ndarray::Array4, IntoPyArray, PyArray4, PyReadonlyArray1, PyReadonlyArray2}; 2 | use pyo3::{ 3 | prelude::*, 4 | types::{PyBytes, PyList}, 5 | }; 6 | use serde::{Deserialize, Serialize}; 7 | 8 | use crate::{ 9 | c4r::Pos, 10 | self_play::self_play, 11 | solver::CachingSolver, 12 | tui, 13 | types::{EvalPosResult, EvalPosT, GameMetadata, GameResult, ModelID, Policy, Sample}, 14 | }; 15 | use rand::{rngs::StdRng, seq::SliceRandom, SeedableRng}; 16 | 17 | /// Play games via MCTS. This is a python wrapper around [self_play]. 18 | /// `reqs` is a list of [GameMetadata] that describes the games to play. 19 | /// The `py_eval_pos_cb` callback is expected to be a pytorch model that runs on the GPU. 20 | #[pyfunction] 21 | pub fn play_games<'py>( 22 | py: Python<'py>, 23 | reqs: &Bound<'py, PyList>, 24 | max_nn_batch_size: usize, 25 | n_mcts_iterations: usize, 26 | c_exploration: f32, 27 | c_ply_penalty: f32, 28 | py_eval_pos_cb: &Bound<'py, PyAny>, 29 | ) -> PyResult { 30 | let reqs: Vec = reqs.extract().expect("error extracting reqs"); 31 | 32 | let eval_pos = PyEvalPos { 33 | py_eval_pos_cb: py_eval_pos_cb.to_object(py), 34 | }; 35 | 36 | let results = { 37 | // Start background processing threads while releasing the GIL with allow_threads. 38 | // This allows other python threads (e.g. pytorch) to continue while we generate training 39 | // samples. When we need to call the py_eval_pos callback, we will re-acquire the GIL. 40 | py.allow_threads(move || { 41 | self_play( 42 | eval_pos, 43 | reqs, 44 | max_nn_batch_size, 45 | n_mcts_iterations, 46 | c_exploration, 47 | c_ply_penalty, 48 | ) 49 | }) 50 | }; 51 | 52 | Ok(PlayGamesResult { results }) 53 | } 54 | 55 | /// The result of [play_games]. 56 | /// Note we explicitly spcify pyclass(module="c4a0_rust") as the module name is required in 57 | /// order for pickling to work. 58 | #[derive(Debug, Clone, Serialize, Deserialize)] 59 | #[pyclass(module = "c4a0_rust")] 60 | pub struct PlayGamesResult { 61 | #[pyo3(get)] 62 | pub results: Vec, 63 | } 64 | 65 | #[pymethods] 66 | impl PlayGamesResult { 67 | /// Empty constructor is required for unpickling. 68 | #[new] 69 | fn new() -> Self { 70 | PlayGamesResult { results: vec![] } 71 | } 72 | 73 | fn to_cbor(&self, py: Python) -> PyResult { 74 | let cbor = serde_cbor::to_vec(self).map_err(pyify_err)?; 75 | Ok(PyBytes::new_bound(py, &cbor).into()) 76 | } 77 | 78 | #[staticmethod] 79 | fn from_cbor(_py: Python, cbor: &[u8]) -> PyResult { 80 | serde_cbor::from_slice(cbor).map_err(pyify_err) 81 | } 82 | 83 | /// Used for pickling serialization. 84 | fn __getstate__(&self, py: Python) -> PyResult { 85 | self.to_cbor(py) 86 | } 87 | 88 | /// Used for pickling deserialization. 89 | fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { 90 | let cbor: &[u8] = state.extract(py)?; 91 | *self = Self::from_cbor(py, cbor)?; 92 | Ok(()) 93 | } 94 | 95 | /// Combine two PlayGamesResult objects. 96 | fn __add__<'py>(&mut self, py: Python<'py>, other: PyObject) -> PyResult { 97 | let other = other.extract::(py)?; 98 | Ok(PlayGamesResult { 99 | results: self 100 | .results 101 | .iter() 102 | .chain(other.results.iter()) 103 | .cloned() 104 | .collect(), 105 | }) 106 | } 107 | 108 | /// Splits the results into training and test datasets. 109 | /// Ensures that whole games end up in either the training set or test set. 110 | /// Expects `train_frac` to be in [0, 1]. 111 | fn split_train_test( 112 | &mut self, 113 | train_frac: f32, 114 | seed: u64, 115 | ) -> PyResult<(Vec, Vec)> { 116 | let mut rng = StdRng::seed_from_u64(seed); 117 | self.results.shuffle(&mut rng); 118 | let n_train = (self.results.len() as f32 * train_frac).round() as usize; 119 | let (train, test) = self.results.split_at(n_train); 120 | Ok(( 121 | train.into_iter().flat_map(|r| r.samples.clone()).collect(), 122 | test.into_iter().flat_map(|r| r.samples.clone()).collect(), 123 | )) 124 | } 125 | 126 | /// Scores the policies in the results using the given solver. 127 | /// `solver_path` is the path to the solver binary, see: 128 | /// https://github.com/PascalPons/connect4 129 | /// `solver_book_path` is the path to the solver book: 130 | /// https://github.com/PascalPons/connect4/releases/tag/book?ts=2 131 | /// `solution_cache_path` is the path to the solution cache file which will be created if it 132 | /// is missing. 133 | fn score_policies( 134 | &self, 135 | solver_path: String, 136 | solver_book_path: String, 137 | solution_cache_path: String, 138 | ) -> PyResult { 139 | let solver = CachingSolver::new(solver_path, solver_book_path, solution_cache_path); 140 | let pos_and_policies = self 141 | .results 142 | .iter() 143 | .flat_map(|r| r.samples.iter()) 144 | .filter(|s| s.pos.is_terminal_state().is_none()) 145 | .map(|p| (p.pos.clone(), p.policy.clone())) 146 | .collect::>(); 147 | let scores = solver.score_policies(pos_and_policies).map_err(pyify_err)?; 148 | let n_scores = scores.len(); 149 | let avg_score = scores.into_iter().sum::() / n_scores as f32; 150 | Ok(avg_score) 151 | } 152 | 153 | /// Returns the number of unique positions in the results. 154 | fn unique_positions(&self) -> usize { 155 | self.results 156 | .iter() 157 | .flat_map(|r| r.samples.iter()) 158 | .map(|s| s.pos.clone()) 159 | .collect::>() 160 | .len() 161 | } 162 | } 163 | 164 | /// [EvalPosT] implementation that calls the `py_eval_pos_cb` python callback. 165 | struct PyEvalPos { 166 | py_eval_pos_cb: PyObject, 167 | } 168 | 169 | impl EvalPosT for PyEvalPos { 170 | /// Evaluates a batch of positions by calling the [Self::py_eval_pos_cb] callback. 171 | /// This is intended to be a pytorch model that runs on the GPU. Because this is a python 172 | /// call we need to first re-acquire the GIL to call this function from a background thread 173 | /// before performing the callback. 174 | fn eval_pos(&self, model_id: ModelID, pos: Vec) -> Vec { 175 | Python::with_gil(|py| { 176 | let batch_size = pos.len(); 177 | let pos_batch = create_pos_batch(py, &pos); 178 | 179 | let (policy, q_penalty, q_no_penalty): ( 180 | PyReadonlyArray2, 181 | PyReadonlyArray1, 182 | PyReadonlyArray1, 183 | ) = (&self 184 | .py_eval_pos_cb 185 | .call_bound(py, (model_id, pos_batch), None) 186 | .expect("Failed to call py_eval_pos_cb")) 187 | .extract(py) 188 | .expect("Failed to extract result"); 189 | 190 | let policy = policy.as_slice().expect("Failed to get policy slice"); 191 | let q_penalty = q_penalty.as_slice().expect("Failed to get value slice"); 192 | let q_no_penalty = q_no_penalty.as_slice().expect("Failed to get value slice"); 193 | 194 | (0..batch_size) 195 | .map(|i| EvalPosResult { 196 | policy: policy_from_slice(&policy[i * Pos::N_COLS..(i + 1) * Pos::N_COLS]), 197 | q_penalty: q_penalty[i], 198 | q_no_penalty: q_no_penalty[i], 199 | }) 200 | .collect() 201 | }) 202 | } 203 | } 204 | 205 | /// Creates a batch of positions in tensor format. 206 | fn create_pos_batch<'py>(py: Python<'py>, positions: &Vec) -> Bound<'py, PyArray4> { 207 | let mut buffer = vec![0.0; positions.len() * Pos::BUF_LEN]; 208 | for i in 0..positions.len() { 209 | let pos = &positions[i]; 210 | let pos_buffer = &mut buffer[i * Pos::BUF_LEN..(i + 1) * Pos::BUF_LEN]; 211 | pos.write_numpy_buffer(pos_buffer); 212 | } 213 | 214 | Array4::from_shape_vec( 215 | ( 216 | positions.len(), 217 | Pos::BUF_N_CHANNELS, 218 | Pos::N_ROWS, 219 | Pos::N_COLS, 220 | ), 221 | buffer, 222 | ) 223 | .expect("Failed to create Array4 from buffer") 224 | .into_pyarray_bound(py) 225 | } 226 | 227 | /// Convert a slice of probabilities into a [Policy]. 228 | fn policy_from_slice(policy: &[f32]) -> Policy { 229 | debug_assert_eq!(policy.len(), Pos::N_COLS); 230 | let mut ret = Policy::default(); 231 | ret.copy_from_slice(policy); 232 | ret 233 | } 234 | 235 | #[pyfunction] 236 | pub fn run_tui<'py>( 237 | py: Python<'py>, 238 | py_eval_pos_cb: &Bound<'py, PyAny>, 239 | max_mcts_iters: usize, 240 | c_exploration: f32, 241 | c_ply_penalty: f32, 242 | ) -> PyResult<()> { 243 | let eval_pos = PyEvalPos { 244 | py_eval_pos_cb: py_eval_pos_cb.to_object(py), 245 | }; 246 | 247 | // Start the TUI while releasing the GIL with allow_threads. 248 | py.allow_threads(move || { 249 | let mut terminal = tui::init()?; 250 | let mut app = tui::App::new(eval_pos, max_mcts_iters, c_exploration, c_ply_penalty); 251 | app.run(&mut terminal)?; 252 | tui::restore()?; 253 | Ok(()) 254 | }) 255 | } 256 | 257 | /// Convert a Rust error into a Python exception. 258 | fn pyify_err(e: T) -> PyErr 259 | where 260 | T: std::fmt::Debug, 261 | { 262 | PyErr::new::(format!("{:?}", e)) 263 | } 264 | -------------------------------------------------------------------------------- /rust/src/self_play.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | collections::{BTreeMap, HashMap, HashSet}, 3 | mem, 4 | sync::{ 5 | atomic::{AtomicUsize, Ordering}, 6 | Arc, 7 | }, 8 | }; 9 | 10 | use crossbeam::thread; 11 | use crossbeam_channel::{bounded, Receiver, RecvError, Sender}; 12 | use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; 13 | 14 | use crate::{ 15 | c4r::Pos, 16 | mcts::MctsGame, 17 | types::{EvalPosResult, EvalPosT, GameMetadata, GameResult, ModelID}, 18 | }; 19 | 20 | /// Generate training samples with self play and MCTS. 21 | /// 22 | /// We use a batched NN forward pass to expand a given node (to determine the initial policy values 23 | /// based on the NN's output policy). Because we want to batch these NN calls for performance, we 24 | /// partially compute many MCTS traversals simultaneously (via [MctsThread]), pausing each until we 25 | /// reach the node expansion phase. Then we are able to batch several NN calls simultaneously 26 | /// (via [NNThread]). This process ping-pongs until the game reaches a terminal state after which 27 | /// it is added to `done_queue`. 28 | /// 29 | /// We use one [NNThread] and (n-1) [MctsThread]s (where n=core count). 30 | /// The thread termination mechanism is as follows: 31 | /// 1. [MctsThread] check whether we have finished all games via `n_games_remaining` atomic. When 32 | /// the first thread detects all work is complete, it sends a [MctsJob::PoisonPill] to all 33 | /// remaining [MctsThread]s, resulting in all of these threads completing. 34 | /// 2. When the last [MctsThread] completes, it drops the last `nn_queue_tx` 35 | /// [crossbeam_channel::Sender], causing the `nn_queue_rx` [crossbeam_channel::Receiver] to 36 | /// close. This notifies the [NNThread], allowing it to close. 37 | /// 3. The main thread simply waits for all threads to complete and then returns the results 38 | /// from the `done_queue`. 39 | pub fn self_play( 40 | eval_pos: E, 41 | reqs: Vec, 42 | max_nn_batch_size: usize, 43 | n_mcts_iterations: usize, 44 | c_exploration: f32, 45 | c_ply_penalty: f32, 46 | ) -> Vec { 47 | let n_games = reqs.len(); 48 | let (pb_game_done, pb_nn_eval, pb_mcts_iter) = init_progress_bars(n_games); 49 | let (nn_queue_tx, nn_queue_rx) = bounded::(n_games); 50 | let (mcts_queue_tx, mcts_queue_rx) = bounded::(n_games); 51 | let (done_queue_tx, done_queue_rx) = bounded::(n_games); 52 | let n_games_remaining = Arc::new(AtomicUsize::new(n_games)); 53 | 54 | // Create initial games 55 | for req in reqs { 56 | let game = MctsGame::new_from_pos(Pos::default(), req); 57 | nn_queue_tx.send(game).unwrap(); 58 | } 59 | 60 | thread::scope(|s| { 61 | // NN batch inference thread 62 | let mcts_queue = mcts_queue_tx.clone(); 63 | s.builder() 64 | .name("nn_thread".into()) 65 | .spawn(move |_| { 66 | NNThread::new( 67 | nn_queue_rx, 68 | mcts_queue, 69 | max_nn_batch_size, 70 | eval_pos, 71 | pb_nn_eval, 72 | ) 73 | .loop_until_close() 74 | }) 75 | .unwrap(); 76 | 77 | // MCTS threads 78 | let n_mcts_threads = usize::max(1, num_cpus::get() - 1); 79 | for i in 0..n_mcts_threads { 80 | let nn_queue_tx = nn_queue_tx.clone(); 81 | let mcts_queue_tx = mcts_queue_tx.clone(); 82 | let mcts_queue_rx = mcts_queue_rx.clone(); 83 | let done_queue_tx = done_queue_tx.clone(); 84 | let n_games_remaining = Arc::clone(&n_games_remaining); 85 | let pb_game_done = pb_game_done.clone(); 86 | let pb_mcts_iter = pb_mcts_iter.clone(); 87 | s.builder() 88 | .name(format!("mcts_thread {}", i)) 89 | .spawn(move |_| { 90 | MctsThread { 91 | nn_queue_tx, 92 | mcts_queue_tx, 93 | mcts_queue_rx, 94 | done_queue_tx, 95 | n_games_remaining, 96 | n_mcts_iterations, 97 | n_mcts_threads, 98 | c_exploration, 99 | c_ply_penalty, 100 | pb_game_done, 101 | pb_mcts_iter, 102 | } 103 | .loop_until_close() 104 | }) 105 | .unwrap(); 106 | } 107 | 108 | // The main thread doesn't tx on any channels. Explicitly drop the txs so the zero reader 109 | // channel close mechanism enables all threads to terminate. 110 | drop(nn_queue_tx); 111 | drop(mcts_queue_tx); 112 | drop(done_queue_tx); 113 | }) 114 | .unwrap(); 115 | 116 | let ret: Vec<_> = done_queue_rx.into_iter().collect(); 117 | 118 | let unique_pos: HashSet<_> = ret 119 | .iter() 120 | .flat_map(|result| result.samples.iter().map(|s| s.pos.clone())) 121 | .collect(); 122 | println!( 123 | "Generated {} games with {} unique positions", 124 | ret.len(), 125 | unique_pos.len() 126 | ); 127 | 128 | ret 129 | } 130 | 131 | /// Performs NN batch inference by reading from the [NNThread::nn_queue]. 132 | /// Performs a batch inference of [Pos]s using [NNThread::eval_pos] with up to 133 | /// [NNThread::max_nn_batch_size] positions for the [ModelID] that has the most positions to 134 | /// evaluate. 135 | /// 136 | /// After the batch inference returns its evaluation, we send the evaluated positions back to the 137 | /// MCTS threads via [NNThread::mcts_queue]. 138 | /// 139 | /// [NNThread::loop_until_close] will continue to loop until the [NNThread::nn_queue] is closed and 140 | /// there are no more pending games to evaluate. 141 | struct NNThread { 142 | nn_queue: Receiver, 143 | mcts_queue: Sender, 144 | max_nn_batch_size: usize, 145 | eval_pos: E, 146 | pb_nn_eval: ProgressBar, 147 | pending_games: Vec, 148 | chan_closed: bool, 149 | } 150 | 151 | impl NNThread { 152 | fn new( 153 | nn_queue: Receiver, 154 | mcts_queue: Sender, 155 | max_nn_batch_size: usize, 156 | eval_pos: E, 157 | pb_nn_eval: ProgressBar, 158 | ) -> Self { 159 | Self { 160 | nn_queue, 161 | mcts_queue, 162 | max_nn_batch_size, 163 | eval_pos, 164 | pb_nn_eval, 165 | pending_games: Vec::default(), 166 | chan_closed: false, 167 | } 168 | } 169 | 170 | /// Drains any items in the [NNThread::nn_queue] into the [NNThread::pending_games] vector, 171 | /// blocking if we have no pending games yet. 172 | /// Sets [NNThread::chan_closed] when the queue closes. 173 | fn drain_queue(&mut self) { 174 | if self.pending_games.is_empty() { 175 | match self.nn_queue.recv() { 176 | Ok(game) => { 177 | self.pending_games.push(game); 178 | } 179 | Err(RecvError) => { 180 | self.chan_closed = true; 181 | return; 182 | } 183 | } 184 | } 185 | 186 | // Optimistically drain additional games from the queue. 187 | while let Ok(game) = self.nn_queue.try_recv() { 188 | self.pending_games.push(game); 189 | } 190 | } 191 | 192 | /// Main [NNThread] logic. Optimistically drain items from the queue, call [NNThread::eval_pos] 193 | /// for the [ModelID] with the most queued positions, send the evaluated positions back to the 194 | /// [NNThread::mcts_queue], and update [NNThread::pending_games] with all games that were not 195 | /// processed in this tick. 196 | fn loop_once(&mut self) { 197 | self.drain_queue(); 198 | if self.pending_games.is_empty() { 199 | // pending_games can be empty if the channel closes 200 | return; 201 | } 202 | 203 | let mut model_pos = BTreeMap::>::new(); 204 | for game in self.pending_games.iter() { 205 | let model_id = game.leaf_model_id_to_play(); 206 | let entry = model_pos.entry(model_id).or_default(); 207 | entry.insert(game.leaf_pos()); 208 | } 209 | 210 | // Select the model with the most positions and evaluate 211 | let model_id = model_pos 212 | .iter() 213 | .max_by_key(|(_, positions)| positions.len()) 214 | .map(|(model_id, _)| *model_id) 215 | .unwrap(); 216 | let pos = model_pos[&model_id] 217 | .iter() 218 | .take(self.max_nn_batch_size) 219 | .cloned() 220 | .collect::>(); 221 | self.pb_nn_eval.inc(pos.len() as u64); 222 | let evals = self.eval_pos.eval_pos(model_id, pos.clone()); 223 | let eval_map = pos.into_iter().zip(evals).collect::>(); 224 | 225 | let mut games = Vec::::default(); 226 | mem::swap(&mut self.pending_games, &mut games); 227 | for game in games.into_iter() { 228 | let pos = game.leaf_pos(); 229 | if game.leaf_model_id_to_play() != model_id || !eval_map.contains_key(&pos) { 230 | self.pending_games.push(game); 231 | continue; 232 | } 233 | 234 | let nn_result = eval_map[&pos].clone(); 235 | self.mcts_queue.send(MctsJob::Job(game, nn_result)).unwrap(); 236 | } 237 | } 238 | 239 | /// Continuously loops until the [NNThread::chan_closed] flag is set and there are no more 240 | /// pending games to evaluate. 241 | fn loop_until_close(&mut self) { 242 | while !self.chan_closed || !self.pending_games.is_empty() { 243 | self.loop_once(); 244 | } 245 | } 246 | } 247 | 248 | /// Performs MCTS iterations by reading from the [Self::mcts_queue_rx]. 249 | /// If we reach the requisite number of iterations, we probabalistically make a move with 250 | /// [MctsGame::make_move]. Then, if the game reaches a terminal position, pass the game to 251 | /// [Self::done_queue]. Otherwise, we pass back to the nn via [Self::nn_queue]. 252 | struct MctsThread { 253 | nn_queue_tx: Sender, 254 | mcts_queue_tx: Sender, 255 | mcts_queue_rx: Receiver, 256 | done_queue_tx: Sender, 257 | n_games_remaining: Arc, 258 | n_mcts_iterations: usize, 259 | n_mcts_threads: usize, 260 | c_exploration: f32, 261 | c_ply_penalty: f32, 262 | pb_game_done: ProgressBar, 263 | pb_mcts_iter: ProgressBar, 264 | } 265 | 266 | impl MctsThread { 267 | /// Main [MctsThread] logic. Returns [Loop] whether we should continue or break from the loop. 268 | fn loop_once(&mut self) -> Loop { 269 | match self.mcts_queue_rx.recv() { 270 | Ok(MctsJob::PoisonPill) => Loop::Break, 271 | Ok(MctsJob::Job(mut game, nn_result)) => { 272 | self.pb_mcts_iter.inc(1); 273 | game.on_received_policy( 274 | nn_result.policy, 275 | nn_result.q_penalty, 276 | nn_result.q_no_penalty, 277 | self.c_exploration, 278 | self.c_ply_penalty, 279 | ); 280 | 281 | // If we haven't reached the requisite number of MCTS iterations, send back to NN 282 | // to evaluate the next leaf. 283 | if game.root_visit_count() < self.n_mcts_iterations { 284 | self.nn_queue_tx.send(game).unwrap(); 285 | return Loop::Continue; 286 | } 287 | 288 | // We have reached the sufficient number of MCTS iterations to make a move. 289 | let root_pos = game.root_pos(); 290 | if root_pos.is_terminal_state().is_none() { 291 | // Make a random move according to the MCTS policy. 292 | // If we are in the early game, use a higher temperature to encourage 293 | // generating more diverse (but suboptimal) games. 294 | let ply = root_pos.ply(); 295 | let temperature = match () { 296 | _ if ply < 4 => 4.0, 297 | _ if ply < 8 => 2.0, 298 | _ => 1.0, 299 | }; 300 | game.make_random_move(self.c_exploration, temperature); 301 | self.nn_queue_tx.send(game).unwrap(); 302 | } else { 303 | // Game is over. Send to done_queue. 304 | self.n_games_remaining.fetch_sub(1, Ordering::Relaxed); 305 | self.done_queue_tx 306 | .send(game.to_result(self.c_ply_penalty)) 307 | .unwrap(); 308 | self.pb_game_done.inc(1); 309 | 310 | if self.n_games_remaining.load(Ordering::Relaxed) == 0 { 311 | // We wrote the last game. Send poison pills to remaining threads. 312 | self.terminate_and_poison_other_threads(); 313 | return Loop::Break; 314 | } 315 | } 316 | 317 | Loop::Continue 318 | } 319 | Err(RecvError) => { 320 | panic!("mcts_thread: mcts_queue unexpectedly closed") 321 | } 322 | } 323 | } 324 | 325 | fn terminate_and_poison_other_threads(&self) { 326 | self.pb_mcts_iter 327 | .finish_with_message("MCTS iterations complete"); 328 | self.pb_game_done.finish_with_message("All games generated"); 329 | for _ in 0..(self.n_mcts_threads - 1) { 330 | self.mcts_queue_tx.send(MctsJob::PoisonPill).unwrap(); 331 | } 332 | } 333 | 334 | fn loop_until_close(&mut self) { 335 | while let Loop::Continue = self.loop_once() {} 336 | } 337 | } 338 | 339 | /// A piece of work for [mcts_thread]s. [MctsJob::PoisonPill] indicates the thread should terminate. 340 | enum MctsJob { 341 | Job(MctsGame, EvalPosResult), 342 | PoisonPill, 343 | } 344 | 345 | /// Indicates whether we should continue or break from the loop. 346 | enum Loop { 347 | Break, 348 | Continue, 349 | } 350 | 351 | /// Initialize progress bars for monitoring. 352 | fn init_progress_bars(n_games: usize) -> (ProgressBar, ProgressBar, ProgressBar) { 353 | let multi_pb = MultiProgress::new(); 354 | 355 | let pb_game_done = multi_pb.add(ProgressBar::new(n_games as u64)); 356 | pb_game_done.set_style(ProgressStyle::default_bar() 357 | .template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {pos}/{len} games ({per_sec} games)") 358 | .unwrap() 359 | .progress_chars("#>-")); 360 | multi_pb.add(pb_game_done.clone()); 361 | 362 | let pb_nn_eval = multi_pb.add(ProgressBar::new_spinner()); 363 | pb_nn_eval.set_style( 364 | ProgressStyle::default_bar() 365 | .template("{spinner:.green} [{elapsed_precise}] NN evals: {pos} ({per_sec} pos)") 366 | .unwrap() 367 | .progress_chars("#>-"), 368 | ); 369 | multi_pb.add(pb_nn_eval.clone()); 370 | 371 | let pb_mcts_iter = multi_pb.add(ProgressBar::new_spinner()); 372 | pb_mcts_iter.set_style( 373 | ProgressStyle::default_bar() 374 | .template("{spinner:.green} [{elapsed_precise}] MCTS iterations: {pos} ({per_sec} it)") 375 | .unwrap() 376 | .progress_chars("#>-"), 377 | ); 378 | multi_pb.add(pb_mcts_iter.clone()); 379 | 380 | (pb_game_done, pb_nn_eval, pb_mcts_iter) 381 | } 382 | 383 | #[cfg(test)] 384 | pub mod tests { 385 | use more_asserts::{assert_ge, assert_le}; 386 | 387 | use super::*; 388 | 389 | const MAX_NN_BATCH_SIZE: usize = 10; 390 | 391 | pub struct UniformEvalPos {} 392 | impl EvalPosT for UniformEvalPos { 393 | fn eval_pos(&self, _model_id: ModelID, pos: Vec) -> Vec { 394 | assert_le!(pos.len(), MAX_NN_BATCH_SIZE); 395 | pos.into_iter() 396 | .map(|_| EvalPosResult { 397 | policy: MctsGame::UNIFORM_POLICY, 398 | q_penalty: 0.0, 399 | q_no_penalty: 0.0, 400 | }) 401 | .collect() 402 | } 403 | } 404 | 405 | #[test] 406 | fn test_self_play() { 407 | let n_games = 1; 408 | let mcts_iterations = 50; 409 | let c_exploration = 1.0; 410 | let c_ply_penalty = 0.01; 411 | let results = self_play( 412 | UniformEvalPos {}, 413 | (0..n_games) 414 | .map(|game_id| GameMetadata { 415 | game_id, 416 | player0_id: 0, 417 | player1_id: 0, 418 | }) 419 | .collect(), 420 | MAX_NN_BATCH_SIZE, 421 | mcts_iterations, 422 | c_exploration, 423 | c_ply_penalty, 424 | ); 425 | 426 | for result in results { 427 | assert_ge!(result.samples.len(), 7); 428 | assert_eq!( 429 | result 430 | .samples 431 | .iter() 432 | .filter(|sample| sample.pos == Pos::default()) 433 | .count(), 434 | 1, 435 | "game {:?} should have a single starting position", 436 | result 437 | ); 438 | 439 | let terminal_positions = result 440 | .samples 441 | .iter() 442 | .filter(|sample| sample.pos.is_terminal_state().is_some()) 443 | .collect::>(); 444 | assert_eq!( 445 | terminal_positions.len(), 446 | 1, 447 | "game {:?} should have a single terminal position", 448 | result 449 | ); 450 | let terminal_value = terminal_positions[0].q_no_penalty; 451 | if terminal_value != -1.0 && terminal_value != 0.0 && terminal_value != 1.0 { 452 | assert!( 453 | false, 454 | "expected terminal value {} to be -1, 0, or 1", 455 | terminal_value 456 | ); 457 | } 458 | } 459 | } 460 | } 461 | -------------------------------------------------------------------------------- /rust/src/solver.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | collections::{HashMap, HashSet}, 3 | error::Error, 4 | io::Write, 5 | ops, 6 | process::{Command, Stdio}, 7 | }; 8 | 9 | use rocksdb::{Options, DB}; 10 | use serde::{Deserialize, Serialize}; 11 | 12 | use crate::{c4r::Pos, types::Policy, utils::OrdF32}; 13 | 14 | /// A caching wrapper around [Solver] that caches solutions to positions in [rocksdb]. 15 | pub struct CachingSolver { 16 | solver: Solver, 17 | db: DB, 18 | } 19 | 20 | impl CachingSolver { 21 | /// path_to_solver: Path to the solver binary. 22 | /// path_to_book: Path to the solver's book file. 23 | /// path_to_solution_db: Path to the rocksdb database to cache solutions. 24 | pub fn new(path_to_solver: String, path_to_book: String, path_to_solution_db: String) -> Self { 25 | let mut options = Options::default(); 26 | options.create_if_missing(true); 27 | let db = DB::open(&options, path_to_solution_db).expect("failed to open rocksdb"); 28 | 29 | Self { 30 | solver: Solver::new(path_to_solver, path_to_book), 31 | db, 32 | } 33 | } 34 | 35 | /// Scores the given policies for the given positions based on the solutions from the solver. 36 | pub fn score_policies( 37 | &self, 38 | pos_and_policy: Vec<(Pos, Policy)>, 39 | ) -> Result, Box> { 40 | let (pos, policy): (Vec, Vec) = pos_and_policy.into_iter().unzip(); 41 | let solutions = self.solve(pos)?; 42 | let ret = solutions 43 | .into_iter() 44 | .zip(policy.into_iter()) 45 | .map(|(sol, pol)| sol.score_policy(&pol)) 46 | .collect(); 47 | Ok(ret) 48 | } 49 | 50 | /// Solves the given position, resorting to cached positions if possible, relying on 51 | /// [Self::solver] to solve missing positions, and then caches resulting solutions. 52 | fn solve(&self, pos: Vec) -> Result, Box> { 53 | let missing_pos = pos 54 | .iter() 55 | .filter(|p| self.get(&p).is_none()) 56 | .cloned() 57 | .collect::>() // Remove duplicates 58 | .into_iter() 59 | .collect::>(); 60 | 61 | log::debug!("Solving {} missing positions", missing_pos.len()); 62 | for chunk in missing_pos.chunks(100) { 63 | log::debug!("Chunk size {}", chunk.len()); 64 | let chunk_solutions = self.solver.solve(chunk)?; 65 | for (pos, solution) in chunk.into_iter().zip(chunk_solutions.into_iter()) { 66 | self.put(&pos, &solution); 67 | } 68 | } 69 | self.db.flush()?; 70 | log::debug!("Finished solving positions"); 71 | 72 | let ret = pos.into_iter().map(|pos| self.get(&pos).unwrap()).collect(); 73 | Ok(ret) 74 | } 75 | 76 | fn get(&self, pos: &Pos) -> Option { 77 | self.db 78 | .get(serde_cbor::to_vec(pos).expect("failed to serialize")) 79 | .expect("failed to get from db") 80 | .map(|bytes| serde_cbor::from_slice(&bytes).expect("failed to deserialize")) 81 | } 82 | 83 | fn put(&self, pos: &Pos, solution: &Solution) { 84 | self.db 85 | .put( 86 | serde_cbor::to_vec(pos).expect("failed to serialize"), 87 | serde_cbor::to_vec(solution).expect("failed to serialize"), 88 | ) 89 | .expect("failed to put to db"); 90 | } 91 | } 92 | 93 | #[derive(Debug, Serialize, Deserialize, Default)] 94 | struct SolutionCache(HashMap); 95 | 96 | /// Interface to PascalPons's connect4 solver: https://github.com/PascalPons/connect4 97 | /// Runs the solver in a subprocess, communicating via stdin/out. 98 | struct Solver { 99 | path_to_solver: String, 100 | path_to_book: String, 101 | } 102 | 103 | impl Solver { 104 | /// Creates a new solver with the given path to the solver binary and book file. 105 | /// Book files available here: https://github.com/PascalPons/connect4/releases/tag/book 106 | fn new(path_to_solver: String, path_to_book: String) -> Self { 107 | Self { 108 | path_to_solver, 109 | path_to_book, 110 | } 111 | } 112 | 113 | /// Calls the solver to solve the given positions. 114 | fn solve(&self, pos: &[Pos]) -> Result, Box> { 115 | let mut cmd = Command::new(self.path_to_solver.clone()) 116 | .arg("-b") 117 | .arg(self.path_to_book.clone()) 118 | .arg("-a") 119 | .stdin(Stdio::piped()) 120 | .stdout(Stdio::piped()) 121 | .stderr(Stdio::null()) 122 | .spawn()?; 123 | 124 | let stdin_bytes = pos 125 | .iter() 126 | .map(|p| { 127 | p.to_moves() 128 | .iter() 129 | .map(|m| (m + 1).to_string()) 130 | .collect::>() 131 | .join("") 132 | }) 133 | .collect::>() 134 | .join("\n") 135 | + "\n"; 136 | 137 | let mut stdin = cmd.stdin.take().ok_or("failed to open stdin")?; 138 | stdin.write_all(&stdin_bytes.into_bytes())?; 139 | drop(stdin); // Close stdin to signal we're done writing 140 | 141 | let output = cmd.wait_with_output()?; 142 | let stdout_str = String::from_utf8(output.stdout)?; 143 | 144 | let ret = stdout_str 145 | .split("\n") 146 | .filter(|l| l.len() > 1) 147 | .map(|l| { 148 | let mut nums: Vec<_> = l.trim().split(" ").collect(); 149 | if nums.len() == Pos::N_COLS + 1 { 150 | // Remove the first number which is the move sequence. 151 | // If there is no first move number, we're playing the starting position. 152 | nums.remove(0); 153 | } 154 | 155 | nums.iter() 156 | .map(|num| { 157 | num.parse::() 158 | .expect(format!("failed to parse stdout: '{}'", num).as_str()) 159 | }) 160 | .collect() 161 | }) 162 | .collect(); 163 | Ok(ret) 164 | } 165 | } 166 | 167 | /// Solution from the solver. Each index represents a column. Positive values indicate that the 168 | /// current player will win if they play in that column, negative values indicate that the current 169 | /// player will lose if they play in that column. The magnitude of the value indicates the number 170 | /// of tokens remaining at the end of the game for the current player. 171 | #[derive(Debug, Serialize, Deserialize, Clone)] 172 | struct Solution([i16; Pos::N_COLS]); 173 | 174 | impl FromIterator for Solution { 175 | fn from_iter>(iter: T) -> Self { 176 | let arr: [i16; Pos::N_COLS] = iter.into_iter().collect::>().try_into().expect(""); 177 | Solution(arr) 178 | } 179 | } 180 | 181 | impl From<[i16; Pos::N_COLS]> for Solution { 182 | fn from(arr: [i16; Pos::N_COLS]) -> Self { 183 | Solution(arr) 184 | } 185 | } 186 | 187 | impl ops::Neg for Solution { 188 | type Output = Self; 189 | 190 | fn neg(self) -> Self::Output { 191 | self.0.iter().map(|&x| -x).collect() 192 | } 193 | } 194 | 195 | impl Solution { 196 | /// Given a [Pos] and a [Policy], score the policy relative to this solution. Only considers the 197 | /// highest probability move in the policy. 198 | /// Selecting the best move according to this solution will score 1.0. 199 | /// Selecting a winning move (but not best) will score 0.5. 200 | /// Selecting a losing move will score 0.0. 201 | fn score_policy(&self, policy: &Policy) -> f32 { 202 | let &sol_max = self.0.iter().max().unwrap(); 203 | let best_moves = self 204 | .0 205 | .iter() 206 | .enumerate() 207 | .filter(|(_, &x)| x == sol_max) 208 | .map(|(i, _)| i) 209 | .collect::>(); 210 | let winning_moves = self 211 | .0 212 | .iter() 213 | .enumerate() 214 | .filter(|(_, &x)| x > 0) 215 | .map(|(i, _)| i) 216 | .collect::>(); 217 | 218 | let policy_max = policy.iter().map(|&p| OrdF32(p)).max().unwrap().0; 219 | let selected_move = policy.iter().position(|&p| p == policy_max).unwrap(); 220 | 221 | if best_moves.contains(&selected_move) { 222 | 1.0 223 | } else if winning_moves.contains(&selected_move) { 224 | 0.5 225 | } else { 226 | 0.0 227 | } 228 | } 229 | } 230 | 231 | #[cfg(test)] 232 | mod tests { 233 | use std::path::Path; 234 | 235 | use proptest::prelude::*; 236 | 237 | use crate::c4r::{tests::random_pos, Pos}; 238 | 239 | use super::*; 240 | 241 | // TODO: Dynamically pull/compile this solver in CI 242 | const PATH_TO_SOLVER: &str = "/home/advait/connect4/c4solver"; 243 | const PATH_TO_BOOK: &str = "/home/advait/connect4/7x6.book"; 244 | 245 | fn paths_exist() -> bool { 246 | Path::new(PATH_TO_SOLVER).exists() && Path::new(PATH_TO_BOOK).exists() 247 | } 248 | 249 | fn test_solver() -> Solver { 250 | Solver { 251 | path_to_solver: PATH_TO_SOLVER.to_string(), 252 | path_to_book: PATH_TO_BOOK.to_string(), 253 | } 254 | } 255 | 256 | fn test_solve(pos: Pos) -> Solution { 257 | test_solver() 258 | .solve(&vec![pos]) 259 | .unwrap() 260 | .into_iter() 261 | .next() 262 | .unwrap() 263 | } 264 | 265 | fn one_hot(idx: usize) -> Policy { 266 | let mut ret = Policy::default(); 267 | ret[idx] = 1.0; 268 | ret 269 | } 270 | 271 | #[test] 272 | fn default_pos() { 273 | if !paths_exist() { 274 | eprintln!("Warning: Skipping Solver tests because solver paths do not exist."); 275 | return; 276 | } 277 | let solution = test_solve(Pos::default()); 278 | let expected_scores = &[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]; 279 | for (i, &score) in expected_scores.iter().enumerate() { 280 | assert_eq!(solution.score_policy(&one_hot(i)), score); 281 | } 282 | } 283 | 284 | #[test] 285 | fn p0_winning_pos() { 286 | if !paths_exist() { 287 | return; 288 | } 289 | let solution = test_solve(Pos::from_moves(&[2, 2, 3, 3])); 290 | let expected_scores = &[0.0, 1.0, 0.5, 0.5, 1.0, 0.0, 0.0]; 291 | for (i, &score) in expected_scores.iter().enumerate() { 292 | assert_eq!(solution.score_policy(&one_hot(i)), score); 293 | } 294 | } 295 | 296 | #[test] 297 | fn p1_winning_pos() { 298 | if !paths_exist() { 299 | return; 300 | } 301 | let solution = test_solve(Pos::from_moves(&[0])); 302 | let expected_scores = &[0.0, 1.0, 0.5, 1.0, 0.0, 0.5, 0.0]; 303 | for (i, &score) in expected_scores.iter().enumerate() { 304 | assert_eq!(solution.score_policy(&one_hot(i)), score); 305 | } 306 | } 307 | 308 | proptest! { 309 | #[test] 310 | fn random_solutions( 311 | pos in random_pos().prop_filter( 312 | "non-terminal positions", 313 | |p| p.is_terminal_state().is_none() 314 | ) 315 | ) { 316 | if !paths_exist() { 317 | return Ok(()); 318 | } 319 | let _solution = test_solve(pos); 320 | } 321 | } 322 | } 323 | -------------------------------------------------------------------------------- /rust/src/tui.rs: -------------------------------------------------------------------------------- 1 | use std::io; 2 | use std::io::{stdout, Stdout}; 3 | use std::time::Duration; 4 | 5 | use ratatui::layout::{Constraint, Layout}; 6 | use ratatui::style::{Color, Style}; 7 | use ratatui::widgets::{Bar, BarChart, BarGroup, Padding}; 8 | use ratatui::{ 9 | backend::CrosstermBackend, 10 | buffer::Buffer, 11 | crossterm::{ 12 | event::{self, Event, KeyCode, KeyEvent, KeyEventKind}, 13 | execute, 14 | terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen}, 15 | }, 16 | layout::{Alignment, Rect}, 17 | style::Stylize, 18 | symbols::border, 19 | text::Line, 20 | widgets::{block::Title, Block, Paragraph, Widget}, 21 | Terminal, 22 | }; 23 | 24 | use crate::c4r::{CellValue, Pos, TerminalState}; 25 | use crate::interactive_play::{InteractivePlay, Snapshot}; 26 | use crate::types::{EvalPosT, QValue}; 27 | 28 | /// A type alias for the terminal type used in this application 29 | pub type Tui = Terminal>; 30 | 31 | /// Initialize the terminal 32 | pub fn init() -> io::Result { 33 | execute!(stdout(), EnterAlternateScreen)?; 34 | enable_raw_mode()?; 35 | Terminal::new(CrosstermBackend::new(stdout())) 36 | } 37 | 38 | /// Restore the terminal to its original state 39 | pub fn restore() -> io::Result<()> { 40 | execute!(stdout(), LeaveAlternateScreen)?; 41 | disable_raw_mode()?; 42 | Ok(()) 43 | } 44 | 45 | #[derive(Debug)] 46 | pub struct App { 47 | game: InteractivePlay, 48 | exit: bool, 49 | } 50 | 51 | impl App { 52 | pub fn new( 53 | eval_pos: E, 54 | max_mcts_iterations: usize, 55 | c_exploration: f32, 56 | c_ply_penalty: f32, 57 | ) -> Self { 58 | Self { 59 | game: InteractivePlay::new(eval_pos, max_mcts_iterations, c_exploration, c_ply_penalty), 60 | exit: false, 61 | } 62 | } 63 | 64 | /// runs the application's main loop until the user quits 65 | pub fn run(&mut self, terminal: &mut Tui) -> io::Result<()> { 66 | while !self.exit { 67 | terminal.draw(|frame| { 68 | let snapshot = self.game.snapshot(); 69 | draw_app(&snapshot, frame.size(), frame.buffer_mut()); 70 | })?; 71 | 72 | if event::poll(Duration::from_millis(100))? { 73 | self.handle_events()?; 74 | } 75 | } 76 | Ok(()) 77 | } 78 | 79 | /// updates the application's state based on user input 80 | fn handle_events(&mut self) -> io::Result<()> { 81 | match event::read()? { 82 | // it's important to check that the event is a key press event as 83 | // crossterm also emits key release and repeat events on Windows. 84 | Event::Key(key_event) if key_event.kind == KeyEventKind::Press => { 85 | self.handle_key_event(key_event) 86 | } 87 | _ => {} 88 | }; 89 | Ok(()) 90 | } 91 | 92 | fn handle_key_event(&mut self, key_event: KeyEvent) { 93 | match key_event.code { 94 | KeyCode::Char('b') => self.game.make_random_move(0.0), 95 | KeyCode::Char('m') => self.game.increase_mcts_iters(100), 96 | KeyCode::Char('n') => self.game.reset_game(), 97 | KeyCode::Char('r') => self.game.make_random_move(1.0), 98 | KeyCode::Char('q') => self.exit = true, 99 | KeyCode::Char('t') => self.game.increase_mcts_iters(1), 100 | KeyCode::Char('u') => self.game.undo_move(), 101 | KeyCode::Char('1') => self.game.make_move(0), 102 | KeyCode::Char('2') => self.game.make_move(1), 103 | KeyCode::Char('3') => self.game.make_move(2), 104 | KeyCode::Char('4') => self.game.make_move(3), 105 | KeyCode::Char('5') => self.game.make_move(4), 106 | KeyCode::Char('6') => self.game.make_move(5), 107 | KeyCode::Char('7') => self.game.make_move(6), 108 | _ => {} 109 | }; 110 | } 111 | } 112 | 113 | fn draw_app(snapshot: &Snapshot, rect: Rect, buf: &mut Buffer) { 114 | let title = Title::from(" c4a0 - Connect Four AlphaZero ".bold()); 115 | let outer_block = Block::bordered() 116 | .title(title.alignment(Alignment::Center)) 117 | .padding(Padding::horizontal(1)) 118 | .border_set(border::THICK); 119 | let inner = outer_block.inner(rect); 120 | outer_block.render(rect, buf); 121 | 122 | let layout = Layout::vertical([ 123 | Constraint::Length(24), // Game, Evals 124 | Constraint::Fill(1), // Policy 125 | Constraint::Length(11), // Instructions 126 | ]) 127 | .spacing(1) 128 | .split(inner); 129 | 130 | draw_game_and_evals(&snapshot, layout[0], buf); 131 | draw_policy(&snapshot, layout[1], buf); 132 | draw_instructions(layout[2], buf); 133 | } 134 | 135 | fn draw_game_and_evals(snapshot: &Snapshot, rect: Rect, buf: &mut Buffer) { 136 | let isp0 = snapshot.pos.ply() % 2 == 0; 137 | let to_play = match snapshot.pos.is_terminal_state() { 138 | Some(TerminalState::PlayerWin) if isp0 => vec![" Blue".blue(), " won".into()], 139 | Some(TerminalState::PlayerWin) => vec![" Red".red(), " won".into()], 140 | Some(TerminalState::OpponentWin) if isp0 => vec![" Blue".blue(), " won".into()], 141 | Some(TerminalState::OpponentWin) => vec![" Red".red(), " won".into()], 142 | Some(TerminalState::Draw) => vec![" Draw".gray()], 143 | None if isp0 => vec![" Red".red(), " to play".into()], 144 | None => vec![" Blue".blue(), " to play".into()], 145 | }; 146 | 147 | let block = Block::bordered() 148 | .title(" Game") 149 | .title_bottom(to_play) 150 | .padding(Padding::uniform(1)); 151 | let inner = block.inner(rect); 152 | block.render(rect, buf); 153 | 154 | let layout = Layout::horizontal([Constraint::Length(40), Constraint::Length(18)]) 155 | .spacing(1) 156 | .split(inner); 157 | 158 | draw_game_grid(&snapshot.pos, layout[0], buf); 159 | draw_evals(snapshot.q_penalty, snapshot.q_no_penalty, layout[1], buf); 160 | } 161 | 162 | fn draw_game_grid(pos: &Pos, rect: Rect, buf: &mut Buffer) { 163 | let cell_width = 5; 164 | let cell_height = 3; 165 | for row in 0..Pos::N_ROWS { 166 | for col in 0..Pos::N_COLS { 167 | let cell_rect = Rect::new( 168 | rect.left() + (col as u16 * cell_width), 169 | rect.top() + (row as u16 * cell_height), 170 | cell_width, 171 | cell_height, 172 | ) 173 | .intersection(rect); 174 | draw_game_cell( 175 | pos.get(Pos::N_ROWS - row - 1, col), 176 | row, 177 | col, 178 | cell_rect, 179 | buf, 180 | ); 181 | } 182 | } 183 | 184 | // Labels below grid 185 | for col in 0..Pos::N_COLS { 186 | let label_rect = Rect::new( 187 | rect.left() + (col as u16 * cell_width), 188 | rect.top() + (Pos::N_ROWS as u16 * cell_height) + 1, 189 | cell_width, 190 | cell_height, 191 | ); 192 | Paragraph::new(format!("{}", col + 1)) 193 | .centered() 194 | .bold() 195 | .render(label_rect, buf); 196 | } 197 | } 198 | 199 | fn draw_game_cell(value: Option, row: usize, col: usize, rect: Rect, buf: &mut Buffer) { 200 | let bg_style = match value { 201 | Some(CellValue::Player) => Style::default().bg(Color::Red), 202 | Some(CellValue::Opponent) => Style::default().bg(Color::Blue), 203 | None => Style::default(), 204 | }; 205 | for y in (rect.top() + 1)..rect.bottom() { 206 | for x in (rect.left() + 1)..rect.right() { 207 | buf.get_mut(x, y).set_style(bg_style); 208 | } 209 | } 210 | 211 | let border_style = Style::default().fg(Color::White); 212 | let mut set_border = |x, y, ch| { 213 | buf.get_mut(x, y).set_char(ch).set_style(border_style); 214 | }; 215 | 216 | // Draw horizontal top borders 217 | for x in rect.left()..rect.right() { 218 | set_border(x, rect.top(), '─'); 219 | } 220 | 221 | // Draw vertical left borders 222 | for y in rect.top()..rect.bottom() { 223 | set_border(rect.left(), y, '│'); 224 | } 225 | 226 | // Top left corners 227 | set_border( 228 | rect.left(), 229 | rect.top(), 230 | match (row, col) { 231 | (0, 0) => '┌', 232 | (0, _c) => '┬', 233 | (_r, 0) => '├', 234 | _ => '┼', 235 | }, 236 | ); 237 | 238 | if row == Pos::N_ROWS - 1 { 239 | // Draw horizontal bottom borders 240 | for x in rect.left()..rect.right() { 241 | set_border(x, rect.bottom(), '─'); 242 | } 243 | 244 | // Bottom left corners 245 | set_border( 246 | rect.left(), 247 | rect.bottom(), 248 | match col { 249 | 0 => '└', 250 | _ => '┴', 251 | }, 252 | ) 253 | } 254 | 255 | if col == Pos::N_COLS - 1 { 256 | // Draw vertical right borders 257 | for y in rect.top()..rect.bottom() { 258 | set_border(rect.right(), y, '│'); 259 | } 260 | 261 | // Top right corners 262 | set_border( 263 | rect.right(), 264 | rect.top(), 265 | match row { 266 | 0 => '┐', 267 | _ => '┤', 268 | }, 269 | ); 270 | 271 | // Single bottom right corner 272 | if row == Pos::N_ROWS - 1 { 273 | set_border(rect.right(), rect.bottom(), '┘'); 274 | } 275 | } 276 | } 277 | 278 | fn draw_evals(q_penalty: QValue, q_no_penalty: QValue, rect: Rect, buf: &mut Buffer) { 279 | let value_max = 1000u64; 280 | let q_penalty_u64 = ((q_penalty + 1.0) / 2.0 * (value_max as f32)) as u64; 281 | let q_no_penalty_u64 = ((q_no_penalty + 1.0) / 2.0 * (value_max as f32)) as u64; 282 | let bars = vec![ 283 | Bar::default() 284 | .label("Eval".into()) 285 | .value(q_penalty_u64) 286 | .text_value(format!("{:.2}", q_penalty).into()) 287 | .style(if q_penalty >= 0.0 { 288 | Style::new().red() 289 | } else { 290 | Style::new().blue() 291 | }), 292 | Bar::default() 293 | .label("Win %".into()) 294 | .value(q_no_penalty_u64) 295 | .text_value(format!("{:.0}%", q_no_penalty * 100.).into()) 296 | .style(if q_no_penalty >= 0.0 { 297 | Style::new().red() 298 | } else { 299 | Style::new().blue() 300 | }), 301 | ]; 302 | BarChart::default() 303 | .data(BarGroup::default().bars(&bars)) 304 | .bar_width((rect.width - 4) / 2 - 1) 305 | .bar_gap(2) 306 | .max(value_max) 307 | .value_style(Style::new().green().bold()) 308 | .label_style(Style::new().white()) 309 | .render(rect, buf); 310 | } 311 | 312 | fn draw_policy(snapshot: &Snapshot, rect: Rect, buf: &mut Buffer) { 313 | let mcts_status = Line::from(vec![ 314 | " ".into(), 315 | if snapshot.bg_thread_running { 316 | "MCTS running: ".green() 317 | } else { 318 | "MCTS stopped: ".red() 319 | }, 320 | snapshot.n_mcts_iterations.to_string().bold(), 321 | "/".into(), 322 | snapshot.max_mcts_iterations.to_string().bold(), 323 | ]); 324 | 325 | let policy_max = 1000u64; 326 | let bars = snapshot 327 | .policy 328 | .iter() 329 | .enumerate() 330 | .map(|(i, p)| { 331 | Bar::default() 332 | .label(format!("{}", i + 1).into()) 333 | .value((p * (policy_max as f32)) as u64) 334 | .text_value(format!("{:.2}", p)[1..].into()) 335 | }) 336 | .collect::>(); 337 | 338 | BarChart::default() 339 | .data(BarGroup::default().bars(&bars)) 340 | .bar_width(5) 341 | .bar_gap(2) 342 | .max(policy_max) 343 | .bar_style(Style::new().yellow()) 344 | .value_style(Style::new().green().bold()) 345 | .label_style(Style::new().white()) 346 | .block( 347 | Block::bordered() 348 | .title(" Policy") 349 | .title_bottom(mcts_status) 350 | .padding(Padding::uniform(1)), 351 | ) 352 | .render(rect, buf); 353 | } 354 | 355 | fn draw_instructions(rect: Rect, buf: &mut Buffer) { 356 | let instruction_text = vec![ 357 | Line::from(vec!["<1-7>".blue().bold(), " Play Move".into()]), 358 | Line::from(vec!["".blue().bold(), " Play the best move".into()]), 359 | Line::from(vec!["".blue().bold(), " Play a random move".into()]), 360 | Line::from(vec!["".blue().bold(), " More MCTS iterations".into()]), 361 | Line::from(vec!["".blue().bold(), " Undo last move".into()]), 362 | Line::from(vec!["".blue().bold(), " New game".into()]), 363 | Line::from(vec!["".blue().bold(), " Quit".into()]), 364 | ]; 365 | Paragraph::new(instruction_text) 366 | .block( 367 | Block::bordered() 368 | .title(" Instructions") 369 | .padding(Padding::uniform(1)), 370 | ) 371 | .render(rect, buf); 372 | } 373 | -------------------------------------------------------------------------------- /rust/src/types.rs: -------------------------------------------------------------------------------- 1 | use core::panic; 2 | use std::array; 3 | 4 | use numpy::{ 5 | ndarray::{Array0, Array3}, 6 | PyArray0, PyArray1, PyArray3, 7 | }; 8 | use pyo3::prelude::*; 9 | use serde::{Deserialize, Serialize}; 10 | 11 | use crate::c4r::Pos; 12 | 13 | /// Probabilities for how lucrative each column is. 14 | pub type Policy = [f32; Pos::N_COLS]; 15 | 16 | /// The lucrativeness value of a given position. This is the objective we are trying to maximize. 17 | pub type QValue = f32; 18 | 19 | /// ID of the Model's NN. 20 | pub type ModelID = u64; 21 | 22 | /// Evaluate a batch of positions with an NN forward pass. 23 | /// The ordering of the results corresponds to the ordering of the input positions. 24 | pub trait EvalPosT { 25 | fn eval_pos(&self, model_id: ModelID, pos: Vec) -> Vec; 26 | } 27 | 28 | /// The returned output from the forward pass of the NN. 29 | #[derive(Debug, Clone)] 30 | pub struct EvalPosResult { 31 | pub policy: Policy, // Probability distribution over moves from the position. 32 | pub q_penalty: QValue, // Lucrativeness [-1, 1] of the position with ply penalty. 33 | pub q_no_penalty: QValue, // Lucrativeness [-1, 1] of the position without ply penalty. 34 | } 35 | 36 | /// Metadata about a game. 37 | #[derive(Debug, Clone, Default, Serialize, Deserialize)] 38 | #[pyclass] 39 | pub struct GameMetadata { 40 | #[pyo3(get)] 41 | pub game_id: u64, 42 | 43 | #[pyo3(get)] 44 | pub player0_id: ModelID, 45 | 46 | #[pyo3(get)] 47 | pub player1_id: ModelID, 48 | } 49 | 50 | #[pymethods] 51 | impl GameMetadata { 52 | #[new] 53 | fn new(game_id: u64, player0_id: ModelID, player1_id: ModelID) -> Self { 54 | GameMetadata { 55 | game_id, 56 | player0_id, 57 | player1_id, 58 | } 59 | } 60 | } 61 | 62 | /// The finished result of a game. 63 | #[derive(Debug, Clone, Serialize, Deserialize)] 64 | #[pyclass] 65 | pub struct GameResult { 66 | #[pyo3(get)] 67 | pub metadata: GameMetadata, 68 | 69 | #[pyo3(get)] 70 | pub samples: Vec, 71 | } 72 | 73 | #[pymethods] 74 | impl GameResult { 75 | /// Returns the score of the game from the perspective of Player 0. 76 | /// If Player 0 wins, 1.0. If Player 0 loses, 0.0. If it's a draw, 0.5. 77 | fn player0_score(&self) -> f32 { 78 | for sample in self.samples.iter() { 79 | if let Some(terminal) = sample.pos.is_terminal_state() { 80 | let score = match terminal { 81 | crate::c4r::TerminalState::PlayerWin => 1.0, 82 | crate::c4r::TerminalState::OpponentWin => 0.0, 83 | crate::c4r::TerminalState::Draw => 0.5, 84 | }; 85 | 86 | // When we play positions, we flip the pieces so that the "player to play" is 87 | // activte. This means the terminal state is from the perspective of the player 88 | // who is about to player. For odd ply positions, the player to play is player 1 89 | // so we must flip the score. 90 | return if sample.pos.ply() % 2 == 1 { 91 | 1.0 - score 92 | } else { 93 | score 94 | }; 95 | } 96 | } 97 | 98 | panic!("player0_score called on an unfinished game"); 99 | } 100 | } 101 | 102 | /// A training sample generated via self-play. 103 | #[derive(Debug, Clone, Serialize, Deserialize)] 104 | #[pyclass] 105 | pub struct Sample { 106 | pub pos: Pos, 107 | pub policy: Policy, 108 | pub q_penalty: QValue, 109 | pub q_no_penalty: QValue, 110 | } 111 | 112 | #[pymethods] 113 | impl Sample { 114 | /// Returns a new sample that is flipped horizontally. 115 | pub fn flip_h(&self) -> Sample { 116 | Sample { 117 | pos: self.pos.flip_h(), 118 | policy: array::from_fn(|col| self.policy[Pos::N_COLS - 1 - col]), 119 | q_penalty: self.q_penalty, 120 | q_no_penalty: self.q_no_penalty, 121 | } 122 | } 123 | 124 | /// [numpy] representation of the sample. 125 | pub fn to_numpy<'py>( 126 | &self, 127 | py: Python<'py>, 128 | ) -> ( 129 | Bound<'py, PyArray3>, 130 | Bound<'py, PyArray1>, 131 | Bound<'py, PyArray0>, 132 | Bound<'py, PyArray0>, 133 | ) { 134 | let mut pos_buffer = vec![0.0; Pos::BUF_LEN]; 135 | self.pos.write_numpy_buffer(&mut pos_buffer); 136 | let pos = 137 | Array3::from_shape_vec([Pos::BUF_N_CHANNELS, Pos::N_ROWS, Pos::N_COLS], pos_buffer) 138 | .unwrap(); 139 | let pos = PyArray3::from_array_bound(py, &pos); 140 | let policy = PyArray1::from_slice_bound(py, &self.policy); 141 | let q_penalty = Array0::from_elem([] /* shape */, self.q_penalty); 142 | let q_penalty = PyArray0::from_array_bound(py, &q_penalty); 143 | let q_no_penalty = Array0::from_elem([] /* shape */, self.q_no_penalty); 144 | let q_no_penalty = PyArray0::from_array_bound(py, &q_no_penalty); 145 | 146 | (pos, policy, q_penalty, q_no_penalty) 147 | } 148 | 149 | /// String representation of the position. 150 | pub fn pos_str(&self) -> String { 151 | self.pos.to_string() 152 | } 153 | } 154 | 155 | pub fn policy_from_iter>(iter: I) -> Policy { 156 | let mut policy = [0.0; Pos::N_COLS]; 157 | for (i, p) in iter.into_iter().enumerate() { 158 | policy[i] = p; 159 | } 160 | policy 161 | } 162 | -------------------------------------------------------------------------------- /rust/src/utils.rs: -------------------------------------------------------------------------------- 1 | use std::cmp::Ordering; 2 | 3 | /// A wrapper around f32 that implements Ord. 4 | #[derive(Debug, Clone, Copy, PartialEq, PartialOrd)] 5 | pub struct OrdF32(pub f32); 6 | 7 | impl Eq for OrdF32 {} 8 | 9 | /// Panics if the f32 is NaN. 10 | impl Ord for OrdF32 { 11 | fn cmp(&self, other: &Self) -> Ordering { 12 | self.partial_cmp(other).unwrap() 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /src/c4a0/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/advait/c4a0/49cb6584fd4cc31e68f056f210faff9bac8823cf/src/c4a0/__init__.py -------------------------------------------------------------------------------- /src/c4a0/explore.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 3, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "from dataclasses import dataclass\n", 20 | "import os\n", 21 | "from pathlib import Path\n", 22 | "import sys\n", 23 | "from typing import List\n", 24 | "\n", 25 | "from loguru import logger\n", 26 | "import matplotlib.pyplot as plt\n", 27 | "\n", 28 | "# Add src/ to path\n", 29 | "root_dir = (Path() / \"..\" / \"..\").resolve()\n", 30 | "training_dir = str(root_dir / \"training\")\n", 31 | "cache_path = str(root_dir / \"solutions.db\")\n", 32 | "solver_path = \"/home/advait/connect4/c4solver\"\n", 33 | "book_path = \"/home/advait/connect4/7x6.book\"\n", 34 | "sys.path.append(str(root_dir / \"src\"))\n", 35 | "\n", 36 | "# Enable rust logging\n", 37 | "os.environ[\"RUST_LOG\"] = \"DEBUG\"" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 4, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "# Import must happen after modifying python path\n", 47 | "from c4a0.training import TrainingGen" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 5, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "@dataclass\n", 57 | "class GenStats:\n", 58 | " gen_numbers: List[int]\n", 59 | " gens: List[TrainingGen]\n", 60 | " set_sizes: List[int]\n", 61 | " policy_scores: List[float]\n", 62 | "\n", 63 | "def gen_stats(training_dir: str) -> GenStats:\n", 64 | " gens = TrainingGen.load_all(training_dir)\n", 65 | " gen_numbers = list(reversed(range(len(gens))))\n", 66 | " gens.pop(-1) # Zeroth gen is untrained\n", 67 | " gen_numbers.pop(-1)\n", 68 | " logger.info(f\"Computing stats for: {training_dir}\")\n", 69 | " set_sizes = [\n", 70 | " gen.get_games(str(training_dir)).unique_positions() # type: ignore\n", 71 | " for gen in gens\n", 72 | " ]\n", 73 | " policy_scores = [\n", 74 | " gen.get_games(training_dir).score_policies(solver_path, book_path, cache_path) # type: ignore\n", 75 | " for gen in gens\n", 76 | " ]\n", 77 | " logger.info(\"Done\")\n", 78 | " return GenStats(gen_numbers, gens, set_sizes, policy_scores)\n" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 6, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "def plot_single_run(s: GenStats):\n", 88 | " fig, ax1 = plt.subplots(figsize=(12, 8))\n", 89 | "\n", 90 | " # Bar plot for set_sizes\n", 91 | " ax1.bar(s.gen_numbers, s.set_sizes, color='lavender', edgecolor='black')\n", 92 | " ax1.set_ylabel('Unique training positions in generation', fontsize=14)\n", 93 | " ax1.tick_params(axis='y')\n", 94 | " ax1.set_xlabel('Generation Number', fontsize=14)\n", 95 | " ax1.set_xticks(s.gen_numbers)\n", 96 | " ax1.set_xticklabels(s.gen_numbers, ha='right', fontsize=12)\n", 97 | "\n", 98 | " # Create a second y-axis for policy_scores\n", 99 | " ax2 = ax1.twinx()\n", 100 | " ax2.plot(s.gen_numbers, s.policy_scores, color='forestgreen', marker='o', linestyle='-', linewidth=2, markersize=6)\n", 101 | " ax2.set_ylabel('% Perfect Moves')\n", 102 | "\n", 103 | " # Add data labels on line plot\n", 104 | " for i, txt in enumerate(s.policy_scores):\n", 105 | " ax2.annotate(f'{txt:.2f}', (s.gen_numbers[i], s.policy_scores[i]), textcoords=\"offset points\", xytext=(0,10), ha='center', fontsize=10)\n", 106 | "\n", 107 | " plt.title('Generation Performance', fontsize=20)\n", 108 | " fig.tight_layout()\n", 109 | " plt.show()\n", 110 | " return plt\n", 111 | "\n", 112 | "def plot_multiple_runs(stats_list: List[GenStats]):\n", 113 | " fig, ax = plt.subplots(figsize=(12, 8))\n", 114 | "\n", 115 | " for s in stats_list:\n", 116 | " ax.plot(s.gen_numbers, s.policy_scores, marker='o', linestyle='-', linewidth=2, markersize=6, label=f'Run {stats_list.index(s) + 1}')\n", 117 | "\n", 118 | " ax.set_ylabel('% Perfect Moves', fontsize=14)\n", 119 | " ax.set_xlabel('Generation Number', fontsize=14)\n", 120 | " ax.set_xticks(stats_list[0].gen_numbers)\n", 121 | " ax.set_xticklabels(stats_list[0].gen_numbers, ha='right', fontsize=12)\n", 122 | " ax.legend(title='Runs', fontsize=12)\n", 123 | "\n", 124 | " plt.title('Generation Performance Across Multiple Runs', fontsize=20)\n", 125 | " fig.tight_layout()\n", 126 | " plt.show()\n", 127 | " return plt\n", 128 | "\n", 129 | "def plot_single_dir(training_dir: str):\n", 130 | " s = gen_stats(training_dir)\n", 131 | " return plot_single_run(s)\n", 132 | "\n", 133 | "def plot_multiple_dirs(dirs: List[str]):\n", 134 | " stats_list = [gen_stats(d) for d in dirs]\n", 135 | " return plot_multiple_runs(stats_list)" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "plot_single_dir(str(root_dir / \"training\"))" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "plot_multiple_dirs([str(root_dir / \"training-sweeps\" / f\"trial_{i}\") for i in range(20)])" 154 | ] 155 | } 156 | ], 157 | "metadata": { 158 | "kernelspec": { 159 | "display_name": ".venv", 160 | "language": "python", 161 | "name": "python3" 162 | }, 163 | "language_info": { 164 | "codemirror_mode": { 165 | "name": "ipython", 166 | "version": 3 167 | }, 168 | "file_extension": ".py", 169 | "mimetype": "text/x-python", 170 | "name": "python", 171 | "nbconvert_exporter": "python", 172 | "pygments_lexer": "ipython3", 173 | "version": "3.11.6" 174 | } 175 | }, 176 | "nbformat": 4, 177 | "nbformat_minor": 2 178 | } 179 | -------------------------------------------------------------------------------- /src/c4a0/main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from pathlib import Path 4 | import sys 5 | 6 | from typing import List, Optional 7 | import warnings 8 | 9 | from loguru import logger 10 | import optuna 11 | import torch 12 | import typer 13 | 14 | # Ensure that the parent directory of this file exists on Python path 15 | parent_dir = Path(__file__).resolve().parent.parent 16 | if str(parent_dir) not in sys.path: 17 | sys.path.insert(0, str(parent_dir)) 18 | 19 | from c4a0.nn import ModelConfig # noqa: E402 20 | from c4a0.sweep import perform_hparam_sweep # noqa: E402 21 | from c4a0.tournament import ModelID, RandomPlayer, UniformPlayer # noqa: E402 22 | from c4a0.training import ( # noqa: E402 23 | SolverConfig, 24 | TrainingGen, 25 | parse_lr_schedule, 26 | training_loop, 27 | ) 28 | from c4a0.utils import get_torch_device # noqa: E402 29 | 30 | import c4a0_rust # noqa: E402 31 | 32 | app = typer.Typer() 33 | 34 | 35 | @app.command() 36 | def train( 37 | base_dir: str = "training", 38 | device: str = str(get_torch_device()), 39 | # These parameters were chosen based on the results of the nn_sweep and mcts_sweep 40 | n_self_play_games: int = 1700, 41 | n_mcts_iterations: int = 1400, 42 | c_exploration: float = 6.6, 43 | c_ply_penalty: float = 0.01, 44 | self_play_batch_size: int = 2000, 45 | training_batch_size: int = 2000, 46 | n_residual_blocks: int = 1, 47 | conv_filter_size: int = 32, 48 | n_policy_layers: int = 4, 49 | n_value_layers: int = 2, 50 | lr_schedule: List[float] = [0, 2e-3, 10, 8e-4], 51 | l2_reg: float = 4e-4, 52 | max_gens: Optional[int] = None, 53 | solver_path: Optional[str] = None, 54 | book_path: Optional[str] = None, 55 | solutions_path: str = "./solutions.db", 56 | ): 57 | """Trains a model via self-play.""" 58 | 59 | model_config = ModelConfig( 60 | n_residual_blocks=n_residual_blocks, 61 | conv_filter_size=conv_filter_size, 62 | n_policy_layers=n_policy_layers, 63 | n_value_layers=n_value_layers, 64 | lr_schedule=parse_lr_schedule(lr_schedule), 65 | l2_reg=l2_reg, 66 | ) 67 | 68 | if solver_path and book_path: 69 | logger.info("Using solver") 70 | solver_config = SolverConfig( 71 | solver_path=solver_path, 72 | book_path=book_path, 73 | solutions_path=solutions_path, 74 | ) 75 | else: 76 | logger.info("Solver not provided, skipping solutions") 77 | solver_config = None 78 | 79 | training_loop( 80 | base_dir=base_dir, 81 | device=torch.device(device), 82 | n_self_play_games=n_self_play_games, 83 | n_mcts_iterations=n_mcts_iterations, 84 | c_exploration=c_exploration, 85 | c_ply_penalty=c_ply_penalty, 86 | self_play_batch_size=self_play_batch_size, 87 | training_batch_size=training_batch_size, 88 | model_config=model_config, 89 | max_gens=max_gens, 90 | solver_config=solver_config, 91 | ) 92 | 93 | 94 | @app.command() 95 | def play( 96 | base_dir: str = "training", 97 | max_mcts_iters: int = 1400, 98 | c_exploration: float = 6.6, 99 | c_ply_penalty: float = 0.01, 100 | model: str = "best", 101 | ): 102 | """Play interactive games""" 103 | gen = TrainingGen.load_latest(base_dir) 104 | if model == "best": 105 | nn = gen.get_model(base_dir) 106 | elif model == "random": 107 | nn = RandomPlayer(ModelID(0)) 108 | elif model == "uniform": 109 | nn = UniformPlayer(ModelID(0)) 110 | else: 111 | raise ValueError(f"unrecognized model: {model}") 112 | 113 | c4a0_rust.run_tui( # type: ignore 114 | lambda model_id, x: nn.forward_numpy(x), 115 | max_mcts_iters, 116 | c_exploration, 117 | c_ply_penalty, 118 | ) 119 | 120 | 121 | @app.command() 122 | def nn_sweep(base_dir: str = "training"): 123 | """ 124 | Performs a hyperparameter sweep to determine best nn model params based on existing training 125 | data. 126 | """ 127 | perform_hparam_sweep(base_dir) 128 | 129 | 130 | @app.command() 131 | def mcts_sweep( 132 | device: str = str(get_torch_device()), 133 | c_ply_penalty: float = 0.01, 134 | self_play_batch_size: int = 2000, 135 | training_batch_size: int = 2000, 136 | # These NN parameters were chosen based on the results of the nn_sweep 137 | n_residual_blocks: int = 1, 138 | conv_filter_size: int = 32, 139 | n_policy_layers: int = 4, 140 | n_value_layers: int = 2, 141 | lr_schedule: List[float] = [0, 2e-3], 142 | l2_reg: float = 4e-4, 143 | # End NN parameters 144 | base_training_dir: str = "training-sweeps", 145 | optuna_db_path: str = "optuna.db", 146 | n_trials: int = 100, 147 | max_gens_per_trial: int = 10, 148 | solver_path: str = "/home/advait/connect4/c4solver", 149 | book_path: str = "/home/advait/connect4/7x6.book", 150 | solutions_path: str = "./solutions.db", 151 | ): 152 | """ 153 | Performs sweep of MCTS hyperparameters (e.g. n_self_play_games, n_mcts_iterations, 154 | c_exploration) to determine optimal values by performing `n_trials` independent training 155 | runs, each with `max_gens_per_trial` generations, seeking to maximize the solver score. 156 | """ 157 | base_path = Path(base_training_dir) 158 | base_path.mkdir(exist_ok=True) 159 | 160 | model_config = ModelConfig( 161 | n_residual_blocks=n_residual_blocks, 162 | conv_filter_size=conv_filter_size, 163 | n_policy_layers=n_policy_layers, 164 | n_value_layers=n_value_layers, 165 | lr_schedule=parse_lr_schedule(lr_schedule), 166 | l2_reg=l2_reg, 167 | ) 168 | 169 | def objective(trial: optuna.Trial): 170 | trial_path = base_path / f"trial_{trial.number}" 171 | trial_path.mkdir(exist_ok=False) 172 | gen = training_loop( 173 | base_dir=str(trial_path), 174 | device=torch.device(device), 175 | n_self_play_games=trial.suggest_int("n_self_play_games", 1000, 5000), 176 | n_mcts_iterations=trial.suggest_int("n_mcts_iterations", 100, 1500), 177 | c_exploration=trial.suggest_float("c_exploration", 0.5, 12.0), 178 | c_ply_penalty=c_ply_penalty, 179 | self_play_batch_size=self_play_batch_size, 180 | training_batch_size=training_batch_size, 181 | model_config=model_config, 182 | max_gens=max_gens_per_trial, 183 | solver_config=SolverConfig( 184 | solver_path=solver_path, 185 | book_path=book_path, 186 | solutions_path=solutions_path, 187 | ), 188 | ) 189 | logger.info( 190 | "Trial {} completed. Solver score: {}", trial.number, gen.solver_score 191 | ) 192 | score = gen.solver_score 193 | assert score is not None 194 | return score 195 | 196 | storage_name = f"sqlite:///{optuna_db_path}" 197 | study = optuna.create_study( 198 | study_name="mcts_sweep", 199 | storage=storage_name, 200 | load_if_exists=True, 201 | direction="maximize", 202 | pruner=optuna.pruners.MedianPruner(), 203 | ) 204 | study.optimize(objective, n_trials=n_trials) 205 | 206 | 207 | @app.command() 208 | def score( 209 | solver_path: str, 210 | book_path: str, 211 | base_dir: str = "training", 212 | solutions_path: str = "./solutions.db", 213 | ): 214 | """Scores the training generations using the given solver.""" 215 | gens = TrainingGen.load_all(base_dir) 216 | for gen in gens: 217 | logger.info("Getting games for: {}", gen.gen_n) 218 | games = gen.get_games(base_dir) # type: ignore 219 | if not games: 220 | continue 221 | if gen.solver_score is not None: 222 | logger.info(f"Gen already has score: {gen.solver_score}") 223 | continue 224 | score = games.score_policies(solver_path, book_path, solutions_path) # type: ignore 225 | gen.solver_score = score 226 | gen.save_metadata(base_dir) 227 | logger.info("Gen {} has score: {}", gen.gen_n, score) 228 | 229 | 230 | if __name__ == "__main__": 231 | # Disable unnecessary pytorch warnings 232 | warnings.filterwarnings("ignore", ".*does not have many workers.*") 233 | 234 | app() 235 | -------------------------------------------------------------------------------- /src/c4a0/nn.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple 2 | 3 | from loguru import logger 4 | import numpy as np 5 | from pydantic import BaseModel 6 | import torch 7 | import torch.nn as nn 8 | import torchmetrics 9 | import pytorch_lightning as pl 10 | from einops import rearrange 11 | 12 | from c4a0_rust import N_COLS, N_ROWS # type: ignore 13 | 14 | 15 | class ModelConfig(BaseModel): 16 | """Configuration for ConnectFourNet.""" 17 | 18 | n_residual_blocks: int 19 | """The number of residual blocks.""" 20 | 21 | conv_filter_size: int 22 | """The number of filters for the conv layers in the residual blocks.""" 23 | 24 | n_policy_layers: int 25 | """Number of fully connected layers in the policy head.""" 26 | 27 | n_value_layers: int 28 | """Number of fully connected layers in the value head.""" 29 | 30 | lr_schedule: Dict[int, float] 31 | """ 32 | Learning rate schedule. The first item in the tuple indicates which gen_n the learning rate 33 | begins to be effective for. The second item in the tuple is the learning rate. 34 | """ 35 | 36 | l2_reg: float 37 | """L2 weight decay regularization for the optimizer""" 38 | 39 | 40 | class ConnectFourNet(pl.LightningModule): 41 | """ 42 | A CNN that takes in as input connect four positions and outputs a policy (in logprob space) 43 | and two Q Values. The policy is a (log) probability distribution over moves. q_penalty 44 | represents the predicted lucrativeness of the position between [-1, +1] where -1 is a 45 | definitively losing position and +1 is a definitively winning position where there is a 46 | penalty applied based on the number of plys (to encourage faster wins). q_no_penalty is the 47 | lucrativeness without the ply penalty. 48 | 49 | The outputs of this network are used to guide MCTS. 50 | The outputs of MCTS are used to train the next network. 51 | 52 | The network consists of a sequence of ResidualBlocks (CNN + CNN + BatchNormalization + Relu) 53 | followed by separate fully connected policy and value heads. 54 | """ 55 | 56 | EPS = 1e-8 # Epsilon small constant to avoid log(0) 57 | 58 | def __init__(self, config: ModelConfig): 59 | super().__init__() 60 | self.lr_schedule = config.lr_schedule 61 | self.l2_reg = config.l2_reg 62 | 63 | self.conv = nn.Sequential( 64 | nn.Conv2d(2, config.conv_filter_size, kernel_size=3, padding=1), 65 | *[ 66 | ResidualBlock(config.conv_filter_size) 67 | for i in range(config.n_residual_blocks) 68 | ], 69 | ) 70 | 71 | fc_size = self._calculate_conv_output_size() 72 | 73 | # Policy head 74 | self.fc_policy = nn.Sequential( 75 | *[ 76 | nn.Sequential( 77 | nn.Linear(fc_size, fc_size), 78 | nn.BatchNorm1d(fc_size), 79 | nn.ReLU(), 80 | ) 81 | for _ in range(config.n_policy_layers - 1) 82 | ], 83 | nn.Linear(fc_size, N_COLS), 84 | nn.LogSoftmax(dim=1), 85 | ) 86 | 87 | # Q Value head, one output dim for q_penalty and another for q_no_penalty 88 | self.fc_value = nn.Sequential( 89 | *[ 90 | nn.Sequential( 91 | nn.Linear(fc_size, fc_size), 92 | nn.BatchNorm1d(fc_size), 93 | nn.ReLU(), 94 | ) 95 | for _ in range(config.n_value_layers - 1) 96 | ], 97 | nn.Linear(fc_size, 2), 98 | nn.Tanh(), 99 | ) 100 | 101 | # Metrics 102 | self.policy_kl_div = torchmetrics.KLDivergence(log_prob=True) 103 | self.q_penalty_mse = torchmetrics.MeanSquaredError() 104 | self.q_no_penalty_mse = torchmetrics.MeanSquaredError() 105 | 106 | self.save_hyperparameters(config.model_dump()) 107 | 108 | def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 109 | x = self.conv(x) 110 | x = rearrange(x, "b c h w -> b (c h w)") 111 | policy_logprobs = self.fc_policy(x) 112 | q_values = self.fc_value(x) # b 2 113 | q_penalty, q_no_penalty = q_values.split(1, dim=1) # both: b 1 114 | q_penalty = q_penalty.squeeze(1) # b 115 | q_no_penalty = q_no_penalty.squeeze(1) # b 116 | return policy_logprobs, q_penalty, q_no_penalty 117 | 118 | def forward_numpy(self, x: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 119 | """ 120 | Forward pass with input/output as numpy. Model is run in inference mode. Used for self play. 121 | """ 122 | self.eval() 123 | pos = torch.from_numpy(x).to(self.device) 124 | with torch.no_grad(): 125 | policy, q_penalty, q_no_penalty = self.forward(pos) 126 | policy = policy.cpu().numpy() 127 | q_penalty = q_penalty.cpu().numpy() 128 | q_no_penalty = q_no_penalty.cpu().numpy() 129 | return policy, q_penalty, q_no_penalty 130 | 131 | def _calculate_conv_output_size(self): 132 | """Helper function to calculate the output size of the convolutional block.""" 133 | # Apply the convolutional layers to a dummy input 134 | dummy_input = torch.zeros(1, 2, N_ROWS, N_COLS) 135 | with torch.no_grad(): 136 | dummy_output = self.conv(dummy_input) 137 | return int(torch.numel(dummy_output)) 138 | 139 | def configure_optimizers(self): 140 | gen_n: int = self.trainer.gen_n # type: ignore 141 | assert gen_n is not None, "please pass gen_n to trainer" 142 | schedule = sorted(list(self.lr_schedule.items())) 143 | _, lr = schedule.pop(0) 144 | for gen_threshold, gen_rate in schedule: 145 | if gen_n < gen_threshold: 146 | break 147 | lr = gen_rate 148 | 149 | logger.info("using lr {} for gen_n {}", lr, gen_n) 150 | optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=self.l2_reg) 151 | return optimizer 152 | 153 | def training_step(self, batch, batch_idx): 154 | return self.step(batch, log_prefix="train") 155 | 156 | def validation_step(self, batch, batch_idx): 157 | return self.step(batch, log_prefix="val") 158 | 159 | def step(self, batch, log_prefix): 160 | # Forward pass 161 | pos, policy_target, q_penalty_target, q_no_penalty_target = batch 162 | policy_logprob, q_penalty_pred, q_no_penalty_pred = self.forward(pos) 163 | policy_logprob_targets = torch.log(policy_target + self.EPS) 164 | 165 | # Losses 166 | policy_loss = self.policy_kl_div(policy_logprob_targets, policy_logprob) 167 | q_penalty_loss = self.q_penalty_mse(q_penalty_pred, q_penalty_target) 168 | q_no_penalty_loss = self.q_no_penalty_mse( 169 | q_no_penalty_pred, q_no_penalty_target 170 | ) 171 | loss = policy_loss + q_penalty_loss + q_no_penalty_loss 172 | 173 | self.log(f"{log_prefix}_loss", loss, prog_bar=True) 174 | self.log(f"{log_prefix}_policy_kl_div", policy_loss) 175 | self.log(f"{log_prefix}_value_mse", q_penalty_loss) 176 | return loss 177 | 178 | 179 | class ResidualBlock(pl.LightningModule): 180 | def __init__(self, n_channels: int, kernel_size=3, padding=1) -> None: 181 | super().__init__() 182 | self.block = nn.Sequential( 183 | nn.Conv2d(n_channels, n_channels, kernel_size=kernel_size, padding=padding), 184 | nn.Conv2d(n_channels, n_channels, kernel_size=kernel_size, padding=padding), 185 | nn.BatchNorm2d(n_channels), 186 | nn.ReLU(), 187 | ) 188 | 189 | def forward(self, x): 190 | return x + self.block(x) 191 | -------------------------------------------------------------------------------- /src/c4a0/sweep.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from typing import List 3 | 4 | from loguru import logger 5 | import optuna 6 | import pytorch_lightning as pl 7 | from pytorch_lightning.callbacks import EarlyStopping 8 | 9 | from c4a0.training import SampleDataModule, TrainingGen 10 | from c4a0.nn import ConnectFourNet, ModelConfig 11 | from c4a0_rust import Sample # type: ignore 12 | 13 | 14 | def load_samples(base_dir: str, n_gens: int = 5) -> List[Sample]: 15 | gens = TrainingGen.load_all(base_dir)[:n_gens] 16 | game_results = [gen.get_games(base_dir) for gen in gens] 17 | samples = [ 18 | sample 19 | for game_result in game_results 20 | if game_result 21 | for results in game_result.results 22 | for sample in results.samples 23 | ] 24 | return samples 25 | 26 | 27 | def objective(trial: optuna.Trial, samples: List[Sample]): 28 | model_config = ModelConfig( 29 | n_residual_blocks=trial.suggest_int("n_residual_blocks", 0, 1), 30 | conv_filter_size=trial.suggest_int("conv_filter_size", 16, 64), 31 | n_policy_layers=trial.suggest_int("n_policy_layers", 0, 4), 32 | n_value_layers=trial.suggest_int("n_value_layers", 0, 2), 33 | lr_schedule={0: trial.suggest_loguniform("learning_rate", 1e-4, 1e-2)}, 34 | l2_reg=trial.suggest_loguniform("l2_reg", 1e-5, 1e-3), 35 | ) 36 | model = ConnectFourNet(model_config) 37 | 38 | batch_size = trial.suggest_categorical("batch_size", [256, 512, 1024]) 39 | 40 | split_idx = int(0.8 * len(samples)) 41 | train, test = samples[:split_idx], samples[split_idx:] 42 | data_module = SampleDataModule(train, test, batch_size) 43 | 44 | trainer = pl.Trainer( 45 | max_epochs=30, 46 | accelerator="auto", 47 | devices="auto", 48 | callbacks=[ 49 | EarlyStopping(monitor="val_loss", patience=4, mode="min"), 50 | ], 51 | enable_progress_bar=False, # Disable progress bar for cleaner logs 52 | ) 53 | 54 | trainer.fit(model, data_module) 55 | 56 | val_loss = trainer.callback_metrics["val_loss"].item() 57 | trial.report(val_loss, step=trainer.current_epoch) 58 | return val_loss 59 | 60 | 61 | def perform_hparam_sweep(base_dir: str, study_name: str = "sweep_hparam"): 62 | samples = load_samples(base_dir) 63 | 64 | storage_name = f"sqlite:///{study_name}.db" 65 | study = optuna.create_study( 66 | study_name=study_name, 67 | storage=storage_name, 68 | load_if_exists=True, 69 | direction="minimize", 70 | pruner=optuna.pruners.MedianPruner(), 71 | ) 72 | 73 | study.optimize( 74 | functools.partial(objective, samples=samples), n_trials=100, catch=(Exception,) 75 | ) 76 | 77 | logger.info("Best trial:") 78 | trial = study.best_trial 79 | logger.info(f" Value: {trial.value}") 80 | logger.info(" Params: ") 81 | for key, value in trial.params.items(): 82 | logger.info(f" {key}: {value}") 83 | 84 | logger.info("") 85 | logger.info("Study statistics:") 86 | logger.info(f" Finished trials: {len(study.trials)}") 87 | logger.info( 88 | f" Pruned trials: {len(study.get_trials(states=[optuna.trial.TrialState.PRUNED]))}" 89 | ) 90 | logger.info( 91 | f" Completed trials: {len(study.get_trials(states=[optuna.trial.TrialState.COMPLETE]))}" 92 | ) 93 | -------------------------------------------------------------------------------- /src/c4a0/tournament.py: -------------------------------------------------------------------------------- 1 | """ 2 | Round-robin tournament to determine which model is the best. 3 | """ 4 | 5 | import abc 6 | from collections import defaultdict 7 | from dataclasses import dataclass, field 8 | from datetime import datetime 9 | import itertools 10 | from typing import Callable, Dict, List, NewType, Optional, Tuple 11 | import numpy as np 12 | 13 | from loguru import logger 14 | from tabulate import tabulate 15 | import torch 16 | 17 | from c4a0.nn import ConnectFourNet 18 | import c4a0_rust # type: ignore 19 | from c4a0_rust import N_COLS # type: ignore 20 | 21 | PlayerName = NewType("PlayerName", str) 22 | 23 | ModelID = NewType("ModelID", int) 24 | 25 | 26 | class Player(abc.ABC): 27 | name: PlayerName 28 | model_id: ModelID 29 | 30 | def __init__(self, name: str, model_id: ModelID): 31 | self.name = PlayerName(name) 32 | self.model_id = model_id 33 | 34 | def forward_numpy(self, x: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 35 | raise NotImplementedError 36 | 37 | 38 | class ModelPlayer(Player): 39 | """Player whose policy and value are determined by a ConnectFourNet.""" 40 | 41 | model: ConnectFourNet 42 | device: torch.device 43 | 44 | def __init__(self, model_id: ModelID, model: ConnectFourNet, device: torch.device): 45 | super().__init__(f"gen{model_id}", model_id) 46 | self.model_id = model_id 47 | self.model = model 48 | self.model.to(device) 49 | 50 | def forward_numpy(self, x: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 51 | return self.model.forward_numpy(x) 52 | 53 | 54 | class RandomPlayer(Player): 55 | """Player that provides a random policy and value.""" 56 | 57 | def __init__(self, model_id: ModelID): 58 | super().__init__("random", model_id) 59 | 60 | def forward_numpy(self, x: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 61 | batch_size = x.shape[0] 62 | policy_logits = torch.rand(batch_size, N_COLS).numpy() 63 | q_value = (torch.rand(batch_size) * 2 - 1).numpy() # [-1, 1] 64 | return policy_logits, q_value, q_value 65 | 66 | 67 | class UniformPlayer(Player): 68 | """Player that provides a uniform policy and 0 value.""" 69 | 70 | def __init__(self, model_id: ModelID): 71 | super().__init__("uniform", model_id) 72 | 73 | def forward_numpy(self, x: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 74 | batch_size = x.shape[0] 75 | policy_logits = torch.ones(batch_size, N_COLS).numpy() 76 | q_value = torch.zeros(batch_size).numpy() 77 | return policy_logits, q_value, q_value 78 | 79 | 80 | @dataclass 81 | class TournamentResult: 82 | """Represents the results from a tournamnet.""" 83 | 84 | model_ids: List[ModelID] 85 | date: datetime = field(default_factory=datetime.now) 86 | games: Optional[c4a0_rust.PlayGamesResult] = None 87 | 88 | def get_scores(self) -> List[Tuple[ModelID, float]]: 89 | scores: Dict[ModelID, float] = defaultdict(lambda: 0.0) 90 | for result in self.games.results: # type: ignore 91 | player0_score = result.player0_score() 92 | scores[result.metadata.player0_id] += player0_score 93 | scores[result.metadata.player1_id] += 1 - player0_score 94 | 95 | ret = list(scores.items()) 96 | ret.sort(key=lambda x: x[1], reverse=True) 97 | return ret 98 | 99 | def scores_table(self, get_name: Callable[[int], str]) -> str: 100 | return tabulate( 101 | [(get_name(id), score) for id, score in self.get_scores()], 102 | headers=["Player", "Score"], 103 | tablefmt="github", 104 | ) 105 | 106 | def get_top_models(self) -> List[ModelID]: 107 | """Returns the top models from the tournament in descending order of performance.""" 108 | return [model_id for model_id, _ in self.get_scores()] 109 | 110 | 111 | def play_tournament( 112 | players: List[Player], 113 | games_per_match: int, 114 | batch_size: int, 115 | mcts_iterations: int, 116 | exploration_constant: float, 117 | ) -> TournamentResult: 118 | """Players a round-robin tournament, returning the total score of each player.""" 119 | assert games_per_match % 2 == 0, "games_per_match must be even" 120 | 121 | gen_id_to_player = {player.model_id: player for player in players} 122 | tournament = TournamentResult( 123 | model_ids=[player.model_id for player in players], 124 | ) 125 | player_ids = [player.model_id for player in players] 126 | pairings = list(itertools.permutations(player_ids, 2)) * int(games_per_match / 2) 127 | reqs = [c4a0_rust.GameMetadata(id, p0, p1) for id, (p0, p1) in enumerate(pairings)] 128 | 129 | logger.info(f"Beginning tournament with {len(players)} players") 130 | tournament.games = c4a0_rust.play_games( 131 | reqs, 132 | batch_size, 133 | mcts_iterations, 134 | exploration_constant, 135 | lambda player_id, pos: gen_id_to_player[player_id].forward_numpy(pos), 136 | ) 137 | logger.info(f"Finished tournament with {len(tournament.games.results)} games") # type: ignore 138 | 139 | return tournament 140 | -------------------------------------------------------------------------------- /src/c4a0/training.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generation-based network training, alternating between self-play and training. 3 | """ 4 | 5 | import copy 6 | from datetime import datetime 7 | import os 8 | import pickle 9 | from typing import Dict, List, NewType, Optional, Tuple 10 | 11 | from loguru import logger 12 | from pydantic import BaseModel 13 | import pytorch_lightning as pl 14 | from pytorch_lightning.callbacks import EarlyStopping 15 | import torch 16 | from torch.utils.data import DataLoader 17 | 18 | from c4a0.nn import ConnectFourNet, ModelConfig 19 | from c4a0.utils import BestModelCheckpoint 20 | 21 | import c4a0_rust # type: ignore 22 | from c4a0_rust import PlayGamesResult, BUF_N_CHANNELS, N_COLS, N_ROWS, Sample # type: ignore 23 | 24 | 25 | class TrainingGen(BaseModel): 26 | """ 27 | Represents a single generation of training. 28 | """ 29 | 30 | created_at: datetime 31 | gen_n: int 32 | n_mcts_iterations: int 33 | c_exploration: float 34 | c_ply_penalty: float 35 | self_play_batch_size: int 36 | training_batch_size: int 37 | parent: Optional[datetime] = None 38 | val_loss: Optional[float] = None 39 | solver_score: Optional[float] = None 40 | 41 | @staticmethod 42 | def _gen_folder(created_at: datetime, base_dir: str) -> str: 43 | return os.path.join(base_dir, created_at.isoformat()) 44 | 45 | def gen_folder(self, base_dir: str) -> str: 46 | return TrainingGen._gen_folder(self.created_at, base_dir) 47 | 48 | def save_all( 49 | self, 50 | base_dir: str, 51 | games: Optional[PlayGamesResult], 52 | model: ConnectFourNet, 53 | ): 54 | gen_dir = self.gen_folder(base_dir) 55 | os.makedirs(gen_dir, exist_ok=True) 56 | 57 | metadata_path = os.path.join(gen_dir, "metadata.json") 58 | with open(metadata_path, "w") as f: 59 | f.write(self.model_dump_json(indent=2)) 60 | 61 | play_result_path = os.path.join(gen_dir, "games.pkl") 62 | with open(play_result_path, "wb") as f: 63 | pickle.dump(games, f) 64 | 65 | model_path = os.path.join(gen_dir, "model.pkl") 66 | with open(model_path, "wb") as f: 67 | pickle.dump(model, f) 68 | 69 | def save_metadata(self, base_dir: str): 70 | gen_dir = self.gen_folder(base_dir) 71 | os.makedirs(gen_dir, exist_ok=True) 72 | 73 | metadata_path = os.path.join(gen_dir, "metadata.json") 74 | with open(metadata_path, "w") as f: 75 | f.write(self.model_dump_json(indent=2)) 76 | 77 | @staticmethod 78 | def load(base_dir: str, created_at: datetime) -> "TrainingGen": 79 | gen_folder = TrainingGen._gen_folder(created_at, base_dir) 80 | with open(os.path.join(gen_folder, "metadata.json"), "r") as f: 81 | return TrainingGen.model_validate_json(f.read()) 82 | 83 | @staticmethod 84 | def load_all(base_dir: str) -> List["TrainingGen"]: 85 | timestamps = sorted( 86 | [ 87 | datetime.fromisoformat(f) 88 | for f in os.listdir(base_dir) 89 | if os.path.isdir(os.path.join(base_dir, f)) 90 | ], 91 | reverse=True, 92 | ) 93 | return [TrainingGen.load(base_dir, t) for t in timestamps] 94 | 95 | @staticmethod 96 | def load_latest(base_dir: str) -> "TrainingGen": 97 | timestamps = sorted( 98 | [ 99 | datetime.fromisoformat(f) 100 | for f in os.listdir(base_dir) 101 | if os.path.isdir(os.path.join(base_dir, f)) 102 | ], 103 | reverse=True, 104 | ) 105 | if not timestamps: 106 | raise FileNotFoundError("No existing generations") 107 | return TrainingGen.load(base_dir, timestamps[0]) 108 | 109 | @staticmethod 110 | def load_latest_with_default( 111 | base_dir: str, 112 | n_mcts_iterations: int, 113 | c_exploration: float, 114 | c_ply_penalty: float, 115 | self_play_batch_size: int, 116 | training_batch_size: int, 117 | model_config: ModelConfig, 118 | ): 119 | try: 120 | return TrainingGen.load_latest(base_dir) 121 | except FileNotFoundError: 122 | logger.info("No existing generations found, initializing root") 123 | gen = TrainingGen( 124 | created_at=datetime.now(), 125 | gen_n=0, 126 | n_mcts_iterations=n_mcts_iterations, 127 | c_exploration=c_exploration, 128 | c_ply_penalty=c_ply_penalty, 129 | self_play_batch_size=self_play_batch_size, 130 | training_batch_size=training_batch_size, 131 | ) 132 | model = ConnectFourNet(model_config) 133 | gen.save_all(base_dir, None, model) 134 | return gen 135 | 136 | def get_games(self, base_dir: str) -> Optional[PlayGamesResult]: 137 | gen_folder = self.gen_folder(base_dir) 138 | with open(os.path.join(gen_folder, "games.pkl"), "rb") as f: 139 | return pickle.load(f) 140 | 141 | def get_model(self, base_dir: str) -> ConnectFourNet: 142 | """Gets the model for this generation.""" 143 | gen_folder = self.gen_folder(base_dir) 144 | with open(os.path.join(gen_folder, "model.pkl"), "rb") as f: 145 | model = pickle.load(f) 146 | return model 147 | 148 | 149 | class SolverConfig(BaseModel): 150 | solver_path: str 151 | book_path: str 152 | solutions_path: str 153 | 154 | 155 | def train_single_gen( 156 | base_dir: str, 157 | device: torch.device, 158 | parent: TrainingGen, 159 | n_self_play_games: int, 160 | n_mcts_iterations: int, 161 | c_exploration: float, 162 | c_ply_penalty: float, 163 | self_play_batch_size: int, 164 | training_batch_size: int, 165 | solver_config: Optional[SolverConfig] = None, 166 | ) -> TrainingGen: 167 | """ 168 | Trains a new generation from the given parent. 169 | First generate games using c4a0_rust.play_games. 170 | Then train a new model based on the parent model using the generated samples. 171 | Finally, save the resulting games and model in the training directory. 172 | """ 173 | gen_n = parent.gen_n + 1 174 | logger.info(f"Beginning new generation {gen_n} from {parent.gen_n}") 175 | 176 | # TODO: log experiment metadata in MLFlow 177 | 178 | # Self play 179 | model = parent.get_model(base_dir) 180 | model.to(device) 181 | reqs = [c4a0_rust.GameMetadata(id, 0, 0) for id in range(n_self_play_games)] # type: ignore 182 | games = c4a0_rust.play_games( # type: ignore 183 | reqs, 184 | self_play_batch_size, 185 | n_mcts_iterations, 186 | c_exploration, 187 | c_ply_penalty, 188 | lambda player_id, pos: model.forward_numpy(pos), # type: ignore 189 | ) 190 | 191 | # Optionally judge generated policies against solver 192 | if solver_config is not None: 193 | logger.info("Scoring policies against solver") 194 | solver_score = games.score_policies( 195 | solver_config.solver_path, 196 | solver_config.book_path, 197 | solver_config.solutions_path, 198 | ) 199 | logger.info("Solver score: {}", solver_score) 200 | else: 201 | logger.info("Skipping scoring against solver") 202 | solver_score = None 203 | 204 | # Training 205 | logger.info("Beginning training") 206 | model = copy.deepcopy(model) 207 | train, test = games.split_train_test(0.8, 1337) # type: ignore 208 | data_module = SampleDataModule(train, test, training_batch_size) 209 | best_model_cb = BestModelCheckpoint(monitor="val_loss", mode="min") 210 | trainer = pl.Trainer( 211 | max_epochs=100, 212 | accelerator="auto", 213 | devices="auto", 214 | callbacks=[ 215 | best_model_cb, 216 | EarlyStopping(monitor="val_loss", patience=10, mode="min"), 217 | ], 218 | ) 219 | trainer.gen_n = gen_n # type: ignore 220 | model.train() # Switch batch normalization to train mode for training bn params 221 | trainer.fit(model, data_module) 222 | logger.info("Finished training") 223 | 224 | gen = TrainingGen( 225 | created_at=datetime.now(), 226 | gen_n=parent.gen_n + 1, 227 | n_mcts_iterations=n_mcts_iterations, 228 | c_exploration=c_exploration, 229 | c_ply_penalty=c_ply_penalty, 230 | self_play_batch_size=self_play_batch_size, 231 | training_batch_size=training_batch_size, 232 | parent=parent.created_at, 233 | val_loss=trainer.callback_metrics["val_loss"].item(), 234 | solver_score=solver_score, 235 | ) 236 | gen.save_all(base_dir, games, best_model_cb.get_best_model()) 237 | return gen 238 | 239 | 240 | def training_loop( 241 | base_dir: str, 242 | device: torch.device, 243 | n_self_play_games: int, 244 | n_mcts_iterations: int, 245 | c_exploration: float, 246 | c_ply_penalty: float, 247 | self_play_batch_size: int, 248 | training_batch_size: int, 249 | model_config: ModelConfig, 250 | max_gens: Optional[int] = None, 251 | solver_config: Optional[SolverConfig] = None, 252 | ) -> TrainingGen: 253 | """Main training loop. Sequentially trains generation after generation.""" 254 | logger.info("Beginning training loop") 255 | logger.info("device: {}", device) 256 | logger.info("n_self_play_games: {}", n_self_play_games) 257 | logger.info("n_mcts_iterations: {}", n_mcts_iterations) 258 | logger.info("c_exploration: {}", c_exploration) 259 | logger.info("c_ply_penalty: {}", c_ply_penalty) 260 | logger.info("self_play_batch_size: {}", self_play_batch_size) 261 | logger.info("training_batch_size: {}", training_batch_size) 262 | logger.info("model_config: \n{}", model_config.model_dump_json(indent=2)) 263 | logger.info("max_gens: {}", max_gens) 264 | logger.info( 265 | "solver_config: {}", solver_config and solver_config.model_dump_json(indent=2) 266 | ) 267 | 268 | gen = TrainingGen.load_latest_with_default( 269 | base_dir=base_dir, 270 | n_mcts_iterations=n_mcts_iterations, 271 | c_exploration=c_exploration, 272 | c_ply_penalty=c_ply_penalty, 273 | self_play_batch_size=self_play_batch_size, 274 | training_batch_size=training_batch_size, 275 | model_config=model_config, 276 | ) 277 | 278 | while True: 279 | gen = train_single_gen( 280 | base_dir=base_dir, 281 | device=device, 282 | parent=gen, 283 | n_self_play_games=n_self_play_games, 284 | n_mcts_iterations=n_mcts_iterations, 285 | c_exploration=c_exploration, 286 | c_ply_penalty=c_ply_penalty, 287 | self_play_batch_size=self_play_batch_size, 288 | training_batch_size=training_batch_size, 289 | solver_config=solver_config, 290 | ) 291 | if max_gens is not None and gen.gen_n >= max_gens: 292 | return gen 293 | 294 | 295 | SampleTensor = NewType( 296 | "SampleTensor", 297 | Tuple[ 298 | torch.Tensor, # Pos 299 | torch.Tensor, # Policy 300 | torch.Tensor, # Q Value with penalty 301 | torch.Tensor, # Q Value without penalty 302 | ], 303 | ) 304 | 305 | 306 | class SampleDataModule(pl.LightningDataModule): 307 | def __init__( 308 | self, 309 | training_data: List[Sample], 310 | validation_data: List[Sample], 311 | batch_size: int, 312 | ): 313 | super().__init__() 314 | self.batch_size = batch_size 315 | training_data += [s.flip_h() for s in training_data] 316 | validation_data += [s.flip_h() for s in validation_data] 317 | self.training_data = [self.sample_to_tensor(s) for s in training_data] 318 | self.validation_data = [self.sample_to_tensor(s) for s in validation_data] 319 | 320 | @staticmethod 321 | def sample_to_tensor(sample: Sample) -> "SampleTensor": 322 | pos, policy, q_penalty, q_no_penalty = sample.to_numpy() 323 | pos_t = torch.from_numpy(pos) 324 | policy_t = torch.from_numpy(policy) 325 | q_penalty_t = torch.from_numpy(q_penalty) 326 | q_no_penalty_t = torch.from_numpy(q_no_penalty) 327 | assert pos_t.shape == (BUF_N_CHANNELS, N_ROWS, N_COLS) 328 | assert policy_t.shape == (N_COLS,) 329 | assert q_penalty_t.shape == () 330 | assert q_no_penalty.shape == () 331 | return SampleTensor((pos_t, policy_t, q_penalty_t, q_no_penalty_t)) 332 | 333 | def train_dataloader(self): 334 | return DataLoader( 335 | self.training_data, # type: ignore 336 | batch_size=self.batch_size, 337 | shuffle=True, 338 | ) 339 | 340 | def val_dataloader(self): 341 | return DataLoader( 342 | self.validation_data, # type: ignore 343 | batch_size=self.batch_size, 344 | ) 345 | 346 | 347 | def parse_lr_schedule(floats: List[float]) -> Dict[int, float]: 348 | """Parses an lr_schedule sequence like "0 2e-3 10 8e-4" into a dict of {0: 2e-3, 10: 8e-4}.""" 349 | assert len(floats) % 2 == 0, "lr_schedule must have an even number of elements" 350 | schedule = {} 351 | for i in range(0, len(floats), 2): 352 | threshold = int(floats[i]) 353 | assert ( 354 | threshold == floats[i] 355 | ), "lr_schedule must alternate between gen_id (int) and lr (float)" 356 | lr = floats[i + 1] 357 | schedule[threshold] = lr 358 | return schedule 359 | -------------------------------------------------------------------------------- /src/c4a0/utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from sys import platform 3 | from typing import Generic, Optional, TypeVar 4 | 5 | import torch 6 | import pytorch_lightning as pl 7 | from pytorch_lightning.callbacks import Callback 8 | from pytorch_lightning.trainer.trainer import Trainer 9 | 10 | 11 | def get_torch_device() -> torch.device: 12 | """Tries to use cuda or mps, if available, otherwise falls back to cpu.""" 13 | if torch.cuda.is_available(): 14 | return torch.device("cuda") 15 | 16 | if platform == "darwin": 17 | if torch.backends.mps.is_available(): 18 | return torch.device("mps") 19 | elif not torch.backends.mps.is_built(): 20 | raise RuntimeError( 21 | "MPS unavailable because the current torch install was not built with MPS enabled." 22 | ) 23 | else: 24 | raise RuntimeError( 25 | "MPS unavailable because the current MacOS version is not 12.3+ and/or you do not " 26 | "have an MPS-enabled device on this machine." 27 | ) 28 | 29 | return torch.device("cpu") 30 | 31 | 32 | M = TypeVar("M", bound=pl.LightningModule) 33 | 34 | 35 | class BestModelCheckpoint(Callback, Generic[M]): 36 | """ 37 | PyTorch Lightning callback that keeps track of the best model in memory during training. 38 | 39 | This callback monitors a specified metric and saves the model with the best 40 | score in memory. It can be used to retrieve the best model after training. 41 | """ 42 | 43 | def __init__(self, monitor: str = "val_loss", mode: str = "min") -> None: 44 | """ 45 | Initialize the BestModelCheckpoint callback. 46 | 47 | Args: 48 | monitor (str): Name of the metric to monitor. Defaults to 'val_loss'. 49 | mode (str): One of {'min', 'max'}. In 'min' mode, the lowest metric value is considered 50 | best, in 'max' mode the highest. Defaults to 'min'. 51 | """ 52 | super().__init__() 53 | self.monitor = monitor 54 | self.mode = mode 55 | self.best_model: Optional[M] = None 56 | self.best_score = float("inf") if mode == "min" else float("-inf") 57 | 58 | def on_validation_end(self, trainer: Trainer, pl_module: M) -> None: 59 | """ 60 | Check if the current model is the best so far. 61 | 62 | This method is called after each validation epoch. It compares the current 63 | monitored metric with the best one so far and updates the best model if necessary. 64 | 65 | Args: 66 | trainer (Trainer): The PyTorch Lightning trainer instance. 67 | pl_module (LightningModule): The current PyTorch Lightning module. 68 | """ 69 | current_score = trainer.callback_metrics.get(self.monitor) 70 | if current_score is None: 71 | return 72 | 73 | if isinstance(current_score, torch.Tensor): 74 | current_score = current_score.item() 75 | 76 | if self.mode == "min": 77 | if current_score < self.best_score: 78 | self.best_score = current_score 79 | self.best_model = copy.deepcopy(pl_module) 80 | elif self.mode == "max": 81 | if current_score > self.best_score: 82 | self.best_score = current_score 83 | self.best_model = copy.deepcopy(pl_module) 84 | 85 | def get_best_model(self) -> M: 86 | """ 87 | Returns the best model found during training. 88 | """ 89 | assert self.best_model is not None, "no model checkpoint called" 90 | return self.best_model 91 | -------------------------------------------------------------------------------- /tests/c4a0_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/advait/c4a0/49cb6584fd4cc31e68f056f210faff9bac8823cf/tests/c4a0_tests/__init__.py -------------------------------------------------------------------------------- /tests/c4a0_tests/conftest.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | # Set the random seed for deterministic behavior 8 | seed = 1337 9 | random.seed(seed) 10 | np.random.seed(seed) 11 | torch.manual_seed(seed) 12 | -------------------------------------------------------------------------------- /tests/c4a0_tests/nn_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from c4a0.c4 import N_COLS, N_ROWS, STARTING_POS 5 | from c4a0.nn import ConnectFourNet 6 | 7 | 8 | def test_random_nn_works(): 9 | model = ConnectFourNet() 10 | pos = torch.from_numpy(STARTING_POS).float().unsqueeze(0) 11 | model.eval() 12 | with torch.no_grad(): 13 | policy_logprobs, value = model(pos) 14 | 15 | policy = torch.exp(policy_logprobs).squeeze(0).numpy() 16 | value = value.squeeze(0).item() 17 | 18 | assert len(policy) == N_COLS 19 | assert policy.sum().item() == pytest.approx(1.0, abs=1e-5) 20 | assert -1.0 <= value <= 1.0 21 | 22 | 23 | @pytest.mark.filterwarnings("ignore:You are trying to `self.log()`*") 24 | def test_loss_of_zero(): 25 | """Using the model output as training labels should result in a loss of zero.""" 26 | model = ConnectFourNet() 27 | pos = torch.from_numpy(STARTING_POS).float().unsqueeze(0) 28 | model.eval() 29 | with torch.no_grad(): 30 | policy_logprobs, value = model(pos) 31 | policy = torch.exp(policy_logprobs) 32 | value = value.squeeze(0) 33 | 34 | assert pos.shape == (1, N_ROWS, N_COLS) 35 | assert policy.shape == (1, N_COLS) 36 | assert value.shape == (1,) 37 | 38 | training_batch = ( 39 | [0], 40 | pos, 41 | policy, 42 | value, 43 | ) 44 | loss = model.training_step(training_batch, 0) 45 | loss = loss.detach().item() 46 | assert loss == pytest.approx(0.0, abs=1e-6) 47 | 48 | 49 | @pytest.mark.filterwarnings("ignore:You are trying to `self.log()`*") 50 | def test_loss_of_nonzero(): 51 | """Using random label data should result in a > 0 loss.""" 52 | model = ConnectFourNet() 53 | pos = torch.from_numpy(STARTING_POS).float().unsqueeze(0) 54 | policy = torch.rand((1, N_COLS)) 55 | value = torch.rand((1,)) 56 | 57 | training_batch = ( 58 | [0], 59 | pos, 60 | policy, 61 | value, 62 | ) 63 | model.eval() 64 | loss = model.training_step(training_batch, 0) 65 | loss = loss.detach().item() 66 | assert loss > 0.0 67 | -------------------------------------------------------------------------------- /tests/c4a0_tests/tournament_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from c4a0.nn import ConnectFourNet 5 | from c4a0.tournament import ( 6 | play_tournament, 7 | ModelPlayer, 8 | RandomPlayer, 9 | UniformPlayer, 10 | ModelID, 11 | ) 12 | 13 | 14 | @pytest.mark.asyncio 15 | async def test_tournament(): 16 | model = ConnectFourNet() 17 | model.eval() # Disable batch normalization 18 | model_player = ModelPlayer( 19 | model_id=ModelID(0), 20 | model=model, 21 | device=torch.device("cpu"), 22 | ) 23 | players = [model_player, RandomPlayer(), UniformPlayer()] 24 | 25 | tournament = await play_tournament( 26 | players=players, 27 | games_per_match=2, 28 | mcts_iterations=10, 29 | exploration_constant=1.4, 30 | ) 31 | 32 | assert len(tournament.games) == 2 * 2 * 3 33 | for game in tournament.games: 34 | assert game.p0 in players 35 | assert game.p1 in players 36 | assert 0 <= game.score <= 1 37 | --------------------------------------------------------------------------------