├── SR-LLRL ├── agents │ └── __init__.py ├── simple_rl │ ├── pomdp │ │ ├── __init__.py │ │ ├── BeliefStateClass.py │ │ ├── POMDPClass.py │ │ ├── BeliefMDPClass.py │ │ └── BeliefUpdaterClass.py │ ├── utils │ │ ├── __init__.py │ │ ├── octogrid.txt │ │ ├── save.py │ │ └── make_mdp.py │ ├── mdp │ │ ├── oomdp │ │ │ ├── __init__.py │ │ │ ├── OOMDPObjectClass.py │ │ │ ├── OOMDPClass.py │ │ │ └── OOMDPStateClass.py │ │ ├── markov_game │ │ │ ├── __init__.py │ │ │ └── MarkovGameMDPClass.py │ │ ├── __init__.py │ │ ├── StateClass.py │ │ ├── MDPClass.py │ │ └── MDPDistributionClass.py │ ├── tasks │ │ ├── bandit │ │ │ ├── __init__.py │ │ │ └── BanditMDPClass.py │ │ ├── chain │ │ │ ├── __init__.py │ │ │ ├── ChainStateClass.py │ │ │ └── ChainMDPClass.py │ │ ├── cleanup │ │ │ ├── __init__.py │ │ │ ├── cleanup_door.py │ │ │ ├── cleanup_task.py │ │ │ ├── cleanup_block.py │ │ │ ├── cleanup_room.py │ │ │ └── cleanup_state.py │ │ ├── gather │ │ │ ├── __init__.py │ │ │ └── GatherStateClass.py │ │ ├── gym │ │ │ ├── __init__.py │ │ │ ├── GymStateClass.py │ │ │ └── GymMDPClass.py │ │ ├── hallway │ │ │ ├── __init__.py │ │ │ └── Grid1DClass.py │ │ ├── hanoi │ │ │ ├── __init__.py │ │ │ └── HanoiMDPClass.py │ │ ├── maze_1d │ │ │ ├── __init__.py │ │ │ ├── Maze1DStateClass.py │ │ │ └── Maze1DPOMDPClass.py │ │ ├── puddle │ │ │ ├── __init__.py │ │ │ └── PuddleMDPClass.py │ │ ├── random │ │ │ ├── __init__.py │ │ │ ├── RandomStateClass.py │ │ │ └── RandomMDPClass.py │ │ ├── taxi │ │ │ ├── __init__.py │ │ │ ├── TaxiStateClass.py │ │ │ ├── taxi_helpers.py │ │ │ └── taxi_visualizer.py │ │ ├── trench │ │ │ ├── __init__.py │ │ │ └── TrenchOOMDPState.py │ │ ├── combo_lock │ │ │ ├── __init__.py │ │ │ └── ComboLockMDPClass.py │ │ ├── four_room │ │ │ ├── __init__.py │ │ │ └── FourRoomMDPClass.py │ │ ├── grid_game │ │ │ ├── __init__.py │ │ │ ├── GridGameStateClass.py │ │ │ └── GridGameMDPClass.py │ │ ├── grid_world │ │ │ ├── __init__.py │ │ │ ├── GridWorldStateClass.py │ │ │ └── grid_visualizer.py │ │ ├── navigation │ │ │ └── __init__.py │ │ ├── prisoners │ │ │ ├── __init__.py │ │ │ └── PrisonersDilemmaMDPClass.py │ │ ├── dev_rock_sample │ │ │ ├── __init__.py │ │ │ └── RockSampleMDPClass.py │ │ ├── rock_paper_scissors │ │ │ ├── __init__.py │ │ │ └── RockPaperScissorsMDPClass.py │ │ └── __init__.py │ ├── _version.py │ ├── agents │ │ ├── bandits │ │ │ ├── __init__.py │ │ │ └── LinUCBAgentClass.py │ │ ├── func_approx │ │ │ ├── __init__.py │ │ │ ├── tile_coding.py │ │ │ ├── LinearQAgentClass.py │ │ │ └── GradientBoostingAgentClass.py │ │ ├── RandomAgentClass.py │ │ ├── BeliefAgentClass.py │ │ ├── FixedPolicyAgentClass.py │ │ ├── __init__.py │ │ ├── AgentClass.py │ │ └── DoubleQAgentClass.py │ ├── abstraction │ │ ├── abstr_mdp │ │ │ ├── __init__.py │ │ │ ├── MDPHierarchy.py │ │ │ ├── RewardFuncClass.py │ │ │ ├── TransitionFuncClass.py │ │ │ └── abstr_mdp_funcs.py │ │ ├── action_abs │ │ │ ├── __init__.py │ │ │ ├── InListPredicateClass.py │ │ │ ├── PredicateClass.py │ │ │ ├── PolicyClass.py │ │ │ ├── PolicyFromDictClass.py │ │ │ ├── aa_helpers.py │ │ │ ├── OptionClass.py │ │ │ └── ActionAbstractionClass.py │ │ ├── state_abs │ │ │ ├── __init__.py │ │ │ ├── ProbStateAbstractionClass.py │ │ │ ├── indicator_funcs.py │ │ │ └── StateAbstractionClass.py │ │ ├── hierarchical_planning.py │ │ ├── __init__.py │ │ ├── AbstractionWrapperClass.py │ │ └── AbstractValueIterationClass.py │ ├── experiments │ │ ├── __init__.py │ │ └── ExperimentParametersClass.py │ ├── planning │ │ ├── __init__.py │ │ ├── PlannerClass.py │ │ ├── MCTSClass.py │ │ └── BeliefSparseSamplingClass.py │ └── __init__.py ├── IEEE_SMC_2021_Plots │ ├── raw_plots │ │ ├── Lava.pdf │ │ ├── Maze.pdf │ │ ├── FourRoom.pdf │ │ ├── Q-Lava-Task.pdf │ │ ├── Q-Maze-Task.pdf │ │ ├── Q-Lava-Episode.pdf │ │ ├── Q-Maze-Episode.pdf │ │ ├── Q-FourRoom-Task.pdf │ │ ├── DelayedQ-Lava-Task.pdf │ │ ├── DelayedQ-Maze-Task.pdf │ │ ├── Q-FourRoom-Episode.pdf │ │ ├── DelayedQ-Lava-Episode.pdf │ │ ├── DelayedQ-Maze-Episode.pdf │ │ ├── DelayedQ-FourRoom-Task.pdf │ │ └── DelayedQ-FourRoom-Episode.pdf │ └── figures │ │ ├── Result_1.png │ │ ├── Result_2.png │ │ └── Environments.png ├── .vscode │ ├── settings.json │ └── launch.json ├── result_show_task.py └── result_show_episode.py ├── .gitignore └── README.md /SR-LLRL/agents/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/pomdp/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/mdp/oomdp/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/bandit/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/chain/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/cleanup/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/gather/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/gym/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/hallway/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/hanoi/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/maze_1d/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/puddle/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/random/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/taxi/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/trench/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/_version.py: -------------------------------------------------------------------------------- 1 | __version__ = 0.8 -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/agents/bandits/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/agents/func_approx/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/mdp/markov_game/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/combo_lock/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/four_room/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/grid_game/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/grid_world/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/hallway/Grid1DClass.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/navigation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/prisoners/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/abstraction/abstr_mdp/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/abstraction/action_abs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/abstraction/state_abs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/dev_rock_sample/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/rock_paper_scissors/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/experiments/__init__.py: -------------------------------------------------------------------------------- 1 | from simple_rl.experiments.ExperimentClass import Experiment 2 | -------------------------------------------------------------------------------- /SR-LLRL/IEEE_SMC_2021_Plots/raw_plots/Lava.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kchu/LifelongRL/HEAD/SR-LLRL/IEEE_SMC_2021_Plots/raw_plots/Lava.pdf -------------------------------------------------------------------------------- /SR-LLRL/IEEE_SMC_2021_Plots/raw_plots/Maze.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kchu/LifelongRL/HEAD/SR-LLRL/IEEE_SMC_2021_Plots/raw_plots/Maze.pdf -------------------------------------------------------------------------------- /SR-LLRL/IEEE_SMC_2021_Plots/figures/Result_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kchu/LifelongRL/HEAD/SR-LLRL/IEEE_SMC_2021_Plots/figures/Result_1.png -------------------------------------------------------------------------------- /SR-LLRL/IEEE_SMC_2021_Plots/figures/Result_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kchu/LifelongRL/HEAD/SR-LLRL/IEEE_SMC_2021_Plots/figures/Result_2.png -------------------------------------------------------------------------------- /SR-LLRL/IEEE_SMC_2021_Plots/raw_plots/FourRoom.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kchu/LifelongRL/HEAD/SR-LLRL/IEEE_SMC_2021_Plots/raw_plots/FourRoom.pdf -------------------------------------------------------------------------------- /SR-LLRL/IEEE_SMC_2021_Plots/figures/Environments.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kchu/LifelongRL/HEAD/SR-LLRL/IEEE_SMC_2021_Plots/figures/Environments.png -------------------------------------------------------------------------------- /SR-LLRL/IEEE_SMC_2021_Plots/raw_plots/Q-Lava-Task.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kchu/LifelongRL/HEAD/SR-LLRL/IEEE_SMC_2021_Plots/raw_plots/Q-Lava-Task.pdf -------------------------------------------------------------------------------- /SR-LLRL/IEEE_SMC_2021_Plots/raw_plots/Q-Maze-Task.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kchu/LifelongRL/HEAD/SR-LLRL/IEEE_SMC_2021_Plots/raw_plots/Q-Maze-Task.pdf -------------------------------------------------------------------------------- /SR-LLRL/IEEE_SMC_2021_Plots/raw_plots/Q-Lava-Episode.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kchu/LifelongRL/HEAD/SR-LLRL/IEEE_SMC_2021_Plots/raw_plots/Q-Lava-Episode.pdf -------------------------------------------------------------------------------- /SR-LLRL/IEEE_SMC_2021_Plots/raw_plots/Q-Maze-Episode.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kchu/LifelongRL/HEAD/SR-LLRL/IEEE_SMC_2021_Plots/raw_plots/Q-Maze-Episode.pdf -------------------------------------------------------------------------------- /SR-LLRL/IEEE_SMC_2021_Plots/raw_plots/Q-FourRoom-Task.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kchu/LifelongRL/HEAD/SR-LLRL/IEEE_SMC_2021_Plots/raw_plots/Q-FourRoom-Task.pdf -------------------------------------------------------------------------------- /SR-LLRL/IEEE_SMC_2021_Plots/raw_plots/DelayedQ-Lava-Task.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kchu/LifelongRL/HEAD/SR-LLRL/IEEE_SMC_2021_Plots/raw_plots/DelayedQ-Lava-Task.pdf -------------------------------------------------------------------------------- /SR-LLRL/IEEE_SMC_2021_Plots/raw_plots/DelayedQ-Maze-Task.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kchu/LifelongRL/HEAD/SR-LLRL/IEEE_SMC_2021_Plots/raw_plots/DelayedQ-Maze-Task.pdf -------------------------------------------------------------------------------- /SR-LLRL/IEEE_SMC_2021_Plots/raw_plots/Q-FourRoom-Episode.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kchu/LifelongRL/HEAD/SR-LLRL/IEEE_SMC_2021_Plots/raw_plots/Q-FourRoom-Episode.pdf -------------------------------------------------------------------------------- /SR-LLRL/IEEE_SMC_2021_Plots/raw_plots/DelayedQ-Lava-Episode.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kchu/LifelongRL/HEAD/SR-LLRL/IEEE_SMC_2021_Plots/raw_plots/DelayedQ-Lava-Episode.pdf -------------------------------------------------------------------------------- /SR-LLRL/IEEE_SMC_2021_Plots/raw_plots/DelayedQ-Maze-Episode.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kchu/LifelongRL/HEAD/SR-LLRL/IEEE_SMC_2021_Plots/raw_plots/DelayedQ-Maze-Episode.pdf -------------------------------------------------------------------------------- /SR-LLRL/IEEE_SMC_2021_Plots/raw_plots/DelayedQ-FourRoom-Task.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kchu/LifelongRL/HEAD/SR-LLRL/IEEE_SMC_2021_Plots/raw_plots/DelayedQ-FourRoom-Task.pdf -------------------------------------------------------------------------------- /SR-LLRL/IEEE_SMC_2021_Plots/raw_plots/DelayedQ-FourRoom-Episode.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kchu/LifelongRL/HEAD/SR-LLRL/IEEE_SMC_2021_Plots/raw_plots/DelayedQ-FourRoom-Episode.pdf -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/abstraction/action_abs/InListPredicateClass.py: -------------------------------------------------------------------------------- 1 | class InListPredicate(object): 2 | 3 | def __init__(self, ls): 4 | self.ls = ls 5 | 6 | def is_true(self, x): 7 | return x in self.ls 8 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/abstraction/action_abs/PredicateClass.py: -------------------------------------------------------------------------------- 1 | class Predicate(object): 2 | 3 | def __init__(self, func): 4 | self.func = func 5 | 6 | def is_true(self, x): 7 | return self.func(x) 8 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/abstraction/action_abs/PolicyClass.py: -------------------------------------------------------------------------------- 1 | class Policy(object): 2 | 3 | def __init__(self, policy_lambda): 4 | self.policy_lambda = policy_lambda 5 | 6 | def get_action(self, state): 7 | return self.policy_lambda(state) 8 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/utils/octogrid.txt: -------------------------------------------------------------------------------- 1 | wwwwgwgwgwwww 2 | wwww-w-w-wwww 3 | wwww-w-w-wwww 4 | wwww-w-w-wwww 5 | g-----------g 6 | wwww-----wwww 7 | g-----a-----g 8 | wwww-----wwww 9 | g-----------g 10 | wwww-w-w-wwww 11 | wwww-w-w-wwww 12 | wwww-w-w-wwww 13 | wwwwgwgwgwwww -------------------------------------------------------------------------------- /SR-LLRL/.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.pythonPath": "D:\\ProgramData\\Anaconda3\\python.exe", 3 | "python.workspaceSymbols.exclusionPatterns": [ 4 | "**/site-packages/**", 5 | "**D:\\ProgramData\\Anaconda3\\Lib\\site-packages**", 6 | ], 7 | "git.ignoreLimitWarning": true 8 | } -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/mdp/__init__.py: -------------------------------------------------------------------------------- 1 | from simple_rl.mdp.markov_game.MarkovGameMDPClass import MarkovGameMDP 2 | from simple_rl.mdp.oomdp.OOMDPClass import OOMDP 3 | from simple_rl.mdp.MDPDistributionClass import MDPDistribution 4 | from simple_rl.mdp.MDPClass import MDP 5 | from simple_rl.mdp.StateClass import State 6 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/planning/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Implementations of standard planning algorithms: 3 | 4 | PlannerClass: Abstract class for a planner 5 | ValueIterationClass: Value Iteration. 6 | MCTSClass: Monte Carlo Tree Search. 7 | ''' 8 | 9 | # Grab classes. 10 | from simple_rl.planning.PlannerClass import Planner 11 | from simple_rl.planning.ValueIterationClass import ValueIteration 12 | from simple_rl.planning.MCTSClass import MCTS -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/abstraction/abstr_mdp/MDPHierarchy.py: -------------------------------------------------------------------------------- 1 | # simple_rl imports. 2 | from simple_rl.mdp import MDP 3 | 4 | class MDPHierarchy(MDP): 5 | 6 | def __init__(self, mdp, sa_stack, aa_stack, sample_rate=10): 7 | ''' 8 | Args: 9 | mdp 10 | sa_stack 11 | aa_stack 12 | sample_rate 13 | ''' 14 | self.mdp = mdp 15 | self.sa_stack = sa_stack 16 | self.aa_stack = aa_stack 17 | self.sample_rate = sample_rate 18 | 19 | -------------------------------------------------------------------------------- /SR-LLRL/.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // 使用 IntelliSense 了解相关属性。 3 | // 悬停以查看现有属性的描述。 4 | // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: 当前文件", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "${file}", 12 | "console": "integratedTerminal" 13 | } 14 | ] 15 | } -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/cleanup/cleanup_door.py: -------------------------------------------------------------------------------- 1 | class CleanUpDoor: 2 | def __init__(self, x, y): 3 | self.x = x 4 | self.y = y 5 | 6 | def __hash__(self): 7 | return hash(tuple([self.x, self.y])) 8 | 9 | def copy(self): 10 | return CleanUpDoor(self.x, self.y) 11 | 12 | def __eq__(self, other): 13 | return isinstance(other, CleanUpDoor) and self.x == other.x and self.y == other.y 14 | 15 | def __str__(self): 16 | return str((self.x, self.y)) -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/abstraction/hierarchical_planning.py: -------------------------------------------------------------------------------- 1 | # simple_rl imports. 2 | from simple_rl.planning.PlannerClass import Planner 3 | 4 | def HierarchicalPlanner(Planner): 5 | 6 | def __init__(self, mdp_hierarchy, planner): 7 | self.mdp_hierarchy = mdp_hierarchy 8 | self.planner = planner 9 | 10 | def plan(self, low_level_start_state): 11 | ''' 12 | Args: 13 | low_level_start_state (simple_rl.State) 14 | 15 | Returns: 16 | (list) 17 | ''' 18 | 19 | 20 | 21 | 22 | def make_mdp_hierarchy(mdp, state_abs) -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/maze_1d/Maze1DStateClass.py: -------------------------------------------------------------------------------- 1 | from simple_rl.mdp.StateClass import State 2 | 3 | class Maze1DState(State): 4 | ''' Class for 1D Maze POMDP States ''' 5 | 6 | def __init__(self, name): 7 | self.name = name 8 | is_terminal = name == 'goal' 9 | State.__init__(self, data=name, is_terminal=is_terminal) 10 | 11 | def __hash__(self): 12 | return hash(tuple(self.data)) 13 | 14 | def __str__(self): 15 | return '1DMazeState::{}'.format(self.data) 16 | 17 | def __repr__(self): 18 | return self.__str__() 19 | 20 | def __eq__(self, other): 21 | return isinstance(other, Maze1DState) and self.data == other.data 22 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/gym/GymStateClass.py: -------------------------------------------------------------------------------- 1 | # Python imports 2 | import numpy as np 3 | 4 | # Local imports 5 | from simple_rl.mdp.StateClass import State 6 | 7 | ''' GymStateClass.py: Contains a State class for Gym. ''' 8 | 9 | class GymState(State): 10 | ''' Gym State class ''' 11 | 12 | def __init__(self, data=[], is_terminal=False): 13 | self.data = data 14 | State.__init__(self, data=data, is_terminal=is_terminal) 15 | 16 | def to_rgb(self, x_dim, y_dim): 17 | # 3 by x_length by y_length array with values 0 (0) --> 1 (255) 18 | board = np.zeros(shape=[3, x_dim, y_dim]) 19 | # print self.data, self.data.shape, x_dim, y_dim 20 | return self.data 21 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/abstraction/action_abs/PolicyFromDictClass.py: -------------------------------------------------------------------------------- 1 | # Python imports. 2 | from __future__ import print_function 3 | import random 4 | from collections import defaultdict 5 | 6 | # Other imports. 7 | from simple_rl.abstraction.action_abs.PolicyClass import Policy 8 | 9 | class PolicyFromDict(Policy): 10 | 11 | def __init__(self, policy_dict={}): 12 | self.policy_dict = policy_dict 13 | 14 | def get_action(self, state): 15 | if state not in self.policy_dict.keys(): 16 | print("(PolicyFromDict) Warning: unseen state (" + str(state) + "). Acting randomly.") 17 | return random.choice(list(set(self.policy_dict.values()))) 18 | else: 19 | return self.policy_dict[state] 20 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/grid_world/GridWorldStateClass.py: -------------------------------------------------------------------------------- 1 | ''' GridWorldStateClass.py: Contains the GridWorldState class. ''' 2 | 3 | # Other imports. 4 | from simple_rl.mdp.StateClass import State 5 | 6 | class GridWorldState(State): 7 | ''' Class for Grid World States ''' 8 | 9 | def __init__(self, x, y): 10 | State.__init__(self, data=[x, y]) 11 | self.x = round(x, 5) 12 | self.y = round(y, 5) 13 | 14 | def __hash__(self): 15 | return hash(tuple(self.data)) 16 | 17 | def __str__(self): 18 | return "s: (" + str(self.x) + "," + str(self.y) + ")" 19 | 20 | def __eq__(self, other): 21 | return isinstance(other, GridWorldState) and self.x == other.x and self.y == other.y 22 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/planning/PlannerClass.py: -------------------------------------------------------------------------------- 1 | class Planner(object): 2 | ''' Abstract class for a Planner. ''' 3 | 4 | def __init__(self, mdp, name="planner"): 5 | 6 | self.name = name 7 | 8 | # MDP components. 9 | self.mdp = mdp 10 | self.init_state = self.mdp.get_init_state() 11 | self.states = set([]) 12 | self.actions = mdp.get_actions() 13 | self.reward_func = mdp.get_reward_func() 14 | self.transition_func = mdp.get_transition_func() 15 | self.gamma = mdp.gamma 16 | self.has_planned = False 17 | 18 | def plan(self, state): 19 | pass 20 | 21 | def policy(self, state): 22 | pass 23 | 24 | def __str__(self): 25 | return self.name -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/abstraction/abstr_mdp/RewardFuncClass.py: -------------------------------------------------------------------------------- 1 | # Python imports. 2 | import random 3 | from collections import defaultdict 4 | 5 | class RewardFunc(object): 6 | 7 | def __init__(self, reward_func_lambda, state_space, action_space): 8 | self.reward_dict = make_dict_from_lambda(reward_func_lambda, state_space, action_space) 9 | 10 | def reward_func(self, state, action): 11 | return self.reward_dict[state][action] 12 | 13 | def make_dict_from_lambda(reward_func_lambda, state_space, action_space, sample_rate=1): 14 | reward_dict = defaultdict(lambda:defaultdict(float)) 15 | for s in state_space: 16 | for a in action_space: 17 | for i in range(sample_rate): 18 | reward_dict[s][a] = reward_func_lambda(s, a) / sample_rate 19 | 20 | return reward_dict 21 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/agents/RandomAgentClass.py: -------------------------------------------------------------------------------- 1 | ''' RandomAgentClass.py: Class for a randomly acting RL Agent ''' 2 | 3 | # Python imports. 4 | import random 5 | from collections import defaultdict 6 | 7 | # Other imports 8 | from simple_rl.agents.AgentClass import Agent 9 | 10 | class RandomAgent(Agent): 11 | ''' Class for a random decision maker. ''' 12 | 13 | def __init__(self, actions, name=""): 14 | name = "Random" if name is "" else name 15 | Agent.__init__(self, name=name, actions=actions) 16 | self.count_sa = defaultdict(lambda : defaultdict(lambda: 0)) 17 | self.count_s= defaultdict(lambda : 0) 18 | self.episode_count = defaultdict(lambda : defaultdict(lambda: defaultdict(lambda: 0))) 19 | self.episode_reward = defaultdict(lambda: 0) 20 | 21 | def act(self, state, reward): 22 | return random.choice(self.actions) 23 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/chain/ChainStateClass.py: -------------------------------------------------------------------------------- 1 | ''' ChainStateClass.py: Contains the ChainStateClass class. ''' 2 | 3 | # Other imports. 4 | from simple_rl.mdp.StateClass import State 5 | 6 | class ChainState(State): 7 | ''' Class for Chain MDP States ''' 8 | 9 | def __init__(self, num): 10 | State.__init__(self, data=num) 11 | self.num = num 12 | 13 | def __hash__(self): 14 | return self.num 15 | 16 | def __add__(self, val): 17 | return ChainState(self.num + val) 18 | 19 | def __lt__(self, val): 20 | return self.num < val 21 | 22 | def __str__(self): 23 | return "s." + str(self.num) 24 | 25 | def __eq__(self, other): 26 | ''' 27 | Summary: 28 | Chain states are equal when their num is the same 29 | ''' 30 | return isinstance(other, ChainState) and self.num == other.num 31 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/random/RandomStateClass.py: -------------------------------------------------------------------------------- 1 | ''' RandomStateClass.py: Contains the RandomStateClass class. ''' 2 | 3 | # Other imports 4 | from simple_rl.mdp.StateClass import State 5 | 6 | class RandomState(State): 7 | ''' Class for Random MDP States ''' 8 | 9 | def __init__(self, num): 10 | State.__init__(self, data=num) 11 | self.num = num 12 | 13 | def __hash__(self): 14 | return self.num 15 | 16 | def __add__(self, val): 17 | return RandomState(self.num + val) 18 | 19 | def __lt__(self, val): 20 | return self.num < val 21 | 22 | def __str__(self): 23 | return "s." + str(self.num) 24 | 25 | def __eq__(self, other): 26 | ''' 27 | Summary: 28 | Random states are equal when their num is the same 29 | ''' 30 | return isinstance(other, RandomState) and self.num == other.num 31 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/taxi/TaxiStateClass.py: -------------------------------------------------------------------------------- 1 | ''' TaxiStateClass.py: Contains the TaxiState class. ''' 2 | 3 | # Other imports 4 | from simple_rl.mdp.oomdp.OOMDPStateClass import OOMDPState 5 | 6 | class TaxiState(OOMDPState): 7 | ''' Class for Taxi World States ''' 8 | 9 | def __init__(self, objects): 10 | OOMDPState.__init__(self, objects=objects) 11 | 12 | def get_agent_x(self): 13 | return self.objects["agent"][0]["x"] 14 | 15 | def get_agent_y(self): 16 | return self.objects["agent"][0]["y"] 17 | 18 | def __hash__(self): 19 | 20 | state_hash = str(self.get_agent_x()) + str(self.get_agent_y()) + "00" 21 | 22 | for p in self.objects["passenger"]: 23 | state_hash += str(p["x"]) + str(p["y"]) + str(p["in_taxi"]) 24 | 25 | return int(state_hash) 26 | 27 | def __eq__(self, other_taxi_state): 28 | return hash(self) == hash(other_taxi_state) 29 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/agents/BeliefAgentClass.py: -------------------------------------------------------------------------------- 1 | # Other imports. 2 | from simple_rl.agents.AgentClass import Agent 3 | from simple_rl.pomdp.BeliefStateClass import BeliefState 4 | 5 | class BeliefAgent(Agent): 6 | def __init__(self, name, actions, gamma=0.99): 7 | ''' 8 | Args: 9 | name (str) 10 | actions (list) 11 | gamma (float 12 | ''' 13 | Agent.__init__(self, name, actions, gamma) 14 | 15 | def act(self, belief_state, reward): 16 | ''' 17 | 18 | Args: 19 | belief_state (BeliefState) 20 | reward (float) 21 | 22 | Returns: 23 | action (str) 24 | ''' 25 | pass 26 | 27 | def policy(self, belief_state): 28 | ''' 29 | Args: 30 | belief_state (BeliefState) 31 | 32 | Returns: 33 | action (str) 34 | ''' 35 | return self.act(belief_state, 0) 36 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/experiments/ExperimentParametersClass.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ExperimentParametersClass.py: Contains the ExperimentParameters Class. 3 | 4 | Purpose: Bundles all relevant parameters into an object that can be written to a file. 5 | ''' 6 | 7 | # Python imports. 8 | from collections import defaultdict 9 | 10 | class ExperimentParameters(object): 11 | ''' 12 | Parameters object given to @ExperimentClass instances. 13 | Used for storing all relevant experiment info for reproducibility. 14 | ''' 15 | 16 | def __init__(self, params=defaultdict(lambda: None)): 17 | self.params = params 18 | 19 | def __str__(self): 20 | ''' 21 | Summary: 22 | Creates a str where each key-value (parameterName-value) 23 | appears on a line. 24 | ''' 25 | result = "" 26 | for item in ["\n\t"+ str(key) + " : " + str(value) for key, value in self.params.items()]: 27 | result += item 28 | return result 29 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/agents/FixedPolicyAgentClass.py: -------------------------------------------------------------------------------- 1 | ''' FixedPolicyAgentClass.py: Class for a basic RL Agent ''' 2 | 3 | # Python imports. 4 | from simple_rl.agents.AgentClass import Agent 5 | 6 | class FixedPolicyAgent(Agent): 7 | ''' Agent Class with a fixed policy. ''' 8 | 9 | NAME = "fixed-policy" 10 | 11 | def __init__(self, policy, name=NAME): 12 | ''' 13 | Args: 14 | policy (func: S ---> A) 15 | ''' 16 | Agent.__init__(self, name=name, actions=[]) 17 | self.policy = policy 18 | 19 | def act(self, state, reward): 20 | ''' 21 | Args: 22 | state (State): see StateClass.py 23 | reward (float): the reward associated with arriving in state @state. 24 | 25 | Returns: 26 | (str): action. 27 | ''' 28 | return self.policy(state) 29 | 30 | def set_policy(self, new_policy): 31 | self.policy = new_policy 32 | 33 | def __str__(self): 34 | return str(self.name) 35 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/cleanup/cleanup_task.py: -------------------------------------------------------------------------------- 1 | class CleanUpTask: 2 | def __init__(self, block_color, goal_room_color, block_name=None, goal_room_name=None): 3 | ''' 4 | You can choose which attributes you would like to have represent the blocks and the rooms 5 | ''' 6 | self.goal_room_name = goal_room_name 7 | self.block_color = block_color 8 | self.goal_room_color = goal_room_color 9 | self.block_name = block_name 10 | 11 | def __str__(self): 12 | if self.goal_room_name is None and self.block_name is None: 13 | return self.block_color + " to the " + self.goal_room_color + " room" 14 | elif self.block_name is None: 15 | return self.block_color + " to the room named " + self.goal_room_name 16 | elif self.goal_room_name is None: 17 | return "The block named " + self.block_name + " to the " + self.goal_room_color + " room" 18 | else: 19 | return "The block named " + self.block_name + " to the room named " + self.goal_room_name 20 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/abstraction/__init__.py: -------------------------------------------------------------------------------- 1 | # Classes. 2 | from simple_rl.abstraction.AbstractionWrapperClass import AbstractionWrapper 3 | from simple_rl.abstraction.AbstractValueIterationClass import AbstractValueIteration 4 | from simple_rl.abstraction.state_abs.StateAbstractionClass import StateAbstraction 5 | from simple_rl.abstraction.state_abs.ProbStateAbstractionClass import ProbStateAbstraction 6 | from simple_rl.abstraction.action_abs.ActionAbstractionClass import ActionAbstraction 7 | from simple_rl.abstraction.action_abs.InListPredicateClass import InListPredicate 8 | from simple_rl.abstraction.action_abs.OptionClass import Option 9 | from simple_rl.abstraction.action_abs.PolicyClass import Policy 10 | from simple_rl.abstraction.action_abs.PolicyFromDictClass import PolicyFromDict 11 | from simple_rl.abstraction.action_abs.PredicateClass import Predicate 12 | 13 | # Scripts. 14 | from simple_rl.abstraction.state_abs import sa_helpers, indicator_funcs 15 | from simple_rl.abstraction.action_abs import aa_helpers 16 | from simple_rl.abstraction.abstr_mdp import abstr_mdp_funcs -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/cleanup/cleanup_block.py: -------------------------------------------------------------------------------- 1 | class CleanUpBlock: 2 | 3 | def __init__(self, name, x=0, y=0, color=""): 4 | self.name = name 5 | self.x = x 6 | self.y = y 7 | self.color = color 8 | 9 | @staticmethod 10 | def class_name(): 11 | return "block" 12 | 13 | def name(self): 14 | return self.name 15 | 16 | def __eq__(self, other): 17 | return isinstance(other, CleanUpBlock) and self.x == other.x and self.y == other.y and self.name == other.name \ 18 | and self.color == other.color 19 | 20 | def __hash__(self): 21 | return hash(tuple([self.name, self.x, self.y, self.color])) 22 | 23 | def copy_with_name(self, new_name): 24 | return CleanUpBlock(new_name, x=self.x, y=self.y, color=self.color) 25 | 26 | def copy(self): 27 | return CleanUpBlock(name=self.name, x=self.x, y=self.y, color=self.color) 28 | 29 | def __str__(self): 30 | return "BLOCK. Name: " + self.name + ", (x,y): (" + str(self.x) + "," + str(self.y) + "), Color: " + self.color 31 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/cleanup/cleanup_room.py: -------------------------------------------------------------------------------- 1 | from simple_rl.tasks.cleanup.cleanup_state import CleanUpState 2 | 3 | class CleanUpRoom: 4 | def __init__(self, name, points_in_room=[(x + 1, y + 1) for x in range(24) for y in range(24)], color="blue"): 5 | self.name = name 6 | self.points_in_room = points_in_room 7 | self.color = color 8 | 9 | def contains(self, block): 10 | return (block.x, block.y) in self.points_in_room 11 | 12 | def copy(self): 13 | return CleanUpRoom(self.name, self.points_in_room[:], color=self.color) 14 | 15 | def __hash__(self): 16 | return hash(tuple([self.name, self.color, tuple(self.points_in_room)])) 17 | 18 | def __eq__(self, other): 19 | if not isinstance(other, CleanUpRoom): 20 | return False 21 | 22 | return self.name == other.name and self.color == other.color and \ 23 | CleanUpState.list_eq(self.points_in_room, other.points_in_room) 24 | 25 | def __str__(self): 26 | return "color: " + self.color + ", points: " + " ".join( 27 | str(tup) for tup in self.points_in_room) 28 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/agents/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Implementations of standard RL agents: 3 | 4 | AgentClass: Contains the basic skeleton of an RL Agent. 5 | QLearningAgentClass: Q-Learning. 6 | LinearQAgentClass: Q-Learning with a Linear Approximator. 7 | RandomAgentClass: Random actor. 8 | RMaxAgentClass: R-Max. 9 | LinUCBAgentClass: Contextual Bandit Algorithm. 10 | ''' 11 | 12 | # Grab agent classes. 13 | from simple_rl.agents.AgentClass import Agent 14 | from simple_rl.agents.FixedPolicyAgentClass import FixedPolicyAgent 15 | from simple_rl.agents.QLearningAgentClass import QLearningAgent 16 | from simple_rl.agents.DoubleQAgentClass import DoubleQAgent 17 | from simple_rl.agents.DelayedQAgentClass import DelayedQAgent 18 | from simple_rl.agents.RandomAgentClass import RandomAgent 19 | from simple_rl.agents.RMaxAgentClass import RMaxAgent 20 | from simple_rl.agents.func_approx.LinearQAgentClass import LinearQAgent 21 | try: 22 | from simple_rl.agents.func_approx.DQNAgentClass import DQNAgent 23 | except ImportError: 24 | print("Warning: Tensorflow not installed.") 25 | pass 26 | 27 | from simple_rl.agents.bandits.LinUCBAgentClass import LinUCBAgent -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/mdp/oomdp/OOMDPObjectClass.py: -------------------------------------------------------------------------------- 1 | ''' OOMDPObjectClass.py: Contains the OOMDP Object Class. ''' 2 | 3 | class OOMDPObject(object): 4 | ''' Abstract OOMDP Object class ''' 5 | 6 | def __init__(self, attributes, name="OOMDP-Object"): 7 | ''' 8 | Args: 9 | attributes (dict): {key=attr_name, val=int} 10 | ''' 11 | self.name = name 12 | self.attributes = attributes 13 | 14 | def set_attribute(self, attr, val): 15 | self.attributes[attr] = val 16 | 17 | def get_attribute(self, attr): 18 | return self.attributes[attr] 19 | 20 | def get_obj_state(self): 21 | return self.attributes.values() 22 | 23 | def get_attributes(self): 24 | return self.attributes 25 | 26 | def __getitem__(self, key): 27 | return self.attributes[key] 28 | 29 | def __setitem__(self, key, item): 30 | self.attributes[key] = item 31 | 32 | def __str__(self): 33 | result = "o:" + self.name + " [" 34 | for attr in self.attributes: 35 | result += "a:" + str(attr) + " = " + str(self.attributes[attr]) + ", " 36 | return result + "]" 37 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/mdp/markov_game/MarkovGameMDPClass.py: -------------------------------------------------------------------------------- 1 | ''' MarkovGameMDP.py: Contains implementation for simple Markov Games. ''' 2 | 3 | # Python imports. 4 | from __future__ import print_function 5 | 6 | # Other imports. 7 | from simple_rl.mdp.MDPClass import MDP 8 | 9 | class MarkovGameMDP(MDP): 10 | ''' Abstract class for a Markov Decision Process. ''' 11 | def __init__(self, actions, transition_func, reward_func, init_state, gamma=0.99, num_agents=2): 12 | MDP.__init__(self, actions, transition_func, reward_func, init_state=init_state, gamma=gamma) 13 | self.num_agents = num_agents 14 | 15 | def execute_agent_action(self, action_dict): 16 | ''' 17 | Args: 18 | actions (dict): an action for each agent. 19 | ''' 20 | if len(action_dict.keys()) != self.num_agents: 21 | raise ValueError("Error: only", len(action_dict.keys()), "action(s) was/were provided, but there are", self.num_agents, "agents.") 22 | 23 | reward_dict = self.reward_func(self.cur_state, action_dict) 24 | next_state = self.transition_func(self.cur_state, action_dict) 25 | self.cur_state = next_state 26 | 27 | return reward_dict, next_state 28 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/abstraction/abstr_mdp/TransitionFuncClass.py: -------------------------------------------------------------------------------- 1 | # Python imports. 2 | import random 3 | import numpy as np 4 | from collections import defaultdict 5 | 6 | class TransitionFunc(object): 7 | 8 | def __init__(self, transition_func_lambda, state_space, action_space, sample_rate=1): 9 | self.transition_dict = make_dict_from_lambda(transition_func_lambda, state_space, action_space, sample_rate) 10 | 11 | def transition_func(self, state, action): 12 | next_state_sample_list = list(np.random.multinomial(1, self.transition_dict[state][action].values()).tolist()) 13 | if len(self.transition_dict[state][action].keys()) == 0: 14 | return state 15 | return self.transition_dict[state][action].keys()[next_state_sample_list.index(1)] 16 | 17 | def make_dict_from_lambda(transition_func_lambda, state_space, action_space, sample_rate=1): 18 | transition_dict = defaultdict(lambda:defaultdict(lambda:defaultdict(int))) 19 | 20 | for s in list(state_space)[:]: 21 | for a in action_space: 22 | for i in range(sample_rate): 23 | s_prime = transition_func_lambda(s, a) 24 | 25 | transition_dict[s][a][s_prime] += (1.0 / sample_rate) 26 | 27 | return transition_dict 28 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/trench/TrenchOOMDPState.py: -------------------------------------------------------------------------------- 1 | from simple_rl.mdp.oomdp.OOMDPStateClass import OOMDPState 2 | 3 | class TrenchOOMDPState(OOMDPState): 4 | ''' Class for Trench World States ''' 5 | 6 | def __init__(self, objects): 7 | OOMDPState.__init__(self, objects=objects) 8 | 9 | def get_agent_x(self): 10 | return self.objects["agent"][0]["x"] 11 | 12 | def get_agent_y(self): 13 | return self.objects["agent"][0]["y"] 14 | 15 | def __hash__(self): 16 | state_hash = str(self.get_agent_x()) + str(self.get_agent_y()) + str(self.objects["agent"][0]["dx"] + 1)\ 17 | + str(self.objects["agent"][0]["dy"] + 1) + str(self.objects["agent"][0]["dest_x"])\ 18 | + str(self.objects["agent"][0]["dest_x"]) + str(self.objects["agent"][0]["dest_y"]) + \ 19 | str(self.objects["agent"][0]["has_block"]) + "00" 20 | 21 | for b in self.objects["block"]: 22 | state_hash += str(b["x"]) + str(b["y"]) 23 | 24 | state_hash += "00" 25 | 26 | for l in self.objects["lava"]: 27 | state_hash += str(l["x"]) + str(l["y"]) 28 | 29 | return int(state_hash) 30 | 31 | def __eq__(self, other_trench_state): 32 | return hash(self) == hash(other_trench_state) 33 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/mdp/oomdp/OOMDPClass.py: -------------------------------------------------------------------------------- 1 | ''' 2 | OOMDPClass.py: Implementation for Object-Oriented MDPs. 3 | 4 | From: 5 | Diuk, Carlos, Andre Cohen, and Michael L. Littman. 6 | "An object-oriented representation for efficient reinforcement learning." 7 | Proceedings of the 25th international conference on Machine learning. ACM, 2008. 8 | ''' 9 | 10 | # Other imports. 11 | from simple_rl.mdp.MDPClass import MDP 12 | from simple_rl.mdp.oomdp.OOMDPObjectClass import OOMDPObject 13 | 14 | class OOMDP(MDP): 15 | ''' Abstract class for an Object Oriented Markov Decision Process. ''' 16 | 17 | def __init__(self, actions, transition_func, reward_func, init_state, gamma=0.99): 18 | MDP.__init__(self, actions, transition_func, reward_func, init_state=init_state, gamma=gamma) 19 | 20 | def _make_oomdp_objs_from_list_of_dict(self, list_of_attr_dicts, name): 21 | ''' 22 | Ags: 23 | list_of_attr_dicts (list of dict) 24 | name (str): Class of the object. 25 | 26 | Returns: 27 | (list of OOMDPObject) 28 | ''' 29 | objects = [] 30 | 31 | for attr_dict in list_of_attr_dicts: 32 | next_obj = OOMDPObject(attributes=attr_dict, name=name) 33 | objects.append(next_obj) 34 | 35 | return objects -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/pomdp/BeliefStateClass.py: -------------------------------------------------------------------------------- 1 | # Python imports. 2 | from collections import defaultdict 3 | 4 | # Other imports. 5 | from simple_rl.mdp.StateClass import State 6 | 7 | class BeliefState(State): 8 | ''' 9 | Abstract class defining a belief state, i.e a probability distribution over states. 10 | ''' 11 | def __init__(self, belief_distribution): 12 | ''' 13 | Args: 14 | belief_distribution (defaultdict) 15 | ''' 16 | self.distribution = belief_distribution 17 | State.__init__(self, data=belief_distribution.values()) 18 | 19 | def __repr__(self): 20 | return self.__str__() 21 | 22 | def __str__(self): 23 | return 'BeliefState::' + str(self.distribution) 24 | 25 | def belief(self, state): 26 | ''' 27 | Args: 28 | state (State) 29 | Returns: 30 | belief[state] (float): probability that agent is in state 31 | ''' 32 | return self.distribution[state] 33 | 34 | def sample(self, sampling_method='max'): 35 | ''' 36 | Returns: 37 | sampled_state (State) 38 | ''' 39 | if sampling_method == 'max': 40 | return max(self.distribution, key=self.distribution.get) 41 | raise NotImplementedError('Sampling method {} not implemented yet'.format(sampling_method)) 42 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | simple_rl 3 | abstraction/ 4 | action_abs/ 5 | state_abs/ 6 | ... 7 | agents/ 8 | AgentClass.py 9 | QLearningAgentClass.py 10 | RandomAgentClass.py 11 | RMaxAgentClass.py 12 | ... 13 | experiments/ 14 | ExperimentClass.py 15 | ExperimentParameters.py 16 | mdp/ 17 | MDPClass.py 18 | StateClass.py 19 | planning/ 20 | BeliefSparseSamplingClass.py 21 | MCTSClass.py 22 | PlannerClass.py 23 | ValueIterationClass.py 24 | pomdp/ 25 | BeliefMDPClass.py 26 | BeliefStateClass.py 27 | BeliefUpdaterClass.py 28 | POMDPClass.py 29 | tasks/ 30 | chain/ 31 | ChainMDPClass.py 32 | ChainStateClass.py 33 | grid_world/ 34 | GridWorldMPDClass.py 35 | GridWorldStateClass.py 36 | ... 37 | utils/ 38 | chart_utils.py 39 | make_mdp.py 40 | run_experiments.py 41 | 42 | Author and Maintainer: David Abel (cs.brown.edu/~dabel/) 43 | Last Updated: April 23rd, 2018 44 | Contact: dabel@cs.brown.edu 45 | License: Apache 46 | ''' 47 | # Fix xrange to cooperate with python 2 and 3. 48 | try: 49 | xrange 50 | except NameError: 51 | xrange = range 52 | 53 | # Fix input to cooperate with python 2 and 3. 54 | try: 55 | input = raw_input 56 | except NameError: 57 | pass 58 | 59 | # Imports. 60 | import simple_rl.abstraction, simple_rl.agents, simple_rl.experiments, simple_rl.mdp, simple_rl.planning, simple_rl.tasks, simple_rl.utils 61 | import simple_rl.run_experiments 62 | 63 | from simple_rl._version import __version__ -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/abstraction/state_abs/ProbStateAbstractionClass.py: -------------------------------------------------------------------------------- 1 | # Python imports. 2 | import numpy as np 3 | 4 | # Other imports. 5 | from simple_rl.mdp.StateClass import State 6 | from simple_rl.abstraction.state_abs.StateAbstractionClass import StateAbstraction 7 | 8 | class ProbStateAbstraction(StateAbstraction): 9 | 10 | def __init__(self, abstr_dist): 11 | ''' 12 | Args: 13 | abstr_dist (dict): Represents Pr(s_phi | phi) 14 | Key: state 15 | Val: dict 16 | Key: s_phi (simple_rl.State) 17 | Val: probability (float) 18 | ''' 19 | self.abstr_dist = abstr_dist 20 | 21 | def phi(self, state): 22 | ''' 23 | Args: 24 | state (State) 25 | 26 | Returns: 27 | state (State) 28 | ''' 29 | 30 | sampled_s_phi_index = np.random.multinomial(1, self.abstr_dist[state].values()).tolist().index(1) 31 | abstr_state = self.abstr_dist[state].keys()[sampled_s_phi_index] 32 | 33 | return abstr_state 34 | 35 | def convert_prob_sa_to_sa(prob_sa): 36 | ''' 37 | Args: 38 | prob_sa (simple_rl.state_abs.ProbStateAbstraction) 39 | 40 | Returns: 41 | (simple_rl.state_abs.StateAbstraction) 42 | ''' 43 | new_phi = {} 44 | 45 | for s_g in prob_sa.abstr_dist.keys(): 46 | new_phi[s_g] = prob_sa.abstr_dist[s_g].keys()[prob_sa.abstr_dist[s_g].values().index(max(prob_sa.abstr_dist[s_g].values()))] 47 | 48 | return StateAbstraction(new_phi) 49 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/agents/AgentClass.py: -------------------------------------------------------------------------------- 1 | ''' AgentClass.py: Class for a basic RL Agent ''' 2 | 3 | # Python imports. 4 | from collections import defaultdict 5 | 6 | class Agent(object): 7 | ''' Abstract Agent class. ''' 8 | 9 | def __init__(self, name, actions, gamma=0.99): 10 | self.name = name 11 | self.actions = list(actions) # Just in case we're given a numpy array (like from Atari). 12 | self.gamma = gamma 13 | self.episode_number = 0 14 | self.prev_state = None 15 | self.prev_action = None 16 | 17 | def act(self, state, reward): 18 | ''' 19 | Args: 20 | state (State): see StateClass.py 21 | reward (float): the reward associated with arriving in state @state. 22 | 23 | Returns: 24 | (str): action. 25 | ''' 26 | pass 27 | 28 | def policy(self, state): 29 | return self.act(state, 0) 30 | 31 | def reset(self): 32 | ''' 33 | Summary: 34 | Resets the agent back to its tabula rasa config. 35 | ''' 36 | self.prev_state = None 37 | self.prev_action = None 38 | self.step_number = 0 39 | 40 | def end_of_episode(self): 41 | ''' 42 | Summary: 43 | Resets the agents prior pointers. 44 | ''' 45 | self.prev_state = None 46 | self.prev_action = None 47 | self.episode_number += 1 48 | 49 | def set_name(self, name): 50 | self.name = name 51 | 52 | def get_name(self): 53 | return self.name 54 | 55 | def __str__(self): 56 | return str(self.name) 57 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | # Python imports. 2 | from __future__ import print_function 3 | 4 | # Grab classes. 5 | from simple_rl.tasks.bandit.BanditMDPClass import BanditMDP 6 | from simple_rl.tasks.chain.ChainMDPClass import ChainMDP 7 | from simple_rl.tasks.chain.ChainStateClass import ChainState 8 | from simple_rl.tasks.combo_lock.ComboLockMDPClass import ComboLockMDP 9 | from simple_rl.tasks.four_room.FourRoomMDPClass import FourRoomMDP 10 | from simple_rl.tasks.gather.GatherMDPClass import GatherMDP 11 | from simple_rl.tasks.gather.GatherStateClass import GatherState 12 | from simple_rl.tasks.grid_game.GridGameMDPClass import GridGameMDP 13 | from simple_rl.tasks.grid_world.GridWorldMDPClass import GridWorldMDP 14 | from simple_rl.tasks.grid_world.GridWorldStateClass import GridWorldState 15 | from simple_rl.tasks.hanoi.HanoiMDPClass import HanoiMDP 16 | from simple_rl.tasks.navigation.NavigationMDP import NavigationMDP 17 | from simple_rl.tasks.prisoners.PrisonersDilemmaMDPClass import PrisonersDilemmaMDP 18 | from simple_rl.tasks.puddle.PuddleMDPClass import PuddleMDP 19 | from simple_rl.tasks.random.RandomMDPClass import RandomMDP 20 | from simple_rl.tasks.random.RandomStateClass import RandomState 21 | from simple_rl.tasks.taxi.TaxiOOMDPClass import TaxiOOMDP 22 | from simple_rl.tasks.taxi.TaxiStateClass import TaxiState 23 | from simple_rl.tasks.trench.TrenchOOMDPClass import TrenchOOMDP 24 | from simple_rl.tasks.rock_paper_scissors.RockPaperScissorsMDPClass import RockPaperScissorsMDP 25 | try: 26 | from simple_rl.tasks.gym.GymMDPClass import GymMDP 27 | except ImportError: 28 | print("Warning: OpenAI gym not installed.") 29 | pass 30 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/mdp/StateClass.py: -------------------------------------------------------------------------------- 1 | # Python imports 2 | import numpy as np 3 | 4 | ''' StateClass.py: Contains the State Class. ''' 5 | 6 | class State(object): 7 | ''' Abstract State class ''' 8 | 9 | def __init__(self, data=[], is_terminal=False): 10 | self.data = data 11 | self._is_terminal = is_terminal 12 | 13 | def features(self): 14 | ''' 15 | Summary 16 | Used by function approximators to represent the state. 17 | Override this method in State subclasses to have functiona 18 | approximators use a different set of features. 19 | Returns: 20 | (iterable) 21 | ''' 22 | return np.array(self.data).flatten() 23 | 24 | def get_data(self): 25 | return self.data 26 | 27 | def get_num_feats(self): 28 | return len(self.features()) 29 | 30 | def is_terminal(self): 31 | return self._is_terminal 32 | 33 | def set_terminal(self, is_term=True): 34 | self._is_terminal = is_term 35 | 36 | def __hash__(self): 37 | if type(self.data).__module__ == np.__name__: 38 | # Numpy arrays 39 | return hash(str(self.data)) 40 | elif self.data.__hash__ is None: 41 | return hash(tuple(self.data)) 42 | else: 43 | return hash(self.data) 44 | 45 | def __str__(self): 46 | return "s." + str(self.data) 47 | 48 | def __eq__(self, other): 49 | return self.data == other.data 50 | 51 | def __getitem__(self, index): 52 | return self.data[index] 53 | 54 | def __len__(self): 55 | return len(self.data) 56 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/gym/GymMDPClass.py: -------------------------------------------------------------------------------- 1 | ''' 2 | GymMDPClass.py: Contains implementation for MDPs of the Gym Environments. 3 | ''' 4 | 5 | # Python imports. 6 | import random 7 | import sys 8 | import os 9 | import random 10 | 11 | # Other imports. 12 | import gym 13 | from simple_rl.mdp.MDPClass import MDP 14 | from simple_rl.tasks.gym.GymStateClass import GymState 15 | 16 | class GymMDP(MDP): 17 | ''' Class for Gym MDPs ''' 18 | 19 | def __init__(self, env_name='CartPole-v0', render=False): 20 | ''' 21 | Args: 22 | env_name (str) 23 | ''' 24 | self.env_name = env_name 25 | self.env = gym.make(env_name) 26 | self.render = render 27 | MDP.__init__(self, range(self.env.action_space.n), self._transition_func, self._reward_func, init_state=GymState(self.env.reset())) 28 | 29 | def _reward_func(self, state, action): 30 | ''' 31 | Args: 32 | state (AtariState) 33 | action (str) 34 | 35 | Returns 36 | (float) 37 | ''' 38 | obs, reward, is_terminal, info = self.env.step(action) 39 | 40 | if self.render: 41 | self.env.render() 42 | 43 | self.next_state = GymState(obs, is_terminal=is_terminal) 44 | 45 | return reward 46 | 47 | def _transition_func(self, state, action): 48 | ''' 49 | Args: 50 | state (AtariState) 51 | action (str) 52 | 53 | Returns 54 | (State) 55 | ''' 56 | return self.next_state 57 | 58 | def reset(self): 59 | self.env.reset() 60 | 61 | def __str__(self): 62 | return "gym-" + str(self.env_name) 63 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/mdp/oomdp/OOMDPStateClass.py: -------------------------------------------------------------------------------- 1 | ''' OOMDPStateClass.py: Contains the OOMDP State Class. ''' 2 | 3 | # Python imports. 4 | from __future__ import print_function 5 | 6 | # Other imports. 7 | from simple_rl.mdp.StateClass import State 8 | 9 | class OOMDPState(State): 10 | ''' OOMDP State class ''' 11 | 12 | def __init__(self, objects): 13 | ''' 14 | Args: 15 | objects (dict of OOMDPObject instances): {key=object class (str):val = object instances} 16 | ''' 17 | self.objects = objects 18 | self.update() 19 | 20 | State.__init__(self, data=self.data) 21 | 22 | def get_objects(self): 23 | return self.objects 24 | 25 | def get_objects_of_class(self, obj_class): 26 | try: 27 | return self.objects[obj_class] 28 | except KeyError: 29 | raise ValueError("Error: given object class (" + str(obj_class) + ") not found in state.\n\t Known classes are: ", self.objects.keys()) 30 | 31 | def get_first_obj_of_class(self, obj_class): 32 | return self.get_objects_of_class(obj_class)[0] 33 | 34 | def update(self): 35 | ''' 36 | Summary: 37 | Turn object attributes into a feature list. 38 | ''' 39 | state_vec = [] 40 | for obj_class in self.objects.keys(): 41 | for obj in self.objects[obj_class]: 42 | state_vec += obj.get_obj_state() 43 | 44 | self.data = tuple(state_vec) 45 | 46 | def __str__(self): 47 | result = "" 48 | for obj_class in self.objects.keys(): 49 | for obj in self.objects[obj_class]: 50 | result += "\t" + str(obj) 51 | result += "\n" 52 | return result 53 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/taxi/taxi_helpers.py: -------------------------------------------------------------------------------- 1 | ''' Helper functions for executing actions in the Taxi Problem ''' 2 | 3 | # Other imports. 4 | from simple_rl.mdp.oomdp.OOMDPObjectClass import OOMDPObject 5 | 6 | def _is_wall_in_the_way(state, dx=0, dy=0): 7 | ''' 8 | Args: 9 | state (TaxiState) 10 | dx (int) [optional] 11 | dy (int) [optional] 12 | 13 | Returns: 14 | (bool): true iff the new loc of the agent is occupied by a wall. 15 | ''' 16 | for wall in state.objects["wall"]: 17 | if wall["x"] == state.objects["agent"][0]["x"] + dx and \ 18 | wall["y"] == state.objects["agent"][0]["y"] + dy: 19 | return True 20 | return False 21 | 22 | def _move_pass_in_taxi(state, dx=0, dy=0): 23 | ''' 24 | Args: 25 | state (TaxiState) 26 | x (int) [optional] 27 | y (int) [optional] 28 | 29 | Returns: 30 | (list of dict): List of new passenger attributes. 31 | 32 | ''' 33 | passenger_attr_dict_ls = state.get_objects_of_class("passenger") 34 | for i, passenger in enumerate(passenger_attr_dict_ls): 35 | if passenger["in_taxi"] == 1: 36 | passenger_attr_dict_ls[i]["x"] += dx 37 | passenger_attr_dict_ls[i]["y"] += dy 38 | 39 | def is_taxi_terminal_state(state): 40 | ''' 41 | Args: 42 | state (OOMDPState) 43 | 44 | Returns: 45 | (bool): True iff all passengers at at their destinations, not in the taxi. 46 | ''' 47 | for p in state.get_objects_of_class("passenger"): 48 | if p.get_attribute("in_taxi") == 1 or p.get_attribute("x") != p.get_attribute("dest_x") or \ 49 | p.get_attribute("y") != p.get_attribute("dest_y"): 50 | return False 51 | return True 52 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/grid_game/GridGameStateClass.py: -------------------------------------------------------------------------------- 1 | ''' GridGameStateClass.py: Contains the GridGameState class. ''' 2 | 3 | # Other imports. 4 | from simple_rl.mdp.StateClass import State 5 | 6 | class GridGameState(State): 7 | ''' Class for two player Grid Game States ''' 8 | 9 | def __init__(self, a_x, a_y, b_x, b_y): 10 | State.__init__(self, data=[a_x, a_y, b_x, b_y]) 11 | self.a_x = a_x 12 | self.a_y = a_y 13 | self.b_x = b_x 14 | self.b_y = b_y 15 | 16 | def __hash__(self): 17 | # The X coordinate takes the first three digits. 18 | if len(str(self.a_x)) < 3: 19 | a_x_str = str(self.a_x) 20 | while len(a_x_str) < 3: 21 | a_x_str = "0" + a_x_str 22 | 23 | # The Y coordinate takes the next three digits. 24 | if len(str(self.a_y)) < 3: 25 | a_y_str = str(self.a_y) 26 | while len(a_y_str) < 3: 27 | a_y_str = "0" + a_y_str 28 | 29 | # The X coordinate takes the first three digits. 30 | if len(str(self.b_x)) < 3: 31 | b_x_str = str(self.b_x) 32 | while len(b_x_str) < 3: 33 | b_x_str = "0" + b_x_str 34 | 35 | # The Y coordinate takes the next three digits. 36 | if len(str(self.b_y)) < 3: 37 | b_y_str = str(self.b_y) 38 | while len(b_y_str) < 3: 39 | b_y_str = "0" + b_y_str 40 | 41 | # Concatenate and return. 42 | return int(a_x_str + a_y_str + "0" + b_x_str + b_y_str) 43 | 44 | def __str__(self): 45 | return "s: (" + str(self.a_x) + "," + str(self.a_y) + ")_a (" + str(self.b_x) + "," + str(self.b_y) + ")_b" 46 | 47 | def __eq__(self, other): 48 | return isinstance(other, GridGameState) and self.a_x == other.a_x and self.a_y == other.a_y and \ 49 | self.b_x == other.b_x and self.b_y == other.b_y 50 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/chain/ChainMDPClass.py: -------------------------------------------------------------------------------- 1 | ''' ChainMDPClass.py: Contains the ChainMDPClass class. ''' 2 | 3 | # Python imports. 4 | from __future__ import print_function 5 | 6 | # Other imports. 7 | from simple_rl.mdp.MDPClass import MDP 8 | from simple_rl.tasks.chain.ChainStateClass import ChainState 9 | 10 | class ChainMDP(MDP): 11 | ''' Implementation for a standard Chain MDP ''' 12 | 13 | ACTIONS = ["forward", "reset"] 14 | 15 | def __init__(self, num_states=5, reset_val=0.01, gamma=0.99): 16 | ''' 17 | Args: 18 | num_states (int) [optional]: Number of states in the chain. 19 | ''' 20 | MDP.__init__(self, ChainMDP.ACTIONS, self._transition_func, self._reward_func, init_state=ChainState(1), gamma=gamma) 21 | self.num_states = num_states 22 | self.reset_val = reset_val 23 | 24 | def _reward_func(self, state, action): 25 | ''' 26 | Args: 27 | state (State) 28 | action (str) 29 | statePrime 30 | 31 | Returns 32 | (float) 33 | ''' 34 | if action == "forward" and state.num == self.num_states: 35 | return 1 36 | elif action == "reset": 37 | return self.reset_val 38 | else: 39 | return 0 40 | 41 | def _transition_func(self, state, action): 42 | ''' 43 | Args: 44 | state (State) 45 | action (str) 46 | 47 | Returns 48 | (State) 49 | ''' 50 | if action == "forward": 51 | if state < self.num_states: 52 | return state + 1 53 | else: 54 | return state 55 | elif action == "reset": 56 | return ChainState(1) 57 | else: 58 | raise ValueError("(simple_rl Error): Unrecognized action! (" + action + ")") 59 | 60 | def __str__(self): 61 | return "chain-" + str(self.num_states) 62 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/prisoners/PrisonersDilemmaMDPClass.py: -------------------------------------------------------------------------------- 1 | ''' PrisonersDilemmaMDPClass.py: Contains an implementation of PrisonersDilemma. ''' 2 | 3 | # Python imports. 4 | import random 5 | 6 | # Other imports. 7 | from simple_rl.mdp.markov_game.MarkovGameMDPClass import MarkovGameMDP 8 | from simple_rl.mdp.StateClass import State 9 | 10 | class PrisonersDilemmaMDP(MarkovGameMDP): 11 | ''' Class for a Grid World MDP ''' 12 | 13 | # Static constants. 14 | ACTIONS = ["defect", "cooperate"] 15 | 16 | def __init__(self): 17 | MarkovGameMDP.__init__(self, PrisonersDilemmaMDP.ACTIONS, self._transition_func, self._reward_func, init_state=State()) 18 | 19 | def _reward_func(self, state, action_dict): 20 | ''' 21 | Args: 22 | state (State) 23 | action (dict of actions) 24 | 25 | Returns 26 | (float) 27 | ''' 28 | agent_a, agent_b = action_dict.keys()[0], action_dict.keys()[1] 29 | action_a, action_b = action_dict[agent_a], action_dict[agent_b] 30 | 31 | reward_dict = {} 32 | 33 | if action_a == action_b == "cooperate": 34 | reward_dict[agent_a], reward_dict[agent_b] = 2, 2 35 | elif action_a == action_b == "defect": 36 | reward_dict[agent_a], reward_dict[agent_b] = 1, 1 37 | elif action_a == "cooperate" and action_b == "defect": 38 | reward_dict[agent_a] = 0 39 | reward_dict[agent_b] = 3 40 | elif action_a == "defect" and action_b == "cooperate": 41 | reward_dict[agent_a] = 3 42 | reward_dict[agent_b] = 0 43 | 44 | return reward_dict 45 | 46 | 47 | 48 | def _transition_func(self, state, action): 49 | ''' 50 | Args: 51 | state (State) 52 | action_dict (str) 53 | 54 | Returns 55 | (State) 56 | ''' 57 | return state 58 | 59 | def __str__(self): 60 | return "prisoners_dilemma" 61 | 62 | 63 | def main(): 64 | grid_world = GridWorldMDP(5, 10, (1, 1), (6, 7)) 65 | 66 | if __name__ == "__main__": 67 | main() 68 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/rock_paper_scissors/RockPaperScissorsMDPClass.py: -------------------------------------------------------------------------------- 1 | ''' RockPaperScissorsMDP.py: Contains an implementation of a two player Rock Paper Scissors game. ''' 2 | 3 | # Python imports. 4 | import random 5 | 6 | # Other imports 7 | from simple_rl.mdp.markov_game.MarkovGameMDPClass import MarkovGameMDP 8 | from simple_rl.mdp.StateClass import State 9 | 10 | class RockPaperScissorsMDP(MarkovGameMDP): 11 | ''' Class for a Rock Paper Scissors Game ''' 12 | 13 | # Static constants. 14 | ACTIONS = ["rock", "paper", "scissors"] 15 | 16 | def __init__(self): 17 | MarkovGameMDP.__init__(self, RockPaperScissorsMDP.ACTIONS, self._transition_func, self._reward_func, init_state=State()) 18 | 19 | def _reward_func(self, state, action_dict): 20 | ''' 21 | Args: 22 | state (State) 23 | action (dict of actions) 24 | 25 | Returns 26 | (float) 27 | ''' 28 | agent_a, agent_b = action_dict.keys()[0], action_dict.keys()[1] 29 | action_a, action_b = action_dict[agent_a], action_dict[agent_b] 30 | 31 | reward_dict = {} 32 | 33 | # Win conditions. 34 | a_win = (action_a == "rock" and action_b == "scissors") or \ 35 | (action_a == "paper" and action_b == "rock") or \ 36 | (action_a == "scissors" and action_b == "paper") 37 | 38 | if action_a == action_b: 39 | reward_dict[agent_a], reward_dict[agent_b] = 0, 0 40 | elif a_win: 41 | reward_dict[agent_a], reward_dict[agent_b] = 1, -1 42 | else: 43 | reward_dict[agent_a], reward_dict[agent_b] = -1, 1 44 | 45 | return reward_dict 46 | 47 | def _transition_func(self, state, action): 48 | ''' 49 | Args: 50 | state (State) 51 | action_dict (str) 52 | 53 | Returns 54 | (State) 55 | ''' 56 | return state 57 | 58 | def __str__(self): 59 | return "rock_paper_scissors" 60 | 61 | 62 | def main(): 63 | grid_world = RockPaperScissorsMDP() 64 | 65 | if __name__ == "__main__": 66 | main() 67 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/utils/save.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import csv 3 | from shutil import copyfile 4 | 5 | 6 | def csv_path_from_agent(root_path, agent): 7 | """ 8 | Get the saving path from agent object and root path. 9 | :param root_path: (str) 10 | :param agent: (object) 11 | :return: (str) 12 | """ 13 | return root_path + '/results-' + agent.get_name() + '.csv' 14 | 15 | 16 | def lifelong_save(init, path, agent, data=None, instance_number=None): 17 | """ 18 | Save according to a specific data structure designed for lifelong RL experiments. 19 | :param init: (bool) 20 | :param path: (str) 21 | :param agent: agent object 22 | :param data: (dictionary) 23 | :param instance_number: (int) 24 | :return: None 25 | """ 26 | full_path = csv_path_from_agent(path, agent) 27 | if init: 28 | names = ['instance', 'task', 'episode', 'return', 'discounted_return'] 29 | csv_write(names, full_path, 'w') 30 | else: 31 | assert data is not None 32 | assert instance_number is not None 33 | n_tasks = len(data['returns_per_tasks']) 34 | n_episodes = len(data['returns_per_tasks'][0]) 35 | 36 | for i in range(n_tasks): 37 | for j in range(n_episodes): 38 | row = [str(instance_number), str(i + 1), str(j + 1), data['returns_per_tasks'][i][j], 39 | data['discounted_returns_per_tasks'][i][j]] 40 | csv_write(row, full_path, 'a') 41 | 42 | 43 | def csv_write(row, path, mode): 44 | """ 45 | Write a row into a csv. 46 | :param row: (array-like) written row, array-like whose elements are separated in the output file. 47 | :param path: (str) path to the edited csv 48 | :param mode: (str) mode for writing: 'w' override, 'a' append 49 | :return: None 50 | """ 51 | with open(path, mode) as csv_file: 52 | w = csv.writer(csv_file, delimiter=',', quotechar='|', quoting=csv.QUOTE_MINIMAL) 53 | w.writerow(row) 54 | 55 | def save_script(path, script_name='original_script.py'): 56 | if path[-1] != '/': 57 | path = path + '/' 58 | copyfile(sys.argv[0], path + script_name) -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/abstraction/action_abs/aa_helpers.py: -------------------------------------------------------------------------------- 1 | # Other imports. 2 | from simple_rl.planning.ValueIterationClass import ValueIteration 3 | from simple_rl.tasks import GridWorldMDP 4 | from simple_rl.abstraction.action_abs.PredicateClass import Predicate 5 | from simple_rl.abstraction.action_abs.InListPredicateClass import InListPredicate 6 | from simple_rl.abstraction.action_abs.OptionClass import Option 7 | from simple_rl.abstraction.action_abs.PolicyFromDictClass import PolicyFromDict 8 | 9 | # ------------------------ 10 | # -- Goal Based Options -- 11 | # ------------------------ 12 | def make_goal_based_options(mdp_distr): 13 | ''' 14 | Args: 15 | mdp_distr (MDPDistribution) 16 | 17 | Returns: 18 | (list): Contains Option instances. 19 | ''' 20 | 21 | goal_list = set([]) 22 | for mdp in mdp_distr.get_all_mdps(): 23 | vi = ValueIteration(mdp) 24 | state_space = vi.get_states() 25 | for s in state_space: 26 | if s.is_terminal(): 27 | goal_list.add(s) 28 | 29 | options = set([]) 30 | for mdp in mdp_distr.get_all_mdps(): 31 | 32 | init_predicate = Predicate(func=lambda x: True) 33 | term_predicate = InListPredicate(ls=goal_list) 34 | o = Option(init_predicate=init_predicate, 35 | term_predicate=term_predicate, 36 | policy=_make_mini_mdp_option_policy(mdp), 37 | term_prob=0.0) 38 | options.add(o) 39 | 40 | return options 41 | 42 | def _make_mini_mdp_option_policy(mini_mdp): 43 | ''' 44 | Args: 45 | mini_mdp (MDP) 46 | 47 | Returns: 48 | Policy 49 | ''' 50 | # Solve the MDP defined by the terminal abstract state. 51 | mini_mdp_vi = ValueIteration(mini_mdp, delta=0.001, max_iterations=1000, sample_rate=10) 52 | iters, val = mini_mdp_vi.run_vi() 53 | 54 | o_policy_dict = make_dict_from_lambda(mini_mdp_vi.policy, mini_mdp_vi.get_states()) 55 | o_policy = PolicyFromDict(o_policy_dict) 56 | 57 | return o_policy.get_action 58 | 59 | def make_dict_from_lambda(policy_func, state_list): 60 | policy_dict = {} 61 | for s in state_list: 62 | policy_dict[s] = policy_func(s) 63 | 64 | return policy_dict 65 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/mdp/MDPClass.py: -------------------------------------------------------------------------------- 1 | ''' MDPClass.py: Contains the MDP Class. ''' 2 | 3 | # Python imports. 4 | import copy 5 | 6 | class MDP(object): 7 | ''' Abstract class for a Markov Decision Process. ''' 8 | 9 | def __init__(self, actions, transition_func, reward_func, init_state, gamma=0.99, step_cost=0): 10 | self.actions = actions 11 | self.transition_func = transition_func 12 | self.reward_func = reward_func 13 | self.gamma = gamma 14 | self.init_state = copy.deepcopy(init_state) 15 | self.cur_state = init_state 16 | self.step_cost = step_cost 17 | 18 | # --------------- 19 | # -- Accessors -- 20 | # --------------- 21 | 22 | def get_init_state(self): 23 | return self.init_state 24 | 25 | def get_curr_state(self): 26 | return self.cur_state 27 | 28 | def get_actions(self): 29 | return self.actions 30 | 31 | def get_gamma(self): 32 | return self.gamma 33 | 34 | def get_reward_func(self): 35 | return self.reward_func 36 | 37 | def get_transition_func(self): 38 | return self.transition_func 39 | 40 | def get_num_state_feats(self): 41 | return self.init_state.get_num_feats() 42 | 43 | # -------------- 44 | # -- Mutators -- 45 | # -------------- 46 | 47 | def set_gamma(self, new_gamma): 48 | self.gamma = new_gamma 49 | 50 | def set_step_cost(self, new_step_cost): 51 | self.step_cost = new_step_cost 52 | 53 | # ---------- 54 | # -- Core -- 55 | # ---------- 56 | 57 | def execute_agent_action(self, action): 58 | ''' 59 | Args: 60 | action (str) 61 | 62 | Returns: 63 | (tuple: ): reward, State 64 | 65 | Summary: 66 | Core method of all of simple_rl. Facilitates interaction 67 | between the MDP and an agent. 68 | ''' 69 | reward = self.reward_func(self.cur_state, action) 70 | next_state = self.transition_func(self.cur_state, action) 71 | self.cur_state = next_state 72 | 73 | return reward, next_state 74 | 75 | def reset(self): 76 | self.cur_state = copy.deepcopy(self.init_state) 77 | 78 | def end_of_instance(self): 79 | pass 80 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/abstraction/AbstractionWrapperClass.py: -------------------------------------------------------------------------------- 1 | # Python imports. 2 | import os 3 | 4 | # Other imports. 5 | from simple_rl.agents import Agent, RMaxAgent, FixedPolicyAgent 6 | from simple_rl.abstraction.state_abs.StateAbstractionClass import StateAbstraction 7 | from simple_rl.abstraction.action_abs.ActionAbstractionClass import ActionAbstraction 8 | 9 | class AbstractionWrapper(Agent): 10 | 11 | def __init__(self, 12 | SubAgentClass, 13 | agent_params={}, 14 | state_abstr=None, 15 | action_abstr=None, 16 | name_ext="-abstr"): 17 | ''' 18 | Args: 19 | SubAgentClass (simple_rl.AgentClass) 20 | agent_params (dict): A dictionary with key=param_name, val=param_value, 21 | to be given to the constructor for the instance of @SubAgentClass. 22 | state_abstr (StateAbstraction) 23 | state_abstr (ActionAbstraction) 24 | name_ext (str) 25 | ''' 26 | 27 | # Setup the abstracted agent. 28 | self.agent = SubAgentClass(**agent_params) 29 | self.action_abstr = action_abstr 30 | self.state_abstr = state_abstr 31 | all_actions = self.action_abstr.get_actions() if self.action_abstr is not None else self.agent.actions 32 | 33 | Agent.__init__(self, name=self.agent.name + name_ext, actions=all_actions) 34 | 35 | def act(self, ground_state, reward): 36 | ''' 37 | Args: 38 | ground_state (State) 39 | reward (float) 40 | 41 | Return: 42 | (str) 43 | ''' 44 | 45 | if self.state_abstr is not None: 46 | abstr_state = self.state_abstr.phi(ground_state) 47 | else: 48 | abstr_state = ground_state 49 | 50 | 51 | if self.action_abstr is not None: 52 | ground_action = self.action_abstr.act(self.agent, abstr_state, ground_state, reward) 53 | else: 54 | ground_action = self.agent.act(abstr_state, reward) 55 | 56 | return ground_action 57 | 58 | def reset(self): 59 | # Write data. 60 | self.agent.reset() 61 | 62 | if self.action_abstr is not None: 63 | self.action_abstr.reset() 64 | 65 | def end_of_episode(self): 66 | self.agent.end_of_episode() 67 | if self.action_abstr is not None: 68 | self.action_abstr.end_of_episode() 69 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/combo_lock/ComboLockMDPClass.py: -------------------------------------------------------------------------------- 1 | ''' ChainMDPClass.py: Contains the ChainMDPClass class. ''' 2 | 3 | # Python imports. 4 | from __future__ import print_function 5 | 6 | # Other imports. 7 | from simple_rl.mdp.MDPClass import MDP 8 | from simple_rl.tasks.chain.ChainStateClass import ChainState 9 | 10 | class ComboLockMDP(MDP): 11 | ''' Imeplementation for a standard Chain MDP ''' 12 | 13 | ACTIONS = [] 14 | 15 | def __init__(self, combo, num_actions=3, num_states=None, reset_val=0.01, gamma=0.99): 16 | ''' 17 | Args: 18 | num_states (int) [optional]: Number of states in the chain. 19 | ''' 20 | ComboLockMDP.ACTIONS = [str(i) for i in range(1, num_actions + 1)] 21 | self.num_states = len(combo) if num_states is None else num_states 22 | self.num_actions = num_actions 23 | self.combo = combo 24 | 25 | if len(combo) != self.num_states: 26 | raise ValueError("(simple_rl.ComboLockMDP Error): Combo length (" + str(len(combo)) + ") must be the same as num_states (" + str(self.num_states) + ").") 27 | elif max(combo) > num_actions: 28 | raise ValueError("(simple_rl.ComboLockMDP Error): Combo (" + str(combo) + ") must only contain values less than or equal to @num_actions (" + str(num_actions) +").") 29 | 30 | MDP.__init__(self, ComboLockMDP.ACTIONS, self._transition_func, self._reward_func, init_state=ChainState(1), gamma=gamma) 31 | 32 | def _reward_func(self, state, action): 33 | ''' 34 | Args: 35 | state (State) 36 | action (str) 37 | statePrime 38 | 39 | Returns 40 | (float) 41 | ''' 42 | if state.num == self.num_states and int(action) == self.combo[state.num - 1]: 43 | return 1 44 | else: 45 | return 0 46 | 47 | def _transition_func(self, state, action): 48 | ''' 49 | Args: 50 | state (State) 51 | action (str) 52 | 53 | Returns 54 | (State) 55 | ''' 56 | # print(state.num, self.num_states, action, self.combo[state.num]) 57 | if int(action) == self.combo[state.num - 1]: 58 | if state < self.num_states: 59 | return state + 1 60 | else: 61 | # At end of chain. 62 | return state 63 | else: 64 | return ChainState(1) 65 | 66 | def __str__(self): 67 | return "combolock-" + str(self.num_states) 68 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/random/RandomMDPClass.py: -------------------------------------------------------------------------------- 1 | ''' RandomMDPClass.py: Contains the RandomMDPClass class. ''' 2 | 3 | # Python imports. 4 | import random 5 | import numpy as np 6 | from collections import defaultdict 7 | 8 | # Other imports. 9 | from simple_rl.mdp.MDPClass import MDP 10 | from simple_rl.tasks.random.RandomStateClass import RandomState 11 | 12 | class RandomMDP(MDP): 13 | ''' Imeplementation for a standard Random MDP ''' 14 | 15 | ACTIONS = [str(i) for i in range(3)] 16 | 17 | def __init__(self, num_states=5, num_rand_trans=5, gamma=0.99): 18 | ''' 19 | Args: 20 | num_states (int) [optional]: Number of states in the Random MDP. 21 | num_rand_trans (int) [optional]: Number of possible next states. 22 | 23 | Summary: 24 | Each state-action pair picks @num_rand_trans possible states and has a uniform distribution 25 | over them for transitions. Rewards are also chosen randomly. 26 | ''' 27 | MDP.__init__(self, RandomMDP.ACTIONS, self._transition_func, self._reward_func, init_state=RandomState(1), gamma=gamma) 28 | # assert(num_rand_trans <= num_states) 29 | self.num_rand_trans = num_rand_trans 30 | self.num_states = num_states 31 | self._reward_s_a = (random.choice(range(self.num_states)), random.choice(RandomMDP.ACTIONS)) 32 | self._transitions = defaultdict(lambda: defaultdict(str)) 33 | 34 | def _reward_func(self, state, action): 35 | ''' 36 | Args: 37 | state (State) 38 | action (str) 39 | statePrime 40 | 41 | Returns 42 | (float) 43 | ''' 44 | if (state.data, action) == self._reward_s_a: 45 | return 1.0 46 | else: 47 | return 0.0 48 | 49 | def _transition_func(self, state, action): 50 | ''' 51 | Args: 52 | state (State) 53 | action (str) 54 | 55 | Returns 56 | (State) 57 | ''' 58 | if self.num_states == 1: 59 | return state 60 | 61 | if (state, action) not in self._transitions: 62 | # Chooses @self.num_rand_trans from range(self.num_states) 63 | self._transitions[state][action] = np.random.choice(self.num_states, self.num_rand_trans, replace=False) 64 | 65 | state_id = np.random.choice(self._transitions[state][action]) 66 | return RandomState(state_id) 67 | 68 | def __str__(self): 69 | return "RandomMDP-" + str(self.num_states) 70 | 71 | 72 | 73 | def main(): 74 | _gen_random_distr() 75 | 76 | if __name__ == "__main__": 77 | main() -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/abstraction/AbstractValueIterationClass.py: -------------------------------------------------------------------------------- 1 | # Python imports. 2 | import random 3 | from collections import defaultdict 4 | 5 | # Other imports. 6 | from simple_rl.utils import make_mdp 7 | from simple_rl.abstraction.action_abs.ActionAbstractionClass import ActionAbstraction 8 | from simple_rl.abstraction.state_abs.StateAbstractionClass import StateAbstraction 9 | from simple_rl.abstraction.abstr_mdp import abstr_mdp_funcs 10 | from simple_rl.planning.PlannerClass import Planner 11 | from simple_rl.planning.ValueIterationClass import ValueIteration 12 | 13 | class AbstractValueIteration(ValueIteration): 14 | ''' AbstractValueIteration: Runs ValueIteration on an abstract MDP induced by the given state and action abstraction ''' 15 | 16 | def __init__(self, ground_mdp, state_abstr=None, action_abstr=None, vi_sample_rate=5, max_iterations=1000, amdp_sample_rate=5, delta=0.001): 17 | ''' 18 | Args: 19 | ground_mdp (simple_rl.MDP) 20 | state_abstr (simple_rl.StateAbstraction) 21 | action_abstr (simple_rl.ActionAbstraction) 22 | vi_sample_rate (int): Num samples per transition for running VI. 23 | max_iterations (int): Usual VI # Iteration bound. 24 | amdp_sample_rate (int): Num samples per abstract transition to use for computing R_abstract, T_abstract. 25 | ''' 26 | self.ground_mdp = ground_mdp 27 | 28 | # Grab ground state space. 29 | vi = ValueIteration(self.ground_mdp, delta=0.001, max_iterations=1000, sample_rate=5) 30 | state_space = vi.get_states() 31 | 32 | # Make the abstract MDP. 33 | self.state_abstr = state_abstr if state_abstr is not None else StateAbstraction(ground_state_space=state_space) 34 | self.action_abstr = action_abstr if action_abstr is not None else ActionAbstraction(prim_actions=ground_mdp.get_actions()) 35 | abstr_mdp = abstr_mdp_funcs.make_abstr_mdp(ground_mdp, self.state_abstr, self.action_abstr, step_cost=0.0, sample_rate=amdp_sample_rate) 36 | 37 | # Create VI with the abstract MDP. 38 | ValueIteration.__init__(self, abstr_mdp, vi_sample_rate, delta, max_iterations) 39 | 40 | def policy(self, state): 41 | ''' 42 | Args: 43 | state (State) 44 | 45 | Returns: 46 | (str): Action 47 | 48 | Summary: 49 | For use in a FixedPolicyAgent. 50 | 51 | # TODO: 52 | Doesn't account for options terminating (policy is over options, currently just grounds them). 53 | ''' 54 | option = self._get_max_q_action(self.state_abstr.phi(state)) 55 | return option.act(state) 56 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/maze_1d/Maze1DPOMDPClass.py: -------------------------------------------------------------------------------- 1 | # Python imports. 2 | from collections import defaultdict 3 | import random 4 | 5 | # Other imports. 6 | from simple_rl.pomdp.POMDPClass import POMDP 7 | from simple_rl.tasks.maze_1d.Maze1DStateClass import Maze1DState 8 | 9 | class Maze1DPOMDP(POMDP): 10 | ''' Class for a 1D Maze POMDP ''' 11 | 12 | ACTIONS = ['west', 'east'] 13 | OBSERVATIONS = ['nothing', 'goal'] 14 | 15 | def __init__(self): 16 | self._states = [Maze1DState('left'), Maze1DState('middle'), Maze1DState('right'), Maze1DState('goal')] 17 | 18 | # Initial belief is a uniform distribution over states 19 | b0 = defaultdict() 20 | for state in self._states: b0[state] = 0.25 21 | 22 | POMDP.__init__(self, Maze1DPOMDP.ACTIONS, Maze1DPOMDP.OBSERVATIONS, self._transition_func, self._reward_func, self._observation_func, b0) 23 | 24 | def _transition_func(self, state, action): 25 | ''' 26 | Args: 27 | state (Maze1DState) 28 | action (str) 29 | 30 | Returns: 31 | next_state (Maze1DState) 32 | ''' 33 | if action == 'west': 34 | if state.name == 'left': 35 | return Maze1DState('left') 36 | if state.name == 'middle': 37 | return Maze1DState('left') 38 | if state.name == 'right': 39 | return Maze1DState('goal') 40 | if state.name == 'goal': 41 | return Maze1DState(random.choice(['left', 'middle', 'right'])) 42 | if action == 'east': 43 | if state.name == 'left': 44 | return Maze1DState('middle') 45 | if state.name == 'middle': 46 | return Maze1DState('goal') 47 | if state.name == 'right': 48 | return Maze1DState('right') 49 | if state.name == 'goal': 50 | return Maze1DState(random.choice(['left', 'middle', 'right'])) 51 | raise ValueError('Invalid state: {} action: {} in 1DMaze'.format(state, action)) 52 | 53 | def _observation_func(self, state, action): 54 | next_state = self._transition_func(state, action) 55 | return 'goal' if next_state.name == 'goal' else 'nothing' 56 | 57 | def _reward_func(self, state, action): 58 | next_state = self._transition_func(state, action) 59 | observation = self._observation_func(state, action) 60 | return (1. - self.step_cost) if (next_state.name == observation == 'goal') else (0. - self.step_cost) 61 | 62 | def is_in_goal_state(self): 63 | return self.cur_state.name == 'goal' 64 | 65 | if __name__ == '__main__': 66 | maze_pomdp = Maze1DPOMDP() 67 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/four_room/FourRoomMDPClass.py: -------------------------------------------------------------------------------- 1 | ''' FourRoomMDPClass.py: Contains the FourRoom class. ''' 2 | 3 | # Python imports. 4 | import math 5 | 6 | # Other imports 7 | from simple_rl.mdp.MDPClass import MDP 8 | from simple_rl.tasks.grid_world.GridWorldMDPClass import GridWorldMDP 9 | from simple_rl.tasks.grid_world.GridWorldStateClass import GridWorldState 10 | 11 | class FourRoomMDP(GridWorldMDP): 12 | ''' Class for a FourRoom ''' 13 | 14 | def __init__(self, width=9, height=9, init_loc=(1,1), goal_locs=[(9,9)], gamma=0.99, slip_prob=0.00, name="four_room", is_goal_terminal=True, rand_init=False, step_cost=0.0): 15 | ''' 16 | Args: 17 | height (int) 18 | width (int) 19 | init_loc (tuple: (int, int)) 20 | goal_locs (list of tuples: [(int, int)...]) 21 | ''' 22 | GridWorldMDP.__init__(self, width, height, init_loc, goal_locs=goal_locs, walls=self._compute_walls(width, height), gamma=gamma, slip_prob=slip_prob, name=name, is_goal_terminal=is_goal_terminal, rand_init=rand_init, step_cost=step_cost) 23 | 24 | def _compute_walls(self, width, height): 25 | ''' 26 | Args: 27 | width (int) 28 | height (int) 29 | 30 | Returns: 31 | (list): Contains (x,y) pairs that define wall locations. 32 | ''' 33 | walls = [] 34 | 35 | half_width = math.ceil(width / 2.0) 36 | half_height = math.ceil(height / 2.0) 37 | 38 | 39 | # # Wall from left to middle. 40 | for i in range(1, width + 1): 41 | if i == half_width: 42 | half_height -= 1 43 | if i == (width + 1) / 3 or i == math.ceil(2 * (width + 1) / 3.0): 44 | continue 45 | 46 | walls.append((i, half_height)) 47 | 48 | # Wall from bottom to top. 49 | for j in range(1, height + 1): 50 | if j == (height + 1) / 3 or j == math.ceil(2 * (height + 1) / 3.0): 51 | continue 52 | walls.append((half_width, j)) 53 | # print(walls) 54 | # walls = [(1,6), (3,6), (4,6), (5,6), (6,6), (6,7), (6,8), (6,10), (6,11),\ 55 | # (6,5), (6,4), (6,3), (6,1), (7,5), (8,5), (10,5), (11,5)] 56 | # walls = [(1,8),(2,8),(3,8),(4,8),(6,8),(7,8),(8,8),(9,8),(11,8),(12,8),(13,8),(14,8),\ 57 | # (8,1),(8,2),(8,3),(8,4),(8,6),(8,7),(8,8),(8,9),(8,11),(8,12),(8,13),(8,14)] 58 | # walls = [(1,9),(2,9),(3,9),(4,9),(5,9),(7,9),(8,9),(9,9),(10,9),(11,9),(13,8),(14,8),(15,8),(16,8),(17,8),\ 59 | # (9,1),(9,2),(9,3),(9,4),(9,5),(9,7),(9,8),(9,9),(9,10),(9,11),(9,13),(9,14),(9,15),(9,16),(9,17)] 60 | return walls 61 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/cleanup/cleanup_state.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | 4 | from simple_rl.mdp.StateClass import State 5 | 6 | from simple_rl.tasks.cleanup.CleanUpMDPClass import CleanUpMDP 7 | 8 | 9 | class CleanUpState(State): 10 | def __init__(self, task, x, y, blocks=[], doors=[], rooms=[]): 11 | ''' 12 | :param task: The given CleanUpTask 13 | :param x: Agent x coordinate 14 | :param y: Agent y coordinate 15 | :param blocks: List of blocks 16 | :param doors: List of doors 17 | :param rooms: List of rooms 18 | ''' 19 | self.x = x 20 | self.y = y 21 | self.blocks = blocks 22 | self.doors = doors 23 | self.rooms = rooms 24 | self.task = task 25 | State.__init__(self, data=[task, (x, y), blocks, doors, rooms]) 26 | 27 | def __hash__(self): 28 | alod = [tuple(self.data[i]) for i in range(1, len(self.data))] 29 | alod.append(self.data[0]) 30 | return hash(tuple(alod)) 31 | 32 | def __str__(self): 33 | str_builder = "(" + str(self.x) + ", " + str(self.y) + ")\n" 34 | str_builder += "\nBLOCKS:\n" 35 | for block in self.blocks: 36 | str_builder += str(block) + "\n" 37 | str_builder += "\nDOORS:\n" 38 | for door in self.doors: 39 | str_builder += str(door) + "\n" 40 | str_builder += "\nROOMS:\n" 41 | for room in self.rooms: 42 | str_builder += str(room) + "\n" 43 | return str_builder 44 | 45 | @staticmethod 46 | def list_eq(alod1, alod2): 47 | ''' 48 | :param alod1: First list 49 | :param alod2: Second list 50 | :return: A boolean indicating whether or not the lists are the same 51 | ''' 52 | if len(alod1) != len(alod2): 53 | return False 54 | sa = set(alod2) 55 | for item in alod1: 56 | if item not in sa: 57 | return False 58 | 59 | return True 60 | 61 | def __eq__(self, other): 62 | return isinstance(other, CleanUpState) and self.x == other.x and self.y == other.y and \ 63 | self.list_eq(other.rooms, self.rooms) and self.list_eq(other.doors, self.doors) and \ 64 | self.list_eq(other.blocks, self.blocks) 65 | 66 | def is_terminal(self): 67 | return CleanUpMDP.is_terminal(self.task, next_state=self) 68 | 69 | def copy(self): 70 | new_blocks = [block.copy() for block in self.blocks] 71 | new_rooms = [room.copy() for room in self.rooms] 72 | new_doors = [door.copy() for door in self.doors] 73 | return CleanUpState(self.task, self.x, self.y, new_blocks, new_doors, new_rooms) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Python implementation of Lifelong Reinforcement Learning (Lifelong RL/LLRL). 2 | 3 | # SR-LLRL 4 | Shaping Rewards for LifeLong Reinforcement Learning 5 | 6 | ## Brief Introduction 7 | Codes for experimenting with proposed approaches to Lifelong RL, attached to our 2021 IEEE SMC paper "Accelerating lifelong reinforcement learning via reshaping rewards". 8 | 9 | Authors: Kun Chu, Xianchao Zhu, William Zhu. 10 | 11 | If you use these codes, please **cite our paper** 12 | 13 | K. Chu, X. Zhu and W. Zhu, "[Accelerating Lifelong Reinforcement Learning via Reshaping Rewards](https://ieeexplore.ieee.org/document/9659064)*," 2021 IEEE International Conference on Systems, Man, and Cybernetics (SMC), 2021, pp. 619-624, doi: 10.1109/SMC52423.2021.9659064. 14 | 15 | BibTeX Style Citation 16 | 17 | ``` 18 | @INPROCEEDINGS{ 19 | author={Chu, Kun and Zhu, Xianchao and Zhu, William}, 20 | booktitle={2021 IEEE International Conference on Systems, Man, and Cybernetics (SMC)}, 21 | title={Accelerating Lifelong Reinforcement Learning via Reshaping Rewards}, 22 | year={2021}, 23 | pages={619-624}, 24 | doi={10.1109/SMC52423.2021.9659064} 25 | } 26 | ``` 27 | 28 | ## Usage 29 | To generate experiemental results, run main.py; 30 | 31 | To draw all of our plots, run result_show_task.py and result_show_episode.py. 32 | 33 | Note that you must choose your learning algorithms or parameters inside the code to generate results/figures. 34 | 35 | ## Important Note 36 | These codes need to import some libraries of python, especially [simple_rl](https://github.com/david-abel/simple_rl) provided by [David Abel](https://github.com/david-abel). However, please note that I have made some improvements and changes based on his codes, so please download the simple_rl inside the fold directly instead of installing from the python official libraries. 37 | 38 | ## Experimental Demonstration 39 | ![png1](https://github.com/Kchu/LifelongRL/blob/master/SR-LLRL/IEEE_SMC_2021_Plots/figures/Environments.png) 40 | ![png2](https://github.com/Kchu/LifelongRL/blob/master/SR-LLRL/IEEE_SMC_2021_Plots/figures/Result_1.png) 41 | ![png3](https://github.com/Kchu/LifelongRL/blob/master/SR-LLRL/IEEE_SMC_2021_Plots/figures/Result_2.png) 42 | 43 | ## Acknowledgment 44 | 45 | Here I want to sincerely thank [David Abel](https://david-abel.github.io/), a great young scientist. He generously shared the source code of his paper in [Github](https://github.com/david-abel/transfer_rl_icml_2018) and gave detailed answers to any of my questions/doubts in the process of conducting this research. I admire his academic achievements, and more importantly, his enthusiastic help and scientific spirit. 46 | 47 | # Last 48 | 49 | Feel free to contact me (kun_chu@outlook.com) with any questions. -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/abstraction/action_abs/OptionClass.py: -------------------------------------------------------------------------------- 1 | # Python imports. 2 | from collections import defaultdict 3 | import random 4 | 5 | # Other imports. 6 | from simple_rl.mdp.StateClass import State 7 | 8 | class Option(object): 9 | 10 | def __init__(self, init_predicate, term_predicate, policy, name="o", term_prob=0.01): 11 | ''' 12 | Args: 13 | init_func (S --> {0,1}) 14 | init_func (S --> {0,1}) 15 | policy (S --> A) 16 | ''' 17 | self.init_predicate = init_predicate 18 | self.term_predicate = term_predicate 19 | self.term_flag = False 20 | self.name = name 21 | self.term_prob = term_prob 22 | 23 | if type(policy) is defaultdict or type(policy) is dict: 24 | self.policy_dict = dict(policy) 25 | self.policy = self.policy_from_dict 26 | else: 27 | self.policy = policy 28 | 29 | def is_init_true(self, ground_state): 30 | return self.init_predicate.is_true(ground_state) 31 | 32 | def is_term_true(self, ground_state): 33 | return self.term_predicate.is_true(ground_state) or self.term_flag or self.term_prob > random.random() 34 | 35 | def act(self, ground_state): 36 | return self.policy(ground_state) 37 | 38 | def set_policy(self, policy): 39 | self.policy = policy 40 | 41 | def set_name(self, new_name): 42 | self.name = new_name 43 | 44 | def act_until_terminal(self, cur_state, transition_func): 45 | ''' 46 | Summary: 47 | Executes the option until termination. 48 | ''' 49 | if self.is_init_true(cur_state): 50 | cur_state = transition_func(cur_state, self.act(cur_state)) 51 | while not self.is_term_true(cur_state): 52 | cur_state = transition_func(cur_state, self.act(cur_state)) 53 | 54 | return cur_state 55 | 56 | def rollout(self, cur_state, reward_func, transition_func, step_cost=0): 57 | ''' 58 | Summary: 59 | Executes the option until termination. 60 | 61 | Returns: 62 | (tuple): 63 | 1. (State): state we landed in. 64 | 2. (float): Reward from the trajectory. 65 | ''' 66 | total_reward = 0 67 | if self.is_init_true(cur_state): 68 | # First step. 69 | total_reward += reward_func(cur_state, self.act(cur_state)) - step_cost 70 | cur_state = transition_func(cur_state, self.act(cur_state)) 71 | 72 | # Act until terminal. 73 | while not self.is_term_true(cur_state): 74 | cur_state = transition_func(cur_state, self.act(cur_state)) 75 | total_reward += reward_func(cur_state, self.act(cur_state)) - step_cost 76 | 77 | return cur_state, total_reward 78 | 79 | def policy_from_dict(self, state): 80 | if state not in self.policy_dict.keys(): 81 | self.term_flag = True 82 | return random.choice(list(set(self.policy_dict.values()))) 83 | else: 84 | self.term_flag = False 85 | return self.policy_dict[state] 86 | 87 | def term_func_from_list(self, state): 88 | return state in self.term_list 89 | 90 | def __str__(self): 91 | return "option." + str(self.name) -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/pomdp/POMDPClass.py: -------------------------------------------------------------------------------- 1 | # Python imports. 2 | from collections import defaultdict 3 | 4 | # Other imports. 5 | from simple_rl.pomdp.BeliefUpdaterClass import BeliefUpdater 6 | from simple_rl.mdp.MDPClass import MDP 7 | 8 | class POMDP(MDP): 9 | ''' Abstract class for a Partially Observable Markov Decision Process. ''' 10 | 11 | def __init__(self, actions, observations, transition_func, reward_func, observation_func, 12 | init_belief, belief_updater_type='discrete', gamma=0.99, step_cost=0): 13 | ''' 14 | In addition to the input parameters needed to define an MDP, the POMDP 15 | definition requires an observation function, a way to update the belief 16 | state and an initial belief. 17 | Args: 18 | actions (list) 19 | observations (list) 20 | transition_func: T(s, a) -> s' 21 | reward_func: R(s, a) -> float 22 | observation_func: O(s, a) -> z 23 | init_belief (defaultdict): initial probability distribution over states 24 | belief_updater_type (str): discrete/kalman/particle 25 | gamma (float) 26 | step_cost (int) 27 | ''' 28 | self.observations = observations 29 | self.observation_func = observation_func 30 | self.init_belief = init_belief 31 | self.curr_belief = init_belief 32 | 33 | # init_belief_state = BeliefState(data=init_belief.values()) 34 | sampled_init_state = max(init_belief, key=init_belief.get) 35 | MDP.__init__(self, actions, transition_func, reward_func, sampled_init_state, gamma, step_cost) 36 | 37 | self.belief_updater = BeliefUpdater(self, transition_func, reward_func, observation_func, belief_updater_type) 38 | self.belief_updater_func = self.belief_updater.updater 39 | 40 | def get_curr_belief(self): 41 | return self.curr_belief 42 | 43 | def get_observation_func(self): 44 | ''' 45 | Returns: 46 | observation_function: O(s, a) -> o 47 | ''' 48 | return self.observation_func 49 | 50 | def get_observations(self): 51 | ''' 52 | Returns: 53 | observations (list): strings representing discrete set of observations 54 | ''' 55 | return self.observations 56 | 57 | def execute_agent_action(self, action): 58 | ''' 59 | Args: 60 | action (str) 61 | 62 | Returns: 63 | reward (float) 64 | next_belief (defaultdict) 65 | ''' 66 | observation = self.observation_func(self.cur_state, action) 67 | new_belief = self.belief_updater_func(self.curr_belief, action, observation) 68 | self.curr_belief = new_belief 69 | 70 | reward, next_state = super(POMDP, self).execute_agent_action(action) 71 | 72 | return reward, observation, new_belief 73 | -------------------------------------------------------------------------------- /SR-LLRL/result_show_task.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | ########################################################################################### 4 | # Implementation of illustrating results. (Average reward for each task) 5 | # Author for codes: Chu Kun(kun_chu@outlook.com), Abel 6 | # Reference: https://github.com/Kchu/LifelongRL 7 | ########################################################################################### 8 | 9 | # Python imports. 10 | import os 11 | from simple_rl.utils import chart_utils 12 | from simple_rl.plot_utils import lifelong_plot 13 | from simple_rl.agents.AgentClass import Agent 14 | 15 | def _get_MDP_name(data_dir): 16 | ''' 17 | Args: 18 | data_dir (str) 19 | 20 | Returns: 21 | (list) 22 | ''' 23 | try: 24 | params_file = open(os.path.join(data_dir, "parameters.txt"), "r") 25 | except IOError: 26 | # No param file. 27 | return [agent_file.replace(".csv", "") for agent_file in os.listdir(data_dir) if os.path.isfile(os.path.join(data_dir, agent_file)) and ".csv" in agent_file] 28 | 29 | MDP_name = [] 30 | 31 | for line in params_file.readlines(): 32 | if "lifelong-" in line: 33 | MDP_name = line.split(" ")[0].strip() 34 | break 35 | 36 | return MDP_name 37 | 38 | def main(): 39 | ''' 40 | Summary: 41 | For manual plotting. 42 | ''' 43 | # Parameter 44 | data_dir = [r'.\results\lifelong-four_room_h-11_w-11-q-learning-vs_task\\'] 45 | output_dir = r'.\plots\\' 46 | 47 | # Format data dir 48 | 49 | # Grab agents 50 | 51 | # Plot. 52 | for index in range(len(data_dir)): 53 | print('Plotting ' + str(index+1) +'th figure.') 54 | agent_names = chart_utils._get_agent_names(data_dir[index]) 55 | agents = [] 56 | actions = [] 57 | if len(agent_names) == 0: 58 | raise ValueError("Error: no csv files found.") 59 | for i in agent_names: 60 | agent = Agent(i, actions) 61 | agents.append(agent) 62 | 63 | # Grab experiment settings 64 | episodic = chart_utils._is_episodic(data_dir[index]) 65 | track_disc_reward = chart_utils._is_disc_reward(data_dir[index]) 66 | mdp_name = _get_MDP_name(data_dir[index]) 67 | lifelong_plot( 68 | agents, 69 | data_dir[index], 70 | output_dir, 71 | n_tasks=40, 72 | n_episodes=100, 73 | confidence=0.95, 74 | open_plot=True, 75 | plot_title=True, 76 | plot_legend=True, 77 | legend_at_bottom=False, 78 | episodes_moving_average=False, 79 | episodes_ma_width=10, 80 | tasks_moving_average=False, 81 | tasks_ma_width=10, 82 | latex_rendering=False, 83 | figure_title=mdp_name) 84 | if __name__ == "__main__": 85 | main() -------------------------------------------------------------------------------- /SR-LLRL/result_show_episode.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | ########################################################################################### 4 | # Implementation of illustrating results. (Average reward for each episode) 5 | # Author for codes: Chu Kun(kun_chu@outlook.com) 6 | # Reference: https://github.com/Kchu/LifelongRL 7 | ########################################################################################### 8 | 9 | # Python imports. 10 | import os 11 | from simple_rl.utils import chart_utils 12 | 13 | def _get_MDP_name(data_dir): 14 | ''' 15 | Args: 16 | data_dir (str) 17 | 18 | Returns: 19 | (list) 20 | ''' 21 | try: 22 | params_file = open(os.path.join(data_dir, "parameters.txt"), "r") 23 | except IOError: 24 | # No param file. 25 | return [agent_file.replace(".csv", "") for agent_file in os.listdir(data_dir) if os.path.isfile(os.path.join(data_dir, agent_file)) and ".csv" in agent_file] 26 | 27 | MDP_name = [] 28 | 29 | for line in params_file.readlines(): 30 | if "lifelong-" in line: 31 | MDP_name = line.split(" ")[0].strip() 32 | break 33 | 34 | return MDP_name 35 | 36 | def main(): 37 | ''' 38 | Summary: 39 | For manual plotting. 40 | ''' 41 | # Parameter 42 | data_dir = ["D:\\MyPapers\\Results_vs_Episodes\\Q-FourRoom\\", "D:\\MyPapers\\Results_vs_Episodes\\Q-Lava\\", 43 | "D:\\MyPapers\\Results_vs_Episodes\\Q-Maze\\", "D:\\MyPapers\\Results_vs_Episodes\\DelayedQ-FourRoom\\", 44 | "D:\\MyPapers\\Results_vs_Episodes\\DelayedQ-Lava\\", 45 | "D:\\MyPapers\\Results_vs_Episodes\\DelayedQ-Maze\\"] 46 | output_dir = "D:\\MyPapers\\Plots\\" 47 | 48 | for index in range(len(data_dir)): 49 | cumulative = False 50 | 51 | # Format data dir 52 | # data_dir[index] = ''.join(data_dir[index]) 53 | # print(data_dir[index]) 54 | if data_dir[index][-1] != "\\": 55 | data_dir[index] = data_dir[index] + "\\" 56 | 57 | # Set output file name 58 | exp_dir_split_list = data_dir[index].split("\\") 59 | file_name = output_dir + exp_dir_split_list[-2] + '-Episode.pdf' 60 | # Grab agents. 61 | agent_names = chart_utils._get_agent_names(data_dir[index]) 62 | if len(agent_names) == 0: 63 | raise ValueError("Error: no csv files found.") 64 | 65 | # Grab experiment settings 66 | episodic = chart_utils._is_episodic(data_dir[index]) 67 | track_disc_reward = chart_utils._is_disc_reward(data_dir[index]) 68 | mdp_name = _get_MDP_name(data_dir[index]) 69 | 70 | # Plot. 71 | chart_utils.make_plots(data_dir[index], agent_names, cumulative=cumulative, episodic=episodic, track_disc_reward=track_disc_reward, figure_title=mdp_name, plot_file_name=file_name) 72 | 73 | if __name__ == "__main__": 74 | main() -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/bandit/BanditMDPClass.py: -------------------------------------------------------------------------------- 1 | ''' BanditMDPClass.py: Contains the BanditMDPClass class. ''' 2 | 3 | # Python imports. 4 | from __future__ import print_function 5 | from collections import defaultdict 6 | import numpy as np 7 | 8 | # Other imports. 9 | from simple_rl.mdp.MDPClass import MDP 10 | from simple_rl.mdp.StateClass import State 11 | 12 | class BanditMDP(MDP): 13 | ''' Imeplementation for a standard Bandit MDP. 14 | 15 | Note: Assumes gaussians with randomly initialized mean and variance 16 | unless payout_distributions is set. 17 | ''' 18 | 19 | ACTIONS = [] 20 | 21 | def __init__(self, num_arms=10, distr_family=np.random.normal, distr_params=None): 22 | ''' 23 | Args: 24 | num_arms (int): Number of arms. 25 | distr_family (lambda): A function from numpy which, when given 26 | entities from @distr_params, samples from the distribution family. 27 | distr_params (dict): If None is given, default mu/sigma for normal 28 | distribution are initialized randomly. 29 | ''' 30 | BanditMDP.ACTIONS = [str(i) for i in range(1, num_arms + 1)] 31 | MDP.__init__(self, BanditMDP.ACTIONS, self._transition_func, self._reward_func, init_state=State(1), gamma=1.0) 32 | self.num_arms = num_arms 33 | self.distr_family = distr_family 34 | self.distr_params = self.init_distr_params() if distr_params is None else distr_params 35 | 36 | def init_distr_params(self): 37 | ''' 38 | Summary: 39 | Creates default distribution parameters for each of 40 | the @self.num_arms arms. Defaults to Gaussian bandits 41 | with each mu ~ Unif(-1,1) and sigma ~ Unif(0,2). 42 | 43 | Returns: 44 | (dict) 45 | ''' 46 | distr_params = defaultdict(lambda: defaultdict(list)) 47 | 48 | for i in range(self.num_arms): 49 | next_mu = np.random.uniform(-1.0, 1.0) 50 | next_sigma = np.random.uniform(0, 2.0) 51 | distr_params[str(i)] = [next_mu, next_sigma] 52 | 53 | return distr_params 54 | 55 | def _reward_func(self, state, action): 56 | ''' 57 | Args: 58 | state (State) 59 | action (str) 60 | statePrime 61 | 62 | Returns 63 | (float) 64 | ''' 65 | # Samples from the distribution associated with @action. 66 | return self.distr_family(*self.distr_params[action]) 67 | 68 | def _transition_func(self, state, action): 69 | ''' 70 | Args: 71 | state (State) 72 | action (str) 73 | 74 | Returns 75 | (State) 76 | 77 | Notes: 78 | Required to fit naturally with the rest of simple_rl, but obviously 79 | doesn't do anything. 80 | ''' 81 | return state 82 | 83 | def __str__(self): 84 | return str(self.num_arms) + "_Armed_Bandit" 85 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/pomdp/BeliefMDPClass.py: -------------------------------------------------------------------------------- 1 | from simple_rl.mdp.MDPClass import MDP 2 | from simple_rl.pomdp.POMDPClass import POMDP 3 | from simple_rl.pomdp.BeliefStateClass import BeliefState 4 | 5 | class BeliefMDP(MDP): 6 | def __init__(self, pomdp): 7 | ''' 8 | Convert given POMDP to a Belief State MDP 9 | Args: 10 | pomdp (POMDP) 11 | ''' 12 | self.state_transition_func = pomdp.transition_func 13 | self.state_reward_func = pomdp.reward_func 14 | self.state_observation_func = pomdp.observation_func 15 | self.belief_updater_func = pomdp.belief_updater_func 16 | 17 | self.pomdp = pomdp 18 | 19 | MDP.__init__(self, pomdp.actions, self._belief_transition_function, self._belief_reward_function, 20 | BeliefState(pomdp.init_belief), pomdp.gamma, pomdp.step_cost) 21 | 22 | def _belief_transition_function(self, belief_state, action): 23 | ''' 24 | The belief MDP transition function T(b, a) --> b' is a generative function that given a belief state and an 25 | action taken from that belief state, returns the most likely next belief state 26 | Args: 27 | belief_state (BeliefState) 28 | action (str) 29 | 30 | Returns: 31 | new_belief (defaultdict) 32 | ''' 33 | observation = self._get_observation_from_environment(action) 34 | next_belief_distribution = self.belief_updater_func(belief_state.distribution, action, observation) 35 | return BeliefState(next_belief_distribution) 36 | 37 | def _belief_reward_function(self, belief_state, action): 38 | ''' 39 | The belief MDP reward function R(b, a) is the expected reward from the POMDP reward function 40 | over the belief state distribution. 41 | Args: 42 | belief_state (BeliefState) 43 | action (str) 44 | 45 | Returns: 46 | reward (float) 47 | ''' 48 | belief = belief_state.distribution 49 | reward = 0. 50 | for state in belief: 51 | reward += belief[state] * self.state_reward_func(state, action) 52 | return reward 53 | 54 | def _get_observation_from_environment(self, action): 55 | ''' 56 | Args: 57 | action (str) 58 | 59 | Returns: 60 | observation (str): retrieve observation from underlying unobserved state in the POMDP 61 | ''' 62 | return self.state_observation_func(self.pomdp.cur_state, action) 63 | 64 | def execute_agent_action(self, action): 65 | reward, next_state = super(BeliefMDP, self).execute_agent_action(action) 66 | self.pomdp.execute_agent_action(action) 67 | 68 | return reward, next_state 69 | 70 | def is_in_goal_state(self): 71 | return self.pomdp.is_in_goal_state() 72 | 73 | if __name__ == '__main__': 74 | from simple_rl.tasks.maze_1d.Maze1DPOMDPClass import Maze1DPOMDP 75 | maze_pomdp = Maze1DPOMDP() 76 | maze_belief_mdp = BeliefMDP(maze_pomdp) 77 | maze_belief_mdp.execute_agent_action('east') 78 | maze_belief_mdp.execute_agent_action('east') -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/hanoi/HanoiMDPClass.py: -------------------------------------------------------------------------------- 1 | ''' HanoiMDPClass.py: Contains a class for the classical planning/puzzle game Towers of Hanoi. ''' 2 | 3 | # Python imports 4 | import itertools 5 | 6 | # Other imports 7 | from simple_rl.mdp.MDPClass import MDP 8 | from simple_rl.mdp.StateClass import State 9 | 10 | class HanoiMDP(MDP): 11 | ''' Class for a Tower of Hanoi MDP ''' 12 | 13 | ACTIONS = ["01", "02", "10", "12", "20", "21"] 14 | 15 | def __init__(self, num_pegs=3, num_discs=3, gamma=0.95): 16 | ''' 17 | Args: 18 | num_pegs (int) 19 | num_discs (int) 20 | gamma (float) 21 | ''' 22 | self.num_pegs = num_pegs 23 | self.num_discs = num_discs 24 | HanoiMDP.ACTIONS = [str(x) + str(y) for x, y in itertools.product(range(self.num_pegs), range(self.num_pegs)) if x != y] 25 | 26 | # Setup init state. 27 | init_state = [" " for peg in range(num_pegs)] 28 | x = "" 29 | for i in range(num_discs): 30 | x += chr(97 + i) 31 | init_state[0] = x 32 | init_state = State(data=init_state) 33 | 34 | MDP.__init__(self, HanoiMDP.ACTIONS, self._transition_func, self._reward_func, init_state=init_state, gamma=gamma) 35 | 36 | def _reward_func(self, state, action): 37 | ''' 38 | Args: 39 | state (State) 40 | action (str) 41 | 42 | Returns 43 | (float) 44 | ''' 45 | source_index = int(action[0]) 46 | dest_index = int(action[1]) 47 | 48 | return int(self._transition_func(state, action).is_terminal()) 49 | 50 | def _transition_func(self, state, action): 51 | ''' 52 | Args: 53 | state (State) 54 | action (str) 55 | 56 | Returns 57 | (State) 58 | ''' 59 | 60 | # Grab top discs on source and dest pegs. 61 | source_index = int(action[0]) 62 | dest_index = int(action[1]) 63 | source_top = state[source_index][-1] 64 | dest_top = state[dest_index][-1] 65 | 66 | # Make new state. 67 | new_state_ls = state.get_data()[:] 68 | if dest_top < source_top: 69 | new_state_ls[source_index] = new_state_ls[source_index][:-1] 70 | if new_state_ls[source_index] == "": 71 | new_state_ls[source_index] = " " 72 | new_state_ls[dest_index] += source_top 73 | new_state_ls[dest_index] = new_state_ls[dest_index].replace(" ", "") 74 | new_state = State(new_state_ls) 75 | 76 | # Set terminal. 77 | if self._is_goal_state(state): # new_state[1] == "abc" or new_state[2] == "abc": 78 | new_state.set_terminal(True) 79 | 80 | return new_state 81 | 82 | def _is_goal_state(self, state): 83 | ''' 84 | Args: 85 | state (simple_rl.State) 86 | 87 | Returns: 88 | (bool) 89 | ''' 90 | for peg in state[1:]: 91 | if len(peg) == self.num_discs and sorted(peg) == list(peg): 92 | return True 93 | return False 94 | 95 | 96 | def __str__(self): 97 | return "hanoi" -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/abstraction/state_abs/indicator_funcs.py: -------------------------------------------------------------------------------- 1 | # Python imports. 2 | from __future__ import print_function 3 | import random 4 | from decimal import Decimal 5 | 6 | # Other imports. 7 | from simple_rl.tasks import FourRoomMDP 8 | 9 | def _four_rooms(state_x, state_y, vi, actions, epsilon=0.0): 10 | if not isinstance(vi.mdp, FourRoomMDP): 11 | raise ValueError("Abstraction Error: four_rooms SA only available for FourRoomMDP/Color. (" + str(vi.mdp) + "given)." ) 12 | height, width = vi.mdp.width, vi.mdp.height 13 | 14 | if (state_x.x < width / 2.0) == (state_y.x < width / 2.0) \ 15 | and (state_x.y < height / 2.0) == (state_y.y < height / 2.0): 16 | return True 17 | return False 18 | 19 | def _random(state_x, state_y, vi, actions, epsilon=0.0): 20 | ''' 21 | Args: 22 | state_x (State) 23 | state_y (State) 24 | vi (ValueIteration) 25 | actions (list) 26 | 27 | Returns: 28 | (bool): true randomly. 29 | ''' 30 | cluster_prob = max(100.0 / vi.get_num_states(), 0.5) 31 | return random.random() > 0.3 32 | 33 | def _v_approx_indicator(state_x, state_y, vi, actions, epsilon=0.0): 34 | ''' 35 | Args: 36 | state_x (State) 37 | state_y (State) 38 | vi (ValueIteration) 39 | actions (list) 40 | 41 | Returns: 42 | (bool): true iff: 43 | max |V(state_x) - V(state_y)| <= epsilon 44 | ''' 45 | return abs(vi.get_value(state_x) - vi.get_value(state_y)) <= epsilon 46 | 47 | def _q_eps_approx_indicator(state_x, state_y, vi, actions, epsilon=0.0): 48 | ''' 49 | Args: 50 | state_x (State) 51 | state_y (State) 52 | vi (ValueIteration) 53 | actions (list) 54 | 55 | Returns: 56 | (bool): true iff: 57 | max |Q(state_x,a) - Q(state_y, a)| <= epsilon 58 | ''' 59 | for a in actions: 60 | q_x = vi.get_q_value(state_x, a) 61 | q_y = vi.get_q_value(state_y, a) 62 | 63 | if abs(q_x - q_y) > epsilon: 64 | return False 65 | 66 | return True 67 | 68 | def _q_disc_approx_indicator(state_x, state_y, vi, actions, epsilon=0.0): 69 | ''' 70 | Args: 71 | state_x (State) 72 | state_y (State) 73 | vi (ValueIteration) 74 | actions (list) 75 | 76 | Returns: 77 | (bool): true iff: 78 | ''' 79 | v_max = 1 #/ (1 - 0.95) 80 | 81 | if epsilon == 0.0: 82 | return _q_eps_approx_indicator(state_x, state_y, vi, actions, epsilon=0) 83 | 84 | for a in actions: 85 | 86 | q_x, q_y = vi.get_q_value(state_x, a), vi.get_q_value(state_y, a) 87 | 88 | bucket_x = int( (q_x * (v_max / epsilon))) 89 | bucket_y = int( (q_y * (v_max / epsilon))) 90 | 91 | if bucket_x != bucket_y: 92 | return False 93 | 94 | return True 95 | 96 | def _v_disc_approx_indicator(state_x, state_y, vi, actions, epsilon=0.0): 97 | ''' 98 | Args: 99 | state_x (State) 100 | state_y (State) 101 | vi (ValueIteration) 102 | actions (list) 103 | 104 | Returns: 105 | (bool): true iff: 106 | ''' 107 | v_max = 1 / (1 - 0.95) 108 | 109 | if epsilon == 0.0: 110 | return _v_approx_indicator(state_x, state_y, vi, actions, epsilon=0) 111 | 112 | v_x, v_y = vi.get_value(state_x), vi.get_value(state_y) 113 | 114 | bucket_x = int( (v_x / v_max) / epsilon) 115 | bucket_y = int( (v_y / v_max) / epsilon) 116 | 117 | return bucket_x == bucket_y 118 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/pomdp/BeliefUpdaterClass.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from simple_rl.planning.ValueIterationClass import ValueIteration 3 | 4 | class BeliefUpdater(object): 5 | ''' Wrapper class for different methods for belief state updates in POMDPs. ''' 6 | 7 | def __init__(self, mdp, transition_func, reward_func, observation_func, updater_type='discrete'): 8 | ''' 9 | Args: 10 | mdp (POMDP) 11 | transition_func: T(s, a) --> s' 12 | reward_func: R(s, a) --> float 13 | observation_func: O(s, a) --> z 14 | updater_type (str) 15 | ''' 16 | self.reward_func = reward_func 17 | self.updater_type = updater_type 18 | 19 | # We use the ValueIteration class to construct the transition and observation probabilities 20 | self.vi = ValueIteration(mdp, sample_rate=500) 21 | 22 | self.transition_probs = self.construct_transition_matrix(transition_func) 23 | self.observation_probs = self.construct_observation_matrix(observation_func, transition_func) 24 | 25 | if updater_type == 'discrete': 26 | self.updater = self.discrete_filter_updater 27 | elif updater_type == 'kalman': 28 | self.updater = self.kalman_filter_updater 29 | elif updater_type == 'particle': 30 | self.updater = self.particle_filter_updater 31 | else: 32 | raise AttributeError('updater_type {} did not conform to expected type'.format(updater_type)) 33 | 34 | def discrete_filter_updater(self, belief, action, observation): 35 | def _compute_normalization_factor(bel): 36 | return sum(bel.values()) 37 | 38 | def _update_belief_for_state(b, sp, T, O, a, z): 39 | return O[sp][z] * sum([T[s][a][sp] * b[s] for s in b]) 40 | 41 | new_belief = defaultdict() 42 | for sprime in belief: 43 | new_belief[sprime] = _update_belief_for_state(belief, sprime, self.transition_probs, self.observation_probs, action, observation) 44 | 45 | normalization = _compute_normalization_factor(new_belief) 46 | 47 | for sprime in belief: 48 | if normalization > 0: new_belief[sprime] /= normalization 49 | 50 | return new_belief 51 | 52 | def kalman_filter_updater(self, belief, action, observation): 53 | pass 54 | 55 | def particle_filter_updater(self, belief, action, observation): 56 | pass 57 | 58 | def construct_transition_matrix(self, transition_func): 59 | ''' 60 | Create an MLE of the transition probabilities by sampling from the transition_func 61 | multiple times. 62 | Args: 63 | transition_func: T(s, a) -> s' 64 | 65 | Returns: 66 | transition_probabilities (defaultdict): T(s, a, s') --> float 67 | ''' 68 | self.vi._compute_matrix_from_trans_func() 69 | return self.vi.trans_dict 70 | 71 | def construct_observation_matrix(self, observation_func, transition_func): 72 | ''' 73 | Create an MLE of the observation probabilities by sampling from the observation_func 74 | multiple times. 75 | Args: 76 | observation_func: O(s) -> z 77 | transition_func: T(s, a) -> s' 78 | 79 | Returns: 80 | observation_probabilities (defaultdict): O(s, z) --> float 81 | ''' 82 | def normalize_probabilities(odict): 83 | norm_factor = sum(odict.values()) 84 | for obs in odict: 85 | odict[obs] /= norm_factor 86 | return odict 87 | 88 | obs_dict = defaultdict(lambda:defaultdict(float)) 89 | for state in self.vi.get_states(): 90 | for action in self.vi.mdp.actions: 91 | for sample in range(self.vi.sample_rate): 92 | observation = observation_func(state, action) 93 | next_state = transition_func(state, action) 94 | obs_dict[next_state][observation] += 1. / self.vi.sample_rate 95 | for state in self.vi.get_states(): 96 | obs_dict[state] = normalize_probabilities(obs_dict[state]) 97 | return obs_dict 98 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/agents/func_approx/tile_coding.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tile Coding Software version 3.0beta 3 | by Rich Sutton 4 | based on a program created by Steph Schaeffer and others 5 | External documentation and recommendations on the use of this code is available in the 6 | reinforcement learning textbook by Sutton and Barto, and on the web. 7 | These need to be understood before this code is. 8 | 9 | This software is for Python 3 or more. 10 | 11 | This is an implementation of grid-style tile codings, based originally on 12 | the UNH CMAC code (see http://www.ece.unh.edu/robots/cmac.htm), but by now highly changed. 13 | Here we provide a function, "tiles", that maps floating and integer 14 | variables to a list of tiles, and a second function "tiles-wrap" that does the same while 15 | wrapping some floats to provided widths (the lower wrap value is always 0). 16 | 17 | The float variables will be gridded at unit intervals, so generalization 18 | will be by approximately 1 in each direction, and any scaling will have 19 | to be done externally before calling tiles. 20 | 21 | Num-tilings should be a power of 2, e.g., 16. To make the offsetting work properly, it should 22 | also be greater than or equal to four times the number of floats. 23 | 24 | The first argument is either an index hash table of a given size (created by (make-iht size)), 25 | an integer "size" (range of the indices from 0), or nil (for testing, indicating that the tile 26 | coordinates are to be returned without being converted to indices). 27 | """ 28 | 29 | basehash = hash 30 | 31 | class IHT(object): 32 | "Structure to handle collisions" 33 | def __init__(self, sizeval): 34 | self.size = sizeval 35 | self.overfullCount = 0 36 | self.dictionary = {} 37 | 38 | def __str__(self): 39 | "Prepares a string for printing whenever this object is printed" 40 | return "Collision table:" + \ 41 | " size:" + str(self.size) + \ 42 | " overfullCount:" + str(self.overfullCount) + \ 43 | " dictionary:" + str(len(self.dictionary)) + " items" 44 | 45 | def count (self): 46 | return len(self.dictionary) 47 | 48 | def fullp (self): 49 | return len(self.dictionary) >= self.size 50 | 51 | def getindex (self, obj, readonly=False): 52 | d = self.dictionary 53 | if obj in d: return d[obj] 54 | elif readonly: return None 55 | size = self.size 56 | count = self.count() 57 | if count >= size: 58 | if self.overfullCount==0: print('IHT full, starting to allow collisions') 59 | self.overfullCount += 1 60 | return basehash(obj) % self.size 61 | else: 62 | d[obj] = count 63 | return count 64 | 65 | def hashcoords(coordinates, m, readonly=False): 66 | if type(m)==IHT: return m.getindex(tuple(coordinates), readonly) 67 | if type(m)==int: return basehash(tuple(coordinates)) % m 68 | if m==None: return coordinates 69 | 70 | from math import floor, log 71 | from itertools import zip_longest 72 | 73 | def tiles (ihtORsize, numtilings, floats, ints=[], readonly=False): 74 | """returns num-tilings tile indices corresponding to the floats and ints""" 75 | qfloats = [floor(f*numtilings) for f in floats] 76 | Tiles = [] 77 | for tiling in range(numtilings): 78 | tilingX2 = tiling*2 79 | coords = [tiling] 80 | b = tiling 81 | for q in qfloats: 82 | coords.append( (q + b) // numtilings ) 83 | b += tilingX2 84 | coords.extend(ints) 85 | Tiles.append(hashcoords(coords, ihtORsize, readonly)) 86 | return Tiles 87 | 88 | def tileswrap (ihtORsize, numtilings, floats, wrawidths, ints=[], readonly=False): 89 | """returns num-tilings tile indices corresponding to the floats and ints, wrapping some floats""" 90 | qfloats = [floor(f*numtilings) for f in floats] 91 | Tiles = [] 92 | for tiling in range(numtilings): 93 | tilingX2 = tiling*2 94 | coords = [tiling] 95 | b = tiling 96 | for q, width in zip_longest(qfloats, wrapwidths): 97 | c = (q + b%numtilings) // numtilings 98 | coords.append(c%width if width else c) 99 | b += tilingX2 100 | coords.extend(ints) 101 | Tiles.append(hashcoords(coords, ihtORsize, readonly)) 102 | return Tiles -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/dev_rock_sample/RockSampleMDPClass.py: -------------------------------------------------------------------------------- 1 | ''' RockSampleMDPClass.py: Contains the RockSample class. ''' 2 | 3 | # Python imports. 4 | import random 5 | import math 6 | import copy 7 | 8 | # Other imports 9 | from simple_rl.mdp.MDPClass import MDP 10 | from simple_rl.tasks.grid_world.GridWorldMDPClass import GridWorldMDP 11 | from simple_rl.mdp.StateClass import State 12 | 13 | class RockSampleMDP(GridWorldMDP): 14 | ''' 15 | Class an MDP adaption of the RockSample POMDP from: 16 | 17 | Trey Smith and Reid Simmons: "Heuristic Search Value Iteration for POMDPs" UAI 2004. 18 | ''' 19 | 20 | ACTIONS = ["up", "down", "left", "right", "sample"] 21 | 22 | def __init__(self, width=8, height=7, init_loc=(1,1), rocks=None, gamma=0.99, slip_prob=0.00, rock_rewards=[0.1, 1, 20], name="rocksample"): 23 | ''' 24 | Args: 25 | height (int) 26 | width (int) 27 | init_loc (tuple: (int, int)) 28 | goal_locs (list of tuples: [(int, int)...]) 29 | ''' 30 | if rocks is None: 31 | rocks = [[1,2,True], [5,4,True], [6,7,True], [1,3,True], [4,5,True], [2,7,False], [2,2,True], [7,4,False]] 32 | self.init_loc = init_loc 33 | self.init_rocks = rocks 34 | self.rock_rewards = rock_rewards 35 | self.name = name + "-" + str(len(rocks)) 36 | self.width = width 37 | self.height = height 38 | MDP.__init__(self, RockSampleMDP.ACTIONS, self._transition_func, self._reward_func, init_state=self.get_init_state(), gamma=gamma) 39 | 40 | def get_init_state(self): 41 | features = [self.init_loc[0], self.init_loc[1]] 42 | for rock in self.init_rocks: 43 | int_rock = [int(f) for f in rock] 44 | features += list(int_rock) 45 | 46 | return State(data=features) 47 | 48 | def _reward_func(self, state, action): 49 | if state[0] == 7 and action == "right": 50 | # Moved into exit area, receive 10 reward. 51 | return 10.0 52 | elif action == "sample": 53 | rock_index = self._get_rock_at_agent_loc(state) 54 | if rock_index != None: 55 | if state.data[rock_index + 2]: 56 | # Sampled good rock. 57 | return self.rock_rewards[rock_index % 3] 58 | else: 59 | # Sampled bad rock. 60 | return -self.rock_rewards[rock_index % 3] 61 | 62 | return 0 63 | 64 | def _transition_func(self, state, action): 65 | ''' 66 | Args: 67 | state (State) 68 | action (str) 69 | 70 | Returns 71 | (State) 72 | ''' 73 | if state.is_terminal(): 74 | return state 75 | 76 | if action == "sample": 77 | # Sample action. 78 | rock_index = self._get_rock_at_agent_loc(state) 79 | if rock_index != None: 80 | # Set to false. 81 | new_data = state.data[:] 82 | new_data[rock_index] = False 83 | next_state = State(data=new_data) 84 | else: 85 | next_state = State(data=state.data) 86 | 87 | elif action == "up" and state.data[1] < self.height: 88 | next_state = State(data=[state.data[0], state.data[1] + 1] + state.data[2:]) 89 | elif action == "down" and state.data[1] > 1: 90 | next_state = State(data=[state.data[0], state.data[1] - 1] + state.data[2:]) 91 | elif action == "right" and state.data[0] < self.width: 92 | next_state = State(data=[state.data[0] + 1, state.data[1]] + state.data[2:]) 93 | elif action == "left" and state.data[0] > 1: 94 | next_state = State(data=[state.data[0] - 1, state.data[1]] + state.data[2:]) 95 | else: 96 | next_state = State(data=state.data) 97 | 98 | if next_state[0] > 7: 99 | next_state.set_terminal(True) 100 | 101 | return next_state 102 | 103 | def _get_rock_at_agent_loc(self, state): 104 | result = None 105 | for i in range(2, len(state.data), 3): 106 | if state.data[i] == state[0] and state.data[i + 1] == state[1]: 107 | return i 108 | 109 | # No rock found. 110 | return None 111 | 112 | def __str__(self): 113 | return self.name 114 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/abstraction/state_abs/StateAbstractionClass.py: -------------------------------------------------------------------------------- 1 | # Python imports. 2 | from collections import defaultdict 3 | 4 | # Other imports. 5 | from simple_rl.mdp.StateClass import State 6 | from simple_rl.mdp.MDPClass import MDP 7 | 8 | class StateAbstraction(object): 9 | 10 | def __init__(self, phi=None, ground_state_space=[]): 11 | ''' 12 | Args: 13 | phi (dict) 14 | ''' 15 | # key:state, val:int. (int represents an abstract state). 16 | self._phi = phi if phi is not None else {s_g: s_g for s_g in ground_state_space} 17 | 18 | def set_phi(self, new_phi): 19 | self._phi = new_phi 20 | 21 | def phi(self, state): 22 | ''' 23 | Args: 24 | state (State) 25 | 26 | Returns: 27 | state (State) 28 | ''' 29 | 30 | # Check types. 31 | if state not in self._phi.keys(): 32 | raise KeyError 33 | 34 | if not isinstance(self._phi[state], State): 35 | raise TypeError 36 | 37 | # Get abstract state. 38 | abstr_state = self._phi[state] 39 | abstr_state.set_terminal(state.is_terminal()) 40 | 41 | return abstr_state 42 | 43 | def make_cluster(self, list_of_ground_states): 44 | if len(list_of_ground_states) == 0: 45 | return 46 | 47 | abstract_value = 0 48 | if len(self._phi.values()) != 0: 49 | abstract_value = max(self._phi.values()) + 1 50 | 51 | for state in list_of_ground_states: 52 | self._phi[state] = abstract_value 53 | 54 | def get_ground_states_in_abs_state(self, abs_state): 55 | ''' 56 | Args: 57 | abs_state (State) 58 | 59 | Returns: 60 | (list): Contains all ground states in the cluster. 61 | ''' 62 | return [s_g for s_g in self.get_ground_states() if self.phi(s_g) == abs_state] 63 | 64 | def get_lower_states_in_abs_state(self, abs_state): 65 | ''' 66 | Args: 67 | abs_state (State) 68 | 69 | Returns: 70 | (list): Contains all ground states in the cluster. 71 | 72 | Notes: 73 | Here to simplify the state abstraction stack subclass. 74 | ''' 75 | return self.get_ground_states_in_abs_state(abs_state) 76 | 77 | def get_abs_states(self): 78 | # For each ground state, get its abstract state. 79 | return set([self.phi(val) for val in set(self._phi.keys())]) 80 | 81 | def get_abs_cluster_num(self, abs_state): 82 | return list(set(self._phi.values())).index(abs_state.data) 83 | 84 | def get_ground_states(self): 85 | return self._phi.keys() 86 | 87 | def get_lower_states(self): 88 | return self.get_ground_states() 89 | 90 | def get_num_abstr_states(self): 91 | return len(set(self._phi.values())) 92 | 93 | def get_num_ground_states(self): 94 | return len(set(self._phi.keys())) 95 | 96 | def reset(self): 97 | self._phi = {} 98 | 99 | def __add__(self, other_abs): 100 | ''' 101 | Args: 102 | other_abs 103 | ''' 104 | merged_state_abs = {} 105 | 106 | # Move the phi into a cluster dictionary. 107 | cluster_dict = defaultdict(list) 108 | for k, v in self._phi.iteritems(): 109 | # Cluster dict: v is abstract, key is ground. 110 | cluster_dict[v].append(k) 111 | 112 | # Move the phi into a cluster dictionary. 113 | other_cluster_dict = defaultdict(list) 114 | for k, v in other_abs._phi.iteritems(): 115 | other_cluster_dict[v].append(k) 116 | 117 | 118 | for ground_state in self._phi.keys(): 119 | 120 | 121 | # Get the two clusters associated with a state. 122 | states_cluster = self._phi[ground_state] 123 | if ground_state in other_abs._phi.keys(): 124 | # Only add if it's in both clusters. 125 | states_other_cluster = other_abs._phi[ground_state] 126 | else: 127 | continue 128 | 129 | for s_g in cluster_dict[states_cluster]: 130 | if s_g in other_cluster_dict[states_other_cluster]: 131 | # Every ground state that's in both clusters, merge. 132 | merged_state_abs[s_g] = states_cluster 133 | 134 | new_sa = StateAbstraction(phi=merged_state_abs) 135 | 136 | return new_sa 137 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/planning/MCTSClass.py: -------------------------------------------------------------------------------- 1 | ''' MCTSClass.py: Class for a basic Monte Carlo Tree Search Planner. ''' 2 | 3 | # Python imports. 4 | import math as m 5 | import random 6 | from collections import defaultdict 7 | 8 | # Other imports. 9 | from simple_rl.planning.PlannerClass import Planner 10 | 11 | class MCTS(Planner): 12 | 13 | def __init__(self, mdp, name="mcts", explore_param=m.sqrt(2), rollout_depth=20, num_rollouts_per_step=10): 14 | Planner.__init__(self, mdp, name=name) 15 | 16 | self.rollout_depth = rollout_depth 17 | self.num_rollouts_per_step = num_rollouts_per_step 18 | self.value_total = defaultdict(lambda : defaultdict(float)) 19 | self.explore_param = explore_param 20 | self.visitation_counts = defaultdict(lambda : defaultdict(lambda : 0)) 21 | 22 | def plan(self, cur_state, horizon=20): 23 | ''' 24 | Args: 25 | cur_state (State) 26 | horizon (int) 27 | 28 | Returns: 29 | (list): List of actions 30 | ''' 31 | action_seq = [] 32 | state_seq = [cur_state] 33 | steps = 0 34 | while not cur_state.is_terminal() and steps < horizon: 35 | action = self._next_action(cur_state) 36 | # Do the rollouts... 37 | cur_state = self.transition_func(cur_state, action) 38 | action_seq.append(action) 39 | state_seq.append(cur_state) 40 | steps += 1 41 | 42 | self.has_planned = True 43 | 44 | return action_seq, state_seq 45 | 46 | def policy(self, state): 47 | ''' 48 | Args: 49 | state (State) 50 | 51 | Returns: 52 | (str) 53 | ''' 54 | if not self.has_planned: 55 | self.plan(state) 56 | 57 | return self._next_action(state) 58 | 59 | def _next_action(self, state): 60 | ''' 61 | Args; 62 | state (State) 63 | 64 | Returns: 65 | (str) 66 | 67 | Summary: 68 | Performs a single step of the MCTS rollout. 69 | ''' 70 | best_action = self.actions[0] 71 | best_score = 0 72 | total_visits = [self.visitation_counts[state][a] for a in self.actions] 73 | 74 | print(total_visits) 75 | 76 | if 0 in total_visits: 77 | # Insufficient stats, return random. 78 | # Should choose randomly AMONG UNSAMPLED. 79 | unsampled_actions = [self.actions[i] for i, x in enumerate(total_visits) if x == 0] 80 | next_action = random.choice(unsampled_actions) 81 | self.visitation_counts[state][next_action] += 1 82 | return next_action 83 | 84 | total = sum(total_visits) 85 | 86 | # Else choose according to the UCT explore method. 87 | for cur_action in self.actions: 88 | s_a_value_tot = self.value_total[state][cur_action] 89 | s_a_visit = self.visitation_counts[state][cur_action] 90 | score = s_a_value_tot / s_a_visit + self.explore_param * m.sqrt(m.log(total) / s_a_visit) 91 | 92 | if score > best_score: 93 | best_action = cur_action 94 | best_score = score 95 | 96 | return best_action 97 | 98 | def _rollout(self, cur_state, action): 99 | ''' 100 | Args: 101 | cur_state (State) 102 | action (str) 103 | 104 | Returns: 105 | (float): Discounted reward from the rollout. 106 | ''' 107 | trajectory = [] 108 | total_discounted_reward = [] 109 | for i in range(self.rollout_depth): 110 | 111 | # Simulate next state. 112 | next_action = self._next_action(cur_state) 113 | cur_state = self.transition_func(cur_state, next_action) 114 | next_reward = self.reward_func(cur_state, next_action) 115 | 116 | # Track rewards and states. 117 | total_discounted_reward.append(self.gamma**i * next_reward) 118 | trajectory.append((cur_state, next_action)) 119 | 120 | if cur_state.is_terminal(): 121 | # Break terminal. 122 | break 123 | 124 | # Update all visited nodes. 125 | for i, experience in enumerate(trajectory): 126 | s, a = experience 127 | self.visitation_counts[s][a] += 1 128 | self.value_total[s][a] += sum(total_discounted_reward[i:]) 129 | 130 | return total_discounted_reward 131 | 132 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/abstraction/action_abs/ActionAbstractionClass.py: -------------------------------------------------------------------------------- 1 | # Python imports. 2 | from __future__ import print_function 3 | from collections import defaultdict 4 | import random 5 | 6 | # Other imports. 7 | from simple_rl.abstraction.action_abs.OptionClass import Option 8 | from simple_rl.abstraction.action_abs.PredicateClass import Predicate 9 | 10 | class ActionAbstraction(object): 11 | 12 | def __init__(self, options=None, prim_actions=[], term_prob=0.0, prims_on_failure=False): 13 | self.options = options if options is not None else self._convert_to_options(prim_actions) 14 | self.is_cur_executing = False 15 | self.cur_option = None # The option we're executing currently. 16 | self.term_prob = term_prob 17 | self.prims_on_failure = prims_on_failure 18 | if self.prims_on_failure: 19 | self.prim_actions = prim_actions 20 | 21 | def act(self, agent, abstr_state, ground_state, reward): 22 | ''' 23 | Args: 24 | agent (Agent) 25 | abstr_state (State) 26 | ground_state (State) 27 | reward (float) 28 | 29 | Returns: 30 | (str) 31 | ''' 32 | 33 | if self.is_next_step_continuing_option(ground_state) and random.random() > self.term_prob: 34 | # We're in an option and not terminating. 35 | return self.get_next_ground_action(ground_state) 36 | else: 37 | # We're not in an option, check with agent. 38 | active_options = self.get_active_options(ground_state) 39 | 40 | if len(active_options) == 0: 41 | if self.prims_on_failure: 42 | # In a failure state, back off to primitives. 43 | agent.actions = self._convert_to_options(self.prim_actions) 44 | else: 45 | # No actions available. 46 | raise ValueError("(simple_rl) Error: no actions available in state " + str(ground_state) + ".") 47 | else: 48 | # Give agent available options. 49 | agent.actions = active_options 50 | 51 | abstr_action = agent.act(abstr_state, reward) 52 | self.set_option_executing(abstr_action) 53 | 54 | return self.abs_to_ground(ground_state, abstr_action) 55 | 56 | def get_active_options(self, state): 57 | ''' 58 | Args: 59 | state (State) 60 | 61 | Returns: 62 | (list): Contains all active options. 63 | ''' 64 | 65 | return [o for o in self.options if o.is_init_true(state)] 66 | 67 | def _convert_to_options(self, action_list): 68 | ''' 69 | Args: 70 | action_list (list) 71 | 72 | Returns: 73 | (list of Option) 74 | ''' 75 | options = [] 76 | for ground_action in action_list: 77 | o = ground_action 78 | if type(ground_action) is str: 79 | o = Option(init_predicate=Predicate(make_lambda(True)), 80 | term_predicate=Predicate(make_lambda(True)), 81 | policy=make_lambda(ground_action), 82 | name="prim." + ground_action) 83 | else: 84 | print(type(ground_action)) 85 | options.append(o) 86 | return options 87 | 88 | def is_next_step_continuing_option(self, ground_state): 89 | ''' 90 | Returns: 91 | (bool): True iff an option was executing and should continue next step. 92 | ''' 93 | return self.is_cur_executing and not self.cur_option.is_term_true(ground_state) 94 | 95 | def set_option_executing(self, option): 96 | if option not in self.options and "prim" not in option.name: 97 | raise ValueError("(simple_rl) Error: agent chose a non-existent option (" + str(option) + ").") 98 | 99 | self.cur_option = option 100 | self.is_cur_executing = True 101 | 102 | def get_next_ground_action(self, ground_state): 103 | return self.cur_option.act(ground_state) 104 | 105 | def get_actions(self): 106 | return self.options 107 | 108 | def abs_to_ground(self, ground_state, abstr_action): 109 | return abstr_action.act(ground_state) 110 | 111 | def add_option(self, option): 112 | self.options += [option] 113 | 114 | def reset(self): 115 | self.is_cur_executing = False 116 | self.cur_option = None # The option we're executing currently. 117 | 118 | def end_of_episode(self): 119 | self.reset() 120 | 121 | 122 | def make_lambda(result): 123 | return lambda x : result 124 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/taxi/taxi_visualizer.py: -------------------------------------------------------------------------------- 1 | # Python imports. 2 | from __future__ import print_function 3 | try: 4 | import pygame 5 | except ImportError: 6 | print("Warning: pygame not installed (needed for visuals).") 7 | 8 | # Other imports. 9 | from simple_rl.utils.chart_utils import color_ls 10 | 11 | def _draw_state(screen, 12 | taxi_oomdp, 13 | state, 14 | draw_statics=False, 15 | agent_shape=None): 16 | ''' 17 | Args: 18 | screen (pygame.Surface) 19 | taxi_oomdp (TaxiOOMDP) 20 | state (State) 21 | agent_shape (pygame.rect) 22 | 23 | Returns: 24 | (pygame.Shape) 25 | ''' 26 | # Prep some dimensions to make drawing easier. 27 | scr_width, scr_height = screen.get_width(), screen.get_height() 28 | width_buffer = scr_width / 10.0 29 | height_buffer = 30 + (scr_height / 10.0) # Add 30 for title. 30 | cell_width = (scr_width - width_buffer * 2) / taxi_oomdp.width 31 | cell_height = (scr_height - height_buffer * 2) / taxi_oomdp.height 32 | objects = state.get_objects() 33 | agent_x, agent_y = objects["agent"][0]["x"], objects["agent"][0]["y"] 34 | 35 | if agent_shape is not None: 36 | # Clear the old shape. 37 | pygame.draw.rect(screen, (255,255,255), agent_shape) 38 | top_left_point = width_buffer + cell_width*(agent_x - 1), height_buffer + cell_height*(taxi_oomdp.height - agent_y) 39 | agent_center = int(top_left_point[0] + cell_width/2.0), int(top_left_point[1] + cell_height/2.0) 40 | 41 | # Draw new. 42 | agent_shape = _draw_agent(agent_center, screen, base_size=min(cell_width, cell_height)/2.5 - 4) 43 | else: 44 | top_left_point = width_buffer + cell_width*(agent_x - 1), height_buffer + cell_height*(taxi_oomdp.height - agent_y) 45 | agent_center = int(top_left_point[0] + cell_width/2.0), int(top_left_point[1] + cell_height/2.0) 46 | agent_shape = _draw_agent(agent_center, screen, base_size=min(cell_width, cell_height)/2.5 - 4) 47 | 48 | # Do passengers first so the agent wipe out will wipe passengers, too. 49 | for i, p in enumerate(objects["passenger"]): 50 | # Passenger 51 | pass_x, pass_y = p["x"], p["y"] 52 | taxi_size = int(min(cell_width, cell_height) / 8.5) if p["in_taxi"] else int(min(cell_width, cell_height) / 5.0) 53 | top_left_point = int(width_buffer + cell_width*(pass_x - 1) + taxi_size + 38) , int(height_buffer + cell_height*(taxi_oomdp.height - pass_y) + taxi_size + 35) 54 | dest_col = (max(color_ls[-i-1][0]-30, 0), max(color_ls[-i-1][1]-30, 0), max(color_ls[-i-1][2]-30, 0)) 55 | pygame.draw.circle(screen, dest_col, top_left_point, taxi_size) 56 | 57 | # Statics 58 | if draw_statics: 59 | # For each row: 60 | for i in range(taxi_oomdp.width): 61 | # For each column: 62 | for j in range(taxi_oomdp.height): 63 | top_left_point = width_buffer + cell_width*i, height_buffer + cell_height*j 64 | r = pygame.draw.rect(screen, (46, 49, 49), top_left_point + (cell_width, cell_height), 3) 65 | 66 | # Draw walls. 67 | for w in objects["wall"]: 68 | # Passenger 69 | w_x, w_y = w["x"], w["y"] 70 | top_left_point = width_buffer + cell_width*(w_x -1) + 5, height_buffer + cell_height*(taxi_oomdp.height - w_y) + 5 71 | pygame.draw.rect(screen, (46, 49, 49), top_left_point + (cell_width - 10, cell_height - 10), 0) 72 | 73 | for i, p in enumerate(objects["passenger"]): 74 | # Dest. 75 | dest_x, dest_y = p["dest_x"], p["dest_y"] 76 | top_left_point = int(width_buffer + cell_width*(dest_x - 1) + 25), int(height_buffer + cell_height*(taxi_oomdp.height - dest_y) + 25) 77 | dest_col = (int(max(color_ls[-i-1][0]-30, 0)), int(max(color_ls[-i-1][1]-30, 0)), int(max(color_ls[-i-1][2]-30, 0))) 78 | pygame.draw.rect(screen, dest_col, top_left_point + (cell_width / 6, cell_height / 6), 0) 79 | 80 | pygame.display.flip() 81 | 82 | return agent_shape 83 | 84 | def _draw_agent(center_point, screen, base_size=30): 85 | ''' 86 | Args: 87 | center_point (tuple): (x,y) 88 | screen (pygame.Surface) 89 | 90 | Returns: 91 | (pygame.rect) 92 | ''' 93 | tri_bot_left = center_point[0] - base_size, center_point[1] + base_size 94 | tri_bot_right = center_point[0] + base_size, center_point[1] + base_size 95 | tri_top = center_point[0], center_point[1] - base_size 96 | tri = [tri_bot_left, tri_top, tri_bot_right] 97 | tri_color = (98, 140, 190) 98 | 99 | return pygame.draw.polygon(screen, tri_color, tri) 100 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/agents/func_approx/LinearQAgentClass.py: -------------------------------------------------------------------------------- 1 | ''' 2 | LinearQLearningAgentClass.py 3 | 4 | Contains implementation for a Q Learner with a Linear Function Approximator. 5 | ''' 6 | 7 | # Python imports. 8 | import numpy as np 9 | import math 10 | 11 | # Other imports. 12 | from simple_rl.agents import Agent, QLearningAgent 13 | 14 | class LinearQAgent(QLearningAgent): 15 | ''' 16 | QLearningAgent with a linear function approximator for the Q Function. 17 | ''' 18 | 19 | def __init__(self, actions, num_features, rand_init=True, name="Linear-Q", alpha=0.2, gamma=0.99, epsilon=0.2, explore="uniform", rbf=False, anneal=True): 20 | name = name + "-rbf" if rbf else name 21 | QLearningAgent.__init__(self, actions=list(actions), name=name, alpha=alpha, gamma=gamma, epsilon=epsilon, explore=explore, anneal=anneal) 22 | self.num_features = num_features 23 | # Add a basis feature. 24 | if rand_init: 25 | self.weights = np.random.random(self.num_features*len(self.actions)) 26 | else: 27 | self.weights = np.zeros(self.num_features*len(self.actions)) 28 | 29 | self.rbf = rbf 30 | 31 | def update(self, state, action, reward, next_state): 32 | ''' 33 | Args: 34 | state (State) 35 | action (str) 36 | reward (float) 37 | next_state (State) 38 | 39 | Summary: 40 | Updates the internal Q Function according to the Bellman Equation. (Classic Q Learning update) 41 | ''' 42 | if state is None: 43 | # If this is the first state, initialize state-relevant data and return. 44 | self.prev_state = state 45 | return 46 | self._update_weights(reward, next_state) 47 | 48 | def _phi(self, state, action): 49 | ''' 50 | Args: 51 | state (State): The abstract state object. 52 | action (str): A string representing an action. 53 | 54 | Returns: 55 | (numpy array): A state-action feature vector representing the current State and action. 56 | 57 | Notes: 58 | The resulting feature vector multiplies the state vector by |A| (size of action space), and only the action passed in retains 59 | the original vector, all other values are set to 0. 60 | ''' 61 | result = np.zeros(self.num_features * len(self.actions)) 62 | act_index = self.actions.index(action) 63 | 64 | basis_feats = state.features() 65 | 66 | if self.rbf: 67 | basis_feats = [_rbf(f) for f in basis_feats] 68 | 69 | result[act_index*self.num_features:(act_index + 1)*self.num_features] = basis_feats 70 | 71 | return result 72 | 73 | def _update_weights(self, reward, cur_state): 74 | ''' 75 | Args: 76 | reward (float) 77 | cur_state (State) 78 | 79 | Summary: 80 | Updates according to: 81 | 82 | [Eq. 1] delta = r + gamma * max_b(Q(s_curr,b)) - Q(s_prev, a_prev) 83 | 84 | For each weight: 85 | w_i = w_i + alpha * phi(s,a)[i] * delta 86 | 87 | Where phi(s,a) maps the state action pair to a feature vector (see QLearningAgent._phi(s,a)) 88 | ''' 89 | 90 | # Compute temporal difference [Eq. 1] 91 | max_q_cur_state = self.get_max_q_value(cur_state) 92 | prev_q_val = self.get_q_value(self.prev_state, self.prev_action) 93 | self.most_recent_loss = reward + self.gamma * max_q_cur_state - prev_q_val 94 | 95 | # Update each weight 96 | phi = self._phi(self.prev_state, self.prev_action) 97 | active_feats_index = self.actions.index(self.prev_action) * self.num_features 98 | 99 | # Sparsely update the weights (only update weights associated with the action we used). 100 | for i in range(active_feats_index, active_feats_index + self.num_features): 101 | self.weights[i] = self.weights[i] + self.alpha * phi[i] * self.most_recent_loss 102 | 103 | def get_q_value(self, state, action): 104 | ''' 105 | Args: 106 | state (State): A State object containing the abstract state representation 107 | action (str): A string representing an action. See namespaceAIX. 108 | 109 | Returns: 110 | (float): denoting the q value of the (@state,@action) pair. 111 | ''' 112 | 113 | # Return linear approximation of Q value 114 | sa_feats = self._phi(state, action) 115 | 116 | return np.dot(self.weights, sa_feats) 117 | 118 | def reset(self): 119 | self.weights = np.zeros(self.num_features*len(self.actions)) 120 | QLearningAgent.reset(self) 121 | 122 | 123 | def _rbf(x): 124 | return math.exp(-(x)**2) 125 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/agents/bandits/LinUCBAgentClass.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Basic LinUCB implementation. 3 | ''' 4 | 5 | # Python imports. 6 | import numpy as np 7 | 8 | # Other imports. 9 | from simple_rl.agents.AgentClass import Agent 10 | 11 | class LinUCBAgent(Agent): 12 | ''' 13 | From: 14 | Lihong Li, et al. "A Contextual-Bandit Approach to Personalized 15 | News Article Recommendation." In Proceedings of the 19th 16 | International Conference on World Wide Web (WWW), 2010. 17 | ''' 18 | 19 | def __init__(self, actions, name="LinUCB", rand_init=True, context_size=1, alpha=1.5): 20 | ''' 21 | Args: 22 | actions (list): Contains a string for each action. 23 | name (str) 24 | context_size (int) 25 | alpha (float): Uncertainty parameter. 26 | ''' 27 | Agent.__init__(self, name, actions) 28 | self.alpha = alpha 29 | self.context_size = context_size 30 | self.prev_context = None 31 | self.step_number = 0 32 | self._init_action_model(rand_init) 33 | 34 | def _init_action_model(self, rand_init=True): 35 | ''' 36 | Summary: 37 | Initializes model parameters 38 | ''' 39 | self.model = {'act': {}, 'act_inv': {}, 'theta': {}, 'b': {}} 40 | for action_id in range(len(self.actions)): 41 | self.model['act'][action_id] = np.identity(self.context_size) 42 | self.model['act_inv'][action_id] = np.identity(self.context_size) 43 | if rand_init: 44 | self.model['theta'][action_id] = np.random.random((self.context_size, 1)) 45 | else: 46 | self.model['theta'][action_id] = np.zeros((self.context_size, 1)) 47 | self.model['b'][action_id] = np.zeros((self.context_size,1)) 48 | 49 | def _compute_score(self, context): 50 | ''' 51 | Args: 52 | context (list) 53 | 54 | Returns: 55 | (dict): 56 | K (str): action 57 | V (float): score 58 | ''' 59 | 60 | a_inv = self.model['act_inv'] 61 | theta = self.model['theta'] 62 | 63 | estimated_reward = {} 64 | uncertainty = {} 65 | score_dict = {} 66 | max_score = 0 67 | for action_id in range(len(self.actions)): 68 | action_context = np.reshape(context[action_id], (-1, 1)) 69 | estimated_reward[action_id] = float(theta[action_id].T.dot(action_context)) 70 | uncertainty[action_id] = float(self.alpha * np.sqrt(action_context.T.dot(a_inv[action_id]).dot(action_context))) 71 | score_dict[action_id] = estimated_reward[action_id] + uncertainty[action_id] 72 | 73 | return score_dict 74 | 75 | def update(self, reward): 76 | ''' 77 | Args: 78 | reward (float) 79 | 80 | Summary: 81 | Updates self.model according to self.prev_context, self.prev_action, @reward. 82 | ''' 83 | action_id = self.actions.index(self.prev_action) 84 | action_context = np.reshape(self.prev_context[action_id], (-1, 1)) 85 | self.model['act'][action_id] += action_context.dot(action_context.T) 86 | self.model['act_inv'][action_id] = np.linalg.inv(self.model['act'][action_id]) 87 | self.model['b'][action_id] += reward * action_context 88 | self.model['theta'][action_id] = self.model['act_inv'][action_id].dot(self.model['b'][action_id]) 89 | 90 | def act(self, context, reward): 91 | ''' 92 | Args: 93 | context (iterable) 94 | reward (float) 95 | 96 | Returns: 97 | (str): action. 98 | ''' 99 | 100 | # Update previous context-action pair. 101 | if self.prev_action is not None: 102 | self.update(reward) 103 | 104 | # Compute score. 105 | context = self._pre_process_context(context) 106 | score = self._compute_score(context) 107 | 108 | # Compute best action. 109 | best_action = np.random.choice(self.actions) 110 | max_score = float("-inf") 111 | for action_id in range(len(self.actions)): 112 | if score[action_id] > max_score: 113 | max_score = score[action_id] 114 | best_action = self.actions[action_id] 115 | 116 | 117 | # Update prev pointers. 118 | self.prev_action = best_action 119 | self.prev_context = context 120 | self.step_number += 1 121 | 122 | return best_action 123 | 124 | def _pre_process_context(self, context): 125 | if context.get_num_feats() == 1: 126 | # If there's no context (that is, we're just in a regular bandit). 127 | context = context.features() 128 | 129 | if not hasattr(context[0], '__iter__'): 130 | # If we only have a single context. 131 | new_context = {} 132 | for action_id in range(len(self.actions)): 133 | new_context[action_id] = context 134 | context = new_context 135 | 136 | return context 137 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/puddle/PuddleMDPClass.py: -------------------------------------------------------------------------------- 1 | ''' 2 | PuddleMDPClass.py: Contains the Puddle class from: 3 | 4 | Boyan, Justin A., and Andrew W. Moore. "Generalization in reinforcement learning: 5 | Safely approximating the value function." NIPS 1995. 6 | ''' 7 | 8 | # Python imports. 9 | import math 10 | import numpy as np 11 | 12 | # Other imports 13 | from simple_rl.mdp.MDPClass import MDP 14 | from simple_rl.tasks.grid_world.GridWorldMDPClass import GridWorldMDP 15 | from simple_rl.tasks.grid_world.GridWorldStateClass import GridWorldState 16 | 17 | class PuddleMDP(GridWorldMDP): 18 | ''' Class for a Puddle MDP ''' 19 | 20 | def __init__(self, gamma=0.99, slip_prob=0.00, name="puddle", puddle_rects=[(0.1, 0.8, 0.5, 0.7), (0.4, 0.7, 0.5, 0.4)], is_goal_terminal=True, rand_init=False): 21 | ''' 22 | Args: 23 | gamma (float) 24 | slip_prob (float) 25 | name (str) 26 | puddle_rects (list): [(top_left_x, top_left_y), (bot_right_x, bot_right_y)] 27 | is_goal_terminal (bool) 28 | rand_init (bool) 29 | ''' 30 | self.delta = 0.05 31 | self.puddle_rects = puddle_rects 32 | GridWorldMDP.__init__(self, width=1.0, height=1.0, init_loc=[0.25, 0.6], goal_locs=[[1.0, 1.0]], gamma=gamma, name=name, is_goal_terminal=is_goal_terminal, rand_init=rand_init) 33 | 34 | def _reward_func(self, state, action): 35 | if self._is_goal_state_action(state, action): 36 | return 1.0 - self.step_cost 37 | elif self._is_puddle_state_action(state, action): 38 | return -1.0 39 | else: 40 | return 0 - self.step_cost 41 | 42 | def _is_puddle_state_action(self, state, action): 43 | ''' 44 | Args: 45 | state (simple_rl.State) 46 | action (str) 47 | 48 | Returns: 49 | (bool) 50 | ''' 51 | for puddle_rect in self.puddle_rects: 52 | x_1, y_1, x_2, y_2 = puddle_rect 53 | if state.x >= x_1 and state.x <= x_2 and \ 54 | state.y <= y_1 and state.y >= y_2: 55 | return True 56 | 57 | return False 58 | 59 | def _is_goal_state_action(self, state, action): 60 | ''' 61 | Args: 62 | state (State) 63 | action (str) 64 | 65 | Returns: 66 | (bool): True iff the state-action pair send the agent to the goal state. 67 | ''' 68 | for g in self.goal_locs: 69 | if _euclidean_distance(state.x, state.y, g[0], g[1]) <= self.delta * 2 and self.is_goal_terminal: 70 | # Already at terminal. 71 | return False 72 | 73 | if action == "left" and self.is_loc_within_radius_to_goal(state.x - self.delta, state.y): 74 | return True 75 | elif action == "right" and self.is_loc_within_radius_to_goal(state.x + self.delta, state.y): 76 | return True 77 | elif action == "down" and self.is_loc_within_radius_to_goal(state.x, state.y - self.delta): 78 | return True 79 | elif action == "up" and self.is_loc_within_radius_to_goal(state.x, state.y + self.delta): 80 | return True 81 | else: 82 | return False 83 | 84 | def is_loc_within_radius_to_goal(self, state_x, state_y): 85 | ''' 86 | Args: 87 | state_x (float) 88 | state_y (float) 89 | 90 | Returns: 91 | (bool) 92 | ''' 93 | for g in self.goal_locs: 94 | if _euclidean_distance(state_x, state_y, g[0], g[1]) <= self.delta * 2: 95 | return True 96 | return False 97 | 98 | def _transition_func(self, state, action): 99 | ''' 100 | Args: 101 | state (simple_rl.State) 102 | action (str) 103 | 104 | Returns: 105 | state (simple_rl.State) 106 | ''' 107 | if state.is_terminal(): 108 | return state 109 | 110 | noise = np.random.randn(1)[0] / 100.0 111 | to_move = self.delta + noise 112 | 113 | if action == "up": 114 | next_state = GridWorldState(state.x, min(state.y + to_move, 1)) 115 | elif action == "down": 116 | next_state = GridWorldState(state.x, max(state.y - to_move, 0)) 117 | elif action == "right": 118 | next_state = GridWorldState(min(state.x + to_move, 1), state.y) 119 | elif action == "left": 120 | next_state = GridWorldState(max(state.x - to_move, 0), state.y) 121 | else: 122 | next_state = GridWorldState(state.x, state.y) 123 | 124 | if self._is_goal_state_action(state, action) and self.is_goal_terminal: 125 | next_state.set_terminal(True) 126 | 127 | return next_state 128 | 129 | 130 | def _euclidean_distance(ax, ay, bx, by): 131 | ''' 132 | Args: 133 | ax (float) 134 | ay (float) 135 | bx (float) 136 | by (float) 137 | 138 | Returns: 139 | (float) 140 | ''' 141 | return np.linalg.norm(np.array([ax, ay]) - np.array([bx, by])) 142 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/grid_game/GridGameMDPClass.py: -------------------------------------------------------------------------------- 1 | ''' GridGameMDPClass.py: Contains an implementation of a two player grid game. ''' 2 | 3 | # Python imports. 4 | import random 5 | 6 | # Other imports. 7 | from simple_rl.mdp.markov_game.MarkovGameMDPClass import MarkovGameMDP 8 | from simple_rl.tasks.grid_game.GridGameStateClass import GridGameState 9 | 10 | class GridGameMDP(MarkovGameMDP): 11 | ''' Class for a Two Player Grid Game ''' 12 | 13 | # Static constants. 14 | ACTIONS = ["up", "left", "down", "right"] 15 | 16 | def __init__(self, height=3, width=8, init_a_x=1, init_a_y=2, init_b_x=8, init_b_y=8): 17 | self.goal_a_x = init_b_x 18 | self.goal_a_y = init_b_y 19 | self.goal_b_x = init_a_x 20 | self.goal_b_y = init_a_y 21 | init_state = GridGameState(init_a_x, init_a_y, init_b_x, init_b_y) 22 | self.height = height 23 | self.width = width 24 | MarkovGameMDP.__init__(self, GridGameMDP.ACTIONS, self._transition_func, self._reward_func, init_state=init_state) 25 | 26 | def _reward_func(self, state, action_dict): 27 | ''' 28 | Args: 29 | state (State) 30 | action (dict of actions) 31 | 32 | Returns 33 | (float) 34 | ''' 35 | agent_a, agent_b = action_dict.keys()[0], action_dict.keys()[1] 36 | action_a, action_b = action_dict[agent_a], action_dict[agent_b] 37 | 38 | reward_dict = {} 39 | 40 | next_state = self._transition_func(state, action_dict) 41 | 42 | a_at_goal = (next_state.a_x == self.goal_a_x and next_state.a_y == self.goal_a_y) 43 | b_at_goal = (next_state.b_x == self.goal_b_x and next_state.b_y == self.goal_b_y) 44 | 45 | if a_at_goal and b_at_goal: 46 | reward_dict[agent_a] = 2.0 47 | reward_dict[agent_b] = 2.0 48 | elif a_at_goal and not b_at_goal: 49 | reward_dict[agent_a] = 1.0 50 | reward_dict[agent_b] = -1.0 51 | elif not a_at_goal and b_at_goal: 52 | reward_dict[agent_a] = -1.0 53 | reward_dict[agent_b] = 1.0 54 | else: 55 | reward_dict[agent_a] = 0.0 56 | reward_dict[agent_b] = 0.0 57 | 58 | return reward_dict 59 | 60 | def _transition_func(self, state, action_dict): 61 | ''' 62 | Args: 63 | state (State) 64 | action_dict (str) 65 | 66 | Returns 67 | (State) 68 | ''' 69 | 70 | agent_a, agent_b = action_dict.keys()[0], action_dict.keys()[1] 71 | action_a, action_b = action_dict[agent_a], action_dict[agent_b] 72 | 73 | next_state = self._move_agents(action_a, state.a_x, state.a_y, action_b, state.b_x, state.b_y) 74 | 75 | return next_state 76 | 77 | def _move_agents(self, action_a, a_x, a_y, action_b, b_x, b_y): 78 | ''' 79 | Args: 80 | action_a (str) 81 | a_x (int) 82 | a_y (int) 83 | action_b (str) 84 | b_x (int) 85 | b_y (int) 86 | 87 | Summary: 88 | Moves the two agents accounting for collisions with walls and each other. 89 | 90 | Returns: 91 | (GridGameState) 92 | ''' 93 | 94 | new_a_x, new_a_y = a_x, a_y 95 | new_b_x, new_b_y = b_x, b_y 96 | 97 | # Move agent a. 98 | if action_a == "up" and a_y < self.height: 99 | new_a_y += 1 100 | elif action_a == "down" and a_y > 1: 101 | new_a_y -= 1 102 | elif action_a == "right" and a_x < self.width: 103 | new_a_x += 1 104 | elif action_a == "left" and a_x > 1: 105 | new_a_x -= 1 106 | 107 | # Move agent b. 108 | if action_b == "up" and b_y < self.height: 109 | new_b_y += 1 110 | elif action_b == "down" and b_y > 1: 111 | new_b_y -= 1 112 | elif action_b == "right" and b_x < self.width: 113 | new_b_x += 1 114 | elif action_b == "left" and b_x > 1: 115 | new_b_x -= 1 116 | 117 | if new_a_x == new_b_x and new_a_y == new_b_y or \ 118 | (new_a_x == b_x and new_a_y == b_y and new_b_x == a_x and new_b_y == a_y): 119 | # If the agent's collided or traded places, reset them. 120 | new_a_x, new_a_y = a_x, a_y 121 | new_b_x, new_b_y = b_x, b_y 122 | 123 | next_state = GridGameState(new_a_x, new_a_y, new_b_x, new_b_y) 124 | 125 | # Check terminal. 126 | if self._is_terminal_state(next_state): 127 | next_state.set_terminal(True) 128 | 129 | return next_state 130 | 131 | def _is_terminal_state(self, next_state): 132 | return (next_state.a_x == self.goal_a_x and next_state.a_y == self.goal_a_y) or \ 133 | (next_state.b_x == self.goal_b_x and next_state.b_y == self.goal_b_y) 134 | 135 | def __str__(self): 136 | return "grid_game-" + str(self.height) + "-" + str(self.width) 137 | 138 | def _manhattan_distance(a_x, a_y, b_x, b_y): 139 | return abs(a_x - b_x) + abs(a_y - b_y) 140 | 141 | def main(): 142 | grid_game = GridGameMDP() 143 | 144 | if __name__ == "__main__": 145 | main() 146 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/utils/make_mdp.py: -------------------------------------------------------------------------------- 1 | ''' 2 | make_mdp.py 3 | 4 | Utility for making MDP instances or distributions. 5 | ''' 6 | 7 | # Python imports. 8 | import itertools 9 | import random 10 | from collections import defaultdict 11 | 12 | # Other imports. 13 | from simple_rl.tasks import ChainMDP, GridWorldMDP, TaxiOOMDP, RandomMDP, FourRoomMDP, HanoiMDP 14 | from simple_rl.tasks.grid_world.GridWorldMDPClass import make_grid_world_from_file 15 | from simple_rl.mdp import MDPDistribution 16 | 17 | def make_markov_game(markov_game_class="grid_game"): 18 | return {"prison":PrisonersDilemmaMDP(), 19 | "rps":RockPaperScissorsMDP(), 20 | "grid_game":GridGameMDP()}[markov_game_class] 21 | 22 | def make_mdp(mdp_class="grid", grid_dim=7): 23 | ''' 24 | Returns: 25 | (MDP) 26 | ''' 27 | # Grid/Hallway stuff. 28 | width, height = grid_dim, grid_dim 29 | hall_goal_locs = [(i, width) for i in range(1, height+1)] 30 | 31 | four_room_goal_locs = [(width, height), (width, 1), (1, height), (1, height - 2), (width - 2, height - 2), (width - 2, 1)] 32 | # four_room_goal_loc = four_room_goal_locs[5] 33 | 34 | # Taxi stuff. 35 | agent = {"x":1, "y":1, "has_passenger":0} 36 | passengers = [{"x":grid_dim / 2, "y":grid_dim / 2, "dest_x":grid_dim-2, "dest_y":2, "in_taxi":0}] 37 | walls = [] 38 | 39 | mdp = {"hall":GridWorldMDP(width=width, height=height, init_loc=(1, 1), goal_locs=hall_goal_locs), 40 | "pblocks_grid":make_grid_world_from_file("pblocks_grid.txt", randomize=True), 41 | "grid":GridWorldMDP(width=width, height=height, init_loc=(1, 1), goal_locs=[(grid_dim, grid_dim)]), 42 | "four_room":FourRoomMDP(width=width, height=height, goal_locs=[four_room_goal_loc]), 43 | "chain":ChainMDP(num_states=grid_dim), 44 | "random":RandomMDP(num_states=50, num_rand_trans=2), 45 | "hanoi":HanoiMDP(num_pegs=grid_dim, num_discs=3), 46 | "taxi":TaxiOOMDP(width=grid_dim, height=grid_dim, slip_prob=0.0, agent=agent, walls=walls, passengers=passengers)}[mdp_class] 47 | 48 | return mdp 49 | 50 | def make_mdp_distr(mdp_class="grid", grid_dim=9, horizon=0, step_cost=0, gamma=0.99): 51 | ''' 52 | Args: 53 | mdp_class (str): one of {"grid", "random"} 54 | horizon (int) 55 | step_cost (float) 56 | gamma (float) 57 | 58 | Returns: 59 | (MDPDistribution) 60 | ''' 61 | mdp_dist_dict = {} 62 | height, width = grid_dim, grid_dim 63 | 64 | # Define goal locations. 65 | 66 | # Corridor. 67 | corr_width = 20 68 | corr_goal_magnitude = 1 #random.randint(1, 5) 69 | corr_goal_cols = [i for i in range(1, corr_goal_magnitude + 1)] + [j for j in range(corr_width-corr_goal_magnitude + 1, corr_width + 1)] 70 | corr_goal_locs = list(itertools.product(corr_goal_cols, [1])) 71 | 72 | # Grid World 73 | tl_grid_world_rows, tl_grid_world_cols = [i for i in range(width - 4, width)], [j for j in range(height - 4, height)] 74 | tl_grid_goal_locs = list(itertools.product(tl_grid_world_rows, tl_grid_world_cols)) 75 | tr_grid_world_rows, tr_grid_world_cols = [i for i in range(1, 4)], [j for j in range(height - 4, height)] 76 | tr_grid_goal_locs = list(itertools.product(tr_grid_world_rows, tr_grid_world_cols)) 77 | grid_goal_locs = tl_grid_goal_locs + tr_grid_goal_locs 78 | 79 | # Hallway. 80 | hall_goal_locs = [(i, height) for i in range(1, 30)] 81 | 82 | # Four room. 83 | four_room_goal_locs = [(width, height), (width, 1), (1, height), (4,4)] 84 | 85 | # Taxi. 86 | agent = {"x":1, "y":1, "has_passenger":0} 87 | walls = [] 88 | 89 | goal_loc_dict = {"four_room":four_room_goal_locs, 90 | "hall":hall_goal_locs, 91 | "grid":grid_goal_locs, 92 | "corridor":corr_goal_locs, 93 | } 94 | 95 | # MDP Probability. 96 | num_mdps = 10 if mdp_class not in goal_loc_dict.keys() else len(goal_loc_dict[mdp_class]) 97 | mdp_prob = 1.0 / num_mdps 98 | 99 | for i in range(num_mdps): 100 | 101 | new_mdp = {"hall":GridWorldMDP(width=30, height=height, rand_init=False, goal_locs=goal_loc_dict["hall"], name="hallway", is_goal_terminal=True), 102 | "corridor":GridWorldMDP(width=20, height=1, init_loc=(10, 1), goal_locs=[goal_loc_dict["corridor"][i % len(goal_loc_dict["corridor"])]], is_goal_terminal=True, name="corridor"), 103 | "grid":GridWorldMDP(width=width, height=height, rand_init=True, goal_locs=[goal_loc_dict["grid"][i % len(goal_loc_dict["grid"])]], is_goal_terminal=True), 104 | "four_room":FourRoomMDP(width=width, height=height, goal_locs=[goal_loc_dict["four_room"][i % len(goal_loc_dict["four_room"])]], is_goal_terminal=True), 105 | "chain":ChainMDP(num_states=10, reset_val=random.choice([0, 0.01, 0.05, 0.1, 0.2, 0.5])), 106 | "random":RandomMDP(num_states=40, num_rand_trans=random.randint(1,10)), 107 | "taxi":TaxiOOMDP(3, 4, slip_prob=0.0, agent=agent, walls=walls, \ 108 | passengers=[{"x":2, "y":1, "dest_x":random.choice([2,3]), "dest_y":random.choice([2,3]), "in_taxi":0}, 109 | {"x":1, "y":2, "dest_x":random.choice([1,2]), "dest_y":random.choice([1,4]), "in_taxi":0}])}[mdp_class] 110 | 111 | new_mdp.set_step_cost(step_cost) 112 | new_mdp.set_gamma(gamma) 113 | 114 | mdp_dist_dict[new_mdp] = mdp_prob 115 | 116 | return MDPDistribution(mdp_dist_dict, horizon=horizon) 117 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/mdp/MDPDistributionClass.py: -------------------------------------------------------------------------------- 1 | ''' MDPDistributionClass.py: Contains the MDP Distribution Class. ''' 2 | 3 | # Python imports. 4 | from __future__ import print_function 5 | import numpy as np 6 | from collections import defaultdict 7 | 8 | class MDPDistribution(object): 9 | ''' Class for distributions over MDPs. ''' 10 | 11 | def __init__(self, mdp_prob_dict, horizon=0): 12 | ''' 13 | Args: 14 | mdp_prob_dict (dict): 15 | Key (MDP) 16 | Val (float): Represents the probability with which the MDP is sampled. 17 | 18 | Notes: 19 | @mdp_prob_dict can also be a list, in which case the uniform distribution is used. 20 | ''' 21 | if type(mdp_prob_dict) == list or len(mdp_prob_dict.values()) == 0: 22 | # Assume uniform if no probabilities are provided. 23 | mdp_prob = 1.0 / len(mdp_prob_dict.keys()) 24 | new_dict = defaultdict(float) 25 | for mdp in mdp_prob_dict: 26 | new_dict[mdp] = mdp_prob 27 | mdp_prob_dict = new_dict 28 | 29 | self.horizon = horizon 30 | self.mdp_prob_dict = mdp_prob_dict 31 | self.index = 1 32 | 33 | def remove_mdps(self, mdp_list): 34 | ''' 35 | Args: 36 | (list): Contains MDP instances. 37 | 38 | Summary: 39 | Removes each mdp in @mdp_list from self.mdp_prob_dict and recomputes the distribution. 40 | ''' 41 | for mdp in mdp_list: 42 | try: 43 | self.mdp_prob_dict.pop(mdp) 44 | except KeyError: 45 | raise ValueError("(simple-rl Error): Trying to remove MDP (" + str(mdp) + ") from MDP Distribution that doesn't contain it.") 46 | 47 | self._normalize() 48 | 49 | def remove_mdp(self, mdp): 50 | ''' 51 | Args: 52 | (MDP) 53 | 54 | Summary: 55 | Removes @mdp from self.mdp_prob_dict and recomputes the distribution. 56 | ''' 57 | try: 58 | self.mdp_prob_dict.pop(mdp) 59 | except KeyError: 60 | raise ValueError("(simple-rl Error): Trying to remove MDP (" + str(mdp) + ") from MDP Distribution that doesn't contain it.") 61 | 62 | self._normalize() 63 | 64 | def _normalize(self): 65 | total = sum(self.mdp_prob_dict.values()) 66 | for mdp in self.mdp_prob_dict.keys(): 67 | self.mdp_prob_dict[mdp] = self.mdp_prob_dict[mdp] / total 68 | 69 | def get_all_mdps(self, prob_threshold=0): 70 | ''' 71 | Args: 72 | prob_threshold (float) 73 | 74 | Returns: 75 | (list): Contains all mdps in the distribution with Pr. > @prob_threshold. 76 | ''' 77 | return [mdp for mdp in self.mdp_prob_dict.keys() if self.mdp_prob_dict[mdp] > prob_threshold] 78 | 79 | def get_horizon(self): 80 | return self.horizon 81 | 82 | def get_actions(self): 83 | return list(self.mdp_prob_dict.keys())[0].get_actions() 84 | 85 | def get_gamma(self): 86 | ''' 87 | Notes: 88 | Not all MDPs in the distribution are guaranteed to share gamma. 89 | ''' 90 | return list(self.mdp_prob_dict.keys())[0].get_gamma() 91 | 92 | def get_reward_func(self, avg=True): 93 | if avg: 94 | self.get_average_reward_func() 95 | else: 96 | self.get_all_mdps()[0].get_reward_func() 97 | 98 | def get_average_reward_func(self): 99 | def _avg_r_func(s, a): 100 | r = 0.0 101 | for m in self.mdp_prob_dict.keys(): 102 | r += m.reward_func(s, a) * self.mdp_prob_dict[m] 103 | return r 104 | return _avg_r_func 105 | 106 | def get_init_state(self): 107 | ''' 108 | Notes: 109 | Not all MDPs in the distribution are guaranteed to share init states. 110 | ''' 111 | return list(self.mdp_prob_dict.keys())[0].get_init_state() 112 | 113 | def get_num_mdps(self): 114 | return len(self.mdp_prob_dict.keys()) 115 | 116 | def get_mdps(self): 117 | return self.mdp_prob_dict.keys() 118 | 119 | def get_prob_of_mdp(self, mdp): 120 | if mdp in self.mdp_prob_dict.keys(): 121 | return self.mdp_prob_dict[mdp] 122 | else: 123 | return 0.0 124 | 125 | def set_gamma(self, new_gamma): 126 | for mdp in self.mdp_prob_dict.keys(): 127 | mdp.set_gamma(new_gamma) 128 | 129 | def sample(self, k=1): 130 | ''' 131 | Args: 132 | k (int) 133 | 134 | Returns: 135 | (List of MDP): Samples @k mdps without replacement. 136 | ''' 137 | # num = len(self.get_all_mdps()) 138 | # i = self.index % num 139 | # print(i) 140 | # self.index +=1 141 | # return list(self.mdp_prob_dict.keys())[i] 142 | sampled_mdp_id_list = np.random.multinomial(k, list(self.mdp_prob_dict.values())).tolist() 143 | indices = [i for i, x in enumerate(sampled_mdp_id_list) if x > 0] 144 | 145 | if k == 1: 146 | return list(self.mdp_prob_dict.keys())[indices[0]] 147 | 148 | mdps_to_return = [] 149 | 150 | for i in indices: 151 | for copies in range(sampled_mdp_id_list[i]): 152 | mdps_to_return.append(list(self.mdp_prob_dict.keys())[i]) 153 | 154 | return mdps_to_return 155 | 156 | def __str__(self): 157 | ''' 158 | Notes: 159 | Not all MDPs are guaranteed to share a name (for instance, might include dimensions). 160 | ''' 161 | return "lifelong-" + str(list(self.mdp_prob_dict.keys())[0]) 162 | 163 | def main(): 164 | # Simple test code. 165 | from simple_rl.tasks import GridWorldMDP 166 | 167 | mdp_distr = {} 168 | height, width = 8, 8 169 | prob_list = [0.0, 0.1, 0.2, 0.3, 0.4] 170 | 171 | for i in range(len(prob_list)): 172 | next_mdp = GridWorldMDP(width=width, height=width, init_loc=(1, 1), goal_locs=r.sample(zip(range(1, width + 1), [height] * width), 2), is_goal_terminal=True) 173 | 174 | mdp_distr[next_mdp] = prob_list[i] 175 | 176 | m = MDPDistribution(mdp_distr) 177 | m.sample() 178 | 179 | if __name__ == "__main__": 180 | main() 181 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/abstraction/abstr_mdp/abstr_mdp_funcs.py: -------------------------------------------------------------------------------- 1 | # Python imports. 2 | from collections import defaultdict 3 | import numpy as np 4 | 5 | # Other imports. 6 | from simple_rl.planning import ValueIteration 7 | from simple_rl.mdp import MDP 8 | from simple_rl.mdp import MDPDistribution 9 | from simple_rl.abstraction.abstr_mdp.RewardFuncClass import RewardFunc 10 | from simple_rl.abstraction.abstr_mdp.TransitionFuncClass import TransitionFunc 11 | from simple_rl.abstraction.action_abs.ActionAbstractionClass import ActionAbstraction 12 | # ------------------ 13 | # -- Single Level -- 14 | # ------------------ 15 | 16 | def make_abstr_mdp(mdp, state_abstr, action_abstr=None, step_cost=0.0, sample_rate=5): 17 | ''' 18 | Args: 19 | mdp (MDP) 20 | state_abstr (StateAbstraction) 21 | action_abstr (ActionAbstraction) 22 | step_cost (float): Cost for a step in the lower MDP. 23 | sample_rate (int): Sample rate for computing the abstract R and T. 24 | 25 | Returns: 26 | (MDP) 27 | ''' 28 | 29 | if action_abstr is None: 30 | action_abstr = ActionAbstraction(prim_actions=mdp.get_actions()) 31 | 32 | # Make abstract reward and transition functions. 33 | def abstr_reward_lambda(abstr_state, abstr_action): 34 | if abstr_state.is_terminal(): 35 | return 0 36 | 37 | # Get relevant MDP components from the lower MDP. 38 | lower_states = state_abstr.get_lower_states_in_abs_state(abstr_state) 39 | lower_reward_func = mdp.get_reward_func() 40 | lower_trans_func = mdp.get_transition_func() 41 | 42 | # Compute reward. 43 | total_reward = 0 44 | for ground_s in lower_states: 45 | for sample in range(sample_rate): 46 | s_prime, reward = abstr_action.rollout(ground_s, lower_reward_func, lower_trans_func, step_cost=step_cost) 47 | total_reward += float(reward) / (len(lower_states) * sample_rate) # Add weighted reward. 48 | 49 | return total_reward 50 | 51 | def abstr_transition_lambda(abstr_state, abstr_action): 52 | is_ground_terminal = False 53 | for s_g in state_abstr.get_lower_states_in_abs_state(abstr_state): 54 | if s_g.is_terminal(): 55 | is_ground_terminal = True 56 | break 57 | 58 | # Get relevant MDP components from the lower MDP. 59 | if abstr_state.is_terminal(): 60 | return abstr_state 61 | 62 | lower_states = state_abstr.get_lower_states_in_abs_state(abstr_state) 63 | lower_reward_func = mdp.get_reward_func() 64 | lower_trans_func = mdp.get_transition_func() 65 | 66 | 67 | # Compute next state distribution. 68 | s_prime_prob_dict = defaultdict(int) 69 | total_reward = 0 70 | for ground_s in lower_states: 71 | for sample in range(sample_rate): 72 | s_prime, reward = abstr_action.rollout(ground_s, lower_reward_func, lower_trans_func) 73 | s_prime_prob_dict[s_prime] += (1.0 / (len(lower_states) * sample_rate)) # Weighted average. 74 | 75 | # Form distribution and sample s_prime. 76 | next_state_sample_list = list(np.random.multinomial(1, list(s_prime_prob_dict.values())).tolist()) 77 | end_ground_state = list(s_prime_prob_dict.keys())[next_state_sample_list.index(1)] 78 | end_abstr_state = state_abstr.phi(end_ground_state) 79 | 80 | return end_abstr_state 81 | 82 | # Make the components of the Abstract MDP. 83 | abstr_init_state = state_abstr.phi(mdp.get_init_state()) 84 | abstr_action_space = action_abstr.get_actions() 85 | abstr_state_space = state_abstr.get_abs_states() 86 | abstr_reward_func = RewardFunc(abstr_reward_lambda, abstr_state_space, abstr_action_space) 87 | abstr_transition_func = TransitionFunc(abstr_transition_lambda, abstr_state_space, abstr_action_space, sample_rate=sample_rate) 88 | 89 | # Make the MDP. 90 | abstr_mdp = MDP(actions=abstr_action_space, 91 | init_state=abstr_init_state, 92 | reward_func=abstr_reward_func.reward_func, 93 | transition_func=abstr_transition_func.transition_func, 94 | gamma=mdp.get_gamma()) 95 | 96 | return abstr_mdp 97 | 98 | def make_abstr_mdp_distr(mdp_distr, state_abstr, action_abstr, step_cost=0.1): 99 | ''' 100 | Args: 101 | mdp_distr (MDPDistribution) 102 | state_abstr (StateAbstraction) 103 | action_abstr (ActionAbstraction) 104 | 105 | Returns: 106 | (MDPDistribution) 107 | ''' 108 | 109 | # Loop through old mdps and abstract. 110 | mdp_distr_dict = {} 111 | for mdp in mdp_distr.get_all_mdps(): 112 | abstr_mdp = make_abstr_mdp(mdp, state_abstr, action_abstr, step_cost=step_cost) 113 | prob_of_abstr_mdp = mdp_distr.get_prob_of_mdp(mdp) 114 | mdp_distr_dict[abstr_mdp] = prob_of_abstr_mdp 115 | 116 | return MDPDistribution(mdp_distr_dict) 117 | 118 | # ----------------- 119 | # -- Multi Level -- 120 | # ----------------- 121 | 122 | def make_abstr_mdp_multi_level(mdp, state_abstr_stack, action_abstr_stack, step_cost=0.1, sample_rate=5): 123 | ''' 124 | Args: 125 | mdp (MDP) 126 | state_abstr_stack (StateAbstractionStack) 127 | action_abstr_stack (ActionAbstractionStack) 128 | step_cost (float): Cost for a step in the lower MDP. 129 | sample_rate (int): Sample rate for computing the abstract R and T. 130 | 131 | Returns: 132 | (MDP) 133 | ''' 134 | mdp_level = min(state_abstr_stack.get_num_levels(), action_abstr_stack.get_num_levels()) 135 | 136 | for i in range(1, mdp_level + 1): 137 | state_abstr_stack.set_level(i) 138 | action_abstr_stack.set_level(i) 139 | mdp = make_abstr_mdp(mdp, state_abstr_stack, action_abstr_stack, step_cost, sample_rate) 140 | 141 | return mdp 142 | 143 | def make_abstr_mdp_distr_multi_level(mdp_distr, state_abstr, action_abstr, step_cost=0.1): 144 | ''' 145 | Args: 146 | mdp_distr (MDPDistribution) 147 | state_abstr (StateAbstraction) 148 | action_abstr (ActionAbstraction) 149 | 150 | Returns: 151 | (MDPDistribution) 152 | ''' 153 | 154 | # Loop through old mdps and abstract. 155 | mdp_distr_dict = {} 156 | for mdp in mdp_distr.get_all_mdps(): 157 | abstr_mdp = make_abstr_mdp_multi_level(mdp, state_abstr, action_abstr, step_cost=step_cost) 158 | prob_of_abstr_mdp = mdp_distr.get_prob_of_mdp(mdp) 159 | mdp_distr_dict[abstr_mdp] = prob_of_abstr_mdp 160 | 161 | return MDPDistribution(mdp_distr_dict) 162 | 163 | def _rew_dict_from_lambda(input_lambda, state_space, action_space, sample_rate): 164 | result_dict = defaultdict(lambda:defaultdict(float)) 165 | for s in state_space: 166 | for a in action_space: 167 | for i in range(sample_rate): 168 | result_dict[s][a] = input_lambda(s,a) / sample_rate 169 | 170 | return result_dict 171 | 172 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/agents/func_approx/GradientBoostingAgentClass.py: -------------------------------------------------------------------------------- 1 | ''' 2 | GradientBoostingAgentClass.py 3 | 4 | Implementation for a Q Learner with Gradient Boosting for an approximator. 5 | 6 | From: 7 | Abel, D., Agarwal, A., Diaz, F., Krishnamurthy, A., & Schapire, R. E. 8 | (2016). Exploratory Gradient Boosting for Reinforcement Learning in Complex Domains. 9 | ICML Workshop on RL and Abstraction (2016). arXiv pre#print arXiv:1603.04119. 10 | ''' 11 | 12 | # Python imports. 13 | import random 14 | import math 15 | import numpy as np 16 | import time 17 | try: 18 | from sklearn.ensemble import GradientBoostingRegressor 19 | except ImportError: 20 | raise ValueError("Error: sklearn not installed. See: http://scikit-learn.org/stable/install.html") 21 | 22 | # simple_rl classes. 23 | from simple_rl.agents.QLearningAgentClass import QLearningAgent 24 | 25 | class GradientBoostingAgent(QLearningAgent): 26 | ''' 27 | QLearningAgent that uses gradient boosting with additive regression trees to approximate the Q Function. 28 | ''' 29 | 30 | def __init__(self, actions, name="grad_boost", gamma=0.99, explore="softmax", markov_window=20, update_window=500): 31 | name += "-m" if markov_window > 0 else "" 32 | QLearningAgent.__init__(self, actions=actions, name=name, gamma=gamma, explore=explore) 33 | self.weak_learners = [] 34 | self.model = [] 35 | self.most_recent_episode = [] 36 | self.max_state_features = 0 37 | self.max_depth = len(actions)*2 38 | self.markov_window = markov_window 39 | self.update_window = 500 40 | 41 | def update(self, state, action, reward, next_state): 42 | ''' 43 | Args: 44 | state (State) 45 | action (str) 46 | reward (float) 47 | next_state (State) 48 | 49 | Summary: 50 | Updates the internal Q Function according to the Bellman Equation. (Classic Q Learning update) 51 | ''' 52 | 53 | # Update on a per step basis. 54 | if self.step_number > 0 and self.step_number % self.update_window == 0: 55 | self.add_new_weak_learner() 56 | self.most_recent_episode = [] 57 | 58 | if self.markov_window > 0: 59 | self.model = self.weak_learners[-self.markov_window:] 60 | else: 61 | self.model = self.weak_learners 62 | 63 | if None not in [state, action, reward, next_state]: 64 | if len(state.features()) > self.max_state_features: 65 | self.max_state_features = len(state.features()) 66 | self.most_recent_episode.append((state, action, reward, next_state)) 67 | 68 | def get_q_value(self, state, action): 69 | ''' 70 | Args: 71 | state (State): A State object containing the abstract state representation 72 | action (str): A string representing an action. See namespaceAIX. 73 | 74 | Summary: 75 | Retrieves the Q Value associated with this state/action pair. Computed via summing h functions. 76 | 77 | Returns: 78 | (float): denoting the q value of the (@state,@action) pair. 79 | ''' 80 | if len(self.weak_learners) == 0: 81 | # Default Q value. 82 | return 0 83 | 84 | features = self._pad_features_with_zeros(state, action) 85 | 86 | # Compute Q(s,a) 87 | predictions = [h.predict(features)[0] for h in self.model] 88 | result = float(sum(predictions)) # Cast since we'll normally get a numpy float. 89 | 90 | return result 91 | 92 | def _pad_features_with_zeros(self, state, action): 93 | ''' 94 | Args: 95 | features (iterable) 96 | 97 | Returns: 98 | (list): Of the same length as self.max_state_features 99 | ''' 100 | features = state.features() 101 | while len(features) < self.max_state_features: 102 | features = np.append(features, 0) 103 | 104 | # Reshape per update to cluster regression in sklearn 0.17. 105 | reshaped_features = np.append(features, [self.actions.index(action)]) 106 | reshaped_features = reshaped_features.reshape(1, -1) 107 | 108 | return reshaped_features 109 | 110 | def add_new_weak_learner(self): 111 | ''' 112 | Summary: 113 | Adds a new function, h, to self.weak_learners by solving for Eq. 1 using multiple additive regression trees: 114 | 115 | [Eq. 1] h = argmin_h (sum_i Q_A(s_i,a_i) + h(s_i, a_i) - (r_i + max_b Q_A(s'_i, b))) 116 | 117 | ''' 118 | if len(self.most_recent_episode) == 0: 119 | # If this episode contains no data, don't do anything. 120 | return 121 | 122 | # Build up data sets of features and loss terms 123 | data = np.zeros((len(self.most_recent_episode), self.max_state_features + 1)) 124 | total_loss = np.zeros(len(self.most_recent_episode)) 125 | 126 | for i, experience in enumerate(self.most_recent_episode): 127 | # Grab the experience. 128 | s, a, r, s_prime = experience 129 | 130 | # Pad in case the state features are too short (as in Atari sometimes). 131 | features = self._pad_features_with_zeros(s, a) 132 | loss = (r + self.gamma * self.get_max_q_value(s_prime) - self.get_q_value(s, a)) 133 | 134 | # Add to relevant lists. 135 | data[i] = features 136 | total_loss[i] = loss 137 | 138 | # Compute new regressor and add it to the weak learners. 139 | estimator = GradientBoostingRegressor(loss='ls', n_estimators=1, max_depth=self.max_depth) 140 | estimator.fit(data, total_loss) 141 | self.weak_learners.append(estimator) 142 | 143 | def end_of_episode(self): 144 | ''' 145 | Summary: 146 | Performs miscellaneous end of episode tasks (#printing out useful information, saving stuff, etc.) 147 | ''' 148 | 149 | # self.model = self.weak_learners 150 | self.add_new_weak_learner() 151 | self.most_recent_episode = [] 152 | 153 | if self.markov_window > 0: 154 | # num_sampled_trees = int(math.ceil(len(self.weak_learners) / 10.0)) 155 | # self.model = random.sample(self.weak_learners, num_sampled_trees) 156 | self.model = self.weak_learners[-self.markov_window:] 157 | else: 158 | self.model = self.weak_learners 159 | 160 | Agent.end_of_episode(self) 161 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/agents/DoubleQAgentClass.py: -------------------------------------------------------------------------------- 1 | ''' 2 | DoubleQAgentClass.py: Class for an RL Agent acting according to Double Q Learning from: 3 | 4 | Hasselt, H. V. (2010). Double Q-learning. 5 | In Advances in Neural Information Processing Systems (pp. 2613-2621). 6 | 7 | Author: David Abel 8 | ''' 9 | 10 | # Python imports. 11 | import random 12 | from collections import defaultdict 13 | 14 | # Other imports 15 | from simple_rl.agents.QLearningAgentClass import QLearningAgent 16 | from simple_rl.agents.AgentClass import Agent 17 | 18 | class DoubleQAgent(QLearningAgent): 19 | ''' Class for an agent using Double Q Learning. ''' 20 | 21 | def __init__(self, actions, name="Double-Q", alpha=0.05, gamma=0.99, epsilon=0.1, explore="uniform", anneal=False): 22 | ''' 23 | Args: 24 | actions (list): Contains strings denoting the actions. 25 | name (str): Denotes the name of the agent. 26 | alpha (float): Learning rate. 27 | gamma (float): Discount factor. 28 | epsilon (float): Exploration term. 29 | explore (str): One of {softmax, uniform}. Denotes explore policy. 30 | ''' 31 | QLearningAgent.__init__(self, actions, name=name, alpha=alpha, gamma=gamma, epsilon=epsilon, explore=explore, anneal=anneal) 32 | 33 | # Make two q functions. 34 | self.q_funcs = {"A":defaultdict(lambda : defaultdict(lambda: self.default_q)), \ 35 | "B":defaultdict(lambda : defaultdict(lambda: self.default_q))} 36 | 37 | 38 | def act(self, state, reward): 39 | ''' 40 | Args: 41 | state (State) 42 | reward (float) 43 | 44 | Summary: 45 | The central method called during each time step. 46 | Retrieves the action according to the current policy 47 | and performs updates. 48 | ''' 49 | self.update(self.prev_state, self.prev_action, reward, state) 50 | 51 | if self.explore == "softmax": 52 | # Softmax exploration 53 | action = self.soft_max_policy(state) 54 | else: 55 | # Uniform exploration 56 | action = self.epsilon_greedy_q_policy(state) 57 | 58 | self.prev_state = state 59 | self.prev_action = action 60 | self.step_number += 1 61 | 62 | # Anneal params. 63 | if self.anneal: 64 | self._anneal() 65 | 66 | return action 67 | 68 | def update(self, state, action, reward, next_state): 69 | ''' 70 | Args: 71 | state (State) 72 | action (str) 73 | reward (float) 74 | next_state (State) 75 | 76 | Summary: 77 | Updates the internal Q Function according to the Double Q update: 78 | 79 | 80 | ''' 81 | # If this is the first state, just return. 82 | if state is None: 83 | self.prev_state = next_state 84 | return 85 | 86 | # Randomly choose either "A" or "B". 87 | which_q_func = "A" if bool(random.getrandbits(1)) else "B" 88 | other_q_func = "B" if which_q_func is "A" else "A" 89 | 90 | # Update the Q Function. 91 | 92 | # Get max q action of the chosen Q func. 93 | max_q_action = self.get_max_q_action(next_state, q_func_id=which_q_func) 94 | prev_q_val = self.get_q_value(state, action, q_func_id=which_q_func) 95 | 96 | # Update 97 | self.q_funcs[which_q_func][state][action] = (1 - self.alpha) * prev_q_val + self.alpha * (reward + self.gamma * self.get_q_value(next_state, max_q_action, q_func_id=other_q_func)) 98 | 99 | def get_max_q_action(self, state, q_func_id=None): 100 | ''' 101 | Args: 102 | state (State) 103 | q_func_id (str): either "A" or "B" 104 | 105 | Returns: 106 | (str): denoting the action with the max q value in the given @state. 107 | ''' 108 | return self._compute_max_qval_action_pair(state, q_func_id)[1] 109 | 110 | def get_max_q_value(self, state, q_func_id=None): 111 | ''' 112 | Args: 113 | state (State) 114 | q_func_id (str): either "A" or "B" 115 | 116 | Returns: 117 | (float): denoting the max q value in the given @state. 118 | ''' 119 | return self._compute_max_qval_action_pair(state, q_func_id)[0] 120 | 121 | def _compute_max_qval_action_pair(self, state, q_func_id=None): 122 | ''' 123 | Args: 124 | state (State) 125 | q_func_id (str): either "A", "B", or None. If None, computes avg of A and B. 126 | 127 | Returns: 128 | (tuple) --> (float, str): where the float is the Qval, str is the action. 129 | ''' 130 | # Grab random initial action in case all equal 131 | best_action = random.choice(self.actions) 132 | max_q_val = float("-inf") 133 | shuffled_action_list = self.actions[:] 134 | random.shuffle(shuffled_action_list) 135 | 136 | # Find best action (action w/ current max predicted Q value) 137 | for action in shuffled_action_list: 138 | q_s_a = self.get_q_value(state, action, q_func_id) 139 | if q_s_a > max_q_val: 140 | max_q_val = q_s_a 141 | best_action = action 142 | 143 | return max_q_val, best_action 144 | 145 | def get_q_value(self, state, action, q_func_id=None): 146 | ''' 147 | Args: 148 | state (State) 149 | action (str) 150 | q_func_id (str): either "A", "B", or defaults to taking the average. 151 | 152 | Returns: 153 | (float): denoting the q value of the (@state, @action) pair relative to 154 | the specified q function. 155 | ''' 156 | if q_func_id is None: 157 | return self.get_avg_q_value(state, action) 158 | else: 159 | return self.q_funcs[q_func_id][state][action] 160 | 161 | def reset(self): 162 | self.step_number = 0 163 | self.episode_number = 0 164 | self.q_funcs = {"A":defaultdict(lambda : defaultdict(lambda: self.default_q)), \ 165 | "B":defaultdict(lambda : defaultdict(lambda: self.default_q))} 166 | Agent.reset(self) 167 | 168 | # ---- DOUBLE Q NEW ---- 169 | 170 | def get_avg_q_value(self, state, action): 171 | ''' 172 | Args: 173 | state (State) 174 | action (str) 175 | 176 | Returns: 177 | (float): denoting the avg. q value of the (@state, @action) pair. 178 | ''' 179 | return (self.q_funcs["A"][state][action] + self.q_funcs["B"][state][action]) / 2.0 180 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/grid_world/grid_visualizer.py: -------------------------------------------------------------------------------- 1 | # Python imports. 2 | from __future__ import print_function 3 | from collections import defaultdict 4 | try: 5 | import pygame 6 | except ImportError: 7 | print("Warning: pygame not installed (needed for visuals).") 8 | import random 9 | import sys 10 | 11 | # Other imports. 12 | from simple_rl.planning import ValueIteration 13 | from simple_rl.tasks import FourRoomMDP 14 | from simple_rl.utils import mdp_visualizer as mdpv 15 | 16 | 17 | def _draw_state(screen, 18 | grid_mdp, 19 | state, 20 | policy=None, 21 | action_char_dict={}, 22 | show_value=False, 23 | agent=None, 24 | draw_statics=False, 25 | agent_shape=None): 26 | ''' 27 | Args: 28 | screen (pygame.Surface) 29 | grid_mdp (MDP) 30 | state (State) 31 | show_value (bool) 32 | agent (Agent): Used to show value, by default uses VI. 33 | draw_statics (bool) 34 | agent_shape (pygame.rect) 35 | 36 | Returns: 37 | (pygame.Shape) 38 | ''' 39 | # Make value dict. 40 | val_text_dict = defaultdict(lambda : defaultdict(float)) 41 | if show_value: 42 | if agent is not None: 43 | # Use agent value estimates. 44 | for s in agent.q_func.keys(): 45 | val_text_dict[s.x][s.y] = agent.get_value(s) 46 | else: 47 | # Use Value Iteration to compute value. 48 | vi = ValueIteration(grid_mdp) 49 | vi.run_vi() 50 | for s in vi.get_states(): 51 | val_text_dict[s.x][s.y] = vi.get_value(s) 52 | 53 | # Make policy dict. 54 | policy_dict = defaultdict(lambda : defaultdict(str)) 55 | if policy: 56 | vi = ValueIteration(grid_mdp) 57 | vi.run_vi() 58 | for s in vi.get_states(): 59 | policy_dict[s.x][s.y] = policy(s) 60 | 61 | # Prep some dimensions to make drawing easier. 62 | scr_width, scr_height = screen.get_width(), screen.get_height() 63 | width_buffer = scr_width / 10.0 64 | height_buffer = 30 + (scr_height / 10.0) # Add 30 for title. 65 | cell_width = (scr_width - width_buffer * 2) / grid_mdp.width 66 | cell_height = (scr_height - height_buffer * 2) / grid_mdp.height 67 | goal_locs = grid_mdp.get_goal_locs() 68 | lava_locs = grid_mdp.get_lava_locs() 69 | font_size = int(min(cell_width, cell_height) / 4.0) 70 | reg_font = pygame.font.SysFont("CMU Serif", font_size) 71 | cc_font = pygame.font.SysFont("Courier", font_size*2 + 2) 72 | 73 | # Draw the static entities. 74 | if draw_statics: 75 | # For each row: 76 | for i in range(grid_mdp.width): 77 | # For each column: 78 | for j in range(grid_mdp.height): 79 | 80 | top_left_point = width_buffer + cell_width*i, height_buffer + cell_height*j 81 | r = pygame.draw.rect(screen, (46, 49, 49), top_left_point + (cell_width, cell_height), 3) 82 | 83 | if policy and not grid_mdp.is_wall(i+1, grid_mdp.height - j): 84 | a = policy_dict[i+1][grid_mdp.height - j] 85 | if a not in action_char_dict: 86 | text_a = a 87 | else: 88 | text_a = action_char_dict[a] 89 | text_center_point = int(top_left_point[0] + cell_width/2.0 - 10), int(top_left_point[1] + cell_height/3.0) 90 | text_rendered_a = cc_font.render(text_a, True, (46, 49, 49)) 91 | screen.blit(text_rendered_a, text_center_point) 92 | 93 | if show_value and not grid_mdp.is_wall(i+1, grid_mdp.height - j): 94 | # Draw the value. 95 | val = val_text_dict[i+1][grid_mdp.height - j] 96 | color = mdpv.val_to_color(val) 97 | pygame.draw.rect(screen, color, top_left_point + (cell_width, cell_height), 0) 98 | 99 | if grid_mdp.is_wall(i+1, grid_mdp.height - j): 100 | # Draw the walls. 101 | top_left_point = width_buffer + cell_width*i + 5, height_buffer + cell_height*j + 5 102 | r = pygame.draw.rect(screen, (94, 99, 99), top_left_point + (cell_width-10, cell_height-10), 0) 103 | 104 | if (i+1,grid_mdp.height - j) in goal_locs: 105 | # Draw goal. 106 | circle_center = int(top_left_point[0] + cell_width/2.0), int(top_left_point[1] + cell_height/2.0) 107 | circler_color = (154, 195, 157) 108 | pygame.draw.circle(screen, circler_color, circle_center, int(min(cell_width, cell_height) / 3.0)) 109 | 110 | if (i+1,grid_mdp.height - j) in lava_locs: 111 | # Draw goal. 112 | circle_center = int(top_left_point[0] + cell_width/2.0), int(top_left_point[1] + cell_height/2.0) 113 | circler_color = (224, 145, 157) 114 | pygame.draw.circle(screen, circler_color, circle_center, int(min(cell_width, cell_height) / 4.0)) 115 | 116 | # Current state. 117 | if not show_value and (i+1,grid_mdp.height - j) == (state.x, state.y) and agent_shape is None: 118 | tri_center = int(top_left_point[0] + cell_width/2.0), int(top_left_point[1] + cell_height/2.0) 119 | agent_shape = _draw_agent(tri_center, screen, base_size=min(cell_width, cell_height)/2.5 - 8) 120 | 121 | if agent_shape is not None: 122 | # Clear the old shape. 123 | pygame.draw.rect(screen, (255,255,255), agent_shape) 124 | top_left_point = width_buffer + cell_width*(state.x - 1), height_buffer + cell_height*(grid_mdp.height - state.y) 125 | tri_center = int(top_left_point[0] + cell_width/2.0), int(top_left_point[1] + cell_height/2.0) 126 | 127 | # Draw new. 128 | agent_shape = _draw_agent(tri_center, screen, base_size=min(cell_width, cell_height)/2.5 - 8) 129 | 130 | pygame.display.flip() 131 | 132 | return agent_shape 133 | 134 | 135 | def _draw_agent(center_point, screen, base_size=20): 136 | ''' 137 | Args: 138 | center_point (tuple): (x,y) 139 | screen (pygame.Surface) 140 | 141 | Returns: 142 | (pygame.rect) 143 | ''' 144 | tri_bot_left = center_point[0] - base_size, center_point[1] + base_size 145 | tri_bot_right = center_point[0] + base_size, center_point[1] + base_size 146 | tri_top = center_point[0], center_point[1] - base_size 147 | tri = [tri_bot_left, tri_top, tri_bot_right] 148 | tri_color = (98, 140, 190) 149 | return pygame.draw.polygon(screen, tri_color, tri) 150 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/planning/BeliefSparseSamplingClass.py: -------------------------------------------------------------------------------- 1 | from math import log 2 | import numpy as np 3 | import copy 4 | from collections import defaultdict 5 | import random 6 | 7 | from simple_rl.pomdp.BeliefMDPClass import BeliefMDP 8 | 9 | 10 | class BeliefSparseSampling(object): 11 | ''' 12 | A Sparse Sampling Algorithm for Near-Optimal Planning in Large Markov Decision Processes (Kearns et al) 13 | 14 | Assuming that you don't have access to the underlying transition dynamics, but do have access to a naiive generative 15 | model of the underlying MDP, this algorithm performs on-line, near-optimal planning with a per-state running time 16 | that has no dependence on the number of states in the MDP. 17 | ''' 18 | def __init__(self, gen_model, gamma, tol, max_reward, state, name="bss"): 19 | ''' 20 | Args: 21 | gen_model (BeliefMDP): Model of our MDP -- we tell it what action we are performing from some state s 22 | and it will return what our next state is 23 | gamma (float): MDP discount factor 24 | tol (float): Most expected difference between optimal and computed value function 25 | max_reward (float): Upper bound on the reward you can get for any state, action 26 | state (State): This is the current state, and we need to output the action to take here 27 | ''' 28 | self.tol = tol 29 | self.gamma = gamma 30 | self.max_reward = max_reward 31 | self.gen_model = gen_model 32 | self.current_state = state 33 | self.horizon = self._horizon 34 | self.width = self._width 35 | 36 | print 'BSS Horizon = {} \t Width = {}'.format(self.horizon, self.width) 37 | 38 | self.name = name 39 | self.root_level_qvals = defaultdict() 40 | self.nodes_by_horizon = defaultdict(lambda: defaultdict(float)) 41 | 42 | @property 43 | def _horizon(self): 44 | ''' 45 | Returns: 46 | _horizon (int): The planning horizon; depth of the recursive tree created to determined the near-optimal 47 | action to take from a given state 48 | ''' 49 | return int(log((self._lam / self._vmax), self.gamma)) 50 | 51 | @property 52 | def _width(self): 53 | ''' 54 | The number of times we ask the generative model to give us a next_state sample for each state, action pair. 55 | Returns: 56 | _width (int) 57 | ''' 58 | part1 = (self._vmax ** 2) / (self._lam ** 2) 59 | part2 = 2 * self._horizon * log(self._horizon * (self._vmax ** 2) / (self._lam ** 2)) 60 | part3 = log(self.max_reward / self._lam) 61 | return int(part1 * (part2 + part3)) 62 | 63 | @property 64 | def _lam(self): 65 | return (self.tol * (1.0 - self.gamma) ** 2) / 4.0 66 | 67 | @property 68 | def _vmax(self): 69 | return float(self.max_reward) / (1 - self.gamma) 70 | 71 | def _get_width_at_height(self, height): 72 | ''' 73 | The branching factor of the tree is decayed according to this formula as suggested by the BSS paper. 74 | Args: 75 | height (int): the current depth in the MDP recursive tree measured from top 76 | Returns: 77 | width (int): the decayed branching factor for a state, action pair 78 | ''' 79 | c = int(self.width * (self.gamma ** (2 * height))) 80 | return c if c > 1 else 1 81 | 82 | def _estimate_qs(self, state, horizon): 83 | qvalues = np.zeros(len(self.gen_model.actions)) 84 | for action_idx, action in enumerate(self.gen_model.actions): 85 | if horizon <= 0: 86 | qvalues[action_idx] = 0.0 87 | else: 88 | qvalues[action_idx] = self._sampled_q_estimate(state, action, horizon) 89 | return qvalues 90 | 91 | def _sampled_q_estimate(self, state, action, horizon): 92 | ''' 93 | Args: 94 | state (State): current state in MDP 95 | action (str): action to take from `state` 96 | horizon (int): planning horizon / depth of recursive tree 97 | 98 | Returns: 99 | average_reward (float): measure of how good (s, a) would be 100 | ''' 101 | total = 0.0 102 | width = self._get_width_at_height(self.horizon - horizon) 103 | for _ in range(width): 104 | next_state = self.gen_model.transition_func(state, action) 105 | total += self.gen_model.reward_func(state, action) + (self.gamma * self._estimate_v(next_state, horizon-1)) 106 | return total / float(width) 107 | 108 | def _estimate_v(self, state, horizon): 109 | ''' 110 | Args: 111 | state (State): current state 112 | horizon (int): time steps in future you want to use to estimate V* 113 | 114 | Returns: 115 | V(s) (float) 116 | ''' 117 | if state in self.nodes_by_horizon: 118 | if horizon in self.nodes_by_horizon[state]: 119 | return self.nodes_by_horizon[state][horizon] 120 | 121 | if self.gen_model.is_in_goal_state(): 122 | self.nodes_by_horizon[state][horizon] = self.gen_model.reward_func(state, random.choice(self.gen_model.actions)) 123 | else: 124 | self.nodes_by_horizon[state][horizon] = np.max(self._estimate_qs(state, horizon)) 125 | 126 | return self.nodes_by_horizon[state][horizon] 127 | 128 | def plan_from_state(self, state): 129 | ''' 130 | Args: 131 | state (State): the current state in the MDP 132 | 133 | Returns: 134 | action (str): near-optimal action to perform from state 135 | ''' 136 | if state in self.root_level_qvals: 137 | qvalues = self.root_level_qvals[state] 138 | else: 139 | init_horizon = self.horizon 140 | qvalues = self._estimate_qs(state, init_horizon) 141 | action_idx = np.argmax(qvalues) 142 | self.root_level_qvals[state] = qvalues 143 | return self.gen_model.actions[action_idx] 144 | 145 | def run(self, verbose=True): 146 | discounted_sum_rewards = 0.0 147 | num_iter = 0 148 | self.gen_model.reset() 149 | state = self.gen_model.init_state 150 | policy = defaultdict() 151 | while not self.gen_model.is_in_goal_state(): 152 | action = self.plan_from_state(state) 153 | reward, next_state = self.gen_model.execute_agent_action(action) 154 | policy[state] = action 155 | discounted_sum_rewards += ((self.gamma ** num_iter) * reward) 156 | if verbose: print '({}, {}, {}) -> {} | {}'.format(state, action, next_state, reward, discounted_sum_rewards) 157 | state = copy.deepcopy(next_state) 158 | num_iter += 1 159 | return discounted_sum_rewards, policy 160 | -------------------------------------------------------------------------------- /SR-LLRL/simple_rl/tasks/gather/GatherStateClass.py: -------------------------------------------------------------------------------- 1 | ''' GatherStateClass.py: Contains the GatherState class. ''' 2 | 3 | # Other imports. 4 | from simple_rl.mdp.StateClass import State 5 | import numpy as np 6 | import time 7 | import matplotlib 8 | matplotlib.use('TkAgg') 9 | import matplotlib.pyplot as plt # NOTE: for debugging 10 | 11 | COLORS = { 12 | 'agent1': (0, 34, 244), 13 | 'agent2': (236, 51, 35), 14 | 'orientation': (46, 47, 46), 15 | 'apple': (132, 249, 77), 16 | 'light': (140, 139, 42), 17 | 'walls': (138, 140, 137), 18 | } 19 | 20 | class GatherState(State): 21 | 22 | def __init__(self, agent1, agent2, apple_locations, apple_times, render_time=0.01): 23 | super(GatherState, self).__init__(data=[], is_terminal=False) 24 | 25 | # Locations of player 1 and player 2 26 | self.agent1, self.agent2 = agent1, agent2 27 | self.apple_locations = apple_locations 28 | # TODO: 1. Incorporate apple_times into hash, str, and eq 29 | self.apple_times = apple_times 30 | self.x_dim = apple_locations.shape[0] 31 | self.y_dim = apple_locations.shape[1] 32 | self.render_time = render_time 33 | 34 | def __hash__(self): 35 | return hash(tuple(str(self.agent1), str(self.agent2), str(self.apple_locations))) 36 | 37 | def __str__(self): 38 | stateString = [str(self.agent1), str(self.agent2), self.apple_locations.tostring()] 39 | return ''.join(stateString) 40 | 41 | def __eq__(self, other): 42 | if not isinstance(other, GatherState): 43 | return False 44 | return self.agent1 == other.agent1 and self.agent2 == other.agent2 and np.array_equal(self.apple_locations, other.apple_locations) 45 | 46 | def to_rgb(self): 47 | # 3 by x_length by y_length array with values 0 (0) --> 1 (255) 48 | board = np.zeros(shape=[3, self.x_dim, self.y_dim]) 49 | 50 | # Orientation (do this first so that more important things override) 51 | orientation = self.agent1.get_orientation() 52 | board[:, orientation[0], orientation[1]] = COLORS['orientation'] 53 | orientation = self.agent2.get_orientation() 54 | board[:, orientation[0], orientation[1]] = COLORS['orientation'] 55 | 56 | # Beams 57 | if self.agent1.is_shining: 58 | beam = self.agent1.get_beam(self.x_dim, self.y_dim) 59 | board[:, beam[0], beam[1]] = np.transpose(np.ones(shape=[beam[2], 1])*COLORS['light']) 60 | if self.agent2.is_shining: 61 | beam = self.agent2.get_beam(self.x_dim, self.y_dim) 62 | board[:, beam[0], beam[1]] = np.transpose(np.ones(shape=[beam[2], 1])*COLORS['light']) 63 | 64 | # Apples 65 | board[0, (self.apple_locations == 1)] = COLORS['apple'][0] 66 | board[1, (self.apple_locations == 1)] = COLORS['apple'][1] 67 | board[2, (self.apple_locations == 1)] = COLORS['apple'][2] 68 | 69 | # Agents 70 | board[:, self.agent1.x, self.agent1.y] = COLORS['agent1'] 71 | board[:, self.agent2.x, self.agent2.y] = COLORS['agent2'] 72 | 73 | # Walls 74 | board[:, np.arange(0, self.x_dim), 0] = np.transpose(np.ones(shape=[self.x_dim, 1])*COLORS['walls']) 75 | board[:, np.arange(0, self.x_dim), self.y_dim - 1] = np.transpose(np.ones(shape=[self.x_dim, 1])*COLORS['walls']) 76 | board[:, 0, np.arange(0, self.y_dim)] = np.transpose(np.ones(shape=[self.y_dim, 1])*COLORS['walls']) 77 | board[:, self.x_dim - 1, np.arange(0, self.y_dim)] = np.transpose(np.ones(shape=[self.y_dim, 1])*COLORS['walls']) 78 | board = board/(255.0) 79 | 80 | return np.transpose(board, axes=[2, 1, 0]) 81 | 82 | def generate_next_state(self): 83 | # assume that we are just copying the current apple locations 84 | # print self.apple_locations 85 | # new_apple_locations = np.copyto(np.empty_like(self.apple_locations), self.apple_locations) 86 | new_apple_locations = np.array(self.apple_locations) 87 | new_apple_times = {} 88 | for apple in self.apple_times.keys(): 89 | new_apple_times[apple] = self.apple_times[apple] 90 | return GatherState(self.agent1, self.agent2, new_apple_locations, new_apple_times) 91 | 92 | def show(self): 93 | rgb = self.to_rgb() 94 | plt.imshow(rgb) 95 | plt.pause(self.render_time) 96 | plt.draw() 97 | 98 | class GatherAgent(): 99 | 100 | def __init__(self, x, y, is_shining, orientation, hits, frozen_time_remaining): 101 | self.x, self.y, self.is_shining, = x, y, is_shining 102 | self.orientation, self.hits = orientation, hits 103 | self.frozen_time_remaining = frozen_time_remaining 104 | 105 | def get_orientation(self): 106 | if self.orientation == 'NORTH': 107 | return self.x, self.y - 1 108 | if self.orientation == 'SOUTH': 109 | return self.x, self.y + 1 110 | if self.orientation == 'WEST': 111 | return self.x - 1, self.y 112 | if self.orientation == 'EAST': 113 | return self.x + 1, self.y 114 | assert False, 'Invalid direction.' 115 | 116 | def get_beam(self, x_dim, y_dim): 117 | assert self.is_shining, 'get_beam called when beam not shining' 118 | orientation = self.get_orientation() 119 | if self.orientation == 'NORTH': 120 | return orientation[0], np.arange(0, orientation[1] + 1), orientation[1] + 1 121 | if self.orientation == 'SOUTH': 122 | return orientation[0], np.arange(orientation[1], y_dim), y_dim - orientation[1] 123 | if self.orientation == 'WEST': 124 | return np.arange(0, orientation[0] + 1), orientation[1], orientation[0] + 1 125 | if self.orientation == 'EAST': 126 | return np.arange(orientation[0], x_dim), orientation[1], x_dim - orientation[0] 127 | assert False, 'Invalid direction.' 128 | 129 | def __hash__(self): 130 | return hash(str(self)) 131 | 132 | def __str__(self): 133 | agentString = ['{:02d}'.format(self.x), '{:02d}'.format(self.y), '1' if self.is_shining else '0', self.orientation, str(self.hits), str(self.frozen_time_remaining)] 134 | return ''.join(agentString) 135 | 136 | def __eq__(self, other): 137 | if not isinstance(other, Agent): 138 | return False 139 | return str(self) == str(other) 140 | 141 | def clone(self): 142 | return GatherAgent(self.x, self.y, self.is_shining, self.orientation, self.hits, self.frozen_time_remaining) 143 | 144 | 145 | if __name__ == '__main__': 146 | agent1 = GatherAgent(32, 6, False, 'NORTH', 0, 0) 147 | agent2 = GatherAgent(31, 5, False, 'NORTH', 0, 0) 148 | agent3 = GatherAgent(5, 6, True, 'NORTH', None, None) 149 | agent4 = GatherAgent(1, 2, True, 'EAST', None, None) 150 | state1 = GatherState(agent1, agent2, np.zeros(shape=[35, 11], dtype=np.int32)) 151 | state2 = GatherState(agent3, agent4, np.zeros(shape=[21, 11], dtype=np.int32)) 152 | state3 = GatherState(agent3, agent4, np.zeros(shape=[21, 11], dtype=np.int32)) 153 | plt.imshow(state1.to_rgb()) 154 | plt.show() 155 | --------------------------------------------------------------------------------