├── .gitignore ├── LICENSE ├── README.md ├── best_policy_6_6_4.model ├── best_policy_6_6_4.model2 ├── best_policy_8_8_5.model ├── best_policy_8_8_5.model2 ├── game.py ├── human_play.py ├── mcts_alphaZero.py ├── mcts_pure.py ├── playout400.gif ├── policy_value_net.py ├── policy_value_net_keras.py ├── policy_value_net_numpy.py ├── policy_value_net_pytorch.py ├── policy_value_net_tensorflow.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | __pycache__/ 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 junxiaosong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## AlphaZero-Gomoku 2 | This is an implementation of the AlphaZero algorithm for playing the simple board game Gomoku (also called Gobang or Five in a Row) from pure self-play training. The game Gomoku is much simpler than Go or chess, so that we can focus on the training scheme of AlphaZero and obtain a pretty good AI model on a single PC in a few hours. 3 | 4 | References: 5 | 1. AlphaZero: Mastering Chess and Shogi by Self-Play with a General Reinforcement Learning Algorithm 6 | 2. AlphaGo Zero: Mastering the game of Go without human knowledge 7 | 8 | ### Update 2018.2.24: supports training with TensorFlow! 9 | ### Update 2018.1.17: supports training with PyTorch! 10 | 11 | ### Example Games Between Trained Models 12 | - Each move with 400 MCTS playouts: 13 | ![playout400](https://raw.githubusercontent.com/junxiaosong/AlphaZero_Gomoku/master/playout400.gif) 14 | 15 | ### Requirements 16 | To play with the trained AI models, only need: 17 | - Python >= 2.7 18 | - Numpy >= 1.11 19 | 20 | To train the AI model from scratch, further need, either: 21 | - Theano >= 0.7 and Lasagne >= 0.1 22 | or 23 | - PyTorch >= 0.2.0 24 | or 25 | - TensorFlow 26 | 27 | **PS**: if your Theano's version > 0.7, please follow this [issue](https://github.com/aigamedev/scikit-neuralnetwork/issues/235) to install Lasagne, 28 | otherwise, force pip to downgrade Theano to 0.7 ``pip install --upgrade theano==0.7.0`` 29 | 30 | If you would like to train the model using other DL frameworks, you only need to rewrite policy_value_net.py. 31 | 32 | ### Getting Started 33 | To play with provided models, run the following script from the directory: 34 | ``` 35 | python human_play.py 36 | ``` 37 | You may modify human_play.py to try different provided models or the pure MCTS. 38 | 39 | To train the AI model from scratch, with Theano and Lasagne, directly run: 40 | ``` 41 | python train.py 42 | ``` 43 | With PyTorch or TensorFlow, first modify the file [train.py](https://github.com/junxiaosong/AlphaZero_Gomoku/blob/master/train.py), i.e., comment the line 44 | ``` 45 | from policy_value_net import PolicyValueNet # Theano and Lasagne 46 | ``` 47 | and uncomment the line 48 | ``` 49 | # from policy_value_net_pytorch import PolicyValueNet # Pytorch 50 | or 51 | # from policy_value_net_tensorflow import PolicyValueNet # Tensorflow 52 | ``` 53 | and then execute: ``python train.py`` (To use GPU in PyTorch, set ``use_gpu=True`` and use ``return loss.item(), entropy.item()`` in function train_step in policy_value_net_pytorch.py if your pytorch version is greater than 0.5) 54 | 55 | The models (best_policy.model and current_policy.model) will be saved every a few updates (default 50). 56 | 57 | **Note:** the 4 provided models were trained using Theano/Lasagne, to use them with PyTorch, please refer to [issue 5](https://github.com/junxiaosong/AlphaZero_Gomoku/issues/5). 58 | 59 | **Tips for training:** 60 | 1. It is good to start with a 6 * 6 board and 4 in a row. For this case, we may obtain a reasonably good model within 500~1000 self-play games in about 2 hours. 61 | 2. For the case of 8 * 8 board and 5 in a row, it may need 2000~3000 self-play games to get a good model, and it may take about 2 days on a single PC. 62 | 63 | ### Further reading 64 | My article describing some details about the implementation in Chinese: [https://zhuanlan.zhihu.com/p/32089487](https://zhuanlan.zhihu.com/p/32089487) 65 | -------------------------------------------------------------------------------- /best_policy_6_6_4.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junxiaosong/AlphaZero_Gomoku/0b4917ccd84ea2771706701015eb26182c1159f4/best_policy_6_6_4.model -------------------------------------------------------------------------------- /best_policy_6_6_4.model2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junxiaosong/AlphaZero_Gomoku/0b4917ccd84ea2771706701015eb26182c1159f4/best_policy_6_6_4.model2 -------------------------------------------------------------------------------- /best_policy_8_8_5.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junxiaosong/AlphaZero_Gomoku/0b4917ccd84ea2771706701015eb26182c1159f4/best_policy_8_8_5.model -------------------------------------------------------------------------------- /best_policy_8_8_5.model2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junxiaosong/AlphaZero_Gomoku/0b4917ccd84ea2771706701015eb26182c1159f4/best_policy_8_8_5.model2 -------------------------------------------------------------------------------- /game.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: Junxiao Song 4 | """ 5 | 6 | from __future__ import print_function 7 | import numpy as np 8 | 9 | 10 | class Board(object): 11 | """board for the game""" 12 | 13 | def __init__(self, **kwargs): 14 | self.width = int(kwargs.get('width', 8)) 15 | self.height = int(kwargs.get('height', 8)) 16 | # board states stored as a dict, 17 | # key: move as location on the board, 18 | # value: player as pieces type 19 | self.states = {} 20 | # need how many pieces in a row to win 21 | self.n_in_row = int(kwargs.get('n_in_row', 5)) 22 | self.players = [1, 2] # player1 and player2 23 | 24 | def init_board(self, start_player=0): 25 | if self.width < self.n_in_row or self.height < self.n_in_row: 26 | raise Exception('board width and height can not be ' 27 | 'less than {}'.format(self.n_in_row)) 28 | self.current_player = self.players[start_player] # start player 29 | # keep available moves in a list 30 | self.availables = list(range(self.width * self.height)) 31 | self.states = {} 32 | self.last_move = -1 33 | 34 | def move_to_location(self, move): 35 | """ 36 | 3*3 board's moves like: 37 | 6 7 8 38 | 3 4 5 39 | 0 1 2 40 | and move 5's location is (1,2) 41 | """ 42 | h = move // self.width 43 | w = move % self.width 44 | return [h, w] 45 | 46 | def location_to_move(self, location): 47 | if len(location) != 2: 48 | return -1 49 | h = location[0] 50 | w = location[1] 51 | move = h * self.width + w 52 | if move not in range(self.width * self.height): 53 | return -1 54 | return move 55 | 56 | def current_state(self): 57 | """return the board state from the perspective of the current player. 58 | state shape: 4*width*height 59 | """ 60 | 61 | square_state = np.zeros((4, self.width, self.height)) 62 | if self.states: 63 | moves, players = np.array(list(zip(*self.states.items()))) 64 | move_curr = moves[players == self.current_player] 65 | move_oppo = moves[players != self.current_player] 66 | square_state[0][move_curr // self.width, 67 | move_curr % self.height] = 1.0 68 | square_state[1][move_oppo // self.width, 69 | move_oppo % self.height] = 1.0 70 | # indicate the last move location 71 | square_state[2][self.last_move // self.width, 72 | self.last_move % self.height] = 1.0 73 | if len(self.states) % 2 == 0: 74 | square_state[3][:, :] = 1.0 # indicate the colour to play 75 | return square_state[:, ::-1, :] 76 | 77 | def do_move(self, move): 78 | self.states[move] = self.current_player 79 | self.availables.remove(move) 80 | self.current_player = ( 81 | self.players[0] if self.current_player == self.players[1] 82 | else self.players[1] 83 | ) 84 | self.last_move = move 85 | 86 | def has_a_winner(self): 87 | width = self.width 88 | height = self.height 89 | states = self.states 90 | n = self.n_in_row 91 | 92 | moved = list(set(range(width * height)) - set(self.availables)) 93 | if len(moved) < self.n_in_row *2-1: 94 | return False, -1 95 | 96 | for m in moved: 97 | h = m // width 98 | w = m % width 99 | player = states[m] 100 | 101 | if (w in range(width - n + 1) and 102 | len(set(states.get(i, -1) for i in range(m, m + n))) == 1): 103 | return True, player 104 | 105 | if (h in range(height - n + 1) and 106 | len(set(states.get(i, -1) for i in range(m, m + n * width, width))) == 1): 107 | return True, player 108 | 109 | if (w in range(width - n + 1) and h in range(height - n + 1) and 110 | len(set(states.get(i, -1) for i in range(m, m + n * (width + 1), width + 1))) == 1): 111 | return True, player 112 | 113 | if (w in range(n - 1, width) and h in range(height - n + 1) and 114 | len(set(states.get(i, -1) for i in range(m, m + n * (width - 1), width - 1))) == 1): 115 | return True, player 116 | 117 | return False, -1 118 | 119 | def game_end(self): 120 | """Check whether the game is ended or not""" 121 | win, winner = self.has_a_winner() 122 | if win: 123 | return True, winner 124 | elif not len(self.availables): 125 | return True, -1 126 | return False, -1 127 | 128 | def get_current_player(self): 129 | return self.current_player 130 | 131 | 132 | class Game(object): 133 | """game server""" 134 | 135 | def __init__(self, board, **kwargs): 136 | self.board = board 137 | 138 | def graphic(self, board, player1, player2): 139 | """Draw the board and show game info""" 140 | width = board.width 141 | height = board.height 142 | 143 | print("Player", player1, "with X".rjust(3)) 144 | print("Player", player2, "with O".rjust(3)) 145 | print() 146 | for x in range(width): 147 | print("{0:8}".format(x), end='') 148 | print('\r\n') 149 | for i in range(height - 1, -1, -1): 150 | print("{0:4d}".format(i), end='') 151 | for j in range(width): 152 | loc = i * width + j 153 | p = board.states.get(loc, -1) 154 | if p == player1: 155 | print('X'.center(8), end='') 156 | elif p == player2: 157 | print('O'.center(8), end='') 158 | else: 159 | print('_'.center(8), end='') 160 | print('\r\n\r\n') 161 | 162 | def start_play(self, player1, player2, start_player=0, is_shown=1): 163 | """start a game between two players""" 164 | if start_player not in (0, 1): 165 | raise Exception('start_player should be either 0 (player1 first) ' 166 | 'or 1 (player2 first)') 167 | self.board.init_board(start_player) 168 | p1, p2 = self.board.players 169 | player1.set_player_ind(p1) 170 | player2.set_player_ind(p2) 171 | players = {p1: player1, p2: player2} 172 | if is_shown: 173 | self.graphic(self.board, player1.player, player2.player) 174 | while True: 175 | current_player = self.board.get_current_player() 176 | player_in_turn = players[current_player] 177 | move = player_in_turn.get_action(self.board) 178 | self.board.do_move(move) 179 | if is_shown: 180 | self.graphic(self.board, player1.player, player2.player) 181 | end, winner = self.board.game_end() 182 | if end: 183 | if is_shown: 184 | if winner != -1: 185 | print("Game end. Winner is", players[winner]) 186 | else: 187 | print("Game end. Tie") 188 | return winner 189 | 190 | def start_self_play(self, player, is_shown=0, temp=1e-3): 191 | """ start a self-play game using a MCTS player, reuse the search tree, 192 | and store the self-play data: (state, mcts_probs, z) for training 193 | """ 194 | self.board.init_board() 195 | p1, p2 = self.board.players 196 | states, mcts_probs, current_players = [], [], [] 197 | while True: 198 | move, move_probs = player.get_action(self.board, 199 | temp=temp, 200 | return_prob=1) 201 | # store the data 202 | states.append(self.board.current_state()) 203 | mcts_probs.append(move_probs) 204 | current_players.append(self.board.current_player) 205 | # perform a move 206 | self.board.do_move(move) 207 | if is_shown: 208 | self.graphic(self.board, p1, p2) 209 | end, winner = self.board.game_end() 210 | if end: 211 | # winner from the perspective of the current player of each state 212 | winners_z = np.zeros(len(current_players)) 213 | if winner != -1: 214 | winners_z[np.array(current_players) == winner] = 1.0 215 | winners_z[np.array(current_players) != winner] = -1.0 216 | # reset MCTS root node 217 | player.reset_player() 218 | if is_shown: 219 | if winner != -1: 220 | print("Game end. Winner is player:", winner) 221 | else: 222 | print("Game end. Tie") 223 | return winner, zip(states, mcts_probs, winners_z) 224 | -------------------------------------------------------------------------------- /human_play.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | human VS AI models 4 | Input your move in the format: 2,3 5 | 6 | @author: Junxiao Song 7 | """ 8 | 9 | from __future__ import print_function 10 | import pickle 11 | from game import Board, Game 12 | from mcts_pure import MCTSPlayer as MCTS_Pure 13 | from mcts_alphaZero import MCTSPlayer 14 | from policy_value_net_numpy import PolicyValueNetNumpy 15 | # from policy_value_net import PolicyValueNet # Theano and Lasagne 16 | # from policy_value_net_pytorch import PolicyValueNet # Pytorch 17 | # from policy_value_net_tensorflow import PolicyValueNet # Tensorflow 18 | # from policy_value_net_keras import PolicyValueNet # Keras 19 | 20 | 21 | class Human(object): 22 | """ 23 | human player 24 | """ 25 | 26 | def __init__(self): 27 | self.player = None 28 | 29 | def set_player_ind(self, p): 30 | self.player = p 31 | 32 | def get_action(self, board): 33 | try: 34 | location = input("Your move: ") 35 | if isinstance(location, str): # for python3 36 | location = [int(n, 10) for n in location.split(",")] 37 | move = board.location_to_move(location) 38 | except Exception as e: 39 | move = -1 40 | if move == -1 or move not in board.availables: 41 | print("invalid move") 42 | move = self.get_action(board) 43 | return move 44 | 45 | def __str__(self): 46 | return "Human {}".format(self.player) 47 | 48 | 49 | def run(): 50 | n = 5 51 | width, height = 8, 8 52 | model_file = 'best_policy_8_8_5.model' 53 | try: 54 | board = Board(width=width, height=height, n_in_row=n) 55 | game = Game(board) 56 | 57 | # ############### human VS AI ################### 58 | # load the trained policy_value_net in either Theano/Lasagne, PyTorch or TensorFlow 59 | 60 | # best_policy = PolicyValueNet(width, height, model_file = model_file) 61 | # mcts_player = MCTSPlayer(best_policy.policy_value_fn, c_puct=5, n_playout=400) 62 | 63 | # load the provided model (trained in Theano/Lasagne) into a MCTS player written in pure numpy 64 | try: 65 | policy_param = pickle.load(open(model_file, 'rb')) 66 | except: 67 | policy_param = pickle.load(open(model_file, 'rb'), 68 | encoding='bytes') # To support python3 69 | best_policy = PolicyValueNetNumpy(width, height, policy_param) 70 | mcts_player = MCTSPlayer(best_policy.policy_value_fn, 71 | c_puct=5, 72 | n_playout=400) # set larger n_playout for better performance 73 | 74 | # uncomment the following line to play with pure MCTS (it's much weaker even with a larger n_playout) 75 | # mcts_player = MCTS_Pure(c_puct=5, n_playout=1000) 76 | 77 | # human player, input your move in the format: 2,3 78 | human = Human() 79 | 80 | # set start_player=0 for human first 81 | game.start_play(human, mcts_player, start_player=1, is_shown=1) 82 | except KeyboardInterrupt: 83 | print('\n\rquit') 84 | 85 | 86 | if __name__ == '__main__': 87 | run() 88 | -------------------------------------------------------------------------------- /mcts_alphaZero.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Monte Carlo Tree Search in AlphaGo Zero style, which uses a policy-value 4 | network to guide the tree search and evaluate the leaf nodes 5 | 6 | @author: Junxiao Song 7 | """ 8 | 9 | import numpy as np 10 | import copy 11 | 12 | 13 | def softmax(x): 14 | probs = np.exp(x - np.max(x)) 15 | probs /= np.sum(probs) 16 | return probs 17 | 18 | 19 | class TreeNode(object): 20 | """A node in the MCTS tree. 21 | 22 | Each node keeps track of its own value Q, prior probability P, and 23 | its visit-count-adjusted prior score u. 24 | """ 25 | 26 | def __init__(self, parent, prior_p): 27 | self._parent = parent 28 | self._children = {} # a map from action to TreeNode 29 | self._n_visits = 0 30 | self._Q = 0 31 | self._u = 0 32 | self._P = prior_p 33 | 34 | def expand(self, action_priors): 35 | """Expand tree by creating new children. 36 | action_priors: a list of tuples of actions and their prior probability 37 | according to the policy function. 38 | """ 39 | for action, prob in action_priors: 40 | if action not in self._children: 41 | self._children[action] = TreeNode(self, prob) 42 | 43 | def select(self, c_puct): 44 | """Select action among children that gives maximum action value Q 45 | plus bonus u(P). 46 | Return: A tuple of (action, next_node) 47 | """ 48 | return max(self._children.items(), 49 | key=lambda act_node: act_node[1].get_value(c_puct)) 50 | 51 | def update(self, leaf_value): 52 | """Update node values from leaf evaluation. 53 | leaf_value: the value of subtree evaluation from the current player's 54 | perspective. 55 | """ 56 | # Count visit. 57 | self._n_visits += 1 58 | # Update Q, a running average of values for all visits. 59 | self._Q += 1.0*(leaf_value - self._Q) / self._n_visits 60 | 61 | def update_recursive(self, leaf_value): 62 | """Like a call to update(), but applied recursively for all ancestors. 63 | """ 64 | # If it is not root, this node's parent should be updated first. 65 | if self._parent: 66 | self._parent.update_recursive(-leaf_value) 67 | self.update(leaf_value) 68 | 69 | def get_value(self, c_puct): 70 | """Calculate and return the value for this node. 71 | It is a combination of leaf evaluations Q, and this node's prior 72 | adjusted for its visit count, u. 73 | c_puct: a number in (0, inf) controlling the relative impact of 74 | value Q, and prior probability P, on this node's score. 75 | """ 76 | self._u = (c_puct * self._P * 77 | np.sqrt(self._parent._n_visits) / (1 + self._n_visits)) 78 | return self._Q + self._u 79 | 80 | def is_leaf(self): 81 | """Check if leaf node (i.e. no nodes below this have been expanded).""" 82 | return self._children == {} 83 | 84 | def is_root(self): 85 | return self._parent is None 86 | 87 | 88 | class MCTS(object): 89 | """An implementation of Monte Carlo Tree Search.""" 90 | 91 | def __init__(self, policy_value_fn, c_puct=5, n_playout=10000): 92 | """ 93 | policy_value_fn: a function that takes in a board state and outputs 94 | a list of (action, probability) tuples and also a score in [-1, 1] 95 | (i.e. the expected value of the end game score from the current 96 | player's perspective) for the current player. 97 | c_puct: a number in (0, inf) that controls how quickly exploration 98 | converges to the maximum-value policy. A higher value means 99 | relying on the prior more. 100 | """ 101 | self._root = TreeNode(None, 1.0) 102 | self._policy = policy_value_fn 103 | self._c_puct = c_puct 104 | self._n_playout = n_playout 105 | 106 | def _playout(self, state): 107 | """Run a single playout from the root to the leaf, getting a value at 108 | the leaf and propagating it back through its parents. 109 | State is modified in-place, so a copy must be provided. 110 | """ 111 | node = self._root 112 | while(1): 113 | if node.is_leaf(): 114 | break 115 | # Greedily select next move. 116 | action, node = node.select(self._c_puct) 117 | state.do_move(action) 118 | 119 | # Evaluate the leaf using a network which outputs a list of 120 | # (action, probability) tuples p and also a score v in [-1, 1] 121 | # for the current player. 122 | action_probs, leaf_value = self._policy(state) 123 | # Check for end of game. 124 | end, winner = state.game_end() 125 | if not end: 126 | node.expand(action_probs) 127 | else: 128 | # for end state,return the "true" leaf_value 129 | if winner == -1: # tie 130 | leaf_value = 0.0 131 | else: 132 | leaf_value = ( 133 | 1.0 if winner == state.get_current_player() else -1.0 134 | ) 135 | 136 | # Update value and visit count of nodes in this traversal. 137 | node.update_recursive(-leaf_value) 138 | 139 | def get_move_probs(self, state, temp=1e-3): 140 | """Run all playouts sequentially and return the available actions and 141 | their corresponding probabilities. 142 | state: the current game state 143 | temp: temperature parameter in (0, 1] controls the level of exploration 144 | """ 145 | for n in range(self._n_playout): 146 | state_copy = copy.deepcopy(state) 147 | self._playout(state_copy) 148 | 149 | # calc the move probabilities based on visit counts at the root node 150 | act_visits = [(act, node._n_visits) 151 | for act, node in self._root._children.items()] 152 | acts, visits = zip(*act_visits) 153 | act_probs = softmax(1.0/temp * np.log(np.array(visits) + 1e-10)) 154 | 155 | return acts, act_probs 156 | 157 | def update_with_move(self, last_move): 158 | """Step forward in the tree, keeping everything we already know 159 | about the subtree. 160 | """ 161 | if last_move in self._root._children: 162 | self._root = self._root._children[last_move] 163 | self._root._parent = None 164 | else: 165 | self._root = TreeNode(None, 1.0) 166 | 167 | def __str__(self): 168 | return "MCTS" 169 | 170 | 171 | class MCTSPlayer(object): 172 | """AI player based on MCTS""" 173 | 174 | def __init__(self, policy_value_function, 175 | c_puct=5, n_playout=2000, is_selfplay=0): 176 | self.mcts = MCTS(policy_value_function, c_puct, n_playout) 177 | self._is_selfplay = is_selfplay 178 | 179 | def set_player_ind(self, p): 180 | self.player = p 181 | 182 | def reset_player(self): 183 | self.mcts.update_with_move(-1) 184 | 185 | def get_action(self, board, temp=1e-3, return_prob=0): 186 | sensible_moves = board.availables 187 | # the pi vector returned by MCTS as in the alphaGo Zero paper 188 | move_probs = np.zeros(board.width*board.height) 189 | if len(sensible_moves) > 0: 190 | acts, probs = self.mcts.get_move_probs(board, temp) 191 | move_probs[list(acts)] = probs 192 | if self._is_selfplay: 193 | # add Dirichlet Noise for exploration (needed for 194 | # self-play training) 195 | move = np.random.choice( 196 | acts, 197 | p=0.75*probs + 0.25*np.random.dirichlet(0.3*np.ones(len(probs))) 198 | ) 199 | # update the root node and reuse the search tree 200 | self.mcts.update_with_move(move) 201 | else: 202 | # with the default temp=1e-3, it is almost equivalent 203 | # to choosing the move with the highest prob 204 | move = np.random.choice(acts, p=probs) 205 | # reset the root node 206 | self.mcts.update_with_move(-1) 207 | # location = board.move_to_location(move) 208 | # print("AI move: %d,%d\n" % (location[0], location[1])) 209 | 210 | if return_prob: 211 | return move, move_probs 212 | else: 213 | return move 214 | else: 215 | print("WARNING: the board is full") 216 | 217 | def __str__(self): 218 | return "MCTS {}".format(self.player) 219 | -------------------------------------------------------------------------------- /mcts_pure.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | A pure implementation of the Monte Carlo Tree Search (MCTS) 4 | 5 | @author: Junxiao Song 6 | """ 7 | 8 | import numpy as np 9 | import copy 10 | from operator import itemgetter 11 | 12 | 13 | def rollout_policy_fn(board): 14 | """a coarse, fast version of policy_fn used in the rollout phase.""" 15 | # rollout randomly 16 | action_probs = np.random.rand(len(board.availables)) 17 | return zip(board.availables, action_probs) 18 | 19 | 20 | def policy_value_fn(board): 21 | """a function that takes in a state and outputs a list of (action, probability) 22 | tuples and a score for the state""" 23 | # return uniform probabilities and 0 score for pure MCTS 24 | action_probs = np.ones(len(board.availables))/len(board.availables) 25 | return zip(board.availables, action_probs), 0 26 | 27 | 28 | class TreeNode(object): 29 | """A node in the MCTS tree. Each node keeps track of its own value Q, 30 | prior probability P, and its visit-count-adjusted prior score u. 31 | """ 32 | 33 | def __init__(self, parent, prior_p): 34 | self._parent = parent 35 | self._children = {} # a map from action to TreeNode 36 | self._n_visits = 0 37 | self._Q = 0 38 | self._u = 0 39 | self._P = prior_p 40 | 41 | def expand(self, action_priors): 42 | """Expand tree by creating new children. 43 | action_priors: a list of tuples of actions and their prior probability 44 | according to the policy function. 45 | """ 46 | for action, prob in action_priors: 47 | if action not in self._children: 48 | self._children[action] = TreeNode(self, prob) 49 | 50 | def select(self, c_puct): 51 | """Select action among children that gives maximum action value Q 52 | plus bonus u(P). 53 | Return: A tuple of (action, next_node) 54 | """ 55 | return max(self._children.items(), 56 | key=lambda act_node: act_node[1].get_value(c_puct)) 57 | 58 | def update(self, leaf_value): 59 | """Update node values from leaf evaluation. 60 | leaf_value: the value of subtree evaluation from the current player's 61 | perspective. 62 | """ 63 | # Count visit. 64 | self._n_visits += 1 65 | # Update Q, a running average of values for all visits. 66 | self._Q += 1.0*(leaf_value - self._Q) / self._n_visits 67 | 68 | def update_recursive(self, leaf_value): 69 | """Like a call to update(), but applied recursively for all ancestors. 70 | """ 71 | # If it is not root, this node's parent should be updated first. 72 | if self._parent: 73 | self._parent.update_recursive(-leaf_value) 74 | self.update(leaf_value) 75 | 76 | def get_value(self, c_puct): 77 | """Calculate and return the value for this node. 78 | It is a combination of leaf evaluations Q, and this node's prior 79 | adjusted for its visit count, u. 80 | c_puct: a number in (0, inf) controlling the relative impact of 81 | value Q, and prior probability P, on this node's score. 82 | """ 83 | self._u = (c_puct * self._P * 84 | np.sqrt(self._parent._n_visits) / (1 + self._n_visits)) 85 | return self._Q + self._u 86 | 87 | def is_leaf(self): 88 | """Check if leaf node (i.e. no nodes below this have been expanded). 89 | """ 90 | return self._children == {} 91 | 92 | def is_root(self): 93 | return self._parent is None 94 | 95 | 96 | class MCTS(object): 97 | """A simple implementation of Monte Carlo Tree Search.""" 98 | 99 | def __init__(self, policy_value_fn, c_puct=5, n_playout=10000): 100 | """ 101 | policy_value_fn: a function that takes in a board state and outputs 102 | a list of (action, probability) tuples and also a score in [-1, 1] 103 | (i.e. the expected value of the end game score from the current 104 | player's perspective) for the current player. 105 | c_puct: a number in (0, inf) that controls how quickly exploration 106 | converges to the maximum-value policy. A higher value means 107 | relying on the prior more. 108 | """ 109 | self._root = TreeNode(None, 1.0) 110 | self._policy = policy_value_fn 111 | self._c_puct = c_puct 112 | self._n_playout = n_playout 113 | 114 | def _playout(self, state): 115 | """Run a single playout from the root to the leaf, getting a value at 116 | the leaf and propagating it back through its parents. 117 | State is modified in-place, so a copy must be provided. 118 | """ 119 | node = self._root 120 | while(1): 121 | if node.is_leaf(): 122 | 123 | break 124 | # Greedily select next move. 125 | action, node = node.select(self._c_puct) 126 | state.do_move(action) 127 | 128 | action_probs, _ = self._policy(state) 129 | # Check for end of game 130 | end, winner = state.game_end() 131 | if not end: 132 | node.expand(action_probs) 133 | # Evaluate the leaf node by random rollout 134 | leaf_value = self._evaluate_rollout(state) 135 | # Update value and visit count of nodes in this traversal. 136 | node.update_recursive(-leaf_value) 137 | 138 | def _evaluate_rollout(self, state, limit=1000): 139 | """Use the rollout policy to play until the end of the game, 140 | returning +1 if the current player wins, -1 if the opponent wins, 141 | and 0 if it is a tie. 142 | """ 143 | player = state.get_current_player() 144 | for i in range(limit): 145 | end, winner = state.game_end() 146 | if end: 147 | break 148 | action_probs = rollout_policy_fn(state) 149 | max_action = max(action_probs, key=itemgetter(1))[0] 150 | state.do_move(max_action) 151 | else: 152 | # If no break from the loop, issue a warning. 153 | print("WARNING: rollout reached move limit") 154 | if winner == -1: # tie 155 | return 0 156 | else: 157 | return 1 if winner == player else -1 158 | 159 | def get_move(self, state): 160 | """Runs all playouts sequentially and returns the most visited action. 161 | state: the current game state 162 | 163 | Return: the selected action 164 | """ 165 | for n in range(self._n_playout): 166 | state_copy = copy.deepcopy(state) 167 | self._playout(state_copy) 168 | return max(self._root._children.items(), 169 | key=lambda act_node: act_node[1]._n_visits)[0] 170 | 171 | def update_with_move(self, last_move): 172 | """Step forward in the tree, keeping everything we already know 173 | about the subtree. 174 | """ 175 | if last_move in self._root._children: 176 | self._root = self._root._children[last_move] 177 | self._root._parent = None 178 | else: 179 | self._root = TreeNode(None, 1.0) 180 | 181 | def __str__(self): 182 | return "MCTS" 183 | 184 | 185 | class MCTSPlayer(object): 186 | """AI player based on MCTS""" 187 | def __init__(self, c_puct=5, n_playout=2000): 188 | self.mcts = MCTS(policy_value_fn, c_puct, n_playout) 189 | 190 | def set_player_ind(self, p): 191 | self.player = p 192 | 193 | def reset_player(self): 194 | self.mcts.update_with_move(-1) 195 | 196 | def get_action(self, board): 197 | sensible_moves = board.availables 198 | if len(sensible_moves) > 0: 199 | move = self.mcts.get_move(board) 200 | self.mcts.update_with_move(-1) 201 | return move 202 | else: 203 | print("WARNING: the board is full") 204 | 205 | def __str__(self): 206 | return "MCTS {}".format(self.player) 207 | -------------------------------------------------------------------------------- /playout400.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junxiaosong/AlphaZero_Gomoku/0b4917ccd84ea2771706701015eb26182c1159f4/playout400.gif -------------------------------------------------------------------------------- /policy_value_net.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | An implementation of the policyValueNet in Theano and Lasagne 4 | 5 | @author: Junxiao Song 6 | """ 7 | 8 | from __future__ import print_function 9 | import theano 10 | import theano.tensor as T 11 | import lasagne 12 | import pickle 13 | 14 | 15 | class PolicyValueNet(): 16 | """policy-value network """ 17 | def __init__(self, board_width, board_height, model_file=None): 18 | self.board_width = board_width 19 | self.board_height = board_height 20 | self.learning_rate = T.scalar('learning_rate') 21 | self.l2_const = 1e-4 # coef of l2 penalty 22 | self.create_policy_value_net() 23 | self._loss_train_op() 24 | if model_file: 25 | try: 26 | net_params = pickle.load(open(model_file, 'rb')) 27 | except: 28 | # To support loading pretrained model in python3 29 | net_params = pickle.load(open(model_file, 'rb'), 30 | encoding='bytes') 31 | lasagne.layers.set_all_param_values( 32 | [self.policy_net, self.value_net], net_params 33 | ) 34 | 35 | def create_policy_value_net(self): 36 | """create the policy value network """ 37 | self.state_input = T.tensor4('state') 38 | self.winner = T.vector('winner') 39 | self.mcts_probs = T.matrix('mcts_probs') 40 | network = lasagne.layers.InputLayer( 41 | shape=(None, 4, self.board_width, self.board_height), 42 | input_var=self.state_input 43 | ) 44 | # conv layers 45 | network = lasagne.layers.Conv2DLayer( 46 | network, num_filters=32, filter_size=(3, 3), pad='same') 47 | network = lasagne.layers.Conv2DLayer( 48 | network, num_filters=64, filter_size=(3, 3), pad='same') 49 | network = lasagne.layers.Conv2DLayer( 50 | network, num_filters=128, filter_size=(3, 3), pad='same') 51 | # action policy layers 52 | policy_net = lasagne.layers.Conv2DLayer( 53 | network, num_filters=4, filter_size=(1, 1)) 54 | self.policy_net = lasagne.layers.DenseLayer( 55 | policy_net, num_units=self.board_width*self.board_height, 56 | nonlinearity=lasagne.nonlinearities.softmax) 57 | # state value layers 58 | value_net = lasagne.layers.Conv2DLayer( 59 | network, num_filters=2, filter_size=(1, 1)) 60 | value_net = lasagne.layers.DenseLayer(value_net, num_units=64) 61 | self.value_net = lasagne.layers.DenseLayer( 62 | value_net, num_units=1, 63 | nonlinearity=lasagne.nonlinearities.tanh) 64 | # get action probs and state score value 65 | self.action_probs, self.value = lasagne.layers.get_output( 66 | [self.policy_net, self.value_net]) 67 | self.policy_value = theano.function([self.state_input], 68 | [self.action_probs, self.value], 69 | allow_input_downcast=True) 70 | 71 | def policy_value_fn(self, board): 72 | """ 73 | input: board 74 | output: a list of (action, probability) tuples for each available 75 | action and the score of the board state 76 | """ 77 | legal_positions = board.availables 78 | current_state = board.current_state() 79 | act_probs, value = self.policy_value( 80 | current_state.reshape(-1, 4, self.board_width, self.board_height) 81 | ) 82 | act_probs = zip(legal_positions, act_probs.flatten()[legal_positions]) 83 | return act_probs, value[0][0] 84 | 85 | def _loss_train_op(self): 86 | """ 87 | Three loss terms: 88 | loss = (z - v)^2 - pi^T * log(p) + c||theta||^2 89 | """ 90 | params = lasagne.layers.get_all_params( 91 | [self.policy_net, self.value_net], trainable=True) 92 | value_loss = lasagne.objectives.squared_error( 93 | self.winner, self.value.flatten()) 94 | policy_loss = lasagne.objectives.categorical_crossentropy( 95 | self.action_probs, self.mcts_probs) 96 | l2_penalty = lasagne.regularization.apply_penalty( 97 | params, lasagne.regularization.l2) 98 | self.loss = self.l2_const*l2_penalty + lasagne.objectives.aggregate( 99 | value_loss + policy_loss, mode='mean') 100 | # policy entropy,for monitoring only 101 | self.entropy = -T.mean(T.sum( 102 | self.action_probs * T.log(self.action_probs + 1e-10), axis=1)) 103 | # get the train op 104 | updates = lasagne.updates.adam(self.loss, params, 105 | learning_rate=self.learning_rate) 106 | self.train_step = theano.function( 107 | [self.state_input, self.mcts_probs, self.winner, self.learning_rate], 108 | [self.loss, self.entropy], 109 | updates=updates, 110 | allow_input_downcast=True 111 | ) 112 | 113 | def get_policy_param(self): 114 | net_params = lasagne.layers.get_all_param_values( 115 | [self.policy_net, self.value_net]) 116 | return net_params 117 | 118 | def save_model(self, model_file): 119 | """ save model params to file """ 120 | net_params = self.get_policy_param() # get model params 121 | pickle.dump(net_params, open(model_file, 'wb'), protocol=2) 122 | -------------------------------------------------------------------------------- /policy_value_net_keras.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | An implementation of the policyValueNet with Keras 4 | Tested under Keras 2.0.5 with tensorflow-gpu 1.2.1 as backend 5 | 6 | @author: Mingxu Zhang 7 | """ 8 | 9 | from __future__ import print_function 10 | 11 | from keras.engine.topology import Input 12 | from keras.engine.training import Model 13 | from keras.layers.convolutional import Conv2D 14 | from keras.layers.core import Activation, Dense, Flatten 15 | from keras.layers.merge import Add 16 | from keras.layers.normalization import BatchNormalization 17 | from keras.regularizers import l2 18 | from keras.optimizers import Adam 19 | import keras.backend as K 20 | 21 | from keras.utils import np_utils 22 | 23 | import numpy as np 24 | import pickle 25 | 26 | 27 | class PolicyValueNet(): 28 | """policy-value network """ 29 | def __init__(self, board_width, board_height, model_file=None): 30 | self.board_width = board_width 31 | self.board_height = board_height 32 | self.l2_const = 1e-4 # coef of l2 penalty 33 | self.create_policy_value_net() 34 | self._loss_train_op() 35 | 36 | if model_file: 37 | net_params = pickle.load(open(model_file, 'rb')) 38 | self.model.set_weights(net_params) 39 | 40 | def create_policy_value_net(self): 41 | """create the policy value network """ 42 | in_x = network = Input((4, self.board_width, self.board_height)) 43 | 44 | # conv layers 45 | network = Conv2D(filters=32, kernel_size=(3, 3), padding="same", data_format="channels_first", activation="relu", kernel_regularizer=l2(self.l2_const))(network) 46 | network = Conv2D(filters=64, kernel_size=(3, 3), padding="same", data_format="channels_first", activation="relu", kernel_regularizer=l2(self.l2_const))(network) 47 | network = Conv2D(filters=128, kernel_size=(3, 3), padding="same", data_format="channels_first", activation="relu", kernel_regularizer=l2(self.l2_const))(network) 48 | # action policy layers 49 | policy_net = Conv2D(filters=4, kernel_size=(1, 1), data_format="channels_first", activation="relu", kernel_regularizer=l2(self.l2_const))(network) 50 | policy_net = Flatten()(policy_net) 51 | self.policy_net = Dense(self.board_width*self.board_height, activation="softmax", kernel_regularizer=l2(self.l2_const))(policy_net) 52 | # state value layers 53 | value_net = Conv2D(filters=2, kernel_size=(1, 1), data_format="channels_first", activation="relu", kernel_regularizer=l2(self.l2_const))(network) 54 | value_net = Flatten()(value_net) 55 | value_net = Dense(64, kernel_regularizer=l2(self.l2_const))(value_net) 56 | self.value_net = Dense(1, activation="tanh", kernel_regularizer=l2(self.l2_const))(value_net) 57 | 58 | self.model = Model(in_x, [self.policy_net, self.value_net]) 59 | 60 | def policy_value(state_input): 61 | state_input_union = np.array(state_input) 62 | results = self.model.predict_on_batch(state_input_union) 63 | return results 64 | self.policy_value = policy_value 65 | 66 | def policy_value_fn(self, board): 67 | """ 68 | input: board 69 | output: a list of (action, probability) tuples for each available action and the score of the board state 70 | """ 71 | legal_positions = board.availables 72 | current_state = board.current_state() 73 | act_probs, value = self.policy_value(current_state.reshape(-1, 4, self.board_width, self.board_height)) 74 | act_probs = zip(legal_positions, act_probs.flatten()[legal_positions]) 75 | return act_probs, value[0][0] 76 | 77 | def _loss_train_op(self): 78 | """ 79 | Three loss terms: 80 | loss = (z - v)^2 + pi^T * log(p) + c||theta||^2 81 | """ 82 | 83 | # get the train op 84 | opt = Adam() 85 | losses = ['categorical_crossentropy', 'mean_squared_error'] 86 | self.model.compile(optimizer=opt, loss=losses) 87 | 88 | def self_entropy(probs): 89 | return -np.mean(np.sum(probs * np.log(probs + 1e-10), axis=1)) 90 | 91 | def train_step(state_input, mcts_probs, winner, learning_rate): 92 | state_input_union = np.array(state_input) 93 | mcts_probs_union = np.array(mcts_probs) 94 | winner_union = np.array(winner) 95 | loss = self.model.evaluate(state_input_union, [mcts_probs_union, winner_union], batch_size=len(state_input), verbose=0) 96 | action_probs, _ = self.model.predict_on_batch(state_input_union) 97 | entropy = self_entropy(action_probs) 98 | K.set_value(self.model.optimizer.lr, learning_rate) 99 | self.model.fit(state_input_union, [mcts_probs_union, winner_union], batch_size=len(state_input), verbose=0) 100 | return loss[0], entropy 101 | 102 | self.train_step = train_step 103 | 104 | def get_policy_param(self): 105 | net_params = self.model.get_weights() 106 | return net_params 107 | 108 | def save_model(self, model_file): 109 | """ save model params to file """ 110 | net_params = self.get_policy_param() 111 | pickle.dump(net_params, open(model_file, 'wb'), protocol=2) 112 | -------------------------------------------------------------------------------- /policy_value_net_numpy.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Implement the policy value network using numpy, so that we can play with the 4 | trained AI model without installing any DL framwork 5 | 6 | @author: Junxiao Song 7 | """ 8 | 9 | from __future__ import print_function 10 | import numpy as np 11 | 12 | 13 | # some utility functions 14 | def softmax(x): 15 | probs = np.exp(x - np.max(x)) 16 | probs /= np.sum(probs) 17 | return probs 18 | 19 | 20 | def relu(X): 21 | out = np.maximum(X, 0) 22 | return out 23 | 24 | 25 | def conv_forward(X, W, b, stride=1, padding=1): 26 | n_filters, d_filter, h_filter, w_filter = W.shape 27 | # theano conv2d flips the filters (rotate 180 degree) first 28 | # while doing the calculation 29 | W = W[:, :, ::-1, ::-1] 30 | n_x, d_x, h_x, w_x = X.shape 31 | h_out = (h_x - h_filter + 2 * padding) / stride + 1 32 | w_out = (w_x - w_filter + 2 * padding) / stride + 1 33 | h_out, w_out = int(h_out), int(w_out) 34 | X_col = im2col_indices(X, h_filter, w_filter, 35 | padding=padding, stride=stride) 36 | W_col = W.reshape(n_filters, -1) 37 | out = (np.dot(W_col, X_col).T + b).T 38 | out = out.reshape(n_filters, h_out, w_out, n_x) 39 | out = out.transpose(3, 0, 1, 2) 40 | return out 41 | 42 | 43 | def fc_forward(X, W, b): 44 | out = np.dot(X, W) + b 45 | return out 46 | 47 | 48 | def get_im2col_indices(x_shape, field_height, 49 | field_width, padding=1, stride=1): 50 | # First figure out what the size of the output should be 51 | N, C, H, W = x_shape 52 | assert (H + 2 * padding - field_height) % stride == 0 53 | assert (W + 2 * padding - field_width) % stride == 0 54 | out_height = int((H + 2 * padding - field_height) / stride + 1) 55 | out_width = int((W + 2 * padding - field_width) / stride + 1) 56 | 57 | i0 = np.repeat(np.arange(field_height), field_width) 58 | i0 = np.tile(i0, C) 59 | i1 = stride * np.repeat(np.arange(out_height), out_width) 60 | j0 = np.tile(np.arange(field_width), field_height * C) 61 | j1 = stride * np.tile(np.arange(out_width), out_height) 62 | i = i0.reshape(-1, 1) + i1.reshape(1, -1) 63 | j = j0.reshape(-1, 1) + j1.reshape(1, -1) 64 | 65 | k = np.repeat(np.arange(C), field_height * field_width).reshape(-1, 1) 66 | 67 | return (k.astype(int), i.astype(int), j.astype(int)) 68 | 69 | 70 | def im2col_indices(x, field_height, field_width, padding=1, stride=1): 71 | """ An implementation of im2col based on some fancy indexing """ 72 | # Zero-pad the input 73 | p = padding 74 | x_padded = np.pad(x, ((0, 0), (0, 0), (p, p), (p, p)), mode='constant') 75 | 76 | k, i, j = get_im2col_indices(x.shape, field_height, 77 | field_width, padding, stride) 78 | 79 | cols = x_padded[:, k, i, j] 80 | C = x.shape[1] 81 | cols = cols.transpose(1, 2, 0).reshape(field_height * field_width * C, -1) 82 | return cols 83 | 84 | 85 | class PolicyValueNetNumpy(): 86 | """policy-value network in numpy """ 87 | def __init__(self, board_width, board_height, net_params): 88 | self.board_width = board_width 89 | self.board_height = board_height 90 | self.params = net_params 91 | 92 | def policy_value_fn(self, board): 93 | """ 94 | input: board 95 | output: a list of (action, probability) tuples for each available 96 | action and the score of the board state 97 | """ 98 | legal_positions = board.availables 99 | current_state = board.current_state() 100 | 101 | X = current_state.reshape(-1, 4, self.board_width, self.board_height) 102 | # first 3 conv layers with ReLu nonlinearity 103 | for i in [0, 2, 4]: 104 | X = relu(conv_forward(X, self.params[i], self.params[i+1])) 105 | # policy head 106 | X_p = relu(conv_forward(X, self.params[6], self.params[7], padding=0)) 107 | X_p = fc_forward(X_p.flatten(), self.params[8], self.params[9]) 108 | act_probs = softmax(X_p) 109 | # value head 110 | X_v = relu(conv_forward(X, self.params[10], 111 | self.params[11], padding=0)) 112 | X_v = relu(fc_forward(X_v.flatten(), self.params[12], self.params[13])) 113 | value = np.tanh(fc_forward(X_v, self.params[14], self.params[15]))[0] 114 | act_probs = zip(legal_positions, act_probs.flatten()[legal_positions]) 115 | return act_probs, value 116 | -------------------------------------------------------------------------------- /policy_value_net_pytorch.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | An implementation of the policyValueNet in PyTorch 4 | Tested in PyTorch 0.2.0 and 0.3.0 5 | 6 | @author: Junxiao Song 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | import torch.nn.functional as F 13 | from torch.autograd import Variable 14 | import numpy as np 15 | 16 | 17 | def set_learning_rate(optimizer, lr): 18 | """Sets the learning rate to the given value""" 19 | for param_group in optimizer.param_groups: 20 | param_group['lr'] = lr 21 | 22 | 23 | class Net(nn.Module): 24 | """policy-value network module""" 25 | def __init__(self, board_width, board_height): 26 | super(Net, self).__init__() 27 | 28 | self.board_width = board_width 29 | self.board_height = board_height 30 | # common layers 31 | self.conv1 = nn.Conv2d(4, 32, kernel_size=3, padding=1) 32 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 33 | self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 34 | # action policy layers 35 | self.act_conv1 = nn.Conv2d(128, 4, kernel_size=1) 36 | self.act_fc1 = nn.Linear(4*board_width*board_height, 37 | board_width*board_height) 38 | # state value layers 39 | self.val_conv1 = nn.Conv2d(128, 2, kernel_size=1) 40 | self.val_fc1 = nn.Linear(2*board_width*board_height, 64) 41 | self.val_fc2 = nn.Linear(64, 1) 42 | 43 | def forward(self, state_input): 44 | # common layers 45 | x = F.relu(self.conv1(state_input)) 46 | x = F.relu(self.conv2(x)) 47 | x = F.relu(self.conv3(x)) 48 | # action policy layers 49 | x_act = F.relu(self.act_conv1(x)) 50 | x_act = x_act.view(-1, 4*self.board_width*self.board_height) 51 | x_act = F.log_softmax(self.act_fc1(x_act)) 52 | # state value layers 53 | x_val = F.relu(self.val_conv1(x)) 54 | x_val = x_val.view(-1, 2*self.board_width*self.board_height) 55 | x_val = F.relu(self.val_fc1(x_val)) 56 | x_val = F.tanh(self.val_fc2(x_val)) 57 | return x_act, x_val 58 | 59 | 60 | class PolicyValueNet(): 61 | """policy-value network """ 62 | def __init__(self, board_width, board_height, 63 | model_file=None, use_gpu=False): 64 | self.use_gpu = use_gpu 65 | self.board_width = board_width 66 | self.board_height = board_height 67 | self.l2_const = 1e-4 # coef of l2 penalty 68 | # the policy value net module 69 | if self.use_gpu: 70 | self.policy_value_net = Net(board_width, board_height).cuda() 71 | else: 72 | self.policy_value_net = Net(board_width, board_height) 73 | self.optimizer = optim.Adam(self.policy_value_net.parameters(), 74 | weight_decay=self.l2_const) 75 | 76 | if model_file: 77 | net_params = torch.load(model_file) 78 | self.policy_value_net.load_state_dict(net_params) 79 | 80 | def policy_value(self, state_batch): 81 | """ 82 | input: a batch of states 83 | output: a batch of action probabilities and state values 84 | """ 85 | if self.use_gpu: 86 | state_batch = Variable(torch.FloatTensor(state_batch).cuda()) 87 | log_act_probs, value = self.policy_value_net(state_batch) 88 | act_probs = np.exp(log_act_probs.data.cpu().numpy()) 89 | return act_probs, value.data.cpu().numpy() 90 | else: 91 | state_batch = Variable(torch.FloatTensor(state_batch)) 92 | log_act_probs, value = self.policy_value_net(state_batch) 93 | act_probs = np.exp(log_act_probs.data.numpy()) 94 | return act_probs, value.data.numpy() 95 | 96 | def policy_value_fn(self, board): 97 | """ 98 | input: board 99 | output: a list of (action, probability) tuples for each available 100 | action and the score of the board state 101 | """ 102 | legal_positions = board.availables 103 | current_state = np.ascontiguousarray(board.current_state().reshape( 104 | -1, 4, self.board_width, self.board_height)) 105 | if self.use_gpu: 106 | log_act_probs, value = self.policy_value_net( 107 | Variable(torch.from_numpy(current_state)).cuda().float()) 108 | act_probs = np.exp(log_act_probs.data.cpu().numpy().flatten()) 109 | value = value.data.cpu().numpy()[0][0] 110 | else: 111 | log_act_probs, value = self.policy_value_net( 112 | Variable(torch.from_numpy(current_state)).float()) 113 | act_probs = np.exp(log_act_probs.data.numpy().flatten()) 114 | value = value.data.numpy()[0][0] 115 | act_probs = zip(legal_positions, act_probs[legal_positions]) 116 | return act_probs, value 117 | 118 | def train_step(self, state_batch, mcts_probs, winner_batch, lr): 119 | """perform a training step""" 120 | # wrap in Variable 121 | if self.use_gpu: 122 | state_batch = Variable(torch.FloatTensor(state_batch).cuda()) 123 | mcts_probs = Variable(torch.FloatTensor(mcts_probs).cuda()) 124 | winner_batch = Variable(torch.FloatTensor(winner_batch).cuda()) 125 | else: 126 | state_batch = Variable(torch.FloatTensor(state_batch)) 127 | mcts_probs = Variable(torch.FloatTensor(mcts_probs)) 128 | winner_batch = Variable(torch.FloatTensor(winner_batch)) 129 | 130 | # zero the parameter gradients 131 | self.optimizer.zero_grad() 132 | # set learning rate 133 | set_learning_rate(self.optimizer, lr) 134 | 135 | # forward 136 | log_act_probs, value = self.policy_value_net(state_batch) 137 | # define the loss = (z - v)^2 - pi^T * log(p) + c||theta||^2 138 | # Note: the L2 penalty is incorporated in optimizer 139 | value_loss = F.mse_loss(value.view(-1), winner_batch) 140 | policy_loss = -torch.mean(torch.sum(mcts_probs*log_act_probs, 1)) 141 | loss = value_loss + policy_loss 142 | # backward and optimize 143 | loss.backward() 144 | self.optimizer.step() 145 | # calc policy entropy, for monitoring only 146 | entropy = -torch.mean( 147 | torch.sum(torch.exp(log_act_probs) * log_act_probs, 1) 148 | ) 149 | return loss.data[0], entropy.data[0] 150 | #for pytorch version >= 0.5 please use the following line instead. 151 | #return loss.item(), entropy.item() 152 | 153 | def get_policy_param(self): 154 | net_params = self.policy_value_net.state_dict() 155 | return net_params 156 | 157 | def save_model(self, model_file): 158 | """ save model params to file """ 159 | net_params = self.get_policy_param() # get model params 160 | torch.save(net_params, model_file) 161 | -------------------------------------------------------------------------------- /policy_value_net_tensorflow.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | An implementation of the policyValueNet in Tensorflow 4 | Tested in Tensorflow 1.4 and 1.5 5 | 6 | @author: Xiang Zhong 7 | """ 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | 12 | 13 | class PolicyValueNet(): 14 | def __init__(self, board_width, board_height, model_file=None): 15 | self.board_width = board_width 16 | self.board_height = board_height 17 | 18 | # Define the tensorflow neural network 19 | # 1. Input: 20 | self.input_states = tf.placeholder( 21 | tf.float32, shape=[None, 4, board_height, board_width]) 22 | self.input_state = tf.transpose(self.input_states, [0, 2, 3, 1]) 23 | # 2. Common Networks Layers 24 | self.conv1 = tf.layers.conv2d(inputs=self.input_state, 25 | filters=32, kernel_size=[3, 3], 26 | padding="same", data_format="channels_last", 27 | activation=tf.nn.relu) 28 | self.conv2 = tf.layers.conv2d(inputs=self.conv1, filters=64, 29 | kernel_size=[3, 3], padding="same", 30 | data_format="channels_last", 31 | activation=tf.nn.relu) 32 | self.conv3 = tf.layers.conv2d(inputs=self.conv2, filters=128, 33 | kernel_size=[3, 3], padding="same", 34 | data_format="channels_last", 35 | activation=tf.nn.relu) 36 | # 3-1 Action Networks 37 | self.action_conv = tf.layers.conv2d(inputs=self.conv3, filters=4, 38 | kernel_size=[1, 1], padding="same", 39 | data_format="channels_last", 40 | activation=tf.nn.relu) 41 | # Flatten the tensor 42 | self.action_conv_flat = tf.reshape( 43 | self.action_conv, [-1, 4 * board_height * board_width]) 44 | # 3-2 Full connected layer, the output is the log probability of moves 45 | # on each slot on the board 46 | self.action_fc = tf.layers.dense(inputs=self.action_conv_flat, 47 | units=board_height * board_width, 48 | activation=tf.nn.log_softmax) 49 | # 4 Evaluation Networks 50 | self.evaluation_conv = tf.layers.conv2d(inputs=self.conv3, filters=2, 51 | kernel_size=[1, 1], 52 | padding="same", 53 | data_format="channels_last", 54 | activation=tf.nn.relu) 55 | self.evaluation_conv_flat = tf.reshape( 56 | self.evaluation_conv, [-1, 2 * board_height * board_width]) 57 | self.evaluation_fc1 = tf.layers.dense(inputs=self.evaluation_conv_flat, 58 | units=64, activation=tf.nn.relu) 59 | # output the score of evaluation on current state 60 | self.evaluation_fc2 = tf.layers.dense(inputs=self.evaluation_fc1, 61 | units=1, activation=tf.nn.tanh) 62 | 63 | # Define the Loss function 64 | # 1. Label: the array containing if the game wins or not for each state 65 | self.labels = tf.placeholder(tf.float32, shape=[None, 1]) 66 | # 2. Predictions: the array containing the evaluation score of each state 67 | # which is self.evaluation_fc2 68 | # 3-1. Value Loss function 69 | self.value_loss = tf.losses.mean_squared_error(self.labels, 70 | self.evaluation_fc2) 71 | # 3-2. Policy Loss function 72 | self.mcts_probs = tf.placeholder( 73 | tf.float32, shape=[None, board_height * board_width]) 74 | self.policy_loss = tf.negative(tf.reduce_mean( 75 | tf.reduce_sum(tf.multiply(self.mcts_probs, self.action_fc), 1))) 76 | # 3-3. L2 penalty (regularization) 77 | l2_penalty_beta = 1e-4 78 | vars = tf.trainable_variables() 79 | l2_penalty = l2_penalty_beta * tf.add_n( 80 | [tf.nn.l2_loss(v) for v in vars if 'bias' not in v.name.lower()]) 81 | # 3-4 Add up to be the Loss function 82 | self.loss = self.value_loss + self.policy_loss + l2_penalty 83 | 84 | # Define the optimizer we use for training 85 | self.learning_rate = tf.placeholder(tf.float32) 86 | self.optimizer = tf.train.AdamOptimizer( 87 | learning_rate=self.learning_rate).minimize(self.loss) 88 | 89 | # Make a session 90 | self.session = tf.Session() 91 | 92 | # calc policy entropy, for monitoring only 93 | self.entropy = tf.negative(tf.reduce_mean( 94 | tf.reduce_sum(tf.exp(self.action_fc) * self.action_fc, 1))) 95 | 96 | # Initialize variables 97 | init = tf.global_variables_initializer() 98 | self.session.run(init) 99 | 100 | # For saving and restoring 101 | self.saver = tf.train.Saver() 102 | if model_file is not None: 103 | self.restore_model(model_file) 104 | 105 | def policy_value(self, state_batch): 106 | """ 107 | input: a batch of states 108 | output: a batch of action probabilities and state values 109 | """ 110 | log_act_probs, value = self.session.run( 111 | [self.action_fc, self.evaluation_fc2], 112 | feed_dict={self.input_states: state_batch} 113 | ) 114 | act_probs = np.exp(log_act_probs) 115 | return act_probs, value 116 | 117 | def policy_value_fn(self, board): 118 | """ 119 | input: board 120 | output: a list of (action, probability) tuples for each available 121 | action and the score of the board state 122 | """ 123 | legal_positions = board.availables 124 | current_state = np.ascontiguousarray(board.current_state().reshape( 125 | -1, 4, self.board_width, self.board_height)) 126 | act_probs, value = self.policy_value(current_state) 127 | act_probs = zip(legal_positions, act_probs[0][legal_positions]) 128 | return act_probs, value 129 | 130 | def train_step(self, state_batch, mcts_probs, winner_batch, lr): 131 | """perform a training step""" 132 | winner_batch = np.reshape(winner_batch, (-1, 1)) 133 | loss, entropy, _ = self.session.run( 134 | [self.loss, self.entropy, self.optimizer], 135 | feed_dict={self.input_states: state_batch, 136 | self.mcts_probs: mcts_probs, 137 | self.labels: winner_batch, 138 | self.learning_rate: lr}) 139 | return loss, entropy 140 | 141 | def save_model(self, model_path): 142 | self.saver.save(self.session, model_path) 143 | 144 | def restore_model(self, model_path): 145 | self.saver.restore(self.session, model_path) 146 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | An implementation of the training pipeline of AlphaZero for Gomoku 4 | 5 | @author: Junxiao Song 6 | """ 7 | 8 | from __future__ import print_function 9 | import random 10 | import numpy as np 11 | from collections import defaultdict, deque 12 | from game import Board, Game 13 | from mcts_pure import MCTSPlayer as MCTS_Pure 14 | from mcts_alphaZero import MCTSPlayer 15 | from policy_value_net import PolicyValueNet # Theano and Lasagne 16 | # from policy_value_net_pytorch import PolicyValueNet # Pytorch 17 | # from policy_value_net_tensorflow import PolicyValueNet # Tensorflow 18 | # from policy_value_net_keras import PolicyValueNet # Keras 19 | 20 | 21 | class TrainPipeline(): 22 | def __init__(self, init_model=None): 23 | # params of the board and the game 24 | self.board_width = 6 25 | self.board_height = 6 26 | self.n_in_row = 4 27 | self.board = Board(width=self.board_width, 28 | height=self.board_height, 29 | n_in_row=self.n_in_row) 30 | self.game = Game(self.board) 31 | # training params 32 | self.learn_rate = 2e-3 33 | self.lr_multiplier = 1.0 # adaptively adjust the learning rate based on KL 34 | self.temp = 1.0 # the temperature param 35 | self.n_playout = 400 # num of simulations for each move 36 | self.c_puct = 5 37 | self.buffer_size = 10000 38 | self.batch_size = 512 # mini-batch size for training 39 | self.data_buffer = deque(maxlen=self.buffer_size) 40 | self.play_batch_size = 1 41 | self.epochs = 5 # num of train_steps for each update 42 | self.kl_targ = 0.02 43 | self.check_freq = 50 44 | self.game_batch_num = 1500 45 | self.best_win_ratio = 0.0 46 | # num of simulations used for the pure mcts, which is used as 47 | # the opponent to evaluate the trained policy 48 | self.pure_mcts_playout_num = 1000 49 | if init_model: 50 | # start training from an initial policy-value net 51 | self.policy_value_net = PolicyValueNet(self.board_width, 52 | self.board_height, 53 | model_file=init_model) 54 | else: 55 | # start training from a new policy-value net 56 | self.policy_value_net = PolicyValueNet(self.board_width, 57 | self.board_height) 58 | self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, 59 | c_puct=self.c_puct, 60 | n_playout=self.n_playout, 61 | is_selfplay=1) 62 | 63 | def get_equi_data(self, play_data): 64 | """augment the data set by rotation and flipping 65 | play_data: [(state, mcts_prob, winner_z), ..., ...] 66 | """ 67 | extend_data = [] 68 | for state, mcts_prob, winner in play_data: 69 | for i in [1, 2, 3, 4]: 70 | # rotate counterclockwise 71 | equi_state = np.array([np.rot90(s, i) for s in state]) 72 | equi_mcts_prob = np.rot90(np.flipud( 73 | mcts_prob.reshape(self.board_height, self.board_width)), i) 74 | extend_data.append((equi_state, 75 | np.flipud(equi_mcts_prob).flatten(), 76 | winner)) 77 | # flip horizontally 78 | equi_state = np.array([np.fliplr(s) for s in equi_state]) 79 | equi_mcts_prob = np.fliplr(equi_mcts_prob) 80 | extend_data.append((equi_state, 81 | np.flipud(equi_mcts_prob).flatten(), 82 | winner)) 83 | return extend_data 84 | 85 | def collect_selfplay_data(self, n_games=1): 86 | """collect self-play data for training""" 87 | for i in range(n_games): 88 | winner, play_data = self.game.start_self_play(self.mcts_player, 89 | temp=self.temp) 90 | play_data = list(play_data)[:] 91 | self.episode_len = len(play_data) 92 | # augment the data 93 | play_data = self.get_equi_data(play_data) 94 | self.data_buffer.extend(play_data) 95 | 96 | def policy_update(self): 97 | """update the policy-value net""" 98 | mini_batch = random.sample(self.data_buffer, self.batch_size) 99 | state_batch = [data[0] for data in mini_batch] 100 | mcts_probs_batch = [data[1] for data in mini_batch] 101 | winner_batch = [data[2] for data in mini_batch] 102 | old_probs, old_v = self.policy_value_net.policy_value(state_batch) 103 | for i in range(self.epochs): 104 | loss, entropy = self.policy_value_net.train_step( 105 | state_batch, 106 | mcts_probs_batch, 107 | winner_batch, 108 | self.learn_rate*self.lr_multiplier) 109 | new_probs, new_v = self.policy_value_net.policy_value(state_batch) 110 | kl = np.mean(np.sum(old_probs * ( 111 | np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), 112 | axis=1) 113 | ) 114 | if kl > self.kl_targ * 4: # early stopping if D_KL diverges badly 115 | break 116 | # adaptively adjust the learning rate 117 | if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1: 118 | self.lr_multiplier /= 1.5 119 | elif kl < self.kl_targ / 2 and self.lr_multiplier < 10: 120 | self.lr_multiplier *= 1.5 121 | 122 | explained_var_old = (1 - 123 | np.var(np.array(winner_batch) - old_v.flatten()) / 124 | np.var(np.array(winner_batch))) 125 | explained_var_new = (1 - 126 | np.var(np.array(winner_batch) - new_v.flatten()) / 127 | np.var(np.array(winner_batch))) 128 | print(("kl:{:.5f}," 129 | "lr_multiplier:{:.3f}," 130 | "loss:{}," 131 | "entropy:{}," 132 | "explained_var_old:{:.3f}," 133 | "explained_var_new:{:.3f}" 134 | ).format(kl, 135 | self.lr_multiplier, 136 | loss, 137 | entropy, 138 | explained_var_old, 139 | explained_var_new)) 140 | return loss, entropy 141 | 142 | def policy_evaluate(self, n_games=10): 143 | """ 144 | Evaluate the trained policy by playing against the pure MCTS player 145 | Note: this is only for monitoring the progress of training 146 | """ 147 | current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, 148 | c_puct=self.c_puct, 149 | n_playout=self.n_playout) 150 | pure_mcts_player = MCTS_Pure(c_puct=5, 151 | n_playout=self.pure_mcts_playout_num) 152 | win_cnt = defaultdict(int) 153 | for i in range(n_games): 154 | winner = self.game.start_play(current_mcts_player, 155 | pure_mcts_player, 156 | start_player=i % 2, 157 | is_shown=0) 158 | win_cnt[winner] += 1 159 | win_ratio = 1.0*(win_cnt[1] + 0.5*win_cnt[-1]) / n_games 160 | print("num_playouts:{}, win: {}, lose: {}, tie:{}".format( 161 | self.pure_mcts_playout_num, 162 | win_cnt[1], win_cnt[2], win_cnt[-1])) 163 | return win_ratio 164 | 165 | def run(self): 166 | """run the training pipeline""" 167 | try: 168 | for i in range(self.game_batch_num): 169 | self.collect_selfplay_data(self.play_batch_size) 170 | print("batch i:{}, episode_len:{}".format( 171 | i+1, self.episode_len)) 172 | if len(self.data_buffer) > self.batch_size: 173 | loss, entropy = self.policy_update() 174 | # check the performance of the current model, 175 | # and save the model params 176 | if (i+1) % self.check_freq == 0: 177 | print("current self-play batch: {}".format(i+1)) 178 | win_ratio = self.policy_evaluate() 179 | self.policy_value_net.save_model('./current_policy.model') 180 | if win_ratio > self.best_win_ratio: 181 | print("New best policy!!!!!!!!") 182 | self.best_win_ratio = win_ratio 183 | # update the best_policy 184 | self.policy_value_net.save_model('./best_policy.model') 185 | if (self.best_win_ratio == 1.0 and 186 | self.pure_mcts_playout_num < 5000): 187 | self.pure_mcts_playout_num += 1000 188 | self.best_win_ratio = 0.0 189 | except KeyboardInterrupt: 190 | print('\n\rquit') 191 | 192 | 193 | if __name__ == '__main__': 194 | training_pipeline = TrainPipeline() 195 | training_pipeline.run() 196 | --------------------------------------------------------------------------------