├── .gitignore ├── .mise.toml ├── AlphaZeroGUI ├── CustomGUI.py ├── __init__.py ├── __main__.py ├── _gui.py ├── args │ └── stratego.json ├── img │ ├── pause.png │ ├── play.png │ ├── save.png │ └── stop.png └── main.py ├── LICENSE ├── README.md ├── alphazero ├── Arena.pyx ├── Coach.py ├── Evaluator.py ├── Game.py ├── GenericPlayers.py ├── MCTS.pyx ├── NNetArchitecture.py ├── NNetWrapper.py ├── SelfPlayAgent.pyx ├── __init__.py ├── cGame.pxd ├── cGame.pyx ├── envs │ ├── __init__.py │ ├── brandubh │ │ ├── __init__.py │ │ ├── brandubh.py │ │ ├── fastafl.pyx │ │ ├── gui.py │ │ └── players.py │ ├── chess │ │ ├── __init__.py │ │ ├── chess.py │ │ └── train.py │ ├── connect4 │ │ ├── Connect4Logic.pyx │ │ ├── __init__.py │ │ ├── connect4.pyx │ │ ├── gui.py │ │ ├── pit.py │ │ ├── players.py │ │ ├── test_connect4.py │ │ └── train.py │ ├── gobang │ │ ├── GobangLogic.pxd │ │ ├── GobangLogic.pyx │ │ ├── GobangPlayers.py │ │ ├── __init__.py │ │ ├── gobang.pyx │ │ ├── pit.py │ │ └── train.py │ ├── hnefatafl │ │ ├── __init__.py │ │ ├── brandubh.pyx │ │ ├── fastafl.pyx │ │ ├── gui.py │ │ ├── hnefatafl.py │ │ ├── pit.py │ │ ├── players.py │ │ ├── tafl_old.pyx │ │ ├── train.py │ │ ├── train_brandubh.py │ │ ├── train_fastafl.py │ │ └── train_test.py │ ├── othello │ │ ├── OthelloLogic.pyx │ │ ├── OthelloPlayers.py │ │ ├── __init__.py │ │ ├── othello.pyx │ │ └── train.py │ ├── stratego │ │ ├── __init__.py │ │ ├── engine.pxd │ │ ├── engine.pyx │ │ ├── pit.py │ │ ├── players.py │ │ ├── stratego.pyx │ │ └── train.py │ └── tictactoe │ │ ├── TicTacToeLogic.py │ │ ├── TicTacToePlayers.py │ │ ├── __init__.py │ │ ├── tictactoe.py │ │ └── train.py ├── pit-multi.py ├── pit.py ├── pytorch_classification │ ├── __init__.py │ └── utils │ │ ├── __init__.py │ │ ├── eval.py │ │ ├── images │ │ ├── cifar.png │ │ └── imagenet.png │ │ ├── logger.py │ │ ├── misc.py │ │ └── progress │ │ ├── LICENSE │ │ ├── MANIFEST.in │ │ ├── README.rst │ │ ├── demo.gif │ │ ├── progress │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── bar.cpython-37.pyc │ │ │ ├── bar.cpython-38.pyc │ │ │ ├── helpers.cpython-37.pyc │ │ │ └── helpers.cpython-38.pyc │ │ ├── bar.py │ │ ├── counter.py │ │ ├── helpers.py │ │ └── spinner.py │ │ ├── setup.py │ │ └── test_progress.py ├── requirements.txt ├── roundrobin.py └── utils.py ├── boardgame ├── __init__.py ├── __init__.pyo ├── board.pxd ├── board.pyx ├── boardgame.pyo ├── errors.py ├── errors.pyo └── net.pyo ├── fastafl ├── __init__.py ├── cengine.pxd ├── cengine.pyx ├── engine.py └── variants.py ├── hnefatafl ├── __init__.pyo ├── _gui.pyo ├── engine │ ├── __init__.pyo │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── board.cpython-37.pyc │ │ ├── board.cpython-38.pyc │ │ ├── game.cpython-37.pyc │ │ ├── game.cpython-38.pyc │ │ ├── variants.cpython-37.pyc │ │ └── variants.cpython-38.pyc │ ├── board.pyo │ └── game.pyo └── net │ ├── __init__.pyo │ ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── client.cpython-37.pyc │ ├── client.cpython-38.pyc │ └── server.cpython-37.pyc │ └── client.pyo ├── remove_train.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | alphazero/__pycache__ 2 | alphazero/pytorch_classification/__pycache__ 3 | alphazero/pytorch_classification/utils/__pycache__ 4 | AlphaZeroGUI/__pycache__ 5 | AlphaZeroGUI/*.ui 6 | game2dboard/__pycache__ 7 | boardgame/__pycache__ 8 | hnefatafl/__pycache__ 9 | fastafl/__pycache__ 10 | .idea 11 | checkpoint 12 | data 13 | runs 14 | -------------------------------------------------------------------------------- /.mise.toml: -------------------------------------------------------------------------------- 1 | [tools] 2 | python = "3.8" 3 | ninja = "1.12" 4 | clang = "19.1" 5 | -------------------------------------------------------------------------------- /AlphaZeroGUI/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | PACKAGE_ROOT = Path('AlphaZeroGUI') 4 | ARGS_DIR = PACKAGE_ROOT / 'args' 5 | IMAGE_DIR = PACKAGE_ROOT / 'img' 6 | 7 | ALPHAZERO_ROOT = Path('alphazero') 8 | ENVS_DIR = ALPHAZERO_ROOT / 'envs' 9 | GENERIC_PLAYERS_MODULE = 'GenericPlayers' 10 | PLAYERS_MODULE = 'players' 11 | -------------------------------------------------------------------------------- /AlphaZeroGUI/__main__.py: -------------------------------------------------------------------------------- 1 | from .main import run 2 | run() 3 | -------------------------------------------------------------------------------- /AlphaZeroGUI/args/stratego.json: -------------------------------------------------------------------------------- 1 | {"run_name": "stratego", "cuda": true, "workers": 6, "startIter": 8, "numIters": 1000, "process_batch_size": 64, "train_batch_size": 512, "arena_batch_size": 32, "train_steps_per_iteration": 64, "train_sample_ratio": 2, "averageTrainSteps": false, "autoTrainSteps": true, "train_on_past_data": false, "past_data_chunk_size": 25, "past_data_run_name": "boardgame", "gamesPerIteration": 384, "minTrainHistoryWindow": 4, "maxTrainHistoryWindow": 20, "trainHistoryIncrementIters": 2, "max_moves": 512, "num_players": 2, "min_discount": 1, "fpu_reduction": 0.4, "num_stacked_observations": 8, "numWarmupIters": 1, "skipSelfPlayIters": null, "selfPlayModelIter": 7, "symmetricSamples": true, "numMCTSSims": 100, "numFastSims": 15, "numWarmupSims": 5, "probFastSim": 0.75, "mctsResetThreshold": null, "startTemp": 1, "temp_scaling_fn": "__CALLABLE__default_temp_scaling", "root_policy_temp": 1.1, "root_noise_frac": 0.2, "add_root_noise": true, "add_root_temp": true, "compareWithBaseline": true, "baselineTester": "__CALLABLE__RawMCTSPlayer", "arenaCompareBaseline": 192, "arenaCompare": 192, "arenaTemp": 0.25, "arenaMCTS": true, "arenaBatched": true, "baselineCompareFreq": 5, "compareWithPast": true, "pastCompareFreq": 5, "model_gating": true, "max_gating_iters": null, "min_next_model_winrate": 0.52, "use_draws_for_winrate": true, "load_model": true, "cpuct": 1.25, "value_loss_weight": 1.5, "checkpoint": "checkpoint", "data": "data", "scheduler": "__CALLABLE__MultiStepLR", "scheduler_args": {"milestones": [75, 125], "gamma": 0.1}, "lr": 0.01, "optimizer": "__CALLABLE__SGD", "optimizer_args": {"momentum": 0.9, "weight_decay": 0.0001}, "num_channels": 64, "depth": 4, "value_head_channels": 16, "policy_head_channels": 16, "value_dense_layers": [1024, 128], "policy_dense_layers": [1024], "nnet_type": "resnet"} -------------------------------------------------------------------------------- /AlphaZeroGUI/img/pause.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/AlphaZeroGUI/img/pause.png -------------------------------------------------------------------------------- /AlphaZeroGUI/img/play.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/AlphaZeroGUI/img/play.png -------------------------------------------------------------------------------- /AlphaZeroGUI/img/save.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/AlphaZeroGUI/img/save.png -------------------------------------------------------------------------------- /AlphaZeroGUI/img/stop.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/AlphaZeroGUI/img/stop.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 kevaday 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /alphazero/Game.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import List, Tuple, Any, Optional 3 | 4 | import numpy as np 5 | 6 | 7 | class GameState(ABC): 8 | def __init__(self, board): 9 | self._board = board 10 | self._player = 0 11 | self._turns = 0 12 | self.last_action = None 13 | 14 | def __str__(self) -> str: 15 | return f'Player:\t{self._player}\n{self._board}\n' 16 | 17 | @abstractmethod 18 | def __eq__(self, other: 'GameState') -> bool: 19 | """Compare the current game state to another""" 20 | pass 21 | 22 | @abstractmethod 23 | def clone(self) -> 'GameState': 24 | """Return a new clone of the game state, independent of the current one.""" 25 | pass 26 | 27 | @staticmethod 28 | @abstractmethod 29 | def action_size() -> int: 30 | """The size of the action space for the game""" 31 | pass 32 | 33 | @staticmethod 34 | @abstractmethod 35 | def observation_size() -> Tuple[int, int, int]: 36 | """ 37 | Returns: 38 | observation_size: the shape of observations of the current state, 39 | must be in the form channels x width x height. 40 | If only one plane is needed for observation, use 1 for channels. 41 | """ 42 | pass 43 | 44 | @abstractmethod 45 | def valid_moves(self) -> np.ndarray: 46 | """Returns a numpy binary array containing zeros for invalid moves and ones for valids.""" 47 | pass 48 | 49 | @staticmethod 50 | @abstractmethod 51 | def num_players() -> int: 52 | """Returns the number of total players participating in the game.""" 53 | pass 54 | 55 | @staticmethod 56 | def max_turns() -> Optional[int]: 57 | """The maximum number of turns the game can last before a draw is declared.""" 58 | return None 59 | 60 | @staticmethod 61 | def has_draw() -> bool: 62 | """Returns True if the game has a draw condition.""" 63 | return True 64 | 65 | @property 66 | def player(self) -> int: 67 | return self._player 68 | 69 | @property 70 | def turns(self) -> int: 71 | return self._turns 72 | 73 | def _next_player(self, player, turns=1) -> int: 74 | return (player + turns) % self.num_players() 75 | 76 | def _update_turn(self) -> None: 77 | """Should be called at the end of play_action""" 78 | self._player = self._next_player(self._player) 79 | self._turns += 1 80 | 81 | @abstractmethod 82 | def play_action(self, action: int) -> None: 83 | """Play the action in the current state given by argument action.""" 84 | self.last_action = action 85 | 86 | @abstractmethod 87 | def win_state(self) -> np.ndarray: 88 | """ 89 | Get the win state of the game, a numpy array of boolean values 90 | for each player indicating if they have won, plus one more 91 | boolean at the end to indicate a draw. 92 | """ 93 | pass 94 | 95 | @abstractmethod 96 | def observation(self) -> np.ndarray: 97 | """Get an observation from the game state in the form of a numpy array with the size of self.observation_size""" 98 | pass 99 | 100 | def symmetries(self, pi) -> List[Tuple['GameState', np.ndarray]]: 101 | """ 102 | Args: 103 | pi: the current policy for the given canonical state 104 | 105 | Returns: 106 | symmetries: list of state, pi pairs for symmetric samples of 107 | the given state and pi (ex: mirror, rotation). 108 | This is an optional method as symmetric samples 109 | can be disabled for training. 110 | """ 111 | raise NotImplementedError( 112 | 'Symmetries not implemented for this environment. Set symmetricSamples to False in args.' 113 | ) 114 | -------------------------------------------------------------------------------- /alphazero/GenericPlayers.py: -------------------------------------------------------------------------------- 1 | from alphazero.MCTS import MCTS 2 | from alphazero.Game import GameState 3 | from alphazero.NNetWrapper import NNetWrapper 4 | from alphazero.utils import dotdict, plot_mcts_tree 5 | 6 | from abc import ABC, abstractmethod 7 | 8 | import numpy as np 9 | import torch 10 | 11 | 12 | class BasePlayer(ABC): 13 | def __init__(self, game_cls: GameState = None, args: dotdict = None, verbose: bool = False): 14 | self.game_cls = game_cls 15 | self.args = args 16 | self.verbose = verbose 17 | 18 | def __call__(self, *args, **kwargs): 19 | return self.play(*args, **kwargs) 20 | 21 | @staticmethod 22 | def supports_process() -> bool: 23 | return False 24 | 25 | @staticmethod 26 | def requires_model() -> bool: 27 | return False 28 | 29 | @staticmethod 30 | def is_human() -> bool: 31 | return False 32 | 33 | def update(self, state: GameState, action: int) -> None: 34 | pass 35 | 36 | def reset(self): 37 | pass 38 | 39 | @abstractmethod 40 | def play(self, state: GameState) -> int: 41 | pass 42 | 43 | def process(self, batch): 44 | raise NotImplementedError 45 | 46 | 47 | class RandomPlayer(BasePlayer): 48 | def play(self, state): 49 | valids = state.valid_moves() 50 | valids = valids / np.sum(valids) 51 | a = np.random.choice(state.action_size(), p=valids) 52 | return a 53 | 54 | 55 | class NNPlayer(BasePlayer): 56 | def __init__(self, nn: NNetWrapper, *args, **kwargs): 57 | super().__init__(*args, **kwargs) 58 | self.nn = nn 59 | self.temp = self.args.startTemp 60 | 61 | @staticmethod 62 | def supports_process() -> bool: 63 | return True 64 | 65 | @staticmethod 66 | def requires_model() -> bool: 67 | return True 68 | 69 | def play(self, state) -> int: 70 | policy, _ = self.nn.predict(state.observation()) 71 | valids = state.valid_moves() 72 | options = policy * valids 73 | self.temp = self.args.temp_scaling_fn(self.temp, state.turns, state.max_turns()) 74 | if self.temp == 0: 75 | bestA = np.argmax(options) 76 | probs = [0] * len(options) 77 | probs[bestA] = 1 78 | else: 79 | probs = [x ** (1. / self.temp) for x in options] 80 | probs /= np.sum(probs) 81 | 82 | choice = np.random.choice( 83 | np.arange(state.action_size()), p=probs 84 | ) 85 | 86 | if valids[choice] == 0: 87 | print() 88 | print(self.temp) 89 | print(valids) 90 | print(policy) 91 | print(probs) 92 | assert valids[choice] > 0 93 | 94 | return choice 95 | 96 | def process(self, *args, **kwargs): 97 | return self.nn.process(*args, **kwargs) 98 | 99 | 100 | class MCTSPlayer(BasePlayer): 101 | def __init__(self, nn: NNetWrapper, *args, print_policy=False, 102 | average_value=False, draw_mcts=False, draw_depth=2, **kwargs): 103 | super().__init__(*args, **kwargs) 104 | self.nn = nn 105 | self.temp = self.args.startTemp 106 | self.print_policy = print_policy 107 | self.average_value = average_value 108 | self.draw_mcts = draw_mcts 109 | self.draw_depth = draw_depth 110 | self.reset() 111 | if self.verbose: 112 | self.mcts.search( 113 | self.game_cls(), self.nn, self.args.numMCTSSims, self.args.add_root_noise, self.args.add_root_temp 114 | ) 115 | value = self.mcts.value(self.average_value) 116 | self.__rel_val_split = value if value > 0.5 else 1 - value 117 | print('initial value:', self.__rel_val_split) 118 | 119 | @staticmethod 120 | def supports_process() -> bool: 121 | return True 122 | 123 | @staticmethod 124 | def requires_model() -> bool: 125 | return True 126 | 127 | def update(self, state: GameState, action: int) -> None: 128 | self.mcts.update_root(state, action) 129 | 130 | def reset(self): 131 | self.mcts = MCTS(self.args) 132 | 133 | def play(self, state) -> int: 134 | self.mcts.search(state, self.nn, self.args.numMCTSSims, self.args.add_root_noise, self.args.add_root_temp) 135 | self.temp = self.args.temp_scaling_fn(self.temp, state.turns, state.max_turns()) 136 | policy = self.mcts.probs(state, self.temp) 137 | 138 | if self.print_policy: 139 | print(f'policy: {policy}') 140 | 141 | if self.verbose: 142 | _, value = self.nn.predict(state.observation()) 143 | print('max tree depth:', self.mcts.max_depth) 144 | print(f'raw network value: {value}') 145 | 146 | value = self.mcts.value(self.average_value) 147 | rel_val = 0.5 * (value - self.__rel_val_split) / (1 - self.__rel_val_split) + 0.5 \ 148 | if value >= self.__rel_val_split else (value / self.__rel_val_split) * 0.5 149 | 150 | print(f'value for player {state.player}: {value}') 151 | print('relative value:', rel_val) 152 | 153 | if self.draw_mcts: 154 | plot_mcts_tree(self.mcts, max_depth=self.draw_depth) 155 | 156 | action = np.random.choice(len(policy), p=policy) 157 | if self.verbose: 158 | print('confidence of action:', policy[action]) 159 | 160 | return action 161 | 162 | def process(self, *args, **kwargs): 163 | return self.nn.process(*args, **kwargs) 164 | 165 | 166 | class RawMCTSPlayer(MCTSPlayer): 167 | def __init__(self, *args, **kwargs): 168 | super().__init__(None, *args, **kwargs) 169 | self._POLICY_SIZE = self.game_cls.action_size() 170 | self._POLICY_FILL_VALUE = 1 / self._POLICY_SIZE 171 | self._VALUE_SIZE = self.game_cls.num_players() + 1 172 | 173 | @staticmethod 174 | def supports_process() -> bool: 175 | return True 176 | 177 | @staticmethod 178 | def requires_model() -> bool: 179 | return False 180 | 181 | def play(self, state) -> int: 182 | self.mcts.raw_search(state, self.args.numMCTSSims, self.args.add_root_noise, self.args.add_root_temp) 183 | self.temp = self.args.temp_scaling_fn(self.temp, state.turns, state.max_turns()) 184 | policy = self.mcts.probs(state, self.temp) 185 | action = np.random.choice(len(policy), p=policy) 186 | 187 | if self.verbose: 188 | print('max tree depth:', self.mcts.max_depth) 189 | print(f'value for player {state.player}: {self.mcts.value(self.average_value)}') 190 | print(f'policy: {policy}') 191 | print('confidence of action:', policy[action]) 192 | 193 | if self.draw_mcts: 194 | plot_mcts_tree(self.mcts, max_depth=self.draw_depth) 195 | 196 | return action 197 | 198 | def process(self, batch: torch.Tensor): 199 | return torch.full((batch.shape[0], self._POLICY_SIZE), self._POLICY_FILL_VALUE).to(batch.device), \ 200 | torch.zeros(batch.shape[0], self._VALUE_SIZE).to(batch.device) 201 | -------------------------------------------------------------------------------- /alphazero/NNetArchitecture.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch.nn as nn 3 | import torch 4 | 5 | from alphazero.Game import GameState 6 | from alphazero.utils import dotdict 7 | 8 | 9 | # 1x1 convolution 10 | def conv1x1(in_channels, out_channels, stride=1): 11 | return nn.Conv2d(in_channels, out_channels, kernel_size=1, 12 | stride=stride, padding=0, bias=False) 13 | 14 | # 3*3 convolution 15 | def conv3x3(in_channels, out_channels, stride=1): 16 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, 17 | stride=stride, padding=1, bias=False) 18 | 19 | # fully connected layers 20 | def mlp( 21 | input_size: int, 22 | layer_sizes: list, 23 | output_size: int, 24 | output_activation=nn.Identity, 25 | activation=nn.ELU, 26 | ): 27 | sizes = [input_size] + layer_sizes + [output_size] 28 | layers = [] 29 | for i in range(len(sizes) - 1): 30 | act = activation if i < len(sizes) - 2 else output_activation 31 | layers += [nn.Linear(sizes[i], sizes[i + 1]), act()] 32 | return nn.Sequential(*layers) 33 | 34 | 35 | # Residual block 36 | class ResidualBlock(nn.Module): 37 | def __init__(self, in_channels, out_channels, downsample=False): 38 | super(ResidualBlock, self).__init__() 39 | 40 | stride = 1 41 | if downsample: 42 | stride = 2 43 | self.conv_ds = conv1x1(in_channels, out_channels, stride) 44 | self.bn_ds = nn.BatchNorm2d(out_channels) 45 | 46 | self.downsample = downsample 47 | self.bn1 = nn.BatchNorm2d(in_channels) 48 | self.relu = nn.ReLU(inplace=True) 49 | self.conv1 = conv3x3(in_channels, out_channels, stride) 50 | self.bn2 = nn.BatchNorm2d(out_channels) 51 | self.conv2 = conv3x3(out_channels, out_channels) 52 | 53 | def forward(self, x): 54 | residual = x 55 | out = x 56 | out = self.bn1(out) 57 | out = self.relu(out) 58 | out = self.conv1(out) 59 | out = self.bn2(out) 60 | out = self.relu(out) 61 | out = self.conv2(out) 62 | if self.downsample: 63 | residual = self.conv_ds(x) 64 | residual = self.bn_ds(residual) 65 | out += residual 66 | return out 67 | 68 | 69 | class ResNet(nn.Module): 70 | def __init__(self, game_cls: GameState, args: dotdict): 71 | super(ResNet, self).__init__() 72 | # game params 73 | self.channels, self.board_x, self.board_y = game_cls.observation_size() 74 | self.action_size = game_cls.action_size() 75 | 76 | self.conv1 = conv3x3(self.channels, args.num_channels) 77 | self.bn1 = nn.BatchNorm2d(args.num_channels) 78 | 79 | self.res_layers = [] 80 | for _ in range(args.depth): 81 | self.res_layers.append( 82 | ResidualBlock(args.num_channels, args.num_channels) 83 | ) 84 | self.resnet = nn.Sequential(*self.res_layers) 85 | 86 | self.v_conv = conv1x1(args.num_channels, args.value_head_channels) 87 | self.v_bn = nn.BatchNorm2d(args.value_head_channels) 88 | self.v_fc = mlp( 89 | self.board_x*self.board_y*args.value_head_channels, 90 | args.value_dense_layers, 91 | game_cls.num_players() + game_cls.has_draw(), 92 | activation=nn.Identity 93 | ) 94 | 95 | self.pi_conv = conv1x1(args.num_channels, args.policy_head_channels) 96 | self.pi_bn = nn.BatchNorm2d(args.policy_head_channels) 97 | self.pi_fc = mlp( 98 | self.board_x*self.board_y*args.policy_head_channels, 99 | args.policy_dense_layers, 100 | self.action_size, 101 | activation=nn.Identity 102 | ) 103 | 104 | def forward(self, s): 105 | # s: batch_size x num_channels x board_x x board_y 106 | s = s.view(-1, self.channels, self.board_x, self.board_y) 107 | s = F.relu(self.bn1(self.conv1(s))) 108 | s = self.resnet(s) 109 | 110 | v = self.v_conv(s) 111 | v = self.v_bn(v) 112 | v = torch.flatten(v, 1) 113 | v = self.v_fc(v) 114 | 115 | pi = self.pi_conv(s) 116 | pi = self.pi_bn(pi) 117 | pi = torch.flatten(pi, 1) 118 | pi = self.pi_fc(pi) 119 | 120 | return F.log_softmax(pi, dim=1), F.log_softmax(v, dim=1) 121 | 122 | 123 | class FullyConnected(nn.Module): 124 | """ 125 | Fully connected network which operates in the same way as NNetArchitecture. 126 | The fully_connected function is used to create the network, as well as the 127 | policy and value heads. Forward method returns log_softmax of policy and value head. 128 | """ 129 | def __init__(self, game_cls: GameState, args: dotdict): 130 | super(FullyConnected, self).__init__() 131 | # get input size 132 | self.input_size = sum(game_cls.observation_size()) 133 | 134 | self.input_fc = mlp( 135 | self.input_size, 136 | args.input_fc_layers, 137 | args.input_fc_layers[-1], 138 | activation=nn.ReLU 139 | ) 140 | self.v_fc = mlp( 141 | args.input_fc_layers[-1], 142 | args.value_dense_layers, 143 | game_cls.num_players() + game_cls.has_draw(), 144 | activation=nn.Identity 145 | ) 146 | self.pi_fc = mlp( 147 | args.input_fc_layers[-1], 148 | args.policy_dense_layers, 149 | self.game_cls.action_size(), 150 | activation=nn.Identity 151 | ) 152 | 153 | def forward(self, s): 154 | # s: batch_size x num_channels x board_x x board_y 155 | # reshape s for input_fc 156 | s = s.view(-1, self.input_size) 157 | 158 | s = self.input_fc(s) 159 | v = self.v_fc(s) 160 | pi = self.pi_fc(s) 161 | 162 | return F.log_softmax(pi, dim=1), F.log_softmax(v, dim=1) 163 | -------------------------------------------------------------------------------- /alphazero/SelfPlayAgent.pyx: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | 3 | import torch.multiprocessing as mp 4 | import numpy as np 5 | import torch 6 | import traceback 7 | import itertools 8 | import time 9 | 10 | from alphazero.MCTS import MCTS 11 | 12 | 13 | class SelfPlayAgent(mp.Process): 14 | def __init__(self, id, game_cls, ready_queue, batch_ready, batch_tensor, policy_tensor, 15 | value_tensor, output_queue, result_queue, complete_count, games_played, 16 | stop_event: mp.Event, pause_event: mp.Event(), args, _is_arena=False, _is_warmup=False): 17 | super().__init__() 18 | self.id = id 19 | self.game_cls = game_cls 20 | self.ready_queue = ready_queue 21 | self.batch_ready = batch_ready 22 | self.batch_tensor = batch_tensor 23 | if _is_arena: 24 | self.batch_size = policy_tensor.shape[0] 25 | else: 26 | self.batch_size = self.batch_tensor.shape[0] 27 | self.policy_tensor = policy_tensor 28 | self.value_tensor = value_tensor 29 | self.output_queue = output_queue 30 | self.result_queue = result_queue 31 | self.games = [] 32 | self.histories = [] 33 | self.temps = [] 34 | self.next_reset = [] 35 | self.mcts = [] 36 | self.games_played = games_played 37 | self.complete_count = complete_count 38 | self.stop_event = stop_event 39 | self.pause_event = pause_event 40 | self.args = args 41 | 42 | self._is_arena = _is_arena 43 | self._is_warmup = _is_warmup 44 | if _is_arena: 45 | self.player_to_index = list(range(game_cls.num_players())) 46 | np.random.shuffle(self.player_to_index) 47 | self.batch_indices = None 48 | if _is_warmup: 49 | action_size = game_cls.action_size() 50 | self._WARMUP_POLICY = torch.full((action_size,), 1 / action_size).to(policy_tensor.device) 51 | value_size = game_cls.num_players() + 1 52 | self._WARMUP_VALUE = torch.full((value_size,), 1 / value_size).to(policy_tensor.device) 53 | 54 | self.fast = False 55 | for _ in range(self.batch_size): 56 | self.games.append(self.game_cls()) 57 | self.histories.append([]) 58 | self.temps.append(self.args.startTemp) 59 | self.next_reset.append(0) 60 | self.mcts.append(self._get_mcts()) 61 | 62 | def _get_mcts(self): 63 | if self._is_arena: 64 | return tuple([MCTS(self.args) for _ in range(self.game_cls.num_players())]) 65 | else: 66 | return MCTS(self.args) 67 | 68 | def _mcts(self, index: int) -> MCTS: 69 | mcts = self.mcts[index] 70 | if self._is_arena: 71 | return mcts[self.games[index].player] 72 | else: 73 | return mcts 74 | 75 | def _check_pause(self): 76 | while self.pause_event.is_set(): 77 | time.sleep(.1) 78 | 79 | def run(self): 80 | try: 81 | np.random.seed() 82 | while not self.stop_event.is_set() and self.games_played.value < self.args.gamesPerIteration: 83 | self._check_pause() 84 | self.fast = np.random.random_sample() < self.args.probFastSim 85 | sims = self.args.numFastSims if self.fast else self.args.numMCTSSims \ 86 | if not self._is_warmup else self.args.numWarmupSims 87 | for _ in range(sims): 88 | if self.stop_event.is_set(): break 89 | self.generateBatch() 90 | if self.stop_event.is_set(): break 91 | self.processBatch() 92 | if self.stop_event.is_set(): break 93 | self.playMoves() 94 | 95 | with self.complete_count.get_lock(): 96 | self.complete_count.value += 1 97 | if not self._is_arena: 98 | self.output_queue.close() 99 | self.output_queue.join_thread() 100 | except Exception: 101 | print(traceback.format_exc()) 102 | 103 | def generateBatch(self): 104 | if self._is_arena: 105 | batch_tensor = [[] for _ in range(self.game_cls.num_players())] 106 | self.batch_indices = [[] for _ in range(self.game_cls.num_players())] 107 | 108 | for i in range(self.batch_size): 109 | self._check_pause() 110 | state = self._mcts(i).find_leaf(self.games[i]) 111 | if self._is_warmup: 112 | self.policy_tensor[i].copy_(self._WARMUP_POLICY) 113 | self.value_tensor[i].copy_(self._WARMUP_VALUE) 114 | continue 115 | 116 | data = torch.from_numpy(state.observation()) 117 | if self._is_arena: 118 | data = data.view(-1, *state.observation_size()) 119 | player = self.player_to_index[self.games[i].player] 120 | batch_tensor[player].append(data) 121 | self.batch_indices[player].append(i) 122 | else: 123 | self.batch_tensor[i].copy_(data) 124 | 125 | if self._is_arena: 126 | for player in range(self.game_cls.num_players()): 127 | player = self.player_to_index[player] 128 | data = batch_tensor[player] 129 | if data: 130 | batch_tensor[player] = torch.cat(data) 131 | self.output_queue.put(batch_tensor) 132 | self.batch_indices = list(itertools.chain.from_iterable(self.batch_indices)) 133 | 134 | if not self._is_warmup: 135 | self.ready_queue.put(self.id) 136 | 137 | def processBatch(self): 138 | if not self._is_warmup: 139 | self.batch_ready.wait() 140 | self.batch_ready.clear() 141 | 142 | for i in range(self.batch_size): 143 | self._check_pause() 144 | index = self.batch_indices[i] if self._is_arena else i 145 | self._mcts(i).process_results( 146 | self.games[i], 147 | self.value_tensor[index].data.numpy(), 148 | self.policy_tensor[index].data.numpy(), 149 | False if self._is_arena else self.args.add_root_noise, 150 | False if self._is_arena else self.args.add_root_temp 151 | ) 152 | 153 | def playMoves(self): 154 | for i in range(self.batch_size): 155 | self._check_pause() 156 | self.temps[i] = self.args.temp_scaling_fn( 157 | self.temps[i], self.games[i].turns, self.game_cls.max_turns() 158 | ) if not self._is_arena else self.args.arenaTemp 159 | policy = self._mcts(i).probs(self.games[i], self.temps[i]) 160 | action = np.random.choice(self.games[i].action_size(), p=policy) 161 | if not self.fast and not self._is_arena: 162 | self.histories[i].append(( 163 | self.games[i].clone(), 164 | self._mcts(i).probs(self.games[i]) 165 | )) 166 | 167 | if self._is_arena: 168 | [mcts.update_root(self.games[i], action) for mcts in self.mcts[i]] 169 | else: 170 | self._mcts(i).update_root(self.games[i], action) 171 | self.games[i].play_action(action) 172 | if self.args.mctsResetThreshold and self.games[i].turns >= self.next_reset[i]: 173 | self.mcts[i] = self._get_mcts() 174 | self.next_reset[i] = self.games[i].turns + self.args.mctsResetThreshold 175 | 176 | winstate = self.games[i].win_state() 177 | if winstate.any(): 178 | self.result_queue.put((self.games[i].clone(), winstate, self.id)) 179 | lock = self.games_played.get_lock() 180 | lock.acquire() 181 | if self.games_played.value < self.args.gamesPerIteration: 182 | self.games_played.value += 1 183 | lock.release() 184 | if not self._is_arena: 185 | for hist in self.histories[i]: 186 | self._check_pause() 187 | if self.args.symmetricSamples: 188 | data = hist[0].symmetries(hist[1]) 189 | else: 190 | data = ((hist[0], hist[1]),) 191 | 192 | for state, pi in data: 193 | self._check_pause() 194 | self.output_queue.put(( 195 | state.observation(), pi, np.array(winstate, dtype=np.float32) 196 | )) 197 | self.games[i] = self.game_cls() 198 | self.histories[i] = [] 199 | self.temps[i] = self.args.startTemp 200 | self.mcts[i] = self._get_mcts() 201 | else: 202 | lock.release() 203 | -------------------------------------------------------------------------------- /alphazero/__init__.py: -------------------------------------------------------------------------------- 1 | from pyximport import install as pyxinstall 2 | from numpy import get_include 3 | 4 | pyxinstall(setup_args={'include_dirs': get_include()}) 5 | 6 | from alphazero.Coach import DEFAULT_ARGS 7 | from alphazero.Game import GameState 8 | 9 | # Options for args eval 10 | from torch.optim import * 11 | from torch.optim.lr_scheduler import * 12 | from alphazero.GenericPlayers import * 13 | from alphazero.utils import default_temp_scaling, const_temp_scaling 14 | 15 | import json 16 | import os 17 | 18 | CALLABLE_PREFIX = '__CALLABLE__' 19 | 20 | 21 | def load_args_file(filepath: str) -> dotdict: 22 | new_args = dotdict() 23 | raw_args = json.load(open(filepath, 'r')) 24 | 25 | for k, v in raw_args.items(): 26 | if isinstance(v, str) and CALLABLE_PREFIX in v: 27 | try: 28 | v = eval(v.replace(CALLABLE_PREFIX, '')) 29 | except Exception as e: 30 | raise RuntimeError('Failed to parse argument file: ' + str(e)) 31 | 32 | elif isinstance(v, dict): 33 | v = dotdict(v) 34 | 35 | new_args.update({k: v}) 36 | 37 | return new_args 38 | 39 | 40 | def save_args_file(args: dotdict or dict, filepath, replace=True): 41 | if not replace and os.path.exists(filepath): return 42 | 43 | save_args = dict() 44 | for k, v in args.items(): 45 | if callable(v): 46 | v = CALLABLE_PREFIX + v.__name__ 47 | save_args.update({k: v}) 48 | 49 | with open(filepath, 'w') as f: 50 | json.dump(save_args, f) 51 | 52 | return save_args 53 | -------------------------------------------------------------------------------- /alphazero/cGame.pxd: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | # cython: boundscheck=False 3 | # cython: wraparound=False 4 | # cython: nonecheck=False 5 | # cython: overflowcheck=False 6 | # cython: initializedcheck=False 7 | # cython: cdivision=True 8 | # cython: auto_pickle=True 9 | 10 | cimport numpy as np 11 | 12 | cdef class GameState: 13 | cdef public object _board 14 | cdef public int _player 15 | cdef public int _turns 16 | 17 | cpdef GameState clone(self) 18 | @staticmethod 19 | cdef int action_size() 20 | @staticmethod 21 | cdef tuple observation_size() 22 | @staticmethod 23 | cdef int num_players() 24 | cpdef np.ndarray valid_moves(self) 25 | cpdef int _next_player(self, int player, int turns=*) 26 | cpdef void _update_turn(self) 27 | cpdef void play_action(self, int action) 28 | cpdef np.ndarray win_state(self) 29 | cpdef float[:, :, :] observation(self) 30 | cpdef list symmetries(self, float[:] pi) 31 | -------------------------------------------------------------------------------- /alphazero/cGame.pyx: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | # cython: boundscheck=False 3 | # cython: wraparound=False 4 | # cython: nonecheck=False 5 | # cython: overflowcheck=False 6 | # cython: initializedcheck=False 7 | # cython: cdivision=True 8 | # cython: auto_pickle=True 9 | 10 | from typing import Tuple, List, Any 11 | cimport numpy as np 12 | 13 | cdef class GameState: 14 | def __init__(self, object board): 15 | self._board = board 16 | self._player = 0 17 | self._turns = 0 18 | 19 | def __str__(self) -> str: 20 | return f'Player:\t{self._player}\n{self._board}\n' 21 | 22 | def __eq__(self, other: 'GameState') -> bool: 23 | """Compare the current game state to an other""" 24 | pass 25 | 26 | cpdef GameState clone(self): 27 | """Return a new clone of the game state, independent of the current one.""" 28 | pass 29 | 30 | @staticmethod 31 | cdef int action_size(): 32 | """The size of the action space for the game""" 33 | pass 34 | 35 | @staticmethod 36 | cdef tuple observation_size(): 37 | """ 38 | Returns: 39 | observation_size: the shape of observations of the current state, 40 | must be in the form channels x width x height. 41 | If only one plane is needed for observation, use 1 for channels. 42 | """ 43 | pass 44 | 45 | cpdef np.ndarray valid_moves(self): 46 | """Returns a numpy binary array containing zeros for invalid moves and ones for valids.""" 47 | pass 48 | 49 | @staticmethod 50 | cdef int num_players(): 51 | """ 52 | Returns: 53 | num_players: the number of total players participating in the game. 54 | """ 55 | pass 56 | 57 | @property 58 | def player(self) -> int: 59 | return self._player 60 | 61 | @property 62 | def turns(self): 63 | return self._turns 64 | 65 | cpdef int _next_player(self, int player, int turns=1): 66 | return (player + turns) % GameState.num_players() 67 | 68 | cpdef void _update_turn(self): 69 | """Should be called at the end of play_action""" 70 | self._player = self._next_player(self._player) 71 | self._turns += 1 72 | 73 | cpdef void play_action(self, int action): 74 | """Play the action in the current state given by argument action.""" 75 | pass 76 | 77 | cpdef np.ndarray win_state(self): 78 | """ 79 | Get the win state of the game, a tuple of boolean values 80 | for each player indicating if they have won, plus one more 81 | boolean at the end to indicate a draw. 82 | """ 83 | pass 84 | 85 | cpdef float[:, :, :] observation(self): 86 | """Get an observation from the game state in the form of a numpy array with the size of self.observation_size""" 87 | pass 88 | 89 | cpdef list symmetries(self, float[:] pi): 90 | """ 91 | Args: 92 | pi: the current policy for the given canonical state 93 | 94 | Returns: 95 | symmetries: list of state, pi pairs for symmetric samples of 96 | the given state and pi (ex: mirror, rotation). 97 | This is an optional method as symmetric samples 98 | can be disabled for training. 99 | """ 100 | pass 101 | -------------------------------------------------------------------------------- /alphazero/envs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/alphazero/envs/__init__.py -------------------------------------------------------------------------------- /alphazero/envs/brandubh/__init__.py: -------------------------------------------------------------------------------- 1 | from .brandubh import * 2 | -------------------------------------------------------------------------------- /alphazero/envs/brandubh/brandubh.py: -------------------------------------------------------------------------------- 1 | import pyximport, numpy 2 | pyximport.install(setup_args={'include_dirs': numpy.get_include()}) 3 | from .fastafl import Game, display 4 | -------------------------------------------------------------------------------- /alphazero/envs/brandubh/players.py: -------------------------------------------------------------------------------- 1 | import pyximport, numpy as np 2 | pyximport.install(setup_args={'include_dirs': np.get_include()}) 3 | 4 | #from hnefatafl.engine import Move, BoardGameException 5 | #from alphazero.envs.brandubh.tafl import get_action 6 | from alphazero.envs.brandubh.fastafl import get_action as ft_get_action 7 | from alphazero.GenericPlayers import BasePlayer 8 | from alphazero.Game import GameState 9 | from alphazero.Evaluator import MCTSEvaluator 10 | 11 | from boardgame.board import Square 12 | from boardgame.errors import InvalidMoveError 13 | 14 | 15 | """ 16 | class HumanTaflPlayer(BasePlayer): 17 | def play(self, state: GameState): 18 | valid_moves = state.valid_moves() 19 | 20 | def string_to_action(player_inp: str) -> int: 21 | try: 22 | move_lst = [int(x) for x in player_inp.split()] 23 | move = Move(state._board, move_lst) 24 | return get_action(state._board, move) 25 | except (ValueError, AttributeError, BoardGameException): 26 | return -1 27 | 28 | action = string_to_action(input(f"Enter the move to play for the player {state.player}: ")) 29 | while action == -1 or not valid_moves[action]: 30 | action = string_to_action(input(f"Illegal move (action={action}, " 31 | f"in valids: {bool(valid_moves[action])}). Enter a valid move: ")) 32 | 33 | return action 34 | """ 35 | 36 | 37 | class HumanFastaflPlayer(BasePlayer): 38 | @staticmethod 39 | def is_human() -> bool: 40 | return True 41 | 42 | def play(self, state: GameState): 43 | valid_moves = state.valid_moves() 44 | 45 | def string_to_action(player_inp: str) -> int: 46 | try: 47 | move_lst = [int(x) for x in player_inp.split()] 48 | return ft_get_action(state._board, (Square(*move_lst[:2]), Square(*move_lst[2:]))) 49 | except (ValueError, AttributeError, InvalidMoveError): 50 | return -1 51 | 52 | action = string_to_action(input(f"Enter the move to play for the player {state.player}: ")) 53 | while action == -1 or not valid_moves[action]: 54 | action = string_to_action(input(f"Illegal move (action={action}, " 55 | f"in valids: {bool(valid_moves[action])}). Enter a valid move: ")) 56 | 57 | return action 58 | 59 | 60 | 61 | class GreedyTaflPlayer(BasePlayer): 62 | def play(self, state: GameState): 63 | valids = state.valid_moves() 64 | candidates = [] 65 | 66 | for a in range(state.action_size()): 67 | if not valids[a]: continue 68 | new_state = state.clone() 69 | new_state.play_action(a) 70 | candidates.append((-new_state.crude_value(), a)) 71 | 72 | candidates.sort() 73 | return candidates[0][1] 74 | 75 | 76 | class GreedyMCTSTaflPlayer(BasePlayer): 77 | def __init__(self, *args, **kwargs): 78 | super().__init__(*args, **kwargs) 79 | self.evaluator = MCTSEvaluator( 80 | args=self.args, 81 | model=self._crude_model, 82 | #num_sims=self.args.numMCTSSims 83 | max_search_time=20 84 | ) 85 | 86 | def _crude_model(self, state: GameState): 87 | value = state.crude_value() 88 | return ( 89 | np.full(state.action_size(), 1, dtype=np.float32), 90 | np.array([value, 1 - value, 0], dtype=np.float32) 91 | ) 92 | 93 | def play(self, state: GameState): 94 | self.evaluator.run(state, block=True) 95 | print('[DEBUG] GreedyMCTS value:', self.evaluator.get_value()) 96 | return self.evaluator.get_best_actions()[0] 97 | 98 | def update(self, state: GameState, action: int) -> None: 99 | self.evaluator.update(state, action) 100 | -------------------------------------------------------------------------------- /alphazero/envs/chess/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/alphazero/envs/chess/__init__.py -------------------------------------------------------------------------------- /alphazero/envs/chess/chess.py: -------------------------------------------------------------------------------- 1 | from alphazero.Game import GameState 2 | from typing import List, Tuple, Any 3 | 4 | import chess 5 | import string 6 | 7 | DIGS = string.digits + string.ascii_letters 8 | 9 | NUM_PLAYERS = 2 10 | BOARD_SIZE = 8 11 | ACTION_SIZE = BOARD_SIZE ^ 4 12 | NUM_CHANNELS = 1 #placeholder 13 | OBSERVATION_SIZE = (NUM_CHANNELS, BOARD_SIZE, BOARD_SIZE) 14 | 15 | 16 | # TODO: Use https://github.com/Unimax/alpha-zero-general-chess-and-battlesnake/blob/master/chesspy/ChessGame.py for 17 | # implementation 18 | 19 | 20 | def _int2base(x, base, length): 21 | if x < 0: 22 | sign = -1 23 | elif x == 0: 24 | return [DIGS[0]]*length 25 | else: 26 | sign = 1 27 | 28 | x *= sign 29 | digits = [] 30 | 31 | while x: 32 | digits.append(DIGS[int(x % base)]) 33 | x //= base 34 | 35 | if sign < 0: 36 | digits.append('-') 37 | 38 | while len(digits) < length: digits.append('0') 39 | return list(map(lambda x: int(x, base), digits)) 40 | 41 | 42 | class Game(GameState): 43 | def __init__(self): 44 | super().__init__(self._get_board()) 45 | 46 | @staticmethod 47 | def _get_board(): 48 | return chess.Board() 49 | 50 | def __eq__(self, other) -> bool: 51 | return ( 52 | self._board == other._board 53 | and self._player == other._player 54 | and self.turns == other.turns 55 | ) 56 | 57 | def clone(self): 58 | g = ChessGame() 59 | g._board = self._board.copy() 60 | g._player = self._player 61 | g._turns = self.turns 62 | return g 63 | 64 | @staticmethod 65 | def action_size() -> int: 66 | return ACTION_SIZE 67 | 68 | @staticmethod 69 | def observation_size() -> Tuple[int, int, int]: 70 | return OBSERVATION_SIZE 71 | 72 | def valid_moves(self): 73 | valids = [0] * self.action_size() 74 | for move in self._board.legal_moves: 75 | valids[ 76 | move.tile.x 77 | + move.tile.y * BOARD_SIZE 78 | + move.new_tile.x * BOARD_SIZE ** 2 79 | + move.new_tile.y * BOARD_SIZE ** 3 80 | ] = 1 81 | 82 | def play_action(self, action: int) -> None: 83 | pass 84 | 85 | def win_state(self) -> Tuple[bool, int]: 86 | pass 87 | 88 | def observation(self): 89 | pass 90 | 91 | def symmetries(self, pi) -> List[Tuple[Any, int]]: 92 | pass 93 | -------------------------------------------------------------------------------- /alphazero/envs/chess/train.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/alphazero/envs/chess/train.py -------------------------------------------------------------------------------- /alphazero/envs/connect4/Connect4Logic.pyx: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | # cython: boundscheck=False 3 | # cython: wraparound=False 4 | # cython: nonecheck=False 5 | # cython: overflowcheck=False 6 | # cython: initializedcheck=False 7 | # cython: cdivision=True 8 | # cython: auto_pickle=True 9 | # cython: profile=True 10 | 11 | import numpy as np 12 | 13 | 14 | cdef class Board(): 15 | """ 16 | Connect4 Board. 17 | """ 18 | 19 | cdef int height 20 | cdef int width 21 | cdef int length 22 | cdef int win_length 23 | cdef public int[:,:] pieces 24 | 25 | def __init__(self, int height, int width, int win_length): 26 | """Set up initial board configuration.""" 27 | self.height = height 28 | self.width = width 29 | self.win_length = win_length 30 | 31 | self.pieces = np.zeros((self.height, self.width), dtype=np.intc) 32 | 33 | def __getstate__(self): 34 | return self.height, self.width, self.win_length, np.asarray(self.pieces) 35 | 36 | def __setstate__(self, state): 37 | self.height, self.width, self.win_length, pieces = state 38 | self.pieces = np.asarray(pieces) 39 | 40 | def add_stone(self, int column, int player): 41 | """Create copy of board containing new stone.""" 42 | cdef Py_ssize_t r 43 | for r in range(self.height): 44 | if self.pieces[(self.height-1)-r,column] == 0: 45 | self.pieces[(self.height-1)-r,column] = player 46 | return 47 | 48 | raise ValueError("Can't play column %s on board %s" % (column, self)) 49 | 50 | def get_valid_moves(self): 51 | """Any zero value in top row is a valid move""" 52 | cdef Py_ssize_t c 53 | cdef int[:] valid = np.zeros((self.width), dtype=np.intc) 54 | for c in range(self.width): 55 | if self.pieces[0,c] == 0: 56 | valid[c] = 1 57 | 58 | return valid 59 | 60 | def get_win_state(self): 61 | cdef int player 62 | cdef int total 63 | cdef int good 64 | cdef Py_ssize_t r, c, x 65 | for player in [1, -1]: 66 | #check row wins 67 | for r in range(self.height): 68 | total = 0 69 | for c in range(self.width): 70 | if self.pieces[r,c] == player: 71 | total += 1 72 | else: 73 | total = 0 74 | if total == self.win_length: 75 | return (True, player) 76 | #check column wins 77 | for c in range(self.width): 78 | total = 0 79 | for r in range(self.height): 80 | if self.pieces[r,c] == player: 81 | total += 1 82 | else: 83 | total = 0 84 | if total == self.win_length: 85 | return (True, player) 86 | #check diagonal 87 | for r in range(self.height - self.win_length + 1): 88 | for c in range(self.width - self.win_length + 1): 89 | good = True 90 | for x in range(self.win_length): 91 | if self.pieces[r+x,c+x] != player: 92 | good = False 93 | break 94 | if good: 95 | return (True, player) 96 | for c in range(self.win_length - 1, self.width): 97 | good = True 98 | for x in range(self.win_length): 99 | if self.pieces[r+x,c-x] != player: 100 | good = False 101 | break 102 | if good: 103 | return (True, player) 104 | 105 | # draw has very little value. 106 | if sum(self.get_valid_moves()) == 0: 107 | return (True, 0) 108 | 109 | # Game is not ended yet. 110 | return (False, 0) 111 | 112 | def __str__(self): 113 | return str(np.asarray(self.pieces)) 114 | -------------------------------------------------------------------------------- /alphazero/envs/connect4/__init__.py: -------------------------------------------------------------------------------- 1 | from .connect4 import * 2 | -------------------------------------------------------------------------------- /alphazero/envs/connect4/connect4.pyx: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | # cython: auto_pickle=True 3 | # cython: profile=True 4 | from typing import List, Tuple, Any 5 | 6 | from alphazero.Game import GameState 7 | from alphazero.envs.connect4.Connect4Logic import Board 8 | 9 | import numpy as np 10 | 11 | DEFAULT_HEIGHT = 6 12 | DEFAULT_WIDTH = 7 13 | DEFAULT_WIN_LENGTH = 4 14 | NUM_PLAYERS = 2 15 | MAX_TURNS = 42 16 | MULTI_PLANE_OBSERVATION = True 17 | NUM_CHANNELS = 4 if MULTI_PLANE_OBSERVATION else 1 18 | 19 | 20 | class Game(GameState): 21 | def __init__(self): 22 | super().__init__(self._get_board()) 23 | 24 | @staticmethod 25 | def _get_board(): 26 | return Board(DEFAULT_HEIGHT, DEFAULT_WIDTH, DEFAULT_WIN_LENGTH) 27 | 28 | def __hash__(self) -> int: 29 | return hash(self._board.pieces.tobytes() + bytes([self.turns]) + bytes([self._player])) 30 | 31 | def __eq__(self, other: 'Game') -> bool: 32 | return self._board.pieces == other._board.pieces and self._player == other._player and self.turns == other.turns 33 | 34 | def clone(self) -> 'Game': 35 | game = Game() 36 | game._board.pieces = np.copy(np.asarray(self._board.pieces)) 37 | game._player = self._player 38 | game._turns = self.turns 39 | game.last_action = self.last_action 40 | return game 41 | 42 | @staticmethod 43 | def max_turns() -> int: 44 | return MAX_TURNS 45 | 46 | @staticmethod 47 | def has_draw() -> bool: 48 | return True 49 | 50 | @staticmethod 51 | def num_players() -> int: 52 | return NUM_PLAYERS 53 | 54 | @staticmethod 55 | def action_size() -> int: 56 | return DEFAULT_WIDTH 57 | 58 | @staticmethod 59 | def observation_size() -> Tuple[int, int, int]: 60 | return NUM_CHANNELS, DEFAULT_HEIGHT, DEFAULT_WIDTH 61 | 62 | def valid_moves(self): 63 | return np.asarray(self._board.get_valid_moves()) 64 | 65 | def play_action(self, action: int) -> None: 66 | super().play_action(action) 67 | self._board.add_stone(action, (1, -1)[self.player]) 68 | self._update_turn() 69 | 70 | def win_state(self) -> np.ndarray: 71 | result = [False] * 3 72 | game_over, player = self._board.get_win_state() 73 | 74 | if game_over: 75 | index = -1 76 | if player == 1: 77 | index = 0 78 | elif player == -1: 79 | index = 1 80 | result[index] = True 81 | 82 | return np.array(result, dtype=np.uint8) 83 | 84 | def observation(self): 85 | if MULTI_PLANE_OBSERVATION: 86 | pieces = np.asarray(self._board.pieces) 87 | player1 = np.where(pieces == 1, 1, 0) 88 | player2 = np.where(pieces == -1, 1, 0) 89 | colour = np.full_like(pieces, self.player) 90 | turn = np.full_like(pieces, self.turns / MAX_TURNS, dtype=np.float32) 91 | return np.array([player1, player2, colour, turn], dtype=np.float32) 92 | 93 | else: 94 | return np.expand_dims(np.asarray(self._board.pieces), axis=0) 95 | 96 | def symmetries(self, pi) -> List[Tuple[Any, int]]: 97 | new_state = self.clone() 98 | new_state._board.pieces = self._board.pieces[:, ::-1] 99 | return [(self.clone(), pi), (new_state, pi[::-1])] 100 | 101 | 102 | def display(board, action=None): 103 | if action: 104 | print(f'Action: {action}, Move: {action + 1}') 105 | print(" -----------------------") 106 | #print(' '.join(map(str, range(len(board[0]))))) 107 | print(board) 108 | print(" -----------------------") 109 | -------------------------------------------------------------------------------- /alphazero/envs/connect4/gui.py: -------------------------------------------------------------------------------- 1 | import pyximport; pyximport.install() 2 | 3 | from AlphaZeroGUI.CustomGUI import CustomGUI, GameWindow, NUM_BEST_ACTIONS 4 | from alphazero.envs.connect4.connect4 import Game 5 | from PySide2.QtCore import Qt 6 | 7 | 8 | class GUI(CustomGUI): 9 | def __init__(self, *args, **kwargs): 10 | super().__init__(Game, *args, **kwargs) 11 | _, self.height, self.width = Game.observation_size() 12 | self.window = GameWindow( 13 | self.width, 14 | self.height, 15 | cell_size=100, 16 | title=self.title, 17 | # image_dir=str(Path(__file__).parent / 'img'), 18 | evaluator=self.evaluator, 19 | verbose=True, 20 | num_best_actions=NUM_BEST_ACTIONS if self.show_hints else 0, 21 | use_evaluator=(self.evaluator is not None), 22 | action_to_move=lambda state, action: str(action + 1) 23 | ) 24 | if self.show_hints: 25 | self.window.eval_stats_timer.timeout.connect(self._update_draw_actions) 26 | self.board = self.window.game_board 27 | self.board.tileClicked.connect(self._mouse_click) 28 | self.board.closing.connect(self.on_window_close) 29 | 30 | self.board.add_circle_pixmap(1, Qt.black) 31 | self.board.add_circle_pixmap(2, Qt.white) 32 | 33 | self.update_state(self._state) 34 | 35 | def _update_draw_actions(self): 36 | if self.evaluator is None or not self.evaluator.is_running: 37 | return 38 | 39 | actions = self.evaluator.get_best_actions() 40 | if not actions: 41 | return 42 | 43 | self.board.clear_fills() 44 | self.board.fill_tile(actions[0], 0, Qt.green) 45 | self.board.fill_tile(actions[-1], 0, Qt.red) 46 | 47 | self.board.update() 48 | 49 | def _mouse_click(self, action, _): 50 | if ( 51 | self.user_input 52 | and self._state.valid_moves()[action] 53 | and callable(self.on_player_move) 54 | ): 55 | self.on_player_move(action) 56 | 57 | def show(self): 58 | self.window.show() 59 | 60 | def close(self): 61 | self.window.close() 62 | super().close() 63 | 64 | def undo(self): 65 | raise NotImplementedError 66 | 67 | def update_state(self, state): 68 | for x in range(self.width): 69 | for y in range(self.height): 70 | piece = state._board.pieces[y][x] 71 | if piece == -1: 72 | piece = 2 73 | self.board.set_tile(x, y, piece if piece else None) 74 | 75 | if state.last_action is not None: 76 | # remove previous highlight 77 | self.board.remove_highlights() 78 | # highlight the tile where the piece landed 79 | for y in range(self.height): 80 | if state._board.pieces[y][state.last_action] != 0: 81 | self.board.highlight_tile(state.last_action, y) 82 | break 83 | 84 | if state.win_state().any(): 85 | self.user_input = False 86 | self.window.stop_evaluator() 87 | else: 88 | self.window.side_menu.update_turn(state.player + 1) 89 | self.window.run_evaluator(state, block=False) 90 | 91 | self.window.update() 92 | super().update_state(state) 93 | 94 | 95 | if __name__ == '__main__': 96 | from PySide2.QtWidgets import QApplication 97 | import sys 98 | 99 | app = QApplication(sys.argv) 100 | gui = GUI(title='Connect 4') 101 | gui.show() 102 | sys.exit(app.exec_()) 103 | 104 | -------------------------------------------------------------------------------- /alphazero/envs/connect4/pit.py: -------------------------------------------------------------------------------- 1 | import pyximport 2 | 3 | pyximport.install() 4 | 5 | from alphazero.Arena import Arena 6 | from alphazero.GenericPlayers import * 7 | from alphazero.NNetWrapper import NNetWrapper as NNet 8 | 9 | """ 10 | use this script to play any two agents against each other, or play manually with 11 | any agent. 12 | """ 13 | if __name__ == '__main__': 14 | from alphazero.envs.connect4.Connect4Game import Game as Game, display 15 | from alphazero.envs.connect4.Connect4Players import HumanConnect4Player 16 | from alphazero.envs.connect4.train import args 17 | 18 | import random 19 | 20 | args.numMCTSSims = 2000 21 | #args.arena_batch_size = 64 22 | args.temp_scaling_fn = lambda x, y, z: 0 23 | args.add_root_noise = args.add_root_temp = False 24 | 25 | # all players 26 | # rp = RandomPlayer(g).play 27 | # gp = OneStepLookaheadConnect4Player(g).play 28 | player2 = HumanConnect4Player() 29 | 30 | # nnet players 31 | nn1 = NNet(Game, args) 32 | nn1.load_checkpoint('./checkpoint/connect4_fpu', 'iteration-0035.pkl') 33 | #nn2 = NNet(Game, args) 34 | #nn2.load_checkpoint('./checkpoint/connect4', 'iteration-0094.pkl') 35 | #player1 = nn1.process 36 | #player2 = nn1.process 37 | 38 | # player2 = NNPlayer(g, nn1, args=args, verbose=True).play 39 | player1 = MCTSPlayer(Game, nn1, args=args, verbose=True, print_policy=True)#, draw_mcts=True, draw_depth=3) 40 | #args2 = args.copy() 41 | #args2.numMCTSSims = 10 42 | #player2 = MCTSPlayer(Game, nn1, args=args, verbose=True, draw_mcts=True, draw_depth=3) 43 | #player2 = RandomPlayer() 44 | #player2 = RawMCTSPlayer(Game, args).process 45 | 46 | players = [player2, player1] 47 | #random.shuffle(players) 48 | arena = Arena(players, Game, use_batched_mcts=False, args=args, display=display) 49 | 50 | """ 51 | wins, draws, winrates = arena.play_games(256) 52 | for i in range(len(wins)): 53 | print(f'player{i+1}:\n\twins: {wins[i]}\n\twin rate: {winrates[i]}') 54 | print('draws: ', draws) 55 | """ 56 | 57 | _, result = arena.play_game(verbose=True) 58 | print('Game result:', result) 59 | -------------------------------------------------------------------------------- /alphazero/envs/connect4/players.py: -------------------------------------------------------------------------------- 1 | from alphazero.Game import GameState 2 | from alphazero.GenericPlayers import BasePlayer 3 | 4 | import numpy as np 5 | 6 | 7 | class HumanConnect4Player(BasePlayer): 8 | @staticmethod 9 | def is_human() -> bool: 10 | return True 11 | 12 | def play(self, state: GameState) -> int: 13 | valid_moves = state.valid_moves() 14 | print('\nMoves:', [i for (i, valid) 15 | in enumerate(valid_moves) if valid]) 16 | 17 | while True: 18 | move = int(input()) 19 | if valid_moves[move]: 20 | break 21 | else: 22 | print('Invalid move') 23 | return move 24 | 25 | 26 | class OneStepLookaheadConnect4Player(BasePlayer): 27 | """Simple player who always takes a win if presented, or blocks a loss if obvious, otherwise is random.""" 28 | 29 | def __init__(self, verbose=False): 30 | self.verbose = verbose 31 | 32 | def play(self, state: GameState) -> int: 33 | valid_moves = state.valid_moves() 34 | win_move_set = set() 35 | fallback_move_set = set() 36 | stop_loss_move_set = set() 37 | 38 | for move, valid in enumerate(valid_moves): 39 | if not valid: continue 40 | 41 | new_state = state.clone() 42 | new_state.play_action(move) 43 | ws = new_state.win_state() 44 | if ws[state.player]: 45 | win_move_set.add(move) 46 | elif ws[new_state.player]: 47 | stop_loss_move_set.add(move) 48 | else: 49 | fallback_move_set.add(move) 50 | 51 | if len(win_move_set) > 0: 52 | ret_move = np.random.choice(list(win_move_set)) 53 | if self.verbose: 54 | print('Playing winning action %s from %s' % 55 | (ret_move, win_move_set)) 56 | elif len(stop_loss_move_set) > 0: 57 | ret_move = np.random.choice(list(stop_loss_move_set)) 58 | if self.verbose: 59 | print('Playing loss stopping action %s from %s' % 60 | (ret_move, stop_loss_move_set)) 61 | elif len(fallback_move_set) > 0: 62 | ret_move = np.random.choice(list(fallback_move_set)) 63 | if self.verbose: 64 | print('Playing random action %s from %s' % 65 | (ret_move, fallback_move_set)) 66 | else: 67 | raise Exception('No valid moves remaining: %s' % state) 68 | 69 | return ret_move 70 | -------------------------------------------------------------------------------- /alphazero/envs/connect4/test_connect4.py: -------------------------------------------------------------------------------- 1 | """ 2 | To run tests: 3 | pytest-3 connect4 4 | """ 5 | import pyximport; pyximport.install() 6 | from collections import namedtuple 7 | import textwrap 8 | import numpy as np 9 | 10 | from .Connect4Game import Game 11 | 12 | # Tuple of (Board, Player, Game) to simplify testing. 13 | BPGTuple = namedtuple('BPGTuple', 'board player game') 14 | 15 | 16 | def init_board_from_moves(moves, height=None, width=None): 17 | """Returns a BPGTuple based on series of specified moved.""" 18 | game = Game(height=height, width=width) 19 | board, player = game.getInitBoard(), 1 20 | for move in moves: 21 | board, player = game.getNextState(board, player, move) 22 | return BPGTuple(board, player, game) 23 | 24 | 25 | def init_board_from_array(board, player): 26 | """Returns a BPGTuple based on series of specified moved.""" 27 | game = Game(height=len(board), width=len(board[0])) 28 | return BPGTuple(board, player, game) 29 | 30 | 31 | def test_simple_moves(): 32 | board, player, game = init_board_from_moves([4, 5, 4, 3, 0, 6]) 33 | expected = np.array( 34 | [[ 0, 0, 0, 0, 0, 0, 0], 35 | [ 0, 0, 0, 0, 0, 0, 0], 36 | [ 0, 0, 0, 0, 0, 0, 0], 37 | [ 0, 0, 0, 0, 0, 0, 0], 38 | [ 0, 0, 0, 0, 1, 0, 0], 39 | [ 1, 0, 0, -1, 1, -1, -1]], dtype=np.intc).tostring() 40 | assert expected == game.stringRepresentation(board) 41 | 42 | 43 | def test_overfull_column(): 44 | for height in range(1, 10): 45 | # Fill to max height is ok 46 | init_board_from_moves([4] * height, height=height) 47 | 48 | # Check overfilling causes an error. 49 | try: 50 | init_board_from_moves([4] * (height + 1), height=height) 51 | assert False, "Expected error when overfilling column" 52 | except ValueError: 53 | pass # Expected. 54 | 55 | 56 | def test_get_valid_moves(): 57 | """Tests vector of valid moved is correct.""" 58 | move_valid_pairs = [ 59 | ([], [True] * 7), 60 | ([0, 1, 2, 3, 4, 5, 6], [True] * 7), 61 | ([0, 1, 2, 3, 4, 5, 6] * 5, [True] * 7), 62 | ([0, 1, 2, 3, 4, 5, 6] * 6, [False] * 7), 63 | ([0, 1, 2] * 3 + [3, 4, 5, 6] * 6, [True] * 3 + [False] * 4), 64 | ] 65 | 66 | for moves, expected_valid in move_valid_pairs: 67 | board, player, game = init_board_from_moves(moves) 68 | assert (np.array(expected_valid) == game.getValidMoves(board, player)).all() 69 | 70 | 71 | def test_symmetries(): 72 | """Tests symetric board are produced.""" 73 | board, player, game = init_board_from_moves([0, 0, 1, 0, 6]) 74 | pi = 0.8 75 | (board1, pi1), (board2, pi2) = game.getSymmetries(board, pi) 76 | assert pi == pi1 and pi == pi2 77 | 78 | expected_board1 = np.array( 79 | [[ 0, 0, 0, 0, 0, 0, 0], 80 | [ 0, 0, 0, 0, 0, 0, 0], 81 | [ 0, 0, 0, 0, 0, 0, 0], 82 | [-1, 0, 0, 0, 0, 0, 0], 83 | [-1, 0, 0, 0, 0, 0, 0], 84 | [ 1, 1, 0, 0, 0, 0, 1]], dtype=np.intc).tostring() 85 | assert expected_board1 == game.stringRepresentation(board1) 86 | 87 | expected_board2 = np.array( 88 | [[ 0, 0, 0, 0, 0, 0, 0], 89 | [ 0, 0, 0, 0, 0, 0, 0], 90 | [ 0, 0, 0, 0, 0, 0, 0], 91 | [ 0, 0, 0, 0, 0, 0, -1], 92 | [ 0, 0, 0, 0, 0, 0, -1], 93 | [ 1, 0, 0, 0, 0, 1, 1]], dtype=np.intc).tostring() 94 | assert expected_board2 == game.stringRepresentation(board2) 95 | 96 | 97 | def test_game_ended(): 98 | """Tests game end detection logic based on fixed boards.""" 99 | array_end_state_pairs = [ 100 | (np.array([[0, 0, 0, 0, 0, 0, 0], 101 | [0, 0, 0, 0, 0, 0, 0], 102 | [0, 0, 0, 0, 0, 0, 0], 103 | [0, 0, 0, 0, 0, 0, 0], 104 | [0, 0, 0, 0, 0, 0, 0]], dtype=np.intc), 1, 0), 105 | (np.array([[0, 0, 0, 0, 0, 0, 0], 106 | [0, 0, 0, 0, 0, 1, 0], 107 | [0, 0, 0, 0, 1, 0, 0], 108 | [0, 0, 0, 1, 0, 0, 0], 109 | [0, 0, 1, 0, 0, 0, 0], 110 | [0, 0, 0, 0, 0, 0, 0]], dtype=np.intc), 1, 1), 111 | (np.array([[0, 0, 0, 0, 1, 0, 0], 112 | [0, 0, 0, 1, 0, 0, 0], 113 | [0, 0, 1, 0, 0, 0, 0], 114 | [0, 1, 0, 0, 0, 0, 0], 115 | [0, 0, 0, 0, 0, 0, 0]], dtype=np.intc), -1, -1), 116 | (np.array([[0, 0, 0, 0, 0, 0, 0], 117 | [0, 0, 1, 0, 0, 0, 0], 118 | [0, 0, 0, 1, 0, 0, 0], 119 | [0, 0, 0, 0, 1, 0, 0], 120 | [0, 0, 0, 0, 0, 1, 0]], dtype=np.intc), -1, -1), 121 | (np.array([[0, 0, 0, -1], 122 | [0, 0, -1, 0], 123 | [0, -1, 0, 0], 124 | [-1, 0, 0, 0]], dtype=np.intc), 1, -1), 125 | (np.array([[0, 0, 0, 0, 1], 126 | [0, 0, 0, 1, 0], 127 | [0, 0, 1, 0, 0], 128 | [0, 1, 0, 0, 0]], dtype=np.intc), -1, -1), 129 | (np.array([[1, 0, 0, 0, 0], 130 | [0, 1, 0, 0, 0], 131 | [0, 0, 1, 0, 0], 132 | [0, 0, 0, 1, 0]], dtype=np.intc), -1, -1), 133 | (np.array([[ 0, 0, 0, 0, 0, 0, 0], 134 | [ 0, 0, 0, -1, 0, 0, 0], 135 | [ 0, 0, 0, -1, 0, 0, 1], 136 | [ 0, 0, 0, 1, 1, -1, -1], 137 | [ 0, 0, 0, -1, 1, 1, 1], 138 | [ 0, -1, 0, -1, 1, -1, 1]], dtype=np.intc), -1, 0), 139 | (np.array([[ 0., 0., 0., 0., 0., 0., 0.], 140 | [ 0., 0., 0., -1., 0., 0., 0.], 141 | [ 1., 0., 1., -1., 0., 0., 0.], 142 | [-1., -1., 1., 1., 0., 0., 0.], 143 | [ 1., 1., 1., -1., 0., 0., 0.], 144 | [ 1., -1., 1., -1., 0., -1., 0.]], dtype=np.intc), -1, -1), 145 | (np.array([[ 0., 0., 0., 1., 0., 0., 0.,], 146 | [ 0., 0., 0., 1., 0., 0., 0.,], 147 | [ 0., 0., 0., -1., 0., 0., 0.,], 148 | [ 0., 0., 1., 1., -1., 0., -1.,], 149 | [ 0., 0., -1., 1., 1., 1., 1.,], 150 | [-1., 0., -1., 1., -1., -1., -1.,],], dtype=np.intc), 1, 1), 151 | ] 152 | 153 | for np_pieces, player, expected_end_state in array_end_state_pairs: 154 | board, player, game = init_board_from_array(np_pieces, player) 155 | end_state = game.getGameEnded(board, player) 156 | assert expected_end_state == end_state, ("expected=%s, actual=%s, board=\n%s" % (expected_end_state, end_state, board)) 157 | 158 | 159 | def test_immutable_move(): 160 | """Test original board is not mutated whtn getNextState() called.""" 161 | board, player, game = init_board_from_moves([1, 2, 3, 3, 4]) 162 | original_board_string = game.stringRepresentation(board) 163 | 164 | new_np_pieces, new_player = game.getNextState(board, 3, -1) 165 | 166 | assert original_board_string == game.stringRepresentation(board) 167 | assert original_board_string != game.stringRepresentation(new_np_pieces) 168 | -------------------------------------------------------------------------------- /alphazero/envs/connect4/train.py: -------------------------------------------------------------------------------- 1 | import pyximport; pyximport.install() 2 | 3 | from torch import multiprocessing as mp 4 | 5 | from alphazero.Coach import Coach, get_args 6 | from alphazero.NNetWrapper import NNetWrapper as nn 7 | from alphazero.envs.connect4.connect4 import Game 8 | from alphazero.GenericPlayers import RawMCTSPlayer 9 | from alphazero.utils import dotdict 10 | 11 | args = get_args(dotdict({ 12 | 'run_name': 'connect4_fpu', 13 | 'workers': mp.cpu_count(), 14 | 'startIter': 1, 15 | 'numIters': 1000, 16 | 'numWarmupIters': 1, 17 | 'process_batch_size': 2048, 18 | 'train_batch_size': 1024, 19 | # should preferably be a multiple of process_batch_size and workers 20 | 'gamesPerIteration': 2048 * mp.cpu_count(), 21 | 'symmetricSamples': True, 22 | 'skipSelfPlayIters': None, 23 | 'selfPlayModelIter': None, 24 | 'numMCTSSims': 200, 25 | 'numFastSims': 40, 26 | 'probFastSim': 0.75, 27 | 'compareWithBaseline': True, 28 | 'arenaCompareBaseline': 512, 29 | 'arenaCompare': 512, 30 | 'arena_batch_size': 128, 31 | 'arenaTemp': 1, 32 | 'arenaMCTS': True, 33 | 'baselineCompareFreq': 1, 34 | 'compareWithPast': True, 35 | 'pastCompareFreq': 1, 36 | 'cpuct': 4, 37 | 'fpu_reduction': 0.4, 38 | 'load_model': True, 39 | }), 40 | model_gating=True, 41 | max_gating_iters=None, 42 | max_moves=42, 43 | 44 | lr=0.01, 45 | num_channels=128, 46 | depth=8, 47 | value_head_channels=32, 48 | policy_head_channels=32, 49 | value_dense_layers=[1024, 256], 50 | policy_dense_layers=[1024] 51 | ) 52 | args.scheduler_args.milestones = [75, 150] 53 | 54 | 55 | if __name__ == "__main__": 56 | nnet = nn(Game, args) 57 | c = Coach(Game, nnet, args) 58 | c.learn() 59 | -------------------------------------------------------------------------------- /alphazero/envs/gobang/GobangLogic.pxd: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | # cython: boundscheck=False 3 | # cython: wraparound=False 4 | # cython: nonecheck=False 5 | # cython: overflowcheck=False 6 | # cython: initializedcheck=False 7 | # cython: cdivision=True 8 | # cython: auto_pickle=True 9 | 10 | import numpy as np 11 | cimport numpy as np 12 | 13 | DTYPE = np.intc 14 | ctypedef np.int32_t DTYPE_t 15 | 16 | 17 | cdef class Board: 18 | cdef public int n 19 | cdef public int n_in_row 20 | cdef public np.ndarray pieces 21 | 22 | cpdef list get_legal_moves(self) 23 | cpdef bint has_legal_moves(self) 24 | cpdef tuple get_win_state(self) 25 | cpdef void execute_move(self, tuple move, int color) 26 | 27 | -------------------------------------------------------------------------------- /alphazero/envs/gobang/GobangLogic.pyx: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | # cython: boundscheck=False 3 | # cython: wraparound=False 4 | # cython: nonecheck=False 5 | # cython: overflowcheck=False 6 | # cython: initializedcheck=False 7 | # cython: cdivision=True 8 | # cython: auto_pickle=True 9 | 10 | import numpy as np 11 | cimport numpy as np 12 | 13 | 14 | cdef class Board: 15 | """ 16 | Author: MBoss. Modified by Kevaday for Cython 17 | Date: Jan 17, 2018. 18 | Board class. 19 | Board data: 20 | 1=white, -1=black, 0=empty 21 | first dim is column , 2nd is row: 22 | pieces[1][7] is the square in column 2, 23 | at the opposite end of the board in row 8. 24 | Squares are stored and manipulated as (x,y) tuples. 25 | x is the column, y is the row. 26 | """ 27 | 28 | def __init__(self, int n, int n_in_row, _pieces=None): 29 | """Set up initial board configuration.""" 30 | self.n = n 31 | self.n_in_row = n_in_row 32 | 33 | if _pieces is not None: 34 | self.pieces = _pieces 35 | else: 36 | # Create the empty board array. 37 | self.pieces = np.zeros((n, n), dtype=np.intc) 38 | 39 | # add [][] indexer syntax to the Board 40 | def __getitem__(self, tuple index): 41 | return self.pieces[index] 42 | 43 | def __setitem__(self, tuple index, int value): 44 | self.pieces[index] = value 45 | 46 | cpdef list get_legal_moves(self): 47 | """Returns all the legal moves for the given color. 48 | (1 for white, -1 for black) 49 | """ 50 | cdef tuple pos 51 | return [tuple(reversed(pos)) for pos in zip(*np.where(self.pieces == 0))] 52 | 53 | cpdef bint has_legal_moves(self): 54 | """Returns True if has legal move else False 55 | """ 56 | return 0 in self.pieces 57 | 58 | cpdef tuple get_win_state(self): 59 | cdef Py_ssize_t n, w, h, i, j, k, l 60 | n = self.n_in_row 61 | 62 | for w in range(self.n): 63 | for h in range(self.n): 64 | if (w in range(self.n - n + 1) and self[w, h] != 0 and 65 | len(set([self[i, h] for i in range(w, w + n)])) == 1): 66 | return True, self[w, h] 67 | if (h in range(self.n - n + 1) and self[w, h] != 0 and 68 | len(set([self[w, j] for j in range(h, h + n)])) == 1): 69 | return True, self[w, h] 70 | if (w in range(self.n - n + 1) and h in range(self.n - n + 1) and self[w, 71 | h] != 0 and 72 | len(set([self[w + k, h + k] for k in range(n)])) == 1): 73 | return True, self[w, h] 74 | if (w in range(self.n - n + 1) and h in range(n - 1, self.n) and self[w, 75 | h] != 0 and 76 | len(set([self[w + l, h - l] for l in range(n)])) == 1): 77 | return True, self[w, h] 78 | 79 | if self.has_legal_moves(): 80 | return False, 0 81 | return True, 0 82 | 83 | cpdef void execute_move(self, tuple move, int color): 84 | """Perform the given move on the board; flips pieces as necessary. 85 | color gives the color of the piece to play (1=white,-1=black) 86 | """ 87 | assert self[move] == 0, f'invalid move {move}' 88 | self[move] = color 89 | -------------------------------------------------------------------------------- /alphazero/envs/gobang/GobangPlayers.py: -------------------------------------------------------------------------------- 1 | from alphazero.GenericPlayers import BasePlayer 2 | from alphazero.Game import GameState 3 | 4 | 5 | class HumanGobangPlayer(BasePlayer): 6 | def play(self, state: GameState) -> int: 7 | valid = state.valid_moves() 8 | """ 9 | for i in range(len(valid)): 10 | if valid[i]: 11 | print(int(i / state._board.n), int(i % state._board.n)) 12 | """ 13 | 14 | while True: 15 | a = input('Enter a move: ') 16 | 17 | x, y = [int(x) for x in a.split(' ')] 18 | a = state._board.n * x + y if x != -1 else state._board.n ** 2 19 | if valid[a]: 20 | break 21 | else: 22 | print('Invalid move entered.') 23 | 24 | return a 25 | 26 | 27 | class GreedyGobangPlayer(BasePlayer): 28 | def play(self, state: GameState) -> int: 29 | valids = state.valid_moves() 30 | candidates = [] 31 | 32 | for a in range(state.action_size()): 33 | if not valids[a]: continue 34 | 35 | next_state = state.clone() 36 | next_state.play_action(a) 37 | candidates += [(int(next_state.win_state()[next_state.player]), a)] 38 | 39 | candidates.sort() 40 | return candidates[0][1] 41 | -------------------------------------------------------------------------------- /alphazero/envs/gobang/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/alphazero/envs/gobang/__init__.py -------------------------------------------------------------------------------- /alphazero/envs/gobang/gobang.pyx: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | # cython: boundscheck=False 3 | # cython: wraparound=False 4 | # cython: nonecheck=False 5 | # cython: overflowcheck=False 6 | # cython: initializedcheck=False 7 | # cython: cdivision=True 8 | # cython: auto_pickle=True 9 | 10 | import numpy as np 11 | cimport numpy as np 12 | 13 | import pyximport 14 | 15 | pyximport.install(setup_args={'include_dirs': np.get_include()}) 16 | 17 | from alphazero.envs.gobang.GobangLogic import Board 18 | from alphazero.envs.gobang.GobangLogic cimport Board 19 | 20 | 21 | DTYPE = np.uint8 22 | ctypedef np.uint8_t DTYPE_t 23 | 24 | cdef int NUM_PLAYERS = 2 25 | cdef int BOARD_SIZE = 15 26 | cdef int NUM_IN_ROW = 5 27 | cdef int MAX_MOVES = BOARD_SIZE ** 2 28 | 29 | cdef int ACTION_SIZE = BOARD_SIZE ** 2 30 | cdef bint MULTI_PLANE_OBSERVATION = True 31 | cdef int NUM_CHANNELS = 4 if MULTI_PLANE_OBSERVATION else 1 32 | cdef tuple OBSERVATION_SIZE = (NUM_CHANNELS, BOARD_SIZE, BOARD_SIZE) 33 | 34 | 35 | cpdef tuple get_move(int action, int n): 36 | return action // n, action % n 37 | 38 | 39 | cpdef int get_action(tuple move, int n): 40 | return n * move[1] + move[0] 41 | 42 | 43 | cdef class Game:#(GameState): 44 | cdef public Board _board 45 | cdef public int _player 46 | cdef public int _turns 47 | 48 | def __init__(self, _board=None): 49 | self._board = _board or self._get_board() 50 | self._player = 0 51 | self._turns = 0 52 | 53 | @staticmethod 54 | def _get_board(*args, **kwargs) -> Board: 55 | return Board(BOARD_SIZE, NUM_IN_ROW, *args, **kwargs) 56 | 57 | def __eq__(self, other: 'Game') -> bool: 58 | return ( 59 | self._board.pieces == other._board.pieces 60 | and self._board.n == other._board.n 61 | and self._board.n_in_row == other._board.n_in_row 62 | and self._player == other._player 63 | and self.turns == other.turns 64 | ) 65 | 66 | cpdef Game clone(self): 67 | cdef Board board = self._get_board(_pieces=np.copy(self._board.pieces)) 68 | cdef Game g = Game(_board=board) 69 | g._player = self._player 70 | g._turns = self.turns 71 | return g 72 | 73 | @property 74 | def player(self) -> int: 75 | return self._player 76 | 77 | @property 78 | def turns(self): 79 | return self._turns 80 | 81 | cpdef int _next_player(self, int player, int turns=1): 82 | return (player + turns) % NUM_PLAYERS 83 | 84 | cpdef void _update_turn(self): 85 | """Should be called at the end of play_action""" 86 | self._player = self._next_player(self._player) 87 | self._turns += 1 88 | 89 | @staticmethod 90 | def num_players(): 91 | return NUM_PLAYERS 92 | 93 | @staticmethod 94 | def action_size(): 95 | return ACTION_SIZE 96 | 97 | @staticmethod 98 | def observation_size(): 99 | return OBSERVATION_SIZE 100 | 101 | @staticmethod 102 | def max_turns(): 103 | return MAX_MOVES 104 | 105 | @staticmethod 106 | def has_draw(): 107 | return True 108 | 109 | cpdef np.ndarray valid_moves(self): 110 | # return a fixed size binary vector 111 | cdef list valids = [0] * self.action_size() 112 | cdef tuple move 113 | 114 | for move in self._board.get_legal_moves(): 115 | valids[get_action(move, self._board.n)] = 1 116 | 117 | return np.array(valids, dtype=np.uint8) 118 | 119 | cpdef void play_action(self, int action): 120 | cdef tuple move = get_move(action, self._board.n) 121 | self._board.execute_move(move, (1, -1)[self.player]) 122 | self._update_turn() 123 | 124 | cpdef np.ndarray win_state(self): 125 | cdef list result = [False] * (NUM_PLAYERS + 1) 126 | cdef bint game_over 127 | cdef int player 128 | cdef Py_ssize_t index 129 | game_over, player = self._board.get_win_state() 130 | 131 | if game_over: 132 | index = NUM_PLAYERS 133 | if player == 1: 134 | index = 0 135 | elif player == -1: 136 | index = 1 137 | result[index] = True 138 | 139 | return np.array(result, dtype=np.uint8) 140 | 141 | cpdef np.ndarray observation(self): 142 | if MULTI_PLANE_OBSERVATION: 143 | pieces = np.asarray(self._board.pieces) 144 | player1 = np.where(pieces == 1, 1, 0) 145 | player2 = np.where(pieces == -1, 1, 0) 146 | colour = np.full_like(pieces, self.player) 147 | turn = np.full_like(pieces, self.turns / self._board.n**2, dtype=np.float32) 148 | return np.array([player1, player2, colour, turn], dtype=np.float32) 149 | 150 | else: 151 | return np.expand_dims(np.asarray(self._board.pieces), axis=0) 152 | 153 | cpdef list symmetries(self, np.ndarray pi): 154 | # mirror, rotational 155 | 156 | cdef np.ndarray[np.float32_t, ndim=2] pi_board = np.reshape(pi, (self._board.n, self._board.n)) 157 | cdef np.ndarray[np.int32_t, ndim=2] new_b 158 | cdef np.ndarray[np.float32_t, ndim=2] new_pi 159 | cdef list result = [] 160 | cdef Game gs 161 | cdef Py_ssize_t i 162 | cdef bint j 163 | 164 | for i in range(1, 5): 165 | for j in [True, False]: 166 | new_b = np.rot90(np.asarray(self._board.pieces), i) 167 | new_pi = np.rot90(pi_board, i) 168 | if j: 169 | new_b = np.fliplr(new_b) 170 | new_pi = np.fliplr(new_pi) 171 | 172 | gs = self.clone() 173 | gs._board.pieces = new_b 174 | result.append((gs, new_pi.ravel())) 175 | 176 | return result 177 | 178 | 179 | cpdef void display(Game gs, int action=-1): 180 | cdef np.ndarray[np.int32_t, ndim=2] board = gs._board.pieces 181 | cdef int n = board.shape[0] 182 | cdef Py_ssize_t y, x 183 | cdef int piece 184 | cdef str prefix = ' ' 185 | 186 | if action != -1: 187 | print(f'Action: {action}, Move: {get_move(action, n)}') 188 | 189 | print(' ' * 4 + '|'.join([str(x) for x in range(n)])) 190 | print(' ' * 4 + '-' * (n * 2)) 191 | 192 | for y in range(n): 193 | if y > 9: 194 | prefix = '' 195 | print(prefix + f'{y} |', end='') # print the row # 196 | 197 | for x in range(n): 198 | piece = board[x, y] # get the piece to print 199 | if piece == -1: 200 | print('b ', end='') 201 | elif piece == 1: 202 | print('W ', end='') 203 | else: 204 | print('- ', end='') 205 | print('|') 206 | 207 | print(' ' * 4 + '-' * (n * 2)) 208 | -------------------------------------------------------------------------------- /alphazero/envs/gobang/pit.py: -------------------------------------------------------------------------------- 1 | from pyximport import install 2 | from numpy import get_include 3 | install(setup_args={'include_dirs': get_include()}) 4 | 5 | from alphazero.Arena import Arena 6 | from alphazero.GenericPlayers import * 7 | from alphazero.NNetWrapper import NNetWrapper as NNet 8 | 9 | 10 | """ 11 | use this script to play any two agents against each other, or play manually with 12 | any agent. 13 | """ 14 | if __name__ == '__main__': 15 | from alphazero.envs.gobang.GobangGame import GobangGame as Game, display 16 | from alphazero.envs.gobang.train import args 17 | import random 18 | 19 | batched_arena = False 20 | args.numMCTSSims = 800 21 | #args.arena_batch_size = 64 22 | args.temp_scaling_fn = lambda x,y,z:0.2 23 | #args2.temp_scaling_fn = args.temp_scaling_fn 24 | #args.cuda = False 25 | args.add_root_noise = args.add_root_temp = False 26 | 27 | # nnet players 28 | nn1 = NNet(Game, args) 29 | nn1.load_checkpoint('./checkpoint/' + args.run_name, 'iteration-0009.pkl') 30 | #nn2 = NNet(Game, args) 31 | #nn2.load_checkpoint('./checkpoint/brandubh2', 'iteration-0112.pkl') 32 | #player1 = nn1.process 33 | #player2 = nn2.process 34 | 35 | player1 = MCTSPlayer(Game, nn1, args=args, verbose=True) 36 | player2 = MCTSPlayer(Game, nn1, args=args, verbose=True) 37 | #player2 = RandomPlayer() 38 | #player2 = GreedyTaflPlayer() 39 | #player2 = RandomPlayer() 40 | #player2 = OneStepLookaheadConnect4Player() 41 | #player2 = RawMCTSPlayer(Game, args) 42 | #player2 = HumanFastaflPlayer() 43 | 44 | players = [player1, player2] 45 | #random.shuffle(players) 46 | 47 | arena = Arena(players, Game, use_batched_mcts=batched_arena, args=args, display=display) 48 | if batched_arena: 49 | wins, draws, winrates = arena.play_games(args.arenaCompare) 50 | for i in range(len(wins)): 51 | print(f'player{i+1}:\n\twins: {wins[i]}\n\twin rate: {winrates[i]}') 52 | print('draws: ', draws) 53 | else: 54 | arena.play_game(verbose=True) 55 | 56 | -------------------------------------------------------------------------------- /alphazero/envs/gobang/train.py: -------------------------------------------------------------------------------- 1 | import numpy, pyximport 2 | pyximport.install(setup_args={'include_dirs': numpy.get_include()}) 3 | 4 | from alphazero.Coach import Coach, get_args 5 | from alphazero.NNetWrapper import NNetWrapper as nn 6 | from alphazero.envs.gobang.gobang import Game 7 | from alphazero.GenericPlayers import RawMCTSPlayer 8 | #from alphazero.envs.gobang.GobangPlayers import GreedyGobangPlayer 9 | 10 | 11 | args = get_args( 12 | run_name='gobang', 13 | max_moves=225, 14 | cpuct=2, 15 | fpu_reduction=0.1, 16 | symmetricSamples=True, 17 | numMCTSSims=250, 18 | numFastSims=50, 19 | numWarmupSims=5, 20 | probFastSim=0.75, 21 | #skipSelfPlayIters=1, 22 | train_on_past_data=True, 23 | past_data_run_name='gobang', 24 | numWarmupIters=1, 25 | baselineCompareFreq=3, 26 | pastCompareFreq=1, 27 | process_batch_size=128, 28 | train_batch_size=512, 29 | arena_batch_size=64, 30 | arenaCompare=64*4, 31 | arenaCompareBaseline=64*4, 32 | gamesPerIteration=128*4, 33 | #train_steps_per_iteration=150, 34 | autoTrainSteps=True, 35 | train_sample_ratio=1, 36 | 37 | lr=0.01, 38 | depth=8, 39 | num_channels=128, 40 | value_head_channels=16, 41 | policy_head_channels=16, 42 | value_dense_layers=[2048, 128], 43 | policy_dense_layers=[2048] 44 | ) 45 | args.scheduler_args.milestones = [75, 100] 46 | 47 | 48 | if __name__ == "__main__": 49 | nnet = nn(Game, args) 50 | c = Coach(Game, nnet, args) 51 | c.learn() 52 | -------------------------------------------------------------------------------- /alphazero/envs/hnefatafl/__init__.py: -------------------------------------------------------------------------------- 1 | from .hnefatafl import * 2 | -------------------------------------------------------------------------------- /alphazero/envs/hnefatafl/brandubh.pyx: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | # cython: profile=True 3 | 4 | import pyximport; pyximport.install() 5 | 6 | from alphazero.Game import GameState 7 | from hnefatafl.engine import Board, Move, PieceType, variants 8 | from typing import List, Tuple, Any 9 | 10 | import numpy as np 11 | 12 | 13 | def _get_board(): 14 | return Board( 15 | GAME_VARIANT, 16 | max_repeats=MAX_REPEATS, 17 | _store_past_states=False, 18 | _max_past_states=14#max((MAX_REPEATS + 1) * NUM_PLAYERS, NUM_STACKED_OBSERVATIONS - 1) 19 | ) 20 | 21 | 22 | GAME_VARIANT = variants.brandubh 23 | MAX_REPEATS = 3 # N-fold repetition loss 24 | NUM_PLAYERS = 2 25 | NUM_STACKED_OBSERVATIONS = 1 26 | NUM_BASE_CHANNELS = 5 27 | NUM_CHANNELS = NUM_BASE_CHANNELS * NUM_STACKED_OBSERVATIONS 28 | 29 | b = _get_board() 30 | ACTION_SIZE = b.board_width * b.board_height * (b.board_width + b.board_height - 2) 31 | OBS_SIZE = (NUM_CHANNELS, b.board_width, b.board_height) 32 | del b 33 | 34 | DRAW_MOVE_COUNT = 100 35 | 36 | 37 | def _board_from_numpy(np_board: np.ndarray) -> Board: 38 | return Board(custom_board='\n'.join([''.join([str(i) for i in row]) for row in np_board])) 39 | 40 | 41 | def _board_to_numpy(board: Board) -> np.ndarray: 42 | return np.array([[int(tile) for tile in row] for row in board._board]) 43 | 44 | 45 | def get_move(board: Board, action: int) -> Move: 46 | size = (board.board_width + board.board_height - 2) 47 | move_type = action % size 48 | a = action // size 49 | start_x = a % board.board_width 50 | start_y = a // board.board_width 51 | 52 | if move_type < board.board_height - 1: 53 | new_x = start_x 54 | new_y = move_type 55 | if move_type >= start_y: new_y += 1 56 | else: 57 | new_x = move_type - board.board_height + 1 58 | if new_x >= start_x: new_x += 1 59 | new_y = start_y 60 | 61 | return Move(board, int(start_x), int(start_y), int(new_x), int(new_y), _check_in_bounds=False) 62 | 63 | 64 | def get_action(board: Board, move: Move) -> int: 65 | new_x = move.new_tile.x 66 | new_y = move.new_tile.y 67 | 68 | if move.is_vertical: 69 | move_type = new_y if new_y < move.tile.y else new_y - 1 70 | else: 71 | move_type = board.board_height + new_x - 1 72 | if new_x >= move.tile.x: move_type -= 1 73 | 74 | return (board.board_width + board.board_height - 2) * (move.tile.x + move.tile.y * board.board_width) + move_type 75 | 76 | 77 | def _get_observation(board: Board, const_max_player: int, const_max_turns: int, past_obs: int = 1): 78 | obs = [] 79 | 80 | def add_obs(b): 81 | game_board = _board_to_numpy(b) 82 | black = np.where(game_board == PieceType.black.value, 1., 0.) 83 | white = np.where((game_board == PieceType.white.value) | (game_board == PieceType.king.value), 1., 0.) 84 | king = np.where(game_board == PieceType.king.value, 1., 0.) 85 | turn_colour = np.full_like( 86 | game_board, 87 | 2 - b.to_play().value / (const_max_player - 1) if const_max_player > 1 else 0 88 | ) 89 | turn_number = np.full_like( 90 | game_board, 91 | b.num_turns / const_max_turns if const_max_turns else 0, dtype=np.float32 92 | ) 93 | obs.extend([black, white, king, turn_colour, turn_number]) 94 | 95 | def add_empty(): 96 | obs.extend([[[0] * board.board_width] * board.board_height] * NUM_BASE_CHANNELS) 97 | 98 | if board._store_past_states: 99 | past = board._past_states.copy() 100 | past.insert(0, (board, None)) 101 | for i in range(past_obs): 102 | if board.num_turns < i: 103 | add_empty() 104 | else: 105 | add_obs(past[i][0]) 106 | else: 107 | add_obs(board) 108 | 109 | return np.array(obs, dtype=np.float32) 110 | 111 | 112 | class Game(GameState): 113 | def __init__(self, _board=None): 114 | super().__init__(_board or _get_board()) 115 | 116 | def __eq__(self, other: 'Game') -> bool: 117 | return self.__dict__ == other.__dict__ 118 | 119 | @staticmethod 120 | def _get_piece_type(player: int) -> PieceType: 121 | return PieceType(2 - player) 122 | 123 | @staticmethod 124 | def _get_player_int(player: PieceType) -> int: 125 | return (1, -1)[2 - player.value] 126 | 127 | def clone(self) -> 'GameState': 128 | g = Game(self._board.copy(store_past_states=self._board._store_past_states)) 129 | g._player = self._player 130 | g._turns = self.turns 131 | return g 132 | 133 | @staticmethod 134 | def num_players() -> int: 135 | return NUM_PLAYERS 136 | 137 | @staticmethod 138 | def action_size() -> int: 139 | return ACTION_SIZE 140 | 141 | @staticmethod 142 | def observation_size() -> Tuple[int, int, int]: 143 | return OBS_SIZE 144 | 145 | def valid_moves(self): 146 | valids = [0] * self.action_size() 147 | legal_moves = self._board.all_valid_moves(self._board.to_play()) 148 | 149 | for move in legal_moves: 150 | valids[get_action(self._board, move)] = 1 151 | 152 | return np.array(valids, dtype=np.intc) 153 | 154 | def play_action(self, action: int) -> None: 155 | move = get_move(self._board, action) 156 | self._board.move(move, _check_game_end=False, _check_valid=False) 157 | self._update_turn() 158 | 159 | def win_state(self) -> Tuple[bool, ...]: 160 | result = [False] * (NUM_PLAYERS + 1) 161 | 162 | # Check if maximum moves have been exceeded 163 | if self.turns >= DRAW_MOVE_COUNT: 164 | result[-1] = True 165 | else: 166 | winner: PieceType = self._board.get_winner() 167 | if winner: 168 | result[2 - winner.value] = True 169 | 170 | return tuple(result) 171 | 172 | def observation(self): 173 | return _get_observation( 174 | self._board, 175 | NUM_PLAYERS, 176 | DRAW_MOVE_COUNT, 177 | NUM_STACKED_OBSERVATIONS 178 | ) 179 | 180 | def symmetries(self, pi: np.ndarray) -> List[Tuple[Any, int]]: 181 | action_size = self.action_size() 182 | assert (len(pi) == action_size) 183 | syms = [None] * 8 184 | 185 | for i in range(1, 5): 186 | for flip in (False, True): 187 | state = np.rot90(np.array(self._board._board), i) 188 | if flip: 189 | state = np.fliplr(state) 190 | 191 | if self._board._store_past_states: 192 | num_past_states = min( 193 | NUM_STACKED_OBSERVATIONS - 1, 194 | len(self._board._past_states) 195 | ) 196 | past_states = [None] * num_past_states 197 | for idx in range(num_past_states): 198 | past = self._board._past_states[idx] 199 | b = np.rot90(np.array(past[0]._board), i) 200 | if flip: 201 | b = np.fliplr(b) 202 | past_states[idx] = (self._board.copy(store_past_states=False, state=b.tolist()), past[1]) 203 | else: 204 | past_states = None 205 | 206 | new_b = self._board.copy( 207 | store_past_states=self._board._store_past_states, 208 | state=state.tolist(), 209 | past_states=past_states 210 | ) 211 | if not past_states: 212 | new_b._past_states = [s.copy(new_b) for s in new_b._past_states] 213 | 214 | new_pi = [0] * action_size 215 | for action, prob in enumerate(pi): 216 | move = get_move(self._board, action) 217 | 218 | x = move.tile.x 219 | new_x = move.new_tile.x 220 | y = move.tile.y 221 | new_y = move.new_tile.y 222 | 223 | for _ in range(i): 224 | temp_x = x 225 | temp_new_x = new_x 226 | x = self._board.board_width - 1 - y 227 | new_x = self._board.board_width - 1 - new_y 228 | y = temp_x 229 | new_y = temp_new_x 230 | if flip: 231 | x = self._board.board_width - 1 - x 232 | new_x = self._board.board_width - 1 - new_x 233 | 234 | move = Move(new_b, x, y, new_x, new_y) 235 | new_action = get_action(new_b, move) 236 | new_pi[new_action] = prob 237 | 238 | new_state = self.clone() 239 | new_state._board = new_b 240 | syms[(i - 1) * 2 + int(flip)] = (new_state, np.array(new_pi, dtype=np.float32)) 241 | 242 | return syms 243 | 244 | def crude_value(self) -> int: 245 | _, result = self.win_state() 246 | white_pieces = len(list(filter(lambda p: p.is_white, self._board.pieces))) 247 | black_pieces = len(list(filter(lambda p: p.is_black, self._board.pieces))) 248 | return self.player * (1000 * result + black_pieces - white_pieces) 249 | 250 | 251 | def display(state: Game, action: int = None): 252 | print(f'Action: {action}, Move: {get_move(state._board, action)}') 253 | print(state) 254 | 255 | 256 | def test_repeat(n): 257 | global GAME_VARIANT 258 | GAME_VARIANT = variants.hnefatafl 259 | g = Game() 260 | # g.board[0][0].piece = Piece(PieceType(3), 0, 0, 0) 261 | board = _get_board() 262 | for _ in range(n): 263 | board.move(Move(board, 3, 0, 2, 0)) 264 | board.move(Move(board, 5, 3, 5, 2)) 265 | board.move(Move(board, 2, 0, 3, 0)) 266 | board.move(Move(board, 5, 2, 5, 3)) 267 | print(board.num_repeats(PieceType.black), board.num_repeats(PieceType.white)) 268 | g._board = board 269 | print(g.win_state()) 270 | -------------------------------------------------------------------------------- /alphazero/envs/hnefatafl/fastafl.pyx: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | # cython: boundscheck=False 3 | # cython: wraparound=False 4 | # cython: nonecheck=False 5 | # cython: overflowcheck=False 6 | # cython: initializedcheck=False 7 | # cython: cdivision=True 8 | # cython: auto_pickle=True 9 | 10 | import numpy as np 11 | cimport numpy as np 12 | import pyximport 13 | 14 | pyximport.install(setup_args={'include_dirs': np.get_include()}) 15 | 16 | #from alphazero.cGame import GameState 17 | #from alphazero.cGame cimport GameState 18 | from boardgame import Square 19 | from boardgame.board cimport Square 20 | from fastafl.cengine import Board 21 | from fastafl.cengine cimport Board 22 | from fastafl import variants 23 | 24 | 25 | DTYPE = np.float32 26 | ctypedef np.float32_t DTYPE_t 27 | 28 | 29 | cpdef Board _get_board(): 30 | return Board(*GAME_VARIANT) 31 | 32 | 33 | cdef tuple GAME_VARIANT = variants.hnefatafl_args 34 | cdef int NUM_PLAYERS = 2 35 | cdef int NUM_STACKED_OBSERVATIONS = 1 36 | cdef int NUM_BASE_CHANNELS = 5 37 | cdef int NUM_CHANNELS = NUM_BASE_CHANNELS * NUM_STACKED_OBSERVATIONS 38 | 39 | cdef Board b = _get_board() 40 | cdef int ACTION_SIZE = b.width * b.height * (b.width + b.height - 2) 41 | cdef tuple OBS_SIZE = (NUM_CHANNELS, b.width, b.height) 42 | 43 | cdef int DRAW_MOVE_COUNT = 512 44 | 45 | 46 | cpdef tuple get_move(Board board, int action): 47 | cdef int size = board.width + board.height - 2 48 | cdef int move_type = action % size 49 | cdef int a = action // size 50 | cdef int start_x = a % board.width 51 | cdef int start_y = a // board.width 52 | cdef int new_x, new_y 53 | 54 | if move_type < board.height - 1: 55 | new_x = start_x 56 | new_y = move_type 57 | if move_type >= start_y: new_y += 1 58 | else: 59 | new_x = move_type - board.height + 1 60 | if new_x >= start_x: new_x += 1 61 | new_y = start_y 62 | 63 | return Square(int(start_x), int(start_y)), Square(int(new_x), int(new_y)) 64 | 65 | 66 | cpdef int get_action(Board board, tuple move): 67 | cdef int x = move[0].x 68 | cdef int y = move[0].y 69 | cdef int new_x = move[1].x 70 | cdef int new_y = move[1].y 71 | cdef int move_type 72 | 73 | if (x - new_x) == 0: 74 | move_type = new_y if new_y < y else new_y - 1 75 | else: 76 | move_type = board.height + new_x - 1 77 | if new_x >= x: move_type -= 1 78 | 79 | return (board.width + board.height - 2) * (x + y * board.width) + move_type 80 | 81 | 82 | cpdef list _add_obs(Board b, int const_max_player, int const_max_turns): 83 | cdef np.ndarray[np.uint8_t, ndim=2, cast=True] black_mask = b.get_mask((2,)) 84 | cdef np.ndarray[np.uint8_t, ndim=2, cast=True] white_mask = b.get_mask((1,)) 85 | cdef np.ndarray[np.uint8_t, ndim=2, cast=True] king_mask = b.get_mask((3, 7, 8)) 86 | 87 | cdef np.ndarray[np.float32_t, ndim=2] black = np.array(np.where(black_mask, 1., 0.), dtype=np.float32) 88 | cdef np.ndarray[np.float32_t, ndim=2] white = np.array(np.where(white_mask, 1., 0.), dtype=np.float32) 89 | cdef np.ndarray[np.float32_t, ndim=2] king = np.array(np.where(king_mask, 1., 0.), dtype=np.float32) 90 | cdef np.ndarray[np.float32_t, ndim=2] turn_colour = np.full_like( 91 | b._state, 2 - b.to_play() / (const_max_player - 1) if const_max_player > 1 else 0, dtype=np.float32 92 | ) 93 | cdef np.ndarray[np.float32_t, ndim=2] turn_number = np.full_like( 94 | b._state, b.num_turns / const_max_turns if const_max_turns else 0, dtype=np.float32 95 | ) 96 | 97 | return [black, white, king, turn_colour, turn_number] 98 | 99 | 100 | cpdef list _add_empty(Board board): 101 | return [[[0] * board.width] * board.height] * NUM_BASE_CHANNELS 102 | 103 | 104 | cpdef np.ndarray _get_observation(Board board, int const_max_players, int const_max_turns, int past_obs=1, list past_states=[]): 105 | cdef list past, obs = [] 106 | cdef Py_ssize_t i 107 | 108 | if past_states: 109 | past = past_states.copy() 110 | past.insert(0, board) 111 | for i in range(past_obs): 112 | if board.num_turns < i: 113 | obs.extend(_add_empty(board)) 114 | else: 115 | obs.extend(_add_obs(past[i], const_max_players, const_max_turns)) 116 | else: 117 | obs.extend(_add_obs(board, const_max_players, const_max_turns)) 118 | 119 | return np.array(obs, dtype=np.float32) 120 | 121 | 122 | cdef class Game:#(GameState): 123 | cdef public Board _board 124 | cdef public int _player 125 | cdef public int _turns 126 | cdef public int last_action 127 | 128 | def __init__(self, _board=None): 129 | self._board = _board or _get_board() 130 | self._player = 0 131 | self._turns = 0 132 | self.last_action = -1 133 | 134 | def __eq__(self, other: 'Game') -> bool: 135 | return self.__dict__ == other.__dict__ 136 | 137 | def __str__(self): 138 | return str(self._board) + '\n' 139 | 140 | @property 141 | def player(self) -> int: 142 | return self._player 143 | 144 | @property 145 | def turns(self): 146 | return self._turns 147 | 148 | @staticmethod 149 | cdef int _get_player_int(int player): 150 | return (1, -1)[2 - player] 151 | 152 | cpdef Game clone(self): 153 | cdef Game g = Game(self._board.copy()) 154 | g._player = self._player 155 | g._turns = self.turns 156 | g.last_action = self.last_action 157 | return g 158 | 159 | @staticmethod 160 | def num_players(): 161 | return NUM_PLAYERS 162 | 163 | @staticmethod 164 | def action_size(): 165 | return ACTION_SIZE 166 | 167 | @staticmethod 168 | def observation_size(): 169 | return OBS_SIZE 170 | 171 | cpdef int _next_player(self, int player, int turns=1): 172 | return (player + turns) % Game.num_players() 173 | 174 | cpdef void _update_turn(self): 175 | """Should be called at the end of play_action""" 176 | self._player = self._next_player(self._player) 177 | self._turns += 1 178 | 179 | cpdef np.ndarray valid_moves(self): 180 | cdef list valids = [0] * ACTION_SIZE 181 | cdef tuple move 182 | 183 | for move in self._board.legal_moves(pieces=(), piece_type=self._board.to_play()): 184 | valids[get_action(self._board, move)] = 1 185 | 186 | return np.array(valids, dtype=np.uint8) 187 | 188 | cpdef void play_action(self, int action): 189 | self.last_action = action 190 | cdef tuple move = get_move(self._board, action) 191 | self._board.move(move[0], move[1], check_turn=False, _check_valid=False, _check_win=False) 192 | self._update_turn() 193 | 194 | cpdef np.ndarray win_state(self): 195 | cdef np.ndarray[dtype=np.uint8_t, ndim=1] result = np.zeros(NUM_PLAYERS + 1, dtype=np.uint8) 196 | cdef int winner 197 | 198 | # Check if maximum moves have been exceeded 199 | if self.turns >= DRAW_MOVE_COUNT: 200 | result[NUM_PLAYERS] = 1 201 | else: 202 | winner = self._board.get_winner() 203 | if winner != 0: 204 | result[2 - winner] = 1 205 | 206 | return result 207 | 208 | cpdef np.ndarray observation(self): 209 | return _get_observation( 210 | self._board, 211 | NUM_PLAYERS, 212 | DRAW_MOVE_COUNT, 213 | NUM_STACKED_OBSERVATIONS 214 | ) 215 | 216 | cpdef list symmetries(self, np.ndarray pi): 217 | cdef list syms = [None] * 8 218 | cdef int i 219 | cdef bint flip 220 | cdef np.ndarray[np.float32_t, ndim=2] state 221 | cdef np.ndarray[np.float32_t, ndim=1] new_pi 222 | cdef Board new_b 223 | cdef Game new_state 224 | 225 | for i in range(1, 5): 226 | for flip in (False, True): 227 | state = np.rot90(np.array(self._board._state, dtype=np.float32), i) 228 | if flip: 229 | state = np.fliplr(state) 230 | 231 | new_b = self._board.copy() 232 | new_b._state = state 233 | new_pi = np.zeros(ACTION_SIZE, dtype=np.float32) 234 | for action, prob in enumerate(pi): 235 | move = get_move(self._board, action) 236 | x = move[0].x 237 | new_x = move[1].x 238 | y = move[0].y 239 | new_y = move[1].y 240 | 241 | for _ in range(i): 242 | temp_x = x 243 | temp_new_x = new_x 244 | x = self._board.width - 1 - y 245 | new_x = self._board.width - 1 - new_y 246 | y = temp_x 247 | new_y = temp_new_x 248 | if flip: 249 | x = self._board.width - 1 - x 250 | new_x = self._board.width - 1 - new_x 251 | 252 | new_action = get_action(new_b, (Square(x, y), Square(new_x, new_y))) 253 | new_pi[new_action] = prob 254 | 255 | new_state = self.clone() 256 | new_state._board = new_b 257 | syms[(i - 1) * 2 + int(flip)] = (new_state, new_pi) 258 | 259 | return syms 260 | 261 | def crude_value(self) -> int: 262 | result = self.win_state() 263 | white_pieces = len(list(filter(lambda p: p.is_white, self._board.pieces))) 264 | black_pieces = len(list(filter(lambda p: p.is_black, self._board.pieces))) 265 | return (1, -1)[2 - self.player] * ( 266 | 1000 * (result[0] - result[1]) 267 | + black_pieces - white_pieces 268 | - result[result.size - 1] * 100 269 | - self.turns 270 | ) 271 | 272 | 273 | cpdef void display(Game state, int action=-1): 274 | if action != -1: print(f'Action: {action}, Move: {get_move(state._board, action)}') 275 | print(state) 276 | -------------------------------------------------------------------------------- /alphazero/envs/hnefatafl/hnefatafl.py: -------------------------------------------------------------------------------- 1 | import pyximport, numpy 2 | pyximport.install(setup_args={'include_dirs': numpy.get_include()}) 3 | from .fastafl import * 4 | -------------------------------------------------------------------------------- /alphazero/envs/hnefatafl/pit.py: -------------------------------------------------------------------------------- 1 | from pyximport import install 2 | from numpy import get_include 3 | install(setup_args={'include_dirs': get_include()}) 4 | 5 | from alphazero.Arena import Arena 6 | from alphazero.GenericPlayers import * 7 | from alphazero.NNetWrapper import NNetWrapper as NNet 8 | 9 | 10 | """ 11 | use this script to play any two agents against each other, or play manually with 12 | any agent. 13 | """ 14 | if __name__ == '__main__': 15 | from alphazero.envs.hnefatafl.fastafl import Game as Game, display 16 | from alphazero.envs.hnefatafl.train_fastafl import args 17 | #from alphazero.envs.tafl.train_brandubh import args as args2 18 | from alphazero.envs.hnefatafl.players import HumanFastaflPlayer 19 | import random 20 | 21 | batched_arena = False 22 | args.numMCTSSims = 2000 23 | #args.arena_batch_size = 64 24 | args.temp_scaling_fn = lambda x,y,z:0.25 25 | #args2.temp_scaling_fn = args.temp_scaling_fn 26 | #args.cuda = False 27 | args.add_root_noise = args.add_root_temp = False 28 | 29 | # nnet players 30 | nn1 = NNet(Game, args) 31 | nn1.load_checkpoint('./checkpoint/' + args.run_name, 'iteration-0036.pkl') 32 | #nn2 = NNet(Game, args) 33 | #nn2.load_checkpoint('./checkpoint/brandubh2', 'iteration-0112.pkl') 34 | #player1 = nn1.process 35 | #player2 = nn2.process 36 | 37 | player1 = MCTSPlayer(Game, nn1, args=args, verbose=True) 38 | #player2 = MCTSPlayer(Game, nn1, args=args, verbose=True) 39 | #player2 = RandomPlayer() 40 | #player2 = GreedyTaflPlayer() 41 | player2 = RandomPlayer() 42 | #player2 = OneStepLookaheadConnect4Player() 43 | #player2 = RawMCTSPlayer(Game, args) 44 | #player2 = HumanFastaflPlayer() 45 | 46 | players = [player2, player1] 47 | #random.shuffle(players) 48 | 49 | arena = Arena(players, Game, use_batched_mcts=batched_arena, args=args, display=display) 50 | if batched_arena: 51 | wins, draws, winrates = arena.play_games(args.arenaCompare) 52 | for i in range(len(wins)): 53 | print(f'player{i+1}:\n\twins: {wins[i]}\n\twin rate: {winrates[i]}') 54 | print('draws: ', draws) 55 | else: 56 | arena.play_game(verbose=True) 57 | -------------------------------------------------------------------------------- /alphazero/envs/hnefatafl/players.py: -------------------------------------------------------------------------------- 1 | #from hnefatafl.engine import Move, BoardGameException 2 | #from alphazero.envs.hnefatafl.tafl_old import get_action 3 | from alphazero.envs.hnefatafl.fastafl import get_action 4 | from alphazero.GenericPlayers import BasePlayer 5 | from alphazero.Game import GameState 6 | 7 | import pyximport, numpy 8 | pyximport.install(setup_args={'include_dirs': numpy.get_include()}) 9 | 10 | from boardgame.board import Square 11 | from boardgame.errors import InvalidMoveError 12 | 13 | 14 | """ 15 | class HumanTaflPlayer(BasePlayer): 16 | def play(self, state: GameState): 17 | valid_moves = state.valid_moves() 18 | 19 | def string_to_action(player_inp: str) -> int: 20 | try: 21 | move_lst = [int(x) for x in player_inp.split()] 22 | move = Move(state._board, move_lst) 23 | return get_action(state._board, move) 24 | except (ValueError, AttributeError, BoardGameException): 25 | return -1 26 | 27 | action = string_to_action(input(f"Enter the move to play for the player {state.player}: ")) 28 | while action == -1 or not valid_moves[action]: 29 | action = string_to_action(input(f"Illegal move (action={action}, " 30 | f"in valids: {bool(valid_moves[action])}). Enter a valid move: ")) 31 | 32 | return action 33 | """ 34 | 35 | 36 | class HumanFastaflPlayer(BasePlayer): 37 | @staticmethod 38 | def is_human() -> bool: 39 | return True 40 | 41 | def play(self, state: GameState): 42 | valid_moves = state.valid_moves() 43 | 44 | def string_to_action(player_inp: str) -> int: 45 | try: 46 | move_lst = [int(x) for x in player_inp.split()] 47 | return get_action(state._board, (Square(*move_lst[:2]), Square(*move_lst[2:]))) 48 | except (ValueError, AttributeError, InvalidMoveError): 49 | return -1 50 | 51 | action = string_to_action(input(f"Enter the move to play for the player {state.player}: ")) 52 | while action == -1 or not valid_moves[action]: 53 | action = string_to_action(input(f"Illegal move (action={action}, " 54 | f"in valids: {bool(valid_moves[action])}). Enter a valid move: ")) 55 | 56 | return action 57 | 58 | 59 | class GreedyTaflPlayer(BasePlayer): 60 | def play(self, state: GameState): 61 | valids = state.valid_moves() 62 | candidates = [] 63 | 64 | for a in range(state.action_size()): 65 | if not valids[a]: continue 66 | new_state = state.clone() 67 | new_state.play_action(a) 68 | candidates.append((-new_state.crude_value(), a)) 69 | 70 | candidates.sort() 71 | return candidates[0][1] 72 | -------------------------------------------------------------------------------- /alphazero/envs/hnefatafl/train.py: -------------------------------------------------------------------------------- 1 | import pyximport; pyximport.install() 2 | 3 | from alphazero.Coach import Coach, get_args 4 | from alphazero.NNetWrapper import NNetWrapper as nn 5 | from alphazero.envs.hnefatafl.tafl_old import Game, NUM_STACKED_OBSERVATIONS, DRAW_MOVE_COUNT 6 | from alphazero.GenericPlayers import RawMCTSPlayer 7 | from alphazero.utils import dotdict 8 | 9 | args = get_args( 10 | run_name='hnefatafl', 11 | max_moves=DRAW_MOVE_COUNT, 12 | num_stacked_observations=NUM_STACKED_OBSERVATIONS, 13 | cpuct=1.25, 14 | symmetricSamples=False, 15 | numMCTSSims=100, 16 | numFastSims=15, 17 | numWarmupSims=5, 18 | probFastSim=0.75, 19 | 20 | selfPlayModelIter=None, 21 | skipSelfPlayIters=None, 22 | #train_on_past_data=True, 23 | #past_data_run_name='brandubh', 24 | model_gating=True, 25 | max_gating_iters=None, 26 | numWarmupIters=2, 27 | arenaMCTS=True, 28 | baselineCompareFreq=3, 29 | pastCompareFreq=3, 30 | train_steps_per_iteration=40, 31 | min_next_model_winrate=0.52, 32 | use_draws_for_winrate=True, 33 | 34 | process_batch_size=64, 35 | train_batch_size=4096, 36 | arena_batch_size=32, 37 | arenaCompare=32*4, 38 | arenaCompareBaseline=32*4, 39 | gamesPerIteration=64*4, 40 | 41 | lr=1e-2, 42 | optimizer_args=dotdict({ 43 | 'momentum': 0.9, 44 | 'weight_decay': 1e-3 45 | }), 46 | 47 | depth=8, 48 | num_channels=64, 49 | value_head_channels=16, 50 | policy_head_channels=16, 51 | value_dense_layers=[2048, 256], 52 | policy_dense_layers=[2048] 53 | ) 54 | args.scheduler_args.milestones = [75, 150] 55 | 56 | 57 | if __name__ == "__main__": 58 | nnet = nn(Game, args) 59 | c = Coach(Game, nnet, args) 60 | c.learn() 61 | -------------------------------------------------------------------------------- /alphazero/envs/hnefatafl/train_brandubh.py: -------------------------------------------------------------------------------- 1 | import pyximport, numpy 2 | 3 | pyximport.install(setup_args={'include_dirs': numpy.get_include()}) 4 | 5 | from alphazero.Coach import Coach, get_args 6 | from alphazero.NNetWrapper import NNetWrapper as nn 7 | from alphazero.envs.hnefatafl.fastafl import Game as Game 8 | from alphazero.GenericPlayers import RawMCTSPlayer 9 | from alphazero.utils import dotdict 10 | 11 | args = get_args( 12 | run_name='brandubh_fastafl', 13 | #workers=1, 14 | max_moves=100, 15 | num_stacked_observations=1, 16 | cpuct=1.25, 17 | symmetricSamples=True, 18 | numMCTSSims=250, 19 | numFastSims=50, 20 | numWarmupSims=5, 21 | probFastSim=0.8, 22 | 23 | selfPlayModelIter=None, 24 | skipSelfPlayIters=None, 25 | # train_on_past_data=True, 26 | # past_data_run_name='brandubh', 27 | model_gating=True, 28 | max_gating_iters=None, 29 | numWarmupIters=1, 30 | arenaMCTS=True, 31 | baselineCompareFreq=1, 32 | pastCompareFreq=1, 33 | train_steps_per_iteration=16, 34 | min_next_model_winrate=0.52, 35 | use_draws_for_winrate=True, 36 | 37 | process_batch_size=512, 38 | train_batch_size=4096, 39 | arena_batch_size=64, 40 | arenaCompare=64 * 4, 41 | arenaCompareBaseline=64 * 4, 42 | gamesPerIteration=512 * 4, 43 | 44 | lr=1e-2, 45 | optimizer_args=dotdict({ 46 | 'momentum': 0.9, 47 | 'weight_decay': 1e-3 48 | }), 49 | 50 | depth=4, 51 | num_channels=64, 52 | value_head_channels=16, 53 | policy_head_channels=16, 54 | value_dense_layers=[1024, 128], 55 | policy_dense_layers=[1024] 56 | ) 57 | args.scheduler_args.milestones = [75, 150] 58 | 59 | if __name__ == "__main__": 60 | nnet = nn(Game, args) 61 | c = Coach(Game, nnet, args) 62 | c.learn() 63 | -------------------------------------------------------------------------------- /alphazero/envs/hnefatafl/train_fastafl.py: -------------------------------------------------------------------------------- 1 | import pyximport, numpy 2 | 3 | pyximport.install(setup_args={'include_dirs': numpy.get_include()}) 4 | 5 | from alphazero.Coach import Coach, get_args 6 | from alphazero.NNetWrapper import NNetWrapper as nn 7 | from alphazero.envs.hnefatafl.fastafl import Game as Game 8 | from alphazero.GenericPlayers import RawMCTSPlayer 9 | from alphazero.utils import dotdict 10 | 11 | args = get_args( 12 | run_name='hnefatafl_fastafl', 13 | #workers=1, 14 | max_moves=512, 15 | num_stacked_observations=1, 16 | cpuct=1.25, 17 | symmetricSamples=True, 18 | numMCTSSims=250, 19 | numFastSims=50, 20 | numWarmupSims=5, 21 | probFastSim=0.8, 22 | 23 | selfPlayModelIter=None, 24 | skipSelfPlayIters=None, 25 | #train_on_past_data=True, 26 | #past_data_run_name='hnefatafl_fastafl', 27 | model_gating=True, 28 | max_gating_iters=None, 29 | numWarmupIters=1, 30 | arenaMCTS=True, 31 | baselineCompareFreq=1, 32 | pastCompareFreq=1, 33 | train_steps_per_iteration=80, 34 | min_next_model_winrate=0.52, 35 | use_draws_for_winrate=True, 36 | 37 | process_batch_size=128, 38 | train_batch_size=2048, 39 | arena_batch_size=32, 40 | arenaCompare=32 * 4, 41 | arenaCompareBaseline=32 * 4, 42 | gamesPerIteration=128 * 4, 43 | 44 | lr=1e-2, 45 | optimizer_args=dotdict({ 46 | 'momentum': 0.9, 47 | 'weight_decay': 1e-3 48 | }), 49 | 50 | depth=10, 51 | num_channels=128, 52 | value_head_channels=32, 53 | policy_head_channels=32, 54 | value_dense_layers=[4096, 128], 55 | policy_dense_layers=[4096] 56 | ) 57 | args.scheduler_args.milestones = [75, 150] 58 | 59 | if __name__ == "__main__": 60 | nnet = nn(Game, args) 61 | c = Coach(Game, nnet, args) 62 | c.learn() 63 | -------------------------------------------------------------------------------- /alphazero/envs/hnefatafl/train_test.py: -------------------------------------------------------------------------------- 1 | import pyximport; 2 | 3 | pyximport.install() 4 | 5 | from alphazero.Coach import Coach, get_args 6 | from alphazero.NNetWrapper import NNetWrapper as nn 7 | from alphazero.envs.hnefatafl.tafl import TaflGame as Game, NUM_STACKED_OBSERVATIONS, DRAW_MOVE_COUNT 8 | from alphazero.GenericPlayers import RawMCTSPlayer 9 | from alphazero.utils import dotdict 10 | 11 | args = get_args( 12 | run_name='hnefatafl_test', 13 | max_moves=DRAW_MOVE_COUNT, 14 | num_stacked_observations=NUM_STACKED_OBSERVATIONS, 15 | cpuct=1.25, 16 | symmetricSamples=False, 17 | numMCTSSims=5, 18 | numFastSims=5, 19 | numWarmupSims=5, 20 | probFastSim=0.75, 21 | 22 | selfPlayModelIter=None, 23 | skipSelfPlayIters=None, 24 | model_gating=True, 25 | max_gating_iters=None, 26 | numWarmupIters=1, 27 | arenaMCTS=False, 28 | baselineCompareFreq=1, 29 | pastCompareFreq=1, 30 | train_sample_ratio=3, 31 | min_next_model_winrate=0.52, 32 | use_draws_for_winrate=False, 33 | 34 | process_batch_size=16, 35 | train_batch_size=1024, 36 | arena_batch_size=32, 37 | arenaCompare=32 * 4, 38 | arenaCompareBaseline=32 * 4, 39 | gamesPerIteration=1, 40 | 41 | lr=1e-2, 42 | optimizer_args=dotdict({ 43 | 'momentum': 0.9, 44 | 'weight_decay': 1e-4 45 | }), 46 | 47 | depth=5, 48 | num_channels=32, 49 | value_head_channels=1, 50 | policy_head_channels=2, 51 | value_dense_layers=[32], 52 | policy_dense_layers=[32] 53 | ) 54 | args.scheduler_args.milestones = [75, 150] 55 | 56 | if __name__ == "__main__": 57 | nnet = nn(Game, args) 58 | c = Coach(Game, nnet, args) 59 | c.learn() 60 | -------------------------------------------------------------------------------- /alphazero/envs/othello/OthelloLogic.pyx: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | # cython: boundscheck=False 3 | # cython: wraparound=False 4 | # cython: nonecheck=False 5 | # cython: overflowcheck=False 6 | # cython: initializedcheck=False 7 | # cython: cdivision=True 8 | """ 9 | Author: Eric P. Nichols 10 | Date: Feb 8, 2008. 11 | Board class. 12 | Board data: 13 | 1=white, -1=black, 0=empty 14 | first dim is column , 2nd is row: 15 | pieces[1][7] is the square in column 2, 16 | at the opposite end of the board in row 8. 17 | Squares are stored and manipulated as (x,y) tuples. 18 | x is the column, y is the row. 19 | """ 20 | 21 | import numpy as np 22 | 23 | # list of all 8 directions on the board, as (x,y) offsets 24 | cdef list __directions = [(1,1),(1,0),(1,-1),(0,-1),(-1,-1),(-1,0),(-1,1),(0,1)] 25 | cdef (Py_ssize_t, Py_ssize_t) null_point = (-1, -1) 26 | 27 | cdef class Board: 28 | cdef public int n 29 | cdef public int[:,:] pieces 30 | 31 | def __init__(self, int n, _pieces=None): 32 | """Set up initial board configuration.""" 33 | 34 | self.n = n 35 | 36 | if _pieces is not None: 37 | self.pieces = _pieces 38 | else: 39 | # Create the empty board array. 40 | self.pieces = np.zeros((self.n, self.n), dtype=np.intc) 41 | 42 | # Set up the initial 4 pieces. 43 | self.pieces[self.n//2-1,self.n//2] = 1 44 | self.pieces[self.n//2,self.n//2-1] = 1 45 | self.pieces[self.n//2-1,self.n//2-1] = -1 46 | self.pieces[self.n//2,self.n//2] = -1 47 | 48 | def __getstate__(self): 49 | return self.n, np.asarray(self.pieces) 50 | 51 | def __setstate__(self, state): 52 | self.n, pieces = state 53 | self.pieces = np.asarray(pieces) 54 | 55 | # add [][] indexer syntax to the Board 56 | def __getitem__(self, Py_ssize_t index): 57 | return self.pieces[index] 58 | 59 | def count_diff(self, int color): 60 | """Counts the # pieces of the given color 61 | (1 for white, -1 for black, 0 for empty spaces)""" 62 | cdef int count = 0 63 | cdef Py_ssize_t x, y 64 | 65 | for y in range(self.n): 66 | for x in range(self.n): 67 | count += self.pieces[x,y] * color 68 | 69 | return count 70 | 71 | cpdef set get_legal_moves(self, int color): 72 | """Returns all the legal moves for the given color.""" 73 | cdef set moves = set() # stores the legal moves. 74 | cdef Py_ssize_t x, y 75 | 76 | # Get all the squares of the given color. 77 | for y in range(self.n): 78 | for x in range(self.n): 79 | if self.pieces[x,y] == color: 80 | moves.update(self.get_moves_for_square((x,y))) 81 | return moves 82 | 83 | cpdef bint has_legal_moves(self, int color): 84 | cdef Py_ssize_t x, y 85 | for y in range(self.n): 86 | for x in range(self.n): 87 | if self.pieces[x,y] == color: 88 | if len(self.get_moves_for_square((x,y)))>0: 89 | return True 90 | return False 91 | 92 | cpdef list get_moves_for_square(self, (Py_ssize_t, Py_ssize_t) square): 93 | """Returns all the legal moves that use the given square as a base. 94 | That is, if the given square is (3,4) and it contains a black piece, 95 | and (3,5) and (3,6) contain white pieces, and (3,7) is empty, one 96 | of the returned moves is (3,7) because everything from there to (3,4) 97 | is flipped. 98 | """ 99 | cdef Py_ssize_t x, y 100 | (x,y) = square 101 | 102 | # determine the color of the piece. 103 | cdef int color = self.pieces[x,y] 104 | 105 | # skip empty source squares. 106 | if color==0: 107 | return [] 108 | 109 | # search all possible directions. 110 | cdef list moves = [] 111 | cdef (Py_ssize_t, Py_ssize_t) move 112 | for direction in __directions: 113 | move = self._discover_move(square, direction) 114 | if move != null_point: 115 | # print(square,move,direction) 116 | moves.append(move) 117 | 118 | # return the generated move list 119 | return moves 120 | 121 | def execute_move(self, (Py_ssize_t, Py_ssize_t) move, int color): 122 | """Perform the given move on the board; flips pieces_raw as necessary. 123 | color gives the color pf the piece to play (1=white,-1=black) 124 | """ 125 | 126 | #Much like move generation, start at the new piece's square and 127 | #follow it on all 8 directions to look for a piece allowing flipping. 128 | 129 | # Add the piece to the empty square. 130 | # print(move) 131 | cdef list flips = [flip for direction in __directions 132 | for flip in self._get_flips(move, direction, color)] 133 | #assert len(list(flips))>0 134 | cdef Py_ssize_t x, y 135 | 136 | for x, y in flips: 137 | #print(self.pieces_raw[x,y],color) 138 | self.pieces[x,y] = color 139 | 140 | cdef (Py_ssize_t, Py_ssize_t) _discover_move(self, (Py_ssize_t, Py_ssize_t) origin, (Py_ssize_t, Py_ssize_t) direction): 141 | """ Returns the endpoint for a legal move, starting at the given origin, 142 | moving by the given increment.""" 143 | cdef Py_ssize_t x, y 144 | x, y = origin 145 | cdef int color = self.pieces[x,y] 146 | cdef list flips = [] 147 | 148 | for x, y in Board._increment_move(origin, direction, self.n): 149 | if self.pieces[x,y] == 0: 150 | if flips: 151 | # print("Found", x,y) 152 | return x, y 153 | else: 154 | return null_point 155 | 156 | elif self.pieces[x,y] == color: 157 | return null_point 158 | 159 | elif self.pieces[x,y] == -1*color: 160 | # print("Flip",x,y) 161 | flips.append((x, y)) 162 | 163 | return null_point 164 | 165 | cdef list _get_flips(self, (Py_ssize_t, Py_ssize_t) origin, (Py_ssize_t, Py_ssize_t) direction, int color): 166 | """ Gets the list of flips for a vertex and direction to use with the 167 | execute_move function """ 168 | #initialize variables 169 | cdef list flips = [origin] 170 | cdef Py_ssize_t x, y 171 | 172 | for x, y in Board._increment_move(origin, direction, self.n): 173 | #print(x,y) 174 | if self.pieces[x,y] == 0: 175 | return [] 176 | if self.pieces[x,y] == -color: 177 | flips.append((x, y)) 178 | elif self.pieces[x,y] == color and len(flips) > 0: 179 | #print(flips) 180 | return flips 181 | 182 | return [] 183 | 184 | @staticmethod 185 | cdef list _increment_move((Py_ssize_t, Py_ssize_t) move, (Py_ssize_t, Py_ssize_t) direction, int n): 186 | """ Generator expression for incrementing moves """ 187 | #move = list(map(sum, zip(move, direction))) 188 | cdef list moves = [] 189 | 190 | move = (move[0]+direction[0], move[1]+direction[1]) 191 | #while all(map(lambda x: 0 <= x < n, move)): 192 | while 0 <= move[0] < n and 0 <= move[1] < n: 193 | moves.append(move) 194 | #move=list(map(sum,zip(move,direction))) 195 | move = (move[0]+direction[0],move[1]+direction[1]) 196 | 197 | return moves 198 | 199 | -------------------------------------------------------------------------------- /alphazero/envs/othello/OthelloPlayers.py: -------------------------------------------------------------------------------- 1 | from alphazero.GenericPlayers import BasePlayer 2 | from alphazero.Game import GameState 3 | 4 | 5 | class HumanOthelloPlayer(BasePlayer): 6 | def play(self, state: GameState) -> int: 7 | valid = state.valid_moves() 8 | """ 9 | for i in range(len(valid)): 10 | if valid[i]: 11 | print(int(i / state._board.n), int(i % state._board.n)) 12 | """ 13 | 14 | while True: 15 | a = input('Enter a move: ') 16 | 17 | x, y = [int(x) for x in a.split(' ')] 18 | a = state._board.n * x + y if x != -1 else state._board.n ** 2 19 | if valid[a]: 20 | break 21 | else: 22 | print('Invalid move entered.') 23 | 24 | return a 25 | 26 | 27 | class GreedyOthelloPlayer(BasePlayer): 28 | def play(self, state: GameState) -> int: 29 | valids = state.valid_moves() 30 | candidates = [] 31 | 32 | for a in range(state.action_size()): 33 | if not valids[a]: continue 34 | 35 | next_state = state.clone() 36 | next_state.play_action(a) 37 | candidates += [(-next_state._board.count_diff(next_state.player), a)] 38 | 39 | candidates.sort() 40 | return candidates[0][1] 41 | -------------------------------------------------------------------------------- /alphazero/envs/othello/__init__.py: -------------------------------------------------------------------------------- 1 | from .othello import * 2 | -------------------------------------------------------------------------------- /alphazero/envs/othello/othello.pyx: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | from typing import List, Tuple, Any 3 | 4 | from alphazero.envs.othello.OthelloLogic import Board 5 | from alphazero.Game import GameState 6 | 7 | import numpy as np 8 | 9 | NUM_PLAYERS = 2 10 | NUM_CHANNELS = 1 11 | BOARD_SIZE = 8 12 | MAX_TURNS = BOARD_SIZE * BOARD_SIZE 13 | ACTION_SIZE = BOARD_SIZE ** 2 14 | OBSERVATION_SIZE = (NUM_CHANNELS, BOARD_SIZE, BOARD_SIZE) 15 | 16 | 17 | class Game(GameState): 18 | def __init__(self, _board=None): 19 | super().__init__(_board or self._get_board()) 20 | 21 | def __hash__(self) -> int: 22 | return hash(self._board.pieces.tobytes() + bytes([self.turns]) + bytes([self._player])) 23 | 24 | def __eq__(self, other: 'Game') -> bool: 25 | return ( 26 | np.asarray(self._board.pieces) == np.asarray(other._board.pieces) 27 | and self._player == other._player 28 | and self.turns == other.turns 29 | ) 30 | 31 | def display(self): 32 | display(self._board.pieces) 33 | 34 | @staticmethod 35 | def _get_board(*args, **kwargs): 36 | return Board(BOARD_SIZE, *args, **kwargs) 37 | 38 | def clone(self) -> 'Game': 39 | board = self._get_board(_pieces=np.copy(np.asarray(self._board.pieces))) 40 | game = Game(_board=board) 41 | game._player = self._player 42 | game._turns = self.turns 43 | return game 44 | 45 | @staticmethod 46 | def action_size() -> int: 47 | return ACTION_SIZE 48 | 49 | @staticmethod 50 | def observation_size() -> Tuple[int, int, int]: 51 | return OBSERVATION_SIZE 52 | 53 | @staticmethod 54 | def num_players() -> int: 55 | return NUM_PLAYERS 56 | 57 | @staticmethod 58 | def max_turns() -> int: 59 | return MAX_TURNS 60 | 61 | @staticmethod 62 | def has_draw() -> bool: 63 | return True 64 | 65 | def _player_range(self): 66 | return (1, -1)[self.player] 67 | 68 | def valid_moves(self): 69 | # return a fixed size binary vector 70 | valids = [0] * self.action_size() 71 | 72 | for x, y in self._board.get_legal_moves(self._player_range()): 73 | valids[self._board.n * x + y] = 1 74 | 75 | return np.array(valids, dtype=np.intc) 76 | 77 | def play_action(self, action: int) -> None: 78 | super().play_action(action) 79 | move = (action // self._board.n, action % self._board.n) 80 | self._board.execute_move(move, self._player_range()) 81 | self._update_turn() 82 | 83 | def win_state(self) -> np.ndarray: 84 | result = [False] * (NUM_PLAYERS + 1) 85 | player = self._player_range() 86 | 87 | if not self._board.has_legal_moves(player): 88 | diff = self._board.count_diff(player) 89 | if diff > 0: 90 | result[self.player] = True 91 | elif diff < 0: 92 | result[self._next_player(self.player)] = True 93 | else: 94 | result[NUM_PLAYERS] = True 95 | 96 | return np.array(result, dtype=np.uint8) 97 | 98 | def observation(self): 99 | return np.expand_dims(np.asarray(self._board.pieces), axis=0) 100 | 101 | def symmetries(self, pi) -> List[Tuple[Any, int]]: 102 | # mirror, rotational 103 | assert len(pi) == self._board.n ** 2 104 | 105 | pi_board = np.reshape(pi, (self._board.n, self._board.n)) 106 | result = [] 107 | 108 | for i in range(1, 5): 109 | for j in [True, False]: 110 | new_b = np.rot90(np.asarray(self._board.pieces), i) 111 | new_pi = np.rot90(pi_board, i) 112 | if j: 113 | new_b = np.fliplr(new_b) 114 | new_pi = np.fliplr(new_pi) 115 | 116 | gs = self.clone() 117 | gs._board.pieces = new_b 118 | result.append((gs, new_pi.ravel())) 119 | 120 | return result 121 | 122 | 123 | def display(board: np.ndarray): 124 | n = board.shape[0] 125 | 126 | for y in range(n): 127 | print(y, "|", end="") 128 | print("") 129 | print(" -----------------------") 130 | for y in range(n): 131 | print(y, "|", end="") # print the row # 132 | for x in range(n): 133 | piece = board[y][x] # get the piece to print 134 | if piece == -1: 135 | print("b ", end="") 136 | elif piece == 1: 137 | print("W ", end="") 138 | else: 139 | if x == n: 140 | print("-", end="") 141 | else: 142 | print("- ", end="") 143 | print("|") 144 | 145 | print(" -----------------------") 146 | -------------------------------------------------------------------------------- /alphazero/envs/othello/train.py: -------------------------------------------------------------------------------- 1 | import pyximport; pyximport.install() 2 | 3 | from alphazero.Coach import Coach, get_args 4 | from alphazero.NNetWrapper import NNetWrapper as nn 5 | from alphazero.envs.othello.othello import Game 6 | #from alphazero.envs.othello.OthelloPlayers import GreedyOthelloPlayer 7 | 8 | 9 | args = get_args( 10 | run_name='othello', 11 | workers=2, 12 | cpuct=4, 13 | numWarmupIters=1, 14 | baselineCompareFreq=1, 15 | pastCompareFreq=1, 16 | numMCTSSims=100, 17 | numFastSims=25, 18 | probFastSim=0.75, 19 | numWarmupSims=1, 20 | #baselineTester=GreedyOthelloPlayer, 21 | process_batch_size=128, 22 | train_batch_size=1024, 23 | gamesPerIteration=128*4, 24 | lr=0.01, 25 | num_channels=64, 26 | depth=4, 27 | value_head_channels=16, 28 | policy_head_channels=16, 29 | value_dense_layers=[512, 256], 30 | policy_dense_layers=[512] 31 | ) 32 | 33 | 34 | if __name__ == "__main__": 35 | nnet = nn(Game, args) 36 | c = Coach(Game, nnet, args) 37 | c.learn() 38 | -------------------------------------------------------------------------------- /alphazero/envs/stratego/__init__.py: -------------------------------------------------------------------------------- 1 | import pyximport, numpy 2 | pyximport.install(setup_args={'include_dirs': numpy.get_include()}) 3 | 4 | from .engine import * 5 | -------------------------------------------------------------------------------- /alphazero/envs/stratego/engine.pxd: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | # cython: boundscheck=False 3 | # cython: wraparound=False 4 | # cython: nonecheck=False 5 | # cython: overflowcheck=False 6 | # cython: initializedcheck=False 7 | # cython: cdivision=True 8 | # cython: auto_pickle=True 9 | 10 | from boardgame.board cimport BaseBoard, Square 11 | 12 | 13 | cdef class Board(BaseBoard): 14 | cdef public list red_exploded_bombs 15 | cdef public list blue_exploded_bombs 16 | cdef public bint _red_flag_captured 17 | cdef public bint _blue_flag_captured 18 | cdef public list _red_pieces_to_place 19 | cdef public list _blue_pieces_to_place 20 | 21 | cpdef void clear_pieces_to_place(self) 22 | cpdef int _base_piece(self, int piece_type) 23 | 24 | cpdef bint _is_valid(self, Square dest_square, int piece_type) 25 | cpdef list legal_moves(self, tuple pieces=*, int piece_type=*) 26 | cpdef bint _has_legals_check(self, Square piece_square) 27 | cpdef bint has_legal_moves(self, tuple pieces=*, int piece_type=*) 28 | 29 | cpdef int get_winner(self) 30 | cpdef void move(self, source, dest, bint check_turn=*, bint _check_valid=*, bint _check_win=*) 31 | cpdef int to_play(self) 32 | 33 | cpdef int _get_raw_value(self, int piece_type) 34 | cpdef int _get_visible_value(self, int piece_value) 35 | -------------------------------------------------------------------------------- /alphazero/envs/stratego/pit.py: -------------------------------------------------------------------------------- 1 | from pyximport import install 2 | from numpy import get_include 3 | install(setup_args={'include_dirs': get_include()}) 4 | 5 | from alphazero.Arena import Arena 6 | from alphazero.GenericPlayers import * 7 | from alphazero.NNetWrapper import NNetWrapper as NNet 8 | from alphazero.envs.stratego.stratego import Game as Game, display 9 | from alphazero.envs.stratego.train import args 10 | from alphazero.envs.stratego.players import HumanStrategoPlayer 11 | import random 12 | 13 | batched_arena = False 14 | args.arenaCompare = 1 15 | #args.arena_batch_size = 8 16 | #args.workers = 4 17 | #args.arenaCompare = args.workers * args.arena_batch_size 18 | #args.numMCTSSims = 2000 19 | #args.arena_batch_size = 64 20 | #args.temp_scaling_fn = lambda x,y,z:0.25 21 | #args2.temp_scaling_fn = args.temp_scaling_fn 22 | #args.cuda = False 23 | #args.add_root_noise = args.add_root_temp = False 24 | 25 | # nnet players 26 | #nn1 = NNet(Game, args) 27 | #nn1.load_checkpoint('./checkpoint/' + args.run_name, 'iteration-0001.pkl') 28 | #nn2 = NNet(Game, args) 29 | #nn2.load_checkpoint('./checkpoint/brandubh2', 'iteration-0112.pkl') 30 | #player1 = nn1.process 31 | #player2 = nn1.process 32 | 33 | #player1 = MCTSPlayer(Game, args, nn1, verbose=False) 34 | #player2 = MCTSPlayer(Game, args, nn1, verbose=False) 35 | #player2 = RandomPlayer() 36 | #player2 = GreedyTaflPlayer() 37 | #player2 = RandomPlayer() 38 | #player2 = OneStepLookaheadConnect4Player() 39 | player1 = RawMCTSPlayer(Game, args) 40 | player2 = RawMCTSPlayer(Game, args) 41 | #player2 = HumanFastaflPlayer() 42 | 43 | players = [player1, player2] 44 | #random.shuffle(players) 45 | 46 | arena = Arena(players, Game, use_batched_mcts=batched_arena, args=args, display=display) 47 | if batched_arena: 48 | wins, draws, winrates = arena.play_games(args.arenaCompare) 49 | for i in range(len(wins)): 50 | print(f'player{i+1}:\n\twins: {wins[i]}\n\twin rate: {winrates[i]}') 51 | print('draws: ', draws) 52 | else: 53 | arena.play_game(verbose=True) 54 | -------------------------------------------------------------------------------- /alphazero/envs/stratego/players.py: -------------------------------------------------------------------------------- 1 | from alphazero.GenericPlayers import BasePlayer 2 | 3 | import pyximport, numpy 4 | pyximport.install(setup_args={'include_dirs': numpy.get_include()}) 5 | 6 | from alphazero.envs.stratego.stratego import get_action, Game, Square 7 | from boardgame.errors import InvalidMoveError 8 | 9 | 10 | class HumanStrategoPlayer(BasePlayer): 11 | def play(self, state: Game) -> int: 12 | valid_moves = state.valid_moves() 13 | 14 | def string_to_action(player_inp: str) -> int: 15 | try: 16 | move_lst = [int(x) for x in player_inp.split()] 17 | 18 | if state._board.play_phase: 19 | return get_action(state._board, (Square(*move_lst[:2]), Square(*move_lst[2:]))) 20 | else: 21 | return get_action(state._board, (move_lst[0], Square(*move_lst[1:]))) 22 | except (ValueError, AttributeError, InvalidMoveError): 23 | return -1 24 | 25 | action = string_to_action(input(f"Enter the move to play for the player {state.player}: ")) 26 | while action == -1 or not valid_moves[action]: 27 | action = string_to_action(input( 28 | f"Illegal move (action={action}, in valids: {bool(valid_moves[action])}). Enter a valid move: " 29 | )) 30 | 31 | return action 32 | -------------------------------------------------------------------------------- /alphazero/envs/stratego/stratego.pyx: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | # cython: boundscheck=False 3 | # cython: wraparound=False 4 | # cython: nonecheck=False 5 | # cython: overflowcheck=False 6 | # cython: initializedcheck=False 7 | # cython: cdivision=True 8 | # cython: auto_pickle=True 9 | 10 | import numpy as np 11 | cimport numpy as np 12 | import pyximport 13 | 14 | pyximport.install(setup_args={'include_dirs': np.get_include()}) 15 | 16 | from boardgame import Square 17 | from boardgame.board cimport Square 18 | from alphazero.envs.stratego.engine import Board 19 | from alphazero.envs.stratego.engine cimport Board 20 | 21 | DTYPE = np.float32 22 | ctypedef np.float32_t DTYPE_t 23 | 24 | 25 | cdef int other_team_offset = 20 26 | cdef int visible_offset = 100 27 | cdef int RED_TEAM_COLOUR = 1 28 | cdef tuple ALL_RED_PIECES = tuple(range(1, 13)) 29 | cdef tuple ALL_BLUE_PIECES = tuple((p + other_team_offset for p in ALL_RED_PIECES)) 30 | 31 | cdef int NUM_PLAYERS = 2 32 | cdef int NUM_STACKED_OBSERVATIONS = 1 33 | cdef int NUM_BASE_CHANNELS = 30 34 | cdef int NUM_CHANNELS = NUM_BASE_CHANNELS * NUM_STACKED_OBSERVATIONS 35 | cdef int NUM_PIECES = len(ALL_RED_PIECES) 36 | cdef int DRAW_MOVE_COUNT = 512 37 | 38 | cdef Board b = Board() 39 | cdef int ACTION_SIZE = max( 40 | # size is 1048 for width 8, height 10, num_pieces 12 41 | b.width + b.height * b.width + NUM_PIECES * b.width * b.height, 42 | # 1280 for width 8, height 10 43 | b.width * b.height * (b.width + b.height - 2) 44 | ) 45 | cdef tuple OBS_SIZE = (NUM_CHANNELS, b.height, b.width) 46 | 47 | 48 | cpdef int get_action(Board board, tuple move): 49 | cdef int x 50 | cdef int y 51 | cdef int new_x = move[1].x 52 | cdef int new_y = move[1].y 53 | cdef int move_type 54 | 55 | if isinstance(move[0], int): 56 | # phase 1 of game 57 | return new_x + new_y * board.width + (move[0] - other_team_offset * super(Board, board).to_play()) * board.width * board.height 58 | 59 | else: 60 | # phase 2 61 | x = move[0].x 62 | y = move[0].y 63 | 64 | if (x - new_x) == 0: 65 | move_type = new_y if new_y < y else new_y - 1 66 | else: 67 | move_type = board.height + new_x - 1 68 | if new_x >= x: move_type -= 1 69 | 70 | return (board.width + board.height - 2) * (x + y * board.width) + move_type 71 | 72 | cpdef tuple get_move(Board board, int action): 73 | cdef int a, size 74 | cdef int move_type 75 | cdef int start_x, start_y 76 | cdef int new_x, new_y 77 | 78 | if board.play_phase: 79 | size = board.width + board.height - 2 80 | move_type = action % size 81 | a = action // size 82 | start_x = a % board.width 83 | start_y = a // board.width 84 | 85 | if move_type < board.height - 1: 86 | new_x = start_x 87 | new_y = move_type 88 | if move_type >= start_y: new_y += 1 89 | else: 90 | new_x = move_type - board.height + 1 91 | if new_x >= start_x: new_x += 1 92 | new_y = start_y 93 | 94 | return Square(int(start_x), int(start_y)), Square(int(new_x), int(new_y)) 95 | 96 | else: 97 | size = board.width * board.height 98 | a = action % size 99 | return (action // size) + other_team_offset * super(Board, board).to_play(), Square(a % board.width, a // board.width) 100 | 101 | 102 | cpdef list _add_obs(Board b, int const_max_player, int const_max_turns): 103 | cdef list obs_planes = [] 104 | cdef np.ndarray[DTYPE_t, ndim=2] red_bombs = np.zeros_like(b._state, dtype=DTYPE) 105 | cdef np.ndarray[DTYPE_t, ndim=2] blue_bombs = red_bombs.copy() 106 | cdef Square bomb_square 107 | cdef int i 108 | 109 | # add binary planes for the positions of pieces on both teams 110 | obs_planes.extend(( 111 | np.array(np.where(np.in1d(b._state % visible_offset, ALL_RED_PIECES), 1., 0.), dtype=DTYPE).reshape((b._state.shape[0], b._state.shape[1])), 112 | np.array(np.where(np.in1d(b._state % visible_offset, ALL_BLUE_PIECES), 1., 0.), dtype=DTYPE).reshape((b._state.shape[0], b._state.shape[1])) 113 | )) 114 | 115 | # add binary planes for the visible pieces of both teams 116 | for i in range(1, NUM_PIECES + 1): 117 | obs_planes.extend(( 118 | np.array(np.where(b.get_mask((i + visible_offset,)), 1., 0.), dtype=DTYPE), 119 | np.array(np.where(b.get_mask((i + visible_offset + other_team_offset,)), 1., 0.), dtype=DTYPE) 120 | )) 121 | 122 | # add binary planes for the locations of exploded bombs on both teams 123 | for bomb_square in b.red_exploded_bombs: 124 | red_bombs[bomb_square.y, bomb_square.x] = 1. 125 | 126 | for bomb_square in b.blue_exploded_bombs: 127 | blue_bombs[bomb_square.y, bomb_square.x] = 1. 128 | 129 | obs_planes.extend((red_bombs, blue_bombs)) 130 | 131 | # add binary plane for the current player and relative plane for the current turn number 132 | obs_planes.extend(( 133 | np.full_like(b._state, 134 | super(Board, b).to_play() / (const_max_player - 1) if const_max_player > 1 else 0., dtype=DTYPE 135 | ), 136 | np.full_like(b._state, 137 | b.num_turns / const_max_turns if const_max_turns else 0., dtype=DTYPE 138 | ) 139 | )) 140 | 141 | return obs_planes 142 | 143 | 144 | cpdef list _add_empty(Board board): 145 | return [np.zeros_like(board._state, dtype=DTYPE)] * NUM_BASE_CHANNELS 146 | 147 | 148 | cpdef np.ndarray _get_observation(Board board, int const_max_players, int const_max_turns, int past_obs=1, list past_states=[]): 149 | cdef list past, obs = [] 150 | cdef Py_ssize_t i 151 | 152 | if past_states: 153 | past = past_states.copy() 154 | past.insert(0, board) 155 | for i in range(past_obs): 156 | if board.num_turns < i: 157 | obs.extend(_add_empty(board)) 158 | else: 159 | obs.extend(_add_obs(past[i], const_max_players, const_max_turns)) 160 | else: 161 | obs = _add_obs(board, const_max_players, const_max_turns) 162 | 163 | return np.array(obs, dtype=DTYPE) 164 | 165 | cdef class Game: #(GameState): 166 | cdef public Board _board 167 | 168 | def __init__(self, _board=None): 169 | self._board = _board or Board() 170 | 171 | def __eq__(self, other: 'Game') -> bool: 172 | return self._board == other._board 173 | 174 | def __str__(self): 175 | return str(self._board) + '\n' 176 | 177 | @property 178 | def player(self) -> int: 179 | return super(Board, self._board).to_play() 180 | 181 | @property 182 | def turns(self): 183 | return self._board.num_turns 184 | 185 | cpdef Game clone(self): 186 | return Game(self._board.copy()) 187 | 188 | @staticmethod 189 | def num_players(): 190 | return NUM_PLAYERS 191 | 192 | @staticmethod 193 | def action_size(): 194 | return ACTION_SIZE 195 | 196 | @staticmethod 197 | def observation_size(): 198 | return OBS_SIZE 199 | 200 | cpdef int _next_player(self, int player, int turns=1): 201 | return (player + turns) % Game.num_players() 202 | 203 | cpdef np.ndarray valid_moves(self): 204 | cdef list valids = [0] * ACTION_SIZE 205 | cdef tuple move 206 | 207 | for move in self._board.legal_moves(pieces=(), piece_type=self._board.to_play()): 208 | valids[get_action(self._board, move)] = 1 209 | 210 | return np.array(valids, dtype=np.uint8) 211 | 212 | cpdef void play_action(self, int action): 213 | cdef tuple move = get_move(self._board, action) 214 | self._board.move(move[0], move[1], check_turn=False, _check_valid=False, _check_win=False) 215 | 216 | cpdef np.ndarray win_state(self): 217 | cdef np.ndarray[dtype=np.uint8_t, ndim=1] result = np.zeros(NUM_PLAYERS + 1, dtype=np.uint8) 218 | cdef int winner 219 | 220 | # Check if maximum moves have been exceeded 221 | if self.turns >= DRAW_MOVE_COUNT: 222 | result[NUM_PLAYERS] = 1 223 | else: 224 | winner = self._board.get_winner() 225 | if winner != 0: 226 | result[0 if winner == RED_TEAM_COLOUR else 1] = 1 227 | 228 | return result 229 | 230 | cpdef np.ndarray observation(self): 231 | return _get_observation( 232 | self._board, 233 | NUM_PLAYERS, 234 | DRAW_MOVE_COUNT, 235 | NUM_STACKED_OBSERVATIONS 236 | ) 237 | 238 | cpdef list symmetries(self, np.ndarray pi): 239 | cdef np.ndarray[DTYPE_t, ndim=2] new_state = np.fliplr(self._board._state.astype(DTYPE)) 240 | cdef np.ndarray[DTYPE_t, ndim=1] new_pi = np.zeros(ACTION_SIZE, dtype=DTYPE) 241 | cdef Board new_b = self._board.copy() 242 | cdef Square dest 243 | cdef tuple new_move 244 | 245 | new_b._state = new_state 246 | for action, prob in enumerate(pi): 247 | move = get_move(self._board, action) 248 | dest = Square(self._board.width - 1 - move[1].x, move[1].y) 249 | 250 | if self._board.play_phase: 251 | new_move = (Square(self._board.width - 1 - move[0].x, move[0].y), dest) 252 | else: 253 | new_move = (move[0], dest) 254 | 255 | new_pi[get_action(new_b, new_move)] = prob 256 | 257 | return [(self.clone(), pi), (Game(new_b), new_pi)] 258 | 259 | 260 | cpdef void display(Game g, int action=-1): 261 | if action != -1: print(f'Action: {action}, Move: {get_move(g._board, action)}') 262 | print(g) 263 | -------------------------------------------------------------------------------- /alphazero/envs/stratego/train.py: -------------------------------------------------------------------------------- 1 | import pyximport; pyximport.install() 2 | 3 | from alphazero.Coach import Coach, get_args 4 | from alphazero.NNetWrapper import NNetWrapper as nn 5 | from alphazero.envs.stratego.stratego import Game 6 | from alphazero.GenericPlayers import RawMCTSPlayer 7 | from alphazero.utils import dotdict 8 | 9 | args = get_args( 10 | run_name='stratego', 11 | workers=6, 12 | max_moves=512, 13 | num_stacked_observations=1, 14 | cpuct=1.25, 15 | symmetricSamples=True, 16 | numMCTSSims=100, 17 | numFastSims=15, 18 | numWarmupSims=5, 19 | probFastSim=0.75, 20 | 21 | selfPlayModelIter=None, 22 | skipSelfPlayIters=None, 23 | #train_on_past_data=True, 24 | #past_data_run_name='brandubh', 25 | model_gating=True, 26 | max_gating_iters=None, 27 | numWarmupIters=2, 28 | arenaMCTS=True, 29 | baselineCompareFreq=5, 30 | pastCompareFreq=5, 31 | #train_steps_per_iteration=40, 32 | min_next_model_winrate=0.52, 33 | use_draws_for_winrate=True, 34 | 35 | process_batch_size=64, 36 | train_batch_size=256, 37 | arena_batch_size=32, 38 | arenaCompare=32*4, 39 | arenaCompareBaseline=32*4, 40 | gamesPerIteration=64*6, 41 | 42 | lr=1e-2, 43 | optimizer_args=dotdict({ 44 | 'momentum': 0.9, 45 | 'weight_decay': 1e-3 46 | }), 47 | 48 | depth=4, 49 | num_channels=64, 50 | value_head_channels=16, 51 | policy_head_channels=16, 52 | value_dense_layers=[1024, 128], 53 | policy_dense_layers=[1024] 54 | ) 55 | args.scheduler_args.milestones = [75, 150] 56 | 57 | 58 | if __name__ == "__main__": 59 | nnet = nn(Game, args) 60 | c = Coach(Game, nnet, args) 61 | c.learn() 62 | -------------------------------------------------------------------------------- /alphazero/envs/tictactoe/TicTacToeLogic.py: -------------------------------------------------------------------------------- 1 | class Board: 2 | """ 3 | Board class for the game of TicTacToe. 4 | Default board size is 3x3. 5 | Board data: 6 | 1=white(O), -1=black(X), 0=empty 7 | first dim is column , 2nd is row: 8 | pieces[0][0] is the top left square, 9 | pieces[2][0] is the bottom left square, 10 | Squares are stored and manipulated as (x,y) tuples. 11 | 12 | Author: Evgeny Tyurin, github.com/evg-tyurin 13 | Date: Jan 5, 2018. 14 | 15 | Based on the board for the game of Othello by Eric P. Nichols. 16 | """ 17 | 18 | # list of all 8 directions on the board, as (x,y) offsets 19 | __directions = [(1, 1), (1, 0), (1, -1), (0, -1), (-1, -1), (-1, 0), (-1, 1), (0, 1)] 20 | 21 | def __init__(self, n=3, _pieces=None): 22 | """Set up initial board configuration.""" 23 | 24 | self.n = n 25 | 26 | if _pieces is not None: 27 | self.pieces = _pieces 28 | else: 29 | # Create the empty board array. 30 | self.pieces = [None] * self.n 31 | for i in range(self.n): 32 | self.pieces[i] = [0] * self.n 33 | 34 | # add [][] indexer syntax to the Board 35 | def __getitem__(self, index): 36 | return self.pieces[index] 37 | 38 | def get_legal_moves(self): 39 | """Returns all the legal moves for the given color. 40 | (1 for white, -1 for black) 41 | """ 42 | moves = [] # stores the legal moves. 43 | 44 | # Get all the empty squares (color==0) 45 | for y in range(self.n): 46 | for x in range(self.n): 47 | if self[x][y] == 0: 48 | newmove = (x, y) 49 | moves.append(newmove) 50 | return moves 51 | 52 | def has_legal_moves(self): 53 | for y in range(self.n): 54 | for x in range(self.n): 55 | if self[x][y] == 0: 56 | return True 57 | return False 58 | 59 | def is_win(self, color): 60 | """Check whether the given player has collected a triplet in any direction; 61 | @param color (1=white,-1=black) 62 | """ 63 | win = self.n 64 | # check y-strips 65 | for y in range(self.n): 66 | count = 0 67 | for x in range(self.n): 68 | if self[x][y] == color: 69 | count += 1 70 | if count == win: 71 | return True 72 | 73 | # check x-strips 74 | for x in range(self.n): 75 | count = 0 76 | for y in range(self.n): 77 | if self[x][y] == color: 78 | count += 1 79 | if count == win: 80 | return True 81 | 82 | # check two diagonal strips 83 | count = 0 84 | for d in range(self.n): 85 | if self[d][d] == color: 86 | count += 1 87 | if count == win: 88 | return True 89 | 90 | count = 0 91 | for d in range(self.n): 92 | if self[d][self.n - d - 1] == color: 93 | count += 1 94 | if count == win: 95 | return True 96 | 97 | return False 98 | 99 | def execute_move(self, move, color): 100 | """Perform the given move on the board; 101 | color gives the color pf the piece to play (1=white,-1=black) 102 | """ 103 | (x, y) = move 104 | 105 | # Add the piece to the empty square. 106 | assert self[x][y] == 0 107 | self[x][y] = color 108 | -------------------------------------------------------------------------------- /alphazero/envs/tictactoe/TicTacToePlayers.py: -------------------------------------------------------------------------------- 1 | from alphazero.GenericPlayers import BasePlayer 2 | from alphazero.Game import GameState 3 | 4 | 5 | class HumanTicTacToePlayer(BasePlayer): 6 | def play(self, state: GameState) -> int: 7 | valid = state.valid_moves() 8 | """ 9 | for i in range(len(valid)): 10 | if valid[i]: 11 | print(int(i / state._board.n), int(i % state._board.n)) 12 | """ 13 | 14 | while True: 15 | a = input('Enter a move: ') 16 | 17 | x, y = [int(x) for x in a.split(' ')] 18 | a = state._board.n * x + y if x != -1 else state._board.n ** 2 19 | if valid[a]: 20 | break 21 | else: 22 | print('Invalid move entered.') 23 | 24 | return a 25 | -------------------------------------------------------------------------------- /alphazero/envs/tictactoe/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/alphazero/envs/tictactoe/__init__.py -------------------------------------------------------------------------------- /alphazero/envs/tictactoe/tictactoe.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Any 2 | 3 | from alphazero.Game import GameState 4 | from alphazero.envs.tictactoe.TicTacToeLogic import Board 5 | 6 | import numpy as np 7 | 8 | NUM_PLAYERS = 2 9 | NUM_CHANNELS = 1 10 | BOARD_SIZE = 3 11 | ACTION_SIZE = BOARD_SIZE ** 2 12 | OBSERVATION_SIZE = (NUM_CHANNELS, BOARD_SIZE, BOARD_SIZE) 13 | 14 | 15 | class Game(GameState): 16 | def __init__(self, _board=None): 17 | super().__init__(_board or self._get_board()) 18 | 19 | @staticmethod 20 | def _get_board(): 21 | return Board(BOARD_SIZE) 22 | 23 | def __eq__(self, other: 'Game') -> bool: 24 | return ( 25 | self._board.pieces == other._board.pieces 26 | and self._board.n == other._board.n 27 | and self._player == other._player 28 | and self.turns == other.turns 29 | ) 30 | 31 | def clone(self) -> 'Game': 32 | g = Game() 33 | g._board.pieces = np.copy(self._board.pieces) 34 | g._player = self._player 35 | g._turns = self.turns 36 | return g 37 | 38 | @staticmethod 39 | def num_players() -> int: 40 | return NUM_PLAYERS 41 | 42 | @staticmethod 43 | def action_size() -> int: 44 | return ACTION_SIZE 45 | 46 | @staticmethod 47 | def observation_size() -> Tuple[int, int, int]: 48 | return OBSERVATION_SIZE 49 | 50 | def _player_range(self): 51 | return (1, -1)[self.player] 52 | 53 | def valid_moves(self): 54 | # return a fixed size binary vector 55 | valids = [0] * self.action_size() 56 | 57 | for x, y in self._board.get_legal_moves(): 58 | valids[self._board.n * x + y] = 1 59 | 60 | return np.array(valids, dtype=np.uint8) 61 | 62 | def play_action(self, action: int) -> None: 63 | move = (action // self._board.n, action % self._board.n) 64 | self._board.execute_move(move, self._player_range()) 65 | self._update_turn() 66 | 67 | def win_state(self): 68 | result = [False] * (NUM_PLAYERS + 1) 69 | player = self._player_range() 70 | 71 | if self._board.is_win(player): 72 | result[self.player] = True 73 | elif self._board.is_win(-player): 74 | result[self._next_player(self.player)] = True 75 | elif not self._board.has_legal_moves(): 76 | result[-1] = True 77 | 78 | return np.array(result, dtype=np.uint8) 79 | 80 | def observation(self): 81 | return np.expand_dims(np.asarray(self._board.pieces), axis=0).astype(np.float32) 82 | 83 | def symmetries(self, pi: np.ndarray) -> List[Tuple[Any, int]]: 84 | # mirror, rotational 85 | assert (len(pi) == self._board.n ** 2) 86 | 87 | pi_board = np.reshape(pi, (self._board.n, self._board.n)) 88 | result = [] 89 | 90 | for i in range(1, 5): 91 | for j in [True, False]: 92 | new_b = np.rot90(np.asarray(self._board.pieces), i) 93 | new_pi = np.rot90(pi_board, i) 94 | if j: 95 | new_b = np.fliplr(new_b) 96 | new_pi = np.fliplr(new_pi) 97 | 98 | gs = self.clone() 99 | gs._board.pieces = new_b 100 | result.append((gs, new_pi.ravel())) 101 | 102 | return result 103 | 104 | 105 | def display(board): 106 | n = board.shape[0] 107 | 108 | print(" ", end="") 109 | for y in range(n): 110 | print(y, "", end="") 111 | print("") 112 | print(" ", end="") 113 | for _ in range(n): 114 | print("-", end="-") 115 | print("--") 116 | for y in range(n): 117 | print(y, "|", end="") # print the row # 118 | for x in range(n): 119 | piece = board[y][x] # get the piece to print 120 | if piece == -1: 121 | print("X ", end="") 122 | elif piece == 1: 123 | print("O ", end="") 124 | else: 125 | if x == n: 126 | print("-", end="") 127 | else: 128 | print("- ", end="") 129 | print("|") 130 | 131 | print(" ", end="") 132 | for _ in range(n): 133 | print("-", end="-") 134 | print("--") 135 | -------------------------------------------------------------------------------- /alphazero/envs/tictactoe/train.py: -------------------------------------------------------------------------------- 1 | import pyximport; pyximport.install() 2 | 3 | from alphazero.Coach import Coach, get_args 4 | from alphazero.NNetWrapper import NNetWrapper as nn 5 | from alphazero.envs.tictactoe.tictactoe import Game 6 | 7 | 8 | args = get_args( 9 | run_name='tictactoe', 10 | workers=2, 11 | cpuct=2, 12 | numMCTSSims=100, 13 | probFastSim=0.5, 14 | numWarmupIters=1, 15 | baselineCompareFreq=1, 16 | pastCompareFreq=1, 17 | arenaBatchSize=2048, 18 | arenaCompare=2*2048, 19 | arenaCompareBaseline=2*2048, 20 | process_batch_size=512, 21 | train_batch_size=2048, 22 | gamesPerIteration=2*512, 23 | lr=0.01, 24 | num_channels=32, 25 | depth=4, 26 | value_head_channels=4, 27 | policy_head_channels=4, 28 | value_dense_layers=[128, 64], 29 | policy_dense_layers=[128], 30 | skipSelfPlayIters=1, 31 | ) 32 | 33 | 34 | if __name__ == "__main__": 35 | nnet = nn(Game, args) 36 | c = Coach(Game, nnet, args) 37 | c.learn() 38 | -------------------------------------------------------------------------------- /alphazero/pit-multi.py: -------------------------------------------------------------------------------- 1 | import pyximport; pyximport.install() 2 | 3 | from alphazero.GenericPlayers import * 4 | from alphazero.MCTS import MCTS 5 | from alphazero.Arena import Arena 6 | from alphazero.NNetWrapper import NNetWrapper as nn 7 | from alphazero.envs.othello import NNetSpecialWrapper as nns 8 | from alphazero.envs.othello import OthelloGame as Game 9 | from alphazero.utils import * 10 | from tensorboardX import SummaryWriter 11 | from pathlib import Path 12 | from glob import glob 13 | 14 | import numpy as np 15 | import pprint 16 | 17 | 18 | """ 19 | use this script to play every x agents against a single agent and graph win rate. 20 | """ 21 | 22 | args = dotdict({ 23 | 'run_name': 'othello_better_teacher', 24 | 'arenaCompare': 100, 25 | 'arenaTemp': 0, 26 | 'temp': 1, 27 | 'tempThreshold': 10, 28 | # use zero if no montecarlo 29 | 'numMCTSSims': 50, 30 | 'cpuct': 1, 31 | 'x': 10, 32 | }) 33 | 34 | if __name__ == '__main__': 35 | print('Args:') 36 | pprint.pprint(args) 37 | benchmark_agent = "othello/special/6x6_153checkpoints_best.pth.tar" 38 | 39 | if args.run_name != '': 40 | writer = SummaryWriter(log_dir='runs/'+args.run_name) 41 | else: 42 | writer = SummaryWriter() 43 | if not Path('checkpoint').exists(): 44 | Path('checkpoint').mkdir() 45 | print('Beginning comparison') 46 | networks = sorted(glob('checkpoint/*')) 47 | temp = networks[::args['x']] 48 | if temp[-1] != networks[-1]: 49 | temp.append(networks[-1]) 50 | 51 | networks = temp 52 | model_count = len(networks) 53 | 54 | if model_count < 1: 55 | print( 56 | "Too few models for pit multi.") 57 | exit() 58 | 59 | total_games = model_count * args.arenaCompare 60 | print( 61 | f'Comparing {model_count} different models in {total_games} total games') 62 | 63 | g = Game(6) 64 | nnet1 = nns(g) 65 | nnet2 = nn(g) 66 | 67 | nnet1.load_checkpoint(folder="", filename=benchmark_agent) 68 | short_name = Path(benchmark_agent).stem 69 | 70 | if args.numMCTSSims <= 0: 71 | p1 = NNPlayer(g, nnet1, args.arenaTemp).play 72 | else: 73 | mcts1 = MCTS(g, nnet1, args) 74 | 75 | def p1(x, turn): 76 | if turn <= 2: 77 | mcts1.reset() 78 | temp = args.temp if turn <= args.tempThreshold else args.arenaTemp 79 | policy = mcts1.getActionProb(x, temp=temp) 80 | return np.random.choice(len(policy), p=policy) 81 | 82 | for i in range(model_count): 83 | file = Path(networks[i]) 84 | print(f'{short_name} vs {file.stem}') 85 | 86 | nnet2.load_checkpoint(folder='checkpoint', filename=file.name) 87 | if args.numMCTSSims <= 0: 88 | p2 = NNPlayer(g, nnet2, args.arenaTemp).play 89 | else: 90 | mcts2 = MCTS(g, nnet2, args) 91 | 92 | def p2(x, turn): 93 | if turn <= 2: 94 | mcts2.reset() 95 | temp = args.temp if turn <= args.tempThreshold else args.arenaTemp 96 | policy = mcts2.getActionProb(x, temp=temp) 97 | return np.random.choice(len(policy), p=policy) 98 | 99 | arena = Arena(p1, p2, g) 100 | p1wins, p2wins, draws = arena.play_games(args.arenaCompare) 101 | writer.add_scalar( 102 | f'Win Rate vs {short_name}', (p2wins + 0.5*draws)/args.arenaCompare, i*args.x) 103 | print(f'wins: {p1wins}, ties: {draws}, losses:{p2wins}\n') 104 | writer.close() -------------------------------------------------------------------------------- /alphazero/pit.py: -------------------------------------------------------------------------------- 1 | import numpy, pyximport 2 | 3 | pyximport.install(setup_args={'include_dirs': numpy.get_include()}) 4 | 5 | from alphazero.Arena import Arena 6 | from alphazero.GenericPlayers import * 7 | from alphazero.NNetWrapper import NNetWrapper as NNet 8 | 9 | 10 | """ 11 | use this script to play any two agents against each other, or play manually with 12 | any agent. 13 | """ 14 | if __name__ == '__main__': 15 | from alphazero.envs.gobang.GobangGame import GobangGame as Game 16 | #from alphazero.envs.tafl.players import GreedyTaflPlayer 17 | from alphazero.envs.gobang.train import args 18 | 19 | args.numMCTSSims = 800 20 | #args.arena_batch_size = 64 21 | 22 | # all players 23 | # rp = RandomPlayer(g).play 24 | # gp = OneStepLookaheadConnect4Player(g).play 25 | # hp = HumanTaflPlayer(g).play 26 | 27 | # nnet players 28 | nn1 = NNet(Game, args) 29 | nn1.load_checkpoint('./checkpoint/hnefatafl', 'iteration-0001.pkl') 30 | #nn2 = NNet(Game, args) 31 | #nn2.load_checkpoint('./checkpoint/hnefatafl', 'iteration-0000.pkl') 32 | #player1 = nn1.process 33 | #player2 = nn2.process 34 | 35 | player1 = MCTSPlayer(nn1, args=args) 36 | #player2 = MCTSPlayer(nn2, args=args) 37 | #player2 = RandomPlayer() 38 | player2 = GreedyTaflPlayer() 39 | 40 | players = [player2, player1] 41 | arena = Arena(players, Game, use_batched_mcts=False, args=args, display=print) 42 | wins, draws, winrates = arena.play_game(verbose=True) 43 | for i in range(len(wins)): 44 | print(f'player{i+1}:\n\twins: {wins[i]}\n\twin rate: {winrates[i]}') 45 | print('draws: ', draws) 46 | -------------------------------------------------------------------------------- /alphazero/pytorch_classification/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/alphazero/pytorch_classification/__init__.py -------------------------------------------------------------------------------- /alphazero/pytorch_classification/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Useful utils 2 | """ 3 | from .misc import * 4 | from .logger import * 5 | from .eval import * 6 | 7 | # progress bar 8 | import os, sys 9 | sys.path.append(os.path.join(os.path.dirname(__file__), "progress")) 10 | from progress.bar import Bar as Bar -------------------------------------------------------------------------------- /alphazero/pytorch_classification/utils/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | __all__ = ['accuracy'] 4 | 5 | def accuracy(output, target, topk=(1,)): 6 | """Computes the precision@k for the specified values of k""" 7 | maxk = max(topk) 8 | batch_size = target.size(0) 9 | 10 | _, pred = output.topk(maxk, 1, True, True) 11 | pred = pred.t() 12 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 13 | 14 | res = [] 15 | for k in topk: 16 | correct_k = correct[:k].view(-1).float().sum(0) 17 | res.append(correct_k.mul_(100.0 / batch_size)) 18 | return res -------------------------------------------------------------------------------- /alphazero/pytorch_classification/utils/images/cifar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/alphazero/pytorch_classification/utils/images/cifar.png -------------------------------------------------------------------------------- /alphazero/pytorch_classification/utils/images/imagenet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/alphazero/pytorch_classification/utils/images/imagenet.png -------------------------------------------------------------------------------- /alphazero/pytorch_classification/utils/logger.py: -------------------------------------------------------------------------------- 1 | # A simple torch style logger 2 | # (C) Wei YANG 2017 3 | from __future__ import absolute_import 4 | #import matplotlib.pyplot as plt 5 | import os 6 | import sys 7 | import numpy as np 8 | 9 | __all__ = ['Logger', 'LoggerMonitor', 'savefig'] 10 | 11 | def savefig(fname, dpi=None): 12 | dpi = 150 if dpi == None else dpi 13 | plt.savefig(fname, dpi=dpi) 14 | 15 | def plot_overlap(logger, names=None): 16 | names = logger.names if names == None else names 17 | numbers = logger.numbers 18 | for _, name in enumerate(names): 19 | x = np.arange(len(numbers[name])) 20 | plt.plot(x, np.asarray(numbers[name])) 21 | return [logger.title + '(' + name + ')' for name in names] 22 | 23 | class Logger(object): 24 | '''Save training process to log file with simple plot function.''' 25 | def __init__(self, fpath, title=None, resume=False): 26 | self.file = None 27 | self.resume = resume 28 | self.title = '' if title == None else title 29 | if fpath is not None: 30 | if resume: 31 | self.file = open(fpath, 'r') 32 | name = self.file.readline() 33 | self.names = name.rstrip().split('\t') 34 | self.numbers = {} 35 | for _, name in enumerate(self.names): 36 | self.numbers[name] = [] 37 | 38 | for numbers in self.file: 39 | numbers = numbers.rstrip().split('\t') 40 | for i in range(0, len(numbers)): 41 | self.numbers[self.names[i]].append(numbers[i]) 42 | self.file.close() 43 | self.file = open(fpath, 'a') 44 | else: 45 | self.file = open(fpath, 'w') 46 | 47 | def set_names(self, names): 48 | if self.resume: 49 | pass 50 | # initialize numbers as empty list 51 | self.numbers = {} 52 | self.names = names 53 | for _, name in enumerate(self.names): 54 | self.file.write(name) 55 | self.file.write('\t') 56 | self.numbers[name] = [] 57 | self.file.write('\n') 58 | self.file.flush() 59 | 60 | 61 | def append(self, numbers): 62 | assert len(self.names) == len(numbers), 'Numbers do not match names' 63 | for index, num in enumerate(numbers): 64 | self.file.write("{0:.6f}".format(num)) 65 | self.file.write('\t') 66 | self.numbers[self.names[index]].append(num) 67 | self.file.write('\n') 68 | self.file.flush() 69 | 70 | def plot(self, names=None): 71 | names = self.names if names == None else names 72 | numbers = self.numbers 73 | for _, name in enumerate(names): 74 | x = np.arange(len(numbers[name])) 75 | plt.plot(x, np.asarray(numbers[name])) 76 | plt.legend([self.title + '(' + name + ')' for name in names]) 77 | plt.grid(True) 78 | 79 | def close(self): 80 | if self.file is not None: 81 | self.file.close() 82 | 83 | class LoggerMonitor(object): 84 | '''Load and visualize multiple logs.''' 85 | def __init__ (self, paths): 86 | '''paths is a distionary with {name:filepath} pair''' 87 | self.loggers = [] 88 | for title, path in paths.items(): 89 | logger = Logger(path, title=title, resume=True) 90 | self.loggers.append(logger) 91 | 92 | def plot(self, names=None): 93 | plt.figure() 94 | plt.subplot(121) 95 | legend_text = [] 96 | for logger in self.loggers: 97 | legend_text += plot_overlap(logger, names) 98 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 99 | plt.grid(True) 100 | 101 | if __name__ == '__main__': 102 | # # Example 103 | # logger = Logger('test.txt') 104 | # logger.set_names(['Train loss', 'Valid loss','Test loss']) 105 | 106 | # length = 100 107 | # t = np.arange(length) 108 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 109 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 110 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 111 | 112 | # for i in range(0, length): 113 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]]) 114 | # logger.plot() 115 | 116 | # Example: logger monitor 117 | paths = { 118 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', 119 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', 120 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', 121 | } 122 | 123 | field = ['Valid Acc.'] 124 | 125 | monitor = LoggerMonitor(paths) 126 | monitor.plot(names=field) 127 | savefig('test.eps') -------------------------------------------------------------------------------- /alphazero/pytorch_classification/utils/misc.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import errno 7 | import os 8 | import sys 9 | import time 10 | import math 11 | 12 | __all__ = ['AverageMeter'] 13 | 14 | class AverageMeter(object): 15 | """Computes and stores the average and current value 16 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 17 | """ 18 | def __init__(self): 19 | self.reset() 20 | 21 | def reset(self): 22 | self.val = 0 23 | self.avg = 0 24 | self.sum = 0 25 | self.count = 0 26 | 27 | def update(self, val, n=1): 28 | self.val = val 29 | self.sum += val * n 30 | self.count += n 31 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /alphazero/pytorch_classification/utils/progress/LICENSE: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2012 Giorgos Verigakis 2 | # 3 | # Permission to use, copy, modify, and distribute this software for any 4 | # purpose with or without fee is hereby granted, provided that the above 5 | # copyright notice and this permission notice appear in all copies. 6 | # 7 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 10 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 12 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 13 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 14 | -------------------------------------------------------------------------------- /alphazero/pytorch_classification/utils/progress/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.rst LICENSE 2 | -------------------------------------------------------------------------------- /alphazero/pytorch_classification/utils/progress/README.rst: -------------------------------------------------------------------------------- 1 | Easy progress reporting for Python 2 | ================================== 3 | 4 | |pypi| 5 | 6 | |demo| 7 | 8 | .. |pypi| image:: https://img.shields.io/pypi/v/progress.svg 9 | .. |demo| image:: https://raw.github.com/verigak/progress/master/demo.gif 10 | :alt: Demo 11 | 12 | Bars 13 | ---- 14 | 15 | There are 7 progress bars to choose from: 16 | 17 | - ``Bar`` 18 | - ``ChargingBar`` 19 | - ``FillingSquaresBar`` 20 | - ``FillingCirclesBar`` 21 | - ``IncrementalBar`` 22 | - ``PixelBar`` 23 | - ``ShadyBar`` 24 | 25 | To use them, just call ``next`` to advance and ``finish`` to finish: 26 | 27 | .. code-block:: python 28 | 29 | from progress.bar import Bar 30 | 31 | bar = Bar('Processing', max=20) 32 | for i in range(20): 33 | # Do some work 34 | bar.next() 35 | bar.finish() 36 | 37 | The result will be a bar like the following: :: 38 | 39 | Processing |############# | 42/100 40 | 41 | To simplify the common case where the work is done in an iterator, you can 42 | use the ``iter`` method: 43 | 44 | .. code-block:: python 45 | 46 | for i in Bar('Processing').iter(it): 47 | # Do some work 48 | 49 | Progress bars are very customizable, you can change their width, their fill 50 | character, their suffix and more: 51 | 52 | .. code-block:: python 53 | 54 | bar = Bar('Loading', fill='@', suffix='%(percent)d%%') 55 | 56 | This will produce a bar like the following: :: 57 | 58 | Loading |@@@@@@@@@@@@@ | 42% 59 | 60 | You can use a number of template arguments in ``message`` and ``suffix``: 61 | 62 | ========== ================================ 63 | Name Value 64 | ========== ================================ 65 | index current value 66 | max maximum value 67 | remaining max - index 68 | progress index / max 69 | percent progress * 100 70 | avg simple moving average time per item (in seconds) 71 | elapsed elapsed time in seconds 72 | elapsed_td elapsed as a timedelta (useful for printing as a string) 73 | eta avg * remaining 74 | eta_td eta as a timedelta (useful for printing as a string) 75 | ========== ================================ 76 | 77 | Instead of passing all configuration options on instatiation, you can create 78 | your custom subclass: 79 | 80 | .. code-block:: python 81 | 82 | class FancyBar(Bar): 83 | message = 'Loading' 84 | fill = '*' 85 | suffix = '%(percent).1f%% - %(eta)ds' 86 | 87 | You can also override any of the arguments or create your own: 88 | 89 | .. code-block:: python 90 | 91 | class SlowBar(Bar): 92 | suffix = '%(remaining_hours)d hours remaining' 93 | @property 94 | def remaining_hours(self): 95 | return self.eta // 3600 96 | 97 | 98 | Spinners 99 | ======== 100 | 101 | For actions with an unknown number of steps you can use a spinner: 102 | 103 | .. code-block:: python 104 | 105 | from progress.spinner import Spinner 106 | 107 | spinner = Spinner('Loading ') 108 | while state != 'FINISHED': 109 | # Do some work 110 | spinner.next() 111 | 112 | There are 5 predefined spinners: 113 | 114 | - ``Spinner`` 115 | - ``PieSpinner`` 116 | - ``MoonSpinner`` 117 | - ``LineSpinner`` 118 | - ``PixelSpinner`` 119 | 120 | 121 | Other 122 | ===== 123 | 124 | There are a number of other classes available too, please check the source or 125 | subclass one of them to create your own. 126 | 127 | 128 | License 129 | ======= 130 | 131 | progress is licensed under ISC 132 | -------------------------------------------------------------------------------- /alphazero/pytorch_classification/utils/progress/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/alphazero/pytorch_classification/utils/progress/demo.gif -------------------------------------------------------------------------------- /alphazero/pytorch_classification/utils/progress/progress/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2012 Giorgos Verigakis 2 | # 3 | # Permission to use, copy, modify, and distribute this software for any 4 | # purpose with or without fee is hereby granted, provided that the above 5 | # copyright notice and this permission notice appear in all copies. 6 | # 7 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 10 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 12 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 13 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 14 | 15 | from __future__ import division 16 | 17 | from collections import deque 18 | from datetime import timedelta 19 | from math import ceil 20 | from sys import stderr, stdout 21 | from time import time 22 | 23 | 24 | __version__ = '1.3' 25 | 26 | 27 | class Infinite(object): 28 | file = stdout 29 | sma_window = 10 # Simple Moving Average window 30 | 31 | def __init__(self, *args, **kwargs): 32 | self.index = 0 33 | self.start_ts = time() 34 | self.avg = 0 35 | self._ts = self.start_ts 36 | self._uts = self.start_ts 37 | self._xput = deque(maxlen=self.sma_window) 38 | for key, val in kwargs.items(): 39 | setattr(self, key, val) 40 | 41 | def __getitem__(self, key): 42 | if key.startswith('_'): 43 | return None 44 | return getattr(self, key, None) 45 | 46 | @property 47 | def elapsed(self): 48 | return int(time() - self.start_ts) 49 | 50 | @property 51 | def elapsed_td(self): 52 | return timedelta(seconds=self.elapsed) 53 | 54 | def update_avg(self, n, dt): 55 | if n > 0: 56 | self._xput.append(dt / n) 57 | self.avg = sum(self._xput) / len(self._xput) 58 | 59 | def update(self): 60 | pass 61 | 62 | def start(self): 63 | pass 64 | 65 | def finish(self): 66 | pass 67 | 68 | def next(self, n=1): 69 | now = time() 70 | dt = now - self._ts 71 | self.update_avg(n, dt) 72 | self._ts = now 73 | self.index = self.index + n 74 | if now - self._uts >= 1: 75 | self._uts = now 76 | self.update() 77 | 78 | def iter(self, it): 79 | try: 80 | for x in it: 81 | yield x 82 | self.next() 83 | finally: 84 | self.finish() 85 | 86 | 87 | class Progress(Infinite): 88 | def __init__(self, *args, **kwargs): 89 | super(Progress, self).__init__(*args, **kwargs) 90 | self.max = kwargs.get('max', 100) 91 | 92 | @property 93 | def eta(self): 94 | return int(ceil(self.avg * self.remaining)) 95 | 96 | @property 97 | def eta_td(self): 98 | return timedelta(seconds=self.eta) 99 | 100 | @property 101 | def percent(self): 102 | return self.progress * 100 103 | 104 | @property 105 | def progress(self): 106 | return min(1, self.index / self.max) 107 | 108 | @property 109 | def remaining(self): 110 | return max(self.max - self.index, 0) 111 | 112 | def start(self): 113 | self.update() 114 | 115 | def goto(self, index): 116 | incr = index - self.index 117 | self.next(incr) 118 | 119 | def iter(self, it): 120 | try: 121 | self.max = len(it) 122 | except TypeError: 123 | pass 124 | 125 | try: 126 | for x in it: 127 | yield x 128 | self.next() 129 | finally: 130 | self.finish() 131 | -------------------------------------------------------------------------------- /alphazero/pytorch_classification/utils/progress/progress/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/alphazero/pytorch_classification/utils/progress/progress/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /alphazero/pytorch_classification/utils/progress/progress/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/alphazero/pytorch_classification/utils/progress/progress/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /alphazero/pytorch_classification/utils/progress/progress/__pycache__/bar.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/alphazero/pytorch_classification/utils/progress/progress/__pycache__/bar.cpython-37.pyc -------------------------------------------------------------------------------- /alphazero/pytorch_classification/utils/progress/progress/__pycache__/bar.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/alphazero/pytorch_classification/utils/progress/progress/__pycache__/bar.cpython-38.pyc -------------------------------------------------------------------------------- /alphazero/pytorch_classification/utils/progress/progress/__pycache__/helpers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/alphazero/pytorch_classification/utils/progress/progress/__pycache__/helpers.cpython-37.pyc -------------------------------------------------------------------------------- /alphazero/pytorch_classification/utils/progress/progress/__pycache__/helpers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/alphazero/pytorch_classification/utils/progress/progress/__pycache__/helpers.cpython-38.pyc -------------------------------------------------------------------------------- /alphazero/pytorch_classification/utils/progress/progress/bar.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright (c) 2012 Giorgos Verigakis 4 | # 5 | # Permission to use, copy, modify, and distribute this software for any 6 | # purpose with or without fee is hereby granted, provided that the above 7 | # copyright notice and this permission notice appear in all copies. 8 | # 9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 16 | 17 | from __future__ import unicode_literals 18 | from . import Progress 19 | from .helpers import WritelnMixin 20 | 21 | 22 | class Bar(WritelnMixin, Progress): 23 | width = 32 24 | message = '' 25 | suffix = '%(index)d/%(max)d' 26 | bar_prefix = ' |' 27 | bar_suffix = '| ' 28 | empty_fill = ' ' 29 | fill = '#' 30 | hide_cursor = True 31 | 32 | def update(self): 33 | filled_length = int(self.width * self.progress) 34 | empty_length = self.width - filled_length 35 | 36 | message = self.message % self 37 | bar = self.fill * filled_length 38 | empty = self.empty_fill * empty_length 39 | suffix = self.suffix % self 40 | line = ''.join([message, self.bar_prefix, bar, empty, self.bar_suffix, 41 | suffix]) 42 | self.writeln(line) 43 | 44 | 45 | class ChargingBar(Bar): 46 | suffix = '%(percent)d%%' 47 | bar_prefix = ' ' 48 | bar_suffix = ' ' 49 | empty_fill = '∙' 50 | fill = '█' 51 | 52 | 53 | class FillingSquaresBar(ChargingBar): 54 | empty_fill = '▢' 55 | fill = '▣' 56 | 57 | 58 | class FillingCirclesBar(ChargingBar): 59 | empty_fill = '◯' 60 | fill = '◉' 61 | 62 | 63 | class IncrementalBar(Bar): 64 | phases = (' ', '▏', '▎', '▍', '▌', '▋', '▊', '▉', '█') 65 | 66 | def update(self): 67 | nphases = len(self.phases) 68 | filled_len = self.width * self.progress 69 | nfull = int(filled_len) # Number of full chars 70 | phase = int((filled_len - nfull) * nphases) # Phase of last char 71 | nempty = self.width - nfull # Number of empty chars 72 | 73 | message = self.message % self 74 | bar = self.phases[-1] * nfull 75 | current = self.phases[phase] if phase > 0 else '' 76 | empty = self.empty_fill * max(0, nempty - len(current)) 77 | suffix = self.suffix % self 78 | line = ''.join([message, self.bar_prefix, bar, current, empty, 79 | self.bar_suffix, suffix]) 80 | self.writeln(line) 81 | 82 | 83 | class PixelBar(IncrementalBar): 84 | phases = ('⡀', '⡄', '⡆', '⡇', '⣇', '⣧', '⣷', '⣿') 85 | 86 | 87 | class ShadyBar(IncrementalBar): 88 | phases = (' ', '░', '▒', '▓', '█') 89 | -------------------------------------------------------------------------------- /alphazero/pytorch_classification/utils/progress/progress/counter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright (c) 2012 Giorgos Verigakis 4 | # 5 | # Permission to use, copy, modify, and distribute this software for any 6 | # purpose with or without fee is hereby granted, provided that the above 7 | # copyright notice and this permission notice appear in all copies. 8 | # 9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 16 | 17 | from __future__ import unicode_literals 18 | from . import Infinite, Progress 19 | from .helpers import WriteMixin 20 | 21 | 22 | class Counter(WriteMixin, Infinite): 23 | message = '' 24 | hide_cursor = True 25 | 26 | def update(self): 27 | self.write(str(self.index)) 28 | 29 | 30 | class Countdown(WriteMixin, Progress): 31 | hide_cursor = True 32 | 33 | def update(self): 34 | self.write(str(self.remaining)) 35 | 36 | 37 | class Stack(WriteMixin, Progress): 38 | phases = (' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█') 39 | hide_cursor = True 40 | 41 | def update(self): 42 | nphases = len(self.phases) 43 | i = min(nphases - 1, int(self.progress * nphases)) 44 | self.write(self.phases[i]) 45 | 46 | 47 | class Pie(Stack): 48 | phases = ('○', '◔', '◑', '◕', '●') 49 | -------------------------------------------------------------------------------- /alphazero/pytorch_classification/utils/progress/progress/helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2012 Giorgos Verigakis 2 | # 3 | # Permission to use, copy, modify, and distribute this software for any 4 | # purpose with or without fee is hereby granted, provided that the above 5 | # copyright notice and this permission notice appear in all copies. 6 | # 7 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 10 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 12 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 13 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 14 | 15 | from __future__ import print_function 16 | 17 | 18 | HIDE_CURSOR = '\x1b[?25l' 19 | SHOW_CURSOR = '\x1b[?25h' 20 | 21 | 22 | class WriteMixin(object): 23 | hide_cursor = False 24 | 25 | def __init__(self, message=None, **kwargs): 26 | super(WriteMixin, self).__init__(**kwargs) 27 | self._width = 0 28 | if message: 29 | self.message = message 30 | 31 | if self.file.isatty(): 32 | if self.hide_cursor: 33 | print(HIDE_CURSOR, end='', file=self.file) 34 | print(self.message, end='', file=self.file) 35 | self.file.flush() 36 | 37 | def write(self, s): 38 | if self.file.isatty(): 39 | b = '\b' * self._width 40 | c = s.ljust(self._width) 41 | print(b + c, end='', file=self.file) 42 | self._width = max(self._width, len(s)) 43 | self.file.flush() 44 | 45 | def finish(self): 46 | if self.file.isatty() and self.hide_cursor: 47 | print(SHOW_CURSOR, end='', file=self.file) 48 | 49 | 50 | class WritelnMixin(object): 51 | hide_cursor = False 52 | 53 | def __init__(self, message=None, **kwargs): 54 | super(WritelnMixin, self).__init__(**kwargs) 55 | if message: 56 | self.message = message 57 | 58 | if self.file.isatty() and self.hide_cursor: 59 | print(HIDE_CURSOR, end='', file=self.file) 60 | 61 | def clearln(self): 62 | #if self.file.isatty(): 63 | print('\r\x1b[K', end='', file=self.file) 64 | 65 | def writeln(self, line): 66 | #if self.file.isatty(): 67 | self.clearln() 68 | print(line, end='', file=self.file) 69 | self.file.flush() 70 | 71 | def finish(self): 72 | #if self.file.isatty(): - snair, to print to stderr logfile 73 | print(file=self.file) 74 | if self.hide_cursor: 75 | print(SHOW_CURSOR, end='', file=self.file) 76 | 77 | 78 | from signal import signal, SIGINT 79 | from sys import exit 80 | 81 | 82 | class SigIntMixin(object): 83 | """Registers a signal handler that calls finish on SIGINT""" 84 | 85 | def __init__(self, *args, **kwargs): 86 | super(SigIntMixin, self).__init__(*args, **kwargs) 87 | signal(SIGINT, self._sigint_handler) 88 | 89 | def _sigint_handler(self, signum, frame): 90 | self.finish() 91 | exit(0) 92 | -------------------------------------------------------------------------------- /alphazero/pytorch_classification/utils/progress/progress/spinner.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright (c) 2012 Giorgos Verigakis 4 | # 5 | # Permission to use, copy, modify, and distribute this software for any 6 | # purpose with or without fee is hereby granted, provided that the above 7 | # copyright notice and this permission notice appear in all copies. 8 | # 9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 16 | 17 | from __future__ import unicode_literals 18 | from . import Infinite 19 | from .helpers import WriteMixin 20 | 21 | 22 | class Spinner(WriteMixin, Infinite): 23 | message = '' 24 | phases = ('-', '\\', '|', '/') 25 | hide_cursor = True 26 | 27 | def update(self): 28 | i = self.index % len(self.phases) 29 | self.write(self.phases[i]) 30 | 31 | 32 | class PieSpinner(Spinner): 33 | phases = ['◷', '◶', '◵', '◴'] 34 | 35 | 36 | class MoonSpinner(Spinner): 37 | phases = ['◑', '◒', '◐', '◓'] 38 | 39 | 40 | class LineSpinner(Spinner): 41 | phases = ['⎺', '⎻', '⎼', '⎽', '⎼', '⎻'] 42 | 43 | class PixelSpinner(Spinner): 44 | phases = ['⣾','⣷', '⣯', '⣟', '⡿', '⢿', '⣻', '⣽'] 45 | -------------------------------------------------------------------------------- /alphazero/pytorch_classification/utils/progress/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup 4 | 5 | import progress 6 | 7 | 8 | setup( 9 | name='progress', 10 | version=progress.__version__, 11 | description='Easy to use progress bars', 12 | long_description=open('README.rst').read(), 13 | author='Giorgos Verigakis', 14 | author_email='verigak@gmail.com', 15 | url='http://github.com/verigak/progress/', 16 | license='ISC', 17 | packages=['progress'], 18 | classifiers=[ 19 | 'Environment :: Console', 20 | 'Intended Audience :: Developers', 21 | 'License :: OSI Approved :: ISC License (ISCL)', 22 | 'Programming Language :: Python :: 2.6', 23 | 'Programming Language :: Python :: 2.7', 24 | 'Programming Language :: Python :: 3.3', 25 | 'Programming Language :: Python :: 3.4', 26 | 'Programming Language :: Python :: 3.5', 27 | 'Programming Language :: Python :: 3.6', 28 | ] 29 | ) 30 | -------------------------------------------------------------------------------- /alphazero/pytorch_classification/utils/progress/test_progress.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from __future__ import print_function 4 | 5 | import random 6 | import time 7 | 8 | from progress.bar import (Bar, ChargingBar, FillingSquaresBar, 9 | FillingCirclesBar, IncrementalBar, PixelBar, 10 | ShadyBar) 11 | from progress.spinner import (Spinner, PieSpinner, MoonSpinner, LineSpinner, 12 | PixelSpinner) 13 | from progress.counter import Counter, Countdown, Stack, Pie 14 | 15 | 16 | def sleep(): 17 | t = 0.01 18 | t += t * random.uniform(-0.1, 0.1) # Add some variance 19 | time.sleep(t) 20 | 21 | 22 | for bar_cls in (Bar, ChargingBar, FillingSquaresBar, FillingCirclesBar): 23 | suffix = '%(index)d/%(max)d [%(elapsed)d / %(eta)d / %(eta_td)s]' 24 | bar = bar_cls(bar_cls.__name__, suffix=suffix) 25 | for i in bar.iter(range(200)): 26 | sleep() 27 | 28 | for bar_cls in (IncrementalBar, PixelBar, ShadyBar): 29 | suffix = '%(percent)d%% [%(elapsed_td)s / %(eta)d / %(eta_td)s]' 30 | bar = bar_cls(bar_cls.__name__, suffix=suffix) 31 | for i in bar.iter(range(200)): 32 | sleep() 33 | 34 | for spin in (Spinner, PieSpinner, MoonSpinner, LineSpinner, PixelSpinner): 35 | for i in spin(spin.__name__ + ' ').iter(range(100)): 36 | sleep() 37 | print() 38 | 39 | for singleton in (Counter, Countdown, Stack, Pie): 40 | for i in singleton(singleton.__name__ + ' ').iter(range(100)): 41 | sleep() 42 | print() 43 | 44 | bar = IncrementalBar('Random', suffix='%(index)d') 45 | for i in range(100): 46 | bar.goto(random.randint(0, 100)) 47 | sleep() 48 | bar.finish() 49 | -------------------------------------------------------------------------------- /alphazero/requirements.txt: -------------------------------------------------------------------------------- 1 | setuptools>=41.6.0 2 | Cython>=0.29.14 3 | numpy>=1.17.3 4 | torch>=1.3.1 5 | choix>=0.3.3 6 | tensorboardX>=1.9 7 | -------------------------------------------------------------------------------- /alphazero/roundrobin.py: -------------------------------------------------------------------------------- 1 | import pyximport 2 | pyximport.install() 3 | 4 | from alphazero.NNetWrapper import NNetWrapper as nn 5 | from alphazero.GenericPlayers import * 6 | from alphazero.Arena import Arena 7 | from pathlib import Path 8 | from glob import glob 9 | 10 | import numpy as np 11 | import pprint 12 | import choix 13 | 14 | if __name__ == '__main__': 15 | from alphazero.envs.hnefatafl.fastafl import Game as Game 16 | from alphazero.envs.hnefatafl.train_fastafl import args 17 | 18 | print('Args:') 19 | pprint.pprint(args) 20 | if not Path('roundrobin').exists(): 21 | Path('roundrobin').mkdir() 22 | print('Beginning round robin') 23 | networks = sorted(glob('roundrobin/*'), reverse=True) 24 | model_count = len(networks) + int(args.compareWithBaseline) 25 | 26 | if model_count <= 2: 27 | print( 28 | "Too few models for round robin. Please add models to the roundrobin/ directory" 29 | ) 30 | exit() 31 | 32 | total_games = 0 33 | for i in range(model_count): 34 | total_games += i 35 | total_games *= args.arenaCompare 36 | print( 37 | f'Comparing {model_count} different models in {total_games} total games') 38 | win_matrix = np.zeros((model_count, model_count)) 39 | 40 | nnet1 = nn(Game, args) 41 | nnet2 = nn(Game, args) 42 | 43 | for i in range(model_count - 1): 44 | for j in range(i + 1, model_count): 45 | file1 = Path(networks[i]) 46 | file2 = Path('random' if args.compareWithBaseline and j == model_count - 1 else networks[j]) 47 | print(f'{file1.stem} vs {file2.stem}') 48 | nnet1.load_checkpoint(folder='roundrobin', filename=file1.name) 49 | 50 | if file2.name != 'random': 51 | nnet2.load_checkpoint(folder='roundrobin', filename=file2.name) 52 | 53 | if args.arenaBatched: 54 | if not args.arenaMCTS: 55 | args.arenaMCTS = True 56 | raise UserWarning( 57 | 'Batched arena comparison is enabled which uses MCTS, but arena MCTS is set to False.' 58 | ' Ignoring this, and continuing with batched MCTS in arena.') 59 | 60 | p1 = nnet1.process 61 | p2 = nnet2.process 62 | else: 63 | cls = MCTSPlayer if args.arenaMCTS else NNPlayer 64 | p1 = cls(nnet1, args=args) 65 | p2 = cls(nnet2, args=args) 66 | else: 67 | p1 = nnet1.process #(MCTSPlayer if args.arenaMCTS else NNPlayer)(Game, nnet1, args=args) 68 | p2 = RawMCTSPlayer(Game, args).process 69 | 70 | arena = Arena([p1, p2], Game, use_batched_mcts=args.arenaBatched, args=args) 71 | wins, draws, winrates = arena.play_games(args.arenaCompare) 72 | win_matrix[i, j] = wins[0] + 0.5 * draws 73 | win_matrix[j, i] = wins[1] + 0.5 * draws 74 | print(f'wins: {wins[0]}, ties: {draws}, losses:{wins[1]}\n') 75 | 76 | print("\nWin Matrix(row beat column):") 77 | print(win_matrix) 78 | try: 79 | with np.errstate(divide='ignore', invalid='ignore'): 80 | params = choix.ilsr_pairwise_dense(win_matrix) 81 | print("\nRankings:") 82 | for i, player in enumerate(np.argsort(params)[::-1]): 83 | name = 'random' if args.compareWithBaseline and player == model_count - \ 84 | 1 else Path(networks[player]).stem 85 | print(f"{i + 1}. {name} with {params[player]:0.2f} rating") 86 | print( 87 | "\n(Rating Diff, Winrate) -> (0.5, 62%), (1, 73%), (2, 88%), (3, 95%), (5, 99%)") 88 | except Exception: 89 | print("\nNot Enough data to calculate rankings") 90 | -------------------------------------------------------------------------------- /alphazero/utils.py: -------------------------------------------------------------------------------- 1 | class dotdict(dict): 2 | def __getattr__(self, name): 3 | if name.startswith('__'): 4 | raise AttributeError 5 | return self[name] 6 | 7 | def __setattr__(self, key, value): 8 | self[key] = value 9 | 10 | def copy(self): 11 | data = super().copy() 12 | return self.__class__(data) 13 | 14 | 15 | def get_iter_file(iteration: int): 16 | return f'iteration-{iteration:04d}.pkl' 17 | 18 | 19 | def scale_temp(scale_factor: float, min_temp: float, cur_temp: float, turns: int, const_max_turns: int) -> float: 20 | if const_max_turns and (turns + 1) % int(scale_factor * const_max_turns) == 0: 21 | return max(min_temp, cur_temp / 2) 22 | else: 23 | return cur_temp 24 | 25 | 26 | def default_temp_scaling(*args, **kwargs) -> float: 27 | return scale_temp(0.15, 0.2, *args, **kwargs) 28 | 29 | 30 | def const_temp_scaling(temp, *args, **kwargs) -> float: 31 | return temp 32 | 33 | 34 | def get_game_results(result_queue, game_cls, _get_index=None): 35 | player_to_index = {p: i for i, p in enumerate(range(game_cls.num_players()))} 36 | 37 | num_games = result_queue.qsize() 38 | wins = [0] * game_cls.num_players() 39 | draws = 0 40 | game_len_sum = 0 41 | 42 | for _ in range(num_games): 43 | state, winstate, agent_id = result_queue.get() 44 | game_len_sum += state.turns 45 | 46 | for player, is_win in enumerate(winstate): 47 | if is_win: 48 | if player == len(wins): 49 | draws += 1 50 | else: 51 | index = _get_index(player, agent_id) if _get_index else player_to_index[player] 52 | wins[index] += 1 53 | 54 | return wins, draws, game_len_sum / num_games if num_games else 0 55 | 56 | 57 | def plot_mcts_tree(mcts, max_depth=2): 58 | import networkx as nx 59 | import matplotlib.pyplot as plt 60 | G = nx.Graph() 61 | 62 | global node_idx 63 | node_idx = 0 64 | 65 | def find_nodes(cur_node, _past_node=None, _past_i=None, _depth=0): 66 | if _depth > max_depth: return 67 | global node_idx 68 | cur_idx = node_idx 69 | 70 | G.add_node(cur_idx, a=cur_node.a, q=round(cur_node.q, 2), n=cur_node.n, v=round(cur_node.v, 2)) 71 | if _past_node: 72 | G.add_edge(cur_idx, _past_i) 73 | node_idx += 1 74 | 75 | for node in cur_node._children: 76 | find_nodes(node, cur_node, cur_idx, _depth+1) 77 | 78 | find_nodes(mcts._root) 79 | labels = {node: '\n'.join(['{}: {}'.format(k, v) for k, v in G.nodes[node].items()]) for node in G.nodes} 80 | #pos = nx.spring_layout(G, k=0.15, iterations=50) 81 | pos = nx.nx_agraph.graphviz_layout(G, prog='dot', args='-Gnodesep=1.0 -Goverlap=false') 82 | nx.draw(G, pos, labels=labels) 83 | plt.show() 84 | 85 | 86 | def convert_checkpoint_file(filepath: str, game_cls, args: dotdict, overwrite_args=False): 87 | from alphazero.NNetWrapper import NNetWrapper 88 | nnet = NNetWrapper(game_cls, args) 89 | nnet.load_checkpoint('', filepath, use_saved_args=not overwrite_args) 90 | nnet.save_checkpoint('', filepath, make_dirs=False) 91 | 92 | 93 | def map_value(value, in_min, in_max, out_min, out_max): 94 | return (value - in_min) * (out_max - out_min) / (in_max - in_min) + out_min 95 | -------------------------------------------------------------------------------- /boardgame/__init__.py: -------------------------------------------------------------------------------- 1 | import pyximport, numpy 2 | pyximport.install(setup_args={'include_dirs': numpy.get_include()}) 3 | 4 | from .board import * 5 | 6 | -------------------------------------------------------------------------------- /boardgame/__init__.pyo: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/boardgame/__init__.pyo -------------------------------------------------------------------------------- /boardgame/board.pxd: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | # cython: boundscheck=False 3 | # cython: wraparound=False 4 | # cython: nonecheck=False 5 | # cython: overflowcheck=False 6 | # cython: initializedcheck=False 7 | # cython: cdivision=True 8 | # cython: auto_pickle=True 9 | 10 | import numpy as np 11 | cimport numpy as np 12 | 13 | DTYPE = np.uint8 14 | ctypedef np.uint8_t DTYPE_t 15 | 16 | ctypedef fused RawState_T: 17 | str 18 | np.ndarray 19 | 20 | cpdef void _raise_invalid_board(RawState_T board_data) 21 | 22 | cdef class Square: 23 | cdef public Py_ssize_t x 24 | cdef public Py_ssize_t y 25 | cpdef tuple _get_tuple(self) 26 | 27 | cdef class Team: 28 | cdef public int colour 29 | cdef public tuple pieces 30 | 31 | cdef class BaseBoard: 32 | cdef public tuple teams 33 | cdef public bint use_load_whitespace 34 | cdef public tuple _valid_squares 35 | cdef public tuple _all_pieces 36 | cdef public int _empty_square 37 | cdef public np.ndarray _state 38 | cdef public Py_ssize_t width 39 | cdef public Py_ssize_t height 40 | cdef public int num_turns 41 | 42 | cpdef void _load_str_inner(self, str data) 43 | cpdef void _load_str(self, str data, bint _skip_error_check=*) 44 | cpdef BaseBoard copy(self) 45 | 46 | cpdef bint _in_bounds(self, Square square) 47 | cpdef list _iter_pieces(self, tuple pieces, int piece_type=*) 48 | cpdef list legal_moves(self, tuple pieces=*, int piece_type=*) 49 | cpdef bint _has_legals_check(self, Square piece_square) 50 | cpdef bint has_legal_moves(self, tuple pieces=*, int piece_type=*) 51 | 52 | cpdef int get_winner(self) 53 | cpdef bint is_game_over(self) 54 | 55 | cpdef void move(self, source, dest, bint check_turn=*, bint _check_valid=*, bint _check_win=*) 56 | cpdef BaseBoard move_(self, source, dest, bint check_turn=*, bint _check_valid=*, bint _check_win=*) 57 | cpdef void random_move(self, bint print_move=*) 58 | cpdef BaseBoard random_move_(self, bint print_move=*) 59 | 60 | cpdef Square _relative_square(self, Square source, tuple direction) 61 | cpdef list _surrounding_squares(self, Square source) 62 | cpdef void _set_square(self, Square square, int new_val) 63 | cpdef tuple _get_team(self, int piece_type, bint enemy=*) 64 | cpdef int _team_colour(self, int piece_type) 65 | cpdef void add_piece(self, Square square, int piece, bint replace=*, _check_valid=*) 66 | cpdef int remove_piece(self, Square square, bint raise_no_piece=*) 67 | cpdef np.ndarray get_mask(self, tuple piece_types) 68 | cpdef list get_squares(self, tuple piece_types) 69 | cpdef int to_play(self) 70 | cpdef bint is_turn(self, Square square) 71 | -------------------------------------------------------------------------------- /boardgame/boardgame.pyo: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/boardgame/boardgame.pyo -------------------------------------------------------------------------------- /boardgame/errors.py: -------------------------------------------------------------------------------- 1 | class BoardgameError(Exception): 2 | pass 3 | 4 | 5 | class LoadError(BoardgameError): 6 | pass 7 | 8 | 9 | class InvalidBoardState(LoadError): 10 | pass 11 | 12 | 13 | class InvalidMoveError(BoardgameError): 14 | pass 15 | 16 | 17 | class PositionError(BoardgameError): 18 | pass 19 | -------------------------------------------------------------------------------- /boardgame/errors.pyo: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/boardgame/errors.pyo -------------------------------------------------------------------------------- /boardgame/net.pyo: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/boardgame/net.pyo -------------------------------------------------------------------------------- /fastafl/__init__.py: -------------------------------------------------------------------------------- 1 | #from .engine import * 2 | import pyximport, numpy 3 | pyximport.install(setup_args={'include_dirs': numpy.get_include()}) 4 | 5 | from .cengine import * 6 | 7 | -------------------------------------------------------------------------------- /fastafl/cengine.pxd: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | # cython: boundscheck=False 3 | # cython: wraparound=False 4 | # cython: nonecheck=False 5 | # cython: overflowcheck=False 6 | # cython: initializedcheck=False 7 | # cython: cdivision=True 8 | # cython: auto_pickle=True 9 | 10 | from boardgame.board cimport BaseBoard, Square 11 | 12 | 13 | cdef class Board(BaseBoard): 14 | cdef public bint king_two_sided_capture 15 | cdef public bint move_over_throne 16 | cdef public bint king_can_enter_throne 17 | cdef public bint _king_captured 18 | cdef public bint _king_escaped 19 | 20 | cpdef bint _is_valid(self, Square square, bint is_king=*) 21 | cpdef list legal_moves(self, tuple pieces=*, int piece_type=*) 22 | cpdef bint _has_legals_check(self, Square piece_square) 23 | 24 | cpdef bint king_escaped(self) 25 | cpdef bint __king_captured_lambda(self, Square x) 26 | cpdef bint king_captured(self) 27 | 28 | cpdef void _check_capture(self, Square moved_piece) 29 | cpdef tuple __next_check_squares(self, tuple enemy, squares) 30 | cpdef bint __blocked(self, Square square) 31 | cpdef tuple __recurse_check(self, Square square, list checked, tuple enemy) 32 | cpdef void _check_surround(self, Square moved_piece) 33 | cpdef void move(self, source, dest, bint check_turn=*, bint _check_valid=*, bint _check_win=*) 34 | 35 | cpdef int to_play(self) 36 | -------------------------------------------------------------------------------- /fastafl/variants.py: -------------------------------------------------------------------------------- 1 | hnefatafl = """50022222005 2 | 00000200000 3 | 00000000000 4 | 20000100002 5 | 20001110002 6 | 22011711022 7 | 20001110002 8 | 20000100002 9 | 00000000000 10 | 00000200000 11 | 50022222005""" 12 | 13 | brandubh = """5002005 14 | 0002000 15 | 0001000 16 | 2217122 17 | 0001000 18 | 0002000 19 | 5002005""" 20 | 21 | hnefatafl_args = (hnefatafl, False) 22 | brandubh_args = (brandubh, True) 23 | -------------------------------------------------------------------------------- /hnefatafl/__init__.pyo: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/hnefatafl/__init__.pyo -------------------------------------------------------------------------------- /hnefatafl/_gui.pyo: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/hnefatafl/_gui.pyo -------------------------------------------------------------------------------- /hnefatafl/engine/__init__.pyo: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/hnefatafl/engine/__init__.pyo -------------------------------------------------------------------------------- /hnefatafl/engine/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/hnefatafl/engine/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /hnefatafl/engine/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/hnefatafl/engine/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /hnefatafl/engine/__pycache__/board.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/hnefatafl/engine/__pycache__/board.cpython-37.pyc -------------------------------------------------------------------------------- /hnefatafl/engine/__pycache__/board.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/hnefatafl/engine/__pycache__/board.cpython-38.pyc -------------------------------------------------------------------------------- /hnefatafl/engine/__pycache__/game.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/hnefatafl/engine/__pycache__/game.cpython-37.pyc -------------------------------------------------------------------------------- /hnefatafl/engine/__pycache__/game.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/hnefatafl/engine/__pycache__/game.cpython-38.pyc -------------------------------------------------------------------------------- /hnefatafl/engine/__pycache__/variants.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/hnefatafl/engine/__pycache__/variants.cpython-37.pyc -------------------------------------------------------------------------------- /hnefatafl/engine/__pycache__/variants.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/hnefatafl/engine/__pycache__/variants.cpython-38.pyc -------------------------------------------------------------------------------- /hnefatafl/engine/board.pyo: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/hnefatafl/engine/board.pyo -------------------------------------------------------------------------------- /hnefatafl/engine/game.pyo: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/hnefatafl/engine/game.pyo -------------------------------------------------------------------------------- /hnefatafl/net/__init__.pyo: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/hnefatafl/net/__init__.pyo -------------------------------------------------------------------------------- /hnefatafl/net/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/hnefatafl/net/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /hnefatafl/net/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/hnefatafl/net/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /hnefatafl/net/__pycache__/client.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/hnefatafl/net/__pycache__/client.cpython-37.pyc -------------------------------------------------------------------------------- /hnefatafl/net/__pycache__/client.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/hnefatafl/net/__pycache__/client.cpython-38.pyc -------------------------------------------------------------------------------- /hnefatafl/net/__pycache__/server.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/hnefatafl/net/__pycache__/server.cpython-37.pyc -------------------------------------------------------------------------------- /hnefatafl/net/client.pyo: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevaday/alphazero-general/99f74f19766923e98523dfc310a1ffef7379ca83/hnefatafl/net/client.pyo -------------------------------------------------------------------------------- /remove_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import shutil 4 | 5 | 6 | if len(sys.argv) > 1: 7 | run_name = sys.argv[1] 8 | else: 9 | print('Please provide a run name as a command line argument to remove training files.') 10 | sys.exit() 11 | 12 | for run_dir in ('checkpoint', 'data', 'runs'): 13 | shutil.rmtree(os.path.join(run_dir, run_name), onerror=lambda _, path, exc_info: print(f'Failed to remove path {path}: {exc_info}')) 14 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch<2.5 2 | numpy<1.25 3 | Cython<3.1 4 | tensorboard<2.15 5 | tensorboardX<2.7 6 | choix<0.4 7 | PySide2<5.16 8 | six<1.17 --------------------------------------------------------------------------------