├── .gitignore ├── LICENSE.md ├── README.md ├── abstraction_experiments.py ├── action_abs ├── ActionAbstractionClass.py ├── ContainsPredicateClass.py ├── CovPredicateClass.py ├── EqPredicateClass.py ├── NotPredicateClass.py ├── OptionClass.py ├── PolicyClass.py ├── PolicyFromDictClass.py ├── PredicateClass.py ├── __init__.py ├── aa_baselines.py └── aa_helpers.py ├── chain.py ├── hierarch ├── ActionAbstractionStackClass.py ├── DynamicHierarchyAgentClass.py ├── HRMaxAgentClass.py ├── HierarchicalValueIterationClass.py ├── HierarchyAgentClass.py ├── HierarchyStateClass.py ├── RewardFuncClass.py ├── StateAbstractionStackClass.py ├── TransitionFuncClass.py ├── __init__.py ├── action_abstr_stack_helpers.py ├── hierarchy_experiments.py ├── hierarchy_helpers.py ├── make_abstr_mdp.py └── state_abstr_stack_helpers.py ├── hierarch_rooms.txt ├── octogrid.txt ├── pblocks_grid.txt ├── run_icml_learning_experiments.py ├── run_icml_planning_experiments.py ├── simple_planning_experiments.py ├── simple_sa_experiments.py ├── state_abs ├── StateAbstractionClass.py ├── __init__.py ├── indicator_funcs.py └── sa_helpers.py └── utils ├── .DS_Store ├── AbstractValueIterationClass.py ├── AbstractionWrapperClass.py ├── ColorMDPClass.py ├── ColorStateClass.py ├── StochasticSAPolicyClass.py ├── __init__.py ├── hierarch_rooms.txt ├── make_mdp.py ├── octogrid.txt ├── pblocks_grid.txt ├── planning_experiments.py ├── run_abstr_combo_experiments.py ├── run_dir_opt_core_experiments.py └── visualize_abstractions.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.p 2 | *.pyc 3 | results/ 4 | images/ 5 | paper_results/ 6 | good_results/ 7 | .DS_Store -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [2017] [David Abel] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | A collection of code for learning, using, evaluating, and visualizing abstractions in Reinforcement Learning in Python. This code is used to run experiments for our 2018 ICML paper: [State Abstractions for Lifelong Reinforcement Learning](https://david-abel.github.io/papers/lifelong_sa_icml_18.pdf) and the earlier workshop paper [Toward Good Abstractions for Lifelong Learning](http://cs.brown.edu/~dabel/papers/nips_hrl_good_abstr.pdf), presented at the NIPS Hierarchical Reinforcement Learning Workshop in 2017. 2 | 3 | Experiments require [simple_rl](https://github.com/david-abel/simple_rl), which can be installed with the usual: 4 | 5 | pip install simple_rl 6 | 7 | For the ICML paper, run _run_icml_learning_experiments.py_ to reproduce plots from Figure 3a/3b and Figure 4. Run _run_icml_planning_experiments.py_ to reproduce plots from Figure 5. Run _chain.py_ to reproduce the plot in Figure 2. 8 | 9 | Authors: David Abel and Dilip Arumugam. Let us know if you have issues! -------------------------------------------------------------------------------- /abstraction_experiments.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Python imports. 4 | import random as r 5 | from collections import defaultdict 6 | import os 7 | import argparse 8 | 9 | # Other imports. 10 | from simple_rl.utils import make_mdp 11 | from simple_rl.agents import RandomAgent, RMaxAgent, QLearningAgent, FixedPolicyAgent 12 | from simple_rl.run_experiments import run_agents_lifelong, run_agents_on_mdp 13 | from simple_rl.tasks import TaxiOOMDP 14 | from simple_rl.mdp import State, MDPDistribution 15 | from simple_rl.abstraction.AbstractionWrapperClass import AbstractionWrapper 16 | from state_abs.StateAbstractionClass import StateAbstraction 17 | from action_abs.ActionAbstractionClass import ActionAbstraction 18 | import state_abs 19 | import action_abs 20 | from state_abs import indicator_funcs as ind_funcs 21 | 22 | # ----------------------- 23 | # -- Make Abstractions -- 24 | # ----------------------- 25 | 26 | def get_abstractions(mdp, indic_func, directed=True, max_options=100): 27 | ''' 28 | Args: 29 | mdp (MDP or MDPDistribution) 30 | indic_func (lambda): Property tester for the state abstraction. 31 | directed (bool) 32 | max_options (int) 33 | 34 | Returns: 35 | (StateAbstraction, ActionAbstraction) 36 | ''' 37 | if directed: 38 | return get_directed_option_sa_pair(mdp, indic_func=indic_func, max_options=max_options) 39 | else: 40 | sa = get_sa(mdp, indic_func=indic_func) 41 | aa = get_aa(mdp) 42 | return sa, aa 43 | 44 | def get_directed_option_sa_pair(mdp_distr, indic_func, max_options=100): 45 | ''' 46 | Args: 47 | mdp_distr (MDPDistribution) 48 | indic_func 49 | max_options (int) 50 | 51 | Returns: 52 | (StateAbstraction, ActionAbstraction) 53 | ''' 54 | 55 | # Get Abstractions by iterating over epsilons. 56 | found_small_option_set = False 57 | sa_epsilon, sa_eps_incr = 0.1, 0.01 58 | 59 | if isinstance(mdp_distr.get_all_mdps()[0], TaxiOOMDP): 60 | sa_epsilon = 0.02 61 | 62 | if "whirlpool" in str(mdp_distr.get_all_mdps()[0]): 63 | sa_eps_incr = 0.002 64 | 65 | if "color" in str(mdp_distr.get_all_mdps()[0]): 66 | sa_epsilon = 0.00 67 | 68 | while sa_epsilon <= 1.0 / (1 - mdp_distr.get_gamma()): 69 | print "Epsilon:", sa_epsilon 70 | 71 | # Compute the SA-AA pair. 72 | # NOTE: Track act_opt_pr is TRUE 73 | sa = get_sa(mdp_distr, indic_func=indic_func, default=False, epsilon=sa_epsilon, track_act_opt_pr=False) 74 | 75 | if sa.get_num_abstr_states() == 1: 76 | # We can't have only 1 abstract state. 77 | print "Abstraction Error: only 1 abstract state." 78 | quit() 79 | 80 | aa = get_directed_aa(mdp_distr, sa, max_options=max_options) 81 | if aa: 82 | # If this is a good aa, we're done. 83 | break 84 | 85 | sa_epsilon += sa_eps_incr 86 | 87 | print "\nFound", len(aa.get_actions()), "Options." 88 | 89 | return sa, aa 90 | 91 | # ------------------------ 92 | # -- State Abstractions -- 93 | # ------------------------ 94 | 95 | def get_sa(mdp_distr, indic_func=None, default=False, epsilon=0.0): 96 | ''' 97 | Args: 98 | mdp_distr (MDPDistributon) 99 | indicator_func (lambda): Indicator function from state_abs/indicator_funcs.py 100 | default (bool): If true, returns a blank StateAbstraction 101 | epsilon (float): Determines approximation for clustering. 102 | 103 | Returns: 104 | (StateAbstraction) 105 | ''' 106 | 107 | if default: 108 | return StateAbstraction(phi={}) 109 | 110 | state_abstr = state_abs.sa_helpers.make_sa(mdp_distr, indic_func=indic_func, state_class=State, epsilon=epsilon) 111 | 112 | return state_abstr 113 | 114 | def compute_pac_sa(mdp_distr, indic_func=None, default=False, phi_epsilon=0.05, pac_delta=0.2): 115 | ''' 116 | Args: 117 | mdp_distr (MDPDistributon) 118 | indicator_func (lambda): Indicator function from state_abs/indicator_funcs.py 119 | default (bool): If true, returns a blank StateAbstraction 120 | phi_epsilon (float): Determines approximation for Q^*_epsilon clustering. 121 | pac_delta (float): Determines how confident the resulting p_hat should be. 122 | 123 | Returns: 124 | (StateAbstraction) 125 | ''' 126 | 127 | state_abstr = state_abs.sa_helpers.get_pac_sa_from_samples(mdp_distr, indic_func=indic_func, phi_epsilon=epsilon, pac_delta=0.2) 128 | 129 | return state_abstr 130 | 131 | # ------------------------- 132 | # -- Action Abstractions -- 133 | # ------------------------- 134 | 135 | def get_aa(mdp_distr, default=False): 136 | ''' 137 | Args: 138 | mdp (defaultdict) 139 | default (bool): If true, returns a blank ActionAbstraction 140 | 141 | Returns: 142 | (ActionAbstraction) 143 | ''' 144 | 145 | if default: 146 | return ActionAbstraction(options=mdp_distr.get_actions(), prim_actions=mdp_distr.get_actions()) 147 | 148 | return action_abs.aa_helpers.make_greedy_options(mdp_distr) 149 | 150 | def get_directed_aa(mdp_distr, state_abs, incl_prim_actions=False, max_options=100): 151 | ''' 152 | Args: 153 | mdp_distr (dict) 154 | state_abs (StateAbstraction) 155 | incl_prim_actions (bool) 156 | max_options (int) 157 | 158 | Returns: 159 | (ActionAbstraction) 160 | ''' 161 | directed_options = action_abs.aa_helpers.get_directed_options_for_sa(mdp_distr, state_abs, incl_self_loops=True, max_options=max_options) 162 | term_prob = 1 - mdp_distr.get_gamma() 163 | 164 | if not directed_options: 165 | # No good option set found. 166 | return False 167 | 168 | if incl_prim_actions: 169 | # Include the primitives. 170 | aa = ActionAbstraction(options=mdp_distr.get_actions(), prim_actions=mdp_distr.get_actions(), prims_on_failure=False, term_prob=term_prob) 171 | for o in directed_options: 172 | aa.add_option(o) 173 | return aa 174 | else: 175 | # Return just the options. 176 | return ActionAbstraction(options=directed_options, prim_actions=mdp_distr.get_actions(), prims_on_failure=True, term_prob=term_prob) 177 | 178 | 179 | def parse_args(): 180 | ''' 181 | Summary: 182 | Parse all arguments 183 | ''' 184 | parser = argparse.ArgumentParser() 185 | parser.add_argument("-task", type = str, default = "octo", nargs = '?', help = "Choose the mdp type (one of {octo, hall, grid, taxi, four_room}).") 186 | parser.add_argument("-samples", type = int, default = 500, nargs = '?', help = "Number of samples from the MDP Distribution.") 187 | parser.add_argument("-steps", type = int, default = 100, nargs = '?', help = "Number of steps for the experiment.") 188 | parser.add_argument("-episodes", type = int, default = 1, nargs = '?', help = "Number of episodes for the experiment.") 189 | parser.add_argument("-grid_dim", type = int, default = 11, nargs = '?', help = "Dimensions of the grid world.") 190 | parser.add_argument("-track_options", type = bool, default = False, nargs = '?', help = "Plot in terms of option executions (if True).") 191 | parser.add_argument("-agent", type = str, default='ql', nargs = '?', help = "Specify agent class (one of {'ql', 'rmax'})..") 192 | parser.add_argument("-max_options", type = int, default=50, nargs = '?', help = "Specify maximum number of options.") 193 | parser.add_argument("-exp_type", type = str, default="core", nargs = '?', help = "Choose which experiment we're running. One of {core, combo}.") 194 | args = parser.parse_args() 195 | 196 | return args.task, args.samples, args.episodes, args.steps, args.grid_dim, bool(args.track_options), args.agent, args.max_options, args.exp_type 197 | 198 | def main(): 199 | 200 | # Grab experiment params. 201 | mdp_class, task_samples, episodes, steps, grid_dim, x_axis_num_options, agent_class_str, max_options, exp_type = parse_args() 202 | 203 | gamma = 0.9 204 | 205 | # ======================== 206 | # === Make Environment === 207 | # ======================== 208 | multi_task = True 209 | max_option_steps = 50 if x_axis_num_options else 0 210 | environment = make_mdp.make_mdp_distr(mdp_class=mdp_class, grid_dim=grid_dim) if multi_task else make_mdp.make_mdp(mdp_class=mdp_class) 211 | actions = environment.get_actions() 212 | environment.set_gamma(gamma) 213 | 214 | # Indicator functions. 215 | v_indic = ind_funcs._v_approx_indicator 216 | q_indic = ind_funcs._q_eps_approx_indicator 217 | v_disc_indic = ind_funcs._v_disc_approx_indicator 218 | rand_indic = ind_funcs._random 219 | 220 | # ========================= 221 | # === Make Abstractions === 222 | # ========================= 223 | 224 | # Directed Variants. 225 | v_directed_sa, v_directed_aa = get_abstractions(environment, v_disc_indic, directed=True, max_options=max_options) 226 | # v_directed_sa, v_directed_aa = get_abstractions(environment, v_indic, directed=True, max_options=max_options) 227 | 228 | # Identity action abstraction. 229 | identity_sa, identity_aa = get_sa(environment, default=True), get_aa(environment, default=True) 230 | 231 | if exp_type == "core": 232 | # Core only abstraction types. 233 | q_directed_sa, q_directed_aa = get_abstractions(environment, q_indic, directed=True, max_options=max_options) 234 | rand_directed_sa, rand_directed_aa = get_abstractions(environment, rand_indic, directed=True, max_options=max_options) 235 | pblocks_sa, pblocks_aa = get_sa(environment, default=True), action_abs.aa_baselines.get_policy_blocks_aa(environment, incl_prim_actions=True, num_options=max_options) 236 | 237 | # =================== 238 | # === Make Agents === 239 | # =================== 240 | 241 | # Base Agents. 242 | agent_class = QLearningAgent if agent_class_str == "ql" else RMaxAgent 243 | rand_agent = RandomAgent(actions) 244 | baseline_agent = agent_class(actions, gamma=gamma) 245 | 246 | if mdp_class == "pblocks": 247 | baseline_agent.epsilon = 0.01 248 | 249 | # Abstraction Extensions. 250 | agents = [] 251 | vabs_agent_directed = AbstractionWrapper(agent_class, actions, str(environment), max_option_steps=max_option_steps, state_abstr=v_directed_sa, action_abstr=v_directed_aa, name_ext="v-sa+aa") 252 | 253 | if exp_type == "core": 254 | # Core only agents. 255 | qabs_agent_directed = AbstractionWrapper(agent_class, actions, str(environment), max_option_steps=max_option_steps, state_abstr=q_directed_sa, action_abstr=q_directed_aa, name_ext="q-sa+aa") 256 | rabs_agent_directed = AbstractionWrapper(agent_class, actions, str(environment), max_option_steps=max_option_steps, state_abstr=rand_directed_sa, action_abstr=rand_directed_aa, name_ext="rand-sa+aa") 257 | pblocks_agent = AbstractionWrapper(agent_class, actions, str(environment), max_option_steps=max_option_steps, state_abstr=pblocks_sa, action_abstr=pblocks_aa, name_ext="pblocks") 258 | agents = [vabs_agent_directed, qabs_agent_directed, rabs_agent_directed, pblocks_agent, baseline_agent] 259 | elif exp_type == "combo": 260 | # Combo only agents. 261 | aa_agent = AbstractionWrapper(agent_class, actions, str(environment), max_option_steps=max_option_steps, state_abstr=identity_sa, action_abstr=v_directed_aa, name_ext="aa") 262 | sa_agent = AbstractionWrapper(agent_class, actions, str(environment), max_option_steps=max_option_steps, state_abstr=v_directed_sa, action_abstr=identity_aa, name_ext="sa") 263 | agents = [vabs_agent_directed, sa_agent, aa_agent, baseline_agent] 264 | 265 | # Run experiments. 266 | if multi_task: 267 | steps = 999999 if x_axis_num_options else steps 268 | run_agents_multi_task(agents, environment, task_samples=task_samples, steps=steps, episodes=episodes, reset_at_terminal=True) 269 | else: 270 | run_agents_on_mdp(agents, environment, instances=20, episodes=30, reset_at_terminal=True) 271 | 272 | 273 | if __name__ == "__main__": 274 | main() -------------------------------------------------------------------------------- /action_abs/ActionAbstractionClass.py: -------------------------------------------------------------------------------- 1 | # Python imports. 2 | from collections import defaultdict 3 | import random 4 | 5 | # Other imports. 6 | from OptionClass import Option 7 | from PredicateClass import Predicate 8 | 9 | class ActionAbstraction(object): 10 | 11 | def __init__(self, options, prim_actions, term_prob=0.0, prims_on_failure=True): 12 | self.options = self._convert_to_options(options) 13 | self.is_cur_executing = False 14 | self.cur_option = None # The option we're executing currently. 15 | self.prim_actions = prim_actions 16 | self.term_prob = term_prob 17 | self.prims_on_failure = prims_on_failure 18 | 19 | def act(self, agent, abstr_state, ground_state, reward): 20 | ''' 21 | Args: 22 | agent (Agent) 23 | abstr_state (State) 24 | ground_state (State) 25 | reward (float) 26 | 27 | Returns: 28 | (str) 29 | ''' 30 | if self.is_next_step_continuing_option(ground_state) and random.random() > self.term_prob: 31 | # We're in an option and not terminating. 32 | return self.get_next_ground_action(ground_state) 33 | else: 34 | # We're not in an option, check with agent. 35 | active_options = self.get_active_options(ground_state) 36 | 37 | if len(active_options) == 0: 38 | if self.prims_on_failure: 39 | # In a rare failure state, back off to primitives. 40 | agent.actions = self._convert_to_options(self.prim_actions) 41 | else: 42 | # No actions available. 43 | print "Error: no actions available in state " + str(ground_state) + ". (r:" + str(reward) + "," + str(ground_state.is_terminal()) + ")." 44 | quit() 45 | else: 46 | # Give agent available options. 47 | agent.actions = active_options 48 | 49 | abstr_action = agent.act(abstr_state, reward) 50 | self.set_option_executing(abstr_action) 51 | 52 | return self.abs_to_ground(ground_state, abstr_action) 53 | 54 | def get_active_options(self, state): 55 | ''' 56 | Args: 57 | state (State) 58 | 59 | Returns: 60 | (list): Contains all active options. 61 | ''' 62 | return [o for o in self.options if o.is_init_true(state)] 63 | 64 | def _convert_to_options(self, action_list): 65 | ''' 66 | Args: 67 | action_list (list) 68 | 69 | Returns: 70 | (list of Option) 71 | ''' 72 | options = [] 73 | for ground_action in action_list: 74 | o = ground_action 75 | if type(ground_action) is str: 76 | o = Option(init_predicate=Predicate(make_lambda(True)), 77 | term_predicate=Predicate(make_lambda(True)), 78 | policy=make_lambda(ground_action), 79 | name="prim." + ground_action) 80 | options.append(o) 81 | return options 82 | 83 | def is_next_step_continuing_option(self, ground_state): 84 | ''' 85 | Returns: 86 | (bool): True iff an option was executing and should continue next step. 87 | ''' 88 | return self.is_cur_executing and not self.cur_option.is_term_true(ground_state) 89 | 90 | def set_option_executing(self, option): 91 | if option not in self.options and "prim" not in option.name: 92 | print "Error: agent chose a non-existent option (" + str(option) + ")." 93 | quit() 94 | 95 | self.cur_option = option 96 | self.is_cur_executing = True 97 | 98 | def get_next_ground_action(self, ground_state): 99 | return self.cur_option.act(ground_state) 100 | 101 | def get_actions(self): 102 | return self.options 103 | 104 | def abs_to_ground(self, ground_state, abstr_action): 105 | return abstr_action.act(ground_state) 106 | 107 | def add_option(self, option): 108 | self.options += [option] 109 | 110 | def reset(self): 111 | self.is_cur_executing = False 112 | self.cur_option = None # The option we're executing currently. 113 | 114 | def end_of_episode(self): 115 | self.reset() 116 | 117 | 118 | def make_lambda(result): 119 | return lambda x : result -------------------------------------------------------------------------------- /action_abs/ContainsPredicateClass.py: -------------------------------------------------------------------------------- 1 | class ContainsPredicate(object): 2 | 3 | def __init__(self, list_of_items): 4 | self.list_of_items = list_of_items 5 | 6 | def is_true(self, x): 7 | return x in self.list_of_items 8 | 9 | -------------------------------------------------------------------------------- /action_abs/CovPredicateClass.py: -------------------------------------------------------------------------------- 1 | class CovPredicate(object): 2 | def __init__(self, y, policy): 3 | self.y = y 4 | self.policy = policy 5 | 6 | def is_true(self, x): 7 | return (x in self.policy.keys()) == self.y 8 | -------------------------------------------------------------------------------- /action_abs/EqPredicateClass.py: -------------------------------------------------------------------------------- 1 | class EqPredicate(object): 2 | 3 | def __init__(self, y, func): 4 | self.y = y 5 | self.func = func 6 | 7 | def is_true(self, x): 8 | return self.func(x) == self.y 9 | 10 | 11 | class NeqPredicate(object): 12 | 13 | def __init__(self, y, func): 14 | self.y = y 15 | self.func = func 16 | 17 | def is_true(self, x): 18 | return not(self.func(x) == self.y) 19 | 20 | 21 | -------------------------------------------------------------------------------- /action_abs/NotPredicateClass.py: -------------------------------------------------------------------------------- 1 | from PredicateClass import Predicate 2 | 3 | class NotPredicate(Predicate): 4 | 5 | def __init__(self, predicate): 6 | self.predicate = predicate 7 | 8 | def is_true(self, x): 9 | return not self.predicate.is_true(x) 10 | 11 | -------------------------------------------------------------------------------- /action_abs/OptionClass.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import random 3 | from simple_rl.mdp.StateClass import State 4 | 5 | 6 | class Option(object): 7 | 8 | def __init__(self, init_predicate, term_predicate, policy, name="o", term_prob=0.01): 9 | ''' 10 | Args: 11 | init_func (S --> {0,1}) 12 | init_func (S --> {0,1}) 13 | policy (S --> A) 14 | ''' 15 | self.init_predicate = init_predicate 16 | self.term_predicate = term_predicate 17 | self.term_flag = False 18 | self.name = name 19 | self.term_prob = term_prob 20 | 21 | if type(policy) is defaultdict or type(policy) is dict: 22 | self.policy_dict = dict(policy) 23 | self.policy = self.policy_from_dict 24 | else: 25 | self.policy = policy 26 | 27 | def is_init_true(self, ground_state): 28 | return self.init_predicate.is_true(ground_state) 29 | 30 | def is_term_true(self, ground_state): 31 | return self.term_predicate.is_true(ground_state) or self.term_flag or self.term_prob > random.random() 32 | 33 | def act(self, ground_state): 34 | return self.policy(ground_state) 35 | 36 | def set_policy(self, policy): 37 | self.policy = policy 38 | 39 | def set_name(self, new_name): 40 | self.name = new_name 41 | 42 | def act_until_terminal(self, cur_state, transition_func): 43 | ''' 44 | Summary: 45 | Executes the option until termination. 46 | ''' 47 | if self.is_init_true(cur_state): 48 | cur_state = transition_func(cur_state, self.act(cur_state)) 49 | while not self.is_term_true(cur_state): 50 | cur_state = transition_func(cur_state, self.act(cur_state)) 51 | 52 | return cur_state 53 | 54 | def rollout(self, cur_state, reward_func, transition_func): #, step_cost=0): 55 | ''' 56 | Summary: 57 | Executes the option until termination. 58 | 59 | Returns: 60 | (tuple): 61 | 1. (State): state we landed in. 62 | 2. (float): Reward from the trajectory. 63 | ''' 64 | total_reward = 0 65 | if self.is_init_true(cur_state): 66 | # First step. 67 | total_reward += reward_func(cur_state, self.act(cur_state)) # - step_cost 68 | cur_state = transition_func(cur_state, self.act(cur_state)) 69 | 70 | # Act until terminal. 71 | while not self.is_term_true(cur_state): 72 | cur_state = transition_func(cur_state, self.act(cur_state)) 73 | total_reward += reward_func(cur_state, self.act(cur_state))# - step_cost 74 | 75 | return cur_state, total_reward 76 | 77 | def policy_from_dict(self, state): 78 | if state not in self.policy_dict.keys(): 79 | self.term_flag = True 80 | return random.choice(list(set(self.policy_dict.values()))) 81 | else: 82 | self.term_flag = False 83 | return self.policy_dict[state] 84 | 85 | def term_func_from_list(self, state): 86 | return state in self.term_list 87 | 88 | 89 | def __str__(self): 90 | return "option." + str(self.name) -------------------------------------------------------------------------------- /action_abs/PolicyClass.py: -------------------------------------------------------------------------------- 1 | class Policy(object): 2 | 3 | def __init__(self, action=""): 4 | self.action = action 5 | 6 | def get_action(self, state): 7 | return self.action -------------------------------------------------------------------------------- /action_abs/PolicyFromDictClass.py: -------------------------------------------------------------------------------- 1 | # Python imports. 2 | import random 3 | 4 | # Other imports. 5 | from PolicyClass import Policy 6 | from collections import defaultdict 7 | 8 | class PolicyFromDict(Policy): 9 | 10 | def __init__(self, policy_dict={}): 11 | self.policy_dict = policy_dict 12 | 13 | def get_action(self, state): 14 | # if state not in self.policy_dict.keys(): 15 | # print "(PolicyFromDict) Abstraction Error:", state, "never seen before." 16 | # quit() 17 | 18 | if state not in self.policy_dict.keys(): 19 | # print "(PolicyFromDict) Warning: unseen state (" + str(state) + "). Acting randomly." 20 | return random.choice(list(set(self.policy_dict.values()))) 21 | else: 22 | # print "Seen state!:" + str(state) 23 | return self.policy_dict[state] 24 | 25 | def make_dict_from_lambda(policy_func, state_list): 26 | policy_dict = {} #defaultdict(str) 27 | for s in state_list: 28 | policy_dict[s] = policy_func(s) 29 | 30 | return policy_dict -------------------------------------------------------------------------------- /action_abs/PredicateClass.py: -------------------------------------------------------------------------------- 1 | class Predicate(object): 2 | 3 | def __init__(self, func, params={}): 4 | self.func = func 5 | 6 | def is_true(self, x): 7 | return self.func(x) 8 | 9 | -------------------------------------------------------------------------------- /action_abs/__init__.py: -------------------------------------------------------------------------------- 1 | import aa_helpers 2 | import aa_baselines -------------------------------------------------------------------------------- /action_abs/aa_baselines.py: -------------------------------------------------------------------------------- 1 | # Python imports. 2 | import sys 3 | import numpy as np 4 | from collections import defaultdict 5 | import random 6 | 7 | # Other imports. 8 | from action_abs.ActionAbstractionClass import ActionAbstraction 9 | from action_abs.ContainsPredicateClass import ContainsPredicate 10 | from action_abs.CovPredicateClass import CovPredicate 11 | from action_abs.EqPredicateClass import EqPredicate, NeqPredicate 12 | from action_abs.NotPredicateClass import NotPredicate 13 | from action_abs.OptionClass import Option 14 | from action_abs.PolicyClass import Policy 15 | from action_abs.PolicyFromDictClass import make_dict_from_lambda, PolicyFromDict 16 | from simple_rl.planning import ValueIteration 17 | from simple_rl.run_experiments import run_agents_lifelong 18 | from simple_rl.tasks.grid_world import GridWorldMDPClass 19 | 20 | 21 | def get_aa_high_prob_opt_single_act(mdp_distr, state_abstr, delta=0.2): 22 | ''' 23 | Args: 24 | mdp_distr 25 | state_abstr (StateAbstraction) 26 | 27 | Summary: 28 | Computes an action abstraction where there exists an option that repeats a 29 | single primitive action, for each primitive action that was optimal *with 30 | high probability* in the ground state in the cluster. 31 | ''' 32 | # K: state, V: dict (K: act, V: probability) 33 | action_optimality_dict = state_abstr.get_act_opt_dict() 34 | 35 | # Compute options. 36 | options = [] 37 | for s_a in state_abstr.get_abs_states(): 38 | 39 | ground_states = state_abstr.get_ground_states_in_abs_state(s_a) 40 | 41 | # One option per action. 42 | for action in mdp_distr.get_actions(): 43 | list_of_state_with_a_optimal_high_pr = [] 44 | 45 | # Compute which states have high prob of being optimal. 46 | for s_g in ground_states: 47 | print "Pr(a = a^* \mid s_g)", s_g, action, action_optimality_dict[s_g][action] 48 | if action_optimality_dict[s_g][action] > (1-delta): 49 | list_of_state_with_a_optimal_high_pr.append(s_g) 50 | 51 | if len(list_of_state_with_a_optimal_high_pr) == 0: 52 | continue 53 | 54 | init_predicate = ContainsPredicate(list_of_items=list_of_state_with_a_optimal_high_pr) 55 | term_predicate = NotPredicate(init_predicate) 56 | policy_obj = Policy(action) 57 | 58 | o = Option(init_predicate=init_predicate, 59 | term_predicate=term_predicate, 60 | policy=policy_obj.get_action) 61 | 62 | options.append(o) 63 | 64 | return ActionAbstraction(options=options, prim_actions=mdp_distr.get_actions(), prims_on_failure=True) 65 | 66 | 67 | def get_aa_opt_only_single_act(mdp_distr, state_abstr): 68 | ''' 69 | Args: 70 | mdp_distr 71 | state_abstr (StateAbstraction) 72 | 73 | Summary: 74 | Computes an action abstraction where there exists an option that repeats a 75 | single primitive action, for each primitive action that was optimal in 76 | the ground state in the cluster. 77 | ''' 78 | action_optimality_dict = state_abstr.get_act_opt_dict() 79 | 80 | # Compute options. 81 | options = [] 82 | for s_a in state_abstr.get_abs_states(): 83 | 84 | ground_states = state_abstr.get_ground_states_in_abs_state(s_a) 85 | 86 | # One option per action. 87 | for action in mdp_distr.get_actions(): 88 | list_of_state_with_a_optimal = [] 89 | 90 | for s_g in ground_states: 91 | if action in action_optimality_dict[s_g]: 92 | list_of_state_with_a_optimal.append(s_g) 93 | 94 | if len(list_of_state_with_a_optimal) == 0: 95 | continue 96 | 97 | init_predicate = ContainsPredicate(list_of_items=list_of_state_with_a_optimal) 98 | term_predicate = NotPredicate(init_predicate) 99 | policy_obj = Policy(action) 100 | 101 | o = Option(init_predicate=init_predicate, 102 | term_predicate=term_predicate, 103 | policy=policy_obj.get_action, 104 | term_prob=1-mdp_distr.get_gamma()) 105 | 106 | options.append(o) 107 | 108 | return ActionAbstraction(options=options, prim_actions=mdp_distr.get_actions()) 109 | 110 | 111 | def get_aa_single_act(mdp_distr, state_abstr): 112 | ''' 113 | Args: 114 | mdp_distr 115 | state_abstr (StateAbstraction) 116 | 117 | Summary: 118 | Computes an action abstraction where there exists an option that repeats a 119 | single primitive action, for each primitive action that was optimal in the 120 | cluster. 121 | ''' 122 | 123 | action_optimality_dict = state_abstr.get_act_opt_dict() 124 | 125 | options = [] 126 | # Compute options. 127 | for s_a in state_abstr.get_abs_states(): 128 | init_predicate = EqPredicate(y=s_a, func=state_abstr.phi) 129 | term_predicate = NeqPredicate(y=s_a, func=state_abstr.phi) 130 | 131 | ground_states = state_abstr.get_ground_states_in_abs_state(s_a) 132 | 133 | unique_a_star_in_cluster = set([]) 134 | for s_g in ground_states: 135 | for a_star in action_optimality_dict[s_g]: 136 | unique_a_star_in_cluster.add(a_star) 137 | 138 | for action in unique_a_star_in_cluster: 139 | policy_obj = Policy(action) 140 | 141 | o = Option(init_predicate=init_predicate, 142 | term_predicate=term_predicate, 143 | policy=policy_obj.get_action) 144 | options.append(o) 145 | 146 | return ActionAbstraction(options=options, prim_actions=mdp_distr.get_actions()) 147 | 148 | 149 | # --------------------- 150 | # --- POLICY BLOCKS --- 151 | # --------------------- 152 | 153 | def get_policy_blocks_aa(mdp_distr, num_options=10, task_samples=20, incl_prim_actions=False): 154 | pb_options = make_policy_blocks_options(mdp_distr, num_options=num_options, task_samples=task_samples) 155 | 156 | if type(mdp_distr) is dict: 157 | first_mdp = mdp_distr.keys()[0] 158 | else: 159 | first_mdp = mdp_distr 160 | 161 | if incl_prim_actions: 162 | # Include the primitives. 163 | aa = ActionAbstraction(options=first_mdp.get_actions(), prim_actions=first_mdp.get_actions()) 164 | for o in pb_options: 165 | aa.add_option(o) 166 | return aa 167 | else: 168 | # Return just the options. 169 | return ActionAbstraction(options=pb_options, prim_actions=first_mdp.get_actions()) 170 | 171 | 172 | def policy_blocks_merge_pair(pi1, pi2): 173 | ''' 174 | Perform pairwise merge between two partial policies to compute their intersection (also a partial policy) 175 | :param pi1: Partial policy 1 176 | :param pi2: Partial policy 2 177 | :return: Merged partial policy 178 | ''' 179 | ret = {} 180 | for state in set(pi1.keys() + pi2.keys()): 181 | a1 = pi1.get(state, None) 182 | a2 = pi2.get(state, None) 183 | if a1 == a2 and a1 is not None: 184 | ret[state] = a1 185 | return ret 186 | 187 | def policy_blocks_subtract_pair(pi1, pi2): 188 | ''' 189 | Perform pairwise subtraction between two partial policies to compute their difference (also a partial policy) 190 | :param pi1: Partial policy 1 191 | :param pi2: Partial policy 2 192 | :return: Difference partial policy 193 | ''' 194 | ret = {} 195 | for state in set(pi1.keys() + pi2.keys()): 196 | a1 = pi1.get(state, None) 197 | a2 = pi2.get(state, None) 198 | if a1 is not None and a2 is None: 199 | ret[state] = a1 200 | return ret 201 | 202 | 203 | def policy_blocks_merge(policy_set): 204 | ''' 205 | :param policy_set: Set of policies to merge 206 | :return: partial policy representing the merge of all policies in the policy set 207 | ''' 208 | policy_set = list(policy_set) 209 | merged = policy_set[0] 210 | for i in range(1, len(policy_set)): 211 | merged = policy_blocks_merge_pair(merged, policy_set[i]) 212 | return merged 213 | 214 | 215 | def policy_blocks_contains_pairwise(containee, container): 216 | ''' 217 | Determine if one policy is contained within another 218 | :param containee: Policy that may be contained 219 | :param container: Policy that is a container 220 | :return: True if containee(s) = container(s) for all s in containee, otherwise False 221 | ''' 222 | for state in containee.keys(): 223 | a = container.get(state, None) 224 | if a is None: 225 | return False 226 | 227 | if a != containee[state]: 228 | return False 229 | return True 230 | 231 | 232 | def policy_blocks_num_contains_policy(pi, policy_set): 233 | ''' 234 | Compute the number of policies in policy set which contain the candidate policy 235 | :param pi: Policy to search for within policy set 236 | :param policy_set: Set of policies to check containment 237 | :return: Number of policies in policy set which contain the candidate policy (value in [0, len(policy_set)] 238 | ''' 239 | ret = 0 240 | for policy in policy_set: 241 | if policy_blocks_contains_pairwise(pi, policy): 242 | ret += 1 243 | return ret 244 | 245 | 246 | def policy_blocks_score_policy(pi_unscored, policy_set): 247 | ''' 248 | The score is the size of the partial policy multiplied by the number of solution policies that contain it 249 | :return: Quick metric for determining how good this option policy is with respect to the sampled initial tasks 250 | ''' 251 | scale_contain = policy_blocks_num_contains_policy(pi_unscored, policy_set) 252 | return len(pi_unscored.keys()) * scale_contain 253 | 254 | 255 | def get_power_set(policy_set): 256 | ''' 257 | Compute subset of the power set of the input policy set by computing all possible pairs and triples 258 | :param policy_set: Set to compute partial power set 259 | :return: Partial power set consisting of all pairwise and triplet policy combinations 260 | ''' 261 | print 'Computing partial power set over solution policies...' 262 | ret = [] 263 | # Compute all pairs 264 | for i in xrange(len(policy_set)-1): 265 | for j in xrange(i+1, len(policy_set)): 266 | ret.append([policy_set[i], policy_set[j]]) 267 | 268 | print 'Finished computing all pairs...starting triples...' 269 | 270 | # Generate all triplets 271 | for i in xrange(len(ret)): 272 | for p in policy_set: 273 | # Sort to maintain consistency for duplicate checks 274 | to_add = sorted(ret[i] + [p]) 275 | if p not in ret[i] and to_add not in ret: 276 | ret.append(to_add) 277 | 278 | return ret 279 | 280 | def make_policy_blocks_options(mdp_distr, num_options, task_samples): 281 | ''' 282 | Args: 283 | mdp_distr (MDPDistribution) 284 | num_options (int) 285 | task_samples (int) 286 | 287 | Returns: 288 | (list): Contains policy blocks options. 289 | ''' 290 | option_set = [] 291 | # Fill solution set for task_samples draws from MDP distribution 292 | L = [] 293 | for new_task in xrange(task_samples): 294 | print " Sample " + str(new_task + 1) + " of " + str(task_samples) + "." 295 | 296 | # Sample the MDP. 297 | mdp = mdp_distr.sample() 298 | 299 | # Run VI to get a policy for the MDP as well as the list of states 300 | print "\tRunning VI...", 301 | sys.stdout.flush() 302 | # Run VI 303 | vi = ValueIteration(mdp, delta=0.0001, max_iterations=5000, sample_rate=5) 304 | iters, val = vi.run_vi() 305 | print " done." 306 | 307 | policy = make_dict_from_lambda(vi.policy, vi.get_states()) 308 | L.append(policy) 309 | 310 | power_L = get_power_set(L) 311 | num_iters = 1 312 | print 'Beginning policy blocks for {2} options with {0} solution policies and power set of size {1}'\ 313 | .format(len(L), len(power_L), num_options) 314 | 315 | while len(power_L) > 0 and len(option_set) < num_options: 316 | print 'Running iteration {0} of policy blocks...'.format(num_iters) 317 | # Initialize empty set of candidate option policies 318 | C = [] 319 | # Iterate over the power set of solution policies 320 | for policy_set in power_L: 321 | # Compute candidate policy as merge over policy set 322 | candidate = policy_blocks_merge(policy_set) 323 | if candidate not in C: 324 | # Compute score of each candidate policy 325 | C.append((candidate, policy_blocks_score_policy(candidate, L))) 326 | # Identify the candidate policy with highest score and add to option set 327 | C = sorted(C, key=lambda x: x[1]) 328 | pi_star = C[-1][0] 329 | if pi_star not in option_set: 330 | option_set.append(pi_star) 331 | 332 | # Subtract chosen candidate from L by iterating through power set 333 | power_L = map(lambda policy_set: [policy_blocks_subtract_pair(p, pi_star) for p in policy_set], power_L) 334 | 335 | # Remove empty elements of power set 336 | power_L = filter(lambda policy_set: sum(map(lambda x: len(x), policy_set)) > 0, power_L) 337 | 338 | num_iters += 1 339 | 340 | # Generate true option set 341 | ret = [] 342 | for o in option_set: 343 | init_predicate = CovPredicate(y=True, policy=o) 344 | term_predicate = CovPredicate(y=False, policy=o) 345 | print map(str, o.keys()) 346 | print o.values() 347 | print '**' 348 | opt = Option(init_predicate=init_predicate, term_predicate=term_predicate, policy=o) 349 | ret.append(opt) 350 | 351 | print 'Policy blocks returning with {0} options'.format(len(ret)) 352 | 353 | return ret 354 | 355 | if __name__ == '__main__': 356 | 357 | # MDP Setting. 358 | mdp_class = "pblocks_grid" 359 | num_mdps = 40 360 | mdp_distr = {} 361 | mdp_prob = 1.0 / num_mdps 362 | 363 | for i in range(num_mdps): 364 | new_mdp = GridWorldMDPClass.make_grid_world_from_file("action_abs/pblocks_grid.txt", randomize=True) 365 | mdp_distr[new_mdp] = mdp_prob 366 | 367 | actions = mdp_distr.keys()[0].actions 368 | gamma = mdp_distr.keys()[0].gamma 369 | 370 | 371 | ql_agent = QLearningAgent(actions, gamma=gamma) 372 | 373 | pblocks_aa = get_policy_blocks_aa(mdp_distr, num_options=5, task_samples=20, incl_prim_actions=True) 374 | regular_sa = get_sa(mdp_distr, default=True) 375 | 376 | pblocks_ql_agent = AbstractionWrapper(QLearningAgent, actions, state_abs=regular_sa, action_abs=pblocks_aa, name_ext="aa") 377 | 378 | agents = [pblocks_ql_agent, ql_agent] 379 | 380 | mdp_distr = MDPDistribution(mdp_distr) 381 | run_agents_lifelong(agents, mdp_distr, task_samples=100, episodes=1, steps=10000) 382 | 383 | from visualize_abstractions import visualize_options_grid 384 | 385 | visualize_options_grid(mdp_distr.sample(1), regular_sa.get_ground_states(), pblocks_aa) 386 | 387 | -------------------------------------------------------------------------------- /action_abs/aa_helpers.py: -------------------------------------------------------------------------------- 1 | # Python imports. 2 | from collections import defaultdict 3 | import Queue 4 | import random 5 | import os 6 | import sys 7 | import cPickle 8 | 9 | # Other imports. 10 | from ActionAbstractionClass import ActionAbstraction 11 | from OptionClass import Option 12 | from simple_rl.planning.ValueIterationClass import ValueIteration 13 | from simple_rl.mdp.MDPClass import MDP 14 | from EqPredicateClass import EqPredicate, NeqPredicate 15 | from PolicyFromDictClass import * 16 | from simple_rl.tasks import GridWorldMDP 17 | 18 | # ---------------------- 19 | # -- Directed Options -- 20 | # ---------------------- 21 | def get_directed_options_for_sa(mdp_distr, state_abstr, incl_self_loops=True, max_options=100): 22 | ''' 23 | Args: 24 | mdp_distr (MDPDistribution) 25 | state_abstr (StateAbstraction) 26 | incl_self_loops (bool) 27 | max_options (int) 28 | 29 | Returns: 30 | (ActionAbstraction) 31 | ''' 32 | 33 | print " Computing directed options." 34 | sys.stdout.flush() 35 | 36 | abs_states = state_abstr.get_abs_states() 37 | 38 | # Check max # options. 39 | total_clique_options = len(abs_states) * (len(abs_states) - 1) 40 | if total_clique_options > max_options: 41 | print "\tToo many options (" + str(total_clique_options) + "), need < " + str(max_options) + ". Increasing compression rate and continuing.\n" 42 | return False 43 | 44 | 45 | g_start_state = mdp_distr.get_init_state() 46 | 47 | # Compute all directed options that transition between abstract states. 48 | options = [] 49 | state_pairs = [] 50 | random_policy = lambda s : random.choice(mdp_distr.get_actions()) 51 | # For each s_{a,1} s_{a,2} pair. 52 | for s_a in abs_states: 53 | for s_a_prime in abs_states: 54 | if not(s_a == s_a_prime): 55 | # Make a non-self loop option. 56 | init_predicate = EqPredicate(y=s_a, func=state_abstr.phi) 57 | term_predicate = EqPredicate(y=s_a_prime, func=state_abstr.phi) 58 | o = Option(init_predicate=init_predicate, 59 | term_predicate=term_predicate, 60 | policy=random_policy) 61 | options.append(o) 62 | state_pairs.append((s_a, s_a_prime)) 63 | 64 | elif incl_self_loops: 65 | # Self loop. 66 | init_predicate = EqPredicate(y=s_a, func=state_abstr.phi) 67 | term_predicate = NeqPredicate(y=s_a, func=state_abstr.phi) # Terminate in any other abstract state. 68 | o = Option(init_predicate=init_predicate, 69 | term_predicate=term_predicate, 70 | policy=random_policy) 71 | 72 | # Initialize with random policy, we'll update it later. 73 | options.append(o) 74 | state_pairs.append((s_a, s_a_prime)) 75 | 76 | 77 | print "\tMade", len(options), "options (formed clique over S_A)." 78 | print "\tPruning..." 79 | 80 | sys.stdout.flush() 81 | 82 | # Prune. 83 | pruned_option_set = _prune_non_directed_options(options, state_pairs, state_abstr, mdp_distr) 84 | 85 | print "\tFinished Pruning. Reduced to", len(pruned_option_set), "options." 86 | 87 | return pruned_option_set 88 | 89 | def _prune_non_directed_options(options, state_pairs, state_abstr, mdp_distr): 90 | ''' 91 | Args: 92 | Options(list) 93 | state_pairs (list) 94 | state_abstr (StateAbstraction) 95 | mdp_distr (MDPDistribution) 96 | 97 | Returns: 98 | (list of Options) 99 | 100 | Summary: 101 | Removes redundant options. That is, if o_1 goes from s_A1 to s_A2, and 102 | o_2 goes from s_A1 *through s_A2 to s_A3, then we get rid of o_2. 103 | ''' 104 | good_options = set([]) 105 | bad_options = set([]) 106 | transition_func = mdp_distr.get_all_mdps()[0].get_transition_func() 107 | 108 | # For each option we created, we'll check overlap. 109 | for i, o in enumerate(options): 110 | print "\t Option", i + 1, "of", len(options) 111 | pre_abs_state, post_abs_state = state_pairs[i] 112 | 113 | # Get init and terminal lower level states. 114 | ground_init_states = state_abstr.get_lower_states_in_abs_state(pre_abs_state) 115 | ground_term_states = state_abstr.get_lower_states_in_abs_state(post_abs_state) 116 | rand_init_g_state = random.choice(ground_init_states) 117 | 118 | # R and T for Option Mini MDP. 119 | def _directed_option_reward_lambda(s, a): 120 | s_prime = transition_func(s, a) 121 | return int(s_prime in ground_term_states and not s in ground_term_states) 122 | 123 | def new_trans_func(s, a): 124 | original = s.is_terminal() 125 | s.set_terminal(s in ground_term_states) 126 | s_prime = transition_func(s,a) 127 | # print s, s_prime, s.is_terminal(), s_prime.is_terminal(), pre_abs_state, post_abs_state, s == s_prime 128 | s.set_terminal(original) 129 | return s_prime 130 | 131 | if pre_abs_state == post_abs_state: 132 | # Self looping option. 133 | mini_mdp_init_states = defaultdict(list) 134 | 135 | # Self loop. Make an option per goal in the cluster. 136 | goal_mdps = [] 137 | goal_state_action_pairs = defaultdict(list) 138 | for i, mdp in enumerate(mdp_distr.get_all_mdps()): 139 | add = False 140 | 141 | # Check if there is a goal for this MDP in one of the ground states. 142 | for s_g in ground_term_states: 143 | for a in mdp.get_actions(): 144 | if mdp.get_reward_func()(s_g, a) > 0.0 and a not in goal_state_action_pairs[s_g]: 145 | goal_state_action_pairs[s_g].append(a) 146 | if isinstance(mdp, GridWorldMDP): 147 | goals = tuple(mdp.get_goal_locs()) 148 | else: 149 | goals = tuple(s_g) 150 | mini_mdp_init_states[goals].append(s_g) 151 | add = True 152 | 153 | if add: 154 | goal_mdps.append(mdp) 155 | 156 | # For each goal. 157 | for goal_mdp in goal_mdps: 158 | 159 | def goal_new_trans_func(s, a): 160 | original = s.is_terminal() 161 | s.set_terminal(s not in ground_term_states or original) 162 | s_prime = goal_mdp.get_transition_func()(s,a) 163 | s.set_terminal(original) 164 | return s_prime 165 | 166 | if isinstance(goal_mdp, GridWorldMDP): 167 | cluster_init_state = random.choice(mini_mdp_init_states[tuple(goal_mdp.get_goal_locs())]) 168 | else: 169 | cluster_init_state = random.choice(ground_init_states) 170 | 171 | # Make a new MDP. 172 | mini_mdp = MDP(actions=goal_mdp.get_actions(), 173 | init_state=cluster_init_state, 174 | transition_func=goal_new_trans_func, 175 | reward_func=goal_mdp.get_reward_func()) 176 | 177 | o_policy, mini_mdp_vi = _make_mini_mdp_option_policy(mini_mdp) 178 | 179 | # Make new option. 180 | new_option = Option(o.init_predicate, o.term_predicate, o_policy) 181 | new_option.set_name(str(ground_init_states[0]) + "-sl") 182 | good_options.add(new_option) 183 | 184 | 185 | continue 186 | else: 187 | # This is a non-self looping option. 188 | mini_mdp = MDP(actions=mdp_distr.get_actions(), 189 | init_state=rand_init_g_state, 190 | transition_func=new_trans_func, 191 | reward_func=_directed_option_reward_lambda) 192 | 193 | o_policy, mini_mdp_vi = _make_mini_mdp_option_policy(mini_mdp) 194 | # Compute overlap w.r.t. plans from each state. 195 | for init_g_state in ground_init_states: 196 | # Prune overlapping ones. 197 | plan, state_seq = mini_mdp_vi.plan(init_g_state) 198 | opt_name = str(ground_init_states[0]) + "-" + str(ground_term_states[0]) 199 | o.set_name(opt_name) 200 | options[i] = o 201 | 202 | if not _check_overlap(o, state_seq, options, bad_options): 203 | # Give the option the new directed policy and name. 204 | o.set_policy(o_policy) 205 | good_options.add(o) 206 | break 207 | else: 208 | # The option overlaps, don't include it. 209 | bad_options.add(o) 210 | 211 | return good_options 212 | 213 | def _make_mini_mdp_option_policy(mini_mdp): 214 | ''' 215 | Args: 216 | mini_mdp (MDP) 217 | 218 | Returns: 219 | Policy 220 | ''' 221 | # Solve the MDP defined by the terminal abstract state. 222 | mini_mdp_vi = ValueIteration(mini_mdp, delta=0.005, max_iterations=1000, sample_rate=30) 223 | iters, val = mini_mdp_vi.run_vi() 224 | 225 | o_policy_dict = make_dict_from_lambda(mini_mdp_vi.policy, mini_mdp_vi.get_states()) 226 | o_policy = PolicyFromDict(o_policy_dict) 227 | 228 | return o_policy.get_action, mini_mdp_vi 229 | 230 | def _check_overlap(option, state_seq, options, bad_options): 231 | ''' 232 | Args: 233 | state_seq (list of State) 234 | options 235 | 236 | Returns: 237 | (bool): If true, we should remove this option. 238 | ''' 239 | terminal_is_reachable = False 240 | bad_options = set(bad_options) 241 | 242 | for i, s_g in enumerate(state_seq): 243 | for o_prime in options: 244 | 245 | if o_prime in bad_options: 246 | continue 247 | 248 | is_in_middle = not (option.is_term_true(s_g) or option.is_init_true(s_g)) 249 | 250 | if is_in_middle and o_prime.is_init_true(s_g): 251 | # We should get rid of @option, because it's path goes through another init. 252 | return True 253 | 254 | # Only keep options whose terminal states are reachable from the initiation set. 255 | if option.is_term_true(s_g): 256 | terminal_is_reachable = True 257 | 258 | if not terminal_is_reachable: 259 | # Can't reach the terminal state. 260 | return True 261 | 262 | return False 263 | 264 | def compute_sub_opt_func_for_mdp_distr(mdp_distr): 265 | ''' 266 | Args: 267 | mdp_distr (dict) 268 | 269 | Returns: 270 | (list): Contains the suboptimality function for each MDP in mdp_distr. 271 | subopt: V^*(s) - Q^(s,a) 272 | ''' 273 | actions = mdp_distr.get_actions() 274 | sub_opt_funcs = [] 275 | 276 | i = 0 277 | for mdp in mdp_distr.get_mdps(): 278 | print "\t mdp", i + 1, "of", mdp_distr.get_num_mdps() 279 | vi = ValueIteration(mdp, delta=0.001, max_iterations=1000) 280 | iters, value = vi.run_vi() 281 | 282 | new_sub_opt_func = defaultdict(float) 283 | for s in vi.get_states(): 284 | max_q = float("-inf") 285 | for a in actions: 286 | next_q = vi.get_q_value(s, a) 287 | if next_q > max_q: 288 | max_q = next_q 289 | 290 | for a in actions: 291 | new_sub_opt_func[(s, a)] = max_q - vi.get_q_value(s,a) 292 | 293 | sub_opt_funcs.append(new_sub_opt_func) 294 | i+=1 295 | 296 | return sub_opt_funcs 297 | 298 | def _compute_agreement(sub_opt_funcs, mdp_distr, state, action, epsilon=0.00): 299 | ''' 300 | Args: 301 | sub_opt_funcs (list of dicts) 302 | mdp_distr (dict) 303 | state (simple_rl.State) 304 | action (str) 305 | epsilon (float) 306 | 307 | Returns: 308 | (list) 309 | 310 | Summary: 311 | Computes the MDPs for which @action is epsilon-optimal in @state. 312 | ''' 313 | all_sub_opt_vals = [sof[(state, action)] for sof in sub_opt_funcs] 314 | eps_opt_mdps = [int(sov <= epsilon) for sov in all_sub_opt_vals] 315 | 316 | return eps_opt_mdps 317 | 318 | def add_next_option(mdp_distr, next_decis_state, sub_opt_funcs): 319 | ''' 320 | Args: 321 | 322 | Returns: 323 | (Option) 324 | ''' 325 | 326 | # Init func and terminal func. 327 | init_func = lambda s : s == next_decis_state 328 | term_func = lambda s : True 329 | term_func_states = [] 330 | 331 | # Misc. 332 | reachable_states = Queue.Queue() 333 | reachable_states.put(next_decis_state) 334 | visited_states = set([next_decis_state]) 335 | policy_dict = defaultdict(str) 336 | actions = mdp_distr.get_actions() 337 | transition_func = mdp_distr.get_mdps()[0].get_transition_func() 338 | 339 | # Tracks which MDPs share near-optimal action sequences. 340 | mdps_active = [1 for m in range(len(sub_opt_funcs))] 341 | 342 | while not reachable_states.empty(): 343 | # Pointers for this iteration. 344 | cur_state = reachable_states.get() 345 | next_action = random.choice(actions) 346 | max_agreement = 0 # agreement for this state. 347 | 348 | # Compute action with max agreement (num active MDPs with shared eps-opt action.) 349 | for a in actions: 350 | agreement_ls = _compute_agreement(sub_opt_funcs, mdp_distr, cur_state, a) 351 | active_agreement_ls = [mdps_active[i] & agreement_ls[i] for i in range(len(agreement_ls))] 352 | agreement = sum(active_agreement_ls) 353 | if agreement > max_agreement: 354 | next_action = a 355 | max_agreement = agreement 356 | 357 | # Set policy for this state to the action with maximal agreement. 358 | policy_dict[cur_state] = next_action 359 | max_agreement_ls = _compute_agreement(sub_opt_funcs, mdp_distr, cur_state, next_action) 360 | mdps_active = [mdps_active[i] & max_agreement_ls[i] for i in range(len(max_agreement_ls))] 361 | agreement = sum(mdps_active) 362 | 363 | # Move to the next state according to max agreement action. 364 | next_state = transition_func(cur_state, next_action) 365 | 366 | if agreement <= 2 or next_state.is_terminal(): 367 | term_func_states.append(next_state) 368 | 369 | if next_state not in visited_states: 370 | reachable_states.put(next_state) 371 | visited_states.add(next_state) 372 | 373 | if len(term_func_states): 374 | term_func_states.append(next_state) 375 | 376 | # Turn policy dict into a function and make the option. 377 | o = Option(init_func, term_func=term_func_states, policy=policy_dict) 378 | 379 | return o 380 | 381 | def make_greedy_options(mdp_distr): 382 | ''' 383 | Assumptions: 384 | Shared S, A, start state, T, gamma between all M in mdp_distr. 385 | ''' 386 | 387 | if isinstance(mdp_distr, MDP): 388 | print "Warning: attempting to create options for a single MDP." 389 | mdp_distr = {1.0:mdp_distr} 390 | 391 | # Grab relevant MDP distr. components. 392 | init_state = mdp_distr.keys()[0].get_init_state() 393 | transition_func = mdp_distr.keys()[0].get_transition_func() 394 | actions = mdp_distr.keys()[0].get_actions() 395 | 396 | # Setup data structures. 397 | print "Computing advantage functions." 398 | sub_opt_funcs = compute_sub_opt_func_for_mdp_distr(mdp_distr) 399 | decision_states = Queue.Queue() 400 | decision_states.put(init_state) 401 | new_aa = ActionAbstraction(options=actions, prim_actions=actions) 402 | 403 | visited_states = set([init_state]) 404 | # Loop over reachable states. 405 | num_options = 0 406 | print "Learning:" 407 | while num_options < 2 and (not decision_states.empty()): 408 | print "\toption", num_options + 1 409 | # Add option as long as we have a decision state. 410 | # A decision state is a state where we don't have a good option. 411 | 412 | next_decis_state = decision_states.get() 413 | o = add_next_option(mdp_distr, next_decis_state, sub_opt_funcs) 414 | new_aa.add_option(o) 415 | num_options += 1 416 | new_state = o.act_until_terminal(next_decis_state, transition_func) 417 | if new_state not in visited_states: 418 | decision_states.put(new_state) 419 | visited_states.add(new_state) 420 | 421 | return new_aa 422 | 423 | def print_aa(action_abstr, state_space): 424 | ''' 425 | Args: 426 | action_abstr (ActionAbstraction) 427 | state_space (list of State) 428 | 429 | Summary: 430 | Prints out options in a convenient way. 431 | ''' 432 | 433 | options = action_abstr.get_actions() 434 | for o in options: 435 | inits = [s for s in state_space if o.is_init_true(s)] 436 | terms = [s for s in state_space if o.is_term_true(s)] 437 | print o 438 | print "\tinit:", 439 | for s in inits: 440 | print s, 441 | print 442 | print "\tterm:", 443 | for s in terms: 444 | print s, 445 | print 446 | print 447 | print -------------------------------------------------------------------------------- /chain.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Other imports. 4 | from simple_rl.agents import RMaxAgent, DelayedQAgent 5 | from simple_rl.run_experiments import run_agents_on_mdp 6 | from simple_rl.mdp import MDP 7 | from simple_rl.tasks.chain.ChainStateClass import ChainState 8 | from simple_rl.abstraction.AbstractionWrapperClass import AbstractionWrapper 9 | from simple_rl.abstraction.state_abs import indicator_funcs 10 | from abstraction_experiments import get_sa 11 | 12 | class BadChainMDP(MDP): 13 | 14 | ACTIONS = ["left", "right", "loop"] 15 | 16 | def __init__(self, gamma, kappa=0.001): 17 | MDP.__init__(self, BadChainMDP.ACTIONS, self._transition_func, self._reward_func, init_state=ChainState(1), gamma=gamma) 18 | self.num_states = 4 19 | self.kappa = kappa 20 | 21 | def _reward_func(self, state, action): 22 | ''' 23 | Args: 24 | state (State) 25 | action (str) 26 | statePrime 27 | 28 | Returns 29 | (float) 30 | ''' 31 | if state.is_terminal(): 32 | return 0 33 | elif action == "right" and state.num + 1 == self.num_states: 34 | return 1 # RMax. 35 | elif action == "loop" and state.num < self.num_states: 36 | return self.kappa 37 | else: 38 | return 0 39 | 40 | def _transition_func(self, state, action): 41 | ''' 42 | Args: 43 | state (State) 44 | action (str) 45 | 46 | Returns 47 | (State) 48 | ''' 49 | if state.is_terminal(): 50 | # Terminal, done. 51 | return state 52 | elif action == "right" and state.num + 1 == self.num_states: 53 | # Applied right in s2, move to terminal. 54 | terminal_state = ChainState(self.num_states) 55 | terminal_state.set_terminal(True) 56 | return terminal_state 57 | elif action == "right" and state.num < self.num_states - 1: 58 | # If in s0 or s1, move to s2. 59 | return ChainState(state.num + 1) 60 | elif action == "left" and state.num > 1: 61 | # If in s1, or s2, move left. 62 | return ChainState(state.num - 1) 63 | else: 64 | # Otherwise, stay in the same state. 65 | return state 66 | 67 | def __str__(self): 68 | return "Bad_chain" 69 | 70 | def main(): 71 | 72 | # Grab experiment params. 73 | mdp = BadChainMDP(gamma=0.95, kappa=0.001) 74 | actions = mdp.get_actions() 75 | 76 | # ======================= 77 | # == Make Abstractions == 78 | # ======================= 79 | sa_q_eps = get_sa(mdp, indic_func=indicator_funcs._q_eps_approx_indicator, epsilon=0.1) 80 | 81 | # RMax Agents. 82 | rmax_agent = RMaxAgent(actions) 83 | abstr_rmax_agent = AbstractionWrapper(RMaxAgent, state_abstr=sa_q_eps, agent_params={"actions":actions}, name_ext="-$\\phi_{Q_\\epsilon^*}$") 84 | 85 | # Delayed Q Agents. 86 | del_q_agent = DelayedQAgent(actions) 87 | abstr_del_q_agent = AbstractionWrapper(DelayedQAgent, state_abstr=sa_q_eps, agent_params={"actions":actions}, name_ext="-$\\phi_{Q_\\epsilon^*}$") 88 | 89 | run_agents_on_mdp([rmax_agent, abstr_rmax_agent, del_q_agent, abstr_del_q_agent], mdp, instances=50, steps=250, episodes=1) 90 | 91 | if __name__ == "__main__": 92 | main() 93 | -------------------------------------------------------------------------------- /hierarch/ActionAbstractionStackClass.py: -------------------------------------------------------------------------------- 1 | # Python imports. 2 | from collections import defaultdict 3 | from os import path 4 | import sys 5 | 6 | # Other imports. 7 | parent_dir = path.dirname(path.dirname(path.abspath(__file__))) 8 | sys.path.append(parent_dir) 9 | from simple_rl.mdp.StateClass import State 10 | from simple_rl.mdp.MDPClass import MDP 11 | from action_abs.ActionAbstractionClass import ActionAbstraction 12 | 13 | class ActionAbstractionStack(ActionAbstraction): 14 | 15 | def __init__(self, list_of_aa, prim_actions, level=0): 16 | ''' 17 | Args: 18 | list_of_aa (list) 19 | ''' 20 | self.list_of_aa = list_of_aa 21 | self.level = level 22 | self.prim_actions = prim_actions 23 | ActionAbstraction.__init__(self, options=self.get_actions(level), prim_actions=prim_actions) 24 | 25 | def get_level(self): 26 | return self.level 27 | 28 | def get_num_levels(self): 29 | return len(self.list_of_aa) 30 | 31 | def get_aa_list(self): 32 | return self.list_of_aa 33 | 34 | def get_actions(self, level=None): 35 | if level is None: 36 | level = self.level 37 | elif level == -1: 38 | level = self.get_num_levels() 39 | elif level == 0: 40 | # If we're at level 0, let the agent act with primitives. 41 | return self._convert_to_options(self.prim_actions) 42 | 43 | return self.list_of_aa[level - 1].get_actions() 44 | 45 | def set_level(self, new_level): 46 | self.level = new_level 47 | 48 | def set_option_executing(self, option): 49 | self.cur_option = option 50 | self.is_cur_executing = True 51 | 52 | def act(self, agent, state_abstr_stack, ground_state, reward, level=None): 53 | ''' 54 | Args: 55 | agent (Agent) 56 | abstr_state (State) 57 | lower_state (State): One level down from abstr_state. 58 | reward (float) 59 | 60 | Returns: 61 | (str) 62 | ''' 63 | if level is None: 64 | level = self.level 65 | elif level == -1: 66 | level = self.get_num_levels() 67 | elif level == 0: 68 | # If we're at level 0, let the agent act with primitives. 69 | agent.actions = self.prim_actions 70 | return agent.act(ground_state, reward) 71 | 72 | abstr_state = state_abstr_stack.phi(ground_state, level) 73 | lower_state = state_abstr_stack.phi(ground_state, level - 1) 74 | 75 | # Calls agent update. 76 | lower_option = self.list_of_aa[level - 1].act(agent, abstr_state, lower_state, reward) 77 | level -= 1 78 | 79 | # Descend via options. 80 | while level > 0: 81 | lower_state = state_abstr_stack.phi(ground_state, level - 1) 82 | lower_option = lower_option.act(lower_state) 83 | level -= 1 84 | 85 | return lower_option 86 | 87 | def add_aa(self, new_aa): 88 | self.list_of_aa.append(new_aa) 89 | 90 | def print_action_spaces_sizes(self): 91 | print "Action Space Sizes:" 92 | print "\tl_0:", len(self.prim_actions) 93 | for i in xrange(len(self.list_of_aa)): 94 | print "\tl_" + str(i + 1) + ":", len(self.list_of_aa[i].get_actions()) 95 | print 96 | -------------------------------------------------------------------------------- /hierarch/DynamicHierarchyAgentClass.py: -------------------------------------------------------------------------------- 1 | # Python imports. 2 | from collections import defaultdict 3 | import numpy as np 4 | 5 | # Other imports. 6 | from HierarchyAgentClass import HierarchyAgent 7 | 8 | class DynamicHierarchyAgent(HierarchyAgent): 9 | 10 | def __init__(self, SubAgentClass, sa_stack, aa_stack, cur_level=0, name_ext="", auo=30): 11 | ''' 12 | Args: 13 | sa_stack (StateAbstractionStack) 14 | aa_stack (ActionAbstractionStack) 15 | cur_level (int): Must be in [0:len(state_abstr_stack)] 16 | ''' 17 | HierarchyAgent.__init__(self, SubAgentClass, sa_stack, aa_stack, cur_level=0, name_ext=name_ext) 18 | self.num_switches = 0 19 | self.num_actions_since_open = 0 20 | self.actions_until_open = auo 21 | 22 | def act(self, ground_state, reward): 23 | ''' 24 | Args: 25 | ground_state (State) 26 | reward (float) 27 | 28 | Return: 29 | (str) 30 | ''' 31 | 32 | if self.num_actions_since_open > self.actions_until_open and not self.action_abstr_stack.is_next_step_continuing_option(ground_state): 33 | # We're in a "decision" state, so change levels. 34 | new_level = int(not(bool(self.cur_level))) 35 | if self.cur_level != new_level: 36 | self.num_switches += 1 37 | self.set_level(new_level) 38 | self.num_actions_since_open = 0 39 | 40 | action = HierarchyAgent.act(self, ground_state, reward) 41 | 42 | self.num_actions_since_open += 1 43 | 44 | return action 45 | 46 | def _compute_max_v_hat_level(self, ground_state): 47 | ''' 48 | Args: 49 | ground_state (simple_rl.mdp.State) 50 | 51 | Returns: 52 | (int): The level with the highest value estimate. 53 | ''' 54 | 55 | if self.cur_level == 1: 56 | return 0 57 | else: 58 | max_q = float("-inf") 59 | best_lvl = 0 60 | for lvl in xrange(self.get_num_levels() + 1): 61 | abstr_state = self.state_abstr_stack.phi(ground_state, lvl) 62 | v_hat = self.agent.get_max_q_value(abstr_state) 63 | # print lvl, v_hat 64 | change_cost = 0.0009 * int(lvl != self.cur_level) 65 | if v_hat - change_cost > max_q: 66 | best_lvl = lvl 67 | max_q = v_hat 68 | return best_lvl 69 | 70 | def reset(self): 71 | print "num switches this instance:", self.num_switches 72 | self.num_switches = 0 73 | HierarchyAgent.reset(self) 74 | -------------------------------------------------------------------------------- /hierarch/HRMaxAgentClass.py: -------------------------------------------------------------------------------- 1 | ''' 2 | HRMaxAgentClass.py: Class for an RMaxAgent that uses a hierarchy controller 3 | ''' 4 | 5 | # Python imports. 6 | import random 7 | from collections import defaultdict 8 | 9 | # Local classes. 10 | from simple_rl.agents import RMaxAgent 11 | 12 | class HRMaxAgent(RMaxAgent): 13 | ''' 14 | Implementation for an R-Max Agent that uses a hierarchy 15 | during its planning phase. 16 | ''' 17 | 18 | def __init__(self, actions, sa_stack, aa_stack, level=0, gamma=0.99, horizon=4, s_a_threshold=1): 19 | self.sa_stack = sa_stack 20 | self.aa_stack = aa_stack 21 | self.level = level 22 | RMaxAgent.__init__(self, actions=actions, gamma=gamma, horizon=horizon, s_a_threshold=s_a_threshold) 23 | self.name = "hrmax-h" + str(horizon) 24 | 25 | def act(self, ground_state, reward): 26 | if not self.aa_stack.is_next_step_continuing_option(ground_state): 27 | # If we're in a decision state, set the level, update, and pick a new action. 28 | self._set_level(ground_state) 29 | 30 | # Grab the landing state in the abstract. 31 | cur_abstr_state = self.sa_stack.phi(ground_state, self.level) 32 | 33 | # Update given s, a, r, s' : self.prev_state, self.prev_action, reward, state 34 | self.update(self.prev_state, self.prev_action, reward, cur_abstr_state) 35 | 36 | # Update actions. 37 | self.actions = self.aa_stack.get_actions(self.level) 38 | 39 | # Compute best action. 40 | action = self.get_max_q_action(ground_state) 41 | 42 | self.aa_stack.set_option_executing(action) 43 | 44 | # Update pointers. 45 | self.prev_action = action 46 | self.prev_state = cur_abstr_state 47 | else: 48 | # In the middle of computing an option. 49 | action = self.aa_stack.get_next_ground_action(ground_state) 50 | 51 | return action 52 | 53 | def _compute_max_qval_action_pair(self, ground_state, horizon=None, bootstrap=False): 54 | ''' 55 | Args: 56 | ground_state (State) 57 | horizon (int): Indicates the level of recursion depth for computing Q. 58 | 59 | Returns: 60 | (tuple) --> (float, str): where the float is the Qval, str is the action. 61 | ''' 62 | # If this is the first call, use the default horizon. 63 | horizon = self.horizon if horizon is None else horizon 64 | 65 | if horizon <= 0: 66 | r_max = float("-inf") 67 | best_a = self.actions[0] 68 | for a in self.actions: 69 | r = self._get_reward(ground_state, a) 70 | if r > r_max: 71 | best_a = a 72 | r_max = r 73 | return r, a 74 | 75 | # Update level and apply phi, omega. 76 | self._set_level(ground_state, horizon=horizon-1) 77 | 78 | self.actions = self.aa_stack.get_actions(level=self.level) 79 | decision_state = self.sa_stack.phi(ground_state, level=self.level) 80 | 81 | # Grab random initial action in case all equal 82 | best_action = random.choice(self.actions) 83 | max_q_val = self.get_q_value(decision_state, best_action, horizon) 84 | 85 | # Find best action (action w/ current max predicted Q value) 86 | for action in self.actions: 87 | q_s_a = self.get_q_value(decision_state, action, horizon) 88 | if q_s_a > max_q_val: 89 | max_q_val = q_s_a 90 | best_action = action 91 | 92 | return max_q_val, best_action 93 | 94 | def _compute_exp_future_return(self, ground_state, action, horizon=None): 95 | ''' 96 | Args: 97 | ground_state (State) 98 | action (str) 99 | horizon (int): Recursion depth to compxute Q 100 | 101 | Return: 102 | (float): Discounted expected future return from applying @action in @state. 103 | ''' 104 | # If this is the first call, use the default horizon. 105 | horizon = self.horizon if horizon is None else horizon 106 | 107 | self._set_level(ground_state, horizon=horizon-1) 108 | 109 | # Compute abstracted state. 110 | abstr_state = self.sa_stack.phi(ground_state, self.level) 111 | 112 | next_state_dict = self.transitions[abstr_state][action] 113 | 114 | denominator = float(sum(next_state_dict.values())) 115 | state_weights = defaultdict(float) 116 | for next_state in next_state_dict.keys(): 117 | count = next_state_dict[next_state] 118 | state_weights[next_state] = (count / denominator) 119 | 120 | weighted_future_returns = [self.get_max_q_value(next_state, horizon-1) * state_weights[next_state] for next_state in next_state_dict.keys()] 121 | 122 | return sum(weighted_future_returns) 123 | 124 | def _set_level(self, ground_state, horizon=None): 125 | # If this is the first call, use the default horizon. 126 | horizon = self.horizon if horizon is None else horizon 127 | 128 | max_q = float("-inf") 129 | best_lvl = 0 130 | for lvl in xrange(self.sa_stack.get_num_levels() + 1): 131 | abstr_state = self.sa_stack.phi(ground_state, lvl) 132 | v_hat = self.get_max_q_value(abstr_state, horizon=horizon) 133 | if v_hat - (lvl * 0.001) > max_q: 134 | best_lvl = lvl 135 | max_q = v_hat 136 | 137 | self.level = best_lvl 138 | -------------------------------------------------------------------------------- /hierarch/HierarchicalValueIterationClass.py: -------------------------------------------------------------------------------- 1 | # Python imports. 2 | import Queue 3 | 4 | # Other imports. 5 | from simple_rl.utils import make_mdp 6 | from simple_rl.planning.ValueIterationClass import ValueIteration 7 | import hierarchy_helpers 8 | 9 | class HierarchicalValueIteration(ValueIteration): 10 | 11 | def __init__(self, mdp, rew_func_list, trans_func_list, sa_stack, aa_stack, name="hierarch_value_iter", delta=0.001, max_iterations=200, sample_rate=3): 12 | ''' 13 | Args: 14 | mdp (MDP) 15 | delta (float): After an iteration if VI, if no change more than @\delta has occurred, terminates. 16 | max_iterations (int): Hard limit for number of iterations. 17 | sample_rate (int): Determines how many samples from @mdp to take to estimate T(s' | s, a). 18 | horizon (int): Number of steps before terminating. 19 | ''' 20 | self.rew_func_list = rew_func_list 21 | self.trans_func_list = trans_func_list 22 | self.sa_stack = sa_stack 23 | self.aa_stack = aa_stack 24 | abstr_actions = [] 25 | 26 | for aa in self.aa_stack.get_aa_list(): 27 | abstr_actions += aa.get_actions() 28 | self.actions = mdp.get_actions() + abstr_actions 29 | 30 | ValueIteration.__init__(self, mdp, name=name, delta=delta, max_iterations=max_iterations, sample_rate=sample_rate) 31 | 32 | def _compute_matrix_from_trans_func(self): 33 | if self.has_computed_matrix: 34 | # We've already run this, just return. 35 | return 36 | 37 | self.trans_dict = defaultdict(lambda:defaultdict(lambda:defaultdict(float))) 38 | # K: state 39 | # K: a 40 | # K: s_prime 41 | # V: prob 42 | 43 | for s in self.get_states(): 44 | for a in self.actions: 45 | for sample in xrange(self.sample_rate): 46 | s_prime = self.transition_func(s, a) 47 | self.trans_dict[s][a][s_prime] += 1.0 / self.sample_rate 48 | 49 | self.has_computed_matrix = True 50 | 51 | def _compute_reachable_state_space(self): 52 | ''' 53 | Summary: 54 | Starting with @self.start_state, determines all reachable states 55 | and stores them in self.states. 56 | ''' 57 | state_queue = Queue.Queue() 58 | 59 | for lvl in xrange(self.sa_stack.get_num_levels()): 60 | abstr_state = self.sa_stack.phi(self.init_state, lvl) 61 | self.states.add(abstr_state) 62 | state_queue.put(abstr_state) 63 | 64 | while not state_queue.empty(): 65 | s = state_queue.get() 66 | for a in self.actions: 67 | for samples in xrange(self.sample_rate): # Take @sample_rate samples to estimate E[V] 68 | next_state = self.transition_func(s,a) # Need to use T w.r.t. the abstract MDP... 69 | 70 | if next_state not in self.states: 71 | self.states.add(next_state) 72 | state_queue.put(next_state) 73 | 74 | self.reachability_done = True 75 | 76 | def main(): 77 | 78 | 79 | # ======================== 80 | # === Make Environment === 81 | # ======================== 82 | mdp_class = "four_room" 83 | environment = make_mdp.make_mdp_distr(mdp_class=mdp_class, grid_dim=10) 84 | actions = environment.get_actions() 85 | 86 | # ========================== 87 | # === Make SA, AA Stacks === 88 | # ========================== 89 | # sa_stack, aa_stack = aa_stack_h.make_random_sa_diropt_aa_stack(environment, max_num_levels=3) 90 | sa_stack, aa_stack = hierarchy_helpers.make_hierarchy(environment, num_levels=3) 91 | 92 | mdp = environment.sample() 93 | HVI = HierarchicalValueIteration(mdp, sa_stack, aa_stack) 94 | VI = ValueIteration(mdp) 95 | 96 | h_iters, h_val = HVI.run_vi() 97 | iters, val = VI.run_vi() 98 | 99 | print "H:", h_iters, h_val 100 | print "V:", iters, val 101 | 102 | 103 | if __name__ == "__main__": 104 | main() -------------------------------------------------------------------------------- /hierarch/HierarchyAgentClass.py: -------------------------------------------------------------------------------- 1 | 2 | # Other imports. 3 | from simple_rl.agents import Agent 4 | 5 | class HierarchyAgent(Agent): 6 | 7 | def __init__(self, SubAgentClass, sa_stack, aa_stack, cur_level=0, name_ext=""): 8 | ''' 9 | Args: 10 | sa_stack (StateAbstractionStack) 11 | aa_stack (ActionAbstractionStack) 12 | cur_level (int): Must be in [0:len(state_abstr_stack)] 13 | ''' 14 | # Setup the abstracted agent. 15 | self.state_abstr_stack = sa_stack 16 | self.action_abstr_stack = aa_stack 17 | self.cur_level = cur_level 18 | self.agent = SubAgentClass(actions=self.get_cur_actions()) 19 | Agent.__init__(self, name=self.agent.name + "-hierarch" + name_ext, actions=self.get_cur_actions()) 20 | 21 | # -- Accessors -- 22 | 23 | def get_num_levels(self): 24 | return self.state_abstr_stack.get_num_levels() 25 | 26 | def get_cur_actions(self): 27 | if self.cur_level == 0: 28 | return self.action_abstr_stack.prim_actions 29 | 30 | return self.get_cur_action_abstr().get_actions() 31 | 32 | def get_cur_action_abstr(self): 33 | return self.action_abstr_stack.get_aa_list()[self.cur_level - 1] 34 | 35 | def get_cur_abstr_state(self, state): 36 | return self.state_abstr_stack.phi(state, self.cur_level) 37 | 38 | # -- Mutators -- 39 | 40 | def add_sa_aa_pair(self, sa, aa): 41 | self.state_abstr_stack.add_sa(sa) 42 | self.action_abstr_stack.add_aa(aa) 43 | 44 | def incr_level(self): 45 | self.cur_level = min(self.cur_level + 1, self.state_abstr_stack.get_num_levels()) 46 | 47 | def decr_level(self): 48 | self.cur_level = min(self.cur_level - 1, 0) 49 | 50 | def set_level(self, new_level): 51 | if new_level < 0 or new_level > self.get_num_levels(): 52 | print "HierarchyAgentError: the given level (" + str(new_level) +") exceeds the hierarchy height (" + str(self.get_num_levels()) + ")" 53 | quit() 54 | 55 | self.cur_level = new_level 56 | 57 | # -- Central Act Method -- 58 | 59 | def act(self, ground_state, reward): 60 | ''' 61 | Args: 62 | ground_state (State) 63 | reward (float) 64 | 65 | Return: 66 | (str) 67 | ''' 68 | # Give the SA stack, ground state, and reward to the AA stack. 69 | return self.action_abstr_stack.act(self.agent, self.state_abstr_stack, ground_state, reward, level=self.cur_level) 70 | 71 | # -- Reset -- 72 | 73 | def reset(self): 74 | self.agent.reset() 75 | for aa in self.action_abstr_stack.get_aa_list(): 76 | aa.reset() 77 | 78 | def end_of_episode(self): 79 | self.agent.end_of_episode() 80 | for aa in self.action_abstr_stack.get_aa_list(): 81 | aa.end_of_episode() 82 | 83 | def _reset_reward(self): 84 | self.agent._reset_reward() 85 | -------------------------------------------------------------------------------- /hierarch/HierarchyStateClass.py: -------------------------------------------------------------------------------- 1 | from simple_rl.mdp.StateClass import State 2 | 3 | ''' HierarchyStateClass.py: Contains the HierarchyState Class. ''' 4 | 5 | class HierarchyState(State): 6 | 7 | def __init__(self, data=[], is_terminal=False, level=0): 8 | self.level = level 9 | State.__init__(self, data=data, is_terminal=is_terminal) 10 | 11 | def get_level(self): 12 | return self.level 13 | 14 | def __str__(self): 15 | return State.__str__(self) + "-lvl=" + str(self.level) 16 | -------------------------------------------------------------------------------- /hierarch/RewardFuncClass.py: -------------------------------------------------------------------------------- 1 | # Python imports. 2 | import random 3 | from collections import defaultdict 4 | 5 | class RewardFunc(object): 6 | 7 | def __init__(self, reward_func_lambda, state_space, action_space): 8 | self.reward_dict = make_dict_from_lambda(reward_func_lambda, state_space, action_space) 9 | 10 | def reward_func(self, state, action): 11 | return self.reward_dict[state][action] 12 | 13 | def make_dict_from_lambda(reward_func_lambda, state_space, action_space, sample_rate=1): 14 | reward_dict = defaultdict(lambda:defaultdict(float)) 15 | for s in state_space: 16 | for a in action_space: 17 | for i in xrange(sample_rate): 18 | reward_dict[s][a] = reward_func_lambda(s, a) / sample_rate 19 | 20 | return reward_dict -------------------------------------------------------------------------------- /hierarch/StateAbstractionStackClass.py: -------------------------------------------------------------------------------- 1 | # Python imports. 2 | from collections import defaultdict 3 | from os import path 4 | import sys 5 | 6 | # Other imports. 7 | from simple_rl.mdp.MDPClass import MDP 8 | parent_dir = path.dirname(path.dirname(path.abspath(__file__))) 9 | sys.path.append(parent_dir) 10 | from state_abs.StateAbstractionClass import StateAbstraction 11 | from HierarchyStateClass import HierarchyState 12 | 13 | 14 | ''' 15 | NOTE: Level 0 is the ground state space. Level 1 is the first abstracted, and so on. 16 | --> Therefore, the "level 0" state abstraction is the identity. 17 | ''' 18 | 19 | class StateAbstractionStack(StateAbstraction): 20 | 21 | def __init__(self, list_of_phi, level=0): 22 | ''' 23 | Args: 24 | list_of_phi (list) 25 | ''' 26 | self.list_of_phi = list_of_phi # list where each item is a dict, key:state, val:int. (int represents an abstract state). 27 | cur_phi = {} if len(self.list_of_phi) == 0 else self.list_of_phi[level] 28 | self.level = 0 29 | StateAbstraction.__init__(self, phi=cur_phi) 30 | 31 | def get_num_levels(self): 32 | return len(self.list_of_phi) 33 | 34 | def get_level(self): 35 | return self.level 36 | 37 | def set_level_to_max(self): 38 | self.level = self.get_num_levels() 39 | 40 | def set_level(self, new_level): 41 | if new_level > self.get_num_levels() or new_level < 0: 42 | print "StateAbstractionStack Error: given level (" + str(new_level) + ") is invalid. Must be between" + \ 43 | "0 and the number of levels in the stack (" + str(self.get_num_levels()) + ")." 44 | quit() 45 | self.level = new_level 46 | 47 | def phi(self, lower_state, level=None): 48 | ''' 49 | Args: 50 | lower_state (simple_rl.State) 51 | level (int) 52 | 53 | Returns: 54 | (simple_rl.State) 55 | 56 | Notes: 57 | level: 58 | 0 --> Ground 59 | 1 --> First abstract layer, and so on. 60 | ''' 61 | 62 | # Get the level to raise the state to. 63 | if level == None: 64 | # Defaults to whatever it's set to. 65 | level = self.level 66 | elif level == -1: 67 | # Grab the last one. 68 | level = self.get_num_levels() - 1 69 | 70 | if self.get_num_levels() == 0 or level == 0: 71 | # If there are no more levels, identity function. 72 | return lower_state 73 | 74 | # Suppose level = 1. Now we'll grab the phi in the first slot and abstract it. 75 | # Suppose level = 2. We abstract once, cur_level=1, abstract again, cur_level=2 (DONE). 76 | 77 | # Get the current state's level. 78 | if isinstance(lower_state, HierarchyState): 79 | cur_level = lower_state.get_level() 80 | else: 81 | cur_level = 0 82 | 83 | 84 | # if cur_level < level: 85 | # # Get the immediate abstracted state (one lvl higher). 86 | # print "cur_level:", cur_level, level 87 | # s_a = self.list_of_phi[cur_level][lower_state] 88 | # cur_level += 1 89 | 90 | s_a = lower_state 91 | # Iterate until we're at the right level. 92 | while cur_level < level: 93 | s_a = self.list_of_phi[cur_level][s_a] 94 | cur_level += 1 95 | 96 | return s_a 97 | 98 | def get_abs_states(self): 99 | # For each ground state, get its abstract state. 100 | if self.level == 0: 101 | # If we're at level 0, identity. 102 | return self.get_ground_states() 103 | 104 | return set([abs_state for abs_state in set(self.list_of_phi[self.level - 1].values())]) 105 | 106 | def get_ground_states_in_abs_state(self, abs_state): 107 | ''' 108 | Args: 109 | abs_state (State) 110 | 111 | Returns: 112 | (list): Contains all ground states in the cluster. 113 | ''' 114 | return [s_g for s_g in self.get_ground_states() if self.phi(s_g, level=abs_state.get_level()) == abs_state] 115 | 116 | def get_lower_states_in_abs_state(self, abs_state): 117 | ''' 118 | Args: 119 | abs_state (State) 120 | 121 | Returns: 122 | (list): Contains all ground states in the cluster. 123 | ''' 124 | return [s_g for s_g in self.get_lower_states(level=abs_state.get_level()) if self.phi(s_g, level=abs_state.get_level()) == abs_state] 125 | 126 | 127 | def get_lower_states(self, level=None): 128 | if level == None: 129 | # Defaults to whatever it's set to. 130 | level = self.level 131 | elif level == -1: 132 | # Grab the last one. 133 | level = self.get_num_levels() 134 | elif level == 0: 135 | return self.get_ground_states() 136 | 137 | return self.list_of_phi[level - 1].keys() 138 | 139 | def get_ground_states(self): 140 | return self.list_of_phi[0].keys() 141 | 142 | def add_phi(self, new_phi): 143 | self.list_of_phi.append(new_phi) 144 | 145 | def remove_last_phi(self): 146 | self.list_of_phi = self.list_of_phi[:-1] 147 | 148 | def print_state_space_sizes(self): 149 | print "State Space Sizes:" 150 | print "\tl_0:", len(self.get_ground_states()) 151 | 152 | for i, phi in enumerate(self.list_of_phi): 153 | print "\tl_" + str(i + 1) + ":", len(set(phi.values())) 154 | print 155 | -------------------------------------------------------------------------------- /hierarch/TransitionFuncClass.py: -------------------------------------------------------------------------------- 1 | # Python imports. 2 | import random 3 | import numpy as np 4 | from collections import defaultdict 5 | 6 | class TransitionFunc(object): 7 | 8 | def __init__(self, transition_func_lambda, state_space, action_space, sample_rate=1): 9 | self.transition_dict = make_dict_from_lambda(transition_func_lambda, state_space, action_space, sample_rate) 10 | 11 | def transition_func(self, state, action): 12 | return self.transition_dict[state][action].keys()[list(np.random.multinomial(1, self.transition_dict[state][action].values()).tolist()).index(1)] 13 | 14 | def make_dict_from_lambda(transition_func_lambda, state_space, action_space, sample_rate=1): 15 | transition_dict = defaultdict(lambda:defaultdict(lambda:defaultdict(int))) 16 | for s in state_space: 17 | for a in action_space: 18 | for i in xrange(sample_rate): 19 | s_prime = transition_func_lambda(s, a) 20 | transition_dict[s][a][s_prime] += (1.0 / sample_rate) 21 | 22 | return transition_dict -------------------------------------------------------------------------------- /hierarch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/david-abel/rl_abstraction/cfe628abbee2614ff873713dceb466b293aa7329/hierarch/__init__.py -------------------------------------------------------------------------------- /hierarch/action_abstr_stack_helpers.py: -------------------------------------------------------------------------------- 1 | # Python imports. 2 | from os import path 3 | import sys 4 | 5 | # Other imports. 6 | parent_dir = path.dirname(path.dirname(path.abspath(__file__))) 7 | sys.path.append(parent_dir) 8 | import state_abstr_stack_helpers as sa_stack_helpers 9 | from StateAbstractionStackClass import StateAbstractionStack 10 | from ActionAbstractionStackClass import ActionAbstractionStack 11 | from action_abs.ActionAbstractionClass import ActionAbstraction 12 | from action_abs import aa_helpers 13 | from simple_rl.utils import make_mdp 14 | 15 | 16 | def make_random_sa_diropt_aa_stack(mdp_distr, max_num_levels=3): 17 | ''' 18 | Args: 19 | mdp_distr (MDPDistribution) 20 | max_num_levels (int) 21 | 22 | Returns: 23 | (tuple): 24 | 1. StateAbstraction 25 | 2. ActionAbstraction 26 | ''' 27 | 28 | # Get Abstractions by iterating over compression ratio. 29 | cluster_size_ratio, ratio_decr = 0.3, 0.05 30 | 31 | while cluster_size_ratio > 0.001: 32 | print "Abstraction ratio:", cluster_size_ratio 33 | 34 | # Make State Abstraction stack. 35 | sa_stack = sa_stack_helpers.make_random_sa_stack(mdp_distr, cluster_size_ratio=cluster_size_ratio, max_num_levels=max_num_levels) 36 | sa_stack.print_state_space_sizes() 37 | 38 | # Make action abstraction stack. 39 | aa_stack = make_directed_options_aa_from_sa_stack(mdp_distr, sa_stack) 40 | 41 | if not aa_stack: 42 | # Too many options. Decrement and continue. 43 | cluster_size_ratio -= ratio_decr 44 | continue 45 | else: 46 | break 47 | 48 | return sa_stack, aa_stack 49 | 50 | # ---------------------- 51 | # -- Directed Options -- 52 | # ---------------------- 53 | 54 | def make_directed_options_aa_from_sa_stack(mdp_distr, sa_stack): 55 | ''' 56 | Args: 57 | mdp_distr (MDPDistribution) 58 | sa_stack (StateAbstractionStack) 59 | 60 | Returns: 61 | (ActionAbstraction) 62 | ''' 63 | 64 | aa_stack = ActionAbstractionStack(list_of_aa=[], prim_actions=mdp_distr.get_actions()) 65 | 66 | for level in xrange(1, sa_stack.get_num_levels() + 1): 67 | 68 | # Make directed options for the current level. 69 | sa_stack.set_level(level) 70 | next_options = aa_helpers.get_directed_options_for_sa(mdp_distr, sa_stack, incl_self_loops=False) 71 | 72 | if not next_options: 73 | # Too many options, decrease abstracton ratio and continue. 74 | return False 75 | 76 | next_aa = ActionAbstraction(options=next_options, prim_actions=mdp_distr.get_actions()) 77 | aa_stack.add_aa(next_aa) 78 | 79 | return aa_stack 80 | 81 | 82 | def main(): 83 | # Make MDP Distribution. 84 | mdp_class = "four_room" 85 | environment = make_mdp.make_mdp_distr(mdp_class=mdp_class, grid_dim=10) 86 | 87 | make_random_sa_diropt_aa_stack(environment, max_num_levels=3) 88 | 89 | if __name__ == "__main__": 90 | main() 91 | -------------------------------------------------------------------------------- /hierarch/hierarchy_experiments.py: -------------------------------------------------------------------------------- 1 | 2 | # Other imports. 3 | from simple_rl.utils import make_mdp 4 | from simple_rl.agents import RandomAgent, QLearnerAgent, RMaxAgent 5 | from simple_rl.run_experiments import run_agents_multi_task 6 | from HierarchyAgentClass import HierarchyAgent 7 | from HRMaxAgentClass import HRMaxAgent 8 | from DynamicHierarchyAgentClass import DynamicHierarchyAgent 9 | import action_abstr_stack_helpers as aa_stack_h 10 | import hierarchy_helpers 11 | 12 | def main(): 13 | 14 | # ======================== 15 | # === Make Environment === 16 | # ======================== 17 | mdp_class = "four_room" 18 | gamma = 1.0 19 | environment = make_mdp.make_mdp_distr(mdp_class=mdp_class, step_cost=0.01, grid_dim=15, gamma=gamma) 20 | actions = environment.get_actions() 21 | 22 | 23 | # ========================== 24 | # === Make SA, AA Stacks === 25 | # ========================== 26 | sa_stack, aa_stack = hierarchy_helpers.make_hierarchy(environment, num_levels=2) 27 | 28 | # Debug. 29 | print "\n" + ("=" * 30) + "\n== Done making abstraction. ==\n" + ("=" * 30) + "\n" 30 | sa_stack.print_state_space_sizes() 31 | aa_stack.print_action_spaces_sizes() 32 | 33 | # =================== 34 | # === Make Agents === 35 | # =================== 36 | # baseline_agent = QLearnerAgent(actions) 37 | agent_class = QLearnerAgent 38 | baseline_agent = agent_class(actions, gamma=gamma) 39 | rand_agent = RandomAgent(actions) 40 | # hierarch_r_max = HRMaxAgent(actions, sa_stack=sa_stack, aa_stack=aa_stack) 41 | l0_hierarch_agent = HierarchyAgent(agent_class, sa_stack=sa_stack, aa_stack=aa_stack, cur_level=0, name_ext="-$l_0$") 42 | l1_hierarch_agent = HierarchyAgent(agent_class, sa_stack=sa_stack, aa_stack=aa_stack, cur_level=1, name_ext="-$l_1$") 43 | # l2_hierarch_agent = HierarchyAgent(agent_class, sa_stack=sa_stack, aa_stack=aa_stack, cur_level=2, name_ext="-$l_2$") 44 | dynamic_hierarch_agent = DynamicHierarchyAgent(agent_class, sa_stack=sa_stack, aa_stack=aa_stack, cur_level=1, name_ext="-$d$") 45 | # dynamic_rmax_hierarch_agent = DynamicHierarchyAgent(RMaxAgent, sa_stack=sa_stack, aa_stack=aa_stack, cur_level=1, name_ext="-$d$") 46 | 47 | print "\n" + ("=" * 26) 48 | print "== Running experiments. ==" 49 | print "=" * 26 + "\n" 50 | 51 | # ====================== 52 | # === Run Experiment === 53 | # ====================== 54 | agents = [dynamic_hierarch_agent, baseline_agent] 55 | 56 | run_agents_multi_task(agents, environment, task_samples=10, steps=20000, episodes=1, reset_at_terminal=True) 57 | 58 | if __name__ == "__main__": 59 | main() 60 | -------------------------------------------------------------------------------- /hierarch/hierarchy_helpers.py: -------------------------------------------------------------------------------- 1 | # Python imports. 2 | from os import path 3 | import sys 4 | 5 | # Other imports. 6 | parent_dir = path.dirname(path.dirname(path.abspath(__file__))) 7 | sys.path.append(parent_dir) 8 | import make_abstr_mdp 9 | from state_abs import sa_helpers, indicator_funcs 10 | from action_abs import aa_helpers 11 | from action_abs.ActionAbstractionClass import ActionAbstraction 12 | from simple_rl.utils import make_mdp 13 | from HierarchyStateClass import HierarchyState 14 | from StateAbstractionStackClass import StateAbstractionStack 15 | from ActionAbstractionStackClass import ActionAbstractionStack 16 | 17 | def make_hierarchy(mdp_distr, num_levels): 18 | ''' 19 | Args: 20 | mdp_distr (MDPDistribution) 21 | num_levels (int) 22 | 23 | Returns: 24 | (tuple) 25 | 1. StateAbstractionStack 26 | 2. ActionAbstractionStack 27 | 28 | Notes: 29 | A one layer hierarchy is *flat* (that is, just uses the ground MDP). A 30 | two layer hierarchy has a single abstract level and the ground level. 31 | ''' 32 | 33 | if num_levels <= 0: 34 | print "(hiearchy_helpers.py) Error: @num_levels must be > 0 (given value: " + str(num_levels) + ")." 35 | quit() 36 | 37 | sa_stack = StateAbstractionStack(list_of_phi=[]) 38 | aa_stack = ActionAbstractionStack(list_of_aa=[], prim_actions=mdp_distr.get_actions()) 39 | epsilon = 0.0 40 | 41 | for i in xrange(1, num_levels): 42 | print "\n" + "=" * 20 + "\n== Making layer " + str(i + 1) + " ==\n" + ("=" * 20) + "\n" 43 | sa_stack, aa_stack, epsilon = add_layer(mdp_distr, sa_stack, aa_stack, init_epsilon=epsilon) 44 | epsilon += 0.10 45 | 46 | return sa_stack, aa_stack 47 | 48 | def add_layer(mdp_distr, sa_stack, aa_stack, init_epsilon=0.0): 49 | ''' 50 | Args: 51 | mdp_distr (MDPDistribution) 52 | sa_stack (StateAbstractionStack) 53 | aa_stack (ActionAbstractionStack) 54 | init_epsilon (float) 55 | 56 | Returns: 57 | (tuple): 58 | 1. StateAbstractionStack 59 | 2. ActionAbstractionStack 60 | 3. (float): Final epsilon value. 61 | ''' 62 | 63 | # Get next abstractions by iterating over compression ratio. 64 | epsilon, epsilon_incr = init_epsilon, 0.005 65 | 66 | while epsilon < 1.0: 67 | print "Abstraction rate (epsilon):", epsilon 68 | 69 | # Set level to the largest shared between sa_stack and aa_stack. 70 | abstr_mdp_level = min(sa_stack.get_num_levels(), aa_stack.get_num_levels()) 71 | sa_stack.set_level(abstr_mdp_level) 72 | aa_stack.set_level(abstr_mdp_level) 73 | 74 | # Add layer to state abstraction stack. 75 | sa_stack = add_layer_to_sa_stack(mdp_distr, sa_stack, aa_stack, epsilon) 76 | 77 | # Add layer to action abstraction stack. 78 | aa_stack, is_too_many_options = add_layer_to_aa_stack(mdp_distr, sa_stack, aa_stack) 79 | 80 | if is_too_many_options: 81 | # Too many options. Decrement and continue. 82 | epsilon += epsilon_incr 83 | sa_stack.remove_last_phi() 84 | continue 85 | else: 86 | break 87 | 88 | return sa_stack, aa_stack, epsilon 89 | 90 | def add_layer_to_sa_stack(mdp_distr, sa_stack, aa_stack, epsilon): 91 | ''' 92 | Args: 93 | mdp_distr (MDPDistribution) 94 | sa_stack (StateAbstractionStack) 95 | aa_stack (ActionAbstractionStack) 96 | epsilon (float) 97 | 98 | Returns: 99 | (StateAbstractionStack) 100 | ''' 101 | 102 | # Check stack height. 103 | if sa_stack.get_num_levels() > 0: 104 | # Make abstract MDPs to compute higher level abstractions. 105 | abstr_mdp_distr = make_abstr_mdp.make_abstr_mdp_distr_multi_level(mdp_distr, sa_stack, aa_stack) 106 | else: 107 | abstr_mdp_distr = mdp_distr 108 | 109 | # Hand coded four room. 110 | new_sa = sa_helpers.make_sa(mdp_distr, indic_func=indicator_funcs._four_rooms) 111 | 112 | # Make new phi. 113 | # new_sa = sa_helpers.make_multitask_sa(abstr_mdp_distr, epsilon=epsilon) 114 | new_phi = _convert_abstr_states(new_sa._phi, sa_stack.get_num_levels() + 1) 115 | sa_stack.add_phi(new_phi) 116 | 117 | return sa_stack 118 | 119 | def _convert_abstr_states(phi_dict, level): 120 | ''' 121 | Args: 122 | phi_dict (dict) 123 | level (int) 124 | 125 | Returns: 126 | (phi_dict) 127 | 128 | Summary: 129 | Translates int based abstract states to HierarchyStates (which track their own level). 130 | ''' 131 | for key in phi_dict.keys(): 132 | state_num = phi_dict[key] 133 | phi_dict[key] = HierarchyState(data=state_num, level=level) 134 | 135 | return phi_dict 136 | 137 | def add_layer_to_aa_stack(mdp_distr, sa_stack, aa_stack): 138 | ''' 139 | Args: 140 | mdp_distr (MDPDistribution) 141 | sa_stack (StateAbstractionStack) 142 | aa_stack (ActionAbstractionStack) 143 | 144 | Returns: 145 | (tuple): 146 | 1. (ActionAbstractionStack) 147 | 2. (MDPDistribution) 148 | 3. (bool) 149 | ''' 150 | if aa_stack.get_num_levels() > 0: 151 | abstr_mdp_distr = make_abstr_mdp.make_abstr_mdp_distr_multi_level(mdp_distr, sa_stack, aa_stack) 152 | else: 153 | abstr_mdp_distr = mdp_distr 154 | 155 | # Make options for the level + 1 height. 156 | sa_stack.set_level_to_max() 157 | next_options = aa_helpers.get_directed_options_for_sa(abstr_mdp_distr, sa_stack, incl_self_loops=False, max_options=512 / (aa_stack.get_num_levels() + 1)) 158 | 159 | if not next_options: 160 | # Too many options, decrease abstracton ratio and continue. 161 | return aa_stack, True 162 | 163 | next_aa = ActionAbstraction(options=next_options, prim_actions=mdp_distr.get_actions()) 164 | 165 | aa_stack.add_aa(next_aa) 166 | return aa_stack, False 167 | 168 | def main(): 169 | 170 | # ====================== 171 | # == Make Environment == 172 | # ====================== 173 | mdp_class = "four_room" 174 | environment = make_mdp.make_mdp_distr(mdp_class=mdp_class, grid_dim=7) 175 | actions = environment.get_actions() 176 | 177 | # ==================== 178 | # == Make Hierarchy == 179 | # ==================== 180 | sa_stack, aa_stack = make_hierarchy(environment, num_levels=3) 181 | 182 | 183 | if __name__ == "__main__": 184 | main() -------------------------------------------------------------------------------- /hierarch/make_abstr_mdp.py: -------------------------------------------------------------------------------- 1 | # Python imports. 2 | from collections import defaultdict 3 | import numpy as np 4 | 5 | # Other imports. 6 | from simple_rl.planning import ValueIteration 7 | from simple_rl.mdp import MDP 8 | from simple_rl.mdp import MDPDistribution 9 | from RewardFuncClass import RewardFunc 10 | from TransitionFuncClass import TransitionFunc 11 | 12 | # ------------------ 13 | # -- Single Level -- 14 | # ------------------ 15 | 16 | def make_abstr_mdp(mdp, state_abstr, action_abstr, sample_rate=25): 17 | ''' 18 | Args: 19 | mdp (MDP) 20 | state_abstr (StateAbstraction) 21 | action_abstr (ActionAbstraction) 22 | sample_rate (int): Sample rate for computing the abstract R and T. 23 | 24 | Returns: 25 | (MDP) 26 | ''' 27 | 28 | # Grab ground state space. 29 | vi = ValueIteration(mdp) 30 | state_space = vi.get_states() 31 | 32 | # Make abstract reward and transition functions. 33 | def abstr_reward_lambda(abstr_state, abstr_action): 34 | # Get relevant MDP components from the lower MDP. 35 | lower_states = state_abstr.get_lower_states_in_abs_state(abstr_state) 36 | lower_reward_func = mdp.get_reward_func() 37 | lower_trans_func = mdp.get_transition_func() 38 | 39 | # Compute reward. 40 | total_reward = 0 41 | for ground_s in lower_states: 42 | for sample in xrange(sample_rate): 43 | s_prime, reward = abstr_action.rollout(ground_s, lower_reward_func, lower_trans_func) 44 | total_reward += float(reward) / (len(lower_states) * sample_rate) # Add weighted reward. 45 | 46 | # print "~"*20 47 | # print "R_A:", abstr_state, abstr_action, total_reward 48 | # print "~"*20 49 | 50 | return total_reward 51 | 52 | def abstr_transition_lambda(abstr_state, abstr_action): 53 | # print "Abstr Transition Func:" 54 | # print "\t abstr_state:", abstr_state 55 | # Get relevant MDP components from the lower MDP. 56 | lower_states = state_abstr.get_lower_states_in_abs_state(abstr_state) 57 | lower_reward_func = mdp.get_reward_func() 58 | lower_trans_func = mdp.get_transition_func() 59 | 60 | # Compute next state distribution. 61 | s_prime_prob_dict = defaultdict(int) 62 | total_reward = 0 63 | for ground_s in lower_states: 64 | for sample in xrange(sample_rate): 65 | s_prime, reward = abstr_action.rollout(ground_s, lower_reward_func, lower_trans_func) 66 | s_prime_prob_dict[s_prime] += (1.0 / (len(lower_states) * sample_rate)) # Weighted average. 67 | 68 | # Form distribution and sample s_prime. 69 | end_ground_state = s_prime_prob_dict.keys()[list(np.random.multinomial(1, s_prime_prob_dict.values()).tolist()).index(1)] 70 | end_abstr_state = state_abstr.phi(end_ground_state, level=abstr_state.get_level()) 71 | 72 | return end_abstr_state 73 | 74 | 75 | # Make the components of the MDP. 76 | abstr_init_state = state_abstr.phi(mdp.get_init_state()) 77 | abstr_action_space = action_abstr.get_actions() 78 | abstr_state_space = state_abstr.get_abs_states() 79 | abstr_reward_func = RewardFunc(abstr_reward_lambda, abstr_state_space, abstr_action_space) 80 | abstr_transition_func = TransitionFunc(abstr_transition_lambda, abstr_state_space, abstr_action_space, sample_rate=sample_rate) 81 | 82 | # Make the MDP. 83 | abstr_mdp = MDP(actions=abstr_action_space, 84 | init_state=abstr_init_state, 85 | reward_func=abstr_reward_func.reward_func, 86 | transition_func=abstr_transition_func.transition_func, 87 | gamma=0.5) 88 | 89 | return abstr_mdp 90 | 91 | def make_abstr_mdp_distr(mdp_distr, state_abstr, action_abstr): #, step_cost=0.1): 92 | ''' 93 | Args: 94 | mdp_distr (MDPDistribution) 95 | state_abstr (StateAbstraction) 96 | action_abstr (ActionAbstraction) 97 | 98 | Returns: 99 | (MDPDistribution) 100 | ''' 101 | 102 | # Loop through old mdps and abstract. 103 | mdp_distr_dict = {} 104 | for mdp in mdp_distr.get_all_mdps(): 105 | abstr_mdp = make_abstr_mdp(mdp, state_abstr, action_abstr) #, step_cost=step_cost) 106 | prob_of_abstr_mdp = mdp_distr.get_prob_of_mdp(mdp) 107 | mdp_distr_dict[abstr_mdp] = prob_of_abstr_mdp 108 | 109 | return MDPDistribution(mdp_distr_dict) 110 | 111 | # ----------------- 112 | # -- Multi Level -- 113 | # ----------------- 114 | 115 | def make_abstr_mdp_multi_level(mdp, state_abstr_stack, action_abstr_stack, sample_rate=5): 116 | ''' 117 | Args: 118 | mdp (MDP) 119 | state_abstr_stack (StateAbstractionStack) 120 | action_abstr_stack (ActionAbstractionStack) 121 | sample_rate (int): Sample rate for computing the abstract R and T. 122 | 123 | Returns: 124 | (MDP) 125 | ''' 126 | mdp_level = min(state_abstr_stack.get_num_levels(), action_abstr_stack.get_num_levels()) 127 | 128 | for i in xrange(1, mdp_level + 1): 129 | state_abstr_stack.set_level(i) 130 | action_abstr_stack.set_level(i) 131 | mdp = make_abstr_mdp(mdp, state_abstr_stack, action_abstr_stack, sample_rate) 132 | 133 | return mdp 134 | 135 | 136 | def make_abstr_mdp_distr_multi_level(mdp_distr, state_abstr, action_abstr): 137 | ''' 138 | Args: 139 | mdp_distr (MDPDistribution) 140 | state_abstr (StateAbstraction) 141 | action_abstr (ActionAbstraction) 142 | 143 | Returns: 144 | (MDPDistribution) 145 | ''' 146 | 147 | # Loop through old mdps and abstract. 148 | mdp_distr_dict = {} 149 | for mdp in mdp_distr.get_all_mdps(): 150 | abstr_mdp = make_abstr_mdp_multi_level(mdp, state_abstr, action_abstr) 151 | prob_of_abstr_mdp = mdp_distr.get_prob_of_mdp(mdp) 152 | mdp_distr_dict[abstr_mdp] = prob_of_abstr_mdp 153 | 154 | return MDPDistribution(mdp_distr_dict) 155 | 156 | def _rew_dict_from_lambda(input_lambda, state_space, action_space, sample_rate): 157 | result_dict = defaultdict(lambda:defaultdict(float)) 158 | for s in state_space: 159 | for a in action_space: 160 | for i in xrange(sample_rate): 161 | result_dict[s][a] = input_lambda(s,a) / sample_rate 162 | 163 | return result_dict 164 | 165 | -------------------------------------------------------------------------------- /hierarch/state_abstr_stack_helpers.py: -------------------------------------------------------------------------------- 1 | # Python imports. 2 | import random 3 | from os import path 4 | import sys 5 | 6 | # Other imports. 7 | from HierarchyStateClass import HierarchyState 8 | from simple_rl.utils import make_mdp 9 | from simple_rl.planning.ValueIterationClass import ValueIteration 10 | parent_dir = path.dirname(path.dirname(path.abspath(__file__))) 11 | sys.path.append(parent_dir) 12 | from state_abs.StateAbstractionClass import StateAbstraction 13 | from state_abs import sa_helpers 14 | from StateAbstractionStackClass import StateAbstractionStack 15 | import make_abstr_mdp 16 | 17 | 18 | 19 | # ---------------------------------- 20 | # -- Make State Abstraction Stack -- 21 | # ---------------------------------- 22 | 23 | def make_random_sa_stack(mdp_distr, cluster_size_ratio=0.5, max_num_levels=2): 24 | ''' 25 | Args: 26 | mdp_distr (MDPDistribution) 27 | cluster_size_ratio (float): A float in (0,1) that determines the size of the abstract state space. 28 | max_num_levels (int): Determines the _total_ number of levels in the hierarchy (includes ground). 29 | 30 | Returns: 31 | (StateAbstraction) 32 | ''' 33 | 34 | # Get ground state space. 35 | vi = ValueIteration(mdp_distr.get_all_mdps()[0], delta=0.0001, max_iterations=5000) 36 | ground_state_space = vi.get_states() 37 | sa_stack = StateAbstractionStack(list_of_phi=[]) 38 | 39 | # Each loop adds a stack. 40 | for i in xrange(max_num_levels - 1): 41 | 42 | # Grab curent state space (at level i). 43 | cur_state_space = _get_level_i_state_space(ground_state_space, sa_stack, i) 44 | cur_state_space_size = len(cur_state_space) 45 | 46 | if int(cur_state_space_size / cluster_size_ratio) <= 1: 47 | # The abstract is as small as it can get. 48 | break 49 | 50 | # Add the mapping. 51 | new_phi = {} 52 | for s in cur_state_space: 53 | new_phi[s] = HierarchyState(data=random.randint(1, max(int(cur_state_space_size * cluster_size_ratio), 1)), level=i + 1) 54 | 55 | if len(set(new_phi.values())) <= 1: 56 | # The abstract is as small as it can get. 57 | break 58 | 59 | # Add the sa to the stack. 60 | sa_stack.add_phi(new_phi) 61 | 62 | return sa_stack 63 | 64 | def _get_level_i_state_space(ground_state_space, state_abstr_stack, level): 65 | ''' 66 | Args: 67 | mdp_distr (MDPDistribution) 68 | state_abstr_stack (StateAbstractionStack) 69 | level (int) 70 | 71 | Returns: 72 | (list) 73 | ''' 74 | level_i_state_space = set([]) 75 | for s in ground_state_space: 76 | level_i_state_space.add(state_abstr_stack.phi(s, level)) 77 | 78 | return list(level_i_state_space) 79 | 80 | 81 | def main(): 82 | # Make MDP Distribution. 83 | mdp_class = "four_room" 84 | environment = make_mdp.make_mdp_distr(mdp_class=mdp_class, grid_dim=10) 85 | 86 | sa_stack = make_random_sa_stack(environment, max_num_levels=5) 87 | sa_stack.print_state_space_sizes() 88 | 89 | 90 | if __name__ == "__main__": 91 | main() -------------------------------------------------------------------------------- /hierarch_rooms.txt: -------------------------------------------------------------------------------- 1 | gw-------w-w----------g 2 | ---w---w---w-w--w------ 3 | -w---w---w---w--w-w---- 4 | ---w---w---w-w--w------ 5 | -w---w---w-w-w--w-w---- 6 | ---w---w---w-w--w-w---- 7 | gw---w---w-w----w-----g 8 | wwww-wwwwwwwwwwww-wwwww 9 | ----------gw----------- 10 | ----w---w--wwwwwwwwwww- 11 | --w-w-www--w---------w- 12 | --w-w-www-------g----w- 13 | --w-w-www--w---------w- 14 | ----w------w-wwwww-www- 15 | a----------w----------g -------------------------------------------------------------------------------- /octogrid.txt: -------------------------------------------------------------------------------- 1 | wwwwgwgwgwwww 2 | wwww-w-w-wwww 3 | wwww-w-w-wwww 4 | wwww-w-w-wwww 5 | g-----------g 6 | wwww-----wwww 7 | g-----a-----g 8 | wwww-----wwww 9 | g-----------g 10 | wwww-w-w-wwww 11 | wwww-w-w-wwww 12 | wwww-w-w-wwww 13 | wwwwgwgwgwwww -------------------------------------------------------------------------------- /pblocks_grid.txt: -------------------------------------------------------------------------------- 1 | --w-----w---w----- 2 | --------w--------- 3 | --w-----w---w----- 4 | --w-----w---w----- 5 | wwwww-wwwwwwwww-ww 6 | ---w----w----w---- 7 | ---w---------w---- 8 | --------w--------- 9 | wwwwwwwww--------- 10 | w-------wwwwwww-ww 11 | --w-----w---w----- 12 | --------w--------- 13 | --w---------w----- 14 | --w-----w---w----- 15 | wwwww-wwwwwwwww-ww 16 | ---w-----w---w---- 17 | ---w-----w---w---- 18 | ---------w-------- -------------------------------------------------------------------------------- /run_icml_learning_experiments.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Python imports. 4 | import subprocess 5 | 6 | # Other imports. 7 | from simple_rl.agents import QLearningAgent, DelayedQAgent 8 | 9 | def spawn_subproc(task, samples, steps, episodes, grid_dim, agent_class): 10 | ''' 11 | Args: 12 | task (str) 13 | samples (int) 14 | steps (int) 15 | steps (int) 16 | grid_dim (int) 17 | 18 | Summary: 19 | Spawns a child subprocess to run the experiment. 20 | ''' 21 | cmd = ['./simple_sa_experiments.py', \ 22 | '-task=' + str(task), \ 23 | '-samples=' + str(samples), \ 24 | '-episodes=' + str(episodes), 25 | '-steps=' + str(steps), \ 26 | '-grid_dim=' + str(grid_dim), \ 27 | '-agent=' + str(agent_class)] 28 | 29 | subprocess.Popen(cmd) 30 | 31 | def main(): 32 | 33 | episodes = 100 34 | samples = 100 35 | 36 | # Color. 37 | # Figure 3a 38 | spawn_subproc(task="color", samples=samples, episodes=episodes, steps=250, grid_dim=11, agent_class=DelayedQAgent) 39 | # Figure 3b 40 | # spawn_subproc(task="color", samples=samples, episodes=episodes, steps=250, grid_dim=11, agent_class=QLearningAgent) 41 | 42 | # Four rooms. 43 | # spawn_subproc(task="four_room", samples=25, episodes=100, steps=500, grid_dim=30, agent_class=DelayedQAgent) 44 | 45 | 46 | if __name__ == "__main__": 47 | main() -------------------------------------------------------------------------------- /run_icml_planning_experiments.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Python imports. 4 | import subprocess 5 | 6 | # Other imports. 7 | 8 | def spawn_subproc(): 9 | ''' 10 | Args: 11 | 12 | Summary: 13 | Spawns a child subprocess to run the experiment. 14 | ''' 15 | cmd = ['./simple_planning_experiments.py'] 16 | 17 | subprocess.Popen(cmd) 18 | 19 | def main(): 20 | spawn_subproc() 21 | 22 | 23 | if __name__ == "__main__": 24 | main() -------------------------------------------------------------------------------- /simple_planning_experiments.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Python imports. 4 | import os 5 | import time 6 | 7 | # Other imports. 8 | from utils import make_mdp 9 | from simple_rl.planning import ValueIteration 10 | from utils.AbstractValueIterationClass import AbstractValueIteration 11 | from state_abs import indicator_funcs as ind_funcs 12 | from abstraction_experiments import get_sa 13 | 14 | 15 | def clear_files(dir_name): 16 | ''' 17 | Args: 18 | dir_name (str) 19 | 20 | Summary: 21 | Removes all csv files in @dir_name. 22 | ''' 23 | for extension in ["iters", "times"]: 24 | dir_w_extension = os.path.join(dir_name, extension) # , mdp_type) + ".csv" 25 | if not os.path.exists(dir_w_extension): 26 | os.makedirs(dir_w_extension) 27 | 28 | for mdp_type in ["vi", "vi-$\phi_{Q_d^*}$"]: 29 | if os.path.exists(os.path.join(dir_w_extension, mdp_type) + ".csv"): 30 | os.remove(os.path.join(dir_w_extension, mdp_type) + ".csv") 31 | 32 | 33 | def write_datum(file_name, datum): 34 | ''' 35 | Args: 36 | file_name (str) 37 | datum (object) 38 | ''' 39 | out_file = open(file_name, "a+") 40 | out_file.write(str(datum) + ",") 41 | out_file.close() 42 | 43 | 44 | def main(): 45 | # Grab experiment params. 46 | # Switch between Upworld and Trench 47 | mdp_class = "upworld" 48 | # mdp_class = "trench" 49 | grid_lim = 20 if mdp_class == 'upworld' else 7 50 | gamma = 0.95 51 | vanilla_file = "vi.csv" 52 | sa_file = "vi-$\phi_{Q_d^*}.csv" 53 | file_prefix = "results/planning-" + mdp_class + "/" 54 | clear_files(dir_name=file_prefix) 55 | 56 | for grid_dim in xrange(3, grid_lim): 57 | # ====================== 58 | # == Make Environment == 59 | # ====================== 60 | environment = make_mdp.make_mdp(mdp_class=mdp_class, grid_dim=grid_dim) 61 | environment.set_gamma(gamma) 62 | 63 | # ======================= 64 | # == Make Abstractions == 65 | # ======================= 66 | sa_qds = get_sa(environment, indic_func=ind_funcs._q_disc_approx_indicator, epsilon=0.01) 67 | 68 | # ============ 69 | # == Run VI == 70 | # ============ 71 | vanilla_vi = ValueIteration(environment, delta=0.0001, sample_rate=15) 72 | sa_vi = AbstractValueIteration(ground_mdp=environment, state_abstr=sa_qds) 73 | 74 | print "Running VIs." 75 | start_time = time.clock() 76 | vanilla_iters, vanilla_val = vanilla_vi.run_vi() 77 | vanilla_time = round(time.clock() - start_time, 2) 78 | 79 | start_time = time.clock() 80 | sa_iters, sa_val = sa_vi.run_vi() 81 | sa_time = round(time.clock() - start_time, 2) 82 | 83 | print "vanilla", vanilla_iters, vanilla_val, vanilla_time 84 | print "sa:", sa_iters, sa_val, sa_time 85 | 86 | write_datum(file_prefix + "iters/" + vanilla_file, vanilla_iters) 87 | write_datum(file_prefix + "iters/" + sa_file, sa_iters) 88 | 89 | write_datum(file_prefix + "times/" + vanilla_file, vanilla_time) 90 | write_datum(file_prefix + "times/" + sa_file, sa_time) 91 | 92 | 93 | if __name__ == "__main__": 94 | main() 95 | -------------------------------------------------------------------------------- /simple_sa_experiments.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Python imports. 4 | import random 5 | from collections import defaultdict 6 | import os 7 | import argparse 8 | 9 | # Other imports. 10 | from simple_rl.agents import RandomAgent, RMaxAgent, QLearningAgent, DelayedQAgent, DoubleQAgent, FixedPolicyAgent 11 | from simple_rl.run_experiments import run_agents_lifelong, run_agents_on_mdp 12 | from simple_rl.tasks import FourRoomMDP, HanoiMDP 13 | from simple_rl.planning import ValueIteration 14 | from simple_rl.abstraction.state_abs.StateAbstractionClass import StateAbstraction 15 | from simple_rl.abstraction.action_abs.ActionAbstractionClass import ActionAbstraction 16 | from simple_rl.abstraction.AbstractValueIterationClass import AbstractValueIteration 17 | from simple_rl.abstraction.AbstractionWrapperClass import AbstractionWrapper 18 | from state_abs import indicator_funcs as ind_funcs 19 | from abstraction_experiments import compute_pac_sa, get_sa, get_directed_option_sa_pair 20 | from utils.StochasticSAPolicyClass import StochasticSAPolicy 21 | from utils import make_mdp 22 | 23 | def get_exact_vs_approx_agents(environment, incl_opt=True): 24 | ''' 25 | Args: 26 | environment (simple_rl.MDPDistribution) 27 | incl_opt (bool) 28 | 29 | Returns: 30 | (list) 31 | ''' 32 | 33 | actions = environment.get_actions() 34 | gamma = environment.get_gamma() 35 | 36 | exact_qds_test = get_sa(environment, indic_func=ind_funcs._q_eps_approx_indicator, epsilon=0.0) 37 | approx_qds_test = get_sa(environment, indic_func=ind_funcs._q_eps_approx_indicator, epsilon=0.05) 38 | 39 | ql_agent = QLearningAgent(actions, gamma=gamma, epsilon=0.1, alpha=0.05) 40 | ql_exact_agent = AbstractionWrapper(QLearningAgent, agent_params={"actions":actions}, state_abstr=exact_qds_test, name_ext="-exact") 41 | ql_approx_agent = AbstractionWrapper(QLearningAgent, agent_params={"actions":actions}, state_abstr=approx_qds_test, name_ext="-approx") 42 | ql_agents = [ql_agent, ql_exact_agent, ql_approx_agent] 43 | 44 | dql_agent = DoubleQAgent(actions, gamma=gamma, epsilon=0.1, alpha=0.05) 45 | dql_exact_agent = AbstractionWrapper(DoubleQAgent, agent_params={"actions":actions}, state_abstr=exact_qds_test, name_ext="-exact") 46 | dql_approx_agent = AbstractionWrapper(DoubleQAgent, agent_params={"actions":actions}, state_abstr=approx_qds_test, name_ext="-approx") 47 | dql_agents = [dql_agent, dql_exact_agent, dql_approx_agent] 48 | 49 | rm_agent = RMaxAgent(actions, gamma=gamma) 50 | rm_exact_agent = AbstractionWrapper(RMaxAgent, agent_params={"actions":actions}, state_abstr=exact_qds_test, name_ext="-exact") 51 | rm_approx_agent = AbstractionWrapper(RMaxAgent, agent_params={"actions":actions}, state_abstr=approx_qds_test, name_ext="-approx") 52 | rm_agents = [rm_agent, rm_exact_agent, rm_approx_agent] 53 | 54 | if incl_opt: 55 | vi = ValueIteration(environment) 56 | vi.run_vi() 57 | opt_agent = FixedPolicyAgent(vi.policy, name="$\pi^*$") 58 | 59 | sa_vi = AbstractValueIteration(environment, sample_rate=50, max_iterations=3000, delta=0.0001, state_abstr=approx_qds_test, action_abstr=ActionAbstraction(options=[], prim_actions=environment.get_actions())) 60 | sa_vi.run_vi() 61 | approx_opt_agent = FixedPolicyAgent(sa_vi.policy, name="$\pi_\phi^*$") 62 | 63 | dql_agents += [opt_agent, approx_opt_agent] 64 | 65 | return ql_agents 66 | 67 | 68 | def get_sa_experiment_agents(environment, AgentClass, pac=False): 69 | ''' 70 | Args: 71 | environment (simple_rl.MDPDistribution) 72 | AgentClass (Class) 73 | 74 | Returns: 75 | (list) 76 | ''' 77 | actions = environment.get_actions() 78 | gamma = environment.get_gamma() 79 | 80 | if pac: 81 | # PAC State Abstractions. 82 | sa_qds_test = compute_pac_sa(environment, indic_func=ind_funcs._q_disc_approx_indicator, epsilon=0.2) 83 | sa_qs_test = compute_pac_sa(environment, indic_func=ind_funcs._q_eps_approx_indicator, epsilon=0.2) 84 | sa_qs_exact_test = compute_pac_sa(environment, indic_func=ind_funcs._q_eps_approx_indicator, epsilon=0.0) 85 | 86 | else: 87 | # Compute state abstractions. 88 | sa_qds_test = get_sa(environment, indic_func=ind_funcs._q_disc_approx_indicator, epsilon=0.1) 89 | sa_qs_test = get_sa(environment, indic_func=ind_funcs._q_eps_approx_indicator, epsilon=0.1) 90 | sa_qs_exact_test = get_sa(environment, indic_func=ind_funcs._q_eps_approx_indicator, epsilon=0.0) 91 | 92 | # Make Agents. 93 | agent = AgentClass(actions, gamma=gamma) 94 | params = {"actions":actions} if AgentClass is not RMaxAgent else {"actions":actions, "s_a_threshold":2, "horizon":5} 95 | sa_qds_agent = AbstractionWrapper(AgentClass, agent_params=params, state_abstr=sa_qds_test, name_ext="$-\phi_{Q_d^*}$") 96 | sa_qs_agent = AbstractionWrapper(AgentClass, agent_params=params, state_abstr=sa_qs_test, name_ext="$-\phi_{Q_\epsilon^*}$") 97 | sa_qs_exact_agent = AbstractionWrapper(AgentClass, agent_params=params, state_abstr=sa_qs_exact_test, name_ext="-$\phi_{Q^*}$") 98 | 99 | agents = [agent, sa_qds_agent, sa_qs_agent, sa_qs_exact_agent] 100 | 101 | # if isinstance(environment.sample(), FourRoomMDP) or isinstance(environment.sample(), ColorMDP): 102 | # # If it's a fourroom add the handcoded one. 103 | # sa_hand_test = get_sa(environment, indic_func=ind_funcs._four_rooms) 104 | # sa_hand_agent = AbstractionWrapper(AgentClass, agent_params=params, state_abstr=sa_hand_test, name_ext="$-\phi_h$") 105 | # agents += [sa_hand_agent] 106 | 107 | return agents 108 | 109 | def get_combo_experiment_agents(environment): 110 | ''' 111 | Args: 112 | environment (simple_rl.MDPDistribution) 113 | 114 | Returns: 115 | (list) 116 | ''' 117 | actions = environment.get_actions() 118 | gamma = environment.get_gamma() 119 | 120 | sa, aa = get_directed_option_sa_pair(environment, indic_func=ind_funcs._q_disc_approx_indicator, max_options=100) 121 | sa_qds_test = get_sa(environment, indic_func=ind_funcs._q_disc_approx_indicator, epsilon=0.05) 122 | sa_qs_test = get_sa(environment, indic_func=ind_funcs._q_eps_approx_indicator, epsilon=0.1) 123 | 124 | # QLearner. 125 | ql_agent = QLearningAgent(actions, gamma=gamma, epsilon=0.1, alpha=0.05) 126 | rmax_agent = RMaxAgent(actions, gamma=gamma, epsilon=0.1, alpha=0.05) 127 | 128 | # Combos. 129 | ql_sa_qds_agent = AbstractionWrapper(QLearningAgent, agent_params={"actions":actions}, state_abstr=sa_qds_test, name_ext="$\phi_{Q_d^*}$") 130 | ql_sa_qs_agent = AbstractionWrapper(QLearningAgent, agent_params={"actions":actions}, state_abstr=sa_qs_test, name_ext="$\phi_{Q_\epsilon^*}$") 131 | 132 | # sa_agent = AbstractionWrapper(QLearningAgent, actions, str(environment), state_abstr=sa, name_ext="sa") 133 | aa_agent = AbstractionWrapper(QLearningAgent, agent_params={"actions":actions}, action_abstr=aa, name_ext="aa") 134 | sa_aa_agent = AbstractionWrapper(QLearningAgent, agent_params={"actions":actions}, state_abstr=sa, action_abstr=aa, name_ext="$\phi_{Q_d^*}+aa$") 135 | 136 | agents = [ql_agent, ql_sa_qds_agent, ql_sa_qs_agent, aa_agent, sa_aa_agent] 137 | 138 | return agents 139 | 140 | def get_optimal_policies(environment): 141 | ''' 142 | Args: 143 | environment (simple_rl.MDPDistribution) 144 | 145 | Returns: 146 | (list) 147 | ''' 148 | 149 | # Make State Abstraction 150 | approx_qds_test = get_sa(environment, indic_func=ind_funcs._q_eps_approx_indicator, epsilon=0.05) 151 | 152 | # True Optimal 153 | true_opt_vi = ValueIteration(environment) 154 | true_opt_vi.run_vi() 155 | opt_agent = FixedPolicyAgent(true_opt_vi.policy, "$\pi^*$") 156 | 157 | # Optimal Abstraction 158 | opt_det_vi = AbstractValueIteration(environment, state_abstr=approx_qds_test, sample_rate=30) 159 | opt_det_vi.run_vi() 160 | opt_det_agent = FixedPolicyAgent(opt_det_vi.policy, name="$\pi_{\phi}^*$") 161 | 162 | stoch_policy_obj = StochasticSAPolicy(approx_qds_test, environment) 163 | stoch_agent = FixedPolicyAgent(stoch_policy_obj.policy, "$\pi(a \mid s_\phi )$") 164 | 165 | ql_agents = [opt_agent, stoch_agent, opt_det_agent] 166 | 167 | return ql_agents 168 | 169 | def parse_args(): 170 | ''' 171 | Summary: 172 | Parse all arguments 173 | ''' 174 | parser = argparse.ArgumentParser() 175 | parser.add_argument("-task", type = str, default = "four_room", nargs = '?', help = "Choose the mdp type (one of {octo, hall, grid, taxi, four_room}).") 176 | parser.add_argument("-samples", type = int, default = 50, nargs = '?', help = "Number of samples from the MDP Distribution.") 177 | parser.add_argument("-steps", type = int, default = 100, nargs = '?', help = "Number of steps for the experiment.") 178 | parser.add_argument("-episodes", type = int, default = 1, nargs = '?', help = "Number of episodes for the experiment.") 179 | parser.add_argument("-grid_dim", type = int, default = 11, nargs = '?', help = "Dimensions of the grid world.") 180 | parser.add_argument("-agent", type = str, default='ql', nargs = '?', help = "Specify agent class (one of {'ql', 'rmax'})..") 181 | args = parser.parse_args() 182 | 183 | return args.task, args.samples, args.episodes, args.steps, args.grid_dim, args.agent 184 | 185 | def get_params(set_manually=False): 186 | ''' 187 | Args: 188 | set_manually (bool) 189 | 190 | Returns: 191 | (tuple) 192 | ''' 193 | 194 | if set_manually: 195 | # Grab experiment params. 196 | mdp_class = "four_room" 197 | task_samples = 5 198 | episodes = 100 199 | steps = 250 200 | grid_dim = 9 201 | AgentClass = QLearningAgent 202 | else: 203 | # Grab experiment params. 204 | mdp_class, task_samples, episodes, steps, grid_dim, agent_class_str = parse_args() 205 | 206 | print(mdp_class) 207 | if "DelayedQAgent" in agent_class_str: 208 | AgentClass = DelayedQAgent 209 | else: 210 | AgentClass = QLearningAgent 211 | 212 | return mdp_class, task_samples, episodes, steps, grid_dim, AgentClass 213 | 214 | def main(): 215 | 216 | # Set Params. 217 | mdp_class, task_samples, episodes, steps, grid_dim, AgentClass = get_params(set_manually=False) 218 | experiment_type = "sa" 219 | lifelong = True 220 | resample_at_terminal = False 221 | reset_at_terminal = False 222 | gamma = 0.95 223 | 224 | # ====================== 225 | # == Make Environment == 226 | # ====================== 227 | environment = make_mdp.make_mdp_distr(mdp_class=mdp_class, grid_dim=grid_dim) if lifelong else make_mdp.make_mdp(mdp_class=mdp_class, grid_dim=grid_dim) 228 | environment.set_gamma(gamma) 229 | 230 | # ================= 231 | # == Make Agents == 232 | # ================= 233 | agents = [] 234 | if experiment_type == "sa": 235 | # SA experiment. 236 | agents = get_sa_experiment_agents(environment, AgentClass) 237 | elif experiment_type == "combo": 238 | # AA experiment. 239 | agents = get_combo_experiment_agents(environment) 240 | elif experiment_type == "exact_v_approx": 241 | agents = get_exact_vs_approx_agents(environment, incl_opt=(not multi_task)) 242 | elif experiment_type == "opt": 243 | agents = get_optimal_policies(environment) 244 | else: 245 | print "Experiment Error: experiment type unknown (" + experiment_type + "). Must be one of {sa, combo, exact_v_approx}." 246 | quit() 247 | 248 | # Run! 249 | if lifelong: 250 | run_agents_lifelong(agents, environment, samples=task_samples, steps=steps, episodes=episodes, reset_at_terminal=reset_at_terminal, resample_at_terminal=resample_at_terminal, cumulative_plot=True, clear_old_results=True) 251 | else: 252 | run_agents_on_mdp(agents, environment, instances=task_samples, steps=steps, episodes=episodes, reset_at_terminal=reset_at_terminal, track_disc_reward=False) 253 | 254 | if __name__ == "__main__": 255 | main() 256 | -------------------------------------------------------------------------------- /state_abs/StateAbstractionClass.py: -------------------------------------------------------------------------------- 1 | # Python imports. 2 | from collections import defaultdict 3 | 4 | # Other imports. 5 | from simple_rl.mdp.StateClass import State 6 | from simple_rl.mdp.MDPClass import MDP 7 | 8 | class StateAbstraction(object): 9 | 10 | def __init__(self, phi, state_class=State, track_act_opt_pr=False): 11 | ''' 12 | Args: 13 | phi (dict) 14 | state_class (Class) 15 | track_act_opt_pr (bool): If true, tracks the probability with which 16 | each action is optimal in each ground state w.r.t. the distribution. 17 | ''' 18 | self._phi = phi # key:state, val:int. (int represents an abstract state). 19 | self.state_class = state_class 20 | self.track_act_opt_pr = track_act_opt_pr 21 | if self.track_act_opt_pr: 22 | self.phi_act_optimality_dict = defaultdict(lambda:defaultdict(float)) 23 | # Key: Ground State 24 | # Val: Dict 25 | # Key: Action. 26 | # Val: Probability it's optimal. 27 | else: 28 | self.phi_act_optimality_dict = defaultdict(set) 29 | 30 | def get_act_opt_dict(self): 31 | return self.phi_act_optimality_dict 32 | 33 | def set_act_opt_dict(self, new_dict): 34 | if self.track_act_opt_pr and (len(new_dict.keys()) == 0 or isinstance(new_dict.keys()[0], dict)): 35 | print "State Abstraction Error: Tried setting optimality dict of incorrect type. Must be K:state, V:dict (K: action, V: probability)." 36 | quit() 37 | self.phi_act_optimality_dict = new_dict 38 | 39 | def set_actions_state_opt_dict(self, ground_state, action_set, prob_of_mdp=1.0): 40 | ''' 41 | Args: 42 | ground_state (State) 43 | action (str) 44 | 45 | Summary: 46 | Tracks optimal actions in each abstract state. 47 | ''' 48 | if self.track_act_opt_pr: 49 | for a in action_set: 50 | self.phi_act_optimality_dict[ground_state][a] = prob_of_mdp 51 | else: 52 | self.phi_act_optimality_dict[ground_state] = action_set 53 | 54 | def set_phi(self, new_phi): 55 | self._phi = new_phi 56 | 57 | def phi(self, state): 58 | ''' 59 | Args: 60 | state (State) 61 | 62 | Returns: 63 | state (State) 64 | ''' 65 | # Setup phi for new states. 66 | if state not in self._phi.keys(): 67 | if len(self._phi.values()) > 0: 68 | self._phi[state] = max(self._phi.values()) + 1 69 | else: 70 | self._phi[state] = 1 71 | 72 | abstr_state = self.state_class(self._phi[state]) 73 | abstr_state.set_terminal(state.is_terminal()) 74 | 75 | return abstr_state 76 | 77 | def make_cluster(self, list_of_ground_states): 78 | if len(list_of_ground_states) == 0: 79 | return 80 | 81 | abstract_value = 0 82 | if len(self._phi.values()) != 0: 83 | abstract_value = max(self._phi.values()) + 1 84 | 85 | for state in list_of_ground_states: 86 | self._phi[state] = abstract_value 87 | 88 | def get_ground_states_in_abs_state(self, abs_state): 89 | ''' 90 | Args: 91 | abs_state (State) 92 | 93 | Returns: 94 | (list): Contains all ground states in the cluster. 95 | ''' 96 | return [s_g for s_g in self.get_ground_states() if self.phi(s_g) == abs_state] 97 | 98 | def get_lower_states_in_abs_state(self, abs_state): 99 | ''' 100 | Args: 101 | abs_state (State) 102 | 103 | Returns: 104 | (list): Contains all ground states in the cluster. 105 | 106 | Notes: 107 | Here to simplify the state abstraction stack subclass. 108 | ''' 109 | return self.get_ground_states_in_abs_state(abs_state) 110 | 111 | 112 | def get_abs_states(self): 113 | # For each ground state, get its abstract state. 114 | return set([self.phi(val) for val in set(self._phi.keys())]) 115 | 116 | def get_abs_cluster_num(self, abs_state): 117 | # FIX: Specific to one abstract state class. 118 | return list(set(self._phi.values())).index(abs_state.data) 119 | 120 | def get_ground_states(self): 121 | return self._phi.keys() 122 | 123 | def get_lower_states(self): 124 | return self.get_ground_states() 125 | 126 | def get_num_abstr_states(self): 127 | return len(set(self._phi.values())) 128 | 129 | def get_num_ground_states(self): 130 | return len(set(self._phi.keys())) 131 | 132 | def reset(self): 133 | self._phi = {} 134 | 135 | def __add__(self, other_abs): 136 | ''' 137 | Args: 138 | other_abs 139 | ''' 140 | merged_state_abs = {} 141 | 142 | # Move the phi into a cluster dictionary. 143 | cluster_dict = defaultdict(set) 144 | for k, v in self._phi.iteritems(): 145 | # Cluster dict: v is abstract, key is ground. 146 | cluster_dict[v].add(k) 147 | 148 | # Move the phi into a cluster dictionary. 149 | other_cluster_dict = defaultdict(set) 150 | for k, v in other_abs._phi.iteritems(): 151 | other_cluster_dict[v].add(k) 152 | 153 | for ground_state in self._phi.keys(): 154 | # Get the two clusters (ints that define abstr states) associated with a state. 155 | states_cluster = self._phi[ground_state] 156 | if ground_state in other_abs._phi.keys(): 157 | # Only add if it's in both clusters. 158 | states_other_cluster = other_abs._phi[ground_state] 159 | else: 160 | continue 161 | 162 | for s_g in cluster_dict[states_cluster]: 163 | if s_g in other_cluster_dict[states_other_cluster]: 164 | # Grab every ground state that's in both clusters and put them in the new cluster. 165 | merged_state_abs[s_g] = states_cluster 166 | 167 | new_sa = StateAbstraction(phi=merged_state_abs, track_act_opt_pr=self.track_act_opt_pr) 168 | 169 | # Build the new action optimality dictionary. 170 | if self.track_act_opt_pr: 171 | # Grab the two action optimality dictionaries. 172 | opt_dict = self.get_act_opt_dict() 173 | other_opt_dict = other_abs.get_act_opt_dict() 174 | 175 | # If we're tracking the action's probability. 176 | new_dict = defaultdict(lambda:defaultdict(float)) 177 | for s_g in self.get_ground_states(): 178 | for a_g in opt_dict[s_g].keys() + other_opt_dict[s_g].keys(): 179 | new_dict[s_g][a_g] = opt_dict[s_g][a_g] + other_opt_dict[s_g][a_g] 180 | new_sa.set_act_opt_dict(new_dict) 181 | 182 | return new_sa 183 | 184 | -------------------------------------------------------------------------------- /state_abs/__init__.py: -------------------------------------------------------------------------------- 1 | import sa_helpers -------------------------------------------------------------------------------- /state_abs/indicator_funcs.py: -------------------------------------------------------------------------------- 1 | import random 2 | from simple_rl.tasks import FourRoomMDP 3 | from decimal import Decimal 4 | 5 | def _four_rooms(state_x, state_y, vi, actions, epsilon=0.0): 6 | if not isinstance(vi.mdp, FourRoomMDP): 7 | print "Abstraction Error: four_rooms SA only available for FourRoomMDP. (" + str(vi.mdp) + "given)." 8 | quit() 9 | height, width = vi.mdp.width, vi.mdp.height 10 | 11 | if (state_x.x < width / 2.0) == (state_y.x < width / 2.0) \ 12 | and (state_x.y < height / 2.0) == (state_y.y < height / 2.0): 13 | return True 14 | return False 15 | 16 | def _random(state_x, state_y, vi, actions, epsilon=0.0): 17 | ''' 18 | Args: 19 | state_x (State) 20 | state_y (State) 21 | vi (ValueIteration) 22 | actions (list) 23 | 24 | Returns: 25 | (bool): true randomly. 26 | ''' 27 | cluster_prob = max(100.0 / vi.get_num_states(), 0.5) 28 | return random.random() > 0.3 29 | 30 | def _v_approx_indicator(state_x, state_y, vi, actions, epsilon=0.0): 31 | ''' 32 | Args: 33 | state_x (State) 34 | state_y (State) 35 | vi (ValueIteration) 36 | actions (list) 37 | 38 | Returns: 39 | (bool): true iff: 40 | max |V(state_x) - V(state_y)| <= epsilon 41 | ''' 42 | return abs(vi.get_value(state_x) - vi.get_value(state_y)) <= epsilon 43 | 44 | def _q_eps_approx_indicator(state_x, state_y, vi, actions, epsilon=0.0): 45 | ''' 46 | Args: 47 | state_x (State) 48 | state_y (State) 49 | vi (ValueIteration) 50 | actions (list) 51 | 52 | Returns: 53 | (bool): true iff: 54 | max |Q(state_x,a) - Q(state_y, a)| <= epsilon 55 | ''' 56 | for a in actions: 57 | q_x = vi.get_q_value(state_x, a) 58 | q_y = vi.get_q_value(state_y, a) 59 | 60 | if abs(q_x - q_y) > epsilon: 61 | return False 62 | 63 | return True 64 | 65 | def _q_disc_approx_indicator(state_x, state_y, vi, actions, epsilon=0.0): 66 | ''' 67 | Args: 68 | state_x (State) 69 | state_y (State) 70 | vi (ValueIteration) 71 | actions (list) 72 | 73 | Returns: 74 | (bool): true iff: 75 | ''' 76 | v_max = 1 #/ (1 - 0.95) 77 | 78 | if epsilon == 0.0: 79 | return _q_eps_approx_indicator(state_x, state_y, vi, actions, epsilon=0) 80 | 81 | for a in actions: 82 | 83 | q_x, q_y = vi.get_q_value(state_x, a), vi.get_q_value(state_y, a) 84 | 85 | bucket_x = int( (q_x * (v_max / epsilon))) 86 | bucket_y = int( (q_y * (v_max / epsilon))) 87 | 88 | if bucket_x != bucket_y: 89 | return False 90 | 91 | return True 92 | 93 | def _v_disc_approx_indicator(state_x, state_y, vi, actions, epsilon=0.0): 94 | ''' 95 | Args: 96 | state_x (State) 97 | state_y (State) 98 | vi (ValueIteration) 99 | actions (list) 100 | 101 | Returns: 102 | (bool): true iff: 103 | ''' 104 | v_max = 1 / (1 - 0.95) 105 | 106 | if epsilon == 0.0: 107 | return _v_approx_indicator(state_x, state_y, vi, actions, epsilon=0) 108 | 109 | v_x, v_y = vi.get_value(state_x), vi.get_value(state_y) 110 | 111 | bucket_x = int( (v_x / v_max) / epsilon) 112 | bucket_y = int( (v_y / v_max) / epsilon) 113 | 114 | return bucket_x == bucket_y 115 | -------------------------------------------------------------------------------- /state_abs/sa_helpers.py: -------------------------------------------------------------------------------- 1 | # Python imports. 2 | from collections import defaultdict 3 | import cPickle 4 | import os 5 | import sys 6 | import itertools 7 | import numpy as np 8 | 9 | # Other imports. 10 | from simple_rl.planning.ValueIterationClass import ValueIteration 11 | from simple_rl.mdp import State 12 | from simple_rl.mdp import MDPDistribution 13 | import indicator_funcs as ind_funcs 14 | from StateAbstractionClass import StateAbstraction 15 | 16 | 17 | def get_pac_sa_from_samples(mdp_distr, indic_func=ind_funcs._q_eps_approx_indicator, phi_epsilon=0.0, delta=0.2): 18 | ''' 19 | Args: 20 | mdp_distr (MDPDistribution) 21 | indicator_func (S x S --> {0,1}) 22 | epsilon (float) 23 | delta (float) 24 | 25 | Returns: 26 | (StateAbstraction) 27 | 28 | 29 | Summary: 30 | Computes a PAC state abstraction. 31 | ''' 32 | sample_eps = 1.0 33 | pac_sample_bound = max(int(np.log(1 / delta) / sample_eps**2), 2) 34 | print("PAC sample bound:", pac_sample_bound) 35 | sa_list = [] 36 | for sample in xrange(pac_sample_bound): 37 | mdp = mdp_distr.sample() 38 | sa = make_singletask_sa(mdp, indic_func, phi_epsilon) #, prob_of_mdp=mdp_distr.get_prob_of_mdp(mdp)) 39 | sa_list += [sa] 40 | 41 | pac_state_abstr = merge_state_abs(sa_list) 42 | 43 | return pac_state_abstr 44 | 45 | def merge_state_abs(list_of_sa): 46 | ''' 47 | Args: 48 | list_of_sa (list of StateAbstraction) 49 | 50 | Returns: 51 | (StateAbstraction) 52 | ''' 53 | merged = list_of_sa[0] 54 | 55 | for sa in list_of_sa[1:]: 56 | merged = merged + sa 57 | 58 | return merged 59 | 60 | def compute_planned_state_abs(mdp_class="grid", num_mdps=30): 61 | ''' 62 | Args: 63 | mdp_class (str) 64 | num_mdps (int) 65 | 66 | Returns: 67 | (StateAbstraction) 68 | ''' 69 | 70 | # Setup grid params for MDPs. 71 | goal_locs = [] 72 | width, height = 7, 4 73 | for element in itertools.product(range(1, width + 1), [height]): 74 | goal_locs.append(element) 75 | 76 | # Compute the optimal Q^* abstraction for each MDP. 77 | state_abstrs = [] 78 | for i in xrange(num_mdps): 79 | left = goal_locs[:len(goal_locs) / 2] 80 | right = goal_locs[len(goal_locs) / 2:] 81 | mdp = GridWorldMDP(width=width, height=height, init_loc=(1, 1), goal_locs=r.choice([left, right])) 82 | state_abstrs.append(make_sa(mdp, ind_funcs._q_eps_approx_indicator)) 83 | 84 | # Merge 85 | merged_sa = merge_state_abs(state_abstrs) 86 | 87 | return merged_sa 88 | 89 | def make_sa(mdp, indic_func=ind_funcs._q_eps_approx_indicator, state_class=State, epsilon=0.0): 90 | ''' 91 | Args: 92 | mdp (MDP) 93 | state_class (Class) 94 | epsilon (float) 95 | 96 | Summary: 97 | Creates and saves a state abstraction. 98 | ''' 99 | print " Making state abstraction... " 100 | new_sa = StateAbstraction(phi={}) 101 | if isinstance(mdp, MDPDistribution): 102 | new_sa = make_multitask_sa(mdp, state_class=state_class, indic_func=indic_func, epsilon=epsilon) 103 | else: 104 | new_sa = make_singletask_sa(mdp, state_class=state_class, indic_func=indic_func, epsilon=epsilon) 105 | 106 | print " (final SA) Num abstract states:", new_sa.get_num_abstr_states() 107 | 108 | return new_sa 109 | 110 | def make_multitask_sa(mdp_distr, state_class=State, indic_func=ind_funcs._q_eps_approx_indicator, epsilon=0.0, aa_single_act=True): 111 | ''' 112 | Args: 113 | mdp_distr (MDPDistribution) 114 | state_class (Class) 115 | indicator_func (S x S --> {0,1}) 116 | epsilon (float) 117 | aa_single_act (bool): If we should track optimal actions. 118 | 119 | Returns: 120 | (StateAbstraction) 121 | ''' 122 | sa_list = [] 123 | for mdp in mdp_distr.get_mdps(): 124 | sa = make_singletask_sa(mdp, indic_func, state_class, epsilon, aa_single_act=aa_single_act, prob_of_mdp=mdp_distr.get_prob_of_mdp(mdp)) 125 | sa_list += [sa] 126 | 127 | multitask_sa = merge_state_abs(sa_list) 128 | 129 | return multitask_sa 130 | 131 | def make_singletask_sa(mdp, indic_func, state_class, epsilon=0.0, aa_single_act=False, prob_of_mdp=1.0): 132 | ''' 133 | Args: 134 | mdp (MDP) 135 | indic_func (S x S --> {0,1}) 136 | state_class (Class) 137 | epsilon (float) 138 | 139 | Returns: 140 | (StateAbstraction) 141 | ''' 142 | 143 | print "\tRunning VI...", 144 | sys.stdout.flush() 145 | # Run VI 146 | if isinstance(mdp, MDPDistribution): 147 | mdp = mdp.sample() 148 | 149 | vi = ValueIteration(mdp) 150 | iters, val = vi.run_vi() 151 | print " done." 152 | 153 | print "\tMaking state abstraction...", 154 | sys.stdout.flush() 155 | sa = StateAbstraction(phi={}, state_class=state_class) 156 | clusters = defaultdict(set) 157 | num_states = len(vi.get_states()) 158 | actions = mdp.get_actions() 159 | 160 | # Find state pairs that satisfy the condition. 161 | for i, state_x in enumerate(vi.get_states()): 162 | sys.stdout.flush() 163 | clusters[state_x].add(state_x) 164 | 165 | for state_y in vi.get_states()[i:]: 166 | if not(state_x == state_y) and indic_func(state_x, state_y, vi, actions, epsilon=epsilon): 167 | clusters[state_x].add(state_y) 168 | clusters[state_y].add(state_x) 169 | 170 | print "making clusters...", 171 | sys.stdout.flush() 172 | 173 | # Build SA. 174 | for i, state in enumerate(clusters.keys()): 175 | new_cluster = clusters[state] 176 | sa.make_cluster(new_cluster) 177 | 178 | # Destroy old so we don't double up. 179 | for s in clusters[state]: 180 | if s in clusters.keys(): 181 | clusters.pop(s) 182 | 183 | print " done." 184 | print "\tGround States:", num_states 185 | print "\tAbstract:", sa.get_num_abstr_states() 186 | print 187 | 188 | return sa 189 | 190 | # ------------ Indicator Functions ------------ 191 | 192 | def agent_q_estimate_equal(state_x, state_y, agent, state_abs, action_abs=[], epsilon=0.0): 193 | ''' 194 | Args: 195 | state_x (State) 196 | state_y (State) 197 | agent (Agent) 198 | state_abs (StateAbstraction) 199 | action_abs (ActionAbstraction) 200 | 201 | Returns: 202 | (bool): true iff: 203 | max |agent.Q(state_x,a) - agent.Q(state_y, a)| <= epsilon 204 | ''' 205 | for a in agent.actions: 206 | q_x = agent.get_q_value(state_abs(state_x), a) 207 | q_y = agent.get_q_value(state_abs(state_y), a) 208 | if abs(q_x - q_y) > epsilon: 209 | return False 210 | 211 | return True 212 | 213 | def agent_always_false(state_x, state_y, agent): 214 | return False 215 | 216 | def load_sa(file_name): 217 | this_dir = os.path.dirname(os.path.realpath(__file__)) 218 | if os.path.isfile(this_dir + "/cached_sa/" + file_name): 219 | return cPickle.load( open( this_dir + "/cached_sa/" + file_name, "rb" ) ) 220 | else: 221 | print "Warning: no saved State Abstraction with name '" + file_name + "'." 222 | 223 | def save_sa(sa, file_name): 224 | this_dir = os.path.dirname(os.path.realpath(__file__)) 225 | cPickle.dump( sa, open( this_dir + "/cached_sa/" + file_name, "w" ) ) 226 | -------------------------------------------------------------------------------- /utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/david-abel/rl_abstraction/cfe628abbee2614ff873713dceb466b293aa7329/utils/.DS_Store -------------------------------------------------------------------------------- /utils/AbstractValueIterationClass.py: -------------------------------------------------------------------------------- 1 | # Python imports. 2 | import random 3 | import Queue 4 | from collections import defaultdict 5 | 6 | # Other imports. 7 | import make_mdp 8 | import abstraction_experiments as ae 9 | from simple_rl.abstraction import ActionAbstraction 10 | from simple_rl.abstraction import StateAbstraction 11 | from simple_rl.abstraction.abstr_mdp.abstr_mdp_funcs import make_abstr_mdp 12 | from simple_rl.planning.PlannerClass import Planner 13 | from simple_rl.planning.ValueIterationClass import ValueIteration 14 | 15 | class AbstractValueIteration(ValueIteration): 16 | ''' AbstractValueIteration: Runs Value Iteration using a state and action abstraction ''' 17 | 18 | def __init__(self, ground_mdp, state_abstr=None, action_abstr=None, sample_rate=10, delta=0.001, max_iterations=1000): 19 | ''' 20 | Args: 21 | ground_mdp (MDP) 22 | state_abstr (StateAbstraction) 23 | action_abstr (ActionAbstraction) 24 | ''' 25 | self.ground_mdp = ground_mdp 26 | self.state_abstr = state_abstr if state_abstr not in [[], None] else StateAbstraction() 27 | self.action_abstr = action_abstr if action_abstr not in [[], None] else ActionAbstraction(prim_actions=ground_mdp.get_actions()) 28 | 29 | mdp = make_abstr_mdp(ground_mdp, self.state_abstr, self.action_abstr) 30 | 31 | ValueIteration.__init__(self, mdp, sample_rate, delta, max_iterations) 32 | 33 | # self.delta = delta 34 | # self.max_iterations = max_iterations 35 | # self.sample_rate = sample_rate 36 | 37 | # self.value_func = defaultdict(float) 38 | # self.reachability_done = False 39 | # self.has_run_vi = False 40 | # self._compute_reachable_state_space() 41 | 42 | # def get_num_states(self): 43 | # return len(self.states) 44 | 45 | # def get_states(self): 46 | # if self.reachability_done: 47 | # return self.states 48 | # else: 49 | # self._compute_reachable_state_space() 50 | # return self.states 51 | 52 | # def _compute_reachable_state_space(self): 53 | # ''' 54 | # Summary: 55 | # Starting with @self.start_state, determines all reachable states 56 | # and stores their abstracted counterparts in self.states. 57 | # ''' 58 | # state_queue = Queue.Queue() 59 | # s_g_init = self.mdp.get_init_state() 60 | # s_a_init = self.state_abstr.phi(s_g_init) 61 | # state_queue.put(s_g_init) 62 | # self.states.add(s_a_init) 63 | # ground_t = self.mdp.get_transition_func() 64 | 65 | # while not state_queue.empty(): 66 | # ground_state = state_queue.get() 67 | # for option in self.action_abstr.get_active_options(ground_state): 68 | # # For each active option. 69 | 70 | # # Take @sample_rate samples to estimate E[V] 71 | # for samples in xrange(self.sample_rate): 72 | 73 | # next_g_state = option.act_until_terminal(ground_state, ground_t) 74 | 75 | # if next_g_state not in self.states: 76 | # next_a_state = self.state_abstr.phi(next_g_state) 77 | # self.states.add(next_a_state) 78 | # state_queue.put(next_g_state) 79 | 80 | # self.reachability_done = True 81 | 82 | # def plan(self, ground_state=None, horizon=100): 83 | # ''' 84 | # Args: 85 | # ground_state (State) 86 | # horizon (int) 87 | 88 | # Returns: 89 | # (tuple): 90 | # (list): List of primitive actions taken. 91 | # (list): List of ground states. 92 | # (list): List of abstract actions taken. 93 | # ''' 94 | 95 | # ground_state = self.mdp.get_init_state() if ground_state is None else ground_state 96 | 97 | # if self.has_run_vi is False: 98 | # print "Warning: VI has not been run. Plan will be random." 99 | 100 | # primitive_action_seq = [] 101 | # abstr_action_seq = [] 102 | # state_seq = [ground_state] 103 | # steps = 0 104 | 105 | # ground_t = self.transition_func 106 | 107 | # # Until terminating condition is met. 108 | # while (not ground_state.is_terminal()) and steps < horizon: 109 | 110 | # # Compute best action, roll it out. 111 | # next_option = self._get_max_q_action(ground_state) 112 | 113 | # while not next_option.is_term_true(ground_state): 114 | # # Keep applying option until it terminates. 115 | # abstr_state = self.state_abstr.phi(ground_state) 116 | # ground_action = next_option.act(ground_state) 117 | # ground_state = ground_t(ground_state, ground_action) 118 | # steps += 1 119 | # primitive_action_seq.append(ground_action) 120 | 121 | # state_seq.append(ground_state) 122 | 123 | # abstr_action_seq.append(next_option) 124 | 125 | # return primitive_action_seq, state_seq, abstr_action_seq 126 | 127 | # def run_vi(self): 128 | # ''' 129 | # Summary: 130 | # Runs ValueIteration and fills in the self.value_func. 131 | # ''' 132 | # # Algorithm bookkeeping params. 133 | # iterations = 0 134 | # max_diff = float("inf") 135 | 136 | # # Main loop. 137 | # while max_diff > self.delta and iterations < self.max_iterations: 138 | # max_diff = 0 139 | # for s_g in self.get_states(): 140 | # if s_g.is_terminal(): 141 | # continue 142 | 143 | # max_q = float("-inf") 144 | # for a in self.action_abstr.get_active_options(s_g): 145 | # # For each active option, compute it's q value. 146 | # q_s_a = self.get_q_value(s_g, a) 147 | # max_q = q_s_a if q_s_a > max_q else max_q 148 | 149 | # # Check terminating condition. 150 | # max_diff = max(abs(self.value_func[s_g] - max_q), max_diff) 151 | 152 | # # Update value. 153 | # self.value_func[s_g] = max_q 154 | 155 | # iterations += 1 156 | 157 | # value_of_init_state = self._compute_max_qval_action_pair(self.init_state)[0] 158 | 159 | # self.has_run_vi = True 160 | 161 | # return iterations, value_of_init_state 162 | 163 | # def get_q_value(self, s_g, option): 164 | # ''' 165 | # Args: 166 | # s (State) 167 | # a (Option): Assumed active option. 168 | 169 | # Returns: 170 | # (float): The Q estimate given the current value function @self.value_func. 171 | # ''' 172 | 173 | # # Take samples and track next state counts. 174 | # next_state_counts = defaultdict(int) 175 | # reward_total = 0 176 | # for samples in xrange(self.sample_rate): # Take @sample_rate samples to estimate E[V] 177 | # next_state, reward, num_steps = self.do_rollout(option, s_g) 178 | # next_state_counts[next_state] += 1 179 | # reward_total += reward 180 | 181 | # # Compute T(s' | s, option) estimate based on MLE and R(s, option). 182 | # next_state_probs = defaultdict(float) 183 | # avg_reward = 0 184 | # for state in next_state_counts: 185 | # next_state_probs[state] = float(next_state_counts[state]) / self.sample_rate 186 | 187 | # avg_reward = float(reward_total) / self.sample_rate 188 | 189 | # # Compute expected value. 190 | # expected_future_val = 0 191 | # for state in next_state_probs: 192 | # expected_future_val += next_state_probs[state] * self.value_func[state] 193 | 194 | # return avg_reward + self.gamma*expected_future_val 195 | 196 | # def do_rollout(self, option, ground_state): 197 | # ''' 198 | # Args: 199 | # option (Option) 200 | # ground_state (State) 201 | 202 | # Returns: 203 | # (tuple): 204 | # (State): Next ground state. 205 | # (float): Reward. 206 | # (int): Number of steps taken. 207 | # ''' 208 | 209 | # ground_t = self.mdp.get_transition_func() 210 | # ground_r = self.mdp.get_reward_func() 211 | 212 | # if type(option) is str: 213 | # ground_action = option 214 | # else: 215 | # ground_action = option.act(ground_state) 216 | # total_reward = ground_r(ground_state, ground_action) 217 | # ground_state = ground_t(ground_state, ground_action) 218 | 219 | # total_steps = 1 220 | # while type(option) is not str and not option.is_term_true(ground_state): 221 | # # Keep applying option until it terminates. 222 | # ground_action = option.act(ground_state) 223 | # total_reward += ground_r(ground_state, ground_action) 224 | # ground_state = ground_t(ground_state, ground_action) 225 | # total_steps += 1 226 | 227 | # return ground_state, total_reward, total_steps 228 | 229 | # def _compute_max_qval_action_pair(self, state): 230 | # ''' 231 | # Args: 232 | # state (State) 233 | 234 | # Returns: 235 | # (tuple) --> (float, str): where the float is the Qval, str is the action. 236 | # ''' 237 | # # Grab random initial action in case all equal 238 | # max_q_val = float("-inf") 239 | # shuffled_option_list = self.action_abstr.get_active_options(state)[:] 240 | # if len(shuffled_option_list) == 0: 241 | # # Prims on failure. 242 | # shuffled_option_list = self.mdp.get_actions() 243 | 244 | # random.shuffle(shuffled_option_list) 245 | # best_action = shuffled_option_list[0] 246 | 247 | # # Find best action (action w/ current max predicted Q value) 248 | # for option in shuffled_option_list: 249 | # q_s_a = self.get_q_value(state, option) 250 | # if q_s_a > max_q_val: 251 | # max_q_val = q_s_a 252 | # best_action = option 253 | 254 | # return max_q_val, best_action 255 | 256 | # def _get_max_q_action(self, state): 257 | # ''' 258 | # Args: 259 | # state (State) 260 | 261 | # Returns: 262 | # (str): denoting the action with the max q value in the given @state. 263 | # ''' 264 | # return self._compute_max_qval_action_pair(state)[1] 265 | 266 | # def policy(self, state): 267 | # ''' 268 | # Args: 269 | # state (State) 270 | 271 | # Returns: 272 | # (str): Action 273 | 274 | # Summary: 275 | # For use in a FixedPolicyAgent. 276 | # ''' 277 | # return self._get_max_q_action(state) 278 | 279 | # def main(): 280 | # # MDP Setting. 281 | # multi_task = False 282 | # mdp_class = "grid" 283 | 284 | # # Make single/multi task environment. 285 | # environment = make_mdp.make_mdp_distr(mdp_class=mdp_class, num_mdps=3, horizon=30) if multi_task else make_mdp.make_mdp(mdp_class=mdp_class) 286 | # actions = environment.get_actions() 287 | # gamma = environment.get_gamma() 288 | 289 | # directed_sa, directed_aa = ae.get_abstractions(environment, directed=True) 290 | # default_sa, default_aa = ae.get_sa(environment, default=True), ae.get_aa(environment, default=True) 291 | 292 | # vi = ValueIteration(environment) 293 | # avi = AbstractValueIteration(environment, state_abstr=default_sa, action_abstr=default_aa) 294 | 295 | # a_num_iters, a_val = avi.run_vi() 296 | # g_num_iters, g_val = vi.run_vi() 297 | 298 | # print "a", a_num_iters, a_val 299 | # print "g", g_num_iters, g_val 300 | 301 | 302 | # if __name__ == "__main__": 303 | # main() -------------------------------------------------------------------------------- /utils/AbstractionWrapperClass.py: -------------------------------------------------------------------------------- 1 | # Python imports. 2 | from collections import defaultdict 3 | import copy 4 | import os 5 | 6 | # Other imports. 7 | from simple_rl.agents import Agent, RMaxAgent 8 | from state_abs.StateAbstractionClass import StateAbstraction 9 | from action_abs.ActionAbstractionClass import ActionAbstraction 10 | 11 | class AbstractionWrapper(Agent): 12 | 13 | def __init__(self, 14 | SubAgentClass, 15 | actions, 16 | mdp_name, 17 | max_option_steps=0, 18 | state_abstr=None, 19 | action_abstr=None, 20 | name_ext="abstr"): 21 | ''' 22 | Args: 23 | SubAgentClass (simple_rl.AgentClass) 24 | actions (list of str) 25 | mdp_name (str) 26 | state_abstr (StateAbstraction) 27 | state_abstr (ActionAbstraction) 28 | name_ext (str) 29 | ''' 30 | 31 | # Setup the abstracted agent. 32 | self._create_default_abstractions(actions, state_abstr, action_abstr) 33 | self.agent = SubAgentClass(actions=self.action_abstr.get_actions()) 34 | self.exp_directory = os.path.join(os.getcwdu(), "results", mdp_name, "options") 35 | self.reward_since_tracking = 0 36 | self.max_option_steps = max_option_steps 37 | self.num_option_steps = 0 38 | Agent.__init__(self, name=self.agent.name + "-" + name_ext, actions=self.action_abstr.get_actions()) 39 | self._setup_files() 40 | 41 | def _setup_files(self): 42 | ''' 43 | Summary: 44 | Creates and removes relevant directories/files. 45 | ''' 46 | if not os.path.exists(os.path.join(self.exp_directory)): 47 | os.makedirs(self.exp_directory) 48 | 49 | if os.path.exists(os.path.join(self.exp_directory, str(self.name)) + ".csv"): 50 | # Remove old 51 | os.remove(os.path.join(self.exp_directory, str(self.name)) + ".csv") 52 | 53 | 54 | def write_datum_to_file(self, datum): 55 | ''' 56 | Summary: 57 | Writes datum to file. 58 | ''' 59 | out_file = open(os.path.join(self.exp_directory, str(self.name)) + ".csv", "a+") 60 | out_file.write(str(datum) + ",") 61 | out_file.close() 62 | 63 | def _record_experience(self, ground_state, reward): 64 | ''' 65 | Args: 66 | abstr_state 67 | abstr_action 68 | reward 69 | next_abstr_state 70 | 71 | Summary: 72 | Tracks experiences to display plots in terms of options. 73 | ''' 74 | # if not self.action_abstr.is_next_step_continuing_option(ground_state): 75 | self.write_datum_to_file(self.reward_since_tracking) 76 | self.reward_since_tracking = 0 77 | 78 | def _create_default_abstractions(self, actions, state_abstr, action_abstr): 79 | ''' 80 | Summary: 81 | We here create the default abstractions. 82 | ''' 83 | if action_abstr is None: 84 | self.action_abstr = ActionAbstraction(options=actions, prim_actions=actions) 85 | else: 86 | self.action_abstr = action_abstr 87 | 88 | self.state_abstr = StateAbstraction(phi={}) if state_abstr is None else state_abstr 89 | 90 | def act(self, ground_state, reward): 91 | ''' 92 | Args: 93 | ground_state (State) 94 | reward (float) 95 | 96 | Return: 97 | (str) 98 | ''' 99 | self.reward_since_tracking += reward 100 | 101 | if self.max_option_steps > 0: 102 | # We're counting action steps in terms of options. 103 | if self.num_option_steps == self.max_option_steps: 104 | # We're at the limit. 105 | self._record_experience(ground_state, reward) 106 | self.num_option_steps += 1 107 | return "terminate" 108 | elif self.num_option_steps > self.max_option_steps: 109 | # Skip. 110 | return "terminate" 111 | elif not self.action_abstr.is_next_step_continuing_option(ground_state): 112 | # Taking a new option, count it and continue. 113 | self.num_option_steps += 1 114 | self._record_experience(ground_state, reward) 115 | else: 116 | self._record_experience(ground_state, reward) 117 | 118 | abstr_state = self.state_abstr.phi(ground_state) 119 | 120 | # print ground_state, abstr_state, hash(ground_state) 121 | 122 | ground_action = self.action_abstr.act(self.agent, abstr_state, ground_state, reward) 123 | 124 | # print "ground_action", ground_action, type(ground_action), len(ground_action) 125 | 126 | return ground_action 127 | 128 | def reset(self): 129 | # Write data. 130 | out_file = open(os.path.join(self.exp_directory, str(self.name)) + ".csv", "a+") 131 | out_file.write("\n") 132 | out_file.close() 133 | self.agent.reset() 134 | self.action_abstr.reset() 135 | self.reward_since_tracking = 0 136 | self.num_option_steps = 0 137 | 138 | def new_task(self): 139 | self._reset_reward() 140 | 141 | def get_num_known_sa(self): 142 | return self.agent.get_num_known_sa() 143 | 144 | def _reset_reward(self): 145 | if isinstance(self.agent, RMaxAgent): 146 | self.agent._reset_reward() 147 | 148 | def end_of_episode(self): 149 | self.agent.end_of_episode() 150 | self.action_abstr.end_of_episode() 151 | -------------------------------------------------------------------------------- /utils/ColorMDPClass.py: -------------------------------------------------------------------------------- 1 | ''' ColorMDPClass.py: Contains the ColorMDP class. ''' 2 | 3 | # Python imports. 4 | from __future__ import print_function 5 | import random 6 | import sys 7 | import os 8 | import numpy as np 9 | import math 10 | 11 | # Other imports. 12 | from simple_rl.mdp.MDPClass import MDP 13 | from ColorStateClass import ColorState 14 | 15 | class ColorMDP(MDP): 16 | ''' Class for a Grid World MDP ''' 17 | 18 | 19 | COLOR_MAP = [] 20 | 21 | # Static constants. 22 | ACTIONS = ["up", "down", "left", "right", "paint"] 23 | 24 | def __init__(self, 25 | width=5, 26 | height=3, 27 | init_loc=(1,1), 28 | goal_locs=[(5,3)], 29 | num_colors=3, 30 | is_goal_terminal=True, 31 | gamma=0.99, 32 | init_state=None, 33 | slip_prob=0.0, 34 | name="color"): 35 | ''' 36 | Args: 37 | height (int) 38 | width (int) 39 | init_loc (tuple: (int, int)) 40 | goal_locs (list of tuples: [(int, int)...]) 41 | ''' 42 | 43 | ColorMDP.COLOR_MAP = range(num_colors) 44 | 45 | # Setup init location. 46 | self.init_loc = init_loc 47 | init_state = ColorState(init_loc[0], init_loc[1], ColorMDP.COLOR_MAP[0]) if init_state is None or rand_init else init_state 48 | 49 | MDP.__init__(self, ColorMDP.ACTIONS, self._transition_func, self._reward_func, init_state=init_state, gamma=gamma) 50 | 51 | if type(goal_locs) is not list: 52 | print("(simple_rl) color Error: argument @goal_locs needs to be a list of locations. For example: [(3,3), (4,3)].") 53 | quit() 54 | 55 | self.width = width 56 | self.height = height 57 | self.walls = self._compute_walls() 58 | self.goal_locs = goal_locs 59 | self.cur_state = ColorState(init_loc[0], init_loc[1], ColorMDP.COLOR_MAP[0]) 60 | self.is_goal_terminal = is_goal_terminal 61 | self.slip_prob = slip_prob 62 | self.name = name 63 | 64 | def _reward_func(self, state, action, next_state): 65 | ''' 66 | Args: 67 | state (State) 68 | action (str) 69 | 70 | Returns 71 | (float) 72 | ''' 73 | if self._is_goal_state_action(state, action): 74 | return 1.0 - self.step_cost 75 | else: 76 | return 0 - self.step_cost 77 | 78 | def _is_goal_state_action(self, state, action): 79 | ''' 80 | Args: 81 | state (State) 82 | action (str) 83 | 84 | Returns: 85 | (bool): True iff the state-action pair send the agent to the goal state. 86 | ''' 87 | if (state.x, state.y) in self.goal_locs and self.is_goal_terminal: 88 | # Already at terminal. 89 | return False 90 | 91 | goals = self.goal_locs 92 | 93 | if action == "left" and (state.x - 1, state.y) in goals: 94 | return True 95 | elif action == "right" and (state.x + 1, state.y) in goals: 96 | return True 97 | elif action == "down" and (state.x, state.y - 1) in goals: 98 | return True 99 | elif action == "up" and (state.x, state.y + 1) in goals: 100 | return True 101 | else: 102 | return False 103 | 104 | def _transition_func(self, state, action): 105 | ''' 106 | Args: 107 | state (State) 108 | action (str) 109 | 110 | Returns 111 | (State) 112 | ''' 113 | if state.is_terminal(): 114 | return state 115 | 116 | if action == "paint": 117 | return ColorState(state.x, state.y, color=random.choice(ColorMDP.COLOR_MAP)) 118 | 119 | r = random.random() 120 | if self.slip_prob > r: 121 | # Flip dir. 122 | if action == "up": 123 | action = random.choice(["left", "right"]) 124 | elif action == "down": 125 | action = random.choice(["left", "right"]) 126 | elif action == "left": 127 | action = random.choice(["up", "down"]) 128 | elif action == "right": 129 | action = random.choice(["up", "down"]) 130 | 131 | if action == "up" and state.y < self.height and not self.is_wall(state.x, state.y + 1): 132 | next_state = ColorState(state.x, state.y + 1, color=state.color) 133 | elif action == "down" and state.y > 1 and not self.is_wall(state.x, state.y - 1): 134 | next_state = ColorState(state.x, state.y - 1, color=state.color) 135 | elif action == "right" and state.x < self.width and not self.is_wall(state.x + 1, state.y): 136 | next_state = ColorState(state.x + 1, state.y, color=state.color) 137 | elif action == "left" and state.x > 1 and not self.is_wall(state.x - 1, state.y): 138 | next_state = ColorState(state.x - 1, state.y, color=state.color) 139 | else: 140 | next_state = ColorState(state.x, state.y, color=state.color) 141 | 142 | if (next_state.x, next_state.y) in self.goal_locs and self.is_goal_terminal: 143 | next_state.set_terminal(True) 144 | 145 | return next_state 146 | 147 | def is_wall(self, x, y): 148 | ''' 149 | Args: 150 | x (int) 151 | y (int) 152 | 153 | Returns: 154 | (bool): True iff (x,y) is a wall location. 155 | ''' 156 | 157 | return (x, y) in self.walls 158 | 159 | def __str__(self): 160 | return self.name + "_h-" + str(self.height) + "_w-" + str(self.width) 161 | 162 | def get_goal_locs(self): 163 | return self.goal_locs 164 | 165 | def visualize_policy(self, policy): 166 | from simple_rl.utils import mdp_visualizer as mdpv 167 | from grid_visualizer import _draw_state 168 | ["up", "down", "left", "right"] 169 | 170 | action_char_dict = { 171 | "up":u"\u2191", 172 | "down":u"\u2193", 173 | "left":u"\u2190", 174 | "right":u"\u2192" 175 | } 176 | 177 | mdpv.visualize_policy(self, policy, _draw_state, action_char_dict) 178 | raw_input("Press anything to quit ") 179 | quit() 180 | 181 | 182 | def visualize_agent(self, agent): 183 | from simple_rl.utils import mdp_visualizer as mdpv 184 | from grid_visualizer import _draw_state 185 | mdpv.visualize_agent(self, agent, _draw_state) 186 | raw_input("Press anything to quit ") 187 | 188 | def visualize_value(self): 189 | from simple_rl.utils import mdp_visualizer as mdpv 190 | from grid_visualizer import _draw_state 191 | mdpv.visualize_value(self, _draw_state) 192 | raw_input("Press anything to quit ") 193 | 194 | def _compute_walls(self): 195 | ''' 196 | Args: 197 | width (int) 198 | height (int) 199 | 200 | Returns: 201 | (list): Contains (x,y) pairs that define wall locations. 202 | ''' 203 | walls = [] 204 | 205 | half_width = math.ceil(self.width / 2.0) 206 | half_height = math.ceil(self.height / 2.0) 207 | 208 | for i in range(1, self.width + 1): 209 | if i == (self.width + 1) / 3 or i == math.ceil(2 * (self.width + 1) / 3.0): 210 | continue 211 | walls.append((i, half_height)) 212 | 213 | for j in range(1, self.height + 1): 214 | if j == (self.height + 1) / 3 or j == math.ceil(2 * (self.height + 1) / 3.0): 215 | continue 216 | walls.append((half_width, j)) 217 | 218 | return walls 219 | 220 | 221 | 222 | def _error_check(state, action): 223 | ''' 224 | Args: 225 | state (State) 226 | action (str) 227 | 228 | Summary: 229 | Checks to make sure the received state and action are of the right type. 230 | ''' 231 | 232 | if action not in ColorMDP.ACTIONS: 233 | print("(simple_rl) colorError: the action provided (" + str(action) + ") was invalid in state: " + str(state) + ".") 234 | quit() 235 | 236 | if not isinstance(state, ColorState): 237 | print("(simple_rl) colorError: the given state (" + str(state) + ") was not of the correct class.") 238 | quit() 239 | 240 | 241 | def make_grid_world_from_file(file_name, randomize=False, num_goals=1, name=None, goal_num=None, slip_prob=0.0): 242 | ''' 243 | Args: 244 | file_name (str) 245 | randomize (bool): If true, chooses a random agent location and goal location. 246 | num_goals (int) 247 | name (str) 248 | 249 | Returns: 250 | (ColorMDP) 251 | 252 | Summary: 253 | Builds a ColorMDP from a file: 254 | 'w' --> wall 255 | 'a' --> agent 256 | 'g' --> goal 257 | '-' --> empty 258 | ''' 259 | 260 | if name is None: 261 | name = file_name.split(".")[0] 262 | 263 | grid_path = os.path.dirname(os.path.realpath(__file__)) 264 | wall_file = open(os.path.join(grid_path, "txt_grids", file_name)) 265 | wall_lines = wall_file.readlines() 266 | 267 | # Get walls, agent, goal loc. 268 | num_rows = len(wall_lines) 269 | num_cols = len(wall_lines[0].strip()) 270 | empty_cells = [] 271 | agent_x, agent_y = 1, 1 272 | walls = [] 273 | goal_locs = [] 274 | for i, line in enumerate(wall_lines): 275 | line = line.strip() 276 | for j, ch in enumerate(line): 277 | if ch == "w": 278 | walls.append((j + 1, num_rows - i)) 279 | elif ch == "g": 280 | goal_locs.append((j + 1, num_rows - i)) 281 | elif ch == "a": 282 | agent_x, agent_y = j + 1, num_rows - i 283 | elif ch == "-": 284 | empty_cells.append((j + 1, num_rows - i)) 285 | 286 | if goal_num is not None: 287 | goal_locs = [goal_locs[goal_num % len(goal_locs)]] 288 | 289 | if randomize: 290 | agent_x, agent_y = random.choice(empty_cells) 291 | if len(goal_locs) == 0: 292 | # Sample @num_goals random goal locations. 293 | goal_locs = random.sample(empty_cells, num_goals) 294 | else: 295 | goal_locs = random.sample(goal_locs, num_goals) 296 | 297 | if len(goal_locs) == 0: 298 | goal_locs = [(num_cols, num_rows)] 299 | 300 | return ColorMDP(width=num_cols, height=num_rows, init_loc=(agent_x, agent_y), goal_locs=goal_locs, walls=walls, name=name, slip_prob=slip_prob) 301 | 302 | def reset(self): 303 | self.cur_state = copy.deepcopy(self.init_state) 304 | 305 | 306 | 307 | def main(): 308 | grid_world = ColorMDP(5, 10, (1, 1), (6, 7)) 309 | 310 | grid_world.visualize() 311 | 312 | if __name__ == "__main__": 313 | main() 314 | -------------------------------------------------------------------------------- /utils/ColorStateClass.py: -------------------------------------------------------------------------------- 1 | ''' GridWorldStateClass.py: Contains the GridWorldState class. ''' 2 | 3 | # Other imports. 4 | from simple_rl.mdp.StateClass import State 5 | 6 | class ColorState(State): 7 | ''' Class for Grid World States ''' 8 | 9 | def __init__(self, x, y, color): 10 | self.color = color 11 | State.__init__(self, data=[x, y, color]) 12 | self.x = round(x, 3) 13 | self.y = round(y, 3) 14 | 15 | def __hash__(self): 16 | return hash(tuple(self.data)) 17 | 18 | def __str__(self): 19 | return "s: (" + str(self.x) + "," + str(self.y) + " c: " + str(self.color) + ")" 20 | 21 | def __eq__(self, other): 22 | return isinstance(other, ColorState) and self.x == other.x and self.y == other.y and self.color == other.color 23 | -------------------------------------------------------------------------------- /utils/StochasticSAPolicyClass.py: -------------------------------------------------------------------------------- 1 | # Python imports 2 | import numpy as np 3 | from collections import defaultdict 4 | 5 | # Other imports. 6 | from simple_rl.planning import ValueIteration 7 | 8 | class StochasticSAPolicy(object): 9 | 10 | def __init__(self, state_abstr, mdp): 11 | self.state_abstr = state_abstr 12 | self.mdp = mdp 13 | self.vi = ValueIteration(mdp) 14 | self.vi.run_vi() 15 | 16 | def policy(self, state): 17 | ''' 18 | Args: 19 | (simple_rl.State) 20 | 21 | Returns: 22 | (str): An action 23 | 24 | Summary: 25 | Chooses an action among the optimal actions in the cluster. That is, roughly: 26 | 27 | \pi(a \mid s_a) \sim Pr_{s_g \in s_a} (a = a^*(s_a)) 28 | ''' 29 | 30 | abstr_state = self.state_abstr.phi(state) 31 | ground_states = self.state_abstr.get_ground_states_in_abs_state(abstr_state) 32 | 33 | action_distr = defaultdict(float) 34 | for s in ground_states: 35 | a = self.vi.policy(s) 36 | action_distr[a] += 1.0 / len(ground_states) 37 | 38 | sampled_distr = np.random.multinomial(1, action_distr.values()).tolist() 39 | indices = [i for i, x in enumerate(sampled_distr) if x > 0] 40 | 41 | return action_distr.keys()[indices[0]] 42 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/david-abel/rl_abstraction/cfe628abbee2614ff873713dceb466b293aa7329/utils/__init__.py -------------------------------------------------------------------------------- /utils/hierarch_rooms.txt: -------------------------------------------------------------------------------- 1 | gw-------w-w----------g 2 | ---w---w---w-w--w------ 3 | -w---w---w---w--w-w---- 4 | ---w---w---w-w--w------ 5 | -w---w---w-w-w--w-w---- 6 | ---w---w---w-w--w-w---- 7 | gw---w---w-w----w-----g 8 | wwww-wwwwwwwwwwww-wwwww 9 | ----------gw----------- 10 | ----w---w--wwwwwwwwwww- 11 | --w-w-www--w---------w- 12 | --w-w-www-------g----w- 13 | --w-w-www--w---------w- 14 | ----w------w-wwwww-www- 15 | a----------w----------g -------------------------------------------------------------------------------- /utils/make_mdp.py: -------------------------------------------------------------------------------- 1 | ''' 2 | make_mdp.py 3 | 4 | Utility for making MDP instances or distributions. 5 | ''' 6 | 7 | # Python imports. 8 | import itertools 9 | import random 10 | from collections import defaultdict 11 | 12 | # Other imports. 13 | from simple_rl.tasks import ChainMDP, GridWorldMDP, TaxiOOMDP, RandomMDP, FourRoomMDP, HanoiMDP, TrenchOOMDP 14 | from simple_rl.tasks.grid_world.GridWorldMDPClass import make_grid_world_from_file 15 | from simple_rl.mdp import MDPDistribution 16 | from ColorMDPClass import ColorMDP 17 | 18 | def make_mdp(mdp_class="grid", grid_dim=7): 19 | ''' 20 | Returns: 21 | (MDP) 22 | ''' 23 | # Grid/Hallway stuff. 24 | width, height = grid_dim, grid_dim 25 | upworld_goal_locs = [(i, width) for i in range(1, height+1)] 26 | 27 | four_room_goal_locs = [(width, height)] #, (width, 1), (1, height)] # (1, height - 2), (width - 2, height - 2), (width - 1, height - 1), (width - 2, 1)] 28 | four_room_goal_loc = four_room_goal_locs[0] 29 | 30 | # Taxi stuff. 31 | agent = {"x":1, "y":1, "has_passenger":0} 32 | passengers = [{"x":grid_dim / 2, "y":grid_dim / 2, "dest_x":grid_dim-2, "dest_y":2, "in_taxi":0}] 33 | walls = [] 34 | 35 | # Trench stuff 36 | tr_agent = {"x": 1, "y": 1, "dx": 1, "dy": 0, "dest_x": grid_dim, "dest_y": grid_dim, "has_block": 0} 37 | blocks = [{"x": grid_dim, "y": 1}] 38 | lavas = [{"x": x, "y": y} for x, y in map(lambda z: (z + 1, (grid_dim + 1) / 2), range(grid_dim))] 39 | 40 | # Do grids separately to avoid making error-prone domains. 41 | if mdp_class == "four_room": 42 | mdp = FourRoomMDP(width=width, height=height, goal_locs=[four_room_goal_loc]) 43 | else: 44 | mdp = {"upworld":GridWorldMDP(width=width, height=height, init_loc=(1, 1), goal_locs=upworld_goal_locs), 45 | "chain":ChainMDP(num_states=grid_dim), 46 | "random":RandomMDP(num_states=50, num_rand_trans=2), 47 | "hanoi":HanoiMDP(num_pegs=grid_dim, num_discs=3), 48 | "taxi":TaxiOOMDP(width=grid_dim, height=grid_dim, agent=agent, walls=walls, passengers=passengers), 49 | "trench":TrenchOOMDP(width=grid_dim, height=3, agent=tr_agent, blocks=blocks, lavas=lavas)}[mdp_class] 50 | 51 | return mdp 52 | 53 | def make_mdp_distr(mdp_class="grid", grid_dim=9, horizon=0, step_cost=0, gamma=0.99): 54 | ''' 55 | Args: 56 | mdp_class (str): one of {"grid", "random"} 57 | horizon (int) 58 | step_cost (float) 59 | gamma (float) 60 | 61 | Returns: 62 | (MDPDistribution) 63 | ''' 64 | mdp_dist_dict = {} 65 | height, width = grid_dim, grid_dim 66 | 67 | # Define goal locations. 68 | 69 | # Corridor. 70 | corr_width = 20 71 | corr_goal_magnitude = 1 #random.randint(1, 5) 72 | corr_goal_cols = [i for i in xrange(1, corr_goal_magnitude + 1)] + [j for j in xrange(corr_width-corr_goal_magnitude + 1, corr_width + 1)] 73 | corr_goal_locs = list(itertools.product(corr_goal_cols, [1])) 74 | 75 | # Grid World 76 | tl_grid_world_rows, tl_grid_world_cols = [i for i in xrange(width - 4, width)], [j for j in xrange(height - 4, height)] 77 | tl_grid_goal_locs = list(itertools.product(tl_grid_world_rows, tl_grid_world_cols)) 78 | tr_grid_world_rows, tr_grid_world_cols = [i for i in xrange(1, 4)], [j for j in xrange(height - 4, height)] 79 | tr_grid_goal_locs = list(itertools.product(tr_grid_world_rows, tr_grid_world_cols)) 80 | grid_goal_locs = tl_grid_goal_locs + tr_grid_goal_locs 81 | 82 | # Hallway. 83 | upworld_goal_locs = [(i, height) for i in xrange(1, 30)] 84 | 85 | # Four room. 86 | four_room_goal_locs = [(width, height), (width, 1), (1, height), (1, height - 2), (width - 2, height - 2), (width - 2, 1)] 87 | 88 | print four_room_goal_locs 89 | 90 | tight_four_room_goal_locs = [(width, height), (width, height-1), (width-1, height), (width, height - 2), (width - 2, height), (width-1, height-1)] 91 | 92 | # Taxi. 93 | agent = {"x":1, "y":1, "has_passenger":0} 94 | walls = [] 95 | 96 | goal_loc_dict = {"four_room":four_room_goal_locs, 97 | "color":four_room_goal_locs, 98 | "upworld":upworld_goal_locs, 99 | "grid":grid_goal_locs, 100 | "corridor":corr_goal_locs, 101 | "tight_four_room":tight_four_room_goal_locs, 102 | } 103 | 104 | # MDP Probability. 105 | num_mdps = 10 if mdp_class not in goal_loc_dict.keys() else len(goal_loc_dict[mdp_class]) 106 | if mdp_class == "octo": 107 | num_mdps = 12 108 | mdp_prob = 1.0 / num_mdps 109 | 110 | for i in xrange(num_mdps): 111 | 112 | new_mdp = {"hrooms":make_grid_world_from_file("hierarch_rooms.txt", num_goals=7, randomize=False), 113 | "octo":make_grid_world_from_file("octogrid.txt", num_goals=12, randomize=False, goal_num=i), 114 | "upworld":GridWorldMDP(width=30, height=height, rand_init=False, goal_locs=goal_loc_dict["upworld"], name="upworld", is_goal_terminal=True), 115 | "corridor":GridWorldMDP(width=20, height=1, init_loc=(10, 1), goal_locs=[goal_loc_dict["corridor"][i % len(goal_loc_dict["corridor"])]], is_goal_terminal=True, name="corridor"), 116 | "grid":GridWorldMDP(width=width, height=height, rand_init=True, goal_locs=[goal_loc_dict["grid"][i % len(goal_loc_dict["grid"])]], is_goal_terminal=True), 117 | "four_room":FourRoomMDP(width=width, height=height, goal_locs=[goal_loc_dict["four_room"][i % len(goal_loc_dict["four_room"])]], is_goal_terminal=True), 118 | "color":ColorMDP(width=width, height=height, num_colors=4, goal_locs=[goal_loc_dict["four_room"][i % len(goal_loc_dict["four_room"])]], is_goal_terminal=True), 119 | "tight_four_room":FourRoomMDP(width=width, height=height, goal_locs=[goal_loc_dict["tight_four_room"][i % len(goal_loc_dict["tight_four_room"])]], is_goal_terminal=True, name="tight_four_room")}[mdp_class] 120 | 121 | new_mdp.set_step_cost(step_cost) 122 | new_mdp.set_gamma(gamma) 123 | 124 | mdp_dist_dict[new_mdp] = mdp_prob 125 | 126 | return MDPDistribution(mdp_dist_dict, horizon=horizon) 127 | -------------------------------------------------------------------------------- /utils/octogrid.txt: -------------------------------------------------------------------------------- 1 | wwwwgwgwgwwww 2 | wwww-w-w-wwww 3 | wwww-w-w-wwww 4 | wwww-w-w-wwww 5 | g-----------g 6 | wwww-----wwww 7 | g-----a-----g 8 | wwww-----wwww 9 | g-----------g 10 | wwww-w-w-wwww 11 | wwww-w-w-wwww 12 | wwww-w-w-wwww 13 | wwwwgwgwgwwww -------------------------------------------------------------------------------- /utils/pblocks_grid.txt: -------------------------------------------------------------------------------- 1 | --w-----w---w----- 2 | --------w--------- 3 | --w-----w---w----- 4 | --w-----w---w----- 5 | wwwww-wwwwwwwww-ww 6 | ---w----w----w---- 7 | ---w---------w---- 8 | --------w--------- 9 | wwwwwwwww--------- 10 | w-------wwwwwww-ww 11 | --w-----w---w----- 12 | --------w--------- 13 | --w---------w----- 14 | --w-----w---w----- 15 | wwwww-wwwwwwwww-ww 16 | ---w-----w---w---- 17 | ---w-----w---w---- 18 | ---------w-------- -------------------------------------------------------------------------------- /utils/planning_experiments.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Python imports. 4 | import random 5 | from collections import defaultdict 6 | import os 7 | import time 8 | 9 | # Other imports. 10 | import make_mdp 11 | from simple_rl.agents import RandomAgent, RMaxAgent, QLearningAgent, FixedPolicyAgent 12 | from simple_rl.run_experiments import run_agents_lifelong 13 | from simple_rl.tasks import FourRoomMDP 14 | from simple_rl.planning import ValueIteration 15 | from AbstractValueIterationClass import AbstractValueIteration 16 | from state_abs.StateAbstractionClass import StateAbstraction 17 | from action_abs.ActionAbstractionClass import ActionAbstraction 18 | from state_abs import indicator_funcs as ind_funcs 19 | from abstraction_experiments import get_sa, get_directed_option_sa_pair 20 | 21 | def clear_files(dir_name): 22 | ''' 23 | Args: 24 | dir_name (str) 25 | 26 | Summary: 27 | Removes all csv files in @dir_name. 28 | ''' 29 | for extension in ["iters", "times"]: 30 | dir_w_extension = os.path.join(dir_name, extension) #, mdp_type) + ".csv" 31 | if not os.path.exists(dir_w_extension): 32 | os.makedirs(dir_w_extension) 33 | 34 | for mdp_type in ["vi", "vi-$\phi_{Q_d^*}$"]: 35 | if os.path.exists(os.path.join(dir_w_extension, mdp_type) + ".csv"): 36 | os.remove(os.path.join(dir_w_extension, mdp_type) + ".csv") 37 | 38 | def write_datum(file_name, datum): 39 | ''' 40 | Args: 41 | file_name (str) 42 | datum (object) 43 | ''' 44 | out_file = open(file_name, "a+") 45 | out_file.write(str(datum) + ",") 46 | out_file.close() 47 | 48 | def main(): 49 | 50 | # Grab experiment params. 51 | mdp_class = "upworld" 52 | gamma = 0.95 53 | vanilla_file = "vi.csv" 54 | sa_file = "vi-$\phi_{Q_d^*}.csv" 55 | file_prefix = "results/planning-" + mdp_class + "/" 56 | clear_files(dir_name=file_prefix) 57 | 58 | for grid_dim in xrange(3,20): 59 | # ====================== 60 | # == Make Environment == 61 | # ====================== 62 | environment = make_mdp.make_mdp(mdp_class=mdp_class, grid_dim=grid_dim) 63 | environment.set_gamma(gamma) 64 | 65 | # ======================= 66 | # == Make Abstractions == 67 | # ======================= 68 | sa_qds = get_sa(environment, indic_func=ind_funcs._q_disc_approx_indicator, epsilon=0.01) 69 | 70 | # ============ 71 | # == Run VI == 72 | # ============ 73 | vanilla_vi = ValueIteration(environment, delta=0.0001, sample_rate=5) 74 | sa_vi = AbstractValueIteration(ground_mdp=environment, state_abstr=sa_qds) 75 | 76 | print "Running VIs." 77 | start_time = time.clock() 78 | vanilla_iters, vanilla_val = vanilla_vi.run_vi() 79 | vanilla_time = round(time.clock() - start_time, 2) 80 | 81 | start_time = time.clock() 82 | sa_iters, sa_val = sa_vi.run_vi() 83 | sa_time = round(time.clock() - start_time, 2) 84 | 85 | print "vanilla", vanilla_iters, vanilla_val, vanilla_time 86 | print "sa:", sa_iters, sa_val, sa_time 87 | 88 | write_datum(file_prefix + "iters/" + vanilla_file, vanilla_iters) 89 | write_datum(file_prefix + "iters/" + sa_file, sa_iters) 90 | 91 | write_datum(file_prefix + "times/" + vanilla_file, vanilla_time) 92 | write_datum(file_prefix + "times/" + sa_file, sa_time) 93 | 94 | if __name__ == "__main__": 95 | main() 96 | -------------------------------------------------------------------------------- /utils/run_abstr_combo_experiments.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Python imports. 4 | import subprocess 5 | 6 | # Other imports. 7 | import abstraction_experiments 8 | from simple_rl.agents import RMaxAgent, QLearningAgent 9 | 10 | # Global params. 11 | track_options = False 12 | agent_class = "ql" # one of 'ql' or 'rmax' 13 | episodic = True 14 | 15 | def spawn_subproc(task, samples, steps, episodes=1, grid_dim=11, max_options=20): 16 | ''' 17 | Args: 18 | task (str) 19 | samples (int) 20 | steps (int) 21 | steps (int) 22 | grid_dim (int) 23 | 24 | Summary: 25 | Spawns a child subprocess to run the experiment. 26 | ''' 27 | cmd = ['./abstraction_experiments.py', \ 28 | '-task=' + str(task), \ 29 | '-samples=' + str(samples), \ 30 | '-episodes=' + str(episodes), 31 | '-steps=' + str(steps), \ 32 | '-grid_dim=' + str(grid_dim), \ 33 | '-agent=' + agent_class, \ 34 | '-max_options=' + str(max_options), 35 | '-exp_type=combo'] 36 | if track_options: 37 | cmd += ['-track_options=True'] 38 | 39 | subprocess.Popen(cmd) 40 | 41 | def main(): 42 | 43 | episodes = 1 44 | step_multipler = 2 45 | if episodic: 46 | episodes = 100 47 | step_multipler = 1 48 | 49 | # Hall. 50 | # spawn_subproc(task="hall", samples=500, episodes=episodes, steps=25 * step_multipler, grid_dim=15, max_options=32) 51 | 52 | # spawn_subproc(task="grid", samples=5, episodes=episodes, steps=50 * step_multipler, grid_dim=7, max_options=300) 53 | 54 | # Octo Grid. 55 | # spawn_subproc(task="octo", samples=500, episodes=episodes, steps=50 * step_multipler, max_options=50) 56 | 57 | spawn_subproc(task="rock_climb", samples=100, episodes=episodes, steps=50 * step_multipler, max_options=50) 58 | 59 | # Four rooms. 60 | # spawn_subproc(task="four_room", samples=300, episodes=episodes, steps=50 * step_multipler, grid_dim=15) 61 | 62 | # Grid random init. 63 | # spawn_subproc(task="whirlpool", samples=150, episodes=episodes, steps=50 * step_multipler, grid_dim=21, max_options=200) 64 | 65 | # # Ice Rink. 66 | # spawn_subproc(task="icerink", samples=300, episodes=episodes, steps=100 * step_multipler, grid_dim=10, max_options=40) 67 | 68 | 69 | if __name__ == "__main__": 70 | main() -------------------------------------------------------------------------------- /utils/run_dir_opt_core_experiments.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Python imports. 4 | import subprocess 5 | 6 | # Other imports. 7 | import abstraction_experiments 8 | from simple_rl.agents import RMaxAgent, QLearningAgent 9 | 10 | # Global params. 11 | track_options = False 12 | agent_class = "ql" # one of 'ql' or 'rmax' 13 | 14 | def spawn_subproc(task, samples, steps, grid_dim=11, max_options=20): 15 | ''' 16 | Args: 17 | task (str) 18 | samples (int) 19 | steps (int) 20 | grid_dim (int) 21 | 22 | Summary: 23 | Spawns a child subprocess to run the experiment. 24 | ''' 25 | cmd = ['./abstraction_experiments.py', \ 26 | '-task=' + str(task), \ 27 | '-samples=' + str(samples), \ 28 | '-steps=' + str(steps), '-grid_dim=' + str(grid_dim), \ 29 | '-agent=' + agent_class, \ 30 | '-max_options=' + str(max_options)] 31 | if track_options: 32 | cmd += ['-track_options=True'] 33 | 34 | subprocess.Popen(cmd) 35 | 36 | def main(): 37 | 38 | # Octo Grid 39 | spawn_subproc(task="octo", samples=1, steps=200) 40 | 41 | # Grid 42 | # spawn_subproc(task="grid", samples=500, steps=50, grid_dim=10) 43 | 44 | # Four rooms 45 | # spawn_subproc(task="four_room", samples=2000, steps=200, grid_dim=15) 46 | 47 | # Taxi 48 | # spawn_subproc(task="taxi", samples=1, steps=5, max_options=101) 49 | 50 | # Pblocks grid 51 | # spawn_subproc(task="pblocks_grid", samples=500, steps=100000, max_options=50) 52 | 53 | if __name__ == "__main__": 54 | main() -------------------------------------------------------------------------------- /utils/visualize_abstractions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Python imports. 4 | import sys 5 | import random 6 | import argparse 7 | from collections import defaultdict 8 | import os, inspect 9 | 10 | # Pygame setup. 11 | try: 12 | import pygame 13 | from pygame.locals import * 14 | 15 | except ImportError: 16 | print "Error: pygame not installed (needed for visuals)." 17 | quit() 18 | 19 | # Other imports. 20 | from simple_rl.utils import mdp_visualizer 21 | from simple_rl.agents import RandomAgent 22 | from simple_rl.mdp import MDPDistribution 23 | from simple_rl.abstraction.state_abs import indicator_funcs as ind_funcs 24 | from simple_rl.abstraction.state_abs.sa_helpers import visualize_state_abstr_grid 25 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 26 | parentdir = os.path.dirname(currentdir) 27 | sys.path.insert(0, parentdir) 28 | from abstraction_experiments import * 29 | from hierarch import hierarchy_helpers 30 | import make_mdp 31 | 32 | colors = [[240, 163, 255], [153, 63, 0],\ 33 | [113, 113, 198],\ 34 | [85, 85, 85], [198, 113, 113],\ 35 | [142, 56, 142], [125, 158, 192],\ 36 | [184, 221, 255],[197, 193, 170],[142, 142, 56],\ 37 | [56, 142, 142], [113, 198, 113],[245, 228, 199]] 38 | 39 | 40 | def visualize_options_grid(grid_mdp, state_space, action_abstr, scr_width=720, scr_height=720): 41 | ''' 42 | Args: 43 | grid_mdp (GridWorldMDP) 44 | state_space (list of State) 45 | action_abstr (ActionAbstraction) 46 | ''' 47 | pygame.init() 48 | title_font = pygame.font.SysFont("CMU Serif", 32) 49 | small_font = pygame.font.SysFont("CMU Serif", 22) 50 | 51 | if len(action_abstr.get_actions()) == 0: 52 | print "Options Error: 0 options found. Can't visualize." 53 | sys.exit(0) 54 | 55 | if isinstance(grid_mdp, MDPDistribution): 56 | goal_locs = set([]) 57 | for m in grid_mdp.get_all_mdps(): 58 | for g in m.get_goal_locs(): 59 | goal_locs.add(g) 60 | grid_mdp = grid_mdp.sample() 61 | else: 62 | goal_locs = grid_mdp.get_goal_locs() 63 | 64 | # Pygame init. 65 | screen = pygame.display.set_mode((scr_width, scr_height)) 66 | pygame.init() 67 | screen.fill((255, 255, 255)) 68 | pygame.display.update() 69 | mdp_visualizer._draw_title_text(grid_mdp, screen) 70 | option_text_point = scr_width / 2.0 - (14*7), 18*scr_height / 20.0 71 | 72 | # Setup states to compute option init/term funcs. 73 | state_dict = defaultdict(lambda : defaultdict(None)) 74 | for s in state_space: 75 | state_dict[s.x][s.y] = s 76 | 77 | # Draw inital option. 78 | option_index = 0 79 | opt_str = "Option " + str(option_index + 1) + " of " + str(len(action_abstr.get_actions())) # + ":" + str(next_option) 80 | option_text = title_font.render(opt_str, True, (46, 49, 49)) 81 | screen.blit(option_text, option_text_point) 82 | next_option = action_abstr.get_actions()[option_index] 83 | visualize_option(screen, grid_mdp, state_dict, option=next_option) 84 | 85 | # Initiation rect and text. 86 | option_text = small_font.render("Init: ", True, (46, 49, 49)) 87 | screen.blit(option_text, (40, option_text_point[1])) 88 | pygame.draw.rect(screen, colors[-1], (90, option_text_point[1]) + (24, 24)) 89 | 90 | # Terminal rect and text. 91 | option_text = small_font.render("Term: ", True, (46, 49, 49)) 92 | screen.blit(option_text, (scr_width - 150, option_text_point[1])) 93 | pygame.draw.rect(screen, colors[1], (scr_width - 80, option_text_point[1]) + (24, 24)) 94 | pygame.display.flip() 95 | 96 | # Keep updating options every space press. 97 | done = False 98 | while not done: 99 | # Check for key presses. 100 | for event in pygame.event.get(): 101 | if event.type == QUIT or (event.type == KEYDOWN and event.key == K_ESCAPE): 102 | # Quit. 103 | pygame.quit() 104 | sys.exit() 105 | if event.type == KEYDOWN and event.key == K_RIGHT: 106 | # Toggle to the next option. 107 | option_index = (option_index + 1) % len(action_abstr.get_actions()) 108 | elif event.type == KEYDOWN and event.key == K_LEFT: 109 | # Go to the previous option. 110 | option_index = (option_index - 1) % len(action_abstr.get_actions()) 111 | if option_index < 0: 112 | option_index = len(action_abstr.get_actions()) - 1 113 | 114 | next_option = action_abstr.get_actions()[option_index] 115 | visualize_option(screen, grid_mdp, state_dict, option=next_option, goal_locs=goal_locs) 116 | pygame.draw.rect(screen, (255, 255, 255), (130, option_text_point[1]) + (scr_width-290 , 50)) 117 | opt_str = "Option " + str(option_index + 1) + " of " + str(len(action_abstr.get_actions())) # + ":" + str(next_option) 118 | option_text = title_font.render(opt_str, True, (46, 49, 49)) 119 | screen.blit(option_text, option_text_point) 120 | 121 | 122 | def visualize_option(screen, grid_mdp, state_dict, option=None, goal_locs=[]): 123 | ''' 124 | Args: 125 | screen (pygame.Surface) 126 | grid_mdp (GridWorldMDP) 127 | state_dict: 128 | Key: int 129 | Val: dict 130 | Key: int 131 | Val: state 132 | 133 | ''' 134 | pygame.init() 135 | title_font = pygame.font.SysFont("CMU Serif", 32) 136 | small_font = pygame.font.SysFont("CMU Serif", 22) 137 | 138 | 139 | # Action char mapping. 140 | action_char_dict = { 141 | "up":u"\u2191", 142 | "down":u"\u2193", 143 | "left":u"\u2190", 144 | "right":u"\u2192" 145 | } 146 | 147 | # Prep some dimensions/fonts to make drawing easier. 148 | scr_width, scr_height = screen.get_width(), screen.get_height() 149 | width_buffer = scr_width / 10.0 150 | height_buffer = 30 + (scr_height / 10.0) # Add 30 for title. 151 | cell_width = (scr_width - width_buffer * 2) / grid_mdp.width 152 | cell_height = (scr_height - height_buffer * 2) / grid_mdp.height 153 | font_size = int(min(cell_width, cell_height) / 4.0) 154 | reg_font = pygame.font.SysFont("CMU Serif", font_size) 155 | cc_font = pygame.font.SysFont("Courier", font_size*2 + 2) 156 | 157 | # For each row: 158 | for i in range(grid_mdp.width): 159 | # For each column: 160 | for j in range(grid_mdp.height): 161 | 162 | # Default square per state. 163 | top_left_point = width_buffer + cell_width*i, height_buffer + cell_height*j 164 | r = pygame.draw.rect(screen, (46, 49, 49), top_left_point + (cell_width, cell_height), 3) 165 | 166 | if grid_mdp.is_wall(i+1, grid_mdp.height - j): 167 | # Draw the walls. 168 | top_left_point = width_buffer + cell_width*i + 5, height_buffer + cell_height*j + 5 169 | r = pygame.draw.rect(screen, (94, 99, 99), top_left_point + (cell_width-10, cell_height-10), 0) 170 | 171 | if (i+1,grid_mdp.height - j) in goal_locs: 172 | # Draw goal. 173 | circle_center = int(top_left_point[0] + cell_width/2.0), int(top_left_point[1] + cell_height/2.0) 174 | circler_color = (154, 195, 157) 175 | pygame.draw.circle(screen, circler_color, circle_center, int(min(cell_width, cell_height) / 3.0)) 176 | 177 | # Goal text. 178 | text = reg_font.render("Goal", True, (46, 49, 49)) 179 | offset = int(min(cell_width, cell_height) / 3.0) 180 | goal_text_point = circle_center[0] - font_size, circle_center[1] - font_size/1.5 181 | screen.blit(text, goal_text_point) 182 | 183 | if option is not None: 184 | # Visualize option. 185 | 186 | if i+1 not in state_dict.keys() or grid_mdp.height - j not in state_dict[i+1].keys(): 187 | # At a wall, don't draw Option stuff. 188 | continue 189 | 190 | s = state_dict[i+1][grid_mdp.height - j] 191 | 192 | if option.is_init_true(s): 193 | # Init. 194 | r = pygame.draw.rect(screen, colors[-1], (top_left_point[0] + 5, top_left_point[1] + 5) + (cell_width - 10, cell_height - 10), 0) 195 | elif option.is_term_true(s): 196 | # Term. 197 | r = pygame.draw.rect(screen, colors[1], (top_left_point[0] + 5, top_left_point[1] + 5) + (cell_width - 10, cell_height - 10), 0) 198 | else: 199 | # White out old option. 200 | r = pygame.draw.rect(screen, (255, 255, 255), (top_left_point[0] + 5, top_left_point[1] + 5) + (cell_width - 10, cell_height - 10), 0) 201 | 202 | # Draw option policy. 203 | a = option.policy(s) 204 | if a not in action_char_dict: 205 | text_a = a 206 | else: 207 | text_a = action_char_dict[a] 208 | text_center_point = int(top_left_point[0] + cell_width/2.0 - 10), int(top_left_point[1] + cell_height/4.0) 209 | text_rendered_a = cc_font.render(text_a, True, (46, 49, 49)) 210 | screen.blit(text_rendered_a, text_center_point) 211 | 212 | if (i+1,grid_mdp.height - j) in goal_locs: 213 | # Draw goal. 214 | circle_center = int(top_left_point[0] + cell_width/2.0), int(top_left_point[1] + cell_height/2.0) 215 | circler_color = (154, 195, 157) 216 | pygame.draw.circle(screen, circler_color, circle_center, int(min(cell_width, cell_height) / 3.0)) 217 | 218 | # Goal text. 219 | text = reg_font.render("Goal", True, (46, 49, 49)) 220 | offset = int(min(cell_width, cell_height) / 3.0) 221 | goal_text_point = circle_center[0] - font_size, circle_center[1] - font_size/1.5 222 | screen.blit(text, goal_text_point) 223 | 224 | pygame.display.flip() 225 | 226 | def parse_args(): 227 | ''' 228 | Summary: 229 | Parse all arguments 230 | ''' 231 | parser = argparse.ArgumentParser() 232 | parser.add_argument("-abs", type = str, default="sa", nargs = '?', help = "Choose the abstraction type (one of {sa, aa}).") 233 | args = parser.parse_args() 234 | 235 | abs_type = args.abs 236 | 237 | return abs_type 238 | 239 | def main(): 240 | 241 | # MDP Setting. 242 | lifelong = True 243 | mdp_class = "four_room" 244 | grid_dim = 11 245 | is_sa = parse_args() 246 | 247 | # Make single/multi task environment. 248 | environment = make_mdp.make_mdp_distr(mdp_class=mdp_class, grid_dim=grid_dim, gamma=0.9) if lifelong else make_mdp.make_mdp(mdp_class=mdp_class, grid_dim=grid_dim) 249 | actions = environment.get_actions() 250 | abs_type = parse_args() 251 | 252 | if abs_type == "sa": 253 | # (DEFAULT) Visualize State Abstractions. 254 | # hand_sa, hand_aa = get_abstractions(environment, ind_funcs._four_rooms, directed=True) 255 | # sa = get_sa(environment, indic_func=ind_funcs._q_disc_approx_indicator, epsilon=0.2) 256 | sa = get_sa(environment, indic_func=ind_funcs._q_eps_approx_indicator, epsilon=0.2) 257 | visualize_state_abstr_grid(environment, sa) 258 | elif abs_type == "aa": 259 | # Visualize Action Abstractions. 260 | visualize_options_grid(environment, directed_sa.get_ground_states(), directed_aa) 261 | elif abs_type == "hand" and mdp_class == "four_room": 262 | hand_sa, hand_aa = get_abstractions(environment, ind_funcs._four_rooms, directed=True) 263 | visualize_options_grid(environment, hand_sa.get_ground_states(), hand_aa) 264 | elif abs_type == "pblocks": 265 | default_sa, pblocks_aa = get_sa(environment, default=True), action_abs.aa_baselines.get_policy_blocks_aa(environment) 266 | visualize_options_grid(environment, default_sa.get_ground_states(), pblocks_aa) 267 | elif abs_type == "michael": 268 | michael_thing_sa = directed_sa #state_abs.sa_helpers.make_multitask_sa(environment) 269 | michael_thing_aa = action_abs.aa_baselines.get_aa_single_act(environment, michael_thing_sa) 270 | visualize_options_grid(environment, michael_thing_sa.get_ground_states(), michael_thing_aa) 271 | elif abs_type == "opt": 272 | opt_action_sa = directed_sa 273 | opt_action_aa = action_abs.aa_baselines.get_aa_opt_only_single_act(environment, opt_action_sa) 274 | visualize_options_grid(environment, opt_action_sa.get_ground_states(), opt_action_aa) 275 | elif abs_type == "pr": 276 | hi_pr_opt_action_aa = action_abs.aa_baselines.get_aa_high_prob_opt_single_act(environment, directed_sa, delta=0.3) 277 | visualize_options_grid(environment, directed_sa.get_ground_states(), hi_pr_opt_action_aa) 278 | elif abs_type == "hierarch": 279 | sa_stack, aa_stack = hierarchy_helpers.make_hierarchy(environment, num_levels=2) 280 | visualize_options_grid(environment, sa_stack.get_ground_states(), aa_stack) 281 | else: 282 | print "Error: abs type not recognized (" + abs_type + "). Options include {aa, sa}." 283 | quit() 284 | 285 | raw_input("Press any key to quit ") 286 | quit() 287 | 288 | if __name__ == "__main__": 289 | main() 290 | --------------------------------------------------------------------------------