├── README.md ├── antas.py ├── mancala.py └── mcts.py /README.md: -------------------------------------------------------------------------------- 1 | # MCTS 2 | Python Implementations of Monte Carlo Tree Search for experimentation. 3 | 4 | Monte Carlo tree search (MCTS) is a newly emerging and promising algorithm in the AI literature. See http://pubs.doc.ic.ac.uk/survey-mcts-methods/survey-mcts-methods.pdf for a survey on MCTS. 5 | 6 | This code is a toy implementation to play with the algorithm. While MCTS can apply to many settings, in the code we apply it to a pretty simple but nonetheless interesting state. 7 | 8 | The State is a game where you have NUM_TURNS and at turn i you can make 9 | a choice from an integeter [-2,2,3,-3]*(NUM_TURNS+1-i). So for example in a game of 4 turns, on turn for turn 1 you can can choose from [-8,8,12,-12], and on turn 2 you can choose from [-6,6,9,-9]. At each turn the choosen number is accumulated into a aggregation value. The goal of the game is for the accumulated value to be as close to 0 as possible. 10 | 11 | The game may not be very interesting but it allows one to study MCTS which is. Some features of the simple game by design are that moves do not commute, and early mistakes are more costly. 12 | 13 | USAGE: 14 | python mcts.py --num_sims 10000 --levels 8 15 | 16 | num_sims is the number of simulations to perform, and levels is the number of times to use MCTS to pick a best child 17 | 18 | 19 | In a 10 turn game here is an optimal solution 20 | [-20, 27, -16, 14, 18, -15, -8, 6, -4, -2] 21 | 22 | Here is a suboptimal solution that you may end up with a local plateau on for example 23 | [20, -18, -16, 14, 12, -10, -8, 6, 4, -3] 24 | 25 | -------------------------------------------------------------------------------- /antas.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import random 3 | import math 4 | import hashlib 5 | import logging 6 | import argparse 7 | from mcts import * 8 | 9 | """ 10 | Another game using MCTS. based on a comment from 11 | atanas1054 on Jun 27, 2017 12 | 13 | I want to have a 2-player game where they take turns. 14 | In the beginning there are 114 possible actions and they decrease by 1 every time a player makes a move. 15 | The game is played for 10 turns (that's the terminal state). I have my own function for the reward. 16 | 17 | Here is a sample game tree: 18 | 19 | START- available actions to both players -> [1,2,3,4,5,6....112,113,114] 20 | Player 1 - takes action 5 -> [5,0,0,0,0,0,0,0,0,0] -remove action 5 from available actions 21 | Player 2 - takes action 32->[5,0,0,0,0,32,0,0,0,0] - remove action 32 from the available actions 22 | Player 1- takes action 97 ->[5,97,0,0,0,32,0,0,0,0] - remove action 97 from the available actions 23 | Player 2 takes action 56 -> [5,97,0,0,0,32,56,0,0,0] - remove action 56 from the available actions 24 | .... 25 | Final (example) game state after each player makes 5 actions -> [5,97,3,5,1,32,56,87,101,8] 26 | First 5 entries present the actions taken by Player1, second 5 entries present the actions taken by Player 2 27 | 28 | Finally, I apply a reward function to this vector [5,97,3,5,1,32,56,87,101,8] 29 | """ 30 | 31 | NUM_TURNS = 5 32 | 33 | class AntasState(): 34 | def __init__(self, current=[0]*2*NUM_TURNS,turn=0): 35 | self.current=current 36 | self.turn=turn 37 | self.num_moves=(114-self.turn)*(114-self.turn-1) 38 | 39 | def next_state(self): 40 | availableActions=[x for x in range(1,115)] 41 | for c in self.current: 42 | if c in availableActions: 43 | availableActions.remove(c) 44 | player1action=random.choice(availableActions) 45 | availableActions.remove(player1action) 46 | nextcurrent=self.current[:] 47 | nextcurrent[self.turn]=player1action 48 | player2action=random.choice(availableActions) 49 | availableActions.remove(player2action) 50 | nextcurrent[self.turn+NUM_TURNS]=player2action 51 | next=AntasState(current=nextcurrent,turn=self.turn+1) 52 | return next 53 | 54 | def terminal(self): 55 | return self.turn == NUM_TURNS 56 | def reward(self): 57 | r = random.uniform(0,1) #ANTAS, put your own function here 58 | return r 59 | 60 | def __hash__(self): 61 | return int(hashlib.md5(str(self.current).encode('utf-8')).hexdigest(),16) 62 | 63 | def __eq__(self,other): 64 | return hash(self)==hash(other) 65 | 66 | def __repr__(self): 67 | return "CurrentState: %s; turn %d"%(self.current,self.turn) 68 | 69 | 70 | if __name__=="__main__": 71 | parser = argparse.ArgumentParser(description='MCTS research code') 72 | parser.add_argument('--num_sims', action="store", required=True, type=int, help="Number of simulations to run, should be more than 114*113") 73 | args=parser.parse_args() 74 | 75 | current_node=Node(AntasState()) 76 | for l in range(NUM_TURNS): 77 | current_node=UCTSEARCH(args.num_sims/(l+1),current_node) 78 | print("level %d"%l) 79 | print("Num Children: %d"%len(current_node.children)) 80 | for i,c in enumerate(current_node.children): 81 | print(i,c) 82 | print("Best Child: %s"%current_node.state) 83 | print("--------------------------------") 84 | -------------------------------------------------------------------------------- /mancala.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import random 3 | import math 4 | import hashlib 5 | import logging 6 | import argparse 7 | from copy import deepcopy 8 | from mcts import * 9 | 10 | """ 11 | Mancala using MCTS. 12 | 13 | First to get >=25 points wins 14 | Board starts as rows =[r1,r2] with ri = [4,4,4,4,4,4] 15 | r1 = [0,1,2,3,4,5] 16 | r2 = [0,1,2,3,4,5] 17 | """ 18 | 19 | NUM_TURNS = 3 20 | 21 | class MancalaState(): 22 | def __init__(self,player1_points=0,player2_points=0,board=[[4,4,4,4,4,4],[4,4,4,4,4,4]],played_moves=[]): 23 | self.player1_points = player1_points 24 | self.player2_points = player2_points 25 | self.board = board 26 | self.num_moves = 6 27 | self.played_moves=played_moves 28 | 29 | def play2(self): 30 | logger.info("PLAYING 2:: %s"%self) 31 | moves2 = [] 32 | for ind,val in enumerate(self.board[1]): 33 | if val>0: 34 | moves2.append((ind,val)) 35 | if not moves2: 36 | return 37 | ind2,val2 = random.choice(moves2) 38 | logger.info("Moving %d,%d"%(ind2,val2)) 39 | self.played_moves.append("PLAYER2: ind:%d,val:%d"%(ind2,val2)) 40 | lind = "NEGATIVE" 41 | #pickup 42 | self.board[1][ind2]=0 43 | #play 44 | while val2>0 and ind2>0: 45 | ind2-=1 46 | val2-=1 47 | self.board[1][ind2]+=1 48 | lind = ind2 49 | if val2>0: 50 | self.player2_points+=1 51 | val2-=1 52 | ind2=-1 53 | lind = "HOME" 54 | while val2>0 and ind2<5: 55 | ind2+=1 56 | val2-=1 57 | self.board[0][ind2]+=1 58 | lind = "NEGATIVE" 59 | if val2>0: 60 | ind2=6 61 | while val2>0 and ind2>0: 62 | ind2-=1 63 | val2-=1 64 | self.board[1][ind2]+=1 65 | lind = ind2 66 | if val2>0: 67 | self.player2_points+=1 68 | val2-=1 69 | ind2=-1 70 | lind = "HOME" 71 | while val2>0 and ind2<5: 72 | ind2+=1 73 | val2-=1 74 | self.board[0][ind2]+=1 75 | lind = "NEGATIVE" 76 | if lind == "HOME": 77 | if self.check_for_remaining(): 78 | self.play2() 79 | elif lind != "NEGATIVE": 80 | if self.board[1][lind]==1: 81 | captured = self.board[0][lind] 82 | self.player2_points += captured + 1 83 | self.board[0][lind] = 0 84 | self.board[1][lind] = 0 85 | self.check_for_remaining() 86 | 87 | def play1(self): 88 | logger.info("PLAYING 1:: %s"%self) 89 | moves1 = [] 90 | for ind,val in enumerate(self.board[0]): 91 | if val>0: 92 | moves1.append((ind,val)) 93 | if not moves1: 94 | return 95 | ind1,val1 = random.choice(moves1) 96 | logger.info("Moving %d,%d"%(ind1,val1)) 97 | self.played_moves.append("PLAYER1: ind:%d,val:%d"%(ind1,val1)) 98 | lind = "NEGATIVE" 99 | #pickup 100 | self.board[0][ind1]=0 101 | #play 102 | while val1>0 and ind1<5: 103 | ind1+=1 104 | val1-=1 105 | self.board[0][ind1]+=1 106 | lind = ind1 107 | if val1>0: 108 | self.player1_points+=1 109 | val1-=1 110 | ind1=6 111 | lind = "HOME" 112 | while val1>0 and ind1>0: 113 | ind1-=1 114 | val1-=1 115 | self.board[1][ind1]+=1 116 | lind = "NEGATIVE" 117 | if val1>0: 118 | ind1=-1 119 | while val1>0 and ind1<5: 120 | ind1+=1 121 | val1-=1 122 | self.board[0][ind1]+=1 123 | lind = ind1 124 | if val1>0: 125 | self.player1_points+=1 126 | val1-=1 127 | ind1=6 128 | lind = "HOME" 129 | while val1>0 and ind1>0: 130 | ind1-=1 131 | val1-=1 132 | self.board[1][ind1]+=1 133 | lind = "NEGATIVE" 134 | if lind == "HOME": 135 | if self.check_for_remaining(): 136 | self.play1() 137 | elif lind != "NEGATIVE": 138 | if self.board[0][lind]==1: 139 | captured = self.board[1][lind] 140 | self.player1_points += captured + 1 141 | self.board[0][lind] = 0 142 | self.board[1][lind] = 0 143 | self.check_for_remaining() 144 | 145 | def check_for_remaining(self): 146 | s1 = sum(self.board[0]) 147 | s2 = sum(self.board[1]) 148 | if s1==0 or s2 ==0: 149 | self.player1_points+=s1 150 | self.player2_points+=s2 151 | self.board=[[0,0,0,0,0,0],[0,0,0,0,0,0]] 152 | return False 153 | return True 154 | 155 | def next_state(self): 156 | if self.check_for_remaining(): 157 | self.play1() 158 | if self.check_for_remaining(): 159 | self.play2() 160 | return MancalaState(self.player1_points,self.player2_points,deepcopy(self.board),deepcopy(self.played_moves)) 161 | 162 | def terminal(self): 163 | self.check_for_remaining() 164 | p1_wins = self.player1_points>=25 165 | p2_wins = self.player2_points>=25 166 | if p1_wins or p2_wins: 167 | return True 168 | if sum(self.board[0]+self.board[1])==0: 169 | return True 170 | return False 171 | 172 | def reward(self): 173 | if self.player1_points>=25: 174 | return 1 175 | elif self.player1_points==24: 176 | return 0.5 177 | else: 178 | return 0 179 | 180 | def __hash__(self): 181 | return int(hashlib.md5(str(self.board).encode('utf-8')).hexdigest(),16) 182 | 183 | def __eq__(self,other): 184 | return hash(self)==hash(other) 185 | 186 | def __repr__(self): 187 | return "CurrentState: %s; points1: %d, points2: %d\nplayed_moves:\n%s"%(self.board,self.player1_points,self.player2_points,"\n".join(self.played_moves)) 188 | 189 | 190 | if __name__=="__main__": 191 | parser = argparse.ArgumentParser(description='MCTS research code') 192 | parser.add_argument('--num_sims', action="store", required=True, type=int, help="Number of simulations to run") 193 | args=parser.parse_args() 194 | 195 | current_node=Node(MancalaState()) 196 | num_moves_lambda = lambda node: len([x for x in node.state.board[0] if x>0]) 197 | for l in range(NUM_TURNS): 198 | current_node=UCTSEARCH(args.num_sims/(l+1),current_node,num_moves_lambda) 199 | print("level %d"%l) 200 | print("Num Children: %d"%len(current_node.children)) 201 | for i,c in enumerate(current_node.children): 202 | print(i,c,c.state.board) 203 | print("Best Child: %s"%current_node.state) 204 | print("--------------------------------") 205 | -------------------------------------------------------------------------------- /mcts.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import random 3 | import math 4 | import hashlib 5 | import logging 6 | import argparse 7 | 8 | 9 | """ 10 | A quick Monte Carlo Tree Search implementation. For more details on MCTS see See http://pubs.doc.ic.ac.uk/survey-mcts-methods/survey-mcts-methods.pdf 11 | 12 | The State is a game where you have NUM_TURNS and at turn i you can make 13 | a choice from an integeter [-2,2,3,-3]*(NUM_TURNS+1-i). So for example in a game of 4 turns, on turn for turn 1 you can can choose from [-8,8,12,-12], and on turn 2 you can choose from [-6,6,9,-9]. At each turn the choosen number is accumulated into a aggregation value. The goal of the game is for the accumulated value to be as close to 0 as possible. 14 | 15 | The game is not very interesting but it allows one to study MCTS which is. Some features 16 | of the example by design are that moves do not commute and early mistakes are more costly. 17 | 18 | In particular there are two models of best child that one can use 19 | """ 20 | 21 | #MCTS scalar. Larger scalar will increase exploitation, smaller will increase exploration. 22 | SCALAR=1/(2*math.sqrt(2.0)) 23 | 24 | logging.basicConfig(level=logging.WARNING) 25 | logger = logging.getLogger('MyLogger') 26 | 27 | 28 | class State(): 29 | NUM_TURNS = 10 30 | GOAL = 0 31 | MOVES=[2,-2,3,-3] 32 | MAX_VALUE= (5.0*(NUM_TURNS-1)*NUM_TURNS)/2 33 | num_moves=len(MOVES) 34 | def __init__(self, value=0, moves=[], turn=NUM_TURNS): 35 | self.value=value 36 | self.turn=turn 37 | self.moves=moves 38 | def next_state(self): 39 | nextmove=random.choice([x*self.turn for x in self.MOVES]) 40 | next=State(self.value+nextmove, self.moves+[nextmove],self.turn-1) 41 | return next 42 | def terminal(self): 43 | if self.turn == 0: 44 | return True 45 | return False 46 | def reward(self): 47 | r = 1.0-(abs(self.value-self.GOAL)/self.MAX_VALUE) 48 | return r 49 | def __hash__(self): 50 | return int(hashlib.md5(str(self.moves).encode('utf-8')).hexdigest(),16) 51 | def __eq__(self,other): 52 | if hash(self)==hash(other): 53 | return True 54 | return False 55 | def __repr__(self): 56 | s="Value: %d; Moves: %s"%(self.value,self.moves) 57 | return s 58 | 59 | 60 | class Node(): 61 | def __init__(self, state, parent=None): 62 | self.visits=1 63 | self.reward=0.0 64 | self.state=state 65 | self.children=[] 66 | self.parent=parent 67 | def add_child(self,child_state): 68 | child=Node(child_state,self) 69 | self.children.append(child) 70 | def update(self,reward): 71 | self.reward+=reward 72 | self.visits+=1 73 | def fully_expanded(self, num_moves_lambda): 74 | num_moves = self.state.num_moves 75 | if num_moves_lambda != None: 76 | num_moves = num_moves_lambda(self) 77 | if len(self.children)==num_moves: 78 | return True 79 | return False 80 | def __repr__(self): 81 | s="Node; children: %d; visits: %d; reward: %f"%(len(self.children),self.visits,self.reward) 82 | return s 83 | 84 | def UCTSEARCH(budget,root,num_moves_lambda = None): 85 | for iter in range(int(budget)): 86 | if iter%10000==9999: 87 | logger.info("simulation: %d"%iter) 88 | logger.info(root) 89 | front=TREEPOLICY(root, num_moves_lambda) 90 | reward=DEFAULTPOLICY(front.state) 91 | BACKUP(front,reward) 92 | return BESTCHILD(root,0) 93 | 94 | def TREEPOLICY(node, num_moves_lambda): 95 | #a hack to force 'exploitation' in a game where there are many options, and you may never/not want to fully expand first 96 | while node.state.terminal()==False: 97 | if len(node.children)==0: 98 | return EXPAND(node) 99 | elif random.uniform(0,1)<.5: 100 | node=BESTCHILD(node,SCALAR) 101 | else: 102 | if node.fully_expanded(num_moves_lambda)==False: 103 | return EXPAND(node) 104 | else: 105 | node=BESTCHILD(node,SCALAR) 106 | return node 107 | 108 | def EXPAND(node): 109 | tried_children=[c.state for c in node.children] 110 | new_state=node.state.next_state() 111 | while new_state in tried_children and new_state.terminal()==False: 112 | new_state=node.state.next_state() 113 | node.add_child(new_state) 114 | return node.children[-1] 115 | 116 | #current this uses the most vanilla MCTS formula it is worth experimenting with THRESHOLD ASCENT (TAGS) 117 | def BESTCHILD(node,scalar): 118 | bestscore=0.0 119 | bestchildren=[] 120 | for c in node.children: 121 | exploit=c.reward/c.visits 122 | explore=math.sqrt(2.0*math.log(node.visits)/float(c.visits)) 123 | score=exploit+scalar*explore 124 | if score==bestscore: 125 | bestchildren.append(c) 126 | if score>bestscore: 127 | bestchildren=[c] 128 | bestscore=score 129 | if len(bestchildren)==0: 130 | logger.warn("OOPS: no best child found, probably fatal") 131 | return random.choice(bestchildren) 132 | 133 | def DEFAULTPOLICY(state): 134 | while state.terminal()==False: 135 | state=state.next_state() 136 | return state.reward() 137 | 138 | def BACKUP(node,reward): 139 | while node!=None: 140 | node.visits+=1 141 | node.reward+=reward 142 | node=node.parent 143 | return 144 | 145 | if __name__=="__main__": 146 | parser = argparse.ArgumentParser(description='MCTS research code') 147 | parser.add_argument('--num_sims', action="store", required=True, type=int) 148 | parser.add_argument('--levels', action="store", required=True, type=int, choices=range(State.NUM_TURNS+1)) 149 | args=parser.parse_args() 150 | 151 | current_node=Node(State()) 152 | for l in range(args.levels): 153 | current_node=UCTSEARCH(args.num_sims/(l+1),current_node) 154 | print("level %d"%l) 155 | print("Num Children: %d"%len(current_node.children)) 156 | for i,c in enumerate(current_node.children): 157 | print(i,c) 158 | print("Best Child: %s"%current_node.state) 159 | 160 | print("--------------------------------") 161 | 162 | 163 | --------------------------------------------------------------------------------