├── LICENSE ├── README.md ├── example.json ├── main.py ├── mdp.py ├── ui.py └── view_controller.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Mohit Deshpande 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | 3 | This is the code for [this](https://youtu.be/fRmZck1Dakc) video on Youtube by Siraj Raval as part of the move 37 course at School of AI. GridWorld is a test-bed frame for learning about Markov Decision Processes. This project is meant to be a tool for anyone to get started learning about Markov Decision Processes. 4 | 5 | ## Prerequisites 6 | * Python 3 7 | * `tkinder` for the UI 8 | 9 | ## Getting Started 10 | The main program to run is `main.py`. This must be supplied with a JSON configuration file to build your GridWorld. See `example.json` for a rudimentary configuration. 11 | 12 | Example Usage: 13 | ```bash 14 | python3 main.py example.json 15 | ``` 16 | 17 | Click the buttons to step through value iteration or progress in increments of 100. The policy current is shown as arrows at each cell. 18 | 19 | ## Building Your Own GridWorld 20 | The map descriptor for GridWorld is in the JSON configuration file given to `main.py` on startup. The JSON configuration defines the following parameters: 21 | 22 | | Parameter | Type | Description | 23 | | --------- | ---- | ----------- | 24 | | `width` | `int` > 0 | width of GridWorld in cells | 25 | | `height` | `int` > 0| height of GridWorld in cells | 26 | | `initial_value` | `float` | initial value to populate all values with | 27 | | `discount` | 0 ≤ `float` ≤ 1 | discounting factor for compute future expected rewards | 28 | | `living_cost` | `float` ≤ 0 | cost subtracted from each value at each iteration | 29 | | `obstacles` | `list` of indices (`list`) | cells that we consider to be non-traversable obstacles | 30 | | `transition_distribution` | `dict` with `float`s for `"forward"`, `"left"`,`"right"`,`"backward"` | transition function; adjacent tiles only; must sum to 1! | 31 | | `terminals` | `list` of objects with `"state"` (index `list`) and `"reward"` (`float`) | terminal states and their associated rewards | 32 | 33 | __There is no input validation so unexpected behavior will occur for unreasonable values, e.g., a negative width or height. YOU HAVE BEEN WARNED!__ 34 | 35 | ## License 36 | This project is licensed under the MIT License. In other words, do whatever you want with it! Just remember to give appropriate credit :-) 37 | 38 | ## Credits 39 | Credits for this code go to [mohitd](https://github.com/mohitd/gridworld). I've merely created a wrapper to get people started. 40 | -------------------------------------------------------------------------------- /example.json: -------------------------------------------------------------------------------- 1 | { 2 | "width": 4, 3 | "height": 3, 4 | "initial_value": 0.0, 5 | "discount": 0.9, 6 | "living_cost": 0.0, 7 | "obstacles": [[1, 1]], 8 | "transition_distribution": { 9 | "forward": 0.8, 10 | "left": 0.1, 11 | "right": 0.1, 12 | "backward": 0.0 13 | }, 14 | "terminals": [ 15 | { 16 | "state": [0, 3], 17 | "reward": 1 18 | }, 19 | { 20 | "state": [1, 3], 21 | "reward": -1 22 | } 23 | ] 24 | } -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | 5 | from view_controller import ViewController 6 | 7 | def main(): 8 | if len(sys.argv) < 2: 9 | print('Must input config file!') 10 | sys.exit(1) 11 | if not os.path.exists(sys.argv[1]): 12 | print('Config file must exist!') 13 | sys.exit(1) 14 | 15 | with open(sys.argv[1]) as f: 16 | metadata = json.load(f) 17 | 18 | view_controller = ViewController(metadata=metadata) 19 | view_controller.run() 20 | 21 | if __name__ == '__main__': 22 | main() 23 | -------------------------------------------------------------------------------- /mdp.py: -------------------------------------------------------------------------------- 1 | import random 2 | from operator import add 3 | 4 | MIN_DELTA = 1e-4 5 | 6 | class GridMDP(object): 7 | def __init__(self, metadata): 8 | self.width = metadata['width'] 9 | self.height = metadata['height'] 10 | self.initial_value = metadata['initial_value'] 11 | self.obstacles = metadata['obstacles'] 12 | self.living_cost = metadata['living_cost'] 13 | 14 | self.discount = metadata['discount'] 15 | self.transition_distribution = metadata['transition_distribution'] 16 | self.rewards = {tuple(terminal['state']) : terminal['reward'] for terminal in metadata['terminals']} 17 | self.terminals = list(self.rewards.keys()) 18 | 19 | self._init_grid() 20 | 21 | # enumerate state space 22 | self.states = set() 23 | for row in range(self.height): 24 | for col in range(self.width): 25 | if self.grid[row][col] is not None: 26 | self.states.add((row, col)) 27 | 28 | # move one tile at a time 29 | self.actions = [(1, 0), (0, 1), (-1, 0), (0, -1)] 30 | self.num_actions = len(self.actions) 31 | 32 | # initialize values and policy 33 | self.policy = {} 34 | self.values = {} 35 | for state in self.states: 36 | self.values[state] = self.initial_value 37 | self.policy[state] = random.choice(self.actions) 38 | 39 | def R(self, state): 40 | if state in self.terminals: 41 | return self.rewards[state] 42 | else: 43 | # living cost 44 | return self.living_cost 45 | 46 | def _init_grid(self): 47 | self.grid = [[self.initial_value for col in range(self.width)] for row in range(self.height)] 48 | # apply obstacles 49 | for obstacle in self.obstacles: 50 | self.grid[obstacle[0]][obstacle[1]] = None 51 | 52 | def _move_forward(self, state, action): 53 | new_state = tuple(map(add, state, action)) 54 | return new_state if new_state in self.states else state 55 | 56 | def _move_backward(self, state, action): 57 | new_action = self.actions[(self.actions.index(action) + 2) % self.num_actions] 58 | new_state = tuple(map(add, state, new_action)) 59 | return new_state if new_state in self.states else state 60 | 61 | def _move_left(self, state, action): 62 | new_action = self.actions[(self.actions.index(action) - 1) % self.num_actions] 63 | new_state = tuple(map(add, state, new_action)) 64 | return new_state if new_state in self.states else state 65 | 66 | def _move_right(self, state, action): 67 | new_action = self.actions[(self.actions.index(action) + 1) % self.num_actions] 68 | new_state = tuple(map(add, state, new_action)) 69 | return new_state if new_state in self.states else state 70 | 71 | def allowed_actions(self, state): 72 | if state in self.terminals: 73 | return [None] 74 | else: 75 | return self.actions 76 | 77 | def next_state_distribution(self, state, action): 78 | if action == None: 79 | return [(0.0, state)] 80 | else: 81 | return [(self.transition_distribution['forward'], self._move_forward(state, action)), 82 | (self.transition_distribution['left'], self._move_left(state, action)), 83 | (self.transition_distribution['right'], self._move_right(state, action)), 84 | (self.transition_distribution['backward'], self._move_backward(state, action))] 85 | 86 | def update_values(self, values): 87 | self.values = values 88 | 89 | def update_policy(self, policy): 90 | self.policy = policy 91 | 92 | def clear(self): 93 | self._init_grid() 94 | for state in self.states: 95 | self.values[state] = self.initial_value 96 | self.policy[state] = random.choice(self.actions) 97 | 98 | def _expected_value(state, action, values, mdp): 99 | return sum([prob * values[new_state] for prob, new_state in mdp.next_state_distribution(state, action)]) 100 | 101 | def values_converged(new_values, old_values): 102 | sum_abs_diff = sum([abs(new_values[state] - old_values[state]) for state in new_values.keys()]) 103 | return sum_abs_diff < MIN_DELTA 104 | 105 | def policy_converged(new_policy, old_policy): 106 | same_action_for_state = [new_policy[state] == old_policy[state] for state in new_policy.keys()] 107 | return all(same_action_for_state) 108 | 109 | def value_iteration(initial_values, mdp, num_iter=100): 110 | # initialize values 111 | values = initial_values 112 | 113 | for _ in range(num_iter): 114 | """ 115 | We're making a copy so newly updated values don't affect each other. 116 | In practice, the values converge to the same thing, but I've added this here 117 | in case you want to step through the values iteration-by-iteration. 118 | """ 119 | new_values = dict(values) 120 | for state in mdp.states: 121 | new_values[state] = mdp.R(state) + mdp.discount * max([_expected_value(state, action, values, mdp) for action in mdp.allowed_actions(state)]) 122 | 123 | if values_converged(new_values, values): 124 | break 125 | 126 | # update values for next iteration 127 | values = new_values 128 | 129 | return values 130 | 131 | def policy_extraction(values, mdp): 132 | policy = {} 133 | for state in mdp.states: 134 | # we don't need to compute the full mdp.R(state) + mdp.discount * ... since mdp.R(state) and mdp.discount are constant given a state 135 | expected_values = [_expected_value(state, action, values, mdp) for action in mdp.allowed_actions(state)] 136 | action_idx, _ = max(enumerate(expected_values), key=lambda ev: ev[1]) 137 | policy[state] = mdp.actions[action_idx] 138 | return policy 139 | 140 | def policy_evaluation(policy, values, mdp, num_iter=50): 141 | for _ in range(num_iter): 142 | for state in mdp.states: 143 | values[state] = mdp.R(state) + mdp.discount * _expected_value(state, policy[state], values, mdp) 144 | 145 | return values 146 | 147 | def policy_iteration(initial_policy, mdp, num_iter=100): 148 | policy = initial_policy 149 | values = {state: 0 for state in mdp.states} 150 | 151 | for _ in range(num_iter): 152 | new_policy = dict(policy) 153 | 154 | values = policy_evaluation(policy, values, mdp) 155 | unchanged_policy = True 156 | for state in mdp.states: 157 | expected_values = [_expected_value(state, action, values, mdp) for action in mdp.allowed_actions(state)] 158 | action_idx, _ = max(enumerate(expected_values), key=lambda ev: ev[1]) 159 | action = mdp.actions[action_idx] 160 | if action != new_policy[state]: 161 | new_policy[state] = action 162 | unchanged_policy = False 163 | 164 | policy = new_policy 165 | 166 | if unchanged_policy: 167 | break 168 | 169 | return policy, values 170 | 171 | -------------------------------------------------------------------------------- /ui.py: -------------------------------------------------------------------------------- 1 | from tkinter import * 2 | from tkinter import messagebox 3 | import math 4 | 5 | CELL_SIZE = 80 6 | CELL_PADDING = 10 7 | ARROW_LENGTH = CELL_SIZE / 2 8 | 9 | class GridWorldWindow(object): 10 | """Manages all of the UI 11 | """ 12 | def __init__(self, metadata): 13 | self.window = Tk() 14 | self.window.title('Gridworld') 15 | self.window.geometry('{}x{}'.format(1080, 720)) 16 | 17 | # extract data from the JSON 18 | self.grid_width = metadata['width'] 19 | self.grid_height = metadata['height'] 20 | self.obstacles = [tuple(obstacle) for obstacle in metadata['obstacles']] 21 | self.terminals = [tuple(terminal['state']) for terminal in metadata['terminals']] 22 | 23 | self.canvas_width = metadata['width'] * CELL_SIZE 24 | self.canvas_height = metadata['height'] * CELL_SIZE 25 | 26 | # create the tkinder IDs for all of the modifiable UI 27 | self.ids_text = [[0 for col in range(self.grid_width)] for row in range(self.grid_height)] 28 | self.ids_rect = [[0 for col in range(self.grid_width)] for row in range(self.grid_height)] 29 | self.ids_arrow = [[0 for col in range(self.grid_width)] for row in range(self.grid_height)] 30 | 31 | self._create_buttons() 32 | 33 | self.canvas = Canvas(self.window, width=self.canvas_width, height=self.canvas_height, bg='black') 34 | self.canvas.pack(padx=10, pady=10) 35 | 36 | self._create_grid() 37 | 38 | def _create_buttons(self): 39 | self.frame_value_buttons = Frame(self.window) 40 | self.frame_value_buttons.pack(padx=5, pady=5) 41 | 42 | self.frame_policy_buttons = Frame(self.window) 43 | self.frame_policy_buttons.pack(padx=5, pady=5) 44 | 45 | self.frame_reset_buttons = Frame(self.window) 46 | self.frame_reset_buttons.pack(padx=5, pady=5) 47 | 48 | self.btn_value_iteration_1_step = Button(self.frame_value_buttons, text='1-Step Value Iteration', anchor=W) 49 | self.btn_value_iteration_1_step.pack(side=LEFT) 50 | 51 | self.btn_value_iteration_100_steps = Button(self.frame_value_buttons, text='100-Step Value Iteration', anchor=E) 52 | self.btn_value_iteration_100_steps.pack(side=LEFT) 53 | 54 | self.btn_value_iteration_slow = Button(self.frame_value_buttons, text='Slow Value Iteration', anchor=E) 55 | self.btn_value_iteration_slow.pack(side=LEFT) 56 | 57 | self.btn_policy_iteration_1_step = Button(self.frame_policy_buttons, text='1-Step Policy Iteration', anchor=E) 58 | self.btn_policy_iteration_1_step.pack(side=LEFT) 59 | 60 | self.btn_policy_iteration_100_steps = Button(self.frame_policy_buttons, text='100-Step Policy Iteration', anchor=E) 61 | self.btn_policy_iteration_100_steps.pack(side=LEFT) 62 | 63 | self.btn_policy_iteration_slow = Button(self.frame_policy_buttons, text='Slow Policy Iteration', anchor=E) 64 | self.btn_policy_iteration_slow.pack(side=LEFT) 65 | 66 | self.btn_reset = Button(self.frame_reset_buttons, text='Reset', anchor=E) 67 | self.btn_reset.pack(side=LEFT) 68 | 69 | def _create_grid(self): 70 | for row in range(self.grid_height): 71 | for col in range(self.grid_width): 72 | if (row, col) in self.obstacles: 73 | fill = 'grey' 74 | text = None 75 | else: 76 | fill = None 77 | text = '0.00' 78 | 79 | self.ids_rect[row][col] = self.canvas.create_rectangle(col * CELL_SIZE, row * CELL_SIZE, (col+1) * CELL_SIZE, (row+1) * CELL_SIZE, fill=fill, outline='white') 80 | if (row, col) in self.terminals: 81 | self.canvas.create_rectangle(col * CELL_SIZE + CELL_PADDING, row * CELL_SIZE + CELL_PADDING, (col+1) * CELL_SIZE - CELL_PADDING, (row+1) * CELL_SIZE - CELL_PADDING, fill=fill, outline='white') 82 | 83 | self.ids_text[row][col] = self.canvas.create_text(col * CELL_SIZE + CELL_SIZE/2, row * CELL_SIZE + CELL_SIZE/2, text=text, fill='white') 84 | self.ids_arrow[row][col] = self.canvas.create_line(0, 0, 0, 0, width=2, arrow=LAST, fill='white') 85 | 86 | def _compute_color(self, value): 87 | # negative values are redder while positive values are greener 88 | if value == 0: 89 | return '#000000' 90 | elif value > 0: 91 | g = math.floor(255 if value >= 1.0 else value * 256) 92 | return '#{:02x}{:02x}{:02x}'.format(0, g, 0) 93 | elif value < 0: 94 | r = math.floor(255 if -value >= 1.0 else -value * 256) 95 | return '#{:02x}{:02x}{:02x}'.format(r, 0, 0) 96 | 97 | def show_dialog(self, text): 98 | messagebox.showinfo('Info', text) 99 | 100 | def update_grid(self, values, policy): 101 | for state, value in values.items(): 102 | rect_id = self.ids_rect[state[0]][state[1]] 103 | text_id = self.ids_text[state[0]][state[1]] 104 | arrow_id = self.ids_arrow[state[0]][state[1]] 105 | 106 | self.canvas.itemconfig(rect_id, fill=self._compute_color(value)) 107 | self.canvas.itemconfig(text_id, text='{:.2f}'.format(value)) 108 | 109 | if state not in self.terminals: 110 | self.canvas.coords(arrow_id, 111 | state[1] * CELL_SIZE + CELL_SIZE/2 + policy[state][1] * ARROW_LENGTH - policy[state][1], 112 | state[0] * CELL_SIZE + CELL_SIZE/2 + policy[state][0] * ARROW_LENGTH - policy[state][0], 113 | state[1] * CELL_SIZE + CELL_SIZE/2 + policy[state][1] * ARROW_LENGTH, 114 | state[0] * CELL_SIZE + CELL_SIZE/2 + policy[state][0] * ARROW_LENGTH) 115 | 116 | def clear(self): 117 | for row in range(self.grid_height): 118 | for col in range(self.grid_width): 119 | rect_id = self.ids_rect[row][col] 120 | text_id = self.ids_text[row][col] 121 | arrow_id = self.ids_arrow[row][col] 122 | 123 | if (row, col) in self.obstacles: 124 | fill = 'grey' 125 | text = None 126 | else: 127 | fill = self._compute_color(0) 128 | text = '0.00' 129 | self.canvas.itemconfig(rect_id, fill=fill) 130 | self.canvas.itemconfig(text_id, text=text) 131 | self.canvas.coords(arrow_id, 0, 0, 0, 0) 132 | 133 | def run(self): 134 | # run the UI loop 135 | mainloop() 136 | 137 | -------------------------------------------------------------------------------- /view_controller.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from ui import GridWorldWindow 4 | from mdp import GridMDP, value_iteration, policy_extraction, policy_evaluation, policy_iteration, values_converged, policy_converged 5 | 6 | class ViewController(object): 7 | def __init__(self, metadata): 8 | self.gridworld = GridWorldWindow(metadata=metadata) 9 | self.mdp = GridMDP(metadata=metadata) 10 | 11 | # bind buttons 12 | self.gridworld.btn_value_iteration_1_step.configure(command=self._value_iteration_1_step) 13 | self.gridworld.btn_value_iteration_100_steps.configure(command=self._value_iteration_100_steps) 14 | self.gridworld.btn_value_iteration_slow.configure(command=self._value_iteration_slow) 15 | self.gridworld.btn_policy_iteration_1_step.configure(command=self._policy_iteration_1_step) 16 | self.gridworld.btn_policy_iteration_100_steps.configure(command=self._policy_iteration_100_steps) 17 | self.gridworld.btn_policy_iteration_slow.configure(command=self._policy_iteration_slow) 18 | 19 | self.gridworld.btn_reset.configure(command=self._reset_grid) 20 | 21 | def _value_iteration_1_step(self): 22 | values = value_iteration(self.mdp.values, self.mdp, num_iter=1) 23 | policy = policy_extraction(values, self.mdp) 24 | self.gridworld.update_grid(values, policy) 25 | self.mdp.update_values(values) 26 | self.mdp.update_policy(policy) 27 | 28 | def _value_iteration_100_steps(self): 29 | values = value_iteration(self.mdp.values, self.mdp, num_iter=100) 30 | policy = policy_extraction(values, self.mdp) 31 | self.gridworld.update_grid(values, policy) 32 | self.mdp.update_values(values) 33 | self.mdp.update_policy(policy) 34 | 35 | def _value_iteration_slow(self): 36 | # run one iteration of value iteration at a time 37 | old_values = dict(self.mdp.values) 38 | for i in range(100): 39 | values = value_iteration(self.mdp.values, self.mdp, num_iter=1) 40 | policy = policy_extraction(values, self.mdp) 41 | self.gridworld.update_grid(values, policy) 42 | self.mdp.update_values(values) 43 | self.mdp.update_policy(policy) 44 | 45 | self.gridworld.window.update() 46 | time.sleep(0.25) 47 | self.gridworld.window.update() 48 | 49 | new_values = dict(values) 50 | if values_converged(new_values, old_values): 51 | break 52 | 53 | old_values = new_values 54 | self.gridworld.show_dialog('Value Iteration has converged in {} steps!'.format(i+1)) 55 | 56 | def _policy_iteration_1_step(self): 57 | policy, values = policy_iteration(self.mdp.policy, self.mdp, num_iter=1) 58 | self.gridworld.update_grid(values, policy) 59 | self.mdp.update_values(values) 60 | self.mdp.update_policy(policy) 61 | 62 | def _policy_iteration_100_steps(self): 63 | policy_iteration(self.mdp, num_iter=100) 64 | self.gridworld.update_grid(self.mdp.values, self.mdp.policy) 65 | 66 | def _policy_iteration_slow(self): 67 | # run one iteration of policy iteration at a time 68 | old_policy = dict(self.mdp.policy) 69 | for i in range(100): 70 | policy_iteration(self.mdp, num_iter=1) 71 | self.gridworld.update_grid(self.mdp.values, self.mdp.policy) 72 | self.gridworld.window.update() 73 | time.sleep(0.25) 74 | self.gridworld.window.update() 75 | 76 | new_policy = dict(self.mdp.policy) 77 | if policy_converged(new_policy, old_policy): 78 | break 79 | 80 | old_policy = new_policy 81 | self.gridworld.show_dialog('Policy Iteration has converged in {} steps!'.format(i+1)) 82 | 83 | def _reset_grid(self): 84 | self.mdp.clear() 85 | self.gridworld.clear() 86 | 87 | def run(self): 88 | # main UI loop 89 | self.gridworld.run() 90 | --------------------------------------------------------------------------------