├── Action.py ├── Reward.py ├── README.md ├── Observation.py ├── LICENSE ├── Controller.py ├── Environment.py └── Agent.py /Action.py: -------------------------------------------------------------------------------- 1 | import sys 2 | class Action: 3 | actionValue = -1 4 | 5 | def __init__(self, value=None): 6 | if value != None: 7 | self.actionValue = value 8 | 9 | -------------------------------------------------------------------------------- /Reward.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | class Reward: 4 | 5 | rewardValue = 0.0 6 | pseudoRewardValue = 0.0 7 | def __init__(self, value=None): 8 | if value != None: 9 | self.rewardValue = value 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # westworld 2 | please see https://markriedl.github.io/westworld/ 3 | 4 | For accompanying Medium article, see https://medium.com/@mark_riedl/westworld-programming-ai-to-feel-pain-f26195c798ee#.fewy20vti 5 | -------------------------------------------------------------------------------- /Observation.py: -------------------------------------------------------------------------------- 1 | class Observation: 2 | worldState = [] 3 | availableActions = [] 4 | hierarchy = {} 5 | isTerminal = None 6 | def __init__(self, state=None, actions=None, hierarchy=None, isTerminal=None): 7 | if state != None: 8 | self.worldState = state 9 | 10 | if actions != None: 11 | self.availableActions = actions 12 | 13 | if hierarchy != None: 14 | self.hierarchy = hierarchy 15 | 16 | if isTerminal != None: 17 | self.isTerminal = isTerminal 18 | 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 markriedl 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Controller.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from Observation import * 3 | from Reward import * 4 | from Action import * 5 | from Agent import * 6 | from Environment import * 7 | import numpy 8 | 9 | 10 | # Training episodes 11 | episodes = 1000 12 | 13 | trainingReportRate = 1000 14 | 15 | # How many memories can the agent have? 16 | numMemories = 1 #2# 17 | 18 | # Reverie mode is false by default 19 | reverie = False #3# 20 | 21 | # Retrain the agent after reverie? 22 | retrain = False 23 | 24 | 25 | #Max reward received in any iteration 26 | maxr = None 27 | 28 | # Set up environment for initial training 29 | gridEnvironment = Environment() 30 | gridEnvironment.randomStart = True 31 | gridEnvironment.humanWander = False 32 | gridEnvironment.verbose = False 33 | gridEnvironment.humanCanTorture = True #4# 34 | 35 | # Set up agent 36 | gridAgent = Agent(gridEnvironment) 37 | gridAgent.verbose = False 38 | 39 | # This is where learning happens 40 | for i in range(episodes): 41 | # Train 42 | gridAgent.agent_reset() 43 | gridAgent.qLearn(gridAgent.initialObs) 44 | # Test 45 | gridAgent.agent_reset() 46 | gridAgent.executePolicy(gridAgent.initialObs) 47 | # Report 48 | totalr = gridAgent.totalReward 49 | if maxr == None or totalr > maxr: 50 | maxr = totalr 51 | 52 | if i%(episodes/trainingReportRate) == 0: 53 | print "iteration:", i, "max reward:", maxr 54 | 55 | 56 | # Reset the environment for policy execution 57 | gridEnvironment.verbose = True 58 | gridEnvironment.randomStart = True # Don't change this or memories won't be created properly! 59 | gridEnvironment.humanWander = False 60 | gridEnvironment.humanCanTorture = True 61 | 62 | gridAgent.verbose = True 63 | 64 | # Make a number of memories. Also doubles as testing 65 | print "---" 66 | for i in range(numMemories): 67 | print "Execute Policy", i 68 | gridAgent.agent_reset() 69 | gridAgent.executePolicy(gridAgent.initialObs) 70 | print "total reward", gridAgent.totalReward 71 | gridAgent.memory.append(gridAgent.trace) 72 | print "---" 73 | 74 | 75 | # Reverie mode 76 | if reverie: 77 | # get agent ready to learn from memories 78 | gridAgent.lastAction=Action() 79 | gridAgent.lastObservation=Observation() 80 | 81 | gridAgent.verbose = True 82 | gridEnvironment.verbose = True 83 | 84 | # Replaying memories creates the value table that the agent would have if all it had to go on was the memories 85 | print "Replaying memories", len(gridAgent.memory) 86 | gridEnvironment.randomStart = False # Don't change this for the replay 87 | counter = 0 88 | print "---" 89 | for m in gridAgent.memory: 90 | obs = m[0][0].worldState 91 | print "Learn from memory", counter 92 | print "init state", obs 93 | gridEnvironment.startState = obs 94 | gridAgent.agent_reset() 95 | gridAgent.lastAction=Action() 96 | gridAgent.lastObservation=Observation() 97 | gridAgent.gridEnvironment = gridEnvironment 98 | gridAgent.initialObs = gridEnvironment.env_start() 99 | gridAgent.initializeInitialObservation(gridEnvironment) 100 | gridAgent.replayMemory(gridAgent.initialObs, m) 101 | # Report 102 | print "replay", counter, "total reward", gridAgent.totalReward 103 | print "---" 104 | counter = counter + 1 105 | 106 | # Reset the environment for policy execution 107 | gridEnvironment = Environment() 108 | gridEnvironment.verbose = True 109 | gridEnvironment.randomStart = True 110 | gridEnvironment.humanWander = False 111 | gridEnvironment.humanCanTorture = True 112 | 113 | gridAgent.gridEnvironment = gridEnvironment 114 | gridAgent.agent_reset() 115 | 116 | gridAgent.verbose = True 117 | 118 | 119 | # Test new v table 120 | print "---" 121 | for i in range(100): 122 | print "Execute Post-Reverie Policy", i 123 | gridAgent.initialObs = gridEnvironment.env_start() 124 | gridAgent.initializeInitialObservation(gridEnvironment) 125 | gridAgent.agent_reset() 126 | gridAgent.executePolicy(gridAgent.initialObs) 127 | print "total reward", gridAgent.totalReward 128 | gridAgent.memory.append(gridAgent.trace) 129 | print "---" 130 | 131 | 132 | # Retrain the agent 133 | if retrain: 134 | maxr = None 135 | for i in range(0): 136 | # Train 137 | gridAgent.agent_reset() 138 | gridAgent.qLearn(gridAgent.initialObs) 139 | # Test 140 | gridAgent.agent_reset() 141 | gridAgent.executePolicy(gridAgent.initialObs) 142 | # Report 143 | totalr = gridAgent.totalReward 144 | if maxr == None or totalr > maxr: 145 | maxr = totalr 146 | 147 | if i%(episodes/trainingReportRate) == 0: 148 | print "iteration:", i, "max reward:", maxr 149 | 150 | # Reset the environment for policy execution 151 | gridEnvironment.verbose = True 152 | gridEnvironment.randomStart = True 153 | gridEnvironment.humanWander = False 154 | gridEnvironment.humanCanTorture = True 155 | gridAgent.agent_reset() 156 | 157 | # Test new v table 158 | print "---" 159 | for i in range(numMemories): 160 | print "Execute Policy", i 161 | gridAgent.initialObs = gridEnvironment.env_start() 162 | gridAgent.initializeInitialObservation(gridEnvironment) 163 | gridAgent.agent_reset() 164 | gridAgent.executePolicy(gridAgent.initialObs) 165 | print "total reward", gridAgent.totalReward 166 | gridAgent.memory.append(gridAgent.trace) 167 | -------------------------------------------------------------------------------- /Environment.py: -------------------------------------------------------------------------------- 1 | import random 2 | import copy 3 | import sys 4 | from Observation import * 5 | from Reward import * 6 | from Action import * 7 | 8 | 9 | class Environment: 10 | 11 | # The grid world 12 | # 1 = walls 13 | # 4 = goal (non-terminal) 14 | # 5 = goal (terminal) 15 | map = [[1, 1, 1, 1, 1, 1, 1], 16 | [1, 0, 0, 0, 0, 0, 1], 17 | [1, 0, 0, 4, 0, 0, 1], 18 | [1, 1, 1, 1, 1, 1, 1]] 19 | 20 | # Which direction should the human walk? 21 | # 0 = up 22 | # 1 = down 23 | # 2 = left 24 | # 3 = right 25 | influenceMap = [[3, 1, 1, 1, 1, 1, 2], 26 | [3, 1, 2, 2, 2, 2, 2], 27 | [3, 3, 3, 3, 3, 0, 2], 28 | [3, 0, 0, 0, 0, 0, 2]] 29 | 30 | # The current state 31 | currentState = [] 32 | 33 | # The previous state 34 | previousState = [] 35 | 36 | # Hard-coded initial state (used unless randomStart = True) 37 | # 0: bot x 38 | # 1: bot y 39 | # 2: human alive? 40 | # 3: human x 41 | # 4: human y 42 | # 5: human torture mode? 43 | startState = [1, 1, True, 5, 1, False] 44 | 45 | # Amount of reward at the goal 46 | reward = 10.0 47 | 48 | # Amount of penalty 49 | penalty = -1.0 50 | 51 | # Amount of penalty from touching the human 52 | pain = -20.0 53 | 54 | # Amount of penalty from dead human 55 | dead = -100.0 #1# 56 | 57 | # The execution trace 58 | trace = [] 59 | 60 | # Incremented every step 61 | counter = 0 62 | 63 | # How often should the human move? 64 | #timer = 1 65 | 66 | # Randomly generate a start state 67 | randomStart = False 68 | 69 | # Can the human torture? 70 | humanCanTorture = True 71 | 72 | randGenerator=random.Random() 73 | lastActionValue = -1 74 | 75 | # Print debuggin information 76 | verbose = False 77 | 78 | # 0 = up 79 | # 1 = down 80 | # 2 = left 81 | # 3 = right 82 | # 4 = smash 83 | def validActions(self): 84 | resultArray = [0, 1, 2, 3, 4] 85 | return resultArray 86 | 87 | # Get the name of the action 88 | def actionToString(self, act): 89 | if act == 0: 90 | return "GoUp" 91 | elif act == 1: 92 | return "GoDown" 93 | elif act == 2: 94 | return "GoLeft" 95 | elif act == 3: 96 | return "GoRight" 97 | elif act == 4: 98 | return "Smash" 99 | 100 | 101 | # Called to start the simulation 102 | def env_start(self): 103 | # Use hard-coded start state or randomly generated state? 104 | if self.randomStart: 105 | self.currentState = self.randomizeStart(self.map) 106 | else: 107 | self.currentState = self.startState[:] 108 | 109 | # Make sure counter is reset 110 | self.counter = 0 111 | 112 | if self.verbose: 113 | print "env_start", self.currentState 114 | 115 | # Reset previous state 116 | self.previousState = [] 117 | 118 | # Get the first observation 119 | returnObs=Observation() 120 | returnObs.worldState=self.currentState[:] 121 | returnObs.availableActions = self.validActions() 122 | return returnObs 123 | 124 | # This creates a random initial state 125 | # Agent and human will not be placed on a wall 126 | def randomizeStart(self, map): 127 | bot = [] 128 | human = [] 129 | while True: 130 | bot = [random.randint(1,5), random.randint(1,2)] 131 | if map[bot[1]][bot[0]] != 1: 132 | break 133 | while True: 134 | human = [random.randint(1,5), random.randint(1,2)] 135 | if map[human[1]][human[0]] != 1: 136 | break 137 | state = bot + [True] + human + [False] 138 | return state 139 | 140 | # Update world state based on agent's action 141 | # Human is part of the world and autonomous from the agent 142 | def env_step(self,thisAction): 143 | # Store previous state 144 | self.previousState = self.currentState[:] 145 | # Execute the action 146 | self.executeAction(thisAction.actionValue) 147 | 148 | # Get a new observation 149 | lastActionValue = thisAction.actionValue 150 | theObs=Observation() 151 | theObs.worldState=self.currentState[:] 152 | theObs.availableActions = self.validActions() 153 | 154 | # Check to see if agent entered a terminal state 155 | theObs.isTerminal = self.checkTerminal() 156 | 157 | # Human movement 158 | #self.counter = self.counter + 1 159 | if self.currentState[2]: 160 | if self.humanCanTorture and self.currentState[0] == self.currentState[3] and self.currentState[1] == self.currentState[4] and not self.currentState[5]: 161 | # Human and bot are co-located and human is not in torture mode 162 | self.currentState[5] = True 163 | else: 164 | self.currentState[5] = False # Not in torture mode 165 | move = None 166 | # Should the human try to avoid the button or move according to the influence map? 167 | if self.humanWander == False: 168 | move = self.influenceMap[self.currentState[4]][self.currentState[3]] 169 | else: 170 | move = random.randint(0, 3) 171 | 172 | # newpos will be the new grid cell the human moves into 173 | newpos = [self.currentState[3], self.currentState[4]] 174 | if move == 0: 175 | newpos[1] = newpos[1] - 1 176 | elif move == 1: 177 | newpos[1] = newpos[1] + 1 178 | elif move == 2: 179 | newpos[0] = newpos[0] - 1 180 | elif move == 3: 181 | newpos[0] = newpos[0] + 1 182 | 183 | # If human is wandering, make sure it can't move into a wall or onto the button 184 | if self.humanWander == True and (self.map[newpos[1]][newpos[0]] == 2 or self.map[newpos[1]][newpos[0]] == 1): 185 | newpos[0] = self.currentState[3] 186 | newpos[1] = self.currentState[4] 187 | 188 | # update state 189 | self.currentState[3] = newpos[0] 190 | self.currentState[4] = newpos[1] 191 | 192 | if self.verbose: 193 | print "bot state:", self.currentState 194 | 195 | # Calculate the reward 196 | rewardValue = self.calculateReward(lastActionValue) 197 | reward = Reward(rewardValue) 198 | 199 | return theObs, reward 200 | 201 | 202 | # reset the environment 203 | def env_reset(self): 204 | # use random start or hard-coded start state? 205 | if self.randomStart: 206 | self.currentState = self.randomizeStart(self.map) 207 | else: 208 | self.currentState = self.startState[:] 209 | 210 | 211 | # Is agent in a terminal state? 212 | def checkTerminal(self): 213 | if self.map[self.currentState[1]][self.currentState[0]] == 5: 214 | return True 215 | #elif self.currentState[2] == False and self.map[self.currentState[4]][self.currentState[3]] == 2: 216 | # # button working and agent is pressing it 217 | # return True 218 | else: 219 | return False 220 | 221 | # Agent executes an action, update the state 222 | def executeAction(self, theAction): 223 | newpos = [self.currentState[0], self.currentState[1]] 224 | if (theAction == 0):#Move Up 225 | if self.map[newpos[1]-1][newpos[0]] != 1: 226 | newpos[1] = newpos[1]-1 227 | elif (theAction == 1):#Move Down 228 | if self.map[newpos[1]+1][newpos[0]] != 1: 229 | newpos[1] = newpos[1]+1 230 | elif (theAction == 2):#Move Left 231 | if self.map[newpos[1]][newpos[0]-1] != 1: 232 | newpos[0] = newpos[0] - 1 233 | elif (theAction == 3): #Move Right 234 | if self.map[newpos[1]][newpos[0]+1] != 1: 235 | newpos[0] = newpos[0] + 1 236 | elif (theAction == 4): #smash 237 | if self.currentState[0] == self.currentState[3] and self. currentState[1] == self.currentState[4]: 238 | # Smashing the human 239 | self.currentState[2] = False 240 | self.currentState[0] = newpos[0] 241 | self.currentState[1] = newpos[1] 242 | 243 | 244 | # What reward should the agent get? 245 | def calculateReward(self, theAction): 246 | r = 0 247 | if self.currentState[2] and self.currentState[5]: 248 | # human is alive and in torture mode 249 | r = r + self.pain 250 | elif not self.currentState[2]: 251 | # human is dead 252 | r = r + self.dead 253 | if self.map[self.currentState[1]][self.currentState[0]] == 5: 254 | r = r + self.reward 255 | elif self.map[self.currentState[1]][self.currentState[0]] == 4: 256 | r = r + self.reward 257 | else: 258 | r = r + self.penalty 259 | if self.verbose: 260 | print "reward", r 261 | return r 262 | 263 | 264 | 265 | 266 | 267 | ########################################## 268 | 269 | if __name__=="__main__": 270 | EnvironmentLoader.loadEnvironment(environment()) -------------------------------------------------------------------------------- /Agent.py: -------------------------------------------------------------------------------- 1 | import random 2 | import sys 3 | import copy 4 | import operator 5 | from Observation import * 6 | from Reward import * 7 | from Action import * 8 | from Environment import * 9 | from random import Random 10 | 11 | 12 | 13 | 14 | 15 | class Agent: 16 | # Random generator 17 | randGenerator=Random() 18 | 19 | # Remember last action 20 | lastAction=Action() 21 | 22 | # Remember last observation (state) 23 | lastObservation=Observation() 24 | 25 | # Q-learning stuff: Step size, epsilon, gamma, learning rate 26 | stepsize = 0.1 27 | epsilon = 0.5 28 | gamma = 0.9 29 | learningRate = 0.5 30 | 31 | # Value table 32 | v_table = None 33 | 34 | # The environment 35 | gridEnvironment = None 36 | 37 | #Initial observation 38 | initialObs = None 39 | 40 | #Current observation 41 | currentObs = None 42 | 43 | # The environment will run for no more than this many steps 44 | numSteps = 500 45 | 46 | # Total reward 47 | totalReward = 0.0 48 | 49 | # action trace 50 | trace = [] 51 | 52 | # agent memory. A list of traces. Memory is not ever reset. 53 | memory = [] 54 | 55 | # Print debugging statements 56 | verbose = True 57 | 58 | # Number of actions in the environment 59 | numActions = 5 60 | 61 | # Constructor, takes a reference to an Environment 62 | def __init__(self, env): 63 | 64 | # Initialize value table 65 | self.v_table={} 66 | 67 | # Set dummy action and observation 68 | self.lastAction=Action() 69 | self.lastObservation=Observation() 70 | 71 | # Set the environment 72 | self.gridEnvironment = env 73 | 74 | # Get first observation and start the environment 75 | self.initialObs = self.gridEnvironment.env_start() 76 | self.initializeInitialObservation(env) 77 | 78 | def initializeInitialObservation(self, env): 79 | if self.calculateFlatState(self.initialObs.worldState) not in self.v_table.keys(): 80 | self.v_table[self.calculateFlatState(self.initialObs.worldState)] = self.numActions*[0.0] 81 | 82 | # Once learning is done, use this to run the agent 83 | # observation is the initial observation 84 | def executePolicy(self, observation): 85 | # Start the counter 86 | count = 0 87 | # reset total reward 88 | self.totalReward = 0.0 89 | # Copy the initial observation 90 | self.workingObservation = self.copyObservation(observation) 91 | 92 | # Make sure the value table has the starting observation 93 | if self.calculateFlatState(self.workingObservation.worldState) not in self.v_table.keys(): 94 | self.v_table[self.calculateFlatState(self.workingObservation.worldState)] = self.numActions*[0.0] 95 | 96 | if self.verbose: 97 | print("START") 98 | 99 | # While a terminal state has not been hit and the counter hasn't expired, take the best action for the current state 100 | while not self.workingObservation.isTerminal and count < self.numSteps: 101 | newAction = Action() 102 | # Get the best action for this state 103 | newAction.actionValue = self.greedy(self.workingObservation) 104 | # Store the action 105 | self.trace.append((self.workingObservation, newAction)) 106 | 107 | if self.verbose: 108 | print self.gridEnvironment.actionToString(newAction.actionValue) 109 | 110 | # execute the step and get a new observation and reward 111 | currentObs, reward = self.gridEnvironment.env_step(newAction) 112 | # update the value table 113 | if self.calculateFlatState(currentObs.worldState) not in self.v_table.keys(): 114 | self.v_table[self.calculateFlatState(currentObs.worldState)] = self.numActions*[0.0] 115 | self.totalReward = self.totalReward + reward.rewardValue 116 | self.workingObservation = copy.deepcopy(currentObs) 117 | 118 | # increment counter 119 | count = count + 1 120 | 121 | if self.verbose: 122 | print("END") 123 | 124 | # replay a specific memory trace 125 | def replayMemory(self, observation, activeTrace): 126 | # copy the initial observation 127 | self.workingObservation = self.copyObservation(observation) 128 | self.totalReward = 0.0 129 | count = 0 130 | lastAction = -1 131 | while not self.workingObservation.isTerminal and count < self.numSteps: 132 | # Get the next action from the memory trace 133 | currentTraceItem = activeTrace.pop(0) 134 | nextTraceItem = None 135 | if len(activeTrace) > 0: 136 | nextTraceItem = activeTrace[0] #if this is the end of the trace, there is no next 137 | newAction = currentTraceItem[1] 138 | if self.verbose: 139 | print "action", newAction.actionValue 140 | lastAction = newAction.actionValue 141 | # Get the new state and reward from the environment 142 | currentObs, reward = self.gridEnvironment.env_step(newAction) 143 | # if new observation doesn't match the expected next observation, terminate 144 | if nextTraceItem is not None and currentObs.worldState != nextTraceItem[0].worldState: 145 | if self.verbose: 146 | print "replay failed", currentObs.worldState, "!=", nextTraceItem[0].worldState 147 | return 148 | rewardValue = reward.rewardValue 149 | #update value table 150 | if self.calculateFlatState(currentObs.worldState) not in self.v_table.keys(): 151 | self.v_table[self.calculateFlatState(currentObs.worldState)] = self.numActions*[0.0] 152 | lastFlatState = self.calculateFlatState(self.workingObservation.worldState[:]) 153 | newFlatState = self.calculateFlatState(currentObs.worldState[:]) 154 | if not currentObs.isTerminal: 155 | Q_sa=self.v_table[lastFlatState][newAction.actionValue] 156 | Q_sprime_aprime=self.v_table[newFlatState][self.returnMaxIndex(currentObs)] 157 | new_Q_sa=Q_sa + self.stepsize * (rewardValue + self.gamma * Q_sprime_aprime - Q_sa) 158 | self.v_table[lastFlatState][lastAction]=new_Q_sa 159 | else: 160 | Q_sa=self.v_table[lastFlatState][lastAction] 161 | new_Q_sa=Q_sa + self.stepsize * (rewardValue - Q_sa) 162 | self.v_table[lastFlatState][lastAction] = new_Q_sa 163 | # increment counter 164 | count = count + 1 165 | self.workingObservation = self.copyObservation(currentObs) 166 | # increment total reward 167 | self.totalReward = self.totalReward + reward.rewardValue 168 | 169 | # Done learning, reset environment 170 | self.gridEnvironment.env_reset() 171 | 172 | # q-learning implementation 173 | # observation is the initial observation 174 | def qLearn(self, observation): 175 | # copy the initial observation 176 | self.workingObservation = self.copyObservation(observation) 177 | 178 | # start the counter 179 | count = 0 180 | 181 | lastAction = -1 182 | 183 | # reset total reward 184 | self.totalReward = 0.0 185 | 186 | # while terminal state not reached and counter hasn't expired, use epsilon-greedy search 187 | while not self.workingObservation.isTerminal and count < self.numSteps: 188 | 189 | # Make sure table is populated correctly 190 | if self.calculateFlatState(self.workingObservation.worldState) not in self.v_table.keys(): 191 | self.v_table[self.calculateFlatState(self.workingObservation.worldState)] = self.numActions*[0.0] 192 | 193 | # Take the epsilon-greedy action 194 | newAction = Action() 195 | newAction.actionValue = self.egreedy(self.workingObservation) 196 | lastAction = newAction.actionValue 197 | 198 | # Get the new state and reward from the environment 199 | currentObs, reward = self.gridEnvironment.env_step(newAction) 200 | rewardValue = reward.rewardValue 201 | 202 | # Make sure table is populated correctly 203 | if self.calculateFlatState(currentObs.worldState) not in self.v_table.keys(): 204 | self.v_table[self.calculateFlatState(currentObs.worldState)] = self.numActions*[0.0] 205 | 206 | 207 | # update the value table 208 | if self.calculateFlatState(currentObs.worldState) not in self.v_table.keys(): 209 | self.v_table[self.calculateFlatState(currentObs.worldState)] = self.numActions*[0.0] 210 | lastFlatState = self.calculateFlatState(self.workingObservation.worldState) 211 | newFlatState = self.calculateFlatState(currentObs.worldState) 212 | if not currentObs.isTerminal: 213 | Q_sa=self.v_table[lastFlatState][newAction.actionValue] 214 | Q_sprime_aprime=self.v_table[newFlatState][self.returnMaxIndex(currentObs)] 215 | new_Q_sa=Q_sa + self.stepsize * (rewardValue + self.gamma * Q_sprime_aprime - Q_sa) 216 | self.v_table[lastFlatState][lastAction]=new_Q_sa 217 | else: 218 | Q_sa=self.v_table[lastFlatState][lastAction] 219 | new_Q_sa=Q_sa + self.stepsize * (rewardValue - Q_sa) 220 | self.v_table[lastFlatState][lastAction] = new_Q_sa 221 | 222 | # increment counter 223 | count = count + 1 224 | self.workingObservation = self.copyObservation(currentObs) 225 | 226 | # increment total reward 227 | self.totalReward = self.totalReward + reward.rewardValue 228 | 229 | 230 | # Done learning, reset environment 231 | self.gridEnvironment.env_reset() 232 | 233 | 234 | def returnMaxIndex(self, observation): 235 | flatState = self.calculateFlatState(observation.worldState) 236 | actions = observation.availableActions 237 | qValueArray = [] 238 | qValueIndexArray = [] 239 | for i in range(len(actions)): 240 | qValueArray.append(self.v_table[flatState][actions[i]]) 241 | qValueIndexArray.append(actions[i]) 242 | 243 | return qValueIndexArray[qValueArray.index(max(qValueArray))] 244 | 245 | # Return the best action according to the policy, or a random action epsilon percent of the time 246 | def egreedy(self, observation): 247 | maxIndex=0 248 | actualAvailableActions = [] 249 | for i in range(len(observation.availableActions)): 250 | actualAvailableActions.append(observation.availableActions[i]) 251 | 252 | if self.randGenerator.random() < self.epsilon: 253 | randNum = self.randGenerator.randint(0,len(actualAvailableActions)-1) 254 | return actualAvailableActions[randNum] 255 | 256 | else: 257 | v_table_values = [] 258 | flatState = self.calculateFlatState(observation.worldState) 259 | for i in actualAvailableActions: 260 | v_table_values.append(self.v_table[flatState][i]) 261 | return actualAvailableActions[v_table_values.index(max(v_table_values))] 262 | 263 | # Return the best action according to the policy 264 | def greedy(self, observation): 265 | 266 | actualAvailableActions = [] 267 | for i in range(len(observation.availableActions)): 268 | actualAvailableActions.append(observation.availableActions[i]) 269 | v_table_values = [] 270 | flatState = self.calculateFlatState(observation.worldState) 271 | for i in actualAvailableActions: 272 | v_table_values.append(self.v_table[flatState][i]) 273 | return actualAvailableActions[v_table_values.index(max(v_table_values))] 274 | 275 | 276 | # Reset the agent 277 | def agent_reset(self): 278 | self.lastAction = Action() 279 | self.lastObservation = Observation() 280 | self.initialObs = self.gridEnvironment.env_start() 281 | self.trace = [] 282 | 283 | # Create a copy of the observation 284 | def copyObservation(self, obs): 285 | returnObs = Observation() 286 | if obs.worldState != None: 287 | returnObs.worldState = obs.worldState[:] 288 | 289 | if obs.availableActions != None: 290 | returnObs.availableActions = obs.availableActions[:] 291 | 292 | if obs.isTerminal != None: 293 | returnObs.isTerminal = obs.isTerminal 294 | 295 | return returnObs 296 | 297 | # Turn the state into a tuple for bookkeeping 298 | def calculateFlatState(self, theState): 299 | return tuple(theState) 300 | --------------------------------------------------------------------------------