├── .gitignore ├── LICENSE ├── README.rst ├── mcts ├── __init__.py └── uct.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | *.pyc 3 | *.egg-info 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015 Jeff Bradberry 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | Monte Carlo Tree Search 2 | ======================= 3 | 4 | This is an implementation of an AI in Python using the UCT Monte Carlo 5 | Tree Search algorithm. 6 | 7 | The Monte Carlo Tree Search AIs included here are designed to work 8 | with `jbradberry/boardgame-socketserver 9 | `_ and 10 | `jbradberry/boardgame-socketplayer 11 | `_. 12 | 13 | 14 | Requirements 15 | ------------ 16 | 17 | * Python 2.7, 3.5+; PyPy; PyPy3 18 | * six 19 | 20 | 21 | Getting Started 22 | --------------- 23 | 24 | To set up your local environment you should create a virtualenv and 25 | install everything into it. :: 26 | 27 | $ mkvirtualenv mcts 28 | 29 | Pip install this repo, either from a local copy, :: 30 | 31 | $ pip install -e mcts 32 | 33 | or from github, :: 34 | 35 | $ pip install git+https://github.com/jbradberry/mcts#egg=mcts 36 | 37 | Additionally, you will need to have `jbradberry/boardgame-socketplayer 38 | `_ installed in 39 | order to make use of the players. 40 | 41 | This project currently comes with two different Monte Carlo Tree 42 | Search players. The first, ``jrb.mcts.uct``, uses the count of the 43 | number of wins for a node to make its decisions. The second, 44 | ``jrb.mcts.uctv`` instead keeps track of the evaluated value of the 45 | board for the playouts from a given node :: 46 | 47 | $ board-play.py t3 jrb.mcts.uct # number of wins metric 48 | $ board-play.py t3 jrb.mcts.uctv # point value of the board metric 49 | 50 | These AI players can also take additional arguments: 51 | 52 | time (default: 30) 53 | The amount of thinking time allowed for the AI to make its decision, 54 | in seconds. Ex: ``$ board-play.py t3 jrb.mcts.uct -e time=5`` 55 | 56 | max_actions (default: 1000) 57 | The maximum number of actions, or plays, to allow in one of the 58 | simulated playouts before giving up. Ex: ``$ board-play.py t3 59 | jrb.mcts.uct -e max_actions=500`` 60 | 61 | C (default: 1.4) 62 | The exploration vs. exploitation coefficient at the heart of the UCT 63 | algorithm. Larger values prioritize exploring inadequately covered 64 | actions from a node, smaller values prioritize exploiting known 65 | higher valued actions. Experimentation with this variable to find 66 | reasonable values for a given game is recommended. Ex: ``$ 67 | board-play.py t3 jrb.mcts.uct -e C=3.5`` 68 | 69 | The ``-e`` flag may be used multiple times to set additional 70 | variables. 71 | 72 | 73 | Games 74 | ----- 75 | 76 | Compatible games that have been implemented include: 77 | 78 | * `Reversi `_ 79 | * `Connect Four `_ 80 | * `Ultimate (or 9x9) Tic Tac Toe 81 | `_ 82 | * `Chong `_ 83 | -------------------------------------------------------------------------------- /mcts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jbradberry/mcts/2480de920fca47d5b2d7f62916f771d59679c24f/mcts/__init__.py -------------------------------------------------------------------------------- /mcts/uct.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | from __future__ import absolute_import 4 | from __future__ import print_function 5 | import time 6 | from math import log, sqrt 7 | from random import choice 8 | from six.moves import range 9 | 10 | 11 | class Stat(object): 12 | __slots__ = ('value', 'visits') 13 | 14 | def __init__(self, value=0.0, visits=0): 15 | self.value = value 16 | self.visits = visits 17 | 18 | def __repr__(self): 19 | return u"Stat(value={}, visits={})".format(self.value, self.visits) 20 | 21 | 22 | class UCT(object): 23 | def __init__(self, board, **kwargs): 24 | self.board = board 25 | self.history = [] 26 | self.stats = {} 27 | 28 | self.max_depth = 0 29 | self.data = {} 30 | 31 | self.calculation_time = float(kwargs.get('time', 30)) 32 | self.max_actions = int(kwargs.get('max_actions', 1000)) 33 | 34 | # Exploration constant, increase for more exploratory actions, 35 | # decrease to prefer actions with known higher win rates. 36 | self.C = float(kwargs.get('C', 1.4)) 37 | 38 | def update(self, state): 39 | self.history.append(self.board.to_compact_state(state)) 40 | 41 | def display(self, state, action): 42 | return self.board.display(state, action) 43 | 44 | def winner_message(self, winners): 45 | return self.board.winner_message(winners) 46 | 47 | def get_action(self): 48 | # Causes the AI to calculate the best action from the 49 | # current game state and return it. 50 | 51 | self.max_depth = 0 52 | self.data = {'C': self.C, 'max_actions': self.max_actions, 'name': self.name} 53 | self.stats.clear() 54 | 55 | state = self.history[-1] 56 | player = self.board.current_player(state) 57 | legal = self.board.legal_actions(state) 58 | 59 | # Bail out early if there is no real choice to be made. 60 | if not legal: 61 | return {'type': 'action', 'message': None, 'extras': self.data.copy()} 62 | if len(legal) == 1: 63 | return { 64 | 'type': 'action', 65 | 'message': self.board.to_json_action(legal[0]), 66 | 'extras': self.data.copy(), 67 | } 68 | 69 | games = 0 70 | begin = time.time() 71 | while time.time() - begin < self.calculation_time: 72 | self.run_simulation() 73 | games += 1 74 | 75 | # Display the number of calls of `run_simulation` and the 76 | # time elapsed. 77 | self.data.update(games=games, max_depth=self.max_depth, 78 | time=str(time.time() - begin)) 79 | print(self.data['games'], self.data['time']) 80 | print("Maximum depth searched:", self.max_depth) 81 | 82 | # Store and display the stats for each possible action. 83 | self.data['actions'] = self.calculate_action_values(self.history, player, legal) 84 | for m in self.data['actions']: 85 | print(self.action_template.format(**m)) 86 | 87 | # Return the action with the highest average value. 88 | return { 89 | 'type': 'action', 90 | 'message': self.board.to_json_action(self.data['actions'][0]['action']), 91 | 'extras': self.data.copy(), 92 | } 93 | 94 | def run_simulation(self): 95 | # Plays out a "random" game from the current position, 96 | # then updates the statistics tables with the result. 97 | 98 | # A bit of an optimization here, so we have a local 99 | # variable lookup instead of an attribute access each loop. 100 | C, stats = self.C, self.stats 101 | 102 | visited_states = [] 103 | history_copy = self.history[:] 104 | state = history_copy[-1] 105 | 106 | expand = True 107 | for t in range(1, self.max_actions + 1): 108 | legal = self.board.legal_actions(state) 109 | actions_states = [(a, self.board.next_state(history_copy, a)) for a in legal] 110 | 111 | if expand and not all(S in stats for a, S in actions_states): 112 | stats.update((S, Stat()) for a, S in actions_states if S not in stats) 113 | expand = False 114 | if t > self.max_depth: 115 | self.max_depth = t 116 | 117 | if expand: 118 | # If we have stats on all of the legal actions here, use UCB1. 119 | actions_states = [(a, S, stats[S]) for a, S in actions_states] 120 | log_total = log(sum(e.visits for a, S, e in actions_states) or 1) 121 | values_actions = [ 122 | (a, S, (e.value / (e.visits or 1)) + C * sqrt(log_total / (e.visits or 1))) 123 | for a, S, e in actions_states 124 | ] 125 | max_value = max(v for _, _, v in values_actions) 126 | # Filter down to only those actions with maximum value under UCB1. 127 | actions_states = [(a, S) for a, S, v in values_actions if v == max_value] 128 | 129 | action, state = choice(actions_states) 130 | visited_states.append(state) 131 | history_copy.append(state) 132 | 133 | if self.board.is_ended(state): 134 | break 135 | 136 | # Back-propagation 137 | end_values = self.end_values(state) 138 | for state in visited_states: 139 | if state not in stats: 140 | continue 141 | S = stats[state] 142 | S.visits += 1 143 | S.value += end_values[self.board.previous_player(state)] 144 | 145 | 146 | class UCTWins(UCT): 147 | name = "jrb.mcts.uct" 148 | action_template = "{action}: {percent:.2f}% ({wins} / {plays})" 149 | 150 | def __init__(self, board, **kwargs): 151 | super(UCTWins, self).__init__(board, **kwargs) 152 | self.end_values = board.win_values 153 | 154 | def calculate_action_values(self, history, player, legal): 155 | actions_states = ((a, self.board.next_state(history, a)) for a in legal) 156 | return sorted( 157 | ({'action': a, 158 | 'percent': 100 * self.stats[S].value / (self.stats[S].visits or 1), 159 | 'wins': self.stats[S].value, 160 | 'plays': self.stats[S].visits} 161 | for a, S in actions_states), 162 | key=lambda x: (x['percent'], x['plays']), 163 | reverse=True 164 | ) 165 | 166 | 167 | class UCTValues(UCT): 168 | name = "jrb.mcts.uctv" 169 | action_template = "{action}: {average:.1f} ({sum} / {plays})" 170 | 171 | def __init__(self, board, **kwargs): 172 | super(UCTValues, self).__init__(board, **kwargs) 173 | self.end_values = board.points_values 174 | 175 | def calculate_action_values(self, history, player, legal): 176 | actions_states = ((a, self.board.next_state(history, a)) for a in legal) 177 | return sorted( 178 | ({'action': a, 179 | 'average': self.stats[S].value / (self.stats[S].visits or 1), 180 | 'sum': self.stats[S].value, 181 | 'plays': self.stats[S].visits} 182 | for a, S in actions_states), 183 | key=lambda x: (x['average'], x['plays']), 184 | reverse=True 185 | ) 186 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from distutils.core import setup 3 | 4 | setup( 5 | name='MonteCarloTreeSearch', 6 | version='0.1dev', 7 | author='Jeff Bradberry', 8 | author_email='jeff.bradberry@gmail.com', 9 | packages=['mcts'], 10 | entry_points={ 11 | 'jrb_board.players': ['jrb.mcts.uct = mcts.uct:UCTWins', 12 | 'jrb.mcts.uctv = mcts.uct:UCTValues'], 13 | }, 14 | install_requires=['six'], 15 | license='LICENSE', 16 | description="An implementation of UCT Monte Carlo Tree Search.", 17 | ) 18 | --------------------------------------------------------------------------------