├── images ├── README.md └── 1.png ├── mcts ├── README.md ├── __init__.py ├── search.py └── nodes.py ├── README.md ├── run.py └── tictactoe.py /images/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /mcts/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /mcts/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /images/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cryer/monte-carlo-tree-search/HEAD/images/1.png -------------------------------------------------------------------------------- /mcts/search.py: -------------------------------------------------------------------------------- 1 | from mcts.nodes import MonteCarloTreeSearchNode 2 | 3 | 4 | class MonteCarloTreeSearch: 5 | def __init__(self, node: MonteCarloTreeSearchNode): 6 | self.root = node 7 | 8 | def best_action(self, simulations_number): 9 | for _ in range(0, simulations_number): 10 | v = self.tree_policy() 11 | reward = v.rollout() 12 | v.backpropagate(reward) 13 | # exploitation only 14 | return self.root.best_child(c_param=0.) 15 | 16 | def tree_policy(self): 17 | current_node = self.root 18 | while not current_node.is_terminal_node(): 19 | if not current_node.is_fully_expanded(): 20 | return current_node.expand() 21 | else: 22 | current_node = current_node.best_child() 23 | return current_node 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # monte-carlo-tree-search 2 | 3 | Simple Game about Tic Tac Toe with Monte Carlo Tree Search. 4 | 5 | Just like below: 6 | 7 | ![](./images/1.png) 8 | 9 | ## Run 10 | 11 | run the code: 12 | 13 | ``` 14 | git clone https://github.com/cryer/monte-carlo-tree-search.git 15 | cd monte-carlo-tree-search 16 | python run.py 17 | ``` 18 | 19 | ## Note 20 | 21 | As we all know,in Tic Tac Toe game,the starting player never lose,at least tie,asuming both are master. 22 | But someone may not know that the first best move of the starting player is not the middle,but the corner. 23 | MCTS also confirms this,but that's just for probability,for masters they always tie. 24 | 25 | * ```I made this small game only for showing how Monte Carlo Tree Search works,you may focus on mcts and apply it to your only apps``` 26 | 27 | ## Results 28 | 29 | I just show you the process,do not mind my move. 30 | ``` 31 | 0 | _ _ _ 32 | 1 | _ X _ 33 | 2 | _ _ _ 34 | ______________________________ 35 | Your move: 0,0 36 | 37 | 0 | O _ _ 38 | 1 | _ X _ 39 | 2 | _ _ _ 40 | ______________________________ 41 | 42 | 0 | O _ _ 43 | 1 | X X _ 44 | 2 | _ _ _ 45 | ______________________________ 46 | Your move: 1,2 47 | 48 | 0 | O _ _ 49 | 1 | X X O 50 | 2 | _ _ _ 51 | ______________________________ 52 | 53 | 0 | O X _ 54 | 1 | X X O 55 | 2 | _ _ _ 56 | ______________________________ 57 | Your move: 2,0 58 | 59 | 0 | O X _ 60 | 1 | X X O 61 | 2 | O _ _ 62 | ______________________________ 63 | 64 | 0 | O X _ 65 | 1 | X X O 66 | 2 | O X _ 67 | ______________________________ 68 | You lose! 69 | 70 | ``` 71 | MCTS is the starting player,it's also very easy for you to change code to set human as the starting player.Just modify run.py. 72 | 73 | ## Inspiration 74 | 75 | * [int8](https://github.com/int8/monte-carlo-tree-search) 76 | 77 | -------------------------------------------------------------------------------- /mcts/nodes.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import defaultdict 3 | from tictactoe import * 4 | 5 | class MonteCarloTreeSearchNode(object): 6 | def __init__(self, state: TicTacToeGameState, parent=None): 7 | self._number_of_visits = 0. 8 | self._results = defaultdict(int) 9 | self.state = state 10 | self.parent = parent 11 | self.children = [] 12 | 13 | @property 14 | def untried_actions(self): 15 | if not hasattr(self, '_untried_actions'): 16 | self._untried_actions = self.state.get_legal_actions() 17 | return self._untried_actions 18 | 19 | @property 20 | def q(self): 21 | wins = self._results[self.parent.state.next_to_move] 22 | loses = self._results[-1 * self.parent.state.next_to_move] 23 | return wins - loses 24 | 25 | @property 26 | def n(self): 27 | return self._number_of_visits 28 | 29 | def expand(self): 30 | action = self.untried_actions.pop() 31 | next_state = self.state.move(action) 32 | child_node = MonteCarloTreeSearchNode(next_state, parent=self) 33 | self.children.append(child_node) 34 | return child_node 35 | 36 | def is_terminal_node(self): 37 | return self.state.is_game_over() 38 | 39 | def rollout(self): 40 | current_rollout_state = self.state 41 | while not current_rollout_state.is_game_over(): 42 | possible_moves = current_rollout_state.get_legal_actions() 43 | action = self.rollout_policy(possible_moves) 44 | current_rollout_state = current_rollout_state.move(action) 45 | return current_rollout_state.game_result 46 | 47 | def backpropagate(self, result): 48 | self._number_of_visits += 1. 49 | self._results[result] += 1. 50 | if self.parent: 51 | self.parent.backpropagate(result) 52 | 53 | def is_fully_expanded(self): 54 | return len(self.untried_actions) == 0 55 | 56 | def best_child(self, c_param=1.4): 57 | choices_weights = [ 58 | (c.q / (c.n)) + c_param * np.sqrt((2 * np.log(self.n) / (c.n))) 59 | for c in self.children 60 | ] 61 | return self.children[np.argmax(choices_weights)] 62 | 63 | def rollout_policy(self, possible_moves): 64 | return possible_moves[np.random.randint(len(possible_moves))] 65 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from mcts.nodes import * 3 | from mcts.search import MonteCarloTreeSearch 4 | from tictactoe import TicTacToeGameState 5 | 6 | 7 | def init(): 8 | state = np.zeros((3, 3)) 9 | initial_board_state = TicTacToeGameState(state=state, next_to_move=1) 10 | root = MonteCarloTreeSearchNode(state=initial_board_state, parent=None) 11 | mcts = MonteCarloTreeSearch(root) 12 | best_node = mcts.best_action(1000) 13 | c_state = best_node.state 14 | c_board = c_state.board 15 | return c_state,c_board 16 | 17 | 18 | def graphics(board): 19 | for i in range(3): 20 | print("") 21 | print("{0:3}".format(i).center(8)+"|", end='') 22 | for j in range(3): 23 | if c_board[i][j] == 0: 24 | print('_'.center(8), end='') 25 | if c_board[i][j] == 1: 26 | print('X'.center(8), end='') 27 | if c_board[i][j] == -1: 28 | print('O'.center(8), end='') 29 | print("") 30 | print("______________________________") 31 | 32 | 33 | def get_action(state): 34 | try: 35 | location = input("Your move: ") 36 | if isinstance(location, str): 37 | location = [int(n, 10) for n in location.split(",")] 38 | if len(location) != 2: 39 | return -1 40 | x = location[0] 41 | y = location[1] 42 | move = TicTacToeMove(x, y, -1) 43 | except Exception as e: 44 | move = -1 45 | if move == -1 or not state.is_move_legal(move): 46 | print("invalid move") 47 | move = get_action(state) 48 | return move 49 | 50 | 51 | def judge(state): 52 | if state.is_game_over(): 53 | if state.game_result == 1.0: 54 | print("You lose!") 55 | if state.game_result == 0.0: 56 | print("Tie!") 57 | if state.game_result == -1.0: 58 | print("You Win!") 59 | return 1 60 | else: 61 | return -1 62 | 63 | 64 | c_state,c_board = init() 65 | graphics(c_board) 66 | 67 | 68 | while True: 69 | move1 = get_action(c_state) 70 | c_state = c_state.move(move1) 71 | c_board = c_state.board 72 | graphics(c_board) 73 | 74 | board_state = TicTacToeGameState(state=c_board, next_to_move=1) 75 | root = MonteCarloTreeSearchNode(state=board_state, parent=None) 76 | mcts = MonteCarloTreeSearch(root) 77 | best_node = mcts.best_action(1000) 78 | c_state = best_node.state 79 | c_board = c_state.board 80 | graphics(c_board) 81 | if judge(c_state)==1: 82 | break 83 | elif judge(c_state)==-1: 84 | continue 85 | -------------------------------------------------------------------------------- /tictactoe.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class TicTacToeMove(object): 5 | def __init__(self, x_coordinate, y_coordinate, value): 6 | self.x_coordinate = x_coordinate 7 | self.y_coordinate = y_coordinate 8 | self.value = value 9 | 10 | def __repr__(self): 11 | return "x:" + str(self.x_coordinate) + " y:" + str(self.y_coordinate) + " v:" + str(self.value) 12 | 13 | 14 | class TicTacToeGameState(object): 15 | x = 1 16 | o = -1 17 | 18 | def __init__(self, state, next_to_move=1): 19 | if len(state.shape) != 2 or state.shape[0] != state.shape[1]: 20 | raise ValueError("Please play on 2D square board") 21 | self.board = state 22 | self.board_size = state.shape[0] 23 | self.next_to_move = next_to_move 24 | 25 | @property 26 | def game_result(self): 27 | # check if game is over 28 | rowsum = np.sum(self.board, 0) 29 | colsum = np.sum(self.board, 1) 30 | diag_sum_tl = self.board.trace() 31 | diag_sum_tr = self.board[::-1].trace() 32 | 33 | if any(rowsum == self.board_size) or any( 34 | colsum == self.board_size) or diag_sum_tl == self.board_size or diag_sum_tr == self.board_size: 35 | return 1. 36 | elif any(rowsum == -self.board_size) or any( 37 | colsum == -self.board_size) or diag_sum_tl == -self.board_size or diag_sum_tr == -self.board_size: 38 | 39 | return -1. 40 | elif np.all(self.board != 0): 41 | return 0. 42 | else: 43 | # if not over - no result 44 | return None 45 | 46 | def is_game_over(self): 47 | return self.game_result != None 48 | 49 | def is_move_legal(self, move): 50 | # check if correct player moves 51 | if move.value != self.next_to_move: 52 | return False 53 | 54 | # check if inside the board 55 | x_in_range = move.x_coordinate < self.board_size and move.x_coordinate >= 0 56 | if not x_in_range: 57 | return False 58 | 59 | # check if inside the board 60 | y_in_range = move.y_coordinate < self.board_size and move.y_coordinate >= 0 61 | if not y_in_range: 62 | return False 63 | 64 | # finally check if board field not occupied yet 65 | return self.board[move.x_coordinate, move.y_coordinate] == 0 66 | 67 | def move(self, move): 68 | if not self.is_move_legal(move): 69 | raise ValueError("move " + move + " on board " + self.board + " is not legal") 70 | new_board = np.copy(self.board) 71 | new_board[move.x_coordinate, move.y_coordinate] = move.value 72 | next_to_move = TicTacToeGameState.o if self.next_to_move == TicTacToeGameState.x else TicTacToeGameState.x 73 | return TicTacToeGameState(new_board, next_to_move) 74 | 75 | def get_legal_actions(self): 76 | indices = np.where(self.board == 0) 77 | return [TicTacToeMove(coords[0], coords[1], self.next_to_move) for coords in list(zip(indices[0], indices[1]))] 78 | --------------------------------------------------------------------------------