├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── experiments └── toy_world.py ├── mcts ├── __init__.py ├── backups.py ├── default_policies.py ├── graph.py ├── mcts.py ├── states │ ├── __init__.py │ └── toy_world_state.py ├── tree_policies.py └── utils.py ├── setup.py └── tests ├── test_toy_world_state.py └── test_uct.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Distribution / packaging 9 | .Python 10 | env/ 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | lib/ 17 | lib64/ 18 | parts/ 19 | sdist/ 20 | virtualenvs/ 21 | var/ 22 | *.egg-info/ 23 | .installed.cfg 24 | *.egg 25 | 26 | # PyInstaller 27 | # Usually these files are written by a python script from a template 28 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 29 | *.manifest 30 | *.spec 31 | 32 | # Installer logs 33 | pip-log.txt 34 | pip-delete-this-directory.txt 35 | 36 | # Unit test / coverage reports 37 | htmlcov/ 38 | .tox/ 39 | .coverage 40 | .cache 41 | nosetests.xml 42 | coverage.xml 43 | 44 | # Translations 45 | *.mo 46 | *.pot 47 | 48 | # Django stuff: 49 | *.log 50 | 51 | # Sphinx documentation 52 | docs/_build/ 53 | 54 | # PyBuilder 55 | target/ 56 | 57 | *ipynb 58 | .idea/ 59 | *pkl 60 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "3.3" 4 | - "2.7" 5 | 6 | before_install: 7 | - wget http://repo.continuum.io/miniconda/Miniconda-latest-Linux-x86_64.sh -O miniconda.sh 8 | - chmod +x miniconda.sh 9 | - ./miniconda.sh -b 10 | - export PATH=/home/travis/miniconda/bin:$PATH 11 | 12 | install: 13 | - conda update --yes conda 14 | - conda create -q -n test-environment --yes python=$TRAVIS_PYTHON_VERSION scipy 15 | - source activate test-environment 16 | - pip install --force-reinstall pytest pytest-cov coveralls 17 | - pip install -e . 18 | 19 | script: 20 | - py.test --cov mcts 21 | 22 | after_success: 23 | - coveralls 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015, Johannes Kulick 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 2. Redistributions in binary form must reproduce the above copyright notice, 10 | this list of conditions and the following disclaimer in the documentation 11 | and/or other materials provided with the distribution. 12 | 13 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 14 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 15 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 16 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 17 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 18 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 19 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 20 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 21 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 22 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://travis-ci.org/hildensia/mcts.svg?branch=master)](https://travis-ci.org/hildensia/mcts) 2 | [![Coverage Status](https://coveralls.io/repos/hildensia/mcts/badge.svg)](https://coveralls.io/r/hildensia/mcts) 3 | #scikit.mcts# 4 | 5 | Version: 0.1 (It's still alpha, don't use it for your production website!) 6 | 7 | Website: https://github.com/hildensia/mcts 8 | 9 | An implementation of Monte Carlo Search Trees in python. 10 | 11 | ## Setup 12 | Requirements: 13 | * numpy 14 | * scipy 15 | * pytest for tests 16 | 17 | Than plain simple `python setup.py install`. Or use `pip`: `pip install scikit.mcts`. 18 | 19 | ## Usage 20 | Assume you have a very simple 3x3 maze. An action could be 'up', 'down', 'left' or 'right'. You start at `[0, 0]` and there is a reward at `[2, 2]`. 21 | 22 | class MazeAction(object): 23 | def __init__(self, move): 24 | self.move = np.asarray(move) 25 | 26 | def __eq__(self, other): 27 | return all(self.move == other.move) 28 | 29 | def __hash__(self): 30 | return 10*self.move[0] + self.move[1] 31 | 32 | class MazeState(object): 33 | def __init__(self, pos): 34 | self.pos = np.asarray(pos) 35 | self.actions = [MazeAction([1, 0]), 36 | MazeAction([0, 1]), 37 | MazeAction([-1, 0]), 38 | MazeAction([0, -1])] 39 | 40 | def perform(self, action): 41 | pos = self.pos + action.move 42 | pos = np.clip(pos, 0, 2) 43 | return MazeState(pos) 44 | 45 | def reward(self, parent, action): 46 | if all(self.pos == np.array([2, 2])): 47 | return 10 48 | else: 49 | return -1 50 | 51 | def is_terminal(self): 52 | return False 53 | 54 | def __eq__(self, other): 55 | return all(self.pos == other.pos) 56 | 57 | def __hash__(self): 58 | return 10 * self.pos[0] + self.pos[1] 59 | 60 | This would be a plain simple implementation. Now let's run MCTS on top: 61 | 62 | mcts = MCTS(tree_policy=UCB1(c=1.41), 63 | default_policy=immediate_reward, 64 | backup=monte_carlo) 65 | 66 | root = StateNode(parent=None, state=MazeState([0, 0])) 67 | best_action = mcts(root) 68 | 69 | 70 | ## Licence 71 | See LICENCE 72 | 73 | ## Authors 74 | Johannes Kulick 75 | 76 | 77 | -------------------------------------------------------------------------------- /experiments/toy_world.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import random 5 | import argparse 6 | 7 | import numpy as np 8 | 9 | from mcts.mcts import mcts_search 10 | from mcts.states import toy_world_state as state 11 | from mcts.graph import StateNode 12 | 13 | 14 | try: 15 | import cPickle as pickle 16 | except ImportError: 17 | import pickle 18 | import datetime 19 | 20 | __author__ = 'johannes' 21 | 22 | 23 | def run_experiment(intrinsic_motivation, gamma, c, mc_n, runs, steps): 24 | trajectories = [] 25 | start = np.array([50, 50]) 26 | true_belief = True 27 | 28 | for _ in range(runs): 29 | goal = draw_goal(start, 6) 30 | manual = draw_goal(start, 3) 31 | print("Goal: {}".format(goal)) 32 | print("Manual: {}".format(manual)) 33 | 34 | world = state.ToyWorld([100, 100], intrinsic_motivation, goal, manual) 35 | belief = None 36 | if true_belief: 37 | belief = dict(zip([state.ToyWorldAction(np.array([0, 1])), 38 | state.ToyWorldAction(np.array([0, -1])), 39 | state.ToyWorldAction(np.array([1, 0])), 40 | state.ToyWorldAction(np.array([-1, 0]))], 41 | [[10, 10, 10, 10], [10, 10, 10, 10], 42 | [10, 10, 10, 10], [10, 10, 10, 10]])) 43 | root_state = state.ToyWorldState(start, world, belief=belief) 44 | print(root_state.pos) 45 | next_state = StateNode(None, root_state, 0) 46 | trajectory =[] 47 | for _ in range(steps): 48 | try: 49 | ba = mcts_search(next_state, gamma, c=c, n=mc_n) 50 | print("") 51 | print("=" * 80) 52 | print("State: {}".format(next_state.state)) 53 | print("Belief: {}".format(next_state.state.belief)) 54 | print("Reward: {}".format(next_state.reward)) 55 | print("N: {}".format(next_state.n)) 56 | print("Q: {}".format(next_state.q)) 57 | print("Action: {}".format(ba.action)) 58 | trajectory.append(next_state.state.pos) 59 | if (next_state.state.pos == np.array(goal)).all(): 60 | break 61 | next_s = next_state.children[ba].sample_state(real_world=True) 62 | next_state = next_s 63 | next_state.parent = None 64 | except KeyboardInterrupt: 65 | break 66 | trajectories.append(trajectory) 67 | with open(gen_name("trajectories", "pkl"), "w") as f: 68 | pickle.dump(trajectories, f) 69 | print("=" * 80) 70 | 71 | 72 | def draw_goal(start, dist): 73 | delta_x = random.randint(0, dist) 74 | delta_y = dist - delta_x 75 | return start - np.array([delta_x, delta_y]) 76 | 77 | 78 | def gen_name(name, suffix): 79 | datestr = datetime.datetime.strftime(datetime.datetime.now(), 80 | '%Y-%m-%d-%H:%M:%S') 81 | return name + datestr + suffix 82 | 83 | 84 | if __name__ == '__main__': 85 | parser = argparse.ArgumentParser(description='Run experiment for UCT with ' 86 | 'intrinsic motivation.') 87 | parser.add_argument('--intrinsic', '-i', action='store_true', 88 | help='Should intrinsic motivation be used?') 89 | parser.add_argument('--mcsamples', '-m', type=int, default=500, 90 | help='How many monte carlo runs should be made.') 91 | parser.add_argument('--runs', '-r', type=int, default=10, 92 | help='How many runs should be made.') 93 | parser.add_argument('--steps', '-s', type=int, default=100, 94 | help="Maximum number of steps performed.") 95 | parser.add_argument('--gamma', '-g', type=float, default=0.6, 96 | help='The learning rate.') 97 | parser.add_argument('--uct_c', '-c', type=float, default=10, 98 | help='The UCT parameter Cp.') 99 | 100 | args = parser.parse_args() 101 | run_experiment(intrinsic_motivation=args.intrinsic, gamma=args.gamma, 102 | mc_n=args.mcsamples, runs=args.runs, steps=args.steps, 103 | c=args.uct_c) 104 | 105 | 106 | -------------------------------------------------------------------------------- /mcts/__init__.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | 3 | __version__ = 0.1 4 | __author__ = "Johannes Kulick" 5 | -------------------------------------------------------------------------------- /mcts/backups.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from .graph import StateNode, ActionNode 3 | 4 | 5 | class Bellman(object): 6 | """ 7 | A dynamical programming update which resembles the Bellman equation 8 | of value iteration. 9 | 10 | See Feldman and Domshlak (2014) for reference. 11 | """ 12 | def __init__(self, gamma): 13 | self.gamma = gamma 14 | 15 | def __call__(self, node): 16 | """ 17 | :param node: The node to start the backups from 18 | """ 19 | while node is not None: 20 | node.n += 1 21 | if isinstance(node, StateNode): 22 | node.q = max([x.q for x in node.children.values()]) 23 | elif isinstance(node, ActionNode): 24 | n = sum([x.n for x in node.children.values()]) 25 | node.q = sum([(self.gamma * x.q + x.reward) * x.n 26 | for x in node.children.values()]) / n 27 | node = node.parent 28 | 29 | 30 | def monte_carlo(node): 31 | """ 32 | A monte carlo update as in classical UCT. 33 | 34 | See feldman amd Domshlak (2014) for reference. 35 | :param node: The node to start the backup from 36 | """ 37 | r = node.reward 38 | while node is not None: 39 | node.n += 1 40 | node.q = ((node.n - 1)/node.n) * node.q + 1/node.n * r 41 | node = node.parent 42 | -------------------------------------------------------------------------------- /mcts/default_policies.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | 4 | def immediate_reward(state_node): 5 | """ 6 | Estimate the reward with the immediate return of that state. 7 | :param state_node: 8 | :return: 9 | """ 10 | return state_node.state.reward(state_node.parent.parent.state, 11 | state_node.parent.action) 12 | 13 | 14 | class RandomKStepRollOut(object): 15 | """ 16 | Estimate the reward with the sum of returns of a k step rollout 17 | """ 18 | def __init__(self, k): 19 | self.k = k 20 | 21 | def __call__(self, state_node): 22 | self.current_k = 0 23 | 24 | def stop_k_step(state): 25 | self.current_k += 1 26 | return self.current_k > self.k or state.is_terminal() 27 | 28 | return _roll_out(state_node, stop_k_step) 29 | 30 | 31 | def random_terminal_roll_out(state_node): 32 | """ 33 | Estimate the reward with the sum of a rollout till a terminal state. 34 | Typical for terminal-only-reward situations such as games with no 35 | evaluation of the board as reward. 36 | 37 | :param state_node: 38 | :return: 39 | """ 40 | def stop_terminal(state): 41 | return state.is_terminal() 42 | 43 | return _roll_out(state_node, stop_terminal) 44 | 45 | 46 | def _roll_out(state_node, stopping_criterion): 47 | reward = 0 48 | state = state_node.state 49 | parent = state_node.parent.parent.state 50 | action = state_node.parent.action 51 | while not stopping_criterion(state): 52 | reward += state.reward(parent, action) 53 | 54 | action = random.choice(state_node.state.actions) 55 | parent = state 56 | state = parent.perform(action) 57 | 58 | return reward 59 | -------------------------------------------------------------------------------- /mcts/graph.py: -------------------------------------------------------------------------------- 1 | class Node(object): 2 | def __init__(self, parent): 3 | self.parent = parent 4 | self.children = {} 5 | self.q = 0 6 | self.n = 0 7 | 8 | 9 | class ActionNode(Node): 10 | """ 11 | A node holding an action in the tree. 12 | """ 13 | def __init__(self, parent, action): 14 | super(ActionNode, self).__init__(parent) 15 | self.action = action 16 | self.n = 0 17 | 18 | def sample_state(self, real_world=False): 19 | """ 20 | Samples a state from this action and adds it to the tree if the 21 | state never occurred before. 22 | 23 | :param real_world: If planning in belief states are used, this can 24 | be set to True if a real world action is taken. The belief is than 25 | used from the real world action instead from the belief state actions. 26 | :return: The state node, which was sampled. 27 | """ 28 | if real_world: 29 | state = self.parent.state.real_world_perform(self.action) 30 | else: 31 | state = self.parent.state.perform(self.action) 32 | 33 | if state not in self.children: 34 | self.children[state] = StateNode(self, state) 35 | 36 | if real_world: 37 | self.children[state].state.belief = state.belief 38 | 39 | return self.children[state] 40 | 41 | def __str__(self): 42 | return "Action: {}".format(self.action) 43 | 44 | 45 | class StateNode(Node): 46 | """ 47 | A node holding a state in the tree. 48 | """ 49 | def __init__(self, parent, state): 50 | super(StateNode, self).__init__(parent) 51 | self.state = state 52 | self.reward = 0 53 | for action in state.actions: 54 | self.children[action] = ActionNode(self, action) 55 | 56 | @property 57 | def untried_actions(self): 58 | """ 59 | All actions which have never be performed 60 | :return: A list of the untried actions. 61 | """ 62 | return [a for a in self.children if self.children[a].n == 0] 63 | 64 | @untried_actions.setter 65 | def untried_actions(self, value): 66 | raise ValueError("Untried actions can not be set.") 67 | 68 | def __str__(self): 69 | return "State: {}".format(self.state) 70 | 71 | 72 | def breadth_first_search(root, fnc=None): 73 | """ 74 | A breadth first search (BFS) over the subtree starting from root. A 75 | function can be run on all visited nodes. It gets the current visited 76 | node and a data object, which it can update and should return it. This 77 | data is returned by the function but never altered from the BFS itself. 78 | :param root: The node to start the BFS from 79 | :param fnc: The function to run on the nodes 80 | :return: A data object, which can be altered from fnc. 81 | """ 82 | data = None 83 | queue = [root] 84 | while queue: 85 | node = queue.pop(0) 86 | data = fnc(node, data) 87 | for child in node.children.values(): 88 | queue.append(child) 89 | return data 90 | 91 | 92 | def depth_first_search(root, fnc=None): 93 | """ 94 | A depth first search (DFS) over the subtree starting from root. A 95 | function can be run on all visited nodes. It gets the current visited 96 | node and a data object, which it can update and should return it. This 97 | data is returned by the function but never altered from the DFS itself. 98 | :param root: The node to start the DFS from 99 | :param fnc: The function to run on the nodes 100 | :return: A data object, which can be altered from fnc. 101 | """ 102 | data = None 103 | stack = [root] 104 | while stack: 105 | node = stack.pop() 106 | data = fnc(node, data) 107 | for child in node.children.values(): 108 | stack.append(child) 109 | return data 110 | 111 | 112 | def get_actions_and_states(node): 113 | """ 114 | Returns a tuple of two lists containing the action and the state nodes 115 | under the given node. 116 | :param node: 117 | :return: A tuple of two lists 118 | """ 119 | return depth_first_search(node, _get_actions_and_states) 120 | 121 | 122 | def _get_actions_and_states(node, data): 123 | if data is None: 124 | data = ([], []) 125 | 126 | action_nodes, state_nodes = data 127 | 128 | if isinstance(node, ActionNode): 129 | action_nodes.append(node) 130 | elif isinstance(node, StateNode): 131 | state_nodes.append(node) 132 | 133 | return action_nodes, state_nodes -------------------------------------------------------------------------------- /mcts/mcts.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import random 4 | from . import utils 5 | 6 | 7 | class MCTS(object): 8 | """ 9 | The central MCTS class, which performs the tree search. It gets a 10 | tree policy, a default policy, and a backup strategy. 11 | See e.g. Browne et al. (2012) for a survey on monte carlo tree search 12 | """ 13 | def __init__(self, tree_policy, default_policy, backup): 14 | self.tree_policy = tree_policy 15 | self.default_policy = default_policy 16 | self.backup = backup 17 | 18 | def __call__(self, root, n=1500): 19 | """ 20 | Run the monte carlo tree search. 21 | 22 | :param root: The StateNode 23 | :param n: The number of roll-outs to be performed 24 | :return: 25 | """ 26 | if root.parent is not None: 27 | raise ValueError("Root's parent must be None.") 28 | 29 | for _ in range(n): 30 | node = _get_next_node(root, self.tree_policy) 31 | node.reward = self.default_policy(node) 32 | self.backup(node) 33 | 34 | return utils.rand_max(root.children.values(), key=lambda x: x.q).action 35 | 36 | 37 | def _expand(state_node): 38 | action = random.choice(state_node.untried_actions) 39 | return state_node.children[action].sample_state() 40 | 41 | 42 | def _best_child(state_node, tree_policy): 43 | best_action_node = utils.rand_max(state_node.children.values(), 44 | key=tree_policy) 45 | return best_action_node.sample_state() 46 | 47 | 48 | def _get_next_node(state_node, tree_policy): 49 | while not state_node.state.is_terminal(): 50 | if state_node.untried_actions: 51 | return _expand(state_node) 52 | else: 53 | state_node = _best_child(state_node, tree_policy) 54 | return state_node 55 | -------------------------------------------------------------------------------- /mcts/states/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'johannes' 2 | -------------------------------------------------------------------------------- /mcts/states/toy_world_state.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import numpy as np 5 | from scipy.stats import rv_discrete, entropy 6 | from copy import deepcopy 7 | 8 | 9 | class ToyWorldAction(object): 10 | def __init__(self, action): 11 | self.action = action 12 | self._hash = 10*(action[0]+2) + action[1]+2 13 | 14 | def __hash__(self): 15 | return int(self._hash) 16 | 17 | def __eq__(self, other): 18 | return (self.action == other.action).all() 19 | 20 | def __str__(self): 21 | return str(self.action) 22 | 23 | def __repr__(self): 24 | return str(self.action) 25 | 26 | 27 | class ToyWorld(object): 28 | def __init__(self, size, information_gain, goal, manual): 29 | self.size = np.asarray(size) 30 | self.information_gain = information_gain 31 | self.goal = np.asarray(goal) 32 | self.manual = manual 33 | 34 | 35 | class ToyWorldState(object): 36 | def __init__(self, pos, world, belief=None): 37 | self.pos = pos 38 | self.world = world 39 | self.actions = [ToyWorldAction(np.array([0, 1])), 40 | ToyWorldAction(np.array([0, -1])), 41 | ToyWorldAction(np.array([1, 0])), 42 | ToyWorldAction(np.array([-1, 0]))] 43 | if belief: 44 | self.belief = belief 45 | else: 46 | self.belief = dict((a, np.array([1] * 4)) for a in self.actions) 47 | 48 | def _correct_position(self, pos): 49 | upper = np.min(np.vstack((pos, self.world.size)), 0) 50 | return np.max(np.vstack((upper, np.array([0, 0]))), 0) 51 | 52 | def perform(self, action): 53 | # get distribution about outcomes 54 | probabilities = self.belief[action] / np.sum(self.belief[action]) 55 | distrib = rv_discrete(values=(range(len(probabilities)), 56 | probabilities)) 57 | 58 | # draw sample 59 | sample = distrib.rvs() 60 | 61 | # update belief accordingly 62 | belief = deepcopy(self.belief) 63 | belief[action][sample] += 1 64 | 65 | # manual found 66 | if (self.pos == self.world.manual).all(): 67 | print("m", end="") 68 | belief = {ToyWorldAction(np.array([0, 1])): [50, 1, 1, 1], 69 | ToyWorldAction(np.array([0, -1])): [1, 50, 1, 1], 70 | ToyWorldAction(np.array([1, 0])): [1, 1, 50, 1], 71 | ToyWorldAction(np.array([-1, 0])): [1, 1, 1, 50]} 72 | 73 | # build next state 74 | pos = self._correct_position(self.pos + self.actions[sample].action) 75 | 76 | return ToyWorldState(pos, self.world, belief) 77 | 78 | def real_world_perform(self, action): 79 | # update belief accordingly 80 | belief = deepcopy(self.belief) 81 | if (action.action == np.array([0, 1])).all(): 82 | real_action = 0 83 | elif (action.action == np.array([0, -1])).all(): 84 | real_action = 1 85 | elif (action.action == np.array([1, 0])).all(): 86 | real_action = 2 87 | elif (action.action == np.array([-1, 0])).all(): 88 | real_action = 3 89 | belief[action][real_action] += 1 90 | 91 | # manual found 92 | if (self.pos == self.world.manual).all(): 93 | print("M", end="") 94 | belief = {ToyWorldAction(np.array([0, 1])): [50, 1, 1, 1], 95 | ToyWorldAction(np.array([0, -1])): [1, 50, 1, 1], 96 | ToyWorldAction(np.array([1, 0])): [1, 1, 50, 1], 97 | ToyWorldAction(np.array([-1, 0])): [1, 1, 1, 50]} 98 | 99 | pos = self._correct_position(self.pos + action.action) 100 | return ToyWorldState(pos, self.world, belief) 101 | 102 | def is_terminal(self): 103 | return False 104 | 105 | def __eq__(self, other): 106 | return (self.pos == other.pos).all() 107 | 108 | def __hash__(self): 109 | return int(self.pos[0]*100 + self.pos[1]) 110 | 111 | def __str__(self): 112 | return str(self.pos) 113 | 114 | def __repr__(self): 115 | return str(self.pos) 116 | 117 | def reward(self, parent, action): 118 | if (self.pos == self.world.goal).all(): 119 | print("g", end="") 120 | return 100 121 | else: 122 | reward = -1 123 | if self.world.information_gain: 124 | for a in self.actions: 125 | reward += entropy(parent.belief[a], self.belief[a]) 126 | return reward 127 | -------------------------------------------------------------------------------- /mcts/tree_policies.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | 4 | 5 | class UCB1(object): 6 | """ 7 | The typical bandit upper confidence bounds algorithm. 8 | """ 9 | def __init__(self, c): 10 | self.c = c 11 | 12 | def __call__(self, action_node): 13 | if self.c == 0: # assert that no nan values are returned 14 | # for action_node.n = 0 15 | return action_node.q 16 | 17 | return (action_node.q + 18 | self.c * np.sqrt(2 * np.log(action_node.parent.n) / 19 | action_node.n)) 20 | 21 | 22 | def flat(_): 23 | """ 24 | All actions are considered equally useful 25 | :param _: 26 | :return: 27 | """ 28 | return 0 -------------------------------------------------------------------------------- /mcts/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | 4 | 5 | def rand_max(iterable, key=None): 6 | """ 7 | A max function that tie breaks randomly instead of first-wins as in 8 | built-in max(). 9 | :param iterable: The container to take the max from 10 | :param key: A function to compute tha max from. E.g.: 11 | >>> rand_max([-2, 1], key=lambda x:x**2 12 | -2 13 | If key is None the identity is used. 14 | :return: The entry of the iterable which has the maximum value. Tie 15 | breaks are random. 16 | """ 17 | if key is None: 18 | key = lambda x: x 19 | 20 | max_v = -np.inf 21 | max_l = [] 22 | 23 | for item, value in zip(iterable, [key(i) for i in iterable]): 24 | if value == max_v: 25 | max_l.append(item) 26 | elif value > max_v: 27 | max_l = [item] 28 | max_v = value 29 | 30 | return random.choice(max_l) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | 3 | from setuptools import setup 4 | import mcts 5 | 6 | setup( 7 | name='mcts', 8 | version=mcts.__version__, 9 | description='Monte Carlo Tree Search in Python', 10 | author='Johannes Kulick', 11 | author_email='johannes.kulick@ipvs.uni-stuttgart.de', 12 | url='http://github.com/hildensia/mcts', 13 | packages=['mcts'], 14 | requires=['numpy', 'scipy', 'pytest'] 15 | ) 16 | 17 | -------------------------------------------------------------------------------- /tests/test_toy_world_state.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | from mcts.states.toy_world_state import * 4 | 5 | 6 | def test_perform(): 7 | n = 1000 8 | 9 | world = ToyWorld((100, 100), False, (0, 0), np.array([100, 100])) 10 | belief = dict(zip([ToyWorldAction(np.array([0, 1])), 11 | ToyWorldAction(np.array([0, -1])), 12 | ToyWorldAction(np.array([1, 0])), 13 | ToyWorldAction(np.array([-1, 0]))], 14 | [[10, 1, 1, 1], [1, 10, 1, 1], [1, 1, 10, 1], 15 | [1, 1, 1, 10]])) 16 | 17 | state = ToyWorldState((0,0), world, belief) 18 | 19 | outcomes = np.array([0., 0, 0, 0]) 20 | for i in range(n): 21 | new_state = state.perform(state.actions[0]) 22 | #print(new_state.belief[state.actions[0]]) 23 | 24 | if new_state.belief[state.actions[0]][0] == 11: 25 | outcomes[0] += 1 26 | elif new_state.belief[state.actions[0]][1] == 2: 27 | outcomes[1] += 1 28 | elif new_state.belief[state.actions[0]][2] == 2: 29 | outcomes[2] += 1 30 | elif new_state.belief[state.actions[0]][3] == 2: 31 | outcomes[3] += 1 32 | 33 | print(outcomes) 34 | 35 | deviation = 3./np.sqrt(n) 36 | outcomes /= float(n) 37 | print(outcomes) 38 | expectation = np.array(belief[state.actions[0]])/\ 39 | sum(belief[state.actions[0]]) 40 | 41 | assert (expectation - deviation < outcomes).all() 42 | assert (outcomes < expectation + deviation).all() 43 | 44 | 45 | -------------------------------------------------------------------------------- /tests/test_uct.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import random 3 | 4 | from mcts.graph import (depth_first_search, _get_actions_and_states, StateNode) 5 | from mcts.mcts import * 6 | from mcts.utils import rand_max 7 | from mcts.states.toy_world_state import * 8 | 9 | import mcts.tree_policies as tree_policies 10 | import mcts.default_policies as default_policies 11 | import mcts.backups as backups 12 | 13 | 14 | parametrize_gamma = pytest.mark.parametrize("gamma", 15 | [.1, .2, .3, .4, .5, .6, .7, .8, 16 | .9]) 17 | 18 | parametrize_n = pytest.mark.parametrize("n", [1, 10, 23, 100, 101]) 19 | 20 | 21 | @pytest.fixture 22 | def eps(): 23 | return 10e-3 24 | 25 | 26 | class UCBTestState(object): 27 | def __init__(self, id=0): 28 | self.actions = [0] 29 | self.hash = id 30 | 31 | def perform(self, action): 32 | return UCBTestState(self.hash+1) 33 | 34 | def is_terminal(self): 35 | return False 36 | 37 | def reward(self, parent, action): 38 | return -1 39 | 40 | def __hash__(self): 41 | return self.hash 42 | 43 | def __eq__(self, other): 44 | return self.hash == other.hash 45 | 46 | 47 | def test_ucb1(): 48 | ucb1 = tree_policies.UCB1(1) 49 | parent = StateNode(None, UCBTestState()) 50 | an = parent.children[0] 51 | 52 | an.n = 1 53 | parent.n = 1 54 | assert ucb1(an) == 0 55 | 56 | an.n = 0 57 | parent.n = 1 58 | assert np.isnan(ucb1(an)) 59 | 60 | an.n = 1 61 | parent.n = 0 62 | assert np.isnan(ucb1(an)) 63 | 64 | an.q = 1 65 | an.n = 1 66 | parent.n = 1 67 | assert ucb1(an) == 1 68 | 69 | 70 | def test_ucb1_c0(): 71 | ucb1 = tree_policies.UCB1(0) 72 | parent = StateNode(None, UCBTestState()) 73 | an = parent.children[0] 74 | 75 | an.q = 19 76 | an.n = 0 77 | assert ucb1(an) == 19 78 | 79 | 80 | class ComplexTestState(object): 81 | def __init__(self, name): 82 | self.actions = [ComplexTestAction('a'), ComplexTestAction('b')] 83 | self.name = name 84 | 85 | def perform(self, action): 86 | return ComplexTestState(action.name) 87 | 88 | def is_terminal(self): 89 | return False 90 | 91 | def reward(self, parent, action): 92 | return -1 93 | 94 | def __hash__(self): 95 | return self.name.__hash__() 96 | 97 | def __eq__(self, other): 98 | return self.name == other.name 99 | 100 | 101 | class ComplexTestAction(object): 102 | def __init__(self, name): 103 | self.name = name 104 | 105 | def __hash__(self): 106 | return self.name.__hash__() 107 | 108 | def __eq__(self, other): 109 | return self.name == other.name 110 | 111 | 112 | def test_best_child(): 113 | parent = StateNode(None, ComplexTestState('root')) 114 | an0 = parent.children[ComplexTestAction('a')] 115 | an1 = parent.children[ComplexTestAction('b')] 116 | 117 | an0.q = 2 118 | an0.n = 1 119 | an1.q = 1 120 | an1.n = 1 121 | 122 | assert len(parent.children.values()) == 2 123 | 124 | child_state = utils.rand_max(parent.children.values(), 125 | key=lambda x: x.q).sample_state() 126 | assert child_state.state.name == 'a' 127 | 128 | 129 | def test_rand_max(): 130 | i = [1, 4, 5, 3] 131 | assert rand_max(i) == 5 132 | 133 | i = [1, -5, 3, 2] 134 | assert rand_max(i, key=lambda x:x**2) == -5 135 | 136 | parent = StateNode(None, ComplexTestState('root')) 137 | an0 = parent.children[ComplexTestAction('a')] 138 | an1 = parent.children[ComplexTestAction('b')] 139 | 140 | an0.q = 2 141 | an0.n = 1 142 | an1.q = 1 143 | an1.n = 1 144 | 145 | assert rand_max(parent.children.values(), 146 | lambda x: x.q).action.name == 'a' 147 | 148 | assert rand_max(parent.children.values(), 149 | tree_policies.UCB1(0)).action.name == 'a' 150 | 151 | 152 | def test_untried_actions(): 153 | s = ComplexTestState('root') 154 | sn = StateNode(None, s) 155 | assert ComplexTestAction('a') in sn.untried_actions 156 | assert ComplexTestAction('b') in sn.untried_actions 157 | 158 | sn.children[ComplexTestAction('a')].n = 1 159 | assert ComplexTestAction('a') not in sn.untried_actions 160 | assert ComplexTestAction('b') in sn.untried_actions 161 | 162 | 163 | def test_sample_state(): 164 | s = ComplexTestState('root') 165 | root = StateNode(None, s) 166 | 167 | child = root.children[ComplexTestAction('a')] 168 | child.sample_state() 169 | 170 | assert len(child.children.values()) == 1 171 | assert ComplexTestState('a') in child.children 172 | 173 | child.sample_state() 174 | assert len(child.children.values()) == 1 175 | assert ComplexTestState('a') in child.children 176 | 177 | 178 | @pytest.fixture 179 | def toy_world_root(): 180 | world = ToyWorld((100, 100), False, (10, 10), np.array([100, 100])) 181 | state = ToyWorldState((0, 0), world) 182 | root = StateNode(None, state) 183 | return root, state 184 | 185 | 186 | @parametrize_gamma 187 | def test_single_run_uct_search(toy_world_root, gamma): 188 | root, state = toy_world_root 189 | random.seed() 190 | 191 | uct = MCTS(tree_policies.UCB1(1.41), default_policies.immediate_reward, 192 | backups.Bellman(gamma)) 193 | 194 | best_child = uct(root=root, n=1) 195 | 196 | states = [state for states in [action.children.values() 197 | for action in root.children.values()] 198 | for state in states] 199 | 200 | assert len(states) == 1 201 | 202 | assert (len(list(root.children[best_child].children.values())) == 0) 203 | 204 | expanded = None 205 | for action in root.children.values(): 206 | if (action.action != best_child and 207 | len(list(action.children.values())) == 1): 208 | assert expanded is None 209 | expanded = action 210 | 211 | for state in states: 212 | assert (np.sum(np.array(list(state.state.belief.values()))) - 1 == 213 | np.sum(np.array(list(root.state.belief.values())))) 214 | assert root.n == 1 215 | 216 | for action in root.children.values(): 217 | if action.action == expanded.action: 218 | assert action.q == -1.0 219 | else: 220 | assert action.q == 0.0 221 | 222 | 223 | @parametrize_gamma 224 | @parametrize_n 225 | def test_n_run_uct_search(toy_world_root, gamma, n): 226 | root, state = toy_world_root 227 | random.seed() 228 | 229 | uct = MCTS(tree_policies.UCB1(1.41), default_policies.immediate_reward, 230 | backups.Bellman(gamma)) 231 | uct(root=root, n=n) 232 | 233 | assert root.n == n 234 | 235 | action_nodes, state_nodes = depth_first_search(root, 236 | _get_actions_and_states) 237 | 238 | for action in action_nodes: 239 | assert action.n == np.sum([state.n 240 | for state in action.children.values()]) 241 | 242 | for state in state_nodes: 243 | assert state.n >= np.sum([action.n 244 | for action 245 | in state.children.values()]) >= state.n - 1 246 | if state.parent is not None: 247 | assert (np.array(list(state.state.belief.values())).sum() - 1 == 248 | np.array(list(state.parent.parent.state.belief.values())). 249 | sum()) 250 | 251 | 252 | @parametrize_gamma 253 | def test_q_value_simple_state(gamma, eps): 254 | root = StateNode(None, UCBTestState(0)) 255 | uct = MCTS(tree_policies.UCB1(1.41), default_policies.immediate_reward, 256 | backups.Bellman(gamma)) 257 | uct(root=root, n=250) 258 | assert root.q - (-1./(1 - gamma)) < eps 259 | 260 | 261 | @parametrize_gamma 262 | def test_q_value_complex_state(gamma, eps): 263 | if gamma > 0.5: # with bigger gamma UCT converges too slow 264 | return 265 | root = StateNode(None, ComplexTestState(0)) 266 | uct = MCTS(tree_policies.UCB1(1.41), default_policies.immediate_reward, 267 | backups.Bellman(gamma)) 268 | uct(root=root, n=1500) 269 | assert root.q - (-1./(1 - gamma)) < eps 270 | --------------------------------------------------------------------------------