├── Additional Comments.pdf ├── README.md ├── Taxali_Lee_Final_Report.pdf ├── ataxali_final_presentation.pdf ├── bayesSparse.py ├── componentTesters.py ├── global_constants.py ├── gpPosterior.py ├── historyManager.py ├── inputReader.py ├── logger.py ├── main.py ├── mdpSimulator.py ├── thompsonSampling.py └── world.py /Additional Comments.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ataxali/bayesian_reinforcement_learning/18a6ffd893945a11e13cc53b5eb237452cb8c0a3/Additional Comments.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Bayesian Reinforcement Learning 2 | 3 | Aman Taxali, Ray Lee 4 | 5 | ### Motivation 6 | 7 | In this project, we explain a general Bayesian strategy for approximating optimal actions in Partially Observable 8 | Markov Decision Processes, known as sparse sampling. Our experimental results confirm the greedy-optimal behavior of this methodology. We also explore ways of augmenting the sparse sampling algorithm by introducing additional exploration conditions. Our experimental results show that this approach yields a more robust model. 9 | 10 | ### Results and Discussion 11 | 12 | [See the included report for the details of our methodologies and findings.](./Taxali_Lee_Final_Report.pdf) 13 | 14 | [Presentation Slides](./ataxali_final_presentation.pdf) 15 | 16 | #### Running the code 17 | 18 | To run our code, please copy all .py files into a directory. Then, within that directory, run: 19 | 20 | python main.py batch_id=1 name=sparse_sampling move_limit=100 root_path="./" 21 | 22 | Dependencies: Python 3, scikit-learn 0.19.1 23 | 24 | More about the script parameters: 25 | * batch_id and name are unique identifiers used for batch jobs on flux 26 | * move_limit sets the training time for the algorithm 27 | * root_path is the directory where the final models are saved 28 | * the parameters above will run the sparse sampling algorithm 29 | * to run sparse sampling with Thompson sampling, add the parameters: 30 | * prune=T 31 | * ts_hyper_param=25 32 | * where ts_hyper_param determines how quickly the additional exploration condition on sparse sampling is removed (we suggest ts_hyper_param = (move_limit * 0.25) 33 | * to run sparse sampling with episodic reset and bootstrapping, add the parameters: 34 | * bootstrap=T 35 | * ep_len=1 36 | * where ep_len determines how many games make one training episode 37 | 38 | The bayesian sparse sampling algorithm (Kearns et al., 2001) is implemented in bayesSparse.py. The file gpPosterior.py fits the internal belief-based models (for belief-based positions of terminal states). The mdpSimulator.py allows the agent to switch between belief-based models of the MDP and the real MDP. The Beta/Dirichlet posteriors using for Thompson Sampling are defined in thompsonSampling.py. 39 | 40 | -------------------------------------------------------------------------------- /Taxali_Lee_Final_Report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ataxali/bayesian_reinforcement_learning/18a6ffd893945a11e13cc53b5eb237452cb8c0a3/Taxali_Lee_Final_Report.pdf -------------------------------------------------------------------------------- /ataxali_final_presentation.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ataxali/bayesian_reinforcement_learning/18a6ffd893945a11e13cc53b5eb237452cb8c0a3/ataxali_final_presentation.pdf -------------------------------------------------------------------------------- /bayesSparse.py: -------------------------------------------------------------------------------- 1 | # Pseudocode 2 | # 3 | # GrowSparseTree (node, branchfactor, horizon) 4 | # If node.depth = horizon; return 5 | # If node.type = “decision” 6 | # For each a ∈ A 7 | # child = (“outcome”, depth, node.belstate, a) 8 | # GrowSparseTree (child, branchfactor, horizon) 9 | # If node.type = “outcome” 10 | # For i = 1...branchfactor 11 | # [rew,obs] = sample(node.belstate, node.act) 12 | # post = posterior(node.belstate, obs) 13 | # child = (“decision”, depth+1, post, [rew,obs]) 14 | # GrowSparseTree (child, branchfactor, horizon) 15 | # 16 | # EvaluateSubTree (node, horizon) 17 | # If node.children = empty 18 | # immed = MaxExpectedValue(node.belstate) 19 | # return immed * (horizon - node.depth) 20 | # If node.type = “decision” 21 | # return max(EvaluateSubTree(node.children)) 22 | # If node.type = “outcome” 23 | # values = EvaluateSubTree(node.children) 24 | # return avg(node.rewards + values) 25 | 26 | import enum 27 | import numpy as np 28 | from global_constants import print_debug 29 | from mdpSimulator import MDPSimulator 30 | 31 | 32 | NodeType = enum.Enum("NodeType", "Outcome Decision") 33 | 34 | 35 | class SparseTree(object): 36 | 37 | class Node(object): 38 | def __init__(self, type, depth, state, value): 39 | self.type = type 40 | self.depth = depth 41 | self.state = state 42 | self.value = value 43 | 44 | def __str__(self): 45 | return "[" + str(self.type) + ":" + str(self.depth) + ":" + str(self.value) + "]" 46 | 47 | def __init__(self, node, parent, actions=None): 48 | self.node = node 49 | self.parent = parent 50 | self.actions = actions 51 | self.children = [] 52 | 53 | def add_child(self, child): 54 | self.children.append(child) 55 | 56 | def append_val_to_parent(self, value): 57 | self.parent.node.value.append(value) 58 | 59 | def __str__(self): 60 | children_str = "{" 61 | for child in self.children: 62 | children_str += " " + str(child.node) 63 | children_str += "}" 64 | return str(self.node) + " -> " + children_str 65 | 66 | def get_tree_size(self): 67 | if not self.children and self.node.type == NodeType.Decision: 68 | return 1 69 | if not self.children: 70 | return 0 71 | return 1 + sum(map(lambda child: child.get_tree_size(), self.children)) 72 | 73 | def get_tree_depth(self): 74 | if len(self.children) == 0: 75 | return self.node.depth 76 | return max(map(lambda child: child.get_tree_depth(), self.children)) 77 | 78 | 79 | class SparseTreeEvaluator(object): 80 | 81 | def __init__(self, mdp_simulator, root_state, action_set, horizon, 82 | history_manager, state_posterior, goal_state, goal_reward, 83 | loss_penalty, thompson_sampler=None, 84 | discount_factor=0.05): 85 | self.simulator = mdp_simulator 86 | self.root_state = root_state 87 | self.action_set = action_set 88 | self.horizon = horizon 89 | self.lookahead_tree = None 90 | self.thompson_sampler = thompson_sampler 91 | self.discount_factor = discount_factor 92 | self.history_manager = history_manager 93 | self.state_posterior = state_posterior 94 | self.goal_state = goal_state 95 | self.loss_penalty = loss_penalty 96 | self.goal_reward = goal_reward 97 | self.ignored_specials = [] 98 | 99 | def evaluate(self, t): 100 | root_node = SparseTree.Node(NodeType.Decision, 0, self.root_state, []) 101 | lookahead_tree = SparseTree(root_node, None) 102 | specials = [] 103 | for i in range(-1, self.horizon+2): 104 | specials_t = [] 105 | self.__predict_specials(specials_t, t + i) 106 | specials.append(set(specials_t)) 107 | self.__grow_sparse_tree(lookahead_tree, specials) 108 | self.__eval_sparse_tree(lookahead_tree, specials) 109 | self.lookahead_tree = lookahead_tree 110 | 111 | def __str__(self): 112 | children_str = "{" 113 | for child in self.lookahead_tree.children: 114 | children_str += " " + str(child.node) 115 | children_str += "}" 116 | return str(self.lookahead_tree.node) + " -> " + children_str 117 | 118 | def __predict_specials(self, specials, time): 119 | x_preds, y_preds = self.state_posterior.predict(time) 120 | for x in x_preds[0]: 121 | for y in y_preds[0]: 122 | specials.append((int(round(x[0])), int(round(y[0])), "red", self.loss_penalty, "NA")) 123 | 124 | def __grow_sparse_tree(self, lookahead_tree, specials): 125 | if (lookahead_tree.node.depth >= self.horizon) and (lookahead_tree.node.type == NodeType.Decision): 126 | # leaves of sparse tree should be outcome nodes 127 | return 128 | 129 | if lookahead_tree.node.type == NodeType.Decision: 130 | specials_t = list(set(specials[lookahead_tree.node.depth]) | 131 | set(specials[lookahead_tree.node.depth + 1]) | 132 | set(specials[lookahead_tree.node.depth + 2])) 133 | specials_t.append(( 134 | self.goal_state[0], self.goal_state[1], "green", 135 | self.goal_reward, "NA")) 136 | statics = self.state_posterior.get_static_states() 137 | # if we are at root node, and asked to evaluate decision tree 138 | # we assume that special and tree root cannot overlap 139 | # otherwise, there is no tree to construct 140 | filtered_specials = specials_t.copy() 141 | if lookahead_tree.node.depth == 0: 142 | for idx, (i, j, c, r, v) in enumerate(specials_t): 143 | if lookahead_tree.node.state[0] == i and lookahead_tree.node.state[1] == j: 144 | print("Root at special", (i, j, c, r, v)) 145 | if not c == "green": 146 | self.ignored_specials.append([i, j]) 147 | filtered_specials.pop(idx) 148 | if lookahead_tree.node.depth == 0 and self.thompson_sampler: 149 | move_pool = self.__get_actions(lookahead_tree, filtered_specials, statics, True) 150 | lookahead_tree.actions = move_pool 151 | else: 152 | move_pool = self.__get_actions(lookahead_tree, filtered_specials, statics, False) 153 | lookahead_tree.actions = move_pool 154 | for action in move_pool: 155 | orig_state, child_action, child_reward, child_state, _ = \ 156 | self.simulator.sim(lookahead_tree.node.state, action, 157 | specials=filtered_specials, walls=statics) 158 | if list(child_state) == list(orig_state): 159 | continue 160 | child = SparseTree(SparseTree.Node(NodeType.Outcome, lookahead_tree.node.depth, 161 | child_state, [child_reward]), lookahead_tree) 162 | lookahead_tree.add_child(child) 163 | if print_debug: print("Added outcome child depth", child) 164 | self.__grow_sparse_tree(child, specials) 165 | 166 | if lookahead_tree.node.type == NodeType.Outcome: 167 | child = SparseTree(SparseTree.Node(NodeType.Decision, 168 | lookahead_tree.node.depth + 1, 169 | lookahead_tree.node.state, []), lookahead_tree) 170 | lookahead_tree.add_child(child) 171 | self.__grow_sparse_tree(child, specials) 172 | 173 | def __eval_sparse_tree(self, lookahead_tree, t): 174 | for child in lookahead_tree.children: 175 | self.__eval_sparse_tree(child, t) 176 | 177 | if lookahead_tree.node.type == NodeType.Outcome: 178 | state_reward = lookahead_tree.node.value.pop(0) 179 | if lookahead_tree.node.value: 180 | reward_avg = state_reward + (sum(lookahead_tree.node.value) / float(len(lookahead_tree.node.value))) 181 | else: 182 | reward_avg = state_reward 183 | if len(lookahead_tree.children) == 0: 184 | depth_factor = max(self.horizon, lookahead_tree.node.depth) - lookahead_tree.node.depth + 1 185 | lookahead_tree.append_val_to_parent(reward_avg * float(depth_factor) * self.discount_factor) 186 | else: 187 | # average present and future rewards 188 | lookahead_tree.append_val_to_parent(reward_avg * self.discount_factor) 189 | 190 | if lookahead_tree.node.type == NodeType.Decision: 191 | if lookahead_tree.node.depth == 0: 192 | # set sparse tree root value to 193 | # (best_action_index, max_avg_reward_value_discounted) 194 | if len(lookahead_tree.node.value) == 0: 195 | print(lookahead_tree.node.state) 196 | print(lookahead_tree.actions) 197 | print(lookahead_tree.node) 198 | max_value = max(lookahead_tree.node.value) 199 | max_idxs = [i for i, j in enumerate(lookahead_tree.node.value) if j == max_value] 200 | lookahead_tree.node.value = (max_idxs, max_value, [lookahead_tree.node.value]) 201 | else: 202 | # maximize the averages and discount the max 203 | if len(lookahead_tree.node.value): 204 | present_reward = max(lookahead_tree.node.value) 205 | if len(lookahead_tree.children) == 0: 206 | depth_factor = max(self.horizon, 207 | lookahead_tree.node.depth) - lookahead_tree.node.depth + 1 208 | lookahead_tree.append_val_to_parent(present_reward * 209 | float(depth_factor)) 210 | else: 211 | lookahead_tree.append_val_to_parent(present_reward) 212 | 213 | def __get_actions(self, root, specials, statics, use_tsampler): 214 | if use_tsampler: 215 | valid_actions = self.simulator.get_valid_actions(root.node.state, 216 | self.action_set, 217 | specials=specials, 218 | walls=statics) 219 | return self.thompson_sampler.get_action_set(valid_actions) 220 | else: 221 | return self.simulator.get_valid_actions(root.node.state, 222 | self.action_set, 223 | specials=specials, 224 | walls=statics) 225 | 226 | def __get_states(self, root, specials, statics): 227 | ## complete neighbor set 228 | neighbors = [] 229 | for action in self.action_set: 230 | n_orig_state, n_action, n_reward, n_new_state, _ = \ 231 | self.simulator.sim(root.node.state, action, 232 | specials=specials, walls=statics) 233 | neighbors.append(n_new_state) 234 | return neighbors -------------------------------------------------------------------------------- /componentTesters.py: -------------------------------------------------------------------------------- 1 | import world 2 | import threading 3 | import time 4 | import random 5 | import inputReader 6 | import logger 7 | import numpy as np 8 | from mdpSimulator import WorldSimulator 9 | from bayesSparse import SparseTreeEvaluator 10 | from historyManager import HistoryManager, BootstrapHistoryManager 11 | from thompsonSampling import ThompsonSampler 12 | from gpPosterior import GPPosterior 13 | from sklearn.gaussian_process.kernels import ExpSineSquared 14 | from matplotlib import pyplot as plt, colors 15 | import pickle 16 | import sys 17 | import os 18 | 19 | 20 | terminal_state_win = [world.static_specials[2][0], world.static_specials[2][1]] 21 | terminal_state_loss_0 = [world.static_specials[0][0], world.static_specials[0][1]] 22 | terminal_state_loss_1 = [world.static_specials[1][0], world.static_specials[1][1]] 23 | 24 | 25 | def test_world_simulator(): 26 | w = WorldSimulator() 27 | # specials at (4,0) and (4,1) 28 | print(w.sim([3, 3], "up")) # ([3, 3], 'up', -0.040000000000000036, (3, 2)) 29 | print(w.sim([3, 2], "up")) # ([3, 2], 'up', -0.040000000000000036, (3, 1)) 30 | print(w.sim([3, 1], "right")) # ([3, 1], 'right', -1.0, (4, 1)) 31 | print(w.sim([3, 0], "right")) # ([3, 0], 'right', 1.0, (4, 0)) 32 | print(w.sim([3, 2], "left")) # ([3, 0], 'right', 1.0, (4, 0)) 33 | 34 | 35 | def sparse_tree_tester(): 36 | t0 = time.time() 37 | simulator = WorldSimulator() 38 | root_state = [0, 4] 39 | action_set = ["up", "down", "left", "right"] 40 | print(simulator.get_valid_actions(root_state, action_set, specials=[], walls=[])) 41 | 42 | simulator = WorldSimulator() 43 | action_set = ["up", "down", "left", "right"] 44 | horizon = 6 45 | branch_factor = 5 46 | history_manager = HistoryManager(action_set) 47 | ste = SparseTreeEvaluator(simulator, root_state, action_set, history_manager, horizon) 48 | ste.evaluate() 49 | print(ste) 50 | print(random.choice(ste.lookahead_tree.node.value[0])) 51 | t1 = time.time() 52 | print("Runtime:", t1-t0) 53 | 54 | 55 | def thompson_sampler_tester(): 56 | action_set = ["up", "down", "left", "right"] 57 | branching_factor = 2 58 | history = HistoryManager(action_set) 59 | #tsampler = ThompsonSampler(history, branching_factor) 60 | 61 | w = WorldSimulator() 62 | move_1 = w.sim([0, 4], "up") 63 | history.add(move_1) 64 | 65 | move_2 = w.sim([0, 3], "right") 66 | history.add(move_2) 67 | 68 | move_3 = w.sim([1, 3], "up") 69 | history.add(move_3) 70 | print(history.get_action_count_reward_dict()) 71 | #print(tsampler.get_action_set()) 72 | 73 | 74 | def bootstrap_history_tester(): 75 | action_set = ["up", "down", "left", "right"] 76 | history = BootstrapHistoryManager(action_set, 0.5) 77 | 78 | w = WorldSimulator() 79 | move_1 = w.sim([0, 4], "up") 80 | history.add(move_1 + (1,)) 81 | 82 | move_2 = w.sim([0, 3], "right") 83 | history.add(move_2 + (2, )) 84 | 85 | move_3 = w.sim([1, 3], "up") 86 | history.add(move_3 + (3, )) 87 | 88 | print(history.history) 89 | 90 | 91 | def gp_posterior_tester(log): 92 | origin_state = [6, 6] 93 | root_state = origin_state 94 | time = 0 95 | action_set = ["up", "down", "left", "right"] 96 | orig_specials = world.static_specials.copy() 97 | orig_walls = world.static_walls.copy() 98 | simulator = WorldSimulator() 99 | history_manager = HistoryManager(action_set) 100 | kernel = ExpSineSquared(length_scale=1, periodicity=1.0, 101 | periodicity_bounds=(2, 100), 102 | length_scale_bounds=(1, 50)) 103 | gp = GPPosterior(history_manager=history_manager, kernel=kernel, log=None) 104 | # warmup no logging, just gp training 105 | # for i in range(1000): 106 | # next_move = np.random.choice(action_set) 107 | # state, action, sim_r, sim_n_s = simulator.sim(root_state, next_move) 108 | # history_manager.add((root_state, action, sim_r, sim_n_s, time)) 109 | # root_state = sim_n_s 110 | # time += 1 111 | # if abs(sim_r) > 1: 112 | # print("Restarting game", sim_r, time) 113 | # root_state = origin_state 114 | # time = 0 115 | # simulator.specials = orig_specials.copy() 116 | # gp.update_posterior() 117 | # time = 0 118 | # 119 | # t = np.atleast_2d(np.linspace(0, 1000, 1000)).T 120 | # x_preds, y_preds = gp.predict(t) 121 | # cmap_x = ['m', 'c', 'k', 'g'] 122 | # cmap_y = ['r', 'b', 'y', 'teal'] 123 | # total_x_obs = 0 124 | # total_y_obs = 0 125 | # for obs in gp.x_obs: 126 | # total_x_obs += len(obs) 127 | # for obs in gp.y_obs: 128 | # total_y_obs += len(obs) 129 | # print(">>> There are " + str(len(x_preds[0])) + " X gaussian procs for " + str(total_x_obs) + " obs <<<") 130 | # print(">>> There are " + str(len(y_preds[0])) + " Y gaussian procs for " + str(total_y_obs) + " obs <<<") 131 | # for i, preds in enumerate(x_preds[0]): 132 | # plt.plot(t, preds, cmap_x[i]+":", label='x_predictions') 133 | # plt.plot(list(map(lambda x: x[0], gp.x_obs[i])), 134 | # list(map(lambda x: x[1], gp.x_obs[i])), cmap_x[i]+"*", markersize=10) 135 | # for i, preds in enumerate(y_preds[0]): 136 | # plt.plot(t, preds, cmap_y[i]+":", label='y_predictions') 137 | # plt.plot(list(map(lambda y: y[0], gp.y_obs[i])), 138 | # list(map(lambda y: y[1], gp.y_obs[i])), cmap_y[i]+"*", markersize=10) 139 | # plt.xlim(0, 100) 140 | # plt.show(block=True) 141 | # end of warmup 142 | 143 | def predict(time, type): 144 | x_preds, y_preds = gp.predict(time) 145 | for x_pred in x_preds[0]: 146 | for y_pred in y_preds[0]: 147 | msg = "add" + type + str(int(round(x_pred[0]))) + "," + str( 148 | int(round(y_pred[0]))) 149 | logger.log(msg, logger=log) 150 | 151 | for i in range(1000): 152 | next_move = np.random.choice(action_set) 153 | state, action, sim_r, sim_n_s, ns = simulator.sim(root_state, next_move, specials=orig_specials, walls=orig_walls) 154 | orig_specials = ns 155 | if list(sim_n_s) == list(state): 156 | if state not in gp.static_states: 157 | if next_move == "up": 158 | logger.log('addw' + str(sim_n_s[0]) + "," + str(sim_n_s[1] - 1), logger=log) 159 | gp.update_static_states([sim_n_s[0], sim_n_s[1] - 1]) 160 | elif next_move == "down": 161 | logger.log('addw' + str(sim_n_s[0]) + "," + str(sim_n_s[1] + 1),logger=log) 162 | gp.update_static_states([sim_n_s[0], sim_n_s[1] + 1]) 163 | elif next_move == "left": 164 | logger.log('addw' + str(sim_n_s[0] - 1) + "," + str(sim_n_s[1]),logger=log) 165 | gp.update_static_states([sim_n_s[0] - 1, sim_n_s[1]]) 166 | elif next_move == "right": 167 | logger.log('addw' + str(sim_n_s[0] + 1) + "," + str(sim_n_s[1]),logger=log) 168 | gp.update_static_states([sim_n_s[0] + 1, sim_n_s[1]]) 169 | history_manager.add((root_state, action, sim_r, sim_n_s, time)) 170 | logger.log('clr', logger=log) 171 | predict(time - 1, "c") 172 | predict(time + 1, "c") 173 | predict(time, "r") 174 | logger.log(next_move, logger=log) 175 | root_state = sim_n_s 176 | time += 1 177 | if abs(sim_r) > 1: 178 | print("Restarting game", sim_r, time) 179 | root_state = origin_state 180 | time = 0 181 | # simulator.specials = orig_specials.copy() 182 | logger.log("reset", logger=log) 183 | gp.update_posterior() 184 | 185 | with open('gp.out', 'wb') as output: 186 | pickle.dump(gp, output, pickle.HIGHEST_PROTOCOL) 187 | 188 | 189 | def plot_gp(filename): 190 | gp = pickle.load(open(filename, "rb")) 191 | kernel = ExpSineSquared(length_scale=2, periodicity=3.0, 192 | periodicity_bounds=(2, 10), 193 | length_scale_bounds=(1, 10)) 194 | #gp = GPPosterior(history_manager=gp_orig.history_manager, kernel=gp_orig.kernel, log=None) 195 | #gp.update_posterior() 196 | #gp.static_states = gp_orig.static_states 197 | t = np.atleast_2d(np.linspace(0, 50, 50)).T 198 | x_preds, y_preds = gp.predict(t) 199 | cmap_x = ['m', 'c', 'k', 'g'] 200 | cmap_y = ['r', 'b', 'y', 'k'] 201 | total_x_obs = 0 202 | total_y_obs = 0 203 | for obs in gp.x_obs: 204 | total_x_obs += len(obs) 205 | for obs in gp.y_obs: 206 | total_y_obs += len(obs) 207 | print(">>> There are " + str(len(x_preds[0])) + " X gaussian procs for " + str(total_x_obs) + " obs <<<") 208 | print(">>> There are " + str(len(y_preds[0])) + " Y gaussian procs for " + str(total_y_obs) + " obs <<<") 209 | #for i, preds in enumerate(x_preds[0]): 210 | # plt.plot(t, preds, cmap_x[i]+":", label='x_predictions') 211 | # plt.plot(list(map(lambda x: x[0], gp.x_obs[i])), 212 | # list(map(lambda x: x[1], gp.x_obs[i])), cmap_x[i]+"*", markersize=10) 213 | # plt.fill(np.concatenate([t, t[::-1]]), np.concatenate([preds - 1.9600 * x_preds[1][i], 214 | # (preds + 1.9600 * x_preds[1][i])[::-1]]), 215 | # alpha=.25, fc='k', ec='None', label='95% confidence interval') 216 | for i, preds in enumerate(y_preds[0]): 217 | plt.plot(t, preds, cmap_y[i]+":", label='y_predictions') 218 | plt.plot(list(map(lambda y: y[0], gp.y_obs[i])), 219 | list(map(lambda y: y[1], gp.y_obs[i])), cmap_y[i]+"*", markersize=10) 220 | plt.fill(np.concatenate([t, t[::-1]]), np.concatenate([preds - 1.9600 * y_preds[1][i], 221 | (preds + 1.9600 * y_preds[1][i])[::-1]]), 222 | alpha=.25, fc='k', ec='None', label='95% confidence interval') 223 | plt.show(block=True) 224 | 225 | 226 | def sparse_tree_model_tester(arg_dict): 227 | ###### Model Variables ##### 228 | root_state = [0, 3] 229 | goal_state = [9, 6] 230 | goal_reward = 10 231 | loss_penalty = -10 232 | original_root = root_state.copy() 233 | horizon = 10 234 | if 'ep_len' in arg_dict and int(arg_dict['ep_len']): 235 | print("Setting episode length:", arg_dict['ep_len'], "...") 236 | episode_length = int(arg_dict['ep_len']) 237 | else: 238 | episode_length = 0 # number of games before posterior distributions are reset 239 | action_set = ["up", "down", "left", "right"] 240 | episode_move_limit = 100 241 | history_manager = HistoryManager(action_set) 242 | if 'bootstrap' in arg_dict: 243 | print("Setting history manager to Bootstrapped...") 244 | history_manager = BootstrapHistoryManager(action_set, 0.25) 245 | if episode_length: 246 | ts_history_manager = HistoryManager(action_set) 247 | else: 248 | ts_history_manager = history_manager 249 | thompson_sampler = None 250 | if 'prune' in arg_dict: 251 | move_wght = float(arg_dict['move_weight']) 252 | if not move_wght: 253 | raise Exception("Cannot start thompson sampler without move weight!") 254 | print("Creating thompson sampler, with move weight", move_wght, "...") 255 | 256 | thompson_sampler = ThompsonSampler(ts_history_manager, use_constant_boundary=0.5, 257 | move_weight=move_wght, move_discount=0.5, 258 | num_dirch_samples=100) 259 | discount_factor = 0.5 260 | is_testing = False 261 | if arg_dict["testing_file"]: 262 | is_testing = True 263 | 264 | def update_world_root(new_root): 265 | world.static_specials[4] = (new_root[0], new_root[1], "green", 10, "NA") 266 | 267 | def new_goal_state(): 268 | new_y = random.randint(0, 6) 269 | while new_y == 5: 270 | new_y = random.randint(0, 6) 271 | nonlocal goal_state 272 | goal_state = [9, new_y] 273 | print(">> New Goal State:", goal_state, "<<") 274 | update_world_root(goal_state) 275 | 276 | if is_testing: 277 | new_goal_state() 278 | 279 | 280 | ############################ 281 | batch_id = arg_dict['batch_id'] 282 | test_name = arg_dict['name'] 283 | move_limit = int(arg_dict['move_limit']) 284 | root_path = arg_dict['root_path'] 285 | simulator = WorldSimulator() 286 | true_specials = world.static_specials.copy() 287 | true_walls = world.static_walls.copy() 288 | total_move_count = 0 289 | game_move_count = 0 290 | episode_count = 0 291 | running_score = 0 292 | log = None 293 | #log = logger.DataLogger(root_path + "\\" + "testing_" + test_name + ".txt", replace=True) 294 | kernel = ExpSineSquared(length_scale=2, periodicity=3.0, 295 | periodicity_bounds=(2, 10), 296 | length_scale_bounds=(1, 10)) 297 | gp = GPPosterior(history_manager=history_manager, kernel=kernel, log=None) 298 | ############################ 299 | # used for testing purposes 300 | ############################ 301 | if is_testing: 302 | gp = pickle.load(open(arg_dict["testing_file"], "rb")) 303 | history_manager.history = gp.history_manager.history 304 | history_manager.action_count_reward_dict = gp.history_manager.action_count_reward_dict 305 | history_manager.state_count_dict = gp.history_manager.state_count_dict 306 | history_manager.total_rewards = gp.history_manager.total_rewards 307 | history_manager.action_set = gp.history_manager.action_set 308 | print(">> Loaded trained model from," + arg_dict["testing_file"] + "<<") 309 | #history_manager = pickle.load(open("hm_ts_es10.out", "rb")) 310 | #history_manager.history = gp.history_manager.history 311 | #history_manager.action_count_reward_dict = gp.history_manager.action_count_reward_dict 312 | #history_manager.state_count_dict = gp.history_manager.state_count_dict 313 | #history_manager.total_rewards = gp.history_manager.total_rewards 314 | #history_manager.action_set = gp.history_manager.action_set 315 | #history_manager = gp.history_manager 316 | #gp = GPPosterior(history_manager=gp_orig.history_manager, kernel=kernel, log=None) 317 | #gp.update_posterior() 318 | #gp.history_manager = history_manager 319 | 320 | def eval_sparse_tree(sim, root_s, actions, horizon, tsampler=None): 321 | ste = SparseTreeEvaluator(sim, root_s, actions, horizon, 322 | history_manager=history_manager, 323 | thompson_sampler=tsampler, 324 | discount_factor=discount_factor, 325 | state_posterior=gp, 326 | goal_state=goal_state, 327 | goal_reward=goal_reward, 328 | loss_penalty=loss_penalty) 329 | ste.evaluate(game_move_count) 330 | print(ste) 331 | optimal_action_index = random.choice(ste.lookahead_tree.node.value[0]) 332 | possible_actions = ste.lookahead_tree.actions 333 | print("Possible actions: ", possible_actions) 334 | optimal_action = possible_actions[optimal_action_index] 335 | print("Optimal action:", str(optimal_action), ":", optimal_action_index) 336 | print("Tree size: ", ste.lookahead_tree.get_tree_size()) 337 | return optimal_action, optimal_action_index, possible_actions, ste 338 | 339 | while True: 340 | print("Evaluating tree at ", root_state) 341 | # belief based 342 | optimal_action, optimal_action_index, possible_actions, ste = \ 343 | eval_sparse_tree(simulator, root_state, action_set, horizon, thompson_sampler) 344 | # real world 345 | orig_state, action, new_reward, new_state, new_specials = simulator.sim(root_state, optimal_action, 346 | specials=true_specials, 347 | walls=true_walls) 348 | # record change in true specials after move 349 | true_specials = new_specials 350 | 351 | # prev_root = root_state.copy() 352 | root_state = list(new_state) 353 | print("Moving to ", root_state, "...") 354 | print("Move count:", total_move_count) 355 | print("Game move count:", game_move_count) 356 | 357 | history_manager.add((orig_state, action, new_reward, new_state, game_move_count)) 358 | if episode_length: 359 | ts_history_manager.add((orig_state, action, new_reward, new_state, game_move_count)) 360 | 361 | running_score += new_reward 362 | print("Score:", running_score) 363 | 364 | # check for walls 365 | if list(new_state) == list(orig_state): 366 | if tuple(new_state) not in gp.static_states: 367 | if action == "up": 368 | logger.log( 369 | 'addw' + str(new_state[0]) + "," + str(new_state[1] - 1), 370 | logger=log) 371 | gp.update_static_states([new_state[0], new_state[1] - 1]) 372 | elif action == "down": 373 | logger.log( 374 | 'addw' + str(new_state[0]) + "," + str(new_state[1] + 1), 375 | logger=log) 376 | gp.update_static_states([new_state[0], new_state[1] + 1]) 377 | elif action == "left": 378 | logger.log( 379 | 'addw' + str(new_state[0] - 1) + "," + str(new_state[1]), 380 | logger=log) 381 | gp.update_static_states([new_state[0] - 1, new_state[1]]) 382 | elif action == "right": 383 | logger.log( 384 | 'addw' + str(new_state[0] + 1) + "," + str(new_state[1]), 385 | logger=log) 386 | gp.update_static_states([new_state[0] + 1, new_state[1]]) 387 | 388 | # update belief game 389 | def predict(time, type): 390 | x_preds, y_preds = gp.predict(time) 391 | for x_pred in x_preds[0]: 392 | for y_pred in y_preds[0]: 393 | if [int(round(x_pred[0])), int(round(y_pred[0]))] in ste.ignored_specials: 394 | print("Ignoring special for belief world", int(round(x_pred[0])), int(round(y_pred[0]))) 395 | else: 396 | msg = "add" + type + str(int(round(x_pred[0]))) + "," + str(int(round(y_pred[0]))) 397 | logger.log(msg, logger=log) 398 | 399 | logger.log('clr', logger=log) 400 | predict(game_move_count - 1, "c") 401 | predict(game_move_count + 1, "c") 402 | predict(game_move_count, "r") 403 | logger.log(action, logger=log) 404 | 405 | total_move_count += 1 406 | game_move_count += 1 407 | 408 | if total_move_count == move_limit: 409 | if not is_testing: 410 | with open(root_path + "/" + test_name + batch_id + '.out', 'wb') as output: 411 | pickle.dump(gp, output, pickle.HIGHEST_PROTOCOL) 412 | return 413 | 414 | # check terminal conditions 415 | #if abs(new_reward) > 1 or (episode_length and (game_move_count > episode_move_limit)): 416 | if abs(new_reward) > 1 or (game_move_count > episode_move_limit): 417 | episode_count += 1 418 | if new_reward > 0: 419 | print("Agent Won in ", game_move_count, " moves!") 420 | sys.stdout.flush() 421 | print("Restarting game", new_reward, game_move_count) 422 | if is_testing: 423 | new_goal_state() 424 | root_state = original_root.copy() 425 | true_specials = world.static_specials.copy() 426 | #if not (episode_length and (game_move_count > episode_move_limit)): 427 | if not (game_move_count > episode_move_limit): 428 | gp.update_posterior() 429 | game_move_count = 0 430 | logger.log("reset", logger=log) 431 | # check if end of training episode 432 | if episode_length and episode_count >= 1 and episode_count % episode_length == 0: 433 | print('>> End of Episode <<') 434 | ts_history_manager.reset_history() 435 | episode_count = 0 436 | 437 | sys.stdout.flush() 438 | 439 | #if thompson_sampler: 440 | # with open('ts_test_ts_es1_bs.out', 'wb') as output: 441 | # pickle.dump(thompson_sampler, output, pickle.HIGHEST_PROTOCOL) 442 | 443 | #if episode_length: 444 | # with open('ts_hm_test_ts_es1_bs.out', 'wb') as output: 445 | # pickle.dump(ts_history_manager, output, pickle.HIGHEST_PROTOCOL) 446 | 447 | #with open('hm_test_ts_es1_bs.out', 'wb') as output: 448 | # pickle.dump(history_manager, output, pickle.HIGHEST_PROTOCOL) 449 | 450 | 451 | ############# 452 | # gp tester # 453 | ############# 454 | #fake_history_logger = logger.DataLogger("./fake_history.txt", replace=True) 455 | #gp_posterior_tester(fake_history_logger) 456 | 457 | 458 | ################## 459 | # move re-player # 460 | ################## 461 | def launch_belief_world(): 462 | world.World(init_x=0, init_y=6, input_reader=key_handler, specials=[(9, 0, "green", 10, "NA")], 463 | do_belief=True, walls=[]) 464 | 465 | def launch_real_world(): 466 | world.World(init_x=0, init_y=6, input_reader=key_handler) 467 | 468 | #log = logger.ConsoleLogger() 469 | #key_handler = inputReader.KeyInputHandler(log) 470 | #file_tailer = inputReader.FileTailer("./complete_models/input_test_ts_es1.txt", key_handler, log) 471 | #t = threading.Thread(target=launch_belief_world) 472 | #t.daemon = True 473 | #t.start() 474 | 475 | 476 | # plot_gp("bayes_opt4.out.") 477 | 478 | arg_dict = dict() 479 | args = sys.argv 480 | for arg in args: 481 | if "=" in arg: 482 | arg_dict[arg.split("=")[0]] = arg.split("=")[1] 483 | 484 | if arg_dict["testing"]: 485 | for filename in os.listdir(arg_dict["testing"]+"\\"): 486 | arg_dict["testing_file"] = arg_dict["testing"] + "\\" +filename 487 | sparse_tree_model_tester(arg_dict) 488 | else: 489 | sparse_tree_model_tester(arg_dict) 490 | -------------------------------------------------------------------------------- /global_constants.py: -------------------------------------------------------------------------------- 1 | 2 | print_debug = False 3 | -------------------------------------------------------------------------------- /gpPosterior.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logger 3 | from sklearn.gaussian_process import GaussianProcessRegressor 4 | from sklearn.gaussian_process.kernels import WhiteKernel, ExpSineSquared 5 | import global_constants 6 | 7 | 8 | class GPPosterior: 9 | 10 | def __init__(self, history_manager, kernel=None, penalty_threshold=-1, log=None): 11 | self.history_manager = history_manager 12 | self.fitted_models_x = [] 13 | self.fitted_models_y = [] 14 | self.penalty_threshold = penalty_threshold 15 | self.log = log 16 | self.x_obs = [] 17 | self.y_obs = [] 18 | self.static_states = [] 19 | if not kernel: 20 | self.kernel = ExpSineSquared(length_scale=1, periodicity=1.0, 21 | periodicity_bounds=(2, 10), 22 | length_scale_bounds=(1, 3)) 23 | else: 24 | self.kernel = kernel 25 | 26 | def update_static_states(self, state): 27 | self.static_states.append(tuple(state)) 28 | 29 | def get_static_states(self): 30 | return self.static_states.copy() 31 | 32 | def update_posterior(self, n_restarts=10, a=0.01): 33 | # each history obs is 34 | history = list(filter(lambda obs: obs[2] < self.penalty_threshold, self.history_manager.get_history())) 35 | if not len(history): return 36 | classified_x = self.__classify_history(history, 0) 37 | classified_y = self.__classify_history(history, 1) 38 | 39 | self.x_obs = classified_x 40 | self.y_obs = classified_y 41 | 42 | def fit_models(classified_dat, parent): 43 | for dat in classified_dat: 44 | t_obs = np.atleast_2d(list(map(lambda obs: obs[0], dat))).T 45 | i_obs = np.array(list(map(lambda obs: obs[1], dat))) 46 | gp = GaussianProcessRegressor(kernel=self.kernel, 47 | n_restarts_optimizer=n_restarts, 48 | alpha=a).fit(t_obs, i_obs) 49 | parent.append(gp) 50 | 51 | self.fitted_models_x = [] 52 | self.fitted_models_y = [] 53 | fit_models(classified_x, self.fitted_models_x) 54 | fit_models(classified_y, self.fitted_models_y) 55 | 56 | def __classify_history(self, history, new_state_idx): 57 | hist_ts = set(map(lambda obs: obs[4], history)) 58 | hist_vals = set(map(lambda obs: obs[3][new_state_idx], history)) 59 | 60 | hist_dict = {key: list() for key in hist_ts} 61 | for obs in history: hist_dict[obs[4]].append(obs[3][new_state_idx]) 62 | 63 | hist_dict_vals = {key: list() for key in hist_vals} 64 | for obs in history: hist_dict_vals[obs[3][new_state_idx]].append(obs[4]) 65 | 66 | collisions = max(map(len, map(lambda x: set(x), hist_dict.values()))) 67 | classes = [list() for _ in range(collisions)] 68 | 69 | def classify_obs(obs, t): 70 | # trivial case, only 1 class 71 | if len(classes) == 1: 72 | classes[0].append((t, obs)) 73 | return 74 | 75 | #last_class_vals = list(map(lambda lst: lst[len(lst) - 1] if lst else None, classes)) 76 | max_class_vals = [max(map(lambda val: val[1], c)) for c in classes] 77 | sorted_max_class_vals = list(zip(sorted(range(len(max_class_vals)), key=lambda k: max_class_vals[k]), 78 | sorted(max_class_vals))) 79 | # first try to classify to one-off sequence 80 | #for i, c in enumerate(classes): 81 | # for val in c: 82 | # if val and abs(val[0] - t) == 1 and abs(val[1] - obs) == 1: 83 | # classes[i].append((t, obs)) 84 | # return 85 | 86 | for sort_val in sorted_max_class_vals: 87 | if obs <= sort_val[1]: 88 | classes[sort_val[0]].append((t, obs)) 89 | return 90 | 91 | # whoops, rough classification. Need to estimate 92 | error_message = "Cannot cleanly classify: " + str((t,obs)) + " -> " + str(classes) 93 | if self.log: 94 | logger.log(error_message, logger=self.log) 95 | else: 96 | if global_constants.print_debug: print(error_message) 97 | 98 | # toss into next nearest neighbor, frobenius manhattan distance 99 | min_coord_diff = min(sorted_max_class_vals, key=lambda x: abs(x[1] - obs)) 100 | classes[sorted_max_class_vals.index(min_coord_diff)].append((t, obs)) 101 | 102 | info_message = "Classified " + str(obs) + " to classes " + str(classes) 103 | if self.log: 104 | logger.log(info_message, logger=self.log) 105 | else: 106 | if global_constants.print_debug: print(info_message) 107 | 108 | # initialize classes based on time with greatest collisions 109 | coll_arr = [] 110 | for t in sorted(hist_dict.keys(), key=int): 111 | t_obs = hist_dict[t] 112 | if len(set(t_obs)) == collisions: 113 | coll_arr.append((t, max(t_obs) - min(t_obs))) 114 | 115 | if coll_arr: 116 | coll_ranges = list(map(lambda val: val[1], coll_arr)) 117 | max_range_idx = coll_ranges.index(min(coll_ranges)) 118 | t_obs = hist_dict[coll_arr[max_range_idx][0]] 119 | for i, o in enumerate(sorted(set(t_obs))): 120 | classes[i].append((coll_arr[max_range_idx][0], o)) 121 | del hist_dict[coll_arr[max_range_idx][0]] 122 | 123 | for val in sorted(hist_dict_vals.keys(), key=int): 124 | unique_obs = set(hist_dict_vals[val]) 125 | if len(unique_obs) > 1: 126 | # the set below doesnt duplicate data 127 | for i, t in enumerate(sorted(set(hist_dict_vals[val]))): 128 | classify_obs(val, t) 129 | #classes[i].append((t, val)) 130 | else: 131 | classify_obs(val, list(unique_obs)[0]) 132 | 133 | # validate classes 134 | for i, c in enumerate(classes): 135 | if not c: 136 | classes.pop(i) 137 | if self.log: 138 | logger.log("Expected more classes than classified!", logger=self.log) 139 | logger.log(str(history), logger=self.log) 140 | else: 141 | if global_constants.print_debug: print("Expected more classes than classified!", str(history)) 142 | 143 | return classes 144 | 145 | def predict(self, time): 146 | x_preds = [] 147 | x_stds = [] 148 | for gp in self.fitted_models_x: 149 | preds, stds = gp.predict(time, return_std=True) 150 | x_preds.append(preds) 151 | x_stds.append(stds) 152 | 153 | y_preds = [] 154 | y_stds = [] 155 | for gp in self.fitted_models_y: 156 | preds, stds = gp.predict(time, return_std=True) 157 | y_preds.append(preds) 158 | y_stds.append(stds) 159 | 160 | return (x_preds, x_stds), (y_preds, y_stds) 161 | -------------------------------------------------------------------------------- /historyManager.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class HistoryManager(object): 5 | def __init__(self, actions): 6 | self.history = list() 7 | self.action_count_reward_dict = dict.fromkeys(actions, (0, 0)) 8 | self.state_count_dict = dict() 9 | self.total_rewards = 0 10 | self.action_set = actions 11 | 12 | def reset_history(self): 13 | self.history = list() 14 | self.action_count_reward_dict = dict.fromkeys(self.action_set, (0, 0)) 15 | self.state_count_dict = dict() 16 | self.total_rewards = 0 17 | 18 | def get_history(self): 19 | return self.history 20 | 21 | def get_action_set(self): 22 | return self.action_set 23 | 24 | def get_action_count_reward_dict(self): 25 | return self.action_count_reward_dict 26 | 27 | def get_total_rewards(self): 28 | return self.total_rewards 29 | 30 | def add(self, observation): 31 | # each observation must be 32 | if not isinstance(observation, tuple): 33 | observation = tuple(observation) 34 | if not len(observation) == 5: 35 | raise Exception("") 36 | self.history.append(observation) 37 | if observation[1] in self.action_count_reward_dict: 38 | count, reward = self.action_count_reward_dict[observation[1]] 39 | self.action_count_reward_dict[observation[1]] = (count+1, reward+observation[2]) 40 | self.total_rewards += observation[2] 41 | else: 42 | raise Exception(str(observation[1]), 43 | " does not exist in action set dictionary") 44 | if not self.state_count_dict.keys(): 45 | "Print adding init state" 46 | self.state_count_dict[tuple(observation[0])] = 1 47 | if tuple(observation[3]) in self.state_count_dict: 48 | self.state_count_dict[tuple(observation[3])] += 1 49 | else: 50 | self.state_count_dict[tuple(observation[3])] = 1 51 | 52 | 53 | class BootstrapHistoryManager(HistoryManager): 54 | def __init__(self, actions, batch_prop, penalty_threshold=-1): 55 | super(BootstrapHistoryManager, self).__init__(actions) 56 | self.batch_prop = batch_prop 57 | self.penalty_threshold = penalty_threshold 58 | 59 | def get_history(self): 60 | history = list(filter(lambda obs: obs[2] < self.penalty_threshold, self.history)) 61 | if not history: 62 | return history 63 | bootstrap_sample_size = max(int(round(self.batch_prop * len(history))), 0) 64 | if not bootstrap_sample_size: 65 | return history 66 | bootstrap_idxs = np.random.choice(len(history), bootstrap_sample_size, replace=True) 67 | bootstrap_sample = [history[i] for i in bootstrap_idxs] 68 | print("Bootstrap history will add", len(bootstrap_sample), " samples to ", len(history)) 69 | #latest_t = history[len(history)-1][4] 70 | multiplier = 2 71 | for sample in bootstrap_sample: 72 | local_multiplier = multiplier 73 | while ((sample[0], sample[1], sample[2], sample[3], sample[4]*local_multiplier) in history): 74 | local_multiplier += 1 75 | history.append((sample[0], sample[1], sample[2], sample[3], sample[4]*local_multiplier)) 76 | return history 77 | 78 | def get_action_count_reward_dict(self): 79 | raise Exception("Thompson sampler shouldn't be calling bootstrapped history!") 80 | 81 | if len(self.history) == 0: 82 | return dict.fromkeys(self.action_set, (0, 0)) 83 | 84 | bootstrap_sample_size = max(int(round(self.batch_prop * len(self.history))), 1) 85 | bootstrap_idxs = np.random.choice(len(self.history), bootstrap_sample_size, replace=True) 86 | bootstrap_sample = [self.history[i] for i in bootstrap_idxs] 87 | # history_with_bootstrap = self.history + bootstrap_sample 88 | action_reward_dict = self.action_count_reward_dict.copy() 89 | for sample in bootstrap_sample: 90 | if sample[1] in action_reward_dict: 91 | count, reward = action_reward_dict[sample[1]] 92 | action_reward_dict[sample[1]] = ( 93 | count + 1, reward + sample[2]) 94 | else: 95 | action_reward_dict[sample[1]] = (1, sample[2]) 96 | return action_reward_dict 97 | -------------------------------------------------------------------------------- /inputReader.py: -------------------------------------------------------------------------------- 1 | import time 2 | import logger 3 | import threading 4 | import enum 5 | import queue 6 | import global_constants 7 | 8 | 9 | class InputHandler(object): 10 | """Abstract class for input handler, given to game or game simulator""" 11 | def get_next_key(self): 12 | raise NotImplementedError("Unimplemented method!") 13 | 14 | 15 | class KeyInputHandler(InputHandler): 16 | # current only supports inputs for maze game 17 | # add to enum if other keys need to be handled 18 | keys = enum.Enum('keys', 'UP DOWN LEFT RIGHT RESET') 19 | 20 | def __init__(self, log): 21 | self.log = log 22 | self.key_q = queue.Queue() 23 | logger.log("Key stroke handler created...", logger.Level.INFO, log) 24 | 25 | # handle needs to be thread safe (a handler can be shared between threads)! 26 | def handle(self, line): 27 | # clean input 28 | inputs = line.lower().split() 29 | # put into queue, with blocking enabled 30 | for input in inputs: 31 | next_key = None 32 | if input == 'up': 33 | next_key = self.keys.UP 34 | elif input == 'down': 35 | next_key = self.keys.DOWN 36 | elif input == 'left': 37 | next_key = self.keys.LEFT 38 | elif input == 'right': 39 | next_key = self.keys.RIGHT 40 | elif input == 'reset': 41 | next_key = self.keys.RESET 42 | elif input[:4] == 'addr': 43 | next_key = input 44 | elif input[:4] == 'addc': 45 | next_key = input 46 | elif input[:4] == 'addw': 47 | next_key = input 48 | elif input[:3] == 'clr': 49 | next_key = input 50 | else: 51 | if global_constants.print_debug: 52 | logger.log("Unrecognized input:" + line.strip(), logger.Level.ERROR, self.log) 53 | if next_key is not None: 54 | self.key_q.put(next_key, block=True) 55 | # logger.log("Pushing:" + str(next_key), logger.Level.DEBUG, self.log) 56 | 57 | def get_next_key(self): 58 | # this is a blocking call, will wait until an item is on queue 59 | return self.key_q.get(block=True) 60 | 61 | 62 | class FileTailer(threading.Thread): 63 | def __init__(self, filepath, handler, log, tail_polling_secs = 1.0): 64 | super(FileTailer, self).__init__() 65 | self.file = open(filepath, 'r') 66 | self.log = log 67 | self.alive = True 68 | self.handler = handler 69 | self.tail_polling_secs = tail_polling_secs 70 | logger.log("File tailer created for:" + filepath + "...", 71 | logger.Level.INFO, log) 72 | # start the thread in ctor 73 | self.start() 74 | 75 | def tail(self): 76 | while self.alive: 77 | read_ptr = self.file.tell() 78 | line = self.file.readline() 79 | if not line: 80 | time.sleep(self.tail_polling_secs) 81 | self.file.seek(read_ptr) 82 | else: 83 | yield line 84 | 85 | # start() invokes run 86 | def run(self): 87 | logger.log("File tailer for " + str(self.file) + " started...", 88 | logger.Level.INFO, self.log) 89 | for line in self.tail(): 90 | self.handler.handle(line) 91 | 92 | 93 | class KeyListener(threading.Thread): 94 | # key inputs captured by std-in 95 | # constants for user input 96 | m_keys = {'w': 'up', 's': 'down', 'a': 'left', 'd': 'right'} 97 | 98 | def __init__(self, handler, log): 99 | super(KeyListener, self).__init__() 100 | self.log = log 101 | self.alive = True 102 | self.handler = handler 103 | logger.log("KeyListener created for stdin...", logger.Level.INFO, log) 104 | # start the thread in ctor 105 | self.start() 106 | 107 | # start() invokes run 108 | def run(self): 109 | while self.alive: 110 | player_move = input() 111 | if player_move in self.m_keys: 112 | self.handler.handle(self.m_keys[player_move]) 113 | else: 114 | logger.log("Unrecognized input:" + player_move, logger.Level.INFO, self.log) 115 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import threading 3 | import os 4 | from enum import Enum 5 | 6 | 7 | # global threading lock for all logging operations 8 | ACTIVE_LOGGERS = [] 9 | ACTIVE_LOGGERS_LK = threading.Lock() 10 | 11 | 12 | class Level(Enum): 13 | 14 | DEBUG = logging.DEBUG 15 | INFO = logging.INFO 16 | ERROR = logging.ERROR 17 | WARN = logging.WARN 18 | FATAL = logging.FATAL 19 | 20 | def get_value(self): 21 | return self.value 22 | 23 | def __str__(self): 24 | return self.name 25 | 26 | 27 | def log(message, level=Level.DEBUG, logger=None): 28 | if logger is None: 29 | __log_all(message, level) 30 | else: 31 | with ACTIVE_LOGGERS_LK: 32 | if isinstance(logger, list): 33 | # in python3 map is lazily evaluated, need list wrapper 34 | list(map(lambda lg: __send_to_logger(lg, message, level), logger)) 35 | else: 36 | __send_to_logger(logger, message, level) 37 | 38 | 39 | def __log_all(message, level): 40 | times_logged = 0 41 | with ACTIVE_LOGGERS_LK: 42 | for logger in ACTIVE_LOGGERS: 43 | times_logged += __send_to_logger(logger, message, level) 44 | 45 | if times_logged is not len(ACTIVE_LOGGERS): 46 | raise Exception("Failed to write to all loggers!") 47 | 48 | 49 | def __send_to_logger(logger, message, level): 50 | log = logger.get_logger() 51 | if level is Level.DEBUG: 52 | log.debug(message) 53 | return 1 54 | elif level is Level.INFO: 55 | log.info(message) 56 | return 1 57 | elif level is Level.ERROR: 58 | log.error(message) 59 | return 1 60 | elif level is Level.WARN: 61 | log.warn(message) 62 | return 1 63 | elif level is Level.FATAL: 64 | log.fatal(message) 65 | return 1 66 | return 0 67 | 68 | 69 | def append_active_logger(logger): 70 | with ACTIVE_LOGGERS_LK: 71 | ACTIVE_LOGGERS.append(logger) 72 | 73 | 74 | class FileLogger: 75 | 76 | def __init__(self, filename=".log", level=Level.DEBUG, name='', format=None, 77 | header=True): 78 | logger = logging.getLogger("FileLogger" + name) 79 | handler = logging.FileHandler(name + filename) 80 | if format is None: 81 | formatter = logging.Formatter("%(asctime)s [%(threadName)s] %(message)s") 82 | else: 83 | formatter = logging.Formatter(format) 84 | handler.setFormatter(formatter) 85 | logger.addHandler(handler) 86 | logger.setLevel(level.get_value()) 87 | self.logger = logger 88 | self.name = name 89 | self.filename = filename 90 | append_active_logger(self) 91 | if header: 92 | logger.info('==============================') 93 | logger.info('File-Logger Started...') 94 | 95 | def get_logger(self): 96 | return self.logger 97 | 98 | def __str__(self): 99 | return "Name:" + self.name + " Filename:" + self.filename + " logger:"\ 100 | + str(self.logger) 101 | 102 | 103 | class ConsoleLogger: 104 | 105 | def __init__(self, level=Level.DEBUG, name=''): 106 | logger = logging.getLogger("ConsoleLogger" + name) 107 | handler = logging.StreamHandler() 108 | formatter = logging.Formatter("%(asctime)s [%(threadName)s] [logger: " + name + "] %(message)s") 109 | handler.setFormatter(formatter) 110 | logger.addHandler(handler) 111 | logger.setLevel(level.get_value()) 112 | self.logger = logger 113 | self.name = name 114 | append_active_logger(self) 115 | logger.info('================================') 116 | logger.info('Console-Logger Started...') 117 | 118 | def get_logger(self): 119 | return self.logger 120 | 121 | def __str__(self): 122 | return "Name:" + self.name + " Filename:" + "Console" + " logger:"\ 123 | + str(self.logger) 124 | 125 | 126 | class DataLogger(FileLogger): 127 | def __init__(self, filename="data.log", level=Level.DEBUG, name='', replace=False): 128 | if replace: 129 | # replace data file if it already exists 130 | if os.path.isfile(filename): 131 | os.remove(filename) 132 | super(DataLogger, self).__init__(filename, level, name, "%(message)s", header=False) 133 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import world 2 | import random 3 | import logger 4 | from mdpSimulator import WorldSimulator 5 | from bayesSparse import SparseTreeEvaluator 6 | from historyManager import HistoryManager, BootstrapHistoryManager 7 | from thompsonSampling import ThompsonSampler 8 | from gpPosterior import GPPosterior 9 | from sklearn.gaussian_process.kernels import ExpSineSquared 10 | import pickle 11 | import sys 12 | import os 13 | 14 | 15 | def sparse_tree_model_tester(arg_dict): 16 | ###### Model Variables ##### 17 | root_state = [0, 3] 18 | goal_state = [9, 6] 19 | goal_reward = 10 20 | loss_penalty = -10 21 | original_root = root_state.copy() 22 | horizon = 10 23 | if 'ep_len' in arg_dict and int(arg_dict['ep_len']): 24 | print("Setting episode length:", arg_dict['ep_len'], "...") 25 | episode_length = int(arg_dict['ep_len']) 26 | else: 27 | episode_length = 0 # number of games before posterior distributions are reset 28 | action_set = ["up", "down", "left", "right"] 29 | episode_move_limit = 100 30 | history_manager = HistoryManager(action_set) 31 | if 'bootstrap' in arg_dict: 32 | print("Setting history manager to Bootstrapped...") 33 | history_manager = BootstrapHistoryManager(action_set, 0.25) 34 | if episode_length: 35 | ts_history_manager = HistoryManager(action_set) 36 | else: 37 | ts_history_manager = history_manager 38 | thompson_sampler = None 39 | if 'prune' in arg_dict: 40 | move_wght = float(arg_dict['ts_hyper_param']) 41 | if not move_wght: 42 | raise Exception("Cannot start thompson sampler without move weight!") 43 | print("Creating thompson sampler, with move weight", move_wght, "...") 44 | 45 | thompson_sampler = ThompsonSampler(ts_history_manager, use_constant_boundary=0.5, 46 | move_weight=move_wght, move_discount=0.5, 47 | num_dirch_samples=100) 48 | discount_factor = 0.5 49 | is_testing = False 50 | if "testing_file" in arg_dict: 51 | is_testing = True 52 | 53 | def update_world_root(new_root): 54 | world.static_specials[4] = (new_root[0], new_root[1], "green", 10, "NA") 55 | 56 | def new_goal_state(): 57 | new_y = random.randint(0, 6) 58 | while new_y == 5: 59 | new_y = random.randint(0, 6) 60 | nonlocal goal_state 61 | goal_state = [9, new_y] 62 | print(">> New Goal State:", goal_state, "<<") 63 | update_world_root(goal_state) 64 | 65 | if is_testing: 66 | new_goal_state() 67 | 68 | ############################ 69 | batch_id = arg_dict['batch_id'] 70 | test_name = arg_dict['name'] 71 | move_limit = int(arg_dict['move_limit']) 72 | root_path = arg_dict['root_path'] 73 | simulator = WorldSimulator() 74 | true_specials = world.static_specials.copy() 75 | true_walls = world.static_walls.copy() 76 | total_move_count = 0 77 | game_move_count = 0 78 | episode_count = 0 79 | running_score = 0 80 | log = None 81 | kernel = ExpSineSquared(length_scale=2, periodicity=3.0, 82 | periodicity_bounds=(2, 10), 83 | length_scale_bounds=(1, 10)) 84 | gp = GPPosterior(history_manager=history_manager, kernel=kernel, log=None) 85 | ############################ 86 | # used for testing purposes 87 | ############################ 88 | if is_testing: 89 | gp = pickle.load(open(arg_dict["testing_file"], "rb")) 90 | history_manager.history = gp.history_manager.history 91 | history_manager.action_count_reward_dict = gp.history_manager.action_count_reward_dict 92 | history_manager.state_count_dict = gp.history_manager.state_count_dict 93 | history_manager.total_rewards = gp.history_manager.total_rewards 94 | history_manager.action_set = gp.history_manager.action_set 95 | print(">> Loaded trained model from," + arg_dict["testing_file"] + "<<") 96 | 97 | 98 | def eval_sparse_tree(sim, root_s, actions, horizon, tsampler=None): 99 | ste = SparseTreeEvaluator(sim, root_s, actions, horizon, 100 | history_manager=history_manager, 101 | thompson_sampler=tsampler, 102 | discount_factor=discount_factor, 103 | state_posterior=gp, 104 | goal_state=goal_state, 105 | goal_reward=goal_reward, 106 | loss_penalty=loss_penalty) 107 | ste.evaluate(game_move_count) 108 | print(ste) 109 | optimal_action_index = random.choice(ste.lookahead_tree.node.value[0]) 110 | possible_actions = ste.lookahead_tree.actions 111 | print("Possible actions: ", possible_actions) 112 | optimal_action = possible_actions[optimal_action_index] 113 | print("Optimal action:", str(optimal_action), ":", optimal_action_index) 114 | print("Tree size: ", ste.lookahead_tree.get_tree_size()) 115 | return optimal_action, optimal_action_index, possible_actions, ste 116 | 117 | while True: 118 | print("Evaluating tree at ", root_state) 119 | # belief based 120 | optimal_action, optimal_action_index, possible_actions, ste = \ 121 | eval_sparse_tree(simulator, root_state, action_set, horizon, thompson_sampler) 122 | # real world 123 | orig_state, action, new_reward, new_state, new_specials = simulator.sim(root_state, optimal_action, 124 | specials=true_specials, 125 | walls=true_walls) 126 | # record change in true specials after move 127 | true_specials = new_specials 128 | 129 | # prev_root = root_state.copy() 130 | root_state = list(new_state) 131 | print("Moving to ", root_state, "...") 132 | print("Move count:", total_move_count) 133 | print("Game move count:", game_move_count) 134 | 135 | history_manager.add((orig_state, action, new_reward, new_state, game_move_count)) 136 | if episode_length: 137 | ts_history_manager.add((orig_state, action, new_reward, new_state, game_move_count)) 138 | 139 | running_score += new_reward 140 | print("Score:", running_score) 141 | 142 | # check for walls 143 | if list(new_state) == list(orig_state): 144 | if tuple(new_state) not in gp.static_states: 145 | if action == "up": 146 | logger.log( 147 | 'addw' + str(new_state[0]) + "," + str(new_state[1] - 1), 148 | logger=log) 149 | gp.update_static_states([new_state[0], new_state[1] - 1]) 150 | elif action == "down": 151 | logger.log( 152 | 'addw' + str(new_state[0]) + "," + str(new_state[1] + 1), 153 | logger=log) 154 | gp.update_static_states([new_state[0], new_state[1] + 1]) 155 | elif action == "left": 156 | logger.log( 157 | 'addw' + str(new_state[0] - 1) + "," + str(new_state[1]), 158 | logger=log) 159 | gp.update_static_states([new_state[0] - 1, new_state[1]]) 160 | elif action == "right": 161 | logger.log( 162 | 'addw' + str(new_state[0] + 1) + "," + str(new_state[1]), 163 | logger=log) 164 | gp.update_static_states([new_state[0] + 1, new_state[1]]) 165 | 166 | # update belief game 167 | def predict(time, type): 168 | x_preds, y_preds = gp.predict(time) 169 | for x_pred in x_preds[0]: 170 | for y_pred in y_preds[0]: 171 | if [int(round(x_pred[0])), int(round(y_pred[0]))] in ste.ignored_specials: 172 | print("Ignoring special for belief world", int(round(x_pred[0])), int(round(y_pred[0]))) 173 | else: 174 | msg = "add" + type + str(int(round(x_pred[0]))) + "," + str(int(round(y_pred[0]))) 175 | logger.log(msg, logger=log) 176 | 177 | logger.log('clr', logger=log) 178 | predict(game_move_count - 1, "c") 179 | predict(game_move_count + 1, "c") 180 | predict(game_move_count, "r") 181 | logger.log(action, logger=log) 182 | 183 | total_move_count += 1 184 | game_move_count += 1 185 | 186 | if total_move_count == move_limit: 187 | if not is_testing: 188 | with open(root_path + "/" + test_name + batch_id + '.out', 'wb') as output: 189 | pickle.dump(gp, output, pickle.HIGHEST_PROTOCOL) 190 | return 191 | 192 | # check terminal conditions 193 | if abs(new_reward) > 1 or (game_move_count > episode_move_limit): 194 | episode_count += 1 195 | if new_reward > 0: 196 | print("Agent Won in ", game_move_count, " moves!") 197 | sys.stdout.flush() 198 | print("Restarting game", new_reward, game_move_count) 199 | if is_testing: 200 | new_goal_state() 201 | root_state = original_root.copy() 202 | true_specials = world.static_specials.copy() 203 | if not (game_move_count > episode_move_limit): 204 | gp.update_posterior() 205 | game_move_count = 0 206 | logger.log("reset", logger=log) 207 | # check if end of training episode 208 | if episode_length and episode_count >= 1 and episode_count % episode_length == 0: 209 | print('>> End of Episode <<') 210 | ts_history_manager.reset_history() 211 | episode_count = 0 212 | 213 | sys.stdout.flush() 214 | 215 | ############################### 216 | # Required script parameters # 217 | ############################### 218 | # name (string) 219 | # batch_id (int) 220 | # move_limit (int) 221 | # root_path (directory path) 222 | 223 | ############################### 224 | # Optional script parameters # 225 | ############################### 226 | # prune (T/F) 227 | # bootstrap (T/F) 228 | # ep_len (int) 229 | # testing (directory path) 230 | 231 | arg_dict = dict() 232 | args = sys.argv 233 | for arg in args: 234 | if "=" in arg: 235 | arg_dict[arg.split("=")[0]] = arg.split("=")[1] 236 | 237 | if 'testing' in arg_dict: 238 | for filename in os.listdir(arg_dict["testing"]+"\\"): 239 | arg_dict["testing_file"] = arg_dict["testing"] + "\\" +filename 240 | sparse_tree_model_tester(arg_dict) 241 | else: 242 | sparse_tree_model_tester(arg_dict) -------------------------------------------------------------------------------- /mdpSimulator.py: -------------------------------------------------------------------------------- 1 | import world 2 | 3 | 4 | class MDPSimulator(object): 5 | """Abstract class for MDP Simulator""" 6 | def sim(self, state, action, specials, walls): 7 | """ Must return tuple """ 8 | raise NotImplementedError("Unimplemented method!") 9 | 10 | def get_valid_actions(self, root, actions, specials, walls): 11 | raise NotImplementedError("Unimplemented method!") 12 | 13 | 14 | class WorldSimulator(MDPSimulator): 15 | WORLD_SIM_CACHE = dict() 16 | WORLD_VALID_ACTIONS_CACHE = dict() 17 | 18 | def __init__(self, do_render=False): 19 | # perhaps init threadpool here 20 | self.do_render = do_render 21 | 22 | def __run(self, sim_world, sim_state, sim_action): 23 | # maze doesnt need current state to simulate 24 | # sim_world has initialized agent position 25 | r = -sim_world.score 26 | if sim_action == sim_world.actions[0]: 27 | sim_world.try_move_idx(0) 28 | elif sim_action == sim_world.actions[1]: 29 | sim_world.try_move_idx(1) 30 | elif sim_action == sim_world.actions[2]: 31 | sim_world.try_move_idx(2) 32 | elif sim_action == sim_world.actions[3]: 33 | sim_world.try_move_idx(3) 34 | else: 35 | return 36 | s2 = sim_world.player 37 | r += sim_world.score 38 | return r, s2 39 | 40 | def sim(self, state, action, specials, walls): 41 | init_x, init_y = self.get_x_y(state) 42 | sim_world = world.World(self.do_render, init_x=init_x, init_y=init_y, 43 | specials=specials, walls=walls) 44 | sim_r, sim_n_s = self.__run(sim_world, state, action) 45 | if self.do_render: sim_world.destroy() 46 | # return values are: 47 | # print("Sim Result: ", state, action, sim_r, sim_n_s) 48 | return state, action, sim_r, sim_n_s, sim_world.specials 49 | 50 | def get_x_y(self, state): 51 | return state[0], state[1] 52 | 53 | def get_valid_actions(self, root, actions, specials, walls): 54 | valid_actions = [] 55 | init_x, init_y = self.get_x_y(root) 56 | for action in actions: 57 | sim_world = world.World(self.do_render, init_x=init_x, init_y=init_y, 58 | specials=specials, walls=walls) 59 | sim_r, sim_n_s = self.__run(sim_world, root, action) 60 | if not list(sim_n_s) == list(root): 61 | valid_actions.append(action) 62 | if self.do_render: sim_world.destroy() 63 | return valid_actions 64 | 65 | 66 | -------------------------------------------------------------------------------- /thompsonSampling.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from global_constants import print_debug 3 | import sys 4 | 5 | 6 | class ThompsonSampler(object): 7 | def __init__(self, history_manager, use_constant_boundary=None, move_weight=0.05, 8 | move_discount=0.5, num_dirch_samples=100): 9 | self.history_manager = history_manager 10 | self.use_constant_boundary = use_constant_boundary 11 | self.move_weight = move_weight 12 | self.move_discount = move_discount 13 | self.num_dirch_samples = num_dirch_samples 14 | 15 | def get_action_set(self, action_set): 16 | # exploration vs exploitation 17 | # exploration means not backtracking 18 | # exploitation means that nearly the complete action set should be considered 19 | # history for thompson sampler is not assumed to be restricted to current game 20 | 21 | # under beta posterior: exploitation => for each move, alpha > 5, beta <= 1 22 | # under beta posterior: exploration => maintain velocity in up/down, left/right directions 23 | 24 | action_psuedo_counts = self.history_manager.get_action_count_reward_dict() 25 | move_counts = list(map(lambda x: action_psuedo_counts[x][0], action_set)) 26 | history_len = max(sum(move_counts), 1) 27 | #weighted_history = float(history_len * self.move_weight) 28 | 29 | # first we use a beta sample to determine hyper parameter 30 | # we want to pick a number between 2 and 4, representing number to actions to reduce to 31 | # when history length > 1/move_weight, we select 3 or 4 moves 32 | n_sample_hyper = np.random.beta(a=history_len, b=self.move_weight, size=1)[0] 33 | #n_sample_hyper = np.mean(np.random.beta(a=weighted_history, b=1, size=self.num_dirch_samples)) 34 | # we want to avoid trivial trees, so choose between 2 and 4 moves 35 | if n_sample_hyper < 1.0/3.0: 36 | sample_hyper = 2 37 | elif n_sample_hyper < 2.0/3.0: 38 | sample_hyper = 3 39 | else: 40 | sample_hyper = 4 41 | 42 | print("TS Hyper:", n_sample_hyper) 43 | sys.stdout.flush() 44 | 45 | sample_hyper = min(sample_hyper, len(action_set)) 46 | 47 | def weighted_sum(type): 48 | running_sum = 0 49 | discount_power = 0 50 | for obs in reversed(self.history_manager.history): 51 | if obs[1] == type: 52 | running_sum += (self.move_discount ** discount_power) 53 | discount_power += 1 54 | 55 | return max(running_sum, 1) 56 | 57 | action_move_counts = list(map(weighted_sum, action_set)) 58 | 59 | dirch_samples = np.random.dirichlet(action_move_counts, self.num_dirch_samples).transpose() 60 | dirch_means = list(map(np.mean, dirch_samples)) 61 | 62 | reduced_action_set = [] 63 | for i in range(sample_hyper): 64 | next_max_idx = dirch_means.index(max(dirch_means)) 65 | reduced_action_set.append(action_set[next_max_idx]) 66 | dirch_means.pop(next_max_idx) 67 | action_set.pop(next_max_idx) 68 | 69 | return reduced_action_set 70 | 71 | 72 | -------------------------------------------------------------------------------- /world.py: -------------------------------------------------------------------------------- 1 | import tkinter as tk 2 | import time 3 | import threading 4 | import numpy as np 5 | from inputReader import KeyInputHandler 6 | 7 | static_x_dim, static_y_dim = (10, 7) 8 | static_time_between_moves = 0.1 9 | 10 | # (x, y, type, reward, velocity) 11 | # 4 cat case 12 | static_specials = [(7, 3, "red", -10, "up"), (2, 4, "red", -10, "left"), 13 | (8, 5, "red", -10, "left"), (3, 0, "red", -10, "down"), 14 | (9, 6, "green", 10, "NA")] 15 | static_walls = [(1, 1), (1, 2), (2, 2), (3, 4), (5, 2), (5, 3), (6, 6), (5, 0), (4, 5)] 16 | 17 | 18 | # 2 cat case 19 | #static_specials = [(7, 3, "red", -10, "up"), (8, 5, "red", -10, "left"), (9, 1, "green", 10, "NA")] 20 | #static_walls = [(1, 1), (1, 2), (2, 1), (2, 2), (3, 4), (5, 3), (5, 4), (5, 5), (5, 0)] 21 | 22 | # deterministic case 23 | #static_specials = [(4, 1, "red", -1), (4, 0, "green", 1)] 24 | #static_x_dim, static_y_dim = (5, 5) 25 | #static_walls = [(1, 1), (1, 2), (2, 1), (2, 2)] 26 | 27 | 28 | class World(object): 29 | 30 | def __init__(self, do_render=True, init_x=None, init_y=None, move_pool=None, 31 | input_reader=None, specials=static_specials, walls=static_walls, 32 | do_restart=False, do_belief=False): 33 | self.do_render = do_render 34 | if self.do_render: self.master = tk.Tk() 35 | 36 | self.triangle_size = 0.2 37 | self.cell_score_min = -0.2 38 | self.cell_score_max = 0.2 39 | self.Width = 50 40 | self.x, self.y = static_x_dim, static_y_dim 41 | self.actions = ["up", "down", "left", "right"] 42 | self.do_restart = do_restart 43 | self.do_belief = do_belief 44 | 45 | if self.do_render: 46 | self.board = tk.Canvas(self.master, width=self.x*self.Width, 47 | height=self.y*self.Width) 48 | self.board.pack(fill=tk.BOTH, expand=tk.YES) 49 | self.score = 0 50 | self.restart = False 51 | self.walk_reward = -0.1 52 | 53 | self.walls = walls 54 | self.belief_walls = [] 55 | self.original_specials = specials.copy() 56 | self.specials = specials 57 | self.belief_states = list() 58 | self.cell_scores = {} 59 | 60 | if do_render: self.render_grid() 61 | if self.do_render: 62 | self.master.bind("", self.call_up) 63 | self.master.bind("", self.call_down) 64 | self.master.bind("", self.call_right) 65 | self.master.bind("", self.call_left) 66 | 67 | if not all(map(lambda x: isinstance(x, int), [init_x, init_y])): 68 | self.player = (0, self.y - 1) 69 | self.origin = (0, self.y - 1) 70 | else: 71 | self.origin = (init_x, init_y) 72 | self.player = self.origin 73 | 74 | if self.do_render: 75 | self.board.grid(row=0, column=0) 76 | self.me = self.board.create_rectangle( 77 | self.player[0] * self.Width + self.Width * 2 / 10, 78 | self.player[1] * self.Width + self.Width * 2 / 10, 79 | self.player[0] * self.Width + self.Width * 8 / 10, 80 | self.player[1] * self.Width + self.Width * 8 / 10, fill="orange", width=1, 81 | tag="me") 82 | 83 | if move_pool: 84 | t = threading.Thread(target=self.run_pooled_moves, args=(move_pool,)) 85 | t.daemon = True 86 | t.start() 87 | 88 | if input_reader: 89 | t = threading.Thread(target=self.run_input_moves, args=(input_reader,)) 90 | t.daemon = True 91 | t.start() 92 | 93 | if do_render: self.master.mainloop() 94 | 95 | def run_input_moves(self, input_reader): 96 | time.sleep(1) 97 | while True: 98 | next_key = input_reader.get_next_key() 99 | if next_key == KeyInputHandler.keys.UP: 100 | time.sleep(static_time_between_moves) 101 | self.call_up(None) 102 | elif next_key == KeyInputHandler.keys.DOWN: 103 | time.sleep(static_time_between_moves) 104 | self.call_down(None) 105 | elif next_key == KeyInputHandler.keys.LEFT: 106 | time.sleep(static_time_between_moves) 107 | self.call_left(None) 108 | elif next_key == KeyInputHandler.keys.RIGHT: 109 | time.sleep(static_time_between_moves) 110 | self.call_right(None) 111 | elif next_key == KeyInputHandler.keys.RESET: 112 | time.sleep(static_time_between_moves) 113 | self.restart_game() 114 | elif next_key[:4] == 'addr': 115 | self.add_belief_node(next_key[4:], "R") 116 | elif next_key[:4] == 'addc': 117 | self.add_belief_node(next_key[4:], "C") 118 | elif next_key[:4] == 'addw': 119 | self.add_belief_walls(next_key[4:]) 120 | elif next_key[:3] == 'clrw': 121 | self.clear_belief_walls() 122 | elif next_key[:3] == 'clr': 123 | self.clear_belief_nodes() 124 | else: 125 | print("Unknown key input:", str(next_key)) 126 | 127 | def clear_belief_walls(self): 128 | if not self.do_belief: return 129 | n = len(self.belief_walls) 130 | for i in range(n): 131 | self.board.delete(self.belief_walls[i]) 132 | self.belief_walls = list() 133 | 134 | def add_belief_walls(self, coords): 135 | if not self.do_belief: return 136 | x_y_vals = list(map(lambda x: int(x), coords.split(','))) 137 | if not len(x_y_vals) == 2: 138 | raise Exception("Cannot add belief wall: " + str(coords)) 139 | new_rect = self.board.create_rectangle(x_y_vals[0]*self.Width, x_y_vals[1]*self.Width, 140 | (x_y_vals[0]+1)*self.Width, (x_y_vals[1]+1)*self.Width, 141 | fill="black", width=1) 142 | self.board.tag_raise(self.me) 143 | self.walls.append((x_y_vals[0], x_y_vals[1])) 144 | self.belief_walls.append(new_rect) 145 | 146 | def clear_belief_nodes(self): 147 | if not self.do_belief: return 148 | n = len(self.belief_states) 149 | for i in range(n): 150 | self.board.delete(self.belief_states[i]) 151 | self.belief_states = list() 152 | 153 | def add_belief_node(self, coords, type): 154 | if not self.do_belief: return 155 | col = None 156 | if type == "R": 157 | col = "red" 158 | else: 159 | col = "light salmon" 160 | x_y_vals = list(map(lambda x: int(x), coords.split(','))) 161 | if not len(x_y_vals) == 2: 162 | raise Exception("Cannot add belief coord: " + str(coords)) 163 | for wall in self.walls: 164 | if x_y_vals[0] == wall[0] and x_y_vals[1] == wall[1]: 165 | return 166 | for special in self.specials: 167 | if x_y_vals[0] == special[0] and x_y_vals[1] == special[1]: 168 | return 169 | new_rect = self.board.create_rectangle(x_y_vals[0] * self.Width, 170 | x_y_vals[1] * self.Width, 171 | (x_y_vals[0] + 1) * self.Width, 172 | (x_y_vals[1] + 1) * self.Width, fill=col, 173 | width=1) 174 | self.board.tag_raise(self.me) 175 | self.belief_states.append(new_rect) 176 | 177 | def run_pooled_moves(self, move_pool): 178 | time.sleep(1) 179 | while len(move_pool) != 0: 180 | action = move_pool[0] 181 | if action == self.actions[0]: 182 | self.try_move(0, -1) 183 | elif action == self.actions[1]: 184 | self.try_move(0, 1) 185 | elif action == self.actions[2]: 186 | self.try_move(-1, 0) 187 | elif action == self.actions[3]: 188 | self.try_move(1, 0) 189 | else: 190 | print("Unknown move", action) 191 | move_pool.pop(0) 192 | time.sleep(static_time_between_moves) 193 | 194 | def create_triangle(self, i, j, action): 195 | if action == self.actions[0]: 196 | return self.board.create_polygon((i+0.5-self.triangle_size)*self.Width, (j+self.triangle_size)*self.Width, 197 | (i+0.5+self.triangle_size)*self.Width, (j+self.triangle_size)*self.Width, 198 | (i+0.5)*self.Width, j*self.Width, 199 | fill="white", width=1) 200 | elif action == self.actions[1]: 201 | return self.board.create_polygon((i+0.5-self.triangle_size)*self.Width, (j+1-self.triangle_size)*self.Width, 202 | (i+0.5+self.triangle_size)*self.Width, (j+1-self.triangle_size)*self.Width, 203 | (i+0.5)*self.Width, (j+1)*self.Width, 204 | fill="white", width=1) 205 | elif action == self.actions[2]: 206 | return self.board.create_polygon((i+self.triangle_size)*self.Width, (j+0.5-self.triangle_size)*self.Width, 207 | (i+self.triangle_size)*self.Width, (j+0.5+self.triangle_size)*self.Width, 208 | i*self.Width, (j+0.5)*self.Width, 209 | fill="white", width=1) 210 | elif action == self.actions[3]: 211 | return self.board.create_polygon((i+1-self.triangle_size)*self.Width, (j+0.5-self.triangle_size)*self.Width, 212 | (i+1-self.triangle_size)*self.Width, (j+0.5+self.triangle_size)*self.Width, 213 | (i+1)*self.Width, (j+0.5)*self.Width, 214 | fill="white", width=1) 215 | 216 | def render_reset_grid(self): 217 | for i in range(self.x): 218 | for j in range(self.y): 219 | self.board.create_rectangle(i*self.Width, j*self.Width, 220 | (i+1)*self.Width, (j+1)*self.Width, fill="black", width=1) 221 | 222 | randn = np.random.choice(range(len(self.actions))) 223 | for action in self.actions[0:randn]: 224 | self.create_triangle(i, j, action) 225 | 226 | def render_grid(self): 227 | for i in range(self.x): 228 | for j in range(self.y): 229 | self.board.create_rectangle(i*self.Width, j*self.Width, 230 | (i+1)*self.Width, (j+1)*self.Width, fill="white", width=1) 231 | #temp = {} 232 | #for action in self.actions: 233 | # temp[action] = self.create_triangle(i, j, action) 234 | # self.cell_scores[(i,j)] = temp 235 | for (i, j, c, w, v) in self.specials: 236 | self.board.create_rectangle(i*self.Width, j*self.Width, 237 | (i+1)*self.Width, (j+1)*self.Width, fill=c, width=1) 238 | for (i, j) in self.walls: 239 | wall_shape = self.board.create_rectangle(i*self.Width, j*self.Width, 240 | (i+1)*self.Width, (j+1)*self.Width, fill="black", width=1) 241 | self.belief_walls.append(wall_shape) 242 | 243 | def set_cell_score(self, state, action, val): 244 | triangle = self.cell_scores[state][action] 245 | green_dec = int(min(255, max(0, (val - self.cell_score_min) * 255.0 / (self.cell_score_max - self.cell_score_min)))) 246 | green = hex(green_dec)[2:] 247 | red = hex(255-green_dec)[2:] 248 | if len(red) == 1: 249 | red += "0" 250 | if len(green) == 1: 251 | green += "0" 252 | color = "#" + red + green + "00" 253 | self.board.itemconfigure(triangle, fill=color) 254 | 255 | def update_specials(self): 256 | # constant specials 257 | # return self.specials 258 | red_specials = [] 259 | green_specials = [] 260 | updated_red_specials = [] 261 | for special in self.specials: 262 | if special[2] == "red": 263 | red_specials.append(special) 264 | if special[2] == "green": 265 | green_specials.append(special) 266 | for (i, j, c, w, v) in red_specials: 267 | if v == "up": 268 | j -= 1 269 | if (j >= 0) and (j < self.y) and not ((i, j) in self.walls): 270 | pass # pass, all good 271 | else: 272 | v = "down" 273 | j += 2 274 | elif v == "down": 275 | j += 1 276 | if (j >= 0) and (j < self.y) and not ((i, j) in self.walls): 277 | pass # pass, all good 278 | else: 279 | v = "up" 280 | j -= 2 281 | elif v == "left": 282 | i -= 1 283 | if (i >= 0) and (i < self.x) and not ((i, j) in self.walls): 284 | pass # pass, all good 285 | else: 286 | v = "right" 287 | i += 2 288 | elif v == "right": 289 | i += 1 290 | if (i >= 0) and (i < self.x) and not ((i, j) in self.walls): 291 | pass # pass, all good 292 | else: 293 | v = "left" 294 | i -= 2 295 | updated_red_specials.append((i, j, c, w, v)) 296 | return updated_red_specials + green_specials 297 | 298 | def try_move_idx(self, move_idx): 299 | if move_idx == 0: 300 | self.try_move(0, -1) 301 | elif move_idx == 1: 302 | self.try_move(0, 1) 303 | elif move_idx == 2: 304 | self.try_move(-1, 0) 305 | elif move_idx == 3: 306 | self.try_move(1, 0) 307 | else: 308 | print("Unknown move index", move_idx) 309 | 310 | def try_move(self, dx, dy): 311 | # no movement out of terminal states 312 | for (i, j, c, w, v) in self.specials: 313 | if self.player[0] == i and self.player[1] == j: 314 | if self.do_restart: 315 | print("Restarting game...") 316 | self.restart_game() 317 | print("Game restarted...") 318 | return 319 | else: 320 | self.score += w 321 | return 322 | 323 | old_specials = self.specials.copy() 324 | self.specials = self.update_specials() 325 | new_x = self.player[0] + dx 326 | new_y = self.player[1] + dy 327 | self.score += self.walk_reward 328 | if (new_x >= 0) and (new_x < self.x) and (new_y >= 0) and (new_y < self.y) and not ((new_x, new_y) in self.walls): 329 | self.player = (new_x, new_y) 330 | 331 | if self.do_render: 332 | self.board.tag_raise(self.me) 333 | for (i, j, c, w, v) in old_specials: 334 | if c == "red": 335 | self.board.create_rectangle(i * self.Width, 336 | j * self.Width, 337 | (i + 1) * self.Width, 338 | (j + 1) * self.Width, 339 | fill='white', 340 | width=1) 341 | self.board.coords(self.me, 342 | self.player[ 343 | 0] * self.Width + self.Width * 2 / 10, 344 | self.player[ 345 | 1] * self.Width + self.Width * 2 / 10, 346 | self.player[ 347 | 0] * self.Width + self.Width * 8 / 10, 348 | self.player[ 349 | 1] * self.Width + self.Width * 8 / 10) 350 | self.board.tag_raise(self.me) 351 | for (i, j, c, w, v) in self.specials: 352 | self.board.create_rectangle(i * self.Width, 353 | j * self.Width, 354 | (i + 1) * self.Width, 355 | (j + 1) * self.Width, 356 | fill=c, 357 | width=1) 358 | self.board.tag_raise(self.me) 359 | 360 | for (i, j, c, w, v) in self.specials: 361 | if self.player[0] == i and self.player[1] == j: 362 | self.score -= self.walk_reward 363 | self.score += w 364 | 365 | def call_up(self, event): 366 | self.try_move(0, -1) 367 | 368 | def call_down(self, event): 369 | self.try_move(0, 1) 370 | 371 | def call_left(self, event): 372 | self.try_move(-1, 0) 373 | 374 | def call_right(self, event): 375 | self.try_move(1, 0) 376 | 377 | def restart_game(self): 378 | self.player = self.origin 379 | self.specials = self.original_specials.copy() 380 | self.restart = False 381 | if self.do_render: 382 | self.render_reset_grid() 383 | time.sleep(static_time_between_moves) 384 | self.board.delete('all') 385 | self.render_grid() 386 | self.me = self.board.create_rectangle( 387 | self.player[0] * self.Width + self.Width * 2 / 10, 388 | self.player[1] * self.Width + self.Width * 2 / 10, 389 | self.player[0] * self.Width + self.Width * 8 / 10, 390 | self.player[1] * self.Width + self.Width * 8 / 10, fill="orange", width=1, 391 | tag="me") 392 | self.board.tag_raise(self.me) 393 | time.sleep(static_time_between_moves) 394 | 395 | def has_restarted(self): 396 | return self.restart 397 | 398 | def _close(self): 399 | self.quit() 400 | 401 | def quit(self): 402 | self.master.destroy() 403 | 404 | --------------------------------------------------------------------------------