├── MCTS Deep Dive.pdf ├── README.md ├── naive_impl.py └── numpy_impl.py /MCTS Deep Dive.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brilee/python_uct/8b49ebd17ad10a945011379e6d44cc112e446e5c/MCTS Deep Dive.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Implementing MCTS in Python efficiently with Numpy 2 | 3 | This repo contains code for [this essay](https://www.moderndescartes.com/essays/deep_dive_mcts/). 4 | 5 | See the "[vloss](https://github.com/brilee/python_uct/tree/vloss)" branch for virtual losses. 6 | -------------------------------------------------------------------------------- /naive_impl.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | 4 | class UCTNode(): 5 | def __init__(self, game_state, parent=None, prior=0): 6 | self.game_state = game_state 7 | self.is_expanded = False 8 | self.parent = parent # Optional[UCTNode] 9 | self.children = {} # Dict[move, UCTNode] 10 | self.prior = prior # float 11 | self.total_value = 0 # float 12 | self.number_visits = 0 # int 13 | 14 | def Q(self): # returns float 15 | return self.total_value / (1 + self.number_visits) 16 | 17 | def U(self): # returns float 18 | return (math.sqrt(self.parent.number_visits) 19 | * self.prior / (1 + self.number_visits)) 20 | 21 | def best_child(self): 22 | return max(self.children.values(), 23 | key=lambda node: node.Q() + node.U()) 24 | 25 | def select_leaf(self): 26 | current = self 27 | while current.is_expanded: 28 | current = current.best_child() 29 | return current 30 | 31 | def expand(self, child_priors): 32 | self.is_expanded = True 33 | for move, prior in enumerate(child_priors): 34 | self.add_child(move, prior) 35 | 36 | def add_child(self, move, prior): 37 | self.children[move] = UCTNode( 38 | self.game_state.play(move), parent=self, prior=prior) 39 | 40 | def backup(self, value_estimate: float): 41 | current = self 42 | while current.parent is not None: 43 | current.number_visits += 1 44 | current.total_value += (value_estimate * 45 | self.game_state.to_play) 46 | current = current.parent 47 | 48 | def UCT_search(game_state, num_reads): 49 | root = UCTNode(game_state) 50 | for _ in range(num_reads): 51 | leaf = root.select_leaf() 52 | child_priors, value_estimate = NeuralNet.evaluate(leaf.game_state) 53 | leaf.expand(child_priors) 54 | leaf.backup(value_estimate) 55 | return max(root.children.items(), 56 | key=lambda item: item[1].number_visits) 57 | 58 | 59 | class NeuralNet(): 60 | @classmethod 61 | def evaluate(self, game_state): 62 | return np.random.random([362]), np.random.random() 63 | 64 | class GameState(): 65 | def __init__(self, to_play=1): 66 | self.to_play = to_play 67 | 68 | def play(self, move): 69 | return GameState(-self.to_play) 70 | 71 | num_reads = 10000 72 | import time 73 | tick = time.time() 74 | UCT_search(GameState(), num_reads) 75 | tock = time.time() 76 | print("Took %s sec to run %s times" % (tock - tick, num_reads)) 77 | import resource 78 | print("Consumed %sB memory" % resource.getrusage(resource.RUSAGE_SELF).ru_maxrss) 79 | -------------------------------------------------------------------------------- /numpy_impl.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import numpy as np 3 | import math 4 | 5 | class UCTNode(): 6 | def __init__(self, game_state, move, parent=None): 7 | self.game_state = game_state 8 | self.move = move 9 | self.is_expanded = False 10 | self.parent = parent # Optional[UCTNode] 11 | self.children = {} # Dict[move, UCTNode] 12 | self.child_priors = np.zeros([362], dtype=np.float32) 13 | self.child_total_value = np.zeros([362], dtype=np.float32) 14 | self.child_number_visits = np.zeros([362], dtype=np.float32) 15 | 16 | @property 17 | def number_visits(self): 18 | return self.parent.child_number_visits[self.move] 19 | 20 | @number_visits.setter 21 | def number_visits(self, value): 22 | self.parent.child_number_visits[self.move] = value 23 | 24 | @property 25 | def total_value(self): 26 | return self.parent.child_total_value[self.move] 27 | 28 | @total_value.setter 29 | def total_value(self, value): 30 | self.parent.child_total_value[self.move] = value 31 | 32 | def child_Q(self): 33 | return self.child_total_value / (1 + self.child_number_visits) 34 | 35 | def child_U(self): 36 | return math.sqrt(self.number_visits) * ( 37 | self.child_priors / (1 + self.child_number_visits)) 38 | 39 | def best_child(self): 40 | return np.argmax(self.child_Q() + self.child_U()) 41 | 42 | def select_leaf(self): 43 | current = self 44 | while current.is_expanded: 45 | best_move = current.best_child() 46 | current = current.maybe_add_child(best_move) 47 | return current 48 | 49 | def expand(self, child_priors): 50 | self.is_expanded = True 51 | self.child_priors = child_priors 52 | 53 | def maybe_add_child(self, move): 54 | if move not in self.children: 55 | self.children[move] = UCTNode( 56 | self.game_state.play(move), move, parent=self) 57 | return self.children[move] 58 | 59 | def backup(self, value_estimate: float): 60 | current = self 61 | while current.parent is not None: 62 | current.number_visits += 1 63 | current.total_value += (value_estimate * 64 | self.game_state.to_play) 65 | current = current.parent 66 | 67 | class DummyNode(object): 68 | def __init__(self): 69 | self.parent = None 70 | self.child_total_value = collections.defaultdict(float) 71 | self.child_number_visits = collections.defaultdict(float) 72 | 73 | 74 | def UCT_search(game_state, num_reads): 75 | root = UCTNode(game_state, move=None, parent=DummyNode()) 76 | for _ in range(num_reads): 77 | leaf = root.select_leaf() 78 | child_priors, value_estimate = NeuralNet.evaluate(leaf.game_state) 79 | leaf.expand(child_priors) 80 | leaf.backup(value_estimate) 81 | return np.argmax(root.child_number_visits) 82 | 83 | 84 | class NeuralNet(): 85 | @classmethod 86 | def evaluate(self, game_state): 87 | return np.random.random([362]), np.random.random() 88 | 89 | class GameState(): 90 | def __init__(self, to_play=1): 91 | self.to_play = to_play 92 | 93 | def play(self, move): 94 | return GameState(-self.to_play) 95 | 96 | num_reads = 10000 97 | import time 98 | tick = time.time() 99 | UCT_search(GameState(), num_reads) 100 | tock = time.time() 101 | print("Took %s sec to run %s times" % (tock - tick, num_reads)) 102 | import resource 103 | print("Consumed %sB memory" % resource.getrusage(resource.RUSAGE_SELF).ru_maxrss) 104 | --------------------------------------------------------------------------------