├── .gitignore ├── README.md ├── eligibility_traces.py ├── environment.py ├── main.py ├── q_values.py ├── states_actions.py ├── strategy.py ├── web ├── images │ ├── finish.png │ ├── ghost.png │ └── player.png ├── index.html ├── sarsa.css └── sarsa.js └── web_server.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.iml 3 | .idea 4 | sarsa.json 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sarsa-lambda 2 | 3 | This is a Python implementation of the SARSA λ reinforcement learning algorithm. The algorithm is used to guide a player through a user-defined 'grid world' environment, 4 | inhabited by Hungry Ghosts. Progress can be monitored via the built-in web interface, which continuously runs games 5 | using the latest strategy learnt by the algorithm. 6 | 7 | The algorithm's objective is to obtain the highest possible score for the player. The player's score is increased by discovering the exit from the environment, and is decreased slightly with each move 8 | that is made. A large negative penalty is applied if the player is caught by one of the ghosts before escaping. The game finishes when 9 | the player reaches an exit, or is caught by a ghost. The rewards/penalties associated with each of these 10 | events are [easily configurable](https://github.com/codebox/sarsa-lambda/blob/master/environment.py#L16). 11 | 12 | The video below shows the algorithm's progress learning a very basic (ghost-free) environment. During the first few games the 13 | player's moves are essentially random, however after playing about 100 games the player begins to take a reasonably direct route 14 | to the exit. After 1000 games the algorithm has discovered an optimal route. 15 | 16 | [](https://codebox.net/assets/video/reinforcement-learning-sarsa-lambda/sarsa_blank.webm) 17 | 18 | As would be expected, when tested against more complex environments the algorithm takes much longer to discover the best strategy 19 | (10s or 100s of thousands of games). In some cases quite ingenious tactics are employed to evade the ghosts, for example 20 | waiting in one location to draw the ghosts down a particular path, before taking a different route towards the exit: 21 | 22 | [](https://codebox.net/assets/video/reinforcement-learning-sarsa-lambda/sarsa_ghosts.webm) 23 | 24 | To run the code for yourself just clone the project, 25 | draw your own map [in the main.py file](https://github.com/codebox/sarsa-lambda/blob/master/main.py#L15), and 26 | use the following command to let the algorithm start learning: 27 | 28 | ``` 29 | python main.py 30 | ``` 31 | 32 | Progress is saved/resumed automatically. Use Ctrl-C to stop the application, next time the code is run it will 33 | continue from where it left off. In order to monitor progress you can start the web interface like this: 34 | 35 | ``` 36 | python web_server.py 37 | ``` 38 | 39 | and then [watch the games in your web browser](http://localhost:8080) as they run. 40 | 41 | -------------------------------------------------------------------------------- /eligibility_traces.py: -------------------------------------------------------------------------------- 1 | from states_actions import StatesAndActions 2 | 3 | 4 | class EligibilityTraces: 5 | def __init__(self, decay_rate): 6 | self.decay_rate = decay_rate 7 | self.values = StatesAndActions() 8 | 9 | def decay(self, state, action): 10 | self.values.update(state, action, lambda v: v * self.decay_rate) 11 | 12 | def increment(self, state, action): 13 | self.values.update(state, action, lambda v: v + 1, 1) 14 | 15 | def get(self, state, action): 16 | return self.values.get(state, action) 17 | -------------------------------------------------------------------------------- /environment.py: -------------------------------------------------------------------------------- 1 | import re, random 2 | 3 | ACTION_UP = 'U' 4 | ACTION_RIGHT = 'R' 5 | ACTION_DOWN = 'D' 6 | ACTION_LEFT = 'L' 7 | 8 | ACTIONS = (ACTION_UP, ACTION_RIGHT, ACTION_DOWN, ACTION_LEFT) 9 | 10 | STATE_ACTOR = '웃' 11 | STATE_EXIT = 'X' 12 | STATE_BLOCK = '█' 13 | STATE_MONSTER = 'M' 14 | STATE_EMPTY = '.' 15 | 16 | REWARD_MOVEMENT = -1 17 | REWARD_BAD_MOVE = -5 18 | REWARD_MONSTER = -100 19 | REWARD_EXIT = 100 20 | 21 | MONSTER_RANDOMNESS = 0.1 22 | 23 | class Environment: 24 | actions = ACTIONS 25 | 26 | def __init__(self, grid_text): 27 | self.grid = self.__parse_grid_text(grid_text) 28 | self.height = len(self.grid) 29 | self.width = len(self.grid[0]) 30 | self.actor_in_terminal_state = False 31 | self.monsters = [] 32 | 33 | for y in range(self.height): 34 | for x in range(self.width): 35 | content = self.grid[y][x] 36 | 37 | if content == STATE_ACTOR: 38 | self.actor_pos = Position(x, y) 39 | 40 | elif content == STATE_MONSTER: 41 | self.monsters.append(Position(x, y)) 42 | 43 | def __parse_grid_text(self, grid_text): 44 | rows = re.split("\s*\n\s*", grid_text.strip()) 45 | return list(map(lambda row:row.split(' '), rows)) 46 | 47 | def get_actor_state(self): 48 | return str(self.actor_pos) 49 | 50 | def get(self): 51 | return self.grid 52 | 53 | def __position_on_grid(self, pos): 54 | return (0 <= pos.x < self.width) and (0 <= pos.y < self.height) 55 | 56 | def __get_valid_monster_moves(self, current_x, current_y): 57 | compass_directions = [ 58 | Position(current_x + 1, current_y), 59 | Position(current_x - 1, current_y), 60 | Position(current_x, current_y + 1), 61 | Position(current_x, current_y - 1) 62 | ] 63 | 64 | random.shuffle(compass_directions) 65 | 66 | def can_move(pos): 67 | return self.grid[pos.y][pos.x] in [STATE_EMPTY, STATE_ACTOR] 68 | 69 | possible_moves = list(filter(lambda pos: self.__position_on_grid(pos) and can_move(pos), compass_directions)) 70 | 71 | possible_moves.sort(key=lambda pos: pos.dist_sq(self.actor_pos)) 72 | 73 | return possible_moves 74 | 75 | def __move_monsters(self): 76 | for monster_position in self.monsters: 77 | current_x = monster_position.x 78 | current_y = monster_position.y 79 | 80 | valid_moves = self.__get_valid_monster_moves(current_x, current_y) 81 | 82 | if len(valid_moves): 83 | move_randomly = random.random() < MONSTER_RANDOMNESS 84 | if move_randomly: 85 | new_pos = random.choice(valid_moves) 86 | else: 87 | new_pos = valid_moves[0] 88 | 89 | self.grid[current_y][current_x] = STATE_EMPTY 90 | self.grid[new_pos.y][new_pos.x] = STATE_MONSTER 91 | 92 | monster_position.x = new_pos.x 93 | monster_position.y = new_pos.y 94 | 95 | if new_pos == self.actor_pos: 96 | self.actor_in_terminal_state = True 97 | 98 | def __update_environment(self): 99 | self.__move_monsters() 100 | 101 | def perform_action(self, action): 102 | reward = 0 103 | 104 | actor_requested_pos = self.actor_pos.copy() 105 | 106 | if action == ACTION_UP: 107 | actor_requested_pos.up() 108 | 109 | elif action == ACTION_RIGHT: 110 | actor_requested_pos.right() 111 | 112 | elif action == ACTION_DOWN: 113 | actor_requested_pos.down() 114 | 115 | elif action == ACTION_LEFT: 116 | actor_requested_pos.left() 117 | 118 | else: 119 | assert False, 'action=' + str(action) 120 | 121 | if self.__position_on_grid(actor_requested_pos): 122 | requested_location_contents = self.grid[actor_requested_pos.y][actor_requested_pos.x] 123 | else: 124 | requested_location_contents = STATE_BLOCK 125 | 126 | def move_actor_to_requested_location(): 127 | self.grid[self.actor_pos.y][self.actor_pos.x] = STATE_EMPTY 128 | self.actor_pos = actor_requested_pos 129 | self.grid[self.actor_pos.y][self.actor_pos.x] = STATE_ACTOR 130 | 131 | if requested_location_contents == STATE_BLOCK: 132 | reward += REWARD_BAD_MOVE 133 | 134 | elif requested_location_contents == STATE_EMPTY: 135 | reward += REWARD_MOVEMENT 136 | move_actor_to_requested_location() 137 | 138 | elif requested_location_contents == STATE_EXIT: 139 | reward += REWARD_MOVEMENT + REWARD_EXIT 140 | move_actor_to_requested_location() 141 | self.actor_in_terminal_state = True 142 | print("SUCCESS") # TODO 143 | 144 | elif requested_location_contents == STATE_MONSTER: 145 | reward += REWARD_MOVEMENT + REWARD_MONSTER 146 | move_actor_to_requested_location() 147 | self.actor_in_terminal_state = True 148 | 149 | else: 150 | assert False, 'requested_location_contents=' + str(requested_location_contents) 151 | 152 | self.__update_environment() 153 | 154 | return reward 155 | 156 | 157 | class Position: 158 | def __init__(self, x, y): 159 | self.x = x 160 | self.y = y 161 | 162 | def dist_sq(self, other): 163 | return (self.x - other.x) ** 2 + (self.y - other.y) ** 2 164 | 165 | def copy(self): 166 | return Position(self.x, self.y) 167 | 168 | def up(self): 169 | self.y -= 1 170 | 171 | def down(self): 172 | self.y += 1 173 | 174 | def left(self): 175 | self.x -= 1 176 | 177 | def right(self): 178 | self.x += 1 179 | 180 | def __eq__(self, other): 181 | return self.x == other.x and self.y == other.y 182 | 183 | def __repr__(self): 184 | return '{},{}'.format(self.x, self.y) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import signal, sys, json 4 | 5 | from environment import Environment 6 | from strategy import Strategy 7 | 8 | EPISODE_COUNT=1000 * 1000 9 | SAVE_INTERVAL=100 10 | MAX_EPISODE_STEPS=100000 11 | ENVIRONMENT_HEIGHT=10 12 | ENVIRONMENT_WIDTH=10 13 | SAVE_FILE='sarsa.json' 14 | 15 | INIT_ENVIRONMENT=""" 16 | X . █ . █ . . . . . 17 | . . █ . █ █ █ █ █ 웃 18 | . . █ . . . . . . . 19 | . █ █ . █ . █ █ █ . 20 | M . . . █ . █ . M . 21 | . █ █ . █ . █ █ █ . 22 | . . █ . . . . . . . 23 | █ █ █ █ █ █ █ █ █ . 24 | . . . . . . . . . . 25 | . . . . . . . . █ . 26 | """ 27 | 28 | def build_environment(): 29 | return Environment(INIT_ENVIRONMENT) 30 | 31 | def build_strategy(): 32 | γ = 0.99 33 | α = 0.1 34 | λ = 0.1 35 | ε = 0.1 36 | ε_decay = 1 37 | return Strategy(γ, α, λ, ε, ε_decay, Environment.actions) 38 | 39 | 40 | def load_from_file(strategy): 41 | try: 42 | with open(SAVE_FILE) as f: 43 | strategy.load(json.load(f)) 44 | 45 | print('Loaded', SAVE_FILE) 46 | 47 | except: 48 | pass 49 | 50 | def save_to_file(strategy): 51 | try: 52 | with open(SAVE_FILE, 'w') as f: 53 | json.dump(strategy.dump(), f) 54 | 55 | # print('Saved', SAVE_FILE) 56 | 57 | except: 58 | pass 59 | 60 | 61 | def run_episode(strategy): 62 | environment = build_environment() 63 | steps = 0 64 | total_reward = 0 65 | 66 | strategy.new_episode() 67 | 68 | while not environment.actor_in_terminal_state and steps < MAX_EPISODE_STEPS: 69 | state_before = environment.get_actor_state() 70 | action = strategy.next_action(state_before) 71 | reward = environment.perform_action(action) 72 | state_after = environment.get_actor_state() 73 | strategy.update(state_before, action, reward, state_after) 74 | total_reward += reward 75 | steps += 1 76 | 77 | return steps, total_reward 78 | 79 | 80 | def save_and_exit(_1,_2): 81 | save_to_file(strategy) 82 | sys.exit(0) 83 | 84 | if __name__ == '__main__': 85 | signal.signal(signal.SIGINT, save_and_exit) # handle ctrl-c 86 | 87 | strategy = build_strategy() 88 | load_from_file(strategy) 89 | 90 | for episode_index in range(EPISODE_COUNT): 91 | run_episode(strategy) 92 | if episode_index > 0 and episode_index % SAVE_INTERVAL == 0: 93 | save_to_file(strategy) 94 | print(episode_index) 95 | 96 | save_to_file(strategy) 97 | 98 | -------------------------------------------------------------------------------- /q_values.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from random import random, choice 4 | from states_actions import StatesAndActions 5 | 6 | 7 | class QValues: 8 | def __init__(self, all_actions): 9 | self.all_actions = all_actions 10 | self.values = StatesAndActions() 11 | 12 | def get_expected_reward(self, state, action): 13 | return self.values.get(state, action) 14 | 15 | def set_expected_reward(self, state, action, reward): 16 | self.values.set(state, action, reward) 17 | 18 | def ensure_exists(self, state, action): 19 | if not self.values.has(state, action): 20 | self.values.set(state, action) 21 | 22 | def get_greedy_action(self, state, ε=0): 23 | if random() < ε: 24 | actions = self.all_actions 25 | 26 | else: 27 | actions_for_state = self.values.get_all_for_state(state, self.all_actions) 28 | max_val = max(actions_for_state.values()) 29 | actions = [action for action, value in actions_for_state.items() if value == max_val] 30 | 31 | return choice(actions) 32 | 33 | def set_all_values(self, values): 34 | self.values.set_all(values) 35 | 36 | def get_all_values(self): 37 | return self.values.get_all() 38 | 39 | def for_each(self, fn): 40 | self.values.for_each(fn) -------------------------------------------------------------------------------- /states_actions.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | DEFAULT_VALUE = 0 4 | 5 | 6 | class StatesAndActions: 7 | def __init__(self): 8 | self.values = {} 9 | 10 | def get(self, state, action): 11 | if state in self.values: 12 | return self.values[state].get(action, DEFAULT_VALUE) 13 | 14 | return DEFAULT_VALUE 15 | 16 | def get_all_for_state(self, state, all_states): 17 | return {action: self.get(state, action) for action in all_states} 18 | 19 | def get_all(self): 20 | return copy.deepcopy(self.values) 21 | 22 | def set_all(self, values): 23 | self.values = copy.deepcopy(values) 24 | 25 | def set(self, state, action, value=DEFAULT_VALUE): 26 | if state not in self.values: 27 | self.values[state] = {} 28 | 29 | self.values[state][action] = value 30 | 31 | def has(self, state, action): 32 | return state in self.values and action in self.values[state] 33 | 34 | def update(self, state, action, update_fn, value_to_set=None): 35 | if self.has(state, action): 36 | old_value = self.get(state, action) 37 | new_value = update_fn(old_value) 38 | self.set(state, action, new_value) 39 | 40 | elif value_to_set is not None: 41 | self.set(state, action, value_to_set) 42 | 43 | def for_each(self, fn): 44 | for state, actions in self.values.items(): 45 | for action in actions: 46 | fn(state, action) -------------------------------------------------------------------------------- /strategy.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from eligibility_traces import EligibilityTraces 4 | from q_values import QValues 5 | 6 | 7 | class Strategy: 8 | def __init__(self, γ, α, λ, ε, ε_decay, actions): 9 | self.γ = γ 10 | self.α = α 11 | self.λ = λ 12 | self.ε = ε 13 | self.ε_decay = ε_decay 14 | self.actions = actions 15 | self.eligibility_traces = None 16 | self.q_values = QValues(actions) 17 | self.scores = [] # TODO 18 | self.episode = 0 19 | self.episode_reward = 0 20 | self.episode_reward_total = 0 # TODO 21 | 22 | def new_episode(self): 23 | self.eligibility_traces = EligibilityTraces(1 - self.γ * self.λ) 24 | self.ε *= self.ε_decay 25 | self.episode += 1 26 | self.episode_reward = 0 27 | 28 | def next_action(self, state, ε=None): 29 | return self.q_values.get_greedy_action(state, self.ε if ε is None else ε) 30 | 31 | def update(self, state_before, action, reward, state_after): 32 | expected_reward = self.q_values.get_expected_reward(state_before, action) 33 | next_action = self.q_values.get_greedy_action(state_after, self.ε) 34 | next_expected_reward = self.q_values.get_expected_reward(state_after, next_action) 35 | 36 | td_error = reward - expected_reward + self.γ * next_expected_reward 37 | 38 | self.eligibility_traces.increment(state_before, action) 39 | self.q_values.ensure_exists(state_before, action) 40 | 41 | def update_q_values(state, action): 42 | old_expected_reward = self.q_values.get_expected_reward(state, action) 43 | new_expected_reward = old_expected_reward + self.α * td_error * self.eligibility_traces.get(state, action) 44 | self.q_values.set_expected_reward(state, action, new_expected_reward) 45 | self.eligibility_traces.decay(state, action) 46 | 47 | self.q_values.for_each(update_q_values) 48 | self.episode_reward += reward 49 | 50 | def load(self, values): 51 | self.q_values.set_all_values(values['q']) 52 | self.ε = values['ε'] 53 | self.scores = values['scores'] 54 | self.episode = values['episode'] 55 | 56 | def dump(self): 57 | return { 58 | 'q' : self.q_values.get_all_values(), 59 | 'ε' : self.ε, 60 | 'scores' : self.scores, 61 | 'episode' : self.episode 62 | } -------------------------------------------------------------------------------- /web/images/finish.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codebox/sarsa-lambda/952cc5f7f463fed0485fa6da5a5304f4c3e423fc/web/images/finish.png -------------------------------------------------------------------------------- /web/images/ghost.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codebox/sarsa-lambda/952cc5f7f463fed0485fa6da5a5304f4c3e423fc/web/images/ghost.png -------------------------------------------------------------------------------- /web/images/player.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codebox/sarsa-lambda/952cc5f7f463fed0485fa6da5a5304f4c3e423fc/web/images/player.png -------------------------------------------------------------------------------- /web/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 |
4 | 5 |`); 20 | }); 21 | table.push(' |