├── ruff.toml ├── tinyzero.png ├── connect2 ├── out │ ├── model.pth │ └── optimizer.pth ├── eval.py ├── game.py └── train.py ├── tictactoe ├── one_dim │ ├── out │ │ ├── model.pth │ │ └── optimizer.pth │ ├── game.py │ ├── train.py │ └── eval.py └── two_dim │ ├── out │ ├── model.pth │ └── optimizer.pth │ ├── game.py │ ├── train.py │ └── eval.py ├── requirements.txt ├── replay_buffer.py ├── LICENSE ├── models.py ├── README.md ├── agents.py └── mcts.py /ruff.toml: -------------------------------------------------------------------------------- 1 | indent-width = 2 2 | line-length = 120 3 | -------------------------------------------------------------------------------- /tinyzero.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s-casci/tinyzero/HEAD/tinyzero.png -------------------------------------------------------------------------------- /connect2/out/model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s-casci/tinyzero/HEAD/connect2/out/model.pth -------------------------------------------------------------------------------- /connect2/out/optimizer.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s-casci/tinyzero/HEAD/connect2/out/optimizer.pth -------------------------------------------------------------------------------- /tictactoe/one_dim/out/model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s-casci/tinyzero/HEAD/tictactoe/one_dim/out/model.pth -------------------------------------------------------------------------------- /tictactoe/two_dim/out/model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s-casci/tinyzero/HEAD/tictactoe/two_dim/out/model.pth -------------------------------------------------------------------------------- /tictactoe/one_dim/out/optimizer.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s-casci/tinyzero/HEAD/tictactoe/one_dim/out/optimizer.pth -------------------------------------------------------------------------------- /tictactoe/two_dim/out/optimizer.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s-casci/tinyzero/HEAD/tictactoe/two_dim/out/optimizer.pth -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Package Version 2 | ------------------ ---------- 3 | numba 0.58.1 4 | numpy 1.26.2 5 | torch 2.1.2 6 | tqdm 4.66.1 7 | wandb 0.16.1 8 | -------------------------------------------------------------------------------- /tictactoe/two_dim/game.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sys 4 | 5 | sys.path.append(os.getcwd()) 6 | from tictactoe.one_dim.game import TicTacToe as TicTacToe1D # noqa: E402 7 | 8 | 9 | class TicTacToe(TicTacToe1D): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | def to_observation(self): 14 | obs = np.zeros((3, 3), dtype=np.float32) 15 | for i, x in enumerate(self.state): 16 | if x == self.turn: 17 | obs[i // 3, i % 3] = 1 18 | elif x == -self.turn: 19 | obs[i // 3, i % 3] = -1 20 | return np.array([obs]) 21 | -------------------------------------------------------------------------------- /replay_buffer.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | import numpy as np 3 | 4 | 5 | class ReplayBuffer: 6 | def __init__(self, max_size): 7 | self.observations = deque(maxlen=max_size) 8 | self.actions_dist = deque(maxlen=max_size) 9 | self.results = deque(maxlen=max_size) 10 | 11 | def __len__(self): 12 | return len(self.observations) 13 | 14 | def add_sample(self, observation, actions_dist, result): 15 | self.observations.append(observation) 16 | self.actions_dist.append(actions_dist) 17 | self.results.append(result) 18 | 19 | def sample(self, batch_size): 20 | indices = np.random.choice(len(self), batch_size, replace=False) 21 | observations = np.array([self.observations[i] for i in indices], dtype=np.float32) 22 | actions_dist = np.array([self.actions_dist[i] for i in indices], dtype=np.float32) 23 | results = np.array([self.results[i] for i in indices], dtype=np.float32) 24 | return observations, actions_dist, results 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 s-casci 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /connect2/eval.py: -------------------------------------------------------------------------------- 1 | from game import Connect2 2 | import torch 3 | from train import OUT_DIR, SEARCH_ITERATIONS 4 | from tqdm import tqdm 5 | 6 | import os 7 | import sys 8 | 9 | sys.path.append(os.getcwd()) 10 | from models import LinearNetwork # noqa: E402 11 | from agents import AlphaZeroAgent # noqa: E402 12 | from mcts import pit # noqa: E402 13 | 14 | EVAL_GAMES = 100 15 | 16 | if __name__ == "__main__": 17 | game = Connect2() 18 | 19 | model = LinearNetwork(game.observation_shape, game.action_space) 20 | model.load_state_dict(torch.load(f"{OUT_DIR}/model.pth")) 21 | 22 | agent = AlphaZeroAgent(model) 23 | agent_play_kwargs = {"search_iterations": SEARCH_ITERATIONS, "c_puct": 1.5, "dirichlet_alpha": 0.3} 24 | 25 | print(f"Playing {EVAL_GAMES} games against itself") 26 | 27 | results = {0: 0, 1: 0, -1: 0} 28 | for _ in tqdm(range(EVAL_GAMES)): 29 | result = pit( 30 | game, 31 | agent, 32 | agent, 33 | agent_play_kwargs, 34 | agent_play_kwargs, 35 | ) 36 | results[result] += 1 37 | 38 | print("Results:") 39 | print(f"First player wins: {results[1]}") 40 | print(f"Second player wins: {results[-1]}") 41 | print(f"Draws: {results[0]}") 42 | -------------------------------------------------------------------------------- /connect2/game.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Connect2: 5 | STATE_LEN = 4 6 | 7 | def __init__(self): 8 | self.reset() 9 | 10 | self.observation_shape = self.to_observation().shape 11 | self.action_space = self.STATE_LEN 12 | 13 | def reset(self): 14 | self.state = [0] * self.STATE_LEN 15 | self.actions_stack = [] 16 | self.turn = 1 17 | 18 | def __str__(self): 19 | return str(self.state) 20 | 21 | def to_observation(self): 22 | obs = np.zeros(self.STATE_LEN, dtype=np.float32) 23 | for i, x in enumerate(self.state): 24 | if x == self.turn: 25 | obs[i] = 1 26 | elif x == -self.turn: 27 | obs[i] = -1 28 | return obs 29 | 30 | def get_legal_actions(self): 31 | return [i for i, x in enumerate(self.state) if x == 0] 32 | 33 | def step(self, action): 34 | if self.state[action] != 0: 35 | raise ValueError(f"Action {action} is illegal") 36 | self.state[action] = self.turn 37 | self.actions_stack.append(action) 38 | self.turn *= -1 39 | 40 | def undo_last_action(self): 41 | self.state[self.actions_stack.pop()] = 0 42 | self.turn *= -1 43 | 44 | def get_result(self): 45 | for x, y in zip(self.state[:-1], self.state[1:]): 46 | if x == y != 0: 47 | return x 48 | if len(self.get_legal_actions()) == 0: 49 | return 0 50 | 51 | # get result from the point of view of the current player 52 | def get_first_person_result(self): 53 | result = self.get_result() 54 | if result is not None: 55 | return result * self.turn 56 | 57 | @staticmethod 58 | def swap_result(result): 59 | return -result 60 | -------------------------------------------------------------------------------- /tictactoe/one_dim/game.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class TicTacToe: 5 | def __init__(self): 6 | self.reset() 7 | 8 | self.observation_shape = self.to_observation().shape 9 | self.action_space = 9 10 | 11 | def reset(self): 12 | self.state = [0] * 9 13 | self.actions = [] 14 | self.turn = 1 15 | 16 | def __str__(self): 17 | return "\n".join([" ".join([str(x) for x in self.state[i : i + 3]]) for i in range(0, 9, 3)]) 18 | 19 | def get_legal_actions(self): 20 | return [i for i, x in enumerate(self.state) if x == 0] 21 | 22 | def step(self, action): 23 | if self.state[action] != 0: 24 | raise ValueError(f"Action {action} is illegal") 25 | self.state[action] = self.turn 26 | self.actions.append(action) 27 | self.turn *= -1 28 | 29 | def undo_last_action(self): 30 | self.state[self.actions.pop()] = 0 31 | self.turn *= -1 32 | 33 | def get_result(self): 34 | if len(self.actions) < 5: 35 | return 36 | for x, y, z in [(0, 1, 2), (3, 4, 5), (6, 7, 8), (0, 3, 6), (1, 4, 7), (2, 5, 8), (0, 4, 8), (2, 4, 6)]: 37 | if self.state[x] == self.state[y] == self.state[z] != 0: 38 | return self.state[x] 39 | if len(self.get_legal_actions()) == 0: 40 | return 0 41 | 42 | def get_first_person_result(self): 43 | result = self.get_result() 44 | if result is not None: 45 | return result * self.turn 46 | 47 | @staticmethod 48 | def swap_result(result): 49 | return -result 50 | 51 | def to_observation(self): 52 | obs = np.zeros(9, dtype=np.float32) 53 | for i, x in enumerate(self.state): 54 | if x == self.turn: 55 | obs[i] = 1 56 | elif x == -self.turn: 57 | obs[i] = -1 58 | return obs 59 | -------------------------------------------------------------------------------- /connect2/train.py: -------------------------------------------------------------------------------- 1 | from game import Connect2 2 | from datetime import datetime 3 | import torch 4 | import wandb 5 | from tqdm import tqdm 6 | import os 7 | import sys 8 | 9 | sys.path.append(os.getcwd()) 10 | from models import LinearNetwork # noqa: E402 11 | from agents import AlphaZeroAgentTrainer # noqa: E402 12 | 13 | OUT_DIR = "connect2/out" 14 | INIT_FROM_CHECKPOINT = False 15 | SELFPLAY_GAMES = 256 16 | SELFPLAY_GAMES_PER_SAVE = SELFPLAY_GAMES // 4 17 | BATCH_SIZE = 64 18 | SEARCH_ITERATIONS = 16 19 | MAX_REPLAY_BUFFER_SIZE = BATCH_SIZE * 4 20 | TRAINING_EPOCHS = 3 21 | LEARNING_RATE = 1e-3 22 | WEIGHT_DECAY = 1e-4 23 | C_PUCT = 1.5 24 | DIRICHLET_ALPHA = 0.3 # set to None to disable 25 | WANDB_LOG = True 26 | WANDB_PROJECT_NAME = "tinyalphazero-connect2" 27 | WANDB_RUN_NAME = "run" + datetime.now().strftime("%Y%m%d-%H%M%S") 28 | 29 | if __name__ == "__main__": 30 | game = Connect2() 31 | 32 | model = LinearNetwork(game.observation_shape, game.action_space) 33 | optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY) 34 | 35 | agent = AlphaZeroAgentTrainer(model, optimizer, MAX_REPLAY_BUFFER_SIZE) 36 | 37 | if INIT_FROM_CHECKPOINT: 38 | agent.load_training_state(f"{OUT_DIR}/model.pth", f"{OUT_DIR}/optimizer.pth") 39 | 40 | if WANDB_LOG: 41 | wandb_run = wandb.init(project=WANDB_PROJECT_NAME, name=WANDB_RUN_NAME) 42 | 43 | os.makedirs(OUT_DIR, exist_ok=True) 44 | print("Starting training") 45 | 46 | for i in tqdm(range(SELFPLAY_GAMES)): 47 | game.reset() 48 | 49 | values_losses, policies_losses = agent.train_step( 50 | game, SEARCH_ITERATIONS, BATCH_SIZE, TRAINING_EPOCHS, c_puct=C_PUCT, dirichlet_alpha=DIRICHLET_ALPHA 51 | ) 52 | 53 | if WANDB_LOG: 54 | for values_loss, policies_loss in zip(values_losses, policies_losses): 55 | wandb.log({"values_loss": values_loss, "policies_loss": policies_loss}) 56 | 57 | if i > 0 and i % SELFPLAY_GAMES_PER_SAVE == 0: 58 | print("Saving training state") 59 | agent.save_training_state(f"{OUT_DIR}/model.pth", f"{OUT_DIR}/optimizer.pth") 60 | 61 | if WANDB_LOG: 62 | wandb_run.finish() 63 | 64 | print("Training complete") 65 | 66 | print("Saving final training state") 67 | agent.save_training_state(f"{OUT_DIR}/model.pth", f"{OUT_DIR}/optimizer.pth") 68 | -------------------------------------------------------------------------------- /tictactoe/one_dim/train.py: -------------------------------------------------------------------------------- 1 | from game import TicTacToe 2 | from datetime import datetime 3 | import torch 4 | import wandb 5 | from tqdm import tqdm 6 | import os 7 | import sys 8 | 9 | sys.path.append(os.getcwd()) 10 | from models import LinearNetwork # noqa: E402 11 | from agents import AlphaZeroAgentTrainer # noqa: E402 12 | 13 | OUT_DIR = "tictactoe/one_dim/out" 14 | INIT_FROM_CHECKPOINT = False 15 | SELFPLAY_GAMES = 5000 16 | SELFPLAY_GAMES_PER_SAVE = SELFPLAY_GAMES // 4 17 | BATCH_SIZE = 128 18 | SEARCH_ITERATIONS = 32 19 | MAX_REPLAY_BUFFER_SIZE = BATCH_SIZE * 4 20 | TRAINING_EPOCHS = 5 21 | LEARNING_RATE = 1e-3 22 | WEIGHT_DECAY = 1e-1 23 | C_PUCT = 1.9 24 | DIRICHLET_ALPHA = 0.3 # set to None to disable 25 | WANDB_LOG = True 26 | WANDB_PROJECT_NAME = "tinyalphazero-tictactoe1d" 27 | WANDB_RUN_NAME = "run" + datetime.now().strftime("%Y%m%d-%H%M%S") 28 | 29 | if __name__ == "__main__": 30 | game = TicTacToe() 31 | 32 | model = LinearNetwork(game.observation_shape, game.action_space) 33 | optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY) 34 | 35 | agent = AlphaZeroAgentTrainer(model, optimizer, MAX_REPLAY_BUFFER_SIZE) 36 | 37 | if INIT_FROM_CHECKPOINT: 38 | agent.load_training_state(f"{OUT_DIR}/model.pth", f"{OUT_DIR}/optimizer.pth") 39 | 40 | if WANDB_LOG: 41 | wandb_run = wandb.init(project=WANDB_PROJECT_NAME, name=WANDB_RUN_NAME) 42 | 43 | os.makedirs(OUT_DIR, exist_ok=True) 44 | print("Starting training") 45 | 46 | for i in tqdm(range(SELFPLAY_GAMES)): 47 | game.reset() 48 | 49 | values_losses, policies_losses = agent.train_step( 50 | game, SEARCH_ITERATIONS, BATCH_SIZE, TRAINING_EPOCHS, c_puct=C_PUCT, dirichlet_alpha=DIRICHLET_ALPHA 51 | ) 52 | 53 | if WANDB_LOG: 54 | for values_loss, policies_loss in zip(values_losses, policies_losses): 55 | wandb.log({"values_loss": values_loss, "policies_loss": policies_loss}) 56 | 57 | if i > 0 and i % SELFPLAY_GAMES_PER_SAVE == 0: 58 | print("Saving training state") 59 | agent.save_training_state(f"{OUT_DIR}/model.pth", f"{OUT_DIR}/optimizer.pth") 60 | 61 | if WANDB_LOG: 62 | wandb_run.finish() 63 | 64 | print("Training complete") 65 | 66 | print("Saving final training state") 67 | agent.save_training_state(f"{OUT_DIR}/model.pth", f"{OUT_DIR}/optimizer.pth") 68 | -------------------------------------------------------------------------------- /tictactoe/two_dim/train.py: -------------------------------------------------------------------------------- 1 | from game import TicTacToe 2 | from datetime import datetime 3 | import torch 4 | import wandb 5 | from tqdm import tqdm 6 | import os 7 | import sys 8 | 9 | sys.path.append(os.getcwd()) 10 | from models import TicTacToe2DNetwork # noqa: E402 11 | from agents import AlphaZeroAgentTrainer # noqa: E402 12 | 13 | OUT_DIR = "tictactoe/two_dim/out" 14 | INIT_FROM_CHECKPOINT = False 15 | SELFPLAY_GAMES = 5000 16 | SELFPLAY_GAMES_PER_SAVE = SELFPLAY_GAMES // 4 17 | BATCH_SIZE = 128 18 | SEARCH_ITERATIONS = 32 19 | MAX_REPLAY_BUFFER_SIZE = BATCH_SIZE * 4 20 | TRAINING_EPOCHS = 5 21 | LEARNING_RATE = 1e-3 22 | WEIGHT_DECAY = 1e-1 23 | C_PUCT = 1.8 24 | DIRICHLET_ALPHA = 0.3 # set to None to disable 25 | WANDB_LOG = True 26 | WANDB_PROJECT_NAME = "tinyalphazero-tictactoe2d" 27 | WANDB_RUN_NAME = "run" + datetime.now().strftime("%Y%m%d-%H%M%S") 28 | 29 | if __name__ == "__main__": 30 | game = TicTacToe() 31 | 32 | model = TicTacToe2DNetwork(game.observation_shape, game.action_space) 33 | optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY) 34 | 35 | agent = AlphaZeroAgentTrainer(model, optimizer, MAX_REPLAY_BUFFER_SIZE) 36 | 37 | if INIT_FROM_CHECKPOINT: 38 | agent.load_training_state(f"{OUT_DIR}/model.pth", f"{OUT_DIR}/optimizer.pth") 39 | 40 | if WANDB_LOG: 41 | wandb_run = wandb.init(project=WANDB_PROJECT_NAME, name=WANDB_RUN_NAME) 42 | 43 | os.makedirs(OUT_DIR, exist_ok=True) 44 | print("Starting training") 45 | 46 | for i in tqdm(range(SELFPLAY_GAMES)): 47 | game.reset() 48 | 49 | values_losses, policies_losses = agent.train_step( 50 | game, SEARCH_ITERATIONS, BATCH_SIZE, TRAINING_EPOCHS, c_puct=C_PUCT, dirichlet_alpha=DIRICHLET_ALPHA 51 | ) 52 | 53 | if WANDB_LOG: 54 | for values_loss, policies_loss in zip(values_losses, policies_losses): 55 | wandb.log({"values_loss": values_loss, "policies_loss": policies_loss}) 56 | 57 | if i > 0 and i % SELFPLAY_GAMES_PER_SAVE == 0: 58 | print("Saving training state") 59 | agent.save_training_state(f"{OUT_DIR}/model.pth", f"{OUT_DIR}/optimizer.pth") 60 | 61 | if WANDB_LOG: 62 | wandb_run.finish() 63 | 64 | print("Training complete") 65 | 66 | print("Saving final training state") 67 | agent.save_training_state(f"{OUT_DIR}/model.pth", f"{OUT_DIR}/optimizer.pth") 68 | -------------------------------------------------------------------------------- /tictactoe/one_dim/eval.py: -------------------------------------------------------------------------------- 1 | from game import TicTacToe 2 | import torch 3 | from train import OUT_DIR, SEARCH_ITERATIONS 4 | from tqdm import tqdm 5 | 6 | import os 7 | import sys 8 | 9 | sys.path.append(os.getcwd()) 10 | from models import LinearNetwork # noqa: E402 11 | from agents import AlphaZeroAgent, ClassicMCTSAgent # noqa: E402 12 | from mcts import pit # noqa: E402 13 | 14 | EVAL_GAMES = 100 15 | 16 | if __name__ == "__main__": 17 | game = TicTacToe() 18 | 19 | model = LinearNetwork(game.observation_shape, game.action_space) 20 | model.load_state_dict(torch.load(f"{OUT_DIR}/model.pth")) 21 | 22 | agent = AlphaZeroAgent(model) 23 | agent_play_kwargs = {"search_iterations": SEARCH_ITERATIONS * 2, "c_puct": 1.0, "dirichlet_alpha": None} 24 | 25 | print(f"Playing {EVAL_GAMES} games against itself") 26 | 27 | results = {0: 0, 1: 0, -1: 0} 28 | for _ in tqdm(range(EVAL_GAMES)): 29 | game.reset() 30 | result = pit( 31 | game, 32 | agent, 33 | agent, 34 | agent_play_kwargs, 35 | agent_play_kwargs, 36 | ) 37 | results[result] += 1 38 | 39 | print("Results:") 40 | print(f"First player wins: {results[1]}") 41 | print(f"Second player wins: {results[-1]}") 42 | print(f"Draws: {results[0]}") 43 | 44 | classic_mcts_agent = ClassicMCTSAgent 45 | classic_mcts_agent_play_kwargs = {"search_iterations": 100, "c_puct": 1.0, "dirichlet_alpha": None} 46 | 47 | print(f"Playing {EVAL_GAMES} games against classic MCTS agent (starting first)") 48 | 49 | results = {0: 0, 1: 0, -1: 0} 50 | for _ in tqdm(range(EVAL_GAMES)): 51 | game.reset() 52 | result = pit( 53 | game, 54 | agent, 55 | classic_mcts_agent, 56 | agent_play_kwargs, 57 | classic_mcts_agent_play_kwargs, 58 | ) 59 | results[result] += 1 60 | 61 | print("Results:") 62 | print(f"AlphaZero agent wins: {results[1]}") 63 | print(f"Classic MCTS agent wins: {results[-1]}") 64 | print(f"Draws: {results[0]}") 65 | 66 | print(f"Playing {EVAL_GAMES} games against classic MCTS agent (starting second)") 67 | 68 | results = {0: 0, 1: 0, -1: 0} 69 | for _ in tqdm(range(EVAL_GAMES)): 70 | game.reset() 71 | result = pit( 72 | game, 73 | classic_mcts_agent, 74 | agent, 75 | classic_mcts_agent_play_kwargs, 76 | agent_play_kwargs, 77 | ) 78 | results[result] += 1 79 | 80 | print("Results:") 81 | print(f"Classic MCTS agent wins: {results[1]}") 82 | print(f"AlphaZero agent wins: {results[-1]}") 83 | print(f"Draws: {results[0]}") 84 | -------------------------------------------------------------------------------- /tictactoe/two_dim/eval.py: -------------------------------------------------------------------------------- 1 | from game import TicTacToe 2 | import torch 3 | from train import OUT_DIR, SEARCH_ITERATIONS 4 | from tqdm import tqdm 5 | 6 | import os 7 | import sys 8 | 9 | sys.path.append(os.getcwd()) 10 | from models import TicTacToe2DNetwork # noqa: E402 11 | from agents import AlphaZeroAgent, ClassicMCTSAgent # noqa: E402 12 | from mcts import pit # noqa: E402 13 | 14 | EVAL_GAMES = 100 15 | 16 | if __name__ == "__main__": 17 | game = TicTacToe() 18 | 19 | model = TicTacToe2DNetwork(game.observation_shape, game.action_space) 20 | model.load_state_dict(torch.load(f"{OUT_DIR}/model.pth")) 21 | 22 | agent = AlphaZeroAgent(model) 23 | agent_play_kwargs = {"search_iterations": SEARCH_ITERATIONS * 2, "c_puct": 1.0, "dirichlet_alpha": None} 24 | 25 | print(f"Playing {EVAL_GAMES} games against itself") 26 | 27 | results = {0: 0, 1: 0, -1: 0} 28 | for _ in tqdm(range(EVAL_GAMES)): 29 | game.reset() 30 | result = pit( 31 | game, 32 | agent, 33 | agent, 34 | agent_play_kwargs, 35 | agent_play_kwargs, 36 | ) 37 | results[result] += 1 38 | 39 | print("Results:") 40 | print(f"First player wins: {results[1]}") 41 | print(f"Second player wins: {results[-1]}") 42 | print(f"Draws: {results[0]}") 43 | 44 | classic_mcts_agent = ClassicMCTSAgent 45 | classic_mcts_agent_play_kwargs = {"search_iterations": 100, "c_puct": 1.0, "dirichlet_alpha": None} 46 | 47 | print(f"Playing {EVAL_GAMES} games against classic MCTS agent (starting first)") 48 | 49 | results = {0: 0, 1: 0, -1: 0} 50 | for _ in tqdm(range(EVAL_GAMES)): 51 | game.reset() 52 | result = pit( 53 | game, 54 | agent, 55 | classic_mcts_agent, 56 | agent_play_kwargs, 57 | classic_mcts_agent_play_kwargs, 58 | ) 59 | results[result] += 1 60 | 61 | print("Results:") 62 | print(f"AlphaZero agent wins: {results[1]}") 63 | print(f"Classic MCTS agent wins: {results[-1]}") 64 | print(f"Draws: {results[0]}") 65 | 66 | print(f"Playing {EVAL_GAMES} games against classic MCTS agent (starting second)") 67 | 68 | results = {0: 0, 1: 0, -1: 0} 69 | for _ in tqdm(range(EVAL_GAMES)): 70 | game.reset() 71 | result = pit( 72 | game, 73 | classic_mcts_agent, 74 | agent, 75 | classic_mcts_agent_play_kwargs, 76 | agent_play_kwargs, 77 | ) 78 | results[result] += 1 79 | 80 | print("Results:") 81 | print(f"Classic MCTS agent wins: {results[1]}") 82 | print(f"AlphaZero agent wins: {results[-1]}") 83 | print(f"Draws: {results[0]}") 84 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class LinearNetwork(nn.Module): 7 | def __init__(self, input_shape, action_space, first_layer_size=512, second_layer_size=256): 8 | super().__init__() 9 | self.first_layer = nn.Linear(input_shape[0], first_layer_size) 10 | self.second_layer = nn.Linear(first_layer_size, second_layer_size) 11 | self.value_head = nn.Linear(second_layer_size, 1) 12 | self.policy_head = nn.Linear(second_layer_size, action_space) 13 | 14 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 15 | self.to(self.device) 16 | 17 | def __call__(self, observations): 18 | self.train() 19 | x = F.relu(self.first_layer(observations)) 20 | x = F.relu(self.second_layer(x)) 21 | value = F.tanh(self.value_head(x)) 22 | log_policy = F.log_softmax(self.policy_head(x), dim=-1) 23 | return value, log_policy 24 | 25 | def value_forward(self, observation): 26 | self.eval() 27 | with torch.no_grad(): 28 | x = F.relu(self.first_layer(observation)) 29 | x = F.relu(self.second_layer(x)) 30 | value = F.tanh(self.value_head(x)) 31 | return value 32 | 33 | def policy_forward(self, observation): 34 | self.eval() 35 | with torch.no_grad(): 36 | x = F.relu(self.first_layer(observation)) 37 | x = F.relu(self.second_layer(x)) 38 | log_policy = F.softmax(self.policy_head(x), dim=-1) 39 | return log_policy 40 | 41 | 42 | class TicTacToe2DNetwork(nn.Module): 43 | def __init__(self, input_shape, action_space, first_linear_size=512, second_linear_size=256): 44 | super().__init__() 45 | self.conv1 = nn.Conv2d(1, 32, kernel_size=1) 46 | self.conv2 = nn.Conv2d(32, 32, kernel_size=1) 47 | self.conv3 = nn.Conv2d(32, 64, kernel_size=1) 48 | self.dropout = nn.Dropout2d(p=0.3) 49 | self.fc1 = nn.Linear(3 * 3 * 64, first_linear_size) 50 | self.fc2 = nn.Linear(first_linear_size, second_linear_size) 51 | self.value_head = nn.Linear(second_linear_size, 1) 52 | self.policy_head = nn.Linear(second_linear_size, action_space) 53 | 54 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 55 | self.to(self.device) 56 | 57 | def __call__(self, observations): 58 | self.train() 59 | x = F.relu(self.conv1(observations)) 60 | x = F.relu(self.conv2(x)) 61 | x = F.relu(self.conv3(x)) 62 | x = self.dropout(x) 63 | x = x.view(-1, 3 * 3 * 64) 64 | x = F.relu(self.fc1(x)) 65 | x = F.relu(self.fc2(x)) 66 | value = F.tanh(self.value_head(x)) 67 | log_policy = F.log_softmax(self.policy_head(x), dim=-1) 68 | return value, log_policy 69 | 70 | def value_forward(self, observation): 71 | self.eval() 72 | with torch.no_grad(): 73 | x = F.relu(self.conv1(observation)) 74 | x = F.relu(self.conv2(x)) 75 | x = F.relu(self.conv3(x)) 76 | x = x.view(-1, 3 * 3 * 64) 77 | x = F.relu(self.fc1(x)) 78 | x = F.relu(self.fc2(x)) 79 | value = F.tanh(self.value_head(x)) 80 | return value[0] 81 | 82 | def policy_forward(self, observation): 83 | self.eval() 84 | with torch.no_grad(): 85 | x = F.relu(self.conv1(observation)) 86 | x = F.relu(self.conv2(x)) 87 | x = F.relu(self.conv3(x)) 88 | x = x.view(-1, 3 * 3 * 64) 89 | x = F.relu(self.fc1(x)) 90 | x = F.relu(self.fc2(x)) 91 | log_policy = F.softmax(self.policy_head(x), dim=-1) 92 | return log_policy[0] 93 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tinyzero 2 | 3 | 4 | 5 | Easily train AlphaZero-like agents on any environment you want! 6 | 7 | ## Usage 8 | Make sure you have Python >= 3.8 intalled. After that, run `pip install -r requirements.txt` to install the necessary dependencies. 9 | 10 | Then, to train an agent on one of the existing environments, run: 11 | ```bash 12 | python3 tictactoe/two_dim/train.py 13 | ``` 14 | where `tictactoe/two_dim` is the name of the environment you want to train on. 15 | 16 | Inside the train script, you can change parameters such as the number of episodes, the number of simulations and enable [wandb](https://wandb.ai/site) logging. 17 | 18 | Similarly, to evaluate the trained agent run: 19 | ```bash 20 | python3 tictactoe/two_dim/eval.py 21 | ``` 22 | 23 | ## Add an environment 24 | 25 | To add a new environment, you can follow the `game.py` files in every existing examples. 26 | 27 | The environment you add should implement the following methods: 28 | - `reset()`: resets the environment to its initial state 29 | - `step(action)`: takes an action and modifies the state of the environment accordingly 30 | - `get_legal_actions()`: returns a list of legal actions 31 | - `undo_last_action()`: cancels the last action taken 32 | - `to_observation()`: returns the current state of the environment as an observation (a numpy array) to be used as input to the model 33 | - `get_result()`: returns the result of the game (for example, it might be 1 if the first player won, -1 if the second player won, 0 if it's a draw, and None if the game is not over yet) 34 | - `get_first_person_result()`: returns the result of the game from the perspective of the current player (for example, it might be 1 if the current player won, -1 if the opponent won, 0 if it's a draw, and None if the game is not over yet) 35 | - `swap_result(result)`: swaps the result of the game (for example, if the result is 1, it should become -1, and vice versa). It's needed to cover all of the possible game types (single player, two players, zero-sum, non-zero-sum, etc.) 36 | 37 | ## Add a model 38 | 39 | To add a new model, you can follow the existing examples in `models.py`. 40 | 41 | The model you add should implement the following methods: 42 | - `__call__`: takes as input an observation and returns a value and a policy 43 | - `value_forward(observation)`: takes as input an observation and returns a value 44 | - `policy_forward(observation)`: takes as input an observation and returns a distribution over the actions (the policy) 45 | 46 | The latter two methods are used to speed up the MCTS. 47 | 48 | The AlphaZero agent computes the policy loss as the Kulback-Leibler divergence between the distribution produced by the model and the one given by the MCTS. Therefore, the policy returned by the `__call__` method should be logaritmic. On the other hand, the policy returned by the `policy_forward` method should represent a probability distribution. 49 | 50 | ## Add a new agent 51 | 52 | Thanks to the way the value and policy functions are called by the search tree, it's possible to use or train any agent that implements them. To add a new agent, you can follow the existing examples in `agents.py`. 53 | 54 | The agent you add should implement the following methods: 55 | - `value_fn(game)`: takes as input a game and returns a value (float) 56 | - `policy_fn(game)`: takes as input a game and returns a policy (Numpy array) 57 | 58 | Any other method is not directly used by the MCTS, so it's optional and depends on the agent you want to implement. For example, the `AlphaZeroAgent` is extended by the `AlphaZeroAgentTrainer` class that adds methods to train the model after each episode. 59 | 60 | ## Train in Google Colab 61 | 62 | To train in Google Colab, install `wandb` first: 63 | ```bash 64 | !pip install wandb 65 | ``` 66 | Then clone the repository: 67 | ```bash 68 | !git clone https://github.com/s-casci/tinyzero.git 69 | ``` 70 | Train on one of the environments (select a GPU runtime for faster training): 71 | ```bash 72 | !cd tinyzero; python3 tictactoe/two_dim/train.py 73 | ``` 74 | And evaluate: 75 | ```bash 76 | !cd tinyzero; python3 tictactoe/two_dim/eval.py 77 | ``` 78 | -------------------------------------------------------------------------------- /agents.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from replay_buffer import ReplayBuffer 5 | import copy 6 | from mcts import search 7 | 8 | 9 | class ClassicMCTSAgent: 10 | @staticmethod 11 | def value_fn(game): 12 | game = copy.deepcopy(game) 13 | while first_person_result := game.get_first_person_result() is None: 14 | game.step(np.random.choice(game.get_legal_actions())) 15 | return first_person_result 16 | 17 | @staticmethod 18 | def policy_fn(game): 19 | return np.ones(game.action_space) / game.action_space 20 | 21 | 22 | class AlphaZeroAgent: 23 | def __init__(self, model): 24 | self.model = model 25 | 26 | def value_fn(self, game): 27 | observation = torch.tensor(game.to_observation(), device=self.model.device, requires_grad=False) 28 | value = self.model.value_forward(observation) 29 | return value.item() 30 | 31 | def policy_fn(self, game): 32 | observation = torch.tensor(game.to_observation(), device=self.model.device, requires_grad=False) 33 | policy = self.model.policy_forward(observation) 34 | return policy.cpu().numpy() 35 | 36 | 37 | class AlphaZeroAgentTrainer(AlphaZeroAgent): 38 | def __init__(self, model, optimizer, replay_buffer_max_size): 39 | super().__init__(model) 40 | self.optimizer = optimizer 41 | self.replay_buffer = ReplayBuffer(max_size=replay_buffer_max_size) 42 | 43 | def _selfplay(self, game, search_iterations, c_puct=1.0, dirichlet_alpha=None): 44 | buffer = [] 45 | while (first_person_result := game.get_first_person_result()) is None: 46 | root_node = search( 47 | game, self.value_fn, self.policy_fn, search_iterations, c_puct=c_puct, dirichlet_alpha=dirichlet_alpha 48 | ) 49 | visits_dist = root_node.children_visits / root_node.children_visits.sum() 50 | 51 | action = root_node.children_actions[np.random.choice(len(root_node.children), p=visits_dist)] 52 | 53 | actions_dist = np.zeros(game.action_space, dtype=np.float32) 54 | actions_dist[root_node.children_actions] = visits_dist 55 | buffer.append((game.to_observation(), actions_dist)) 56 | 57 | game.step(action) 58 | 59 | return first_person_result, buffer 60 | 61 | def train_step(self, game, search_iterations, batch_size, epochs, c_puct=1.0, dirichlet_alpha=None): 62 | first_person_result, game_buffer = self._selfplay( 63 | game, search_iterations, c_puct=c_puct, dirichlet_alpha=dirichlet_alpha 64 | ) 65 | 66 | result = game.swap_result(first_person_result) 67 | while len(game_buffer) > 0: 68 | observation, action_dist = game_buffer.pop() 69 | self.replay_buffer.add_sample(observation, action_dist, result) 70 | result = game.swap_result(result) 71 | 72 | values_losses, policies_losses = [], [] 73 | if len(self.replay_buffer) >= batch_size: 74 | for _ in range(epochs): 75 | observations, actions_dist, results = self.replay_buffer.sample(batch_size) 76 | observations = torch.tensor(observations, device=self.model.device) 77 | actions_dist = torch.tensor(actions_dist, device=self.model.device) 78 | results = torch.tensor(results, device=self.model.device) 79 | 80 | self.optimizer.zero_grad() 81 | values, log_policies = self.model(observations) 82 | 83 | # mean squared error 84 | values_loss = F.mse_loss(values.squeeze(1), results) 85 | # Kullback–Leibler divergence 86 | policies_loss = F.kl_div(log_policies, actions_dist, reduction="batchmean") 87 | 88 | (values_loss + policies_loss).backward() 89 | self.optimizer.step() 90 | 91 | values_losses.append(values_loss.item()) 92 | policies_losses.append(policies_loss.item()) 93 | 94 | return values_losses, policies_losses 95 | 96 | def save_training_state(self, model_out_path, optimizer_out_path): 97 | torch.save(self.model.state_dict(), model_out_path) 98 | torch.save(self.optimizer.state_dict(), optimizer_out_path) 99 | 100 | def load_training_state(self, model_out_path, optimizer_out_path): 101 | self.model.load_state_dict(torch.load(model_out_path)) 102 | self.optimizer.load_state_dict(torch.load(optimizer_out_path)) 103 | -------------------------------------------------------------------------------- /mcts.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | from numba import njit 4 | 5 | 6 | class RootNode: 7 | def __init__(self): 8 | self.parent = None 9 | self.visits = 0 10 | self.children = None 11 | 12 | 13 | class Node(RootNode): 14 | def __init__(self, idx, parent): 15 | self.idx = idx 16 | self.parent = parent 17 | self.children = None 18 | 19 | @property 20 | def visits(self): 21 | return self.parent.children_visits[self.idx] 22 | 23 | @visits.setter 24 | def visits(self, x): 25 | self.parent.children_visits[self.idx] = x 26 | 27 | @property 28 | def action(self): 29 | return self.parent.children_actions[self.idx] 30 | 31 | @property 32 | def value(self): 33 | return self.parent.children_values[self.idx] 34 | 35 | @value.setter 36 | def value(self, x): 37 | self.parent.children_values[self.idx] = x 38 | 39 | 40 | @njit(fastmath=True, parallel=True) 41 | def get_ucb_scores_jitted(children_values, children_priors, visits, children_visits, c_puct): 42 | return children_values + c_puct * children_priors * math.sqrt(visits) / (children_visits + 1) 43 | 44 | 45 | def get_ucb_scores(node, c_puct): 46 | return get_ucb_scores_jitted(node.children_values, node.children_priors, node.visits, node.children_visits, c_puct) 47 | 48 | 49 | def select(root, game, c_puct): 50 | current = root 51 | while current.children: 52 | ucb_scores = get_ucb_scores(current, c_puct) 53 | # every child needs at least 1 visit 54 | ucb_scores[current.children_visits == 0] = np.inf 55 | current = current.children[np.argmax(ucb_scores)] 56 | game.step(current.action) 57 | return current 58 | 59 | 60 | def expand(leaf, children_actions, children_priors): 61 | leaf.children = [Node(idx, leaf) for idx, _ in enumerate(children_actions)] 62 | leaf.children_actions = children_actions 63 | leaf.children_priors = children_priors 64 | leaf.children_values = np.zeros_like(leaf.children_priors) 65 | leaf.children_visits = np.zeros_like(leaf.children_priors) 66 | 67 | 68 | def backpropagate(leaf, game, result): 69 | current = leaf 70 | while current.parent: 71 | # different games might have different result representations 72 | result = game.swap_result(result) 73 | # incremental mean update 74 | current.value = (current.value * current.visits + result) / (current.visits + 1) 75 | current.visits += 1 76 | current = current.parent 77 | game.undo_last_action() 78 | current.visits += 1 79 | 80 | 81 | def search(game, value_fn, policy_fn, iterations, c_puct=1.0, dirichlet_alpha=None): 82 | root = RootNode() 83 | # expand the root so that there's no need to check if it's necessary to add dirichlet noise 84 | # at every iteration of the search loop 85 | children_actions = game.get_legal_actions() 86 | children_priors = policy_fn(game)[children_actions] 87 | if dirichlet_alpha: 88 | children_priors = 0.75 * children_priors + 0.25 * np.random.default_rng().dirichlet( 89 | dirichlet_alpha * np.ones_like(children_priors) 90 | ) 91 | expand(root, game.get_legal_actions(), children_priors) 92 | 93 | for _ in range(iterations): 94 | leaf = select(root, game, c_puct) 95 | result = game.get_first_person_result() 96 | if result is None: 97 | children_actions = game.get_legal_actions() 98 | children_priors = policy_fn(game)[children_actions] 99 | expand(leaf, children_actions, children_priors) 100 | result = value_fn(game) 101 | backpropagate(leaf, game, result) 102 | return root 103 | 104 | 105 | def play(game, agent, search_iterations, c_puct=1.0, dirichlet_alpha=None): 106 | root = search( 107 | game, agent.value_fn, agent.policy_fn, search_iterations, c_puct=c_puct, dirichlet_alpha=dirichlet_alpha 108 | ) 109 | return root.children_actions[np.argmax(root.children_visits)] 110 | 111 | 112 | def pit(game, agent1, agent2, agent1_play_kwargs, agent2_play_kwargs): 113 | current_agent, other_agent = agent1, agent2 114 | current_agent_play_kwargs, other_agent_play_kwargs = agent1_play_kwargs, agent2_play_kwargs 115 | while (result := game.get_result()) is None: 116 | action = play(game, current_agent, **current_agent_play_kwargs) 117 | game.step(action) 118 | current_agent, other_agent = other_agent, current_agent 119 | current_agent_play_kwargs, other_agent_play_kwargs = other_agent_play_kwargs, current_agent_play_kwargs 120 | return result 121 | --------------------------------------------------------------------------------