├── README.md ├── example.gif ├── game.py ├── human_play.py ├── mcts_alphaZero.py ├── mcts_pure.py ├── model ├── checkpoint ├── tf_policy_11_11_5_model.data-00000-of-00001 ├── tf_policy_11_11_5_model.index ├── tf_policy_11_11_5_model.meta ├── tf_policy_8_8_5_model.data-00000-of-00001 ├── tf_policy_8_8_5_model.index └── tf_policy_8_8_5_model.meta ├── tf_policy_value_net.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # AlphaZero_Gomoku-tensorflow 2 | 3 | Forked from [junxiaosong/AlphaZero_Gomoku](https://github.com/junxiaosong/AlphaZero_Gomoku) with some changes: 4 | 5 | * rewrited the network code with tensorflow 6 | * trained with 11 * 11 board 7 | * added a GUI 8 | 9 | ## Usage 10 | To play with the AI 11 | 12 | $ python human_play.py 13 | 14 | To train the model: 15 | 16 | $ python train.py 17 | 18 | 19 | ### Example of Game 20 | 21 | ![Example](https://github.com/zouyih/AlphaZero_Gomoku-tensorflow/blob/master/example.gif) 22 | 23 | there's another interesting implementation of reinforcement learning [DQN-tensorflow-gluttonous_snake](https://github.com/zouyih/DQN-tensorflow-gluttonous_snake) -------------------------------------------------------------------------------- /example.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zouyih/AlphaZero_Gomoku-tensorflow/45aa66b01df1f4618a17d3055bfe183828bec4f7/example.gif -------------------------------------------------------------------------------- /game.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tkinter 3 | 4 | class Board(object): 5 | """ 6 | board for the game 7 | """ 8 | 9 | def __init__(self, **kwargs): 10 | self.width = int(kwargs.get('width', 8)) 11 | self.height = int(kwargs.get('height', 8)) 12 | self.states = {} # board states, key:move as location on the board, value:player as pieces type 13 | self.n_in_row = int(kwargs.get('n_in_row', 5)) # need how many pieces in a row to win 14 | self.players = [1, 2] # player1 and player2 15 | 16 | def init_board(self, start_player=0): 17 | if self.width < self.n_in_row or self.height < self.n_in_row: 18 | raise Exception('board width and height can not less than %d' % self.n_in_row) 19 | self.current_player = self.players[start_player] # start player 20 | self.availables = list(range(self.width * self.height)) # available moves 21 | self.states = {} # board states, key:move as location on the board, value:player as pieces type 22 | self.last_move = -1 23 | 24 | def move_to_location(self, move): 25 | """ 26 | 3*3 board's moves like: 27 | 6 7 8 28 | 3 4 5 29 | 0 1 2 30 | and move 5's location is (1,2) 31 | """ 32 | h = move // self.width 33 | w = move % self.width 34 | return [h, w] 35 | 36 | def location_to_move(self, location): 37 | if(len(location) != 2): 38 | return -1 39 | h = location[0] 40 | w = location[1] 41 | move = h * self.width + w 42 | if(move not in range(self.width * self.height)): 43 | return -1 44 | return move 45 | 46 | def current_state(self): 47 | """return the board state from the perspective of the current player 48 | shape: 4*width*height""" 49 | 50 | square_state = np.zeros((4, self.width, self.height)) 51 | if self.states: 52 | moves, players = np.array(list(zip(*self.states.items()))) 53 | move_curr = moves[players == self.current_player] 54 | move_oppo = moves[players != self.current_player] 55 | square_state[0][move_curr // self.width, move_curr % self.height] = 1.0 56 | square_state[1][move_oppo // self.width, move_oppo % self.height] = 1.0 57 | square_state[2][self.last_move //self.width, self.last_move % self.height] = 1.0 # last move indication 58 | if len(self.states)%2 == 0: 59 | square_state[3][:,:] = 1.0 60 | 61 | return square_state[:,::-1,:] 62 | 63 | def do_move(self, move): 64 | self.states[move] = self.current_player 65 | self.availables.remove(move) 66 | self.current_player = self.players[0] if self.current_player == self.players[1] else self.players[1] 67 | self.last_move = move 68 | 69 | def has_a_winner(self): 70 | width = self.width 71 | height = self.height 72 | states = self.states 73 | n = self.n_in_row 74 | 75 | moved = list(set(range(width * height)) - set(self.availables)) 76 | if(len(moved) < self.n_in_row*2 - 1): 77 | return False, -1 78 | 79 | for m in moved: 80 | h = m // width 81 | w = m % width 82 | player = states[m] 83 | 84 | if (w in range(width - n + 1) and 85 | len(set(states.get(i, -1) for i in range(m, m + n))) == 1): 86 | return True, player 87 | 88 | if (h in range(height - n + 1) and 89 | len(set(states.get(i, -1) for i in range(m, m + n * width, width))) == 1): 90 | return True, player 91 | 92 | if (w in range(width - n + 1) and h in range(height - n + 1) and 93 | len(set(states.get(i, -1) for i in range(m, m + n * (width + 1), width + 1))) == 1): 94 | return True, player 95 | 96 | if (w in range(n - 1, width) and h in range(height - n + 1) and 97 | len(set(states.get(i, -1) for i in range(m, m + n * (width - 1), width - 1))) == 1): 98 | return True, player 99 | 100 | return False, -1 101 | 102 | def game_end(self): 103 | """Check whether the game is ended or not""" 104 | win, winner = self.has_a_winner() 105 | if win: 106 | return True, winner 107 | elif not len(self.availables):# 108 | return True, -1 109 | return False, -1 110 | 111 | def get_current_player(self): 112 | return self.current_player 113 | 114 | class Point: 115 | 116 | def __init__(self, x, y): 117 | self.x = x; 118 | self.y = y; 119 | self.pixel_x = 30 + 30 * self.x 120 | self.pixel_y = 30 + 30 * self.y 121 | 122 | class Game(object): 123 | """ 124 | game server 125 | """ 126 | def __init__(self, board, **kwargs): 127 | self.board = board 128 | 129 | def click1(self, event): #click1 because keyword repetition 130 | 131 | current_player = self.board.get_current_player() 132 | if current_player == 1: 133 | i = (event.x) // 30 134 | j = (event.y) // 30 135 | ri = (event.x) % 30 136 | rj = (event.y) % 30 137 | i = i-1 if ri<15 else i 138 | j = j-1 if rj<15 else j 139 | move = self.board.location_to_move((i, j)) 140 | if move in self.board.availables: 141 | self.cv.create_oval(self.chess_board_points[i][j].pixel_x-10, self.chess_board_points[i][j].pixel_y-10, self.chess_board_points[i][j].pixel_x+10, self.chess_board_points[i][j].pixel_y+10, fill='black') 142 | self.board.do_move(move) 143 | 144 | def run(self): 145 | current_player = self.board.get_current_player() 146 | 147 | end, winner = self.board.game_end() 148 | 149 | if current_player == 2 and not end: 150 | player_in_turn = self.players[current_player] 151 | move = player_in_turn.get_action(self.board) 152 | self.board.do_move(move) 153 | i, j = self.board.move_to_location(move) 154 | self.cv.create_oval(self.chess_board_points[i][j].pixel_x-10, self.chess_board_points[i][j].pixel_y-10, self.chess_board_points[i][j].pixel_x+10, self.chess_board_points[i][j].pixel_y+10, fill='white') 155 | 156 | end, winner = self.board.game_end() 157 | 158 | if end: 159 | if winner != -1: 160 | self.cv.create_text(self.board.width*15+15, self.board.height*30+30, text="Game over. Winner is {}".format(self.players[winner])) 161 | self.cv.unbind('') 162 | else: 163 | self.cv.create_text(self.board.width*15+15, self.board.height*30+30, text="Game end. Tie") 164 | 165 | return winner 166 | else: 167 | self.cv.after(100, self.run) 168 | 169 | def graphic(self, board, player1, player2): 170 | """ 171 | Draw the board and show game info 172 | """ 173 | width = board.width 174 | height = board.height 175 | 176 | p1, p2 = self.board.players 177 | player1.set_player_ind(p1) 178 | player2.set_player_ind(p2) 179 | self.players = {p1: player1, p2:player2} 180 | 181 | window = tkinter.Tk() 182 | self.cv = tkinter.Canvas(window, height=height*30+60, width=width*30 + 30, bg = 'white') 183 | self.chess_board_points = [[None for i in range(height)] for j in range(width)] 184 | 185 | for i in range(width): 186 | for j in range(height): 187 | self.chess_board_points[i][j] = Point(i, j); 188 | for i in range(width): #vertical line 189 | self.cv.create_line(self.chess_board_points[i][0].pixel_x, self.chess_board_points[i][0].pixel_y, self.chess_board_points[i][width-1].pixel_x, self.chess_board_points[i][width-1].pixel_y) 190 | 191 | for j in range(height): #rizontal line 192 | self.cv.create_line(self.chess_board_points[0][j].pixel_x, self.chess_board_points[0][j].pixel_y, self.chess_board_points[height-1][j].pixel_x, self.chess_board_points[height-1][j].pixel_y) 193 | 194 | self.button = tkinter.Button(window, text="start game!", command=self.run) 195 | self.cv.bind('', self.click1) 196 | self.cv.pack() 197 | self.button.pack() 198 | window.mainloop() 199 | 200 | def start_play(self, player1, player2, start_player=0, is_shown=1): 201 | """ 202 | start a game between two players 203 | """ 204 | if start_player not in (0,1): 205 | raise Exception('start_player should be 0 (player1 first) or 1 (player2 first)') 206 | self.board.init_board(start_player) 207 | 208 | if is_shown: 209 | self.graphic(self.board, player1, player2) 210 | else: 211 | p1, p2 = self.board.players 212 | player1.set_player_ind(p1) 213 | player2.set_player_ind(p2) 214 | players = {p1: player1, p2:player2} 215 | while(1): 216 | current_player = self.board.get_current_player() 217 | print(current_player) 218 | player_in_turn = players[current_player] 219 | move = player_in_turn.get_action(self.board) 220 | self.board.do_move(move) 221 | if is_shown: 222 | self.graphic(self.board, player1.player, player2.player) 223 | end, winner = self.board.game_end() 224 | if end: 225 | return winner 226 | 227 | def start_self_play(self, player, is_shown=0, temp=1e-3): 228 | """ start a self-play game using a MCTS player, reuse the search tree 229 | store the self-play data: (state, mcts_probs, z) 230 | """ 231 | self.board.init_board() 232 | p1, p2 = self.board.players 233 | states, mcts_probs, current_players = [], [], [] 234 | while(1): 235 | move, move_probs = player.get_action(self.board, temp=temp, return_prob=1) 236 | # store the data 237 | states.append(self.board.current_state()) 238 | mcts_probs.append(move_probs) 239 | current_players.append(self.board.current_player) 240 | # perform a move 241 | self.board.do_move(move) 242 | end, winner = self.board.game_end() 243 | if end: 244 | # winner from the perspective of the current player of each state 245 | winners_z = np.zeros(len(current_players)) 246 | if winner != -1: 247 | winners_z[np.array(current_players) == winner] = 1.0 248 | winners_z[np.array(current_players) != winner] = -1.0 249 | #reset MCTS root node 250 | player.reset_player() 251 | if is_shown: 252 | if winner != -1: 253 | print("Game end. Winner is player:", winner) 254 | else: 255 | print("Game end. Tie") 256 | return winner, zip(states, mcts_probs, winners_z) 257 | -------------------------------------------------------------------------------- /human_play.py: -------------------------------------------------------------------------------- 1 | 2 | from game import Board, Game 3 | 4 | from tf_policy_value_net import PolicyValueNet 5 | from mcts_alphaZero import MCTSPlayer 6 | 7 | 8 | class Human(object): 9 | """ 10 | human player 11 | """ 12 | 13 | def __init__(self): 14 | self.player = None 15 | 16 | def set_player_ind(self, p): 17 | self.player = p 18 | 19 | def get_action(self, board): 20 | try: 21 | location = input("Your move: ") 22 | if isinstance(location, str): 23 | location = [int(n, 10) for n in location.split(",")] # for python3 24 | move = board.location_to_move(location) 25 | except Exception as e: 26 | move = -1 27 | if move == -1 or move not in board.availables: 28 | print("invalid move") 29 | move = self.get_action(board) 30 | return move 31 | 32 | def __str__(self): 33 | return "Human {}".format(self.player) 34 | 35 | 36 | def run(): 37 | n_row = 5 38 | width, height = 11, 11 39 | 40 | try: 41 | board = Board(width=width, height=height, n_in_row=n_row) 42 | game = Game(board) 43 | 44 | ################ human VS AI ################### 45 | 46 | best_policy = PolicyValueNet(width, height, n_row) 47 | mcts_player = MCTSPlayer(best_policy.policy_value_fn, c_puct=5, n_playout=400) # set larger n_playout for better performance 48 | 49 | human = Human() 50 | 51 | # set start_player=0 for human first 52 | game.start_play(human, mcts_player, start_player=1, is_shown=1) 53 | except KeyboardInterrupt: 54 | print('\n\rquit') 55 | 56 | if __name__ == '__main__': 57 | run() 58 | 59 | 60 | -------------------------------------------------------------------------------- /mcts_alphaZero.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import copy 4 | 5 | 6 | def softmax(x): 7 | probs = np.exp(x - np.max(x)) 8 | probs /= np.sum(probs) 9 | return probs 10 | 11 | class TreeNode(object): 12 | """A node in the MCTS tree. Each node keeps track of its own value Q, prior probability P, and 13 | its visit-count-adjusted prior score u. 14 | """ 15 | 16 | def __init__(self, parent, prior_p): 17 | self._parent = parent 18 | self._children = {} # a map from action to TreeNode 19 | self._n_visits = 0 20 | self._Q = 0 21 | self._u = 0 22 | self._P = prior_p 23 | 24 | def expand(self, action_priors): 25 | """Expand tree by creating new children. 26 | action_priors -- output from policy function - a list of tuples of actions 27 | and their prior probability according to the policy function. 28 | """ 29 | for action, prob in action_priors: 30 | if action not in self._children: 31 | self._children[action] = TreeNode(self, prob) 32 | 33 | def select(self, c_puct): 34 | """Select action among children that gives maximum action value, Q plus bonus u(P). 35 | Returns: 36 | A tuple of (action, next_node) 37 | """ 38 | return max(self._children.items(), key=lambda act_node: act_node[1].get_value(c_puct)) 39 | 40 | def update(self, leaf_value): 41 | """Update node values from leaf evaluation. 42 | """ 43 | # Count visit. 44 | self._n_visits += 1 45 | # Update Q, a running average of values for all visits. 46 | self._Q += 1.0*(leaf_value - self._Q) / self._n_visits 47 | 48 | def update_recursive(self, leaf_value): 49 | """Like a call to update(), but applied recursively for all ancestors. 50 | """ 51 | # If it is not root, this node's parent should be updated first. 52 | 53 | if self._parent: 54 | self._parent.update_recursive(-leaf_value) 55 | self.update(leaf_value) 56 | 57 | def get_value(self, c_puct): 58 | """Calculate and return the value for this node: a combination of leaf evaluations, Q, and 59 | this node's prior adjusted for its visit count, u 60 | c_puct -- a number in (0, inf) controlling the relative impact of values, Q, and 61 | prior probability, P, on this node's score. 62 | """ 63 | self._u = c_puct * self._P * np.sqrt(self._parent._n_visits) / (1 + self._n_visits) 64 | return self._Q + self._u 65 | 66 | def is_leaf(self): 67 | """Check if leaf node (i.e. no nodes below this have been expanded). 68 | """ 69 | return self._children == {} 70 | 71 | def is_root(self): 72 | return self._parent is None 73 | 74 | 75 | class MCTS(object): 76 | """A simple implementation of Monte Carlo Tree Search. 77 | """ 78 | 79 | def __init__(self, policy_value_fn, c_puct=5, n_playout=10000): 80 | """Arguments: 81 | policy_value_fn -- a function that takes in a board state and outputs a list of (action, probability) 82 | tuples and also a score in [-1, 1] (i.e. the expected value of the end game score from 83 | the current player's perspective) for the current player. 84 | c_puct -- a number in (0, inf) that controls how quickly exploration converges to the 85 | maximum-value policy, where a higher value means relying on the prior more 86 | """ 87 | self._root = TreeNode(None, 1.0) 88 | self._policy = policy_value_fn 89 | self._c_puct = c_puct 90 | self._n_playout = n_playout 91 | 92 | def _playout(self, state): 93 | """Run a single playout from the root to the leaf, getting a value at the leaf and 94 | propagating it back through its parents. State is modified in-place, so a copy must be 95 | provided. 96 | Arguments: 97 | state -- a copy of the state. 98 | """ 99 | node = self._root 100 | while(1): 101 | if node.is_leaf(): 102 | break 103 | # Greedily select next move. 104 | action, node = node.select(self._c_puct) 105 | state.do_move(action) 106 | 107 | # Evaluate the leaf using a network which outputs a list of (action, probability) 108 | # tuples p and also a score v in [-1, 1] for the current player. 109 | action_probs, leaf_value = self._policy(state) 110 | 111 | # Check for end of game. 112 | end, winner = state.game_end() 113 | if not end: 114 | node.expand(action_probs) 115 | else: 116 | # for end state,return the "true" leaf_value 117 | if winner == -1: # tie 118 | leaf_value = 0.0 119 | else: 120 | leaf_value = 1.0 if winner == state.get_current_player() else -1.0 121 | 122 | # Update value and visit count of nodes in this traversal. 123 | node.update_recursive(-leaf_value) 124 | 125 | def get_move_probs(self, state, temp=1e-3): 126 | """Runs all playouts sequentially and returns the available actions and their corresponding probabilities 127 | Arguments: 128 | state -- the current state, including both game state and the current player. 129 | temp -- temperature parameter in (0, 1] that controls the level of exploration 130 | Returns: 131 | the available actions and the corresponding probabilities 132 | """ 133 | for n in range(self._n_playout): 134 | state_copy = copy.deepcopy(state) 135 | self._playout(state_copy) 136 | 137 | # calc the move probabilities based on the visit counts at the root node 138 | act_visits = [(act, node._n_visits) for act, node in self._root._children.items()] 139 | acts, visits = zip(*act_visits) 140 | act_probs = softmax(1.0/temp * np.log(visits)) 141 | 142 | return acts, act_probs 143 | 144 | def update_with_move(self, last_move): 145 | """Step forward in the tree, keeping everything we already know about the subtree. 146 | """ 147 | if last_move in self._root._children: 148 | self._root = self._root._children[last_move] 149 | self._root._parent = None 150 | else: 151 | self._root = TreeNode(None, 1.0) 152 | 153 | def __str__(self): 154 | return "MCTS" 155 | 156 | 157 | class MCTSPlayer(object): 158 | """AI player based on MCTS""" 159 | def __init__(self, policy_value_function, c_puct=5, n_playout=2000, is_selfplay=0): 160 | self.mcts = MCTS(policy_value_function, c_puct, n_playout) 161 | self._is_selfplay = is_selfplay 162 | 163 | def set_player_ind(self, p): 164 | self.player = p 165 | 166 | def reset_player(self): 167 | self.mcts.update_with_move(-1) 168 | 169 | def get_action(self, board, temp=1e-3, return_prob=0): 170 | sensible_moves = board.availables 171 | move_probs = np.zeros(board.width*board.height) # the pi vector returned by MCTS as in the alphaGo Zero paper 172 | if len(sensible_moves) > 0: 173 | acts, probs = self.mcts.get_move_probs(board, temp) 174 | move_probs[list(acts)] = probs 175 | if self._is_selfplay: 176 | # add Dirichlet Noise for exploration (needed for self-play training) 177 | move = np.random.choice(acts, p=0.75*probs + 0.25*np.random.dirichlet(0.3*np.ones(len(probs)))) 178 | self.mcts.update_with_move(move) # update the root node and reuse the search tree 179 | else: 180 | # with the default temp=1e-3, this is almost equivalent to choosing the move with the highest prob 181 | move = np.random.choice(acts, p=probs) 182 | # reset the root node 183 | self.mcts.update_with_move(-1) 184 | 185 | if return_prob: 186 | return move, move_probs 187 | else: 188 | return move 189 | else: 190 | print("WARNING: the board is full") 191 | 192 | def __str__(self): 193 | return "MCTS {}".format(self.player) -------------------------------------------------------------------------------- /mcts_pure.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat Feb 10 16:37:43 2018 4 | 5 | @author: zou 6 | """ 7 | 8 | import numpy as np 9 | import copy 10 | from operator import itemgetter 11 | 12 | def rollout_policy_fn(board): 13 | """rollout_policy_fn -- a coarse, fast version of policy_fn used in the rollout phase.""" 14 | # rollout randomly 15 | action_probs = np.random.rand(len(board.availables)) 16 | return zip(board.availables, action_probs) 17 | 18 | def policy_value_fn(board): 19 | """a function that takes in a state and outputs a list of (action, probability) 20 | tuples and a score for the state""" 21 | # return uniform probabilities and 0 score for pure MCTS 22 | action_probs = np.ones(len(board.availables))/len(board.availables) 23 | return zip(board.availables, action_probs), 0 24 | 25 | class TreeNode(object): 26 | """A node in the MCTS tree. Each node keeps track of its own value Q, prior probability P, and 27 | its visit-count-adjusted prior score u. 28 | """ 29 | 30 | def __init__(self, parent, prior_p): 31 | self._parent = parent 32 | self._children = {} # a map from action to TreeNode 33 | self._n_visits = 0 34 | self._Q = 0 35 | self._u = 0 36 | self._P = prior_p 37 | 38 | def expand(self, action_priors): 39 | """Expand tree by creating new children. 40 | action_priors -- output from policy function - a list of tuples of actions 41 | and their prior probability according to the policy function. 42 | """ 43 | for action, prob in action_priors: 44 | if action not in self._children: 45 | self._children[action] = TreeNode(self, prob) 46 | 47 | def select(self, c_puct): 48 | """Select action among children that gives maximum action value, Q plus bonus u(P). 49 | Returns: 50 | A tuple of (action, next_node) 51 | """ 52 | return max(self._children.items(), key=lambda act_node: act_node[1].get_value(c_puct)) 53 | 54 | def update(self, leaf_value): 55 | """Update node values from leaf evaluation. 56 | Arguments: 57 | leaf_value -- the value of subtree evaluation from the current player's perspective. 58 | """ 59 | # Count visit. 60 | self._n_visits += 1 61 | # Update Q, a running average of values for all visits. 62 | self._Q += 1.0*(leaf_value - self._Q) / self._n_visits 63 | 64 | def update_recursive(self, leaf_value): 65 | """Like a call to update(), but applied recursively for all ancestors. 66 | """ 67 | # If it is not root, this node's parent should be updated first. 68 | if self._parent: 69 | self._parent.update_recursive(-leaf_value) 70 | self.update(leaf_value) 71 | 72 | def get_value(self, c_puct): 73 | """Calculate and return the value for this node: a combination of leaf evaluations, Q, and 74 | this node's prior adjusted for its visit count, u 75 | c_puct -- a number in (0, inf) controlling the relative impact of values, Q, and 76 | prior probability, P, on this node's score. 77 | """ 78 | self._u = c_puct * self._P * np.sqrt(self._parent._n_visits) / (1 + self._n_visits) 79 | return self._Q + self._u 80 | 81 | def is_leaf(self): 82 | """Check if leaf node (i.e. no nodes below this have been expanded). 83 | """ 84 | return self._children == {} 85 | 86 | def is_root(self): 87 | return self._parent is None 88 | 89 | 90 | class MCTS(object): 91 | """A simple implementation of Monte Carlo Tree Search. 92 | """ 93 | 94 | def __init__(self, policy_value_fn, c_puct=5, n_playout=10000): 95 | """Arguments: 96 | policy_value_fn -- a function that takes in a board state and outputs a list of (action, probability) 97 | tuples and also a score in [-1, 1] (i.e. the expected value of the end game score from 98 | the current player's perspective) for the current player. 99 | c_puct -- a number in (0, inf) that controls how quickly exploration converges to the 100 | maximum-value policy, where a higher value means relying on the prior more 101 | """ 102 | self._root = TreeNode(None, 1.0) 103 | self._policy = policy_value_fn 104 | self._c_puct = c_puct 105 | self._n_playout = n_playout 106 | 107 | def _playout(self, state): 108 | """Run a single playout from the root to the leaf, getting a value at the leaf and 109 | propagating it back through its parents. State is modified in-place, so a copy must be 110 | provided. 111 | Arguments: 112 | state -- a copy of the state. 113 | """ 114 | node = self._root 115 | while(1): 116 | if node.is_leaf(): 117 | 118 | break 119 | # Greedily select next move. 120 | action, node = node.select(self._c_puct) 121 | state.do_move(action) 122 | 123 | action_probs, _ = self._policy(state) 124 | # Check for end of game 125 | end, winner = state.game_end() 126 | if not end: 127 | node.expand(action_probs) 128 | # Evaluate the leaf node by random rollout 129 | leaf_value = self._evaluate_rollout(state) 130 | # Update value and visit count of nodes in this traversal. 131 | node.update_recursive(-leaf_value) 132 | 133 | def _evaluate_rollout(self, state, limit=1000): 134 | """Use the rollout policy to play until the end of the game, returning +1 if the current 135 | player wins, -1 if the opponent wins, and 0 if it is a tie. 136 | """ 137 | player = state.get_current_player() 138 | for i in range(limit): 139 | end, winner = state.game_end() 140 | if end: 141 | break 142 | action_probs = rollout_policy_fn(state) 143 | max_action = max(action_probs, key=itemgetter(1))[0] 144 | state.do_move(max_action) 145 | else: 146 | # If no break from the loop, issue a warning. 147 | print("WARNING: rollout reached move limit") 148 | if winner == -1: # tie 149 | return 0 150 | else: 151 | return 1 if winner == player else -1 152 | 153 | def get_move(self, state): 154 | """Runs all playouts sequentially and returns the most visited action. 155 | Arguments: 156 | state -- the current state, including both game state and the current player. 157 | Returns: 158 | the selected action 159 | """ 160 | for n in range(self._n_playout): 161 | state_copy = copy.deepcopy(state) 162 | self._playout(state_copy) 163 | return max(self._root._children.items(), key=lambda act_node: act_node[1]._n_visits)[0] 164 | 165 | def update_with_move(self, last_move): 166 | """Step forward in the tree, keeping everything we already know about the subtree. 167 | """ 168 | if last_move in self._root._children: 169 | self._root = self._root._children[last_move] 170 | self._root._parent = None 171 | else: 172 | self._root = TreeNode(None, 1.0) 173 | 174 | def __str__(self): 175 | return "MCTS" 176 | 177 | 178 | class MCTSPlayer(object): 179 | """AI player based on MCTS""" 180 | def __init__(self, c_puct=5, n_playout=2000): 181 | self.mcts = MCTS(policy_value_fn, c_puct, n_playout) 182 | 183 | def set_player_ind(self, p): 184 | self.player = p 185 | 186 | def reset_player(self): 187 | self.mcts.update_with_move(-1) 188 | 189 | def get_action(self, board): 190 | sensible_moves = board.availables 191 | if len(sensible_moves) > 0: 192 | move = self.mcts.get_move(board) 193 | self.mcts.update_with_move(-1) 194 | return move 195 | else: 196 | print("WARNING: the board is full") 197 | 198 | def __str__(self): 199 | return "MCTS {}".format(self.player) -------------------------------------------------------------------------------- /model/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "tf_policy_11_11_5_model" 2 | all_model_checkpoint_paths: "tf_policy_11_11_5_model" 3 | -------------------------------------------------------------------------------- /model/tf_policy_11_11_5_model.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zouyih/AlphaZero_Gomoku-tensorflow/45aa66b01df1f4618a17d3055bfe183828bec4f7/model/tf_policy_11_11_5_model.data-00000-of-00001 -------------------------------------------------------------------------------- /model/tf_policy_11_11_5_model.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zouyih/AlphaZero_Gomoku-tensorflow/45aa66b01df1f4618a17d3055bfe183828bec4f7/model/tf_policy_11_11_5_model.index -------------------------------------------------------------------------------- /model/tf_policy_11_11_5_model.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zouyih/AlphaZero_Gomoku-tensorflow/45aa66b01df1f4618a17d3055bfe183828bec4f7/model/tf_policy_11_11_5_model.meta -------------------------------------------------------------------------------- /model/tf_policy_8_8_5_model.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zouyih/AlphaZero_Gomoku-tensorflow/45aa66b01df1f4618a17d3055bfe183828bec4f7/model/tf_policy_8_8_5_model.data-00000-of-00001 -------------------------------------------------------------------------------- /model/tf_policy_8_8_5_model.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zouyih/AlphaZero_Gomoku-tensorflow/45aa66b01df1f4618a17d3055bfe183828bec4f7/model/tf_policy_8_8_5_model.index -------------------------------------------------------------------------------- /model/tf_policy_8_8_5_model.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zouyih/AlphaZero_Gomoku-tensorflow/45aa66b01df1f4618a17d3055bfe183828bec4f7/model/tf_policy_8_8_5_model.meta -------------------------------------------------------------------------------- /tf_policy_value_net.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Feb 9 12:29:58 2018 4 | 5 | @author: zou 6 | """ 7 | import tensorflow as tf 8 | import os 9 | 10 | 11 | class PolicyValueNet(): 12 | """policy-value network """ 13 | def __init__(self, board_width, board_height, n_in_row): 14 | tf.reset_default_graph() 15 | self.board_width = board_width 16 | self.board_height = board_height 17 | self.model_file = './model/tf_policy_{}_{}_{}_model'.format(board_width, board_height, n_in_row) 18 | self.sess = tf.Session() 19 | self.l2_const = 1e-4 # coef of l2 penalty 20 | self._create_policy_value_net() 21 | self._loss_train_op() 22 | self.saver = tf.train.Saver() 23 | self.restore_model() 24 | 25 | def _create_policy_value_net(self): 26 | """create the policy value network """ 27 | with tf.name_scope("inputs"): 28 | self.state_input = tf.placeholder(tf.float32, shape=[None, 4, self.board_width, self.board_height], name="state") 29 | # self.state = tf.transpose(self.state_input, [0, 2, 3, 1]) 30 | 31 | self.winner = tf.placeholder(tf.float32, shape=[None], name="winner") 32 | self.winner_reshape = tf.reshape(self.winner, [-1,1]) 33 | self.mcts_probs = tf.placeholder(tf.float32, shape=[None, self.board_width*self.board_height], name="mcts_probs") 34 | 35 | # conv layers 36 | conv1 = tf.layers.conv2d(self.state_input, filters=32, kernel_size=3, 37 | strides=1, padding="SAME", data_format='channels_first', 38 | activation=tf.nn.relu, name="conv1") 39 | conv2 = tf.layers.conv2d(conv1, filters=64, kernel_size=3, 40 | strides=1, padding="SAME", data_format='channels_first', 41 | activation=tf.nn.relu, name="conv2") 42 | conv3 = tf.layers.conv2d(conv2, filters=128, kernel_size=3, 43 | strides=1, padding="SAME", data_format='channels_first', 44 | activation=tf.nn.relu, name="conv3") 45 | 46 | # action policy layers 47 | policy_net = tf.layers.conv2d(conv3, filters=4, kernel_size=1, 48 | strides=1, padding="SAME", data_format='channels_first', 49 | activation=tf.nn.relu, name="policy_net") 50 | policy_net_flat = tf.reshape(policy_net, shape=[-1, 4*self.board_width*self.board_height]) 51 | self.policy_net_out = tf.layers.dense(policy_net_flat, self.board_width*self.board_height, name="output") 52 | self.action_probs = tf.nn.softmax(self.policy_net_out, name="policy_net_proba") 53 | 54 | # state value layers 55 | value_net = tf.layers.conv2d(conv3, filters=2, kernel_size=1, data_format='channels_first', 56 | name='value_conv', activation=tf.nn.relu) 57 | value_net = tf.layers.dense(tf.contrib.layers.flatten(value_net), 64, activation=tf.nn.relu) 58 | self.value = tf.layers.dense(value_net, units=1, activation=tf.nn.tanh) 59 | 60 | def _loss_train_op(self): 61 | """ 62 | Three loss terms: 63 | loss = (z - v)^2 + pi^T * log(p) + c||theta||^2 64 | """ 65 | l2_penalty = 0 66 | for v in tf.trainable_variables(): 67 | if not 'bias' in v.name.lower(): 68 | l2_penalty += tf.nn.l2_loss(v) 69 | value_loss = tf.reduce_mean(tf.square(self.winner_reshape - self.value)) 70 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=self.policy_net_out, labels=self.mcts_probs) 71 | policy_loss = tf.reduce_mean(cross_entropy) 72 | self.loss = value_loss + policy_loss + self.l2_const*l2_penalty 73 | # policy entropy,for monitoring only 74 | self.entropy = policy_loss 75 | # get the train op 76 | self.learning_rate = tf.placeholder(tf.float32) 77 | optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate) 78 | self.training_op = optimizer.minimize(self.loss) 79 | 80 | def get_policy_value(self, state_batch): 81 | # get action probs and state score value 82 | action_probs, value = self.sess.run([self.action_probs, self.value], 83 | feed_dict={self.state_input: state_batch}) 84 | return action_probs, value 85 | 86 | def policy_value_fn(self, board): 87 | """ 88 | input: board 89 | output: a list of (action, probability) tuples for each available action and the score of the board state 90 | """ 91 | legal_positions = board.availables 92 | current_state = board.current_state() 93 | act_probs, value = self.sess.run([self.action_probs, self.value], 94 | feed_dict={self.state_input: current_state.reshape(-1, 4, self.board_width, self.board_height)}) 95 | act_probs = zip(legal_positions, act_probs.flatten()[legal_positions]) 96 | return act_probs, value[0][0] 97 | 98 | def train_step(self, state_batch, mcts_probs_batch, winner_batch, lr): 99 | feed_dict = {self.state_input : state_batch, 100 | self.mcts_probs : mcts_probs_batch, 101 | self.winner : winner_batch, 102 | self.learning_rate: lr} 103 | 104 | loss, entropy, _ = self.sess.run([self.loss, self.entropy, self.training_op], 105 | feed_dict=feed_dict) 106 | return loss, entropy 107 | 108 | 109 | def restore_model(self): 110 | if os.path.exists(self.model_file + '.meta'): 111 | self.saver.restore(self.sess, self.model_file) 112 | else: 113 | self.sess.run(tf.global_variables_initializer()) 114 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | 2 | import random 3 | import numpy as np 4 | 5 | from collections import defaultdict 6 | from collections import deque 7 | 8 | from game import Board, Game 9 | from tf_policy_value_net import PolicyValueNet 10 | from mcts_pure import MCTSPlayer as MCTS_Pure 11 | from mcts_alphaZero import MCTSPlayer 12 | 13 | #import os 14 | #os.environ["CUDA_VISIBLE_DEVICES"] = "2" 15 | 16 | 17 | class TrainPipeline(): 18 | def __init__(self): 19 | # params of the board and the game 20 | self.board_width = 11 21 | self.board_height = 11 22 | self.n_in_row = 5 23 | self.board = Board(width=self.board_width, height=self.board_height, n_in_row=self.n_in_row) 24 | self.game = Game(self.board) 25 | # training params 26 | self.learn_rate = 0.001 27 | self.lr_multiplier = 1.0 # adaptively adjust the learning rate based on KL 28 | self.temp = 1.0 # the temperature param 29 | self.n_playout = 400 # num of simulations for each move 30 | self.c_puct = 5 31 | self.buffer_size = 10000 32 | self.batch_size = 128 # mini-batch size for training 33 | self.data_buffer = deque(maxlen=self.buffer_size) 34 | self.play_batch_size = 1 35 | self.epochs = 5 # num of train_steps for each update 36 | self.kl_targ = 0.02 37 | self.check_freq = 1000 38 | self.game_batch_num = 50000000 39 | self.best_win_ratio = 0.0 40 | # num of simulations used for the pure mcts, which is used as the opponent to evaluate the trained policy 41 | self.pure_mcts_playout_num = 3000 42 | 43 | # start training from a new policy-value net 44 | self.policy_value_net = PolicyValueNet(self.board_width, self.board_height, self.n_in_row) 45 | self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, 46 | c_puct=self.c_puct, 47 | n_playout=self.n_playout, is_selfplay=1) 48 | 49 | def get_equi_data(self, play_data): 50 | """ 51 | augment the data set by rotation and flipping 52 | play_data: [(state, mcts_prob, winner_z), ..., ...]""" 53 | extend_data = [] 54 | for state, mcts_porb, winner in play_data: 55 | for i in [1,2,3,4]: 56 | # rotate counterclockwise 57 | equi_state = np.array([np.rot90(s,i) for s in state]) 58 | equi_mcts_prob = np.rot90(np.flipud(mcts_porb.reshape(self.board_height, self.board_width)), i) 59 | extend_data.append((equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) 60 | # flip horizontally 61 | equi_state = np.array([np.fliplr(s) for s in equi_state]) 62 | equi_mcts_prob = np.fliplr(equi_mcts_prob) 63 | extend_data.append((equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) 64 | return extend_data 65 | 66 | def collect_selfplay_data(self, n_games=1): 67 | """collect self-play data for training""" 68 | for i in range(n_games): 69 | winner, play_data = self.game.start_self_play(self.mcts_player, temp=self.temp) 70 | # augment the data 71 | play_data = self.get_equi_data(play_data) 72 | self.episode_len = len(play_data) / 8 73 | self.data_buffer.extend(play_data) 74 | 75 | def policy_update(self, verbose=False): 76 | """update the policy-value net""" 77 | mini_batch = random.sample(self.data_buffer, self.batch_size) 78 | state_batch = [data[0] for data in mini_batch] 79 | mcts_probs_batch = [data[1] for data in mini_batch] 80 | winner_batch = [data[2] for data in mini_batch] 81 | 82 | old_probs, old_v = self.policy_value_net.get_policy_value(state_batch) 83 | 84 | loss_list = [] 85 | entropy_list = [] 86 | for i in range(self.epochs): 87 | loss, entropy = self.policy_value_net.train_step(state_batch, 88 | mcts_probs_batch, 89 | winner_batch, 90 | self.learn_rate*self.lr_multiplier) 91 | 92 | loss_list.append(loss) 93 | entropy_list.append(entropy) 94 | 95 | new_probs, new_v = self.policy_value_net.get_policy_value(state_batch) 96 | kl = np.mean(np.sum(old_probs * ( 97 | np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), 98 | axis=1) 99 | ) 100 | if kl > self.kl_targ * 4: # early stopping if D_KL diverges badly 101 | break 102 | 103 | if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1: 104 | self.lr_multiplier /= 1.5 105 | elif kl < self.kl_targ / 2 and self.lr_multiplier < 10: 106 | self.lr_multiplier *= 1.5 107 | 108 | if verbose: 109 | explained_var_old = (1 - 110 | np.var(np.array(winner_batch) - old_v.flatten()) / 111 | np.var(np.array(winner_batch))) 112 | explained_var_new = (1 - 113 | np.var(np.array(winner_batch) - new_v.flatten()) / 114 | np.var(np.array(winner_batch))) 115 | 116 | print(("kl: {:.3f}, " 117 | "lr_multiplier: {:.3f}\n" 118 | "loss: {:.3f}, " 119 | "entropy: {:.3f}\n" 120 | "explained old: {:.3f}, " 121 | "explained new: {:.3f}\n" 122 | ).format(kl, 123 | self.lr_multiplier, 124 | np.mean(loss_list), 125 | np.mean(entropy_list), 126 | explained_var_old, 127 | explained_var_new)) 128 | 129 | 130 | def policy_evaluate(self, n_games=10): 131 | """ 132 | Evaluate the trained policy by playing games against the pure MCTS player 133 | Note: this is only for monitoring the progress of training 134 | """ 135 | current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout) 136 | pure_mcts_player = MCTS_Pure(c_puct=5, n_playout=self.pure_mcts_playout_num) 137 | win_cnt = defaultdict(int) 138 | for i in range(n_games): 139 | winner = self.game.start_play(current_mcts_player, pure_mcts_player, start_player=i%2, is_shown=0) 140 | print("winner is {}".format(winner)) 141 | win_cnt[winner] += 1 142 | win_ratio = 1.0*(win_cnt[1] + 0.5*win_cnt[-1])/n_games 143 | print("num_playouts:{}, win: {}, lose: {}, tie:{}".format(self.pure_mcts_playout_num, win_cnt[1], win_cnt[2], win_cnt[-1])) 144 | return win_ratio 145 | 146 | def run(self): 147 | """run the training pipeline""" 148 | try: 149 | for i in range(self.game_batch_num): 150 | self.collect_selfplay_data(self.play_batch_size) 151 | 152 | if len(self.data_buffer) > self.batch_size: 153 | print("#### batch i:{}, episode_len:{} ####\n".format(i+1, self.episode_len)) 154 | for i in range(5): 155 | verbose = i % 5 == 0 156 | self.policy_update(verbose) 157 | # check the performance of the current model,and save the model params 158 | if (i+1) % self.check_freq == 0: 159 | print("current self-play batch: {}".format(i+1)) 160 | self.policy_value_net.saver.save(self.policy_value_net.sess, self.policy_value_net.model_file) 161 | win_ratio = self.policy_evaluate() 162 | print('*****win ration: {:.2f}%\n'.format(win_ratio*100)) 163 | 164 | if win_ratio > self.best_win_ratio: 165 | print("New best policy!!!!!!!!") 166 | self.best_win_ratio = win_ratio 167 | self.policy_value_net.saver.save(self.policy_value_net.sess, self.policy_value_net.model_file) # update the best_policy 168 | if self.best_win_ratio == 1.0 and self.pure_mcts_playout_num < 5000: 169 | self.pure_mcts_playout_num += 100 170 | self.best_win_ratio = 0.0 171 | except KeyboardInterrupt: 172 | self.policy_value_net.saver.save(self.policy_value_net.sess, self.policy_value_net.model_file) 173 | print('\n\rquit') 174 | 175 | if __name__ == '__main__': 176 | 177 | training_pipeline = TrainPipeline() 178 | training_pipeline.run() 179 | --------------------------------------------------------------------------------