├── .gitignore ├── playout400.gif ├── playout800.gif ├── best_policy_6_6_4.model ├── best_policy_6_6_4.model2 ├── best_policy_8_8_5.model ├── best_policy_8_8_5.model2 ├── LICENSE ├── README.md ├── human_play.py ├── policy_value_net.py ├── policy_value_net_numpy.py ├── train.py ├── mcts_pure.py ├── mcts_alphaZero.py └── game.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | __pycache__/ 3 | -------------------------------------------------------------------------------- /playout400.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cafe/AlphaZero_Gomoku/master/playout400.gif -------------------------------------------------------------------------------- /playout800.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cafe/AlphaZero_Gomoku/master/playout800.gif -------------------------------------------------------------------------------- /best_policy_6_6_4.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cafe/AlphaZero_Gomoku/master/best_policy_6_6_4.model -------------------------------------------------------------------------------- /best_policy_6_6_4.model2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cafe/AlphaZero_Gomoku/master/best_policy_6_6_4.model2 -------------------------------------------------------------------------------- /best_policy_8_8_5.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cafe/AlphaZero_Gomoku/master/best_policy_8_8_5.model -------------------------------------------------------------------------------- /best_policy_8_8_5.model2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cafe/AlphaZero_Gomoku/master/best_policy_8_8_5.model2 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 junxiaosong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## AlphaZero-Gomoku 2 | This is an implementation of the AlphaZero algorithm for playing the simple board game Gomoku (also called Gobang or Five in a Row) from pure self-play training. The game Gomoku is much simpler than Go or chess, so that we can focus on the training scheme of AlphaZero and obtain a pretty good AI model on a single PC in a few hours. 3 | 4 | References: 5 | 1. AlphaZero: Mastering Chess and Shogi by Self-Play with a General Reinforcement Learning Algorithm 6 | 2. AlphaGo Zero: Mastering the game of Go without human knowledge 7 | 8 | ### Example Games Between Trained Models 9 | - Each move with 400 playouts: 10 | ![playout400](https://raw.githubusercontent.com/junxiaosong/AlphaZero_Gomoku/master/playout400.gif) 11 | - Each move with 800 playouts: 12 | ![playout800](https://raw.githubusercontent.com/junxiaosong/AlphaZero_Gomoku/master/playout800.gif) 13 | 14 | ### Requirements 15 | To play with the trained AI models, only need: 16 | - Python >= 2.7 17 | - Numpy >= 1.11 18 | 19 | To train the AI model from scratch, further need: 20 | - Theano >= 0.7 21 | - Lasagne >= 0.1 22 | 23 | **PS**: if your Theano's version > 0.7, please follow this [issue](https://github.com/aigamedev/scikit-neuralnetwork/issues/235) to install Lasagne, 24 | otherwise, force pip to downgrade Theano to 0.7 ``pip install --upgrade theano==0.7.0`` 25 | 26 | If you would like to train the model using other DL frameworks, such as TensorFlow or PyTorch, you only need to rewrite policy_value_net.py. 27 | 28 | ### Getting Started 29 | To play with provided models, run the following script from the directory: 30 | > python human_play.py 31 | 32 | You may modify human_play.py to try different provided models or the pure MCTS. 33 | 34 | To train the AI model from scratch, run: 35 | > python train.py 36 | 37 | The models (best_policy.model and current_policy.model) will be saved every a few updates (default 50). 38 | 39 | **Tips for training:** 40 | 1. It is good to start with a 6 * 6 board and 4 in a row. For this case, we may obtain a reasonably good model within 500~1000 self-play games in about 2 hours. 41 | 2. For the case of 8 * 8 board and 5 in a row, it may need 2000~3000 self-play games to get a good model, and it may take about 2 days on a single PC. 42 | 43 | -------------------------------------------------------------------------------- /human_play.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | human VS AI models 4 | Input your move in the format: 2,3 5 | 6 | @author: Junxiao Song 7 | """ 8 | 9 | from __future__ import print_function 10 | from game import Board, Game 11 | # from policy_value_net import PolicyValueNet 12 | from policy_value_net_numpy import PolicyValueNetNumpy 13 | from mcts_pure import MCTSPlayer as MCTS_Pure 14 | from mcts_alphaZero import MCTSPlayer 15 | # import cPickle as pickle 16 | import pickle 17 | 18 | class Human(object): 19 | """ 20 | human player 21 | """ 22 | 23 | def __init__(self): 24 | self.player = None 25 | 26 | def set_player_ind(self, p): 27 | self.player = p 28 | 29 | def get_action(self, board): 30 | try: 31 | location = input("Your move: ") 32 | if isinstance(location, str): 33 | location = [int(n, 10) for n in location.split(",")] # for python3 34 | move = board.location_to_move(location) 35 | except Exception as e: 36 | move = -1 37 | if move == -1 or move not in board.availables: 38 | print("invalid move") 39 | move = self.get_action(board) 40 | return move 41 | 42 | def __str__(self): 43 | return "Human {}".format(self.player) 44 | 45 | 46 | def run(): 47 | n = 5 48 | width, height = 8, 8 49 | model_file = 'best_policy_8_8_5.model' 50 | try: 51 | board = Board(width=width, height=height, n_in_row=n) 52 | game = Game(board) 53 | 54 | ################ human VS AI ################### 55 | # MCTS player with the policy_value_net trained by AlphaZero algorithm 56 | # policy_param = pickle.load(open(model_file, 'rb')) 57 | # best_policy = PolicyValueNet(width, height, net_params = policy_param) 58 | # mcts_player = MCTSPlayer(best_policy.policy_value_fn, c_puct=5, n_playout=400) 59 | 60 | # MCTS player with the trained policy_value_net written in pure numpy 61 | try: 62 | policy_param = pickle.load(open(model_file, 'rb')) 63 | except: 64 | policy_param = pickle.load(open(model_file, 'rb'), encoding = 'bytes') # To support python3 65 | best_policy = PolicyValueNetNumpy(width, height, policy_param) 66 | mcts_player = MCTSPlayer(best_policy.policy_value_fn, c_puct=5, n_playout=400) # set larger n_playout for better performance 67 | 68 | # uncomment the following line to play with pure MCTS (its much weaker even with a larger n_playout) 69 | # mcts_player = MCTS_Pure(c_puct=5, n_playout=1000) 70 | 71 | # human player, input your move in the format: 2,3 72 | human = Human() 73 | 74 | # set start_player=0 for human first 75 | game.start_play(human, mcts_player, start_player=1, is_shown=1) 76 | except KeyboardInterrupt: 77 | print('\n\rquit') 78 | 79 | if __name__ == '__main__': 80 | run() 81 | 82 | 83 | -------------------------------------------------------------------------------- /policy_value_net.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: Junxiao Song 4 | """ 5 | 6 | from __future__ import print_function 7 | import theano 8 | import theano.tensor as T 9 | import lasagne 10 | 11 | class PolicyValueNet(): 12 | """policy-value network """ 13 | def __init__(self, board_width, board_height, net_params=None): 14 | self.board_width = board_width 15 | self.board_height = board_height 16 | self.learning_rate = T.scalar('learning_rate') 17 | self.l2_const = 1e-4 # coef of l2 penalty 18 | self.create_policy_value_net() 19 | self._loss_train_op() 20 | if net_params: 21 | lasagne.layers.set_all_param_values([self.policy_net, self.value_net], net_params) 22 | 23 | def create_policy_value_net(self): 24 | """create the policy value network """ 25 | self.state_input = T.tensor4('state') 26 | self.winner = T.vector('winner') 27 | self.mcts_probs = T.matrix('mcts_probs') 28 | network = lasagne.layers.InputLayer(shape=(None, 4, self.board_width, self.board_height), 29 | input_var=self.state_input) 30 | # conv layers 31 | network = lasagne.layers.Conv2DLayer(network, num_filters=32, filter_size=(3, 3), pad='same') 32 | network = lasagne.layers.Conv2DLayer(network, num_filters=64, filter_size=(3, 3), pad='same') 33 | network = lasagne.layers.Conv2DLayer(network, num_filters=128, filter_size=(3, 3), pad='same') 34 | # action policy layers 35 | policy_net = lasagne.layers.Conv2DLayer(network, num_filters=4, filter_size=(1, 1)) 36 | self.policy_net = lasagne.layers.DenseLayer(policy_net, num_units=self.board_width*self.board_height, 37 | nonlinearity=lasagne.nonlinearities.softmax) 38 | # state value layers 39 | value_net = lasagne.layers.Conv2DLayer(network, num_filters=2, filter_size=(1, 1)) 40 | value_net = lasagne.layers.DenseLayer(value_net, num_units=64) 41 | self.value_net = lasagne.layers.DenseLayer(value_net, num_units=1, nonlinearity=lasagne.nonlinearities.tanh) 42 | # get action probs and state score value 43 | self.action_probs, self.value = lasagne.layers.get_output([self.policy_net, self.value_net]) 44 | self.policy_value = theano.function([self.state_input], [self.action_probs, self.value] ,allow_input_downcast=True) 45 | 46 | def policy_value_fn(self, board): 47 | """ 48 | input: board 49 | output: a list of (action, probability) tuples for each available action and the score of the board state 50 | """ 51 | legal_positions = board.availables 52 | current_state = board.current_state() 53 | act_probs, value = self.policy_value(current_state.reshape(-1, 4, self.board_width, self.board_height)) 54 | act_probs = zip(legal_positions, act_probs.flatten()[legal_positions]) 55 | return act_probs, value[0][0] 56 | 57 | def _loss_train_op(self): 58 | """ 59 | Three loss terms: 60 | loss = (z - v)^2 + pi^T * log(p) + c||theta||^2 61 | """ 62 | params = lasagne.layers.get_all_params([self.policy_net, self.value_net], trainable=True) 63 | value_loss = lasagne.objectives.squared_error(self.winner, self.value.flatten()) 64 | policy_loss = lasagne.objectives.categorical_crossentropy(self.action_probs, self.mcts_probs) 65 | l2_penalty = lasagne.regularization.apply_penalty(params, lasagne.regularization.l2) 66 | self.loss = lasagne.objectives.aggregate(value_loss + policy_loss, mode='mean') + self.l2_const*l2_penalty 67 | # policy entropy,for monitoring only 68 | self.entropy = -T.mean(T.sum(self.action_probs * T.log(self.action_probs + 1e-10), axis=1)) 69 | # get the train op 70 | updates = lasagne.updates.adam(self.loss, params, learning_rate=self.learning_rate) 71 | self.train_step = theano.function([self.state_input, self.mcts_probs, self.winner, self.learning_rate], 72 | [self.loss, self.entropy], updates=updates, allow_input_downcast=True) 73 | 74 | def get_policy_param(self): 75 | net_params = lasagne.layers.get_all_param_values([self.policy_net, self.value_net]) 76 | return net_params 77 | -------------------------------------------------------------------------------- /policy_value_net_numpy.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Implement the policy value network using numpy, so that we can play with the 4 | trained AI model without installing any DL framwork 5 | 6 | @author: Junxiao Song 7 | """ 8 | 9 | from __future__ import print_function 10 | import numpy as np 11 | 12 | # some utility functions 13 | def softmax(x): 14 | probs = np.exp(x - np.max(x)) 15 | probs /= np.sum(probs) 16 | return probs 17 | 18 | def relu(X): 19 | out = np.maximum(X, 0) 20 | return out 21 | 22 | def conv_forward(X, W, b, stride=1, padding=1): 23 | n_filters, d_filter, h_filter, w_filter = W.shape 24 | W = W[:,:,::-1,::-1] # theano conv2d flips the filters (rotate 180 degree) first while doing the calculation 25 | n_x, d_x, h_x, w_x = X.shape 26 | h_out = (h_x - h_filter + 2 * padding) / stride + 1 27 | w_out = (w_x - w_filter + 2 * padding) / stride + 1 28 | h_out, w_out = int(h_out), int(w_out) 29 | X_col = im2col_indices(X, h_filter, w_filter, padding=padding, stride=stride) 30 | W_col = W.reshape(n_filters, -1) 31 | out = (np.dot(W_col, X_col).T + b).T 32 | out = out.reshape(n_filters, h_out, w_out, n_x) 33 | out = out.transpose(3, 0, 1, 2) 34 | return out 35 | 36 | def fc_forward(X, W, b): 37 | out = np.dot(X, W) + b 38 | return out 39 | 40 | def get_im2col_indices(x_shape, field_height, field_width, padding=1, stride=1): 41 | # First figure out what the size of the output should be 42 | N, C, H, W = x_shape 43 | assert (H + 2 * padding - field_height) % stride == 0 44 | assert (W + 2 * padding - field_height) % stride == 0 45 | out_height = int((H + 2 * padding - field_height) / stride + 1) 46 | out_width = int((W + 2 * padding - field_width) / stride + 1) 47 | 48 | i0 = np.repeat(np.arange(field_height), field_width) 49 | i0 = np.tile(i0, C) 50 | i1 = stride * np.repeat(np.arange(out_height), out_width) 51 | j0 = np.tile(np.arange(field_width), field_height * C) 52 | j1 = stride * np.tile(np.arange(out_width), out_height) 53 | i = i0.reshape(-1, 1) + i1.reshape(1, -1) 54 | j = j0.reshape(-1, 1) + j1.reshape(1, -1) 55 | 56 | k = np.repeat(np.arange(C), field_height * field_width).reshape(-1, 1) 57 | 58 | return (k.astype(int), i.astype(int), j.astype(int)) 59 | 60 | 61 | def im2col_indices(x, field_height, field_width, padding=1, stride=1): 62 | """ An implementation of im2col based on some fancy indexing """ 63 | # Zero-pad the input 64 | p = padding 65 | x_padded = np.pad(x, ((0, 0), (0, 0), (p, p), (p, p)), mode='constant') 66 | 67 | k, i, j = get_im2col_indices(x.shape, field_height, field_width, padding, stride) 68 | 69 | cols = x_padded[:, k, i, j] 70 | C = x.shape[1] 71 | cols = cols.transpose(1, 2, 0).reshape(field_height * field_width * C, -1) 72 | return cols 73 | 74 | class PolicyValueNetNumpy(): 75 | """policy-value network in numpy """ 76 | def __init__(self, board_width, board_height, net_params): 77 | self.board_width = board_width 78 | self.board_height = board_height 79 | self.params = net_params 80 | 81 | def policy_value_fn(self, board): 82 | """ 83 | input: board 84 | output: a list of (action, probability) tuples for each available action and the score of the board state 85 | """ 86 | legal_positions = board.availables 87 | current_state = board.current_state() 88 | 89 | X = current_state.reshape(-1, 4, self.board_width, self.board_height) 90 | # first 3 conv layers with ReLu nonlinearity 91 | for i in [0,2,4]: 92 | X = relu(conv_forward(X, self.params[i], self.params[i+1])) 93 | # policy head 94 | X_p = relu(conv_forward(X, self.params[6], self.params[7], padding=0)) 95 | X_p = fc_forward(X_p.flatten(), self.params[8], self.params[9]) 96 | act_probs = softmax(X_p) 97 | # value head 98 | X_v = relu(conv_forward(X, self.params[10], self.params[11], padding=0)) 99 | X_v = relu(fc_forward(X_v.flatten(), self.params[12], self.params[13])) 100 | value = np.tanh(fc_forward(X_v, self.params[14], self.params[15]))[0] 101 | 102 | act_probs = zip(legal_positions, act_probs.flatten()[legal_positions]) 103 | return act_probs, value -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: Junxiao Song 4 | """ 5 | 6 | from __future__ import print_function 7 | import random 8 | import numpy as np 9 | import cPickle as pickle 10 | from collections import defaultdict, deque 11 | from game import Board, Game 12 | from policy_value_net import PolicyValueNet 13 | from mcts_pure import MCTSPlayer as MCTS_Pure 14 | from mcts_alphaZero import MCTSPlayer 15 | 16 | 17 | class TrainPipeline(): 18 | def __init__(self): 19 | # params of the board and the game 20 | self.board_width = 6 21 | self.board_height = 6 22 | self.n_in_row = 4 23 | self.board = Board(width=self.board_width, height=self.board_height, n_in_row=self.n_in_row) 24 | self.game = Game(self.board) 25 | # training params 26 | self.learn_rate = 5e-3 27 | self.lr_multiplier = 1.0 # adaptively adjust the learning rate based on KL 28 | self.temp = 1.0 # the temperature param 29 | self.n_playout = 400 # num of simulations for each move 30 | self.c_puct = 5 31 | self.buffer_size = 10000 32 | self.batch_size = 512 # mini-batch size for training 33 | self.data_buffer = deque(maxlen=self.buffer_size) 34 | self.play_batch_size = 1 35 | self.epochs = 5 # num of train_steps for each update 36 | self.kl_targ = 0.025 37 | self.check_freq = 50 38 | self.game_batch_num = 1500 39 | self.best_win_ratio = 0.0 40 | # num of simulations used for the pure mcts, which is used as the opponent to evaluate the trained policy 41 | self.pure_mcts_playout_num = 1000 42 | # start training from a given policy-value net 43 | # policy_param = pickle.load(open('current_policy.model', 'rb')) 44 | # self.policy_value_net = PolicyValueNet(self.board_width, self.board_height, net_params = policy_param) 45 | # start training from a new policy-value net 46 | self.policy_value_net = PolicyValueNet(self.board_width, self.board_height) 47 | self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout, is_selfplay=1) 48 | 49 | def get_equi_data(self, play_data): 50 | """ 51 | augment the data set by rotation and flipping 52 | play_data: [(state, mcts_prob, winner_z), ..., ...]""" 53 | extend_data = [] 54 | for state, mcts_porb, winner in play_data: 55 | for i in [1,2,3,4]: 56 | # rotate counterclockwise 57 | equi_state = np.array([np.rot90(s,i) for s in state]) 58 | equi_mcts_prob = np.rot90(np.flipud(mcts_porb.reshape(self.board_height, self.board_width)), i) 59 | extend_data.append((equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) 60 | # flip horizontally 61 | equi_state = np.array([np.fliplr(s) for s in equi_state]) 62 | equi_mcts_prob = np.fliplr(equi_mcts_prob) 63 | extend_data.append((equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) 64 | return extend_data 65 | 66 | def collect_selfplay_data(self, n_games=1): 67 | """collect self-play data for training""" 68 | for i in range(n_games): 69 | winner, play_data = self.game.start_self_play(self.mcts_player, temp=self.temp) 70 | self.episode_len = len(play_data) 71 | # augment the data 72 | play_data = self.get_equi_data(play_data) 73 | self.data_buffer.extend(play_data) 74 | 75 | def policy_update(self): 76 | """update the policy-value net""" 77 | mini_batch = random.sample(self.data_buffer, self.batch_size) 78 | state_batch = [data[0] for data in mini_batch] 79 | mcts_probs_batch = [data[1] for data in mini_batch] 80 | winner_batch = [data[2] for data in mini_batch] 81 | old_probs, old_v = self.policy_value_net.policy_value(state_batch) 82 | for i in range(self.epochs): 83 | loss, entropy = self.policy_value_net.train_step(state_batch, mcts_probs_batch, winner_batch, self.learn_rate*self.lr_multiplier) 84 | new_probs, new_v = self.policy_value_net.policy_value(state_batch) 85 | kl = np.mean(np.sum(old_probs * (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1)) 86 | if kl > self.kl_targ * 4: # early stopping if D_KL diverges badly 87 | break 88 | # adaptively adjust the learning rate 89 | if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1: 90 | self.lr_multiplier /= 1.5 91 | elif kl < self.kl_targ / 2 and self.lr_multiplier < 10: 92 | self.lr_multiplier *= 1.5 93 | 94 | explained_var_old = 1 - np.var(np.array(winner_batch) - old_v.flatten())/np.var(np.array(winner_batch)) 95 | explained_var_new = 1 - np.var(np.array(winner_batch) - new_v.flatten())/np.var(np.array(winner_batch)) 96 | print("kl:{:.5f},lr_multiplier:{:.3f},loss:{},entropy:{},explained_var_old:{:.3f},explained_var_new:{:.3f}".format( 97 | kl, self.lr_multiplier, loss, entropy, explained_var_old, explained_var_new)) 98 | return loss, entropy 99 | 100 | def policy_evaluate(self, n_games=10): 101 | """ 102 | Evaluate the trained policy by playing games against the pure MCTS player 103 | Note: this is only for monitoring the progress of training 104 | """ 105 | current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout) 106 | pure_mcts_player = MCTS_Pure(c_puct=5, n_playout=self.pure_mcts_playout_num) 107 | win_cnt = defaultdict(int) 108 | for i in range(n_games): 109 | winner = self.game.start_play(current_mcts_player, pure_mcts_player, start_player=i%2, is_shown=0) 110 | win_cnt[winner] += 1 111 | win_ratio = 1.0*(win_cnt[1] + 0.5*win_cnt[-1])/n_games 112 | print("num_playouts:{}, win: {}, lose: {}, tie:{}".format(self.pure_mcts_playout_num, win_cnt[1], win_cnt[2], win_cnt[-1])) 113 | return win_ratio 114 | 115 | def run(self): 116 | """run the training pipeline""" 117 | try: 118 | for i in range(self.game_batch_num): 119 | self.collect_selfplay_data(self.play_batch_size) 120 | print("batch i:{}, episode_len:{}".format(i+1, self.episode_len)) 121 | if len(self.data_buffer) > self.batch_size: 122 | loss, entropy = self.policy_update() 123 | # check the performance of the current model,and save the model params 124 | if (i+1) % self.check_freq == 0: 125 | print("current self-play batch: {}".format(i+1)) 126 | win_ratio = self.policy_evaluate() 127 | net_params = self.policy_value_net.get_policy_param() # get model params 128 | pickle.dump(net_params, open('current_policy.model', 'wb'), pickle.HIGHEST_PROTOCOL) # save model param to file 129 | if win_ratio > self.best_win_ratio: 130 | print("New best policy!!!!!!!!") 131 | self.best_win_ratio = win_ratio 132 | pickle.dump(net_params, open('best_policy.model', 'wb'), pickle.HIGHEST_PROTOCOL) # update the best_policy 133 | if self.best_win_ratio == 1.0 and self.pure_mcts_playout_num < 5000: 134 | self.pure_mcts_playout_num += 1000 135 | self.best_win_ratio = 0.0 136 | except KeyboardInterrupt: 137 | print('\n\rquit') 138 | 139 | 140 | if __name__ == '__main__': 141 | training_pipeline = TrainPipeline() 142 | training_pipeline.run() 143 | -------------------------------------------------------------------------------- /mcts_pure.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | A pure implementation of the Monte Carlo Tree Search (MCTS) 4 | 5 | @author: Junxiao Song 6 | """ 7 | import numpy as np 8 | import copy 9 | from operator import itemgetter 10 | 11 | def rollout_policy_fn(board): 12 | """rollout_policy_fn -- a coarse, fast version of policy_fn used in the rollout phase.""" 13 | # rollout randomly 14 | action_probs = np.random.rand(len(board.availables)) 15 | return zip(board.availables, action_probs) 16 | 17 | def policy_value_fn(board): 18 | """a function that takes in a state and outputs a list of (action, probability) 19 | tuples and a score for the state""" 20 | # return uniform probabilities and 0 score for pure MCTS 21 | action_probs = np.ones(len(board.availables))/len(board.availables) 22 | return zip(board.availables, action_probs), 0 23 | 24 | class TreeNode(object): 25 | """A node in the MCTS tree. Each node keeps track of its own value Q, prior probability P, and 26 | its visit-count-adjusted prior score u. 27 | """ 28 | 29 | def __init__(self, parent, prior_p): 30 | self._parent = parent 31 | self._children = {} # a map from action to TreeNode 32 | self._n_visits = 0 33 | self._Q = 0 34 | self._u = 0 35 | self._P = prior_p 36 | 37 | def expand(self, action_priors): 38 | """Expand tree by creating new children. 39 | action_priors -- output from policy function - a list of tuples of actions 40 | and their prior probability according to the policy function. 41 | """ 42 | for action, prob in action_priors: 43 | if action not in self._children: 44 | self._children[action] = TreeNode(self, prob) 45 | 46 | def select(self, c_puct): 47 | """Select action among children that gives maximum action value, Q plus bonus u(P). 48 | Returns: 49 | A tuple of (action, next_node) 50 | """ 51 | return max(self._children.iteritems(), key=lambda act_node: act_node[1].get_value(c_puct)) 52 | 53 | def update(self, leaf_value): 54 | """Update node values from leaf evaluation. 55 | Arguments: 56 | leaf_value -- the value of subtree evaluation from the current player's perspective. 57 | """ 58 | # Count visit. 59 | self._n_visits += 1 60 | # Update Q, a running average of values for all visits. 61 | self._Q += 1.0*(leaf_value - self._Q) / self._n_visits 62 | 63 | def update_recursive(self, leaf_value): 64 | """Like a call to update(), but applied recursively for all ancestors. 65 | """ 66 | # If it is not root, this node's parent should be updated first. 67 | if self._parent: 68 | self._parent.update_recursive(-leaf_value) 69 | self.update(leaf_value) 70 | 71 | def get_value(self, c_puct): 72 | """Calculate and return the value for this node: a combination of leaf evaluations, Q, and 73 | this node's prior adjusted for its visit count, u 74 | c_puct -- a number in (0, inf) controlling the relative impact of values, Q, and 75 | prior probability, P, on this node's score. 76 | """ 77 | self._u = c_puct * self._P * np.sqrt(self._parent._n_visits) / (1 + self._n_visits) 78 | return self._Q + self._u 79 | 80 | def is_leaf(self): 81 | """Check if leaf node (i.e. no nodes below this have been expanded). 82 | """ 83 | return self._children == {} 84 | 85 | def is_root(self): 86 | return self._parent is None 87 | 88 | 89 | class MCTS(object): 90 | """A simple implementation of Monte Carlo Tree Search. 91 | """ 92 | 93 | def __init__(self, policy_value_fn, c_puct=5, n_playout=10000): 94 | """Arguments: 95 | policy_value_fn -- a function that takes in a board state and outputs a list of (action, probability) 96 | tuples and also a score in [-1, 1] (i.e. the expected value of the end game score from 97 | the current player's perspective) for the current player. 98 | c_puct -- a number in (0, inf) that controls how quickly exploration converges to the 99 | maximum-value policy, where a higher value means relying on the prior more 100 | """ 101 | self._root = TreeNode(None, 1.0) 102 | self._policy = policy_value_fn 103 | self._c_puct = c_puct 104 | self._n_playout = n_playout 105 | 106 | def _playout(self, state): 107 | """Run a single playout from the root to the leaf, getting a value at the leaf and 108 | propagating it back through its parents. State is modified in-place, so a copy must be 109 | provided. 110 | Arguments: 111 | state -- a copy of the state. 112 | """ 113 | node = self._root 114 | while(1): 115 | if node.is_leaf(): 116 | 117 | break 118 | # Greedily select next move. 119 | action, node = node.select(self._c_puct) 120 | state.do_move(action) 121 | 122 | action_probs, _ = self._policy(state) 123 | # Check for end of game 124 | end, winner = state.game_end() 125 | if not end: 126 | node.expand(action_probs) 127 | # Evaluate the leaf node by random rollout 128 | leaf_value = self._evaluate_rollout(state) 129 | # Update value and visit count of nodes in this traversal. 130 | node.update_recursive(-leaf_value) 131 | 132 | def _evaluate_rollout(self, state, limit=1000): 133 | """Use the rollout policy to play until the end of the game, returning +1 if the current 134 | player wins, -1 if the opponent wins, and 0 if it is a tie. 135 | """ 136 | player = state.get_current_player() 137 | for i in range(limit): 138 | end, winner = state.game_end() 139 | if end: 140 | break 141 | action_probs = rollout_policy_fn(state) 142 | max_action = max(action_probs, key=itemgetter(1))[0] 143 | state.do_move(max_action) 144 | else: 145 | # If no break from the loop, issue a warning. 146 | print("WARNING: rollout reached move limit") 147 | if winner == -1: # tie 148 | return 0 149 | else: 150 | return 1 if winner == player else -1 151 | 152 | def get_move(self, state): 153 | """Runs all playouts sequentially and returns the most visited action. 154 | Arguments: 155 | state -- the current state, including both game state and the current player. 156 | Returns: 157 | the selected action 158 | """ 159 | for n in range(self._n_playout): 160 | state_copy = copy.deepcopy(state) 161 | self._playout(state_copy) 162 | return max(self._root._children.iteritems(), key=lambda act_node: act_node[1]._n_visits)[0] 163 | 164 | def update_with_move(self, last_move): 165 | """Step forward in the tree, keeping everything we already know about the subtree. 166 | """ 167 | if last_move in self._root._children: 168 | self._root = self._root._children[last_move] 169 | self._root._parent = None 170 | else: 171 | self._root = TreeNode(None, 1.0) 172 | 173 | def __str__(self): 174 | return "MCTS" 175 | 176 | 177 | class MCTSPlayer(object): 178 | """AI player based on MCTS""" 179 | def __init__(self, c_puct=5, n_playout=2000): 180 | self.mcts = MCTS(policy_value_fn, c_puct, n_playout) 181 | 182 | def set_player_ind(self, p): 183 | self.player = p 184 | 185 | def reset_player(self): 186 | self.mcts.update_with_move(-1) 187 | 188 | def get_action(self, board): 189 | sensible_moves = board.availables 190 | if len(sensible_moves) > 0: 191 | move = self.mcts.get_move(board) 192 | self.mcts.update_with_move(-1) 193 | return move 194 | else: 195 | print("WARNING: the board is full") 196 | 197 | def __str__(self): 198 | return "MCTS {}".format(self.player) -------------------------------------------------------------------------------- /mcts_alphaZero.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Monte Carlo Tree Search in AlphaGo Zero style, which uses a policy-value network 4 | to guide the tree search and evaluate the leaf nodes 5 | 6 | @author: Junxiao Song 7 | """ 8 | import numpy as np 9 | import copy 10 | 11 | 12 | def softmax(x): 13 | probs = np.exp(x - np.max(x)) 14 | probs /= np.sum(probs) 15 | return probs 16 | 17 | class TreeNode(object): 18 | """A node in the MCTS tree. Each node keeps track of its own value Q, prior probability P, and 19 | its visit-count-adjusted prior score u. 20 | """ 21 | 22 | def __init__(self, parent, prior_p): 23 | self._parent = parent 24 | self._children = {} # a map from action to TreeNode 25 | self._n_visits = 0 26 | self._Q = 0 27 | self._u = 0 28 | self._P = prior_p 29 | 30 | def expand(self, action_priors): 31 | """Expand tree by creating new children. 32 | action_priors -- output from policy function - a list of tuples of actions 33 | and their prior probability according to the policy function. 34 | """ 35 | for action, prob in action_priors: 36 | if action not in self._children: 37 | self._children[action] = TreeNode(self, prob) 38 | 39 | def select(self, c_puct): 40 | """Select action among children that gives maximum action value, Q plus bonus u(P). 41 | Returns: 42 | A tuple of (action, next_node) 43 | """ 44 | return max(self._children.items(), key=lambda act_node: act_node[1].get_value(c_puct)) 45 | 46 | def update(self, leaf_value): 47 | """Update node values from leaf evaluation. 48 | Arguments: 49 | leaf_value -- the value of subtree evaluation from the current player's perspective. 50 | """ 51 | # Count visit. 52 | self._n_visits += 1 53 | # Update Q, a running average of values for all visits. 54 | self._Q += 1.0*(leaf_value - self._Q) / self._n_visits 55 | 56 | def update_recursive(self, leaf_value): 57 | """Like a call to update(), but applied recursively for all ancestors. 58 | """ 59 | # If it is not root, this node's parent should be updated first. 60 | if self._parent: 61 | self._parent.update_recursive(-leaf_value) 62 | self.update(leaf_value) 63 | 64 | def get_value(self, c_puct): 65 | """Calculate and return the value for this node: a combination of leaf evaluations, Q, and 66 | this node's prior adjusted for its visit count, u 67 | c_puct -- a number in (0, inf) controlling the relative impact of values, Q, and 68 | prior probability, P, on this node's score. 69 | """ 70 | self._u = c_puct * self._P * np.sqrt(self._parent._n_visits) / (1 + self._n_visits) 71 | return self._Q + self._u 72 | 73 | def is_leaf(self): 74 | """Check if leaf node (i.e. no nodes below this have been expanded). 75 | """ 76 | return self._children == {} 77 | 78 | def is_root(self): 79 | return self._parent is None 80 | 81 | 82 | class MCTS(object): 83 | """A simple implementation of Monte Carlo Tree Search. 84 | """ 85 | 86 | def __init__(self, policy_value_fn, c_puct=5, n_playout=10000): 87 | """Arguments: 88 | policy_value_fn -- a function that takes in a board state and outputs a list of (action, probability) 89 | tuples and also a score in [-1, 1] (i.e. the expected value of the end game score from 90 | the current player's perspective) for the current player. 91 | c_puct -- a number in (0, inf) that controls how quickly exploration converges to the 92 | maximum-value policy, where a higher value means relying on the prior more 93 | """ 94 | self._root = TreeNode(None, 1.0) 95 | self._policy = policy_value_fn 96 | self._c_puct = c_puct 97 | self._n_playout = n_playout 98 | 99 | def _playout(self, state): 100 | """Run a single playout from the root to the leaf, getting a value at the leaf and 101 | propagating it back through its parents. State is modified in-place, so a copy must be 102 | provided. 103 | Arguments: 104 | state -- a copy of the state. 105 | """ 106 | node = self._root 107 | while(1): 108 | if node.is_leaf(): 109 | break 110 | # Greedily select next move. 111 | action, node = node.select(self._c_puct) 112 | state.do_move(action) 113 | 114 | # Evaluate the leaf using a network which outputs a list of (action, probability) 115 | # tuples p and also a score v in [-1, 1] for the current player. 116 | action_probs, leaf_value = self._policy(state) 117 | # Check for end of game. 118 | end, winner = state.game_end() 119 | if not end: 120 | node.expand(action_probs) 121 | else: 122 | # for end state,return the "true" leaf_value 123 | if winner == -1: # tie 124 | leaf_value = 0.0 125 | else: 126 | leaf_value = 1.0 if winner == state.get_current_player() else -1.0 127 | 128 | # Update value and visit count of nodes in this traversal. 129 | node.update_recursive(-leaf_value) 130 | 131 | def get_move_probs(self, state, temp=1e-3): 132 | """Runs all playouts sequentially and returns the available actions and their corresponding probabilities 133 | Arguments: 134 | state -- the current state, including both game state and the current player. 135 | temp -- temperature parameter in (0, 1] that controls the level of exploration 136 | Returns: 137 | the available actions and the corresponding probabilities 138 | """ 139 | for n in range(self._n_playout): 140 | state_copy = copy.deepcopy(state) 141 | self._playout(state_copy) 142 | 143 | # calc the move probabilities based on the visit counts at the root node 144 | act_visits = [(act, node._n_visits) for act, node in self._root._children.items()] 145 | acts, visits = zip(*act_visits) 146 | act_probs = softmax(1.0/temp * np.log(visits)) 147 | 148 | return acts, act_probs 149 | 150 | def update_with_move(self, last_move): 151 | """Step forward in the tree, keeping everything we already know about the subtree. 152 | """ 153 | if last_move in self._root._children: 154 | self._root = self._root._children[last_move] 155 | self._root._parent = None 156 | else: 157 | self._root = TreeNode(None, 1.0) 158 | 159 | def __str__(self): 160 | return "MCTS" 161 | 162 | 163 | class MCTSPlayer(object): 164 | """AI player based on MCTS""" 165 | def __init__(self, policy_value_function, c_puct=5, n_playout=2000, is_selfplay=0): 166 | self.mcts = MCTS(policy_value_function, c_puct, n_playout) 167 | self._is_selfplay = is_selfplay 168 | 169 | def set_player_ind(self, p): 170 | self.player = p 171 | 172 | def reset_player(self): 173 | self.mcts.update_with_move(-1) 174 | 175 | def get_action(self, board, temp=1e-3, return_prob=0): 176 | sensible_moves = board.availables 177 | move_probs = np.zeros(board.width*board.height) # the pi vector returned by MCTS as in the alphaGo Zero paper 178 | if len(sensible_moves) > 0: 179 | acts, probs = self.mcts.get_move_probs(board, temp) 180 | move_probs[list(acts)] = probs 181 | if self._is_selfplay: 182 | # add Dirichlet Noise for exploration (needed for self-play training) 183 | move = np.random.choice(acts, p=0.75*probs + 0.25*np.random.dirichlet(0.3*np.ones(len(probs)))) 184 | self.mcts.update_with_move(move) # update the root node and reuse the search tree 185 | else: 186 | # with the default temp=1e-3, this is almost equivalent to choosing the move with the highest prob 187 | move = np.random.choice(acts, p=probs) 188 | # reset the root node 189 | self.mcts.update_with_move(-1) 190 | # location = board.move_to_location(move) 191 | # print("AI move: %d,%d\n" % (location[0], location[1])) 192 | 193 | if return_prob: 194 | return move, move_probs 195 | else: 196 | return move 197 | else: 198 | print("WARNING: the board is full") 199 | 200 | def __str__(self): 201 | return "MCTS {}".format(self.player) -------------------------------------------------------------------------------- /game.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: Junxiao Song 4 | """ 5 | 6 | from __future__ import print_function 7 | import numpy as np 8 | 9 | class Board(object): 10 | """ 11 | board for the game 12 | """ 13 | 14 | def __init__(self, **kwargs): 15 | self.width = int(kwargs.get('width', 8)) 16 | self.height = int(kwargs.get('height', 8)) 17 | self.states = {} # board states, key:move as location on the board, value:player as pieces type 18 | self.n_in_row = int(kwargs.get('n_in_row', 5)) # need how many pieces in a row to win 19 | self.players = [1, 2] # player1 and player2 20 | 21 | def init_board(self, start_player=0): 22 | if self.width < self.n_in_row or self.height < self.n_in_row: 23 | raise Exception('board width and height can not less than %d' % self.n_in_row) 24 | self.current_player = self.players[start_player] # start player 25 | self.availables = list(range(self.width * self.height)) # available moves 26 | self.states = {} # board states, key:move as location on the board, value:player as pieces type 27 | self.last_move = -1 28 | 29 | def move_to_location(self, move): 30 | """ 31 | 3*3 board's moves like: 32 | 6 7 8 33 | 3 4 5 34 | 0 1 2 35 | and move 5's location is (1,2) 36 | """ 37 | h = move // self.width 38 | w = move % self.width 39 | return [h, w] 40 | 41 | def location_to_move(self, location): 42 | if(len(location) != 2): 43 | return -1 44 | h = location[0] 45 | w = location[1] 46 | move = h * self.width + w 47 | if(move not in range(self.width * self.height)): 48 | return -1 49 | return move 50 | 51 | def current_state(self): 52 | """return the board state from the perspective of the current player 53 | shape: 4*width*height""" 54 | 55 | square_state = np.zeros((4, self.width, self.height)) 56 | if self.states: 57 | moves, players = np.array(list(zip(*self.states.items()))) 58 | move_curr = moves[players == self.current_player] 59 | move_oppo = moves[players != self.current_player] 60 | square_state[0][move_curr // self.width, move_curr % self.height] = 1.0 61 | square_state[1][move_oppo // self.width, move_oppo % self.height] = 1.0 62 | square_state[2][self.last_move //self.width, self.last_move % self.height] = 1.0 # last move indication 63 | if len(self.states)%2 == 0: 64 | square_state[3][:,:] = 1.0 65 | return square_state[:,::-1,:] 66 | 67 | def do_move(self, move): 68 | self.states[move] = self.current_player 69 | self.availables.remove(move) 70 | self.current_player = self.players[0] if self.current_player == self.players[1] else self.players[1] 71 | self.last_move = move 72 | 73 | def has_a_winner(self): 74 | width = self.width 75 | height = self.height 76 | states = self.states 77 | n = self.n_in_row 78 | 79 | moved = list(set(range(width * height)) - set(self.availables)) 80 | if(len(moved) < self.n_in_row + 2): 81 | return False, -1 82 | 83 | for m in moved: 84 | h = m // width 85 | w = m % width 86 | player = states[m] 87 | 88 | if (w in range(width - n + 1) and 89 | len(set(states.get(i, -1) for i in range(m, m + n))) == 1): 90 | return True, player 91 | 92 | if (h in range(height - n + 1) and 93 | len(set(states.get(i, -1) for i in range(m, m + n * width, width))) == 1): 94 | return True, player 95 | 96 | if (w in range(width - n + 1) and h in range(height - n + 1) and 97 | len(set(states.get(i, -1) for i in range(m, m + n * (width + 1), width + 1))) == 1): 98 | return True, player 99 | 100 | if (w in range(n - 1, width) and h in range(height - n + 1) and 101 | len(set(states.get(i, -1) for i in range(m, m + n * (width - 1), width - 1))) == 1): 102 | return True, player 103 | 104 | return False, -1 105 | 106 | def game_end(self): 107 | """Check whether the game is ended or not""" 108 | win, winner = self.has_a_winner() 109 | if win: 110 | return True, winner 111 | elif not len(self.availables):# 112 | return True, -1 113 | return False, -1 114 | 115 | def get_current_player(self): 116 | return self.current_player 117 | 118 | 119 | class Game(object): 120 | """ 121 | game server 122 | """ 123 | 124 | def __init__(self, board, **kwargs): 125 | self.board = board 126 | 127 | def graphic(self, board, player1, player2): 128 | """ 129 | Draw the board and show game info 130 | """ 131 | width = board.width 132 | height = board.height 133 | 134 | print("Player", player1, "with X".rjust(3)) 135 | print("Player", player2, "with O".rjust(3)) 136 | print() 137 | for x in range(width): 138 | print("{0:8}".format(x), end='') 139 | print('\r\n') 140 | for i in range(height - 1, -1, -1): 141 | print("{0:4d}".format(i), end='') 142 | for j in range(width): 143 | loc = i * width + j 144 | p = board.states.get(loc, -1) 145 | if p == player1: 146 | print('X'.center(8), end='') 147 | elif p == player2: 148 | print('O'.center(8), end='') 149 | else: 150 | print('_'.center(8), end='') 151 | print('\r\n\r\n') 152 | 153 | def start_play(self, player1, player2, start_player=0, is_shown=1): 154 | """ 155 | start a game between two players 156 | """ 157 | if start_player not in (0,1): 158 | raise Exception('start_player should be 0 (player1 first) or 1 (player2 first)') 159 | self.board.init_board(start_player) 160 | p1, p2 = self.board.players 161 | player1.set_player_ind(p1) 162 | player2.set_player_ind(p2) 163 | players = {p1: player1, p2:player2} 164 | if is_shown: 165 | self.graphic(self.board, player1.player, player2.player) 166 | while(1): 167 | current_player = self.board.get_current_player() 168 | player_in_turn = players[current_player] 169 | move = player_in_turn.get_action(self.board) 170 | self.board.do_move(move) 171 | if is_shown: 172 | self.graphic(self.board, player1.player, player2.player) 173 | end, winner = self.board.game_end() 174 | if end: 175 | if is_shown: 176 | if winner != -1: 177 | print("Game end. Winner is", players[winner]) 178 | else: 179 | print("Game end. Tie") 180 | return winner 181 | 182 | 183 | def start_self_play(self, player, is_shown=0, temp=1e-3): 184 | """ start a self-play game using a MCTS player, reuse the search tree 185 | store the self-play data: (state, mcts_probs, z) 186 | """ 187 | self.board.init_board() 188 | p1, p2 = self.board.players 189 | states, mcts_probs, current_players = [], [], [] 190 | while(1): 191 | move, move_probs = player.get_action(self.board, temp=temp, return_prob=1) 192 | # store the data 193 | states.append(self.board.current_state()) 194 | mcts_probs.append(move_probs) 195 | current_players.append(self.board.current_player) 196 | # perform a move 197 | self.board.do_move(move) 198 | if is_shown: 199 | self.graphic(self.board, p1, p2) 200 | end, winner = self.board.game_end() 201 | if end: 202 | # winner from the perspective of the current player of each state 203 | winners_z = np.zeros(len(current_players)) 204 | if winner != -1: 205 | winners_z[np.array(current_players) == winner] = 1.0 206 | winners_z[np.array(current_players) != winner] = -1.0 207 | #reset MCTS root node 208 | player.reset_player() 209 | if is_shown: 210 | if winner != -1: 211 | print("Game end. Winner is player:", winner) 212 | else: 213 | print("Game end. Tie") 214 | return winner, zip(states, mcts_probs, winners_z) 215 | --------------------------------------------------------------------------------