├── .gitignore
├── web
├── images
│ ├── finish.png
│ ├── ghost.png
│ └── player.png
├── index.html
├── sarsa.css
└── sarsa.js
├── eligibility_traces.py
├── q_values.py
├── states_actions.py
├── web_server.py
├── strategy.py
├── main.py
├── README.md
└── environment.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.iml
3 | .idea
4 | sarsa.json
5 |
--------------------------------------------------------------------------------
/web/images/finish.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codebox/sarsa-lambda/HEAD/web/images/finish.png
--------------------------------------------------------------------------------
/web/images/ghost.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codebox/sarsa-lambda/HEAD/web/images/ghost.png
--------------------------------------------------------------------------------
/web/images/player.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codebox/sarsa-lambda/HEAD/web/images/player.png
--------------------------------------------------------------------------------
/web/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | SARSA λ
6 |
7 |
8 |
9 | SARSA λ
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/eligibility_traces.py:
--------------------------------------------------------------------------------
1 | from states_actions import StatesAndActions
2 |
3 |
4 | class EligibilityTraces:
5 | def __init__(self, decay_rate):
6 | self.decay_rate = decay_rate
7 | self.values = StatesAndActions()
8 |
9 | def decay(self, state, action):
10 | self.values.update(state, action, lambda v: v * self.decay_rate)
11 |
12 | def increment(self, state, action):
13 | self.values.update(state, action, lambda v: v + 1, 1)
14 |
15 | def get(self, state, action):
16 | return self.values.get(state, action)
17 |
--------------------------------------------------------------------------------
/web/sarsa.css:
--------------------------------------------------------------------------------
1 | #grid{
2 | font-family: monospace;
3 | border-collapse: collapse;
4 | border-spacing: 0;
5 | }
6 | #grid td {
7 | width: 1em;
8 | background-color: #ebebeb;
9 | margin: 0;
10 | height: 1em;
11 | }
12 | .cell웃 {
13 | background-image: url('images/player.png');
14 | background-size: 100%;
15 | }
16 | .cellM {
17 | background-image: url('images/ghost.png');
18 | background-size: 100%;
19 | }
20 | .cellX {
21 | background-image: url('images/finish.png');
22 | background-size: 100%;
23 | }
24 | .cell█{
25 | background-color: black ! important;
26 | }
--------------------------------------------------------------------------------
/web/sarsa.js:
--------------------------------------------------------------------------------
1 | "use strict";
2 |
3 | const DELAY_MILLIS = 500;
4 |
5 | function fetchJson(url, callback) {
6 | fetch(url).then(response => response.json()).then(obj => {
7 | setTimeout(() => callback(obj), DELAY_MILLIS);
8 | });
9 | }
10 |
11 | const grid = document.getElementById('grid');
12 |
13 | function drawEnvironment(env) {
14 | const table = [''];
15 |
16 | env.forEach(row => {
17 | table.push('');
18 | row.forEach(cell => {
19 | table.push(` | `);
20 | });
21 | table.push('
');
22 | });
23 | table.push('
');
24 |
25 | grid.innerHTML = table.join('');
26 | }
27 |
28 | function init(){
29 | fetchJson('/init', env => {
30 | drawEnvironment(env);
31 | function move() {
32 | fetchJson(`/move/${JSON.stringify(env)}`, response => {
33 | env = response.env
34 | drawEnvironment(env);
35 | if (response.terminal) {
36 | init();
37 | } else {
38 | move();
39 | }
40 | });
41 | }
42 | move();
43 | });
44 | }
45 | init();
46 |
--------------------------------------------------------------------------------
/q_values.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from random import random, choice
4 | from states_actions import StatesAndActions
5 |
6 |
7 | class QValues:
8 | def __init__(self, all_actions):
9 | self.all_actions = all_actions
10 | self.values = StatesAndActions()
11 |
12 | def get_expected_reward(self, state, action):
13 | return self.values.get(state, action)
14 |
15 | def set_expected_reward(self, state, action, reward):
16 | self.values.set(state, action, reward)
17 |
18 | def ensure_exists(self, state, action):
19 | if not self.values.has(state, action):
20 | self.values.set(state, action)
21 |
22 | def get_greedy_action(self, state, ε=0):
23 | if random() < ε:
24 | actions = self.all_actions
25 |
26 | else:
27 | actions_for_state = self.values.get_all_for_state(state, self.all_actions)
28 | max_val = max(actions_for_state.values())
29 | actions = [action for action, value in actions_for_state.items() if value == max_val]
30 |
31 | return choice(actions)
32 |
33 | def set_all_values(self, values):
34 | self.values.set_all(values)
35 |
36 | def get_all_values(self):
37 | return self.values.get_all()
38 |
39 | def for_each(self, fn):
40 | self.values.for_each(fn)
--------------------------------------------------------------------------------
/states_actions.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | DEFAULT_VALUE = 0
4 |
5 |
6 | class StatesAndActions:
7 | def __init__(self):
8 | self.values = {}
9 |
10 | def get(self, state, action):
11 | if state in self.values:
12 | return self.values[state].get(action, DEFAULT_VALUE)
13 |
14 | return DEFAULT_VALUE
15 |
16 | def get_all_for_state(self, state, all_states):
17 | return {action: self.get(state, action) for action in all_states}
18 |
19 | def get_all(self):
20 | return copy.deepcopy(self.values)
21 |
22 | def set_all(self, values):
23 | self.values = copy.deepcopy(values)
24 |
25 | def set(self, state, action, value=DEFAULT_VALUE):
26 | if state not in self.values:
27 | self.values[state] = {}
28 |
29 | self.values[state][action] = value
30 |
31 | def has(self, state, action):
32 | return state in self.values and action in self.values[state]
33 |
34 | def update(self, state, action, update_fn, value_to_set=None):
35 | if self.has(state, action):
36 | old_value = self.get(state, action)
37 | new_value = update_fn(old_value)
38 | self.set(state, action, new_value)
39 |
40 | elif value_to_set is not None:
41 | self.set(state, action, value_to_set)
42 |
43 | def for_each(self, fn):
44 | for state, actions in self.values.items():
45 | for action in actions:
46 | fn(state, action)
--------------------------------------------------------------------------------
/web_server.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | from main import INIT_ENVIRONMENT, build_strategy, load_from_file
4 | from environment import Environment
5 | from http.server import SimpleHTTPRequestHandler
6 | from socketserver import TCPServer
7 | from urllib.parse import unquote
8 |
9 | TCPServer.allow_reuse_address = True
10 |
11 | class MyRequestHandler(SimpleHTTPRequestHandler):
12 | def do_GET(self):
13 | MOVE_PREFIX = '/move/'
14 |
15 | if self.path == '/init':
16 | environment = Environment(INIT_ENVIRONMENT)
17 | # return self.__send_json({'env' : environment.get(), 'actor' : environment.get_actor_state()})
18 | return self.__send_json(environment.get())
19 |
20 | elif self.path == '/':
21 | self.path = '/web/index.html'
22 |
23 | elif self.path.startswith(MOVE_PREFIX):
24 | env = json.loads(unquote(self.path[len(MOVE_PREFIX):]))
25 | env_text = '\n'.join(map(lambda r: ' '.join(r), env))
26 | environment = Environment(env_text)
27 |
28 | state = environment.get_actor_state()
29 |
30 | strategy = build_strategy()
31 | load_from_file(strategy)
32 | action = strategy.next_action(state, 0)
33 | environment.perform_action(action)
34 |
35 | response = {
36 | 'env' : environment.get(),
37 | 'terminal' : environment.actor_in_terminal_state,
38 | 'stats' : {
39 | 'ε' : strategy.ε,
40 | 'scores' : strategy.scores,
41 | 'episode' : strategy.episode
42 | }
43 | }
44 |
45 | return self.__send_json(response)
46 |
47 | return SimpleHTTPRequestHandler.do_GET(self)
48 |
49 | def __send_json(self, obj):
50 | self.protocol_version = 'HTTP/1.1'
51 | self.send_response(200, 'OK')
52 | self.send_header('Content-type', 'application/json')
53 | self.end_headers()
54 | self.wfile.write(bytes(json.dumps(obj), 'utf-8'))
55 |
56 |
57 | Handler = MyRequestHandler
58 | server = TCPServer(('0.0.0.0', 8080), Handler)
59 |
60 | server.serve_forever()
--------------------------------------------------------------------------------
/strategy.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from eligibility_traces import EligibilityTraces
4 | from q_values import QValues
5 |
6 |
7 | class Strategy:
8 | def __init__(self, γ, α, λ, ε, ε_decay, actions):
9 | self.γ = γ
10 | self.α = α
11 | self.λ = λ
12 | self.ε = ε
13 | self.ε_decay = ε_decay
14 | self.actions = actions
15 | self.eligibility_traces = None
16 | self.q_values = QValues(actions)
17 | self.scores = [] # TODO
18 | self.episode = 0
19 | self.episode_reward = 0
20 | self.episode_reward_total = 0 # TODO
21 |
22 | def new_episode(self):
23 | self.eligibility_traces = EligibilityTraces(1 - self.γ * self.λ)
24 | self.ε *= self.ε_decay
25 | self.episode += 1
26 | self.episode_reward = 0
27 |
28 | def next_action(self, state, ε=None):
29 | return self.q_values.get_greedy_action(state, self.ε if ε is None else ε)
30 |
31 | def update(self, state_before, action, reward, state_after):
32 | expected_reward = self.q_values.get_expected_reward(state_before, action)
33 | next_action = self.q_values.get_greedy_action(state_after, self.ε)
34 | next_expected_reward = self.q_values.get_expected_reward(state_after, next_action)
35 |
36 | td_error = reward - expected_reward + self.γ * next_expected_reward
37 |
38 | self.eligibility_traces.increment(state_before, action)
39 | self.q_values.ensure_exists(state_before, action)
40 |
41 | def update_q_values(state, action):
42 | old_expected_reward = self.q_values.get_expected_reward(state, action)
43 | new_expected_reward = old_expected_reward + self.α * td_error * self.eligibility_traces.get(state, action)
44 | self.q_values.set_expected_reward(state, action, new_expected_reward)
45 | self.eligibility_traces.decay(state, action)
46 |
47 | self.q_values.for_each(update_q_values)
48 | self.episode_reward += reward
49 |
50 | def load(self, values):
51 | self.q_values.set_all_values(values['q'])
52 | self.ε = values['ε']
53 | self.scores = values['scores']
54 | self.episode = values['episode']
55 |
56 | def dump(self):
57 | return {
58 | 'q' : self.q_values.get_all_values(),
59 | 'ε' : self.ε,
60 | 'scores' : self.scores,
61 | 'episode' : self.episode
62 | }
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import signal, sys, json
4 |
5 | from environment import Environment
6 | from strategy import Strategy
7 |
8 | EPISODE_COUNT=1000 * 1000
9 | SAVE_INTERVAL=100
10 | MAX_EPISODE_STEPS=100000
11 | ENVIRONMENT_HEIGHT=10
12 | ENVIRONMENT_WIDTH=10
13 | SAVE_FILE='sarsa.json'
14 |
15 | INIT_ENVIRONMENT="""
16 | X . █ . █ . . . . .
17 | . . █ . █ █ █ █ █ 웃
18 | . . █ . . . . . . .
19 | . █ █ . █ . █ █ █ .
20 | M . . . █ . █ . M .
21 | . █ █ . █ . █ █ █ .
22 | . . █ . . . . . . .
23 | █ █ █ █ █ █ █ █ █ .
24 | . . . . . . . . . .
25 | . . . . . . . . █ .
26 | """
27 |
28 | def build_environment():
29 | return Environment(INIT_ENVIRONMENT)
30 |
31 | def build_strategy():
32 | γ = 0.99
33 | α = 0.1
34 | λ = 0.1
35 | ε = 0.1
36 | ε_decay = 1
37 | return Strategy(γ, α, λ, ε, ε_decay, Environment.actions)
38 |
39 |
40 | def load_from_file(strategy):
41 | try:
42 | with open(SAVE_FILE) as f:
43 | strategy.load(json.load(f))
44 |
45 | print('Loaded', SAVE_FILE)
46 |
47 | except:
48 | pass
49 |
50 | def save_to_file(strategy):
51 | try:
52 | with open(SAVE_FILE, 'w') as f:
53 | json.dump(strategy.dump(), f)
54 |
55 | # print('Saved', SAVE_FILE)
56 |
57 | except:
58 | pass
59 |
60 |
61 | def run_episode(strategy):
62 | environment = build_environment()
63 | steps = 0
64 | total_reward = 0
65 |
66 | strategy.new_episode()
67 |
68 | while not environment.actor_in_terminal_state and steps < MAX_EPISODE_STEPS:
69 | state_before = environment.get_actor_state()
70 | action = strategy.next_action(state_before)
71 | reward = environment.perform_action(action)
72 | state_after = environment.get_actor_state()
73 | strategy.update(state_before, action, reward, state_after)
74 | total_reward += reward
75 | steps += 1
76 |
77 | return steps, total_reward
78 |
79 |
80 | def save_and_exit(_1,_2):
81 | save_to_file(strategy)
82 | sys.exit(0)
83 |
84 | if __name__ == '__main__':
85 | signal.signal(signal.SIGINT, save_and_exit) # handle ctrl-c
86 |
87 | strategy = build_strategy()
88 | load_from_file(strategy)
89 |
90 | for episode_index in range(EPISODE_COUNT):
91 | run_episode(strategy)
92 | if episode_index > 0 and episode_index % SAVE_INTERVAL == 0:
93 | save_to_file(strategy)
94 | print(episode_index)
95 |
96 | save_to_file(strategy)
97 |
98 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # sarsa-lambda
2 |
3 | This is a Python implementation of the SARSA λ reinforcement learning algorithm. The algorithm is used to guide a player through a user-defined 'grid world' environment,
4 | inhabited by Hungry Ghosts. Progress can be monitored via the built-in web interface, which continuously runs games
5 | using the latest strategy learnt by the algorithm.
6 |
7 | The algorithm's objective is to obtain the highest possible score for the player. The player's score is increased by discovering the exit from the environment, and is decreased slightly with each move
8 | that is made. A large negative penalty is applied if the player is caught by one of the ghosts before escaping. The game finishes when
9 | the player reaches an exit, or is caught by a ghost. The rewards/penalties associated with each of these
10 | events are [easily configurable](https://github.com/codebox/sarsa-lambda/blob/master/environment.py#L16).
11 |
12 | The video below shows the algorithm's progress learning a very basic (ghost-free) environment. During the first few games the
13 | player's moves are essentially random, however after playing about 100 games the player begins to take a reasonably direct route
14 | to the exit. After 1000 games the algorithm has discovered an optimal route.
15 |
16 | [](https://codebox.net/assets/video/reinforcement-learning-sarsa-lambda/sarsa_blank.webm)
17 |
18 | As would be expected, when tested against more complex environments the algorithm takes much longer to discover the best strategy
19 | (10s or 100s of thousands of games). In some cases quite ingenious tactics are employed to evade the ghosts, for example
20 | waiting in one location to draw the ghosts down a particular path, before taking a different route towards the exit:
21 |
22 | [](https://codebox.net/assets/video/reinforcement-learning-sarsa-lambda/sarsa_ghosts.webm)
23 |
24 | To run the code for yourself just clone the project,
25 | draw your own map [in the main.py file](https://github.com/codebox/sarsa-lambda/blob/master/main.py#L15), and
26 | use the following command to let the algorithm start learning:
27 |
28 | ```
29 | python main.py
30 | ```
31 |
32 | Progress is saved/resumed automatically. Use Ctrl-C to stop the application, next time the code is run it will
33 | continue from where it left off. In order to monitor progress you can start the web interface like this:
34 |
35 | ```
36 | python web_server.py
37 | ```
38 |
39 | and then [watch the games in your web browser](http://localhost:8080) as they run.
40 |
41 |
--------------------------------------------------------------------------------
/environment.py:
--------------------------------------------------------------------------------
1 | import re, random
2 |
3 | ACTION_UP = 'U'
4 | ACTION_RIGHT = 'R'
5 | ACTION_DOWN = 'D'
6 | ACTION_LEFT = 'L'
7 |
8 | ACTIONS = (ACTION_UP, ACTION_RIGHT, ACTION_DOWN, ACTION_LEFT)
9 |
10 | STATE_ACTOR = '웃'
11 | STATE_EXIT = 'X'
12 | STATE_BLOCK = '█'
13 | STATE_MONSTER = 'M'
14 | STATE_EMPTY = '.'
15 |
16 | REWARD_MOVEMENT = -1
17 | REWARD_BAD_MOVE = -5
18 | REWARD_MONSTER = -100
19 | REWARD_EXIT = 100
20 |
21 | MONSTER_RANDOMNESS = 0.1
22 |
23 | class Environment:
24 | actions = ACTIONS
25 |
26 | def __init__(self, grid_text):
27 | self.grid = self.__parse_grid_text(grid_text)
28 | self.height = len(self.grid)
29 | self.width = len(self.grid[0])
30 | self.actor_in_terminal_state = False
31 | self.monsters = []
32 |
33 | for y in range(self.height):
34 | for x in range(self.width):
35 | content = self.grid[y][x]
36 |
37 | if content == STATE_ACTOR:
38 | self.actor_pos = Position(x, y)
39 |
40 | elif content == STATE_MONSTER:
41 | self.monsters.append(Position(x, y))
42 |
43 | def __parse_grid_text(self, grid_text):
44 | rows = re.split("\s*\n\s*", grid_text.strip())
45 | return list(map(lambda row:row.split(' '), rows))
46 |
47 | def get_actor_state(self):
48 | return str(self.actor_pos)
49 |
50 | def get(self):
51 | return self.grid
52 |
53 | def __position_on_grid(self, pos):
54 | return (0 <= pos.x < self.width) and (0 <= pos.y < self.height)
55 |
56 | def __get_valid_monster_moves(self, current_x, current_y):
57 | compass_directions = [
58 | Position(current_x + 1, current_y),
59 | Position(current_x - 1, current_y),
60 | Position(current_x, current_y + 1),
61 | Position(current_x, current_y - 1)
62 | ]
63 |
64 | random.shuffle(compass_directions)
65 |
66 | def can_move(pos):
67 | return self.grid[pos.y][pos.x] in [STATE_EMPTY, STATE_ACTOR]
68 |
69 | possible_moves = list(filter(lambda pos: self.__position_on_grid(pos) and can_move(pos), compass_directions))
70 |
71 | possible_moves.sort(key=lambda pos: pos.dist_sq(self.actor_pos))
72 |
73 | return possible_moves
74 |
75 | def __move_monsters(self):
76 | for monster_position in self.monsters:
77 | current_x = monster_position.x
78 | current_y = monster_position.y
79 |
80 | valid_moves = self.__get_valid_monster_moves(current_x, current_y)
81 |
82 | if len(valid_moves):
83 | move_randomly = random.random() < MONSTER_RANDOMNESS
84 | if move_randomly:
85 | new_pos = random.choice(valid_moves)
86 | else:
87 | new_pos = valid_moves[0]
88 |
89 | self.grid[current_y][current_x] = STATE_EMPTY
90 | self.grid[new_pos.y][new_pos.x] = STATE_MONSTER
91 |
92 | monster_position.x = new_pos.x
93 | monster_position.y = new_pos.y
94 |
95 | if new_pos == self.actor_pos:
96 | self.actor_in_terminal_state = True
97 |
98 | def __update_environment(self):
99 | self.__move_monsters()
100 |
101 | def perform_action(self, action):
102 | reward = 0
103 |
104 | actor_requested_pos = self.actor_pos.copy()
105 |
106 | if action == ACTION_UP:
107 | actor_requested_pos.up()
108 |
109 | elif action == ACTION_RIGHT:
110 | actor_requested_pos.right()
111 |
112 | elif action == ACTION_DOWN:
113 | actor_requested_pos.down()
114 |
115 | elif action == ACTION_LEFT:
116 | actor_requested_pos.left()
117 |
118 | else:
119 | assert False, 'action=' + str(action)
120 |
121 | if self.__position_on_grid(actor_requested_pos):
122 | requested_location_contents = self.grid[actor_requested_pos.y][actor_requested_pos.x]
123 | else:
124 | requested_location_contents = STATE_BLOCK
125 |
126 | def move_actor_to_requested_location():
127 | self.grid[self.actor_pos.y][self.actor_pos.x] = STATE_EMPTY
128 | self.actor_pos = actor_requested_pos
129 | self.grid[self.actor_pos.y][self.actor_pos.x] = STATE_ACTOR
130 |
131 | if requested_location_contents == STATE_BLOCK:
132 | reward += REWARD_BAD_MOVE
133 |
134 | elif requested_location_contents == STATE_EMPTY:
135 | reward += REWARD_MOVEMENT
136 | move_actor_to_requested_location()
137 |
138 | elif requested_location_contents == STATE_EXIT:
139 | reward += REWARD_MOVEMENT + REWARD_EXIT
140 | move_actor_to_requested_location()
141 | self.actor_in_terminal_state = True
142 | print("SUCCESS") # TODO
143 |
144 | elif requested_location_contents == STATE_MONSTER:
145 | reward += REWARD_MOVEMENT + REWARD_MONSTER
146 | move_actor_to_requested_location()
147 | self.actor_in_terminal_state = True
148 |
149 | else:
150 | assert False, 'requested_location_contents=' + str(requested_location_contents)
151 |
152 | self.__update_environment()
153 |
154 | return reward
155 |
156 |
157 | class Position:
158 | def __init__(self, x, y):
159 | self.x = x
160 | self.y = y
161 |
162 | def dist_sq(self, other):
163 | return (self.x - other.x) ** 2 + (self.y - other.y) ** 2
164 |
165 | def copy(self):
166 | return Position(self.x, self.y)
167 |
168 | def up(self):
169 | self.y -= 1
170 |
171 | def down(self):
172 | self.y += 1
173 |
174 | def left(self):
175 | self.x -= 1
176 |
177 | def right(self):
178 | self.x += 1
179 |
180 | def __eq__(self, other):
181 | return self.x == other.x and self.y == other.y
182 |
183 | def __repr__(self):
184 | return '{},{}'.format(self.x, self.y)
--------------------------------------------------------------------------------