├── README.md
├── imgs
├── 1.PNG
├── 2.PNG
└── sutton_barto_example.PNG
└── source code
├── 0-MDP Environment (Chapter 3)
├── .ipynb_checkpoints
│ └── 1-introduction-to-gridworld-environment-checkpoint.ipynb
├── 1-introduction-to-gridworld-environment.ipynb
├── __pycache__
│ └── gridWorldEnvironment.cpython-36.pyc
├── gridWorldEnvironment.py
└── gridworld.txt
├── 1-Dynamic Programming (Chapter 4)
├── .ipynb_checkpoints
│ ├── 1-policy-evaluation-and-improvement-checkpoint.ipynb
│ ├── 2-policy-iteration-checkpoint.ipynb
│ └── 3-value-iteration-checkpoint.ipynb
├── 1-policy-evaluation-and-improvement.ipynb
├── 2-policy-iteration.ipynb
├── 3-value-iteration.ipynb
├── gridWorldEnvironment.py
└── gridworld.txt
├── 2-Monte Carlo Methods (Chapter 5)
├── .ipynb_checkpoints
│ ├── 1-monte-carlo-prediction-checkpoint.ipynb
│ ├── 2-monte-carlo-exploring-starts-checkpoint.ipynb
│ ├── 3-on-policy-monte-carlo-checkpoint.ipynb
│ └── 4-off-policy-monte-carlo-checkpoint.ipynb
├── 1-monte-carlo-prediction.ipynb
├── 2-monte-carlo-exploring-starts.ipynb
├── 3-on-policy-monte-carlo.ipynb
├── 4-off-policy-monte-carlo.ipynb
├── __pycache__
│ └── gridWorldEnvironment.cpython-36.pyc
├── gridWorldEnvironment.py
└── gridworld.txt
├── 3-Temporal Difference Learning (Chapter 6)
├── .ipynb_checkpoints
│ ├── 1-td-prediction-checkpoint.ipynb
│ ├── 2-SARSA-on-policy control-checkpoint.ipynb
│ ├── 3-Q-learning-off-policy-control-checkpoint.ipynb
│ └── 4-double-Q-learning-off-policy-control-checkpoint.ipynb
├── 1-td-prediction.ipynb
├── 2-SARSA-on-policy control.ipynb
├── 3-Q-learning-off-policy-control.ipynb
├── 4-double-Q-learning-off-policy-control.ipynb
├── __pycache__
│ └── gridWorldEnvironment.cpython-36.pyc
├── gridWorldEnvironment.py
└── gridworld.txt
└── 4-n-step Bootstrapping (Chapter 7)
├── .ipynb_checkpoints
├── 1-n-step-td-prediction-checkpoint.ipynb
├── 2-n-step-SARSA-on-policy-control-checkpoint.ipynb
├── 3-n-step-off-policy-learning-by-importance-sampling-checkpoint.ipynb
└── 4-n-step-off-policy-learning-wo-importance-sampling-checkpoint.ipynb
├── 1-n-step-td-prediction.ipynb
├── 2-n-step-SARSA-on-policy-control.ipynb
├── 3-n-step-off-policy-learning-by-importance-sampling.ipynb
├── 4-n-step-off-policy-learning-wo-importance-sampling.ipynb
├── __pycache__
└── gridWorldEnvironment.cpython-36.pyc
├── gridWorldEnvironment.py
└── gridworld.txt
/README.md:
--------------------------------------------------------------------------------
1 | # Tabular Reinforcement Learning with Algorithms Python
2 | Python implementation of Tabular RL Algorithms in Sutton & Barto 2017 (Reinforcement Learning: An Introduction)
3 | Using only NumPy & basic Python data structures (list, tuple, set, and dictionary) to create environment & create algorithms
4 |
5 | ### Algorithms learning from *4X4 Grid World Environment (From Sutton & Barto 2017, pp. 61)*
6 |
7 | 
8 |
9 | ### Tabular Reinforcement Learning Algorithms with *NumPy*
10 |
11 | 
12 |
13 | ### Visualizations with *Seaborn* (Policy & Value function)
14 |
15 | 
16 |
17 |
18 | ## Contents
19 |
20 | ### 0. MDP Environment (Chapter 3, Sutton & Barto 2017)
21 | 1. Introduction to gridworld environment
22 |
23 | ### 1. Dynamic Programming (Chapter 4, Sutton & Barto 2017)
24 | 1. Policy Evaluation and improvement
25 | 2. Policy Iteration
26 | 3. Value Iteration
27 |
28 | ### 2. Monte Carlo Methods (Chapter 5, Sutton & Barto 2017)
29 | 1. Monte Carlo Prediction
30 | 2. Monte Carlo Exploring Starts
31 | 3. On Policy Monte Carlo
32 | 4. Off Policy Monte Carlo
33 |
34 | ### 3. Temporal Difference Learning (Chapter 6, Sutton & Barto 2017)
35 | 1. TD Prediction
36 | 2. SARSA - On-policy Control
37 | 3. Q-learning - Off-policy Control
38 | 4. Double Q-learning - Off-policy Control
39 |
40 | ### 4. n-step Bootstrapping (Chapter 7, Sutton & Barto 2017)
41 | 1. n-step TD Prediction
42 | 2. n-step SARSA - On-policy Control
43 | 3. n-step Off-policy learning by Importance Sampling
44 | 4. n-step Off-policy learning without Importance Sampling
45 |
46 |
47 |
--------------------------------------------------------------------------------
/imgs/1.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buomsoo-kim/Tabular-RL-with-Python/0157c05126821524dc9d744613b9aa8dd5d47232/imgs/1.PNG
--------------------------------------------------------------------------------
/imgs/2.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buomsoo-kim/Tabular-RL-with-Python/0157c05126821524dc9d744613b9aa8dd5d47232/imgs/2.PNG
--------------------------------------------------------------------------------
/imgs/sutton_barto_example.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buomsoo-kim/Tabular-RL-with-Python/0157c05126821524dc9d744613b9aa8dd5d47232/imgs/sutton_barto_example.PNG
--------------------------------------------------------------------------------
/source code/0-MDP Environment (Chapter 3)/.ipynb_checkpoints/1-introduction-to-gridworld-environment-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# GridWorld Environment"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {
14 | "collapsed": true
15 | },
16 | "outputs": [],
17 | "source": [
18 | "import matplotlib.pyplot as plt\n",
19 | "import pandas as pd\n",
20 | "from gridWorldEnvironment import GridWorld"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": 2,
26 | "metadata": {},
27 | "outputs": [
28 | {
29 | "name": "stdout",
30 | "output_type": "stream",
31 | "text": [
32 | "Actions: ('U', 'D', 'L', 'R')\n",
33 | "States: [ 1 2 3 4 5 6 7 8 9 10 11 12 13 14]\n"
34 | ]
35 | }
36 | ],
37 | "source": [
38 | "# creating gridworld environment\n",
39 | "gw = GridWorld()\n",
40 | "\n",
41 | "print(\"Actions: \", gw.actions)\n",
42 | "print(\"States: \", gw.states)"
43 | ]
44 | },
45 | {
46 | "cell_type": "markdown",
47 | "metadata": {},
48 | "source": [
49 | "### State Transitions\n",
50 | "- All possible state transitions in deterministic gridworld\n",
51 | " - Each transition is a quadruple of (```state```,```action```, ```next state```, ```reward```)\n",
52 | " - For instance, first row implies that in if the agent performs action ```U``` (upwards) in state ```1```, it ends up at state ```1``` again (not moving) with -1 reward"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": 12,
58 | "metadata": {
59 | "scrolled": true
60 | },
61 | "outputs": [
62 | {
63 | "data": {
64 | "text/html": [
65 | "
\n",
66 | "\n",
79 | "
\n",
80 | " \n",
81 | " \n",
82 | " | \n",
83 | " State | \n",
84 | " Action | \n",
85 | " Next State | \n",
86 | " Reward | \n",
87 | "
\n",
88 | " \n",
89 | " \n",
90 | " \n",
91 | " 0 | \n",
92 | " 1 | \n",
93 | " U | \n",
94 | " 1 | \n",
95 | " -1 | \n",
96 | "
\n",
97 | " \n",
98 | " 1 | \n",
99 | " 1 | \n",
100 | " D | \n",
101 | " 5 | \n",
102 | " -1 | \n",
103 | "
\n",
104 | " \n",
105 | " 2 | \n",
106 | " 1 | \n",
107 | " R | \n",
108 | " 2 | \n",
109 | " -1 | \n",
110 | "
\n",
111 | " \n",
112 | " 3 | \n",
113 | " 1 | \n",
114 | " L | \n",
115 | " 0 | \n",
116 | " -1 | \n",
117 | "
\n",
118 | " \n",
119 | " 4 | \n",
120 | " 2 | \n",
121 | " U | \n",
122 | " 2 | \n",
123 | " -1 | \n",
124 | "
\n",
125 | " \n",
126 | "
\n",
127 | "
"
128 | ],
129 | "text/plain": [
130 | " State Action Next State Reward\n",
131 | "0 1 U 1 -1\n",
132 | "1 1 D 5 -1\n",
133 | "2 1 R 2 -1\n",
134 | "3 1 L 0 -1\n",
135 | "4 2 U 2 -1"
136 | ]
137 | },
138 | "execution_count": 12,
139 | "metadata": {},
140 | "output_type": "execute_result"
141 | }
142 | ],
143 | "source": [
144 | "pd.DataFrame(gw.transitions, columns = [\"State\", \"Action\", \"Next State\", \"Reward\"]).head()"
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": 14,
150 | "metadata": {},
151 | "outputs": [
152 | {
153 | "name": "stdout",
154 | "output_type": "stream",
155 | "text": [
156 | "(1, -1)\n",
157 | "(2, -1)\n"
158 | ]
159 | }
160 | ],
161 | "source": [
162 | "print(gw.state_transition(1, \"U\"))\n",
163 | "print(gw.state_transition(3, \"L\"))"
164 | ]
165 | },
166 | {
167 | "cell_type": "markdown",
168 | "metadata": {},
169 | "source": [
170 | "### Show environment\n",
171 | "- Visualized environment is shown as table\n",
172 | "- Note that terminal states ```(0,0) and (3,3)``` are added"
173 | ]
174 | },
175 | {
176 | "cell_type": "code",
177 | "execution_count": 3,
178 | "metadata": {},
179 | "outputs": [
180 | {
181 | "data": {
182 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAdwAAAFJCAYAAAAxCJwFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAFnJJREFUeJzt3X1U1vX9x/HXBah58BZDdIlnuKnnWi7NFqftZB2HpukY\njtyRpugUjmflTR3TTYRgNgw967i1xKWgTZdJDQ2ciU1mudYOO9XKHRy5pZtmGunxXrkTvvvDE+dX\nB+m6OL/r/eEaz8dfXpeni1d9gCff78U5+TzP8wQAAEIqwvUAAAC6AoILAIABggsAgAGCCwCAAYIL\nAIABggsAgIGoUL643+9XcXFxKD8EQiQzM1M1NTWuZ6CD/H4/5xem/H6/JHF+Yaq9rz2ucAEAMEBw\nAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEA\nMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBA\ncAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEAMEBwAQAwEOV6QDg6ePCgduzYoaamJsXH\nx2vu3Lnq2bOn61kIgud5ysrK0vDhw5WRkeF6DgJUXl6uTZs2yefzqWfPnsrOztbXv/5117MQoOef\nf17bt2+Xz+dTfHy88vPzNWDAANezzHCFG6SLFy9q8+bNWrBggQoKChQbG6vS0lLXsxCEI0eOaM6c\nOaqoqHA9BUE4evSofv7zn6u4uFjl5eV66KGHtGjRItezEKDq6mpt3rxZJSUl2r17t7785S/r6aef\ndj3LVMDBbWlpCeWOsHHo0CElJCQoLi5OkjR+/HhVVVXJ8zzHyxCobdu2KTU1Vffff7/rKQhC9+7d\nlZ+fr4EDB0qSRo0apTNnzqixsdHxMgRi1KhRevXVV9W7d281NDSotrZW/fr1cz3LVLu3lD/88EMV\nFBSourpaUVFRamlp0YgRI5SVlaWEhASrjZ3K2bNnFRMT0/q4f//+qqurU319PbeVw0Rubq4kqaqq\nyvESBGPIkCEaMmSIpOtvCRQUFOjb3/62unfv7ngZAtWtWzdVVlYqOztb3bt31+LFi11PMtVucLOz\ns/XYY49p9OjRrc+99957ysrKUklJScjHdUY3upKNiODuPGDh6tWrWr58uT7++GMVFxe7noMgTZgw\nQRMmTNBLL72kjIwM7du3r8t8/2z337KxsfEzsZWkMWPGhHRQZzdgwACdP3++9fG5c+cUHR2tHj16\nOFwFdA0nT55UWlqaIiMjtXXrVvXp08f1JATo2LFjevvtt1sfP/DAAzp58qQuXLjgcJWtdq9wR44c\nqaysLI0bN069e/fWlStXdODAAY0cOdJqX6dz66236sUXX1Rtba3i4uL0+uuvd/kfQgAL58+f16xZ\ns5SamqqFCxe6noMgnT59WkuWLFFZWZliYmL0+9//XsOHD1f//v1dTzPTbnB/+tOfqrKyUu+8844u\nX76sXr16afz48Zo4caLVvk6nT58+mjdvngoLC9Xc3KzY2FhlZma6ngX8z9u+fbtOnTqlffv2ad++\nfa3P/+Y3v+lS37TD1Te+8Q396Ec/0uzZsxUZGamBAweqsLDQ9SxTPi+Ev17r9/t5jyVMZWZmqqam\nxvUMdJDf7+f8wpTf75ckzi9Mtfe11zXeqQYAwDGCCwCAAYILAIABggsAgAGCCwCAAYILAIABggsA\ngAGCCwCAAYILAIABggsAgAGCCwCAAYILAIABggsAgAGCCwCAAYILAIABggsAgAGCCwCAAYILAIAB\nggsAgAGCCwCAAYILAIABggsAgAGCCwCAAYILAIABggsAgAGCCwCAAYILAIABggsAgAGCCwCAAYIL\nAIABggsAgAGCCwCAAZ/neV6oXtzv94fqpQEA6JRqamrafD7K1QdG5+b3+zm7MMb5ha9PL1Q4v/DU\n3oUmt5QBADBAcAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEA\nMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBA\ncAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADAQ5XpAOKusrNSPf/xj\n/e1vf3M9BUFYvXq19u7dq759+0qSEhIS9Mtf/tLxKgTi8OHDys/P16VLlxQREaEnnnhCo0aNcj0L\nASgrK9Nzzz3X+vjSpUuqra3VgQMHdPPNNztcZofgdtB//vMfrVmzRp7nuZ6CIL377rtau3atxo4d\n63oKglBXV6eMjAytWrVK9957ryorK7V06VLt3bvX9TQEYNq0aZo2bZokqampSbNmzdL8+fO7TGwl\nbil3SF1dnZYtW6bly5e7noIgNTY26h//+Ic2b96s7373u1q0aJFOnjzpehYC8Oabbyo+Pl733nuv\nJCkpKYk7E2GqqKhIMTExSktLcz3FFMHtgNzcXM2YMUMjR450PQVBqq2t1V133aUlS5aovLxco0eP\n1sMPP8ydijDw73//W7GxsVqxYoVSU1M1d+5cNTc3u56FIJ09e1bPPfecVqxY4XqKOYIbpG3btikq\nKkrTp093PQUdEB8fr6KiIg0bNkw+n08ZGRk6fvy4Tpw44XoavsC1a9d04MABzZgxQzt37my9JdnY\n2Oh6GoLw0ksvKSkpSfHx8a6nmGv3Pdz09HQ1NTV95jnP8+Tz+VRSUhLSYZ3Vyy+/rPr6eqWkpKip\nqan1zxs3blRcXJzrefgC77//vt5///3W95Kk65/T3bp1c7gKgRg4cKCGDRum0aNHS5ImTJignJwc\nffjhh/rKV77ieB0CtWfPHuXk5Lie4US7wV26dKlycnJUWFioyMhIq02dWmlpaeufT5w4oeTkZJWX\nlztchGBERERo1apVuuOOOxQfH68XXnhBI0eO1KBBg1xPwxe45557tGbNGlVXV2vUqFF666235PP5\nNGTIENfTEKALFy7o+PHjuv32211PcaLd4I4ePVopKSk6fPiwJk6caLUJCJkRI0YoJydHDz30kJqb\nmzVo0CCtXbvW9SwEIDY2VoWFhVq5cqXq6urUvXt3PfPMM+rRo4fraQjQsWPHFBsb22XvKPm8EP62\niN/vV01NTaheHiHE2YU3zi98+f1+SeL8wlR7X3v80hQAAAYILgAABgguAAAGCC4AAAYILgAABggu\nAAAGCC4AAAYILgAABgguAAAGCC4AAAYILgAABgguAAAGCC4AAAYILgAABgguAAAGCC4AAAYILgAA\nBgguAAAGCC4AAAYILgAABgguAAAGCC4AAAYILgAABgguAAAGCC4AAAYILgAABgguAAAGCC4AAAYI\nLgAABgguAAAGCC4AAAYILgAABgguAAAGfJ7neaF6cb/fH6qXBgCgU6qpqWnz+ShXHxidm9/v5+zC\nGOcXvj69UOH8wlN7F5rcUgYAwADBBQDAAMEFAMAAwQUAwADBBQDAAMEFAMAAwQUAwADBBQDAAMEF\nAMAAwQUAwADBBQDAAMEFAMAAwQUAwADBBQDAAMEFAMAAwQUAwADBBQDAAMEFAMAAwQUAwADBBQDA\nAMEFAMAAwQUAwADBBQDAAMEFAMAAwQUAwADBBQDAAMEFAMAAwQUAwADBBQDAAMEFAMAAwQUAwADB\n7YB9+/YpOTlZKSkpSk9P1/Hjx11PQhB++9vfatKkSUpJSdGSJUt0/vx515PwBTzP0/Lly7Vp0yZJ\nUnNzs/Lz8zV58mRNnDhR27dvd7wQN/L5s/vUqVOnNG7cOJ09e9bRMnsEN0j19fVatmyZ1q1bp/Ly\nciUlJSk/P9/1LASoqqpKRUVF2rJli8rLy3XPPfcoNzfX9Sy048iRI5ozZ44qKipanyspKdGxY8e0\ne/dulZaWasuWLfr73//ucCXa0tbZSVJZWZlmzpypTz75xNEyN4IObmNjYyh2hI3m5mZ5nqdLly5J\nkq5cuaIePXo4XoVAHTp0SN/61rc0aNAgSdJ9992n/fv3d/nP685s27ZtSk1N1f3339/6XGVlpVJT\nUxUVFaW+fftq6tSp2rVrl8OVaEtbZ1dbW6vKykpt3LjR4TI3bhjc/fv3a/z48Zo4caL27NnT+nxm\nZqbJsM4qOjpaK1euVFpamu6++25t27ZNS5cudT0LAbrttttUVVWljz76SJK0c+dONTU1cVu5E8vN\nzdW0adM+89ypU6c0ePDg1seDBg3Sxx9/bD0NX6Cts4uLi9O6dev01a9+1dEqd6Ju9BfPPvusysrK\n1NLSokceeUQNDQ363ve+J8/zLPd1OocPH1ZhYaH27NmjoUOHauvWrVq0aJHKy8vl8/lcz8MXuPPO\nO7VgwQItXLhQPp9PDzzwgPr166du3bq5noYgtPV9KCKCd8jQud3wM7Rbt27q27ev+vfvr/Xr1+v5\n559XVVVVl4/Kn//8Z40dO1ZDhw6VJM2cOVP/+te/dO7cOcfLEIjLly8rMTFRL7/8snbu3KlJkyZJ\nkvr16+d4GYIxePBgnT59uvVxbW1t69sEQGd1w+DecsstKigo0NWrV9WrVy+tW7dOTzzxhI4ePWq5\nr9P52te+prfeektnzpyRdP29pCFDhigmJsbxMgTik08+UXp6ui5fvixJWr9+vaZOndrlf5AMN0lJ\nSdqxY4euXbumixcv6pVXXtGECRNczwLadcNbyk8++aR27drV+o1o8ODB2rp1qzZs2GA2rjP65je/\nqYyMDKWnp7feBVi/fr3rWQjQsGHDNH/+fH3/+99XS0uL7rjjDn5LOQw9+OCDOn78uFJSUtTU1KQZ\nM2YoMTHR9SygXT4vhG/K+v1+1dTUhOrlEUKcXXjj/MKX3++XJM4vTLX3tcdvGQAAYIDgAgBggOAC\nAGCA4AIAYIDgAgBggOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDgAgBg\ngOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDg\nAgBggOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDgAgBgwOd5nheqF/f7/aF6aQAAOqWampo2n+cK\nFwAAA1Gh/gA3Kj06N7/fz9mFMc4vfH16Z7C4uNjxEnREZmbmDf+OK1wAAAwQXAAADBBcAAAMEFwA\nAAwQXAAADBBcAAAMEFwAAAwQXAAADBBcAAAMEFwAAAwQXAAADBBcAAAMEFwAAAwQXAAADBBcAAAM\nEFwAAAwQXAAADBBcAAAMEFwAAAwQXAAADBBcAAAMEFwAAAwQXAAADBBcAAAMEFwAAAwQXAAADBBc\nAAAMEFwAAAwQXAAADBBcAAAMEFwAAAxEuR4QTjzPU1ZWloYPH66MjAzV19dr5cqVqq6uVktLi267\n7Tbl5eXppptucj0Vn/P5s7t06ZKys7N19OhRtbS0aNq0aZo/f77rmbiBz5/f/7Vw4UINHDhQubm5\njtYhUAcPHtSOHTvU1NSk+Ph4zZ07Vz179nQ9ywxXuAE6cuSI5syZo4qKitbnfv3rX6u5uVnl5eXa\ntWuXGhoatGHDBocr0Za2zu7pp59WXFycdu/erdLSUpWUlOjdd991uBI30tb5faqoqEhvv/22g1UI\n1sWLF7V582YtWLBABQUFio2NVWlpqetZpoK6wq2vr1dERIS6d+8eqj2d1rZt25SamqovfelLrc/d\neeeduuWWWxQRcf3nFr/frw8++MDVRNxAW2eXnZ2t5uZmSdLp06fV2Nio3r17u5qIdrR1fpJUVVWl\nN954Q2lpabp48aKjdQjUoUOHlJCQoLi4OEnS+PHjlZeXp1mzZsnn8zleZ6PdK9wPPvhADz/8sLKy\nsvSXv/xFU6ZM0ZQpU/Taa69Z7es0cnNzNW3atM88d/fddyshIUGS9NFHH2nLli2aPHmyi3loR1tn\n5/P5FBUVpaVLl+o73/mOEhMTW88SnUtb51dbW6tVq1bpqaeeUmRkpKNlCMbZs2cVExPT+rh///6q\nq6tTfX29w1W22g1uXl6efvjDHyoxMVGLFy/W7373O5WVlXHb9HOqq6s1c+ZMzZo1S+PHj3c9B0F4\n6qmnVFVVpQsXLqiwsND1HASgqalJS5Ys0YoVKzRw4EDXcxAgz/PafP7TO4RdQbu3lFtaWpSYmChJ\n+utf/6oBAwZc/4ei+F2rT73yyitauXKlHn/8cSUnJ7uegwC98cYbGjFihOLi4hQdHa2pU6fqD3/4\ng+tZCEB1dbVOnDih1atXS5LOnDmj5uZmNTQ0aNWqVY7X4UYGDBigo0ePtj4+d+6coqOj1aNHD4er\nbLX7o0VCQoKys7PV0tLS+sm9ceNG3XzzzSbjOru9e/cqPz9fmzZtIrZhpqKiQoWFhfI8T42Njaqo\nqNBdd93lehYCcPvtt+vAgQMqLy9XeXm50tLSNGXKFGLbyd166606evSoamtrJUmvv/66xowZ43iV\nrXYvVfPz87V///7PXPLHxcUpPT095MPCwdq1a+V5nnJyclqfGzt2rPLy8hyuQiCWL1+uvLw8JScn\ny+fzKSkpSbNnz3Y9C/if1adPH82bN0+FhYVqbm5WbGysMjMzXc8y5fNudGP9/4Hf71dNTU2oXh4h\nxNmFN84vfPn9fklScXGx4yXoiMzMzBt+7XWdd6sBAHCI4AIAYIDgAgBggOACAGCA4AIAYIDgAgBg\ngOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDg\nAgBggOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDgAgBggOACAGCA4AIA\nYIDgAgBggOACAGCA4AIAYMDneZ4Xqhf3+/2hemkAADqlmpqaNp8PaXABAMB13FIGAMAAwQUAwADB\nBQDAAMEFAMAAwQUAwADBBQDAAMHtgJaWFuXm5mrGjBlKT0/XsWPHXE9CkA4ePKj09HTXMxCkpqYm\nLVu2TD/4wQ80ffp0/fGPf3Q9CQFqbm5WVlaW0tLS9OCDD+qf//yn60nmCG4HVFZWqrGxUS+++KIe\ne+wxrV692vUkBKGoqEg5OTlqaGhwPQVB2rVrl/r166cXXnhBxcXF+tnPfuZ6EgL02muvSZJKSkr0\n6KOP6he/+IXjRfYIbge88847GjdunCRpzJgxqq6udrwIwRg6dKieeeYZ1zPQAZMnT9YjjzwiSfI8\nT5GRkY4XIVATJkxo/QHp5MmT6tOnj+NF9qJcDwhHly9fVq9evVofR0ZG6tq1a4qK4j9nOJg0aZJO\nnDjhegY6IDo6WtL1r8HFixfr0UcfdbwIwYiKitJPfvIT7du3T7/61a9czzHHFW4H9OrVS1euXGl9\n3NLSQmwBI6dOndLs2bOVkpKi5ORk13MQpDVr1ujVV1/V448/rqtXr7qeY4rgdsDYsWP1pz/9SZL0\n3nvvacSIEY4XAV3DmTNnNG/ePC1btkzTp093PQdBKCsr04YNGyRJPXv2lM/nU0RE10oQl2UdMHHi\nRL355ptKS0uT53l68sknXU8CuoRnn31WFy9e1Pr167V+/XpJ138J7qabbnK8DF/kvvvuU1ZWlmbO\nnKlr165pxYoVXe7c+L8FAQBgoGtdzwMA4AjBBQDAAMEFAMAAwQUAwADBBQDAAMEFAMAAwQUAwADB\nBQDAwH8BCoTS1WbM11MAAAAASUVORK5CYII=\n",
183 | "text/plain": [
184 | ""
185 | ]
186 | },
187 | "metadata": {},
188 | "output_type": "display_data"
189 | }
190 | ],
191 | "source": [
192 | "gw.show_environment()\n",
193 | "plt.show()"
194 | ]
195 | }
196 | ],
197 | "metadata": {
198 | "kernelspec": {
199 | "display_name": "Python 3",
200 | "language": "python",
201 | "name": "python3"
202 | },
203 | "language_info": {
204 | "codemirror_mode": {
205 | "name": "ipython",
206 | "version": 3
207 | },
208 | "file_extension": ".py",
209 | "mimetype": "text/x-python",
210 | "name": "python",
211 | "nbconvert_exporter": "python",
212 | "pygments_lexer": "ipython3",
213 | "version": "3.6.1"
214 | }
215 | },
216 | "nbformat": 4,
217 | "nbformat_minor": 2
218 | }
219 |
--------------------------------------------------------------------------------
/source code/0-MDP Environment (Chapter 3)/1-introduction-to-gridworld-environment.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# GridWorld Environment"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {
14 | "collapsed": true
15 | },
16 | "outputs": [],
17 | "source": [
18 | "import matplotlib.pyplot as plt\n",
19 | "import pandas as pd\n",
20 | "from gridWorldEnvironment import GridWorld"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": 2,
26 | "metadata": {},
27 | "outputs": [
28 | {
29 | "name": "stdout",
30 | "output_type": "stream",
31 | "text": [
32 | "Actions: ('U', 'D', 'L', 'R')\n",
33 | "States: [ 1 2 3 4 5 6 7 8 9 10 11 12 13 14]\n"
34 | ]
35 | }
36 | ],
37 | "source": [
38 | "# creating gridworld environment\n",
39 | "gw = GridWorld()\n",
40 | "\n",
41 | "print(\"Actions: \", gw.actions)\n",
42 | "print(\"States: \", gw.states)"
43 | ]
44 | },
45 | {
46 | "cell_type": "markdown",
47 | "metadata": {},
48 | "source": [
49 | "### State Transitions\n",
50 | "- All possible state transitions in deterministic gridworld\n",
51 | " - Each transition is a quadruple of (```state```,```action```, ```next state```, ```reward```)\n",
52 | " - For instance, first row implies that in if the agent performs action ```U``` (upwards) in state ```1```, it ends up at state ```1``` again (not moving) with -1 reward"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": 12,
58 | "metadata": {
59 | "scrolled": true
60 | },
61 | "outputs": [
62 | {
63 | "data": {
64 | "text/html": [
65 | "\n",
66 | "\n",
79 | "
\n",
80 | " \n",
81 | " \n",
82 | " | \n",
83 | " State | \n",
84 | " Action | \n",
85 | " Next State | \n",
86 | " Reward | \n",
87 | "
\n",
88 | " \n",
89 | " \n",
90 | " \n",
91 | " 0 | \n",
92 | " 1 | \n",
93 | " U | \n",
94 | " 1 | \n",
95 | " -1 | \n",
96 | "
\n",
97 | " \n",
98 | " 1 | \n",
99 | " 1 | \n",
100 | " D | \n",
101 | " 5 | \n",
102 | " -1 | \n",
103 | "
\n",
104 | " \n",
105 | " 2 | \n",
106 | " 1 | \n",
107 | " R | \n",
108 | " 2 | \n",
109 | " -1 | \n",
110 | "
\n",
111 | " \n",
112 | " 3 | \n",
113 | " 1 | \n",
114 | " L | \n",
115 | " 0 | \n",
116 | " -1 | \n",
117 | "
\n",
118 | " \n",
119 | " 4 | \n",
120 | " 2 | \n",
121 | " U | \n",
122 | " 2 | \n",
123 | " -1 | \n",
124 | "
\n",
125 | " \n",
126 | "
\n",
127 | "
"
128 | ],
129 | "text/plain": [
130 | " State Action Next State Reward\n",
131 | "0 1 U 1 -1\n",
132 | "1 1 D 5 -1\n",
133 | "2 1 R 2 -1\n",
134 | "3 1 L 0 -1\n",
135 | "4 2 U 2 -1"
136 | ]
137 | },
138 | "execution_count": 12,
139 | "metadata": {},
140 | "output_type": "execute_result"
141 | }
142 | ],
143 | "source": [
144 | "pd.DataFrame(gw.transitions, columns = [\"State\", \"Action\", \"Next State\", \"Reward\"]).head()"
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": 14,
150 | "metadata": {},
151 | "outputs": [
152 | {
153 | "name": "stdout",
154 | "output_type": "stream",
155 | "text": [
156 | "(1, -1)\n",
157 | "(2, -1)\n"
158 | ]
159 | }
160 | ],
161 | "source": [
162 | "print(gw.state_transition(1, \"U\"))\n",
163 | "print(gw.state_transition(3, \"L\"))"
164 | ]
165 | },
166 | {
167 | "cell_type": "markdown",
168 | "metadata": {},
169 | "source": [
170 | "### Show environment\n",
171 | "- Visualized environment is shown as table\n",
172 | "- Note that terminal states ```(0,0) and (3,3)``` are added"
173 | ]
174 | },
175 | {
176 | "cell_type": "code",
177 | "execution_count": 3,
178 | "metadata": {},
179 | "outputs": [
180 | {
181 | "data": {
182 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAdwAAAFJCAYAAAAxCJwFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAFnJJREFUeJzt3X1U1vX9x/HXBah58BZDdIlnuKnnWi7NFqftZB2HpukY\njtyRpugUjmflTR3TTYRgNgw967i1xKWgTZdJDQ2ciU1mudYOO9XKHRy5pZtmGunxXrkTvvvDE+dX\nB+m6OL/r/eEaz8dfXpeni1d9gCff78U5+TzP8wQAAEIqwvUAAAC6AoILAIABggsAgAGCCwCAAYIL\nAIABggsAgIGoUL643+9XcXFxKD8EQiQzM1M1NTWuZ6CD/H4/5xem/H6/JHF+Yaq9rz2ucAEAMEBw\nAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEA\nMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBA\ncAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEAMEBwAQAwEOV6QDg6ePCgduzYoaamJsXH\nx2vu3Lnq2bOn61kIgud5ysrK0vDhw5WRkeF6DgJUXl6uTZs2yefzqWfPnsrOztbXv/5117MQoOef\nf17bt2+Xz+dTfHy88vPzNWDAANezzHCFG6SLFy9q8+bNWrBggQoKChQbG6vS0lLXsxCEI0eOaM6c\nOaqoqHA9BUE4evSofv7zn6u4uFjl5eV66KGHtGjRItezEKDq6mpt3rxZJSUl2r17t7785S/r6aef\ndj3LVMDBbWlpCeWOsHHo0CElJCQoLi5OkjR+/HhVVVXJ8zzHyxCobdu2KTU1Vffff7/rKQhC9+7d\nlZ+fr4EDB0qSRo0apTNnzqixsdHxMgRi1KhRevXVV9W7d281NDSotrZW/fr1cz3LVLu3lD/88EMV\nFBSourpaUVFRamlp0YgRI5SVlaWEhASrjZ3K2bNnFRMT0/q4f//+qqurU319PbeVw0Rubq4kqaqq\nyvESBGPIkCEaMmSIpOtvCRQUFOjb3/62unfv7ngZAtWtWzdVVlYqOztb3bt31+LFi11PMtVucLOz\ns/XYY49p9OjRrc+99957ysrKUklJScjHdUY3upKNiODuPGDh6tWrWr58uT7++GMVFxe7noMgTZgw\nQRMmTNBLL72kjIwM7du3r8t8/2z337KxsfEzsZWkMWPGhHRQZzdgwACdP3++9fG5c+cUHR2tHj16\nOFwFdA0nT55UWlqaIiMjtXXrVvXp08f1JATo2LFjevvtt1sfP/DAAzp58qQuXLjgcJWtdq9wR44c\nqaysLI0bN069e/fWlStXdODAAY0cOdJqX6dz66236sUXX1Rtba3i4uL0+uuvd/kfQgAL58+f16xZ\ns5SamqqFCxe6noMgnT59WkuWLFFZWZliYmL0+9//XsOHD1f//v1dTzPTbnB/+tOfqrKyUu+8844u\nX76sXr16afz48Zo4caLVvk6nT58+mjdvngoLC9Xc3KzY2FhlZma6ngX8z9u+fbtOnTqlffv2ad++\nfa3P/+Y3v+lS37TD1Te+8Q396Ec/0uzZsxUZGamBAweqsLDQ9SxTPi+Ev17r9/t5jyVMZWZmqqam\nxvUMdJDf7+f8wpTf75ckzi9Mtfe11zXeqQYAwDGCCwCAAYILAIABggsAgAGCCwCAAYILAIABggsA\ngAGCCwCAAYILAIABggsAgAGCCwCAAYILAIABggsAgAGCCwCAAYILAIABggsAgAGCCwCAAYILAIAB\nggsAgAGCCwCAAYILAIABggsAgAGCCwCAAYILAIABggsAgAGCCwCAAYILAIABggsAgAGCCwCAAYIL\nAIABggsAgAGCCwCAAZ/neV6oXtzv94fqpQEA6JRqamrafD7K1QdG5+b3+zm7MMb5ha9PL1Q4v/DU\n3oUmt5QBADBAcAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEA\nMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBA\ncAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADAQ5XpAOKusrNSPf/xj\n/e1vf3M9BUFYvXq19u7dq759+0qSEhIS9Mtf/tLxKgTi8OHDys/P16VLlxQREaEnnnhCo0aNcj0L\nASgrK9Nzzz3X+vjSpUuqra3VgQMHdPPNNztcZofgdtB//vMfrVmzRp7nuZ6CIL377rtau3atxo4d\n63oKglBXV6eMjAytWrVK9957ryorK7V06VLt3bvX9TQEYNq0aZo2bZokqampSbNmzdL8+fO7TGwl\nbil3SF1dnZYtW6bly5e7noIgNTY26h//+Ic2b96s7373u1q0aJFOnjzpehYC8Oabbyo+Pl733nuv\nJCkpKYk7E2GqqKhIMTExSktLcz3FFMHtgNzcXM2YMUMjR450PQVBqq2t1V133aUlS5aovLxco0eP\n1sMPP8ydijDw73//W7GxsVqxYoVSU1M1d+5cNTc3u56FIJ09e1bPPfecVqxY4XqKOYIbpG3btikq\nKkrTp093PQUdEB8fr6KiIg0bNkw+n08ZGRk6fvy4Tpw44XoavsC1a9d04MABzZgxQzt37my9JdnY\n2Oh6GoLw0ksvKSkpSfHx8a6nmGv3Pdz09HQ1NTV95jnP8+Tz+VRSUhLSYZ3Vyy+/rPr6eqWkpKip\nqan1zxs3blRcXJzrefgC77//vt5///3W95Kk65/T3bp1c7gKgRg4cKCGDRum0aNHS5ImTJignJwc\nffjhh/rKV77ieB0CtWfPHuXk5Lie4US7wV26dKlycnJUWFioyMhIq02dWmlpaeufT5w4oeTkZJWX\nlztchGBERERo1apVuuOOOxQfH68XXnhBI0eO1KBBg1xPwxe45557tGbNGlVXV2vUqFF666235PP5\nNGTIENfTEKALFy7o+PHjuv32211PcaLd4I4ePVopKSk6fPiwJk6caLUJCJkRI0YoJydHDz30kJqb\nmzVo0CCtXbvW9SwEIDY2VoWFhVq5cqXq6urUvXt3PfPMM+rRo4fraQjQsWPHFBsb22XvKPm8EP62\niN/vV01NTaheHiHE2YU3zi98+f1+SeL8wlR7X3v80hQAAAYILgAABgguAAAGCC4AAAYILgAABggu\nAAAGCC4AAAYILgAABgguAAAGCC4AAAYILgAABgguAAAGCC4AAAYILgAABgguAAAGCC4AAAYILgAA\nBgguAAAGCC4AAAYILgAABgguAAAGCC4AAAYILgAABgguAAAGCC4AAAYILgAABgguAAAGCC4AAAYI\nLgAABgguAAAGCC4AAAYILgAABgguAAAGfJ7neaF6cb/fH6qXBgCgU6qpqWnz+ShXHxidm9/v5+zC\nGOcXvj69UOH8wlN7F5rcUgYAwADBBQDAAMEFAMAAwQUAwADBBQDAAMEFAMAAwQUAwADBBQDAAMEF\nAMAAwQUAwADBBQDAAMEFAMAAwQUAwADBBQDAAMEFAMAAwQUAwADBBQDAAMEFAMAAwQUAwADBBQDA\nAMEFAMAAwQUAwADBBQDAAMEFAMAAwQUAwADBBQDAAMEFAMAAwQUAwADBBQDAAMEFAMAAwQUAwADB\n7YB9+/YpOTlZKSkpSk9P1/Hjx11PQhB++9vfatKkSUpJSdGSJUt0/vx515PwBTzP0/Lly7Vp0yZJ\nUnNzs/Lz8zV58mRNnDhR27dvd7wQN/L5s/vUqVOnNG7cOJ09e9bRMnsEN0j19fVatmyZ1q1bp/Ly\nciUlJSk/P9/1LASoqqpKRUVF2rJli8rLy3XPPfcoNzfX9Sy048iRI5ozZ44qKipanyspKdGxY8e0\ne/dulZaWasuWLfr73//ucCXa0tbZSVJZWZlmzpypTz75xNEyN4IObmNjYyh2hI3m5mZ5nqdLly5J\nkq5cuaIePXo4XoVAHTp0SN/61rc0aNAgSdJ9992n/fv3d/nP685s27ZtSk1N1f3339/6XGVlpVJT\nUxUVFaW+fftq6tSp2rVrl8OVaEtbZ1dbW6vKykpt3LjR4TI3bhjc/fv3a/z48Zo4caL27NnT+nxm\nZqbJsM4qOjpaK1euVFpamu6++25t27ZNS5cudT0LAbrttttUVVWljz76SJK0c+dONTU1cVu5E8vN\nzdW0adM+89ypU6c0ePDg1seDBg3Sxx9/bD0NX6Cts4uLi9O6dev01a9+1dEqd6Ju9BfPPvusysrK\n1NLSokceeUQNDQ363ve+J8/zLPd1OocPH1ZhYaH27NmjoUOHauvWrVq0aJHKy8vl8/lcz8MXuPPO\nO7VgwQItXLhQPp9PDzzwgPr166du3bq5noYgtPV9KCKCd8jQud3wM7Rbt27q27ev+vfvr/Xr1+v5\n559XVVVVl4/Kn//8Z40dO1ZDhw6VJM2cOVP/+te/dO7cOcfLEIjLly8rMTFRL7/8snbu3KlJkyZJ\nkvr16+d4GYIxePBgnT59uvVxbW1t69sEQGd1w+DecsstKigo0NWrV9WrVy+tW7dOTzzxhI4ePWq5\nr9P52te+prfeektnzpyRdP29pCFDhigmJsbxMgTik08+UXp6ui5fvixJWr9+vaZOndrlf5AMN0lJ\nSdqxY4euXbumixcv6pVXXtGECRNczwLadcNbyk8++aR27drV+o1o8ODB2rp1qzZs2GA2rjP65je/\nqYyMDKWnp7feBVi/fr3rWQjQsGHDNH/+fH3/+99XS0uL7rjjDn5LOQw9+OCDOn78uFJSUtTU1KQZ\nM2YoMTHR9SygXT4vhG/K+v1+1dTUhOrlEUKcXXjj/MKX3++XJM4vTLX3tcdvGQAAYIDgAgBggOAC\nAGCA4AIAYIDgAgBggOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDgAgBg\ngOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDg\nAgBggOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDgAgBgwOd5nheqF/f7/aF6aQAAOqWampo2n+cK\nFwAAA1Gh/gA3Kj06N7/fz9mFMc4vfH16Z7C4uNjxEnREZmbmDf+OK1wAAAwQXAAADBBcAAAMEFwA\nAAwQXAAADBBcAAAMEFwAAAwQXAAADBBcAAAMEFwAAAwQXAAADBBcAAAMEFwAAAwQXAAADBBcAAAM\nEFwAAAwQXAAADBBcAAAMEFwAAAwQXAAADBBcAAAMEFwAAAwQXAAADBBcAAAMEFwAAAwQXAAADBBc\nAAAMEFwAAAwQXAAADBBcAAAMEFwAAAxEuR4QTjzPU1ZWloYPH66MjAzV19dr5cqVqq6uVktLi267\n7Tbl5eXppptucj0Vn/P5s7t06ZKys7N19OhRtbS0aNq0aZo/f77rmbiBz5/f/7Vw4UINHDhQubm5\njtYhUAcPHtSOHTvU1NSk+Ph4zZ07Vz179nQ9ywxXuAE6cuSI5syZo4qKitbnfv3rX6u5uVnl5eXa\ntWuXGhoatGHDBocr0Za2zu7pp59WXFycdu/erdLSUpWUlOjdd991uBI30tb5faqoqEhvv/22g1UI\n1sWLF7V582YtWLBABQUFio2NVWlpqetZpoK6wq2vr1dERIS6d+8eqj2d1rZt25SamqovfelLrc/d\neeeduuWWWxQRcf3nFr/frw8++MDVRNxAW2eXnZ2t5uZmSdLp06fV2Nio3r17u5qIdrR1fpJUVVWl\nN954Q2lpabp48aKjdQjUoUOHlJCQoLi4OEnS+PHjlZeXp1mzZsnn8zleZ6PdK9wPPvhADz/8sLKy\nsvSXv/xFU6ZM0ZQpU/Taa69Z7es0cnNzNW3atM88d/fddyshIUGS9NFHH2nLli2aPHmyi3loR1tn\n5/P5FBUVpaVLl+o73/mOEhMTW88SnUtb51dbW6tVq1bpqaeeUmRkpKNlCMbZs2cVExPT+rh///6q\nq6tTfX29w1W22g1uXl6efvjDHyoxMVGLFy/W7373O5WVlXHb9HOqq6s1c+ZMzZo1S+PHj3c9B0F4\n6qmnVFVVpQsXLqiwsND1HASgqalJS5Ys0YoVKzRw4EDXcxAgz/PafP7TO4RdQbu3lFtaWpSYmChJ\n+utf/6oBAwZc/4ei+F2rT73yyitauXKlHn/8cSUnJ7uegwC98cYbGjFihOLi4hQdHa2pU6fqD3/4\ng+tZCEB1dbVOnDih1atXS5LOnDmj5uZmNTQ0aNWqVY7X4UYGDBigo0ePtj4+d+6coqOj1aNHD4er\nbLX7o0VCQoKys7PV0tLS+sm9ceNG3XzzzSbjOru9e/cqPz9fmzZtIrZhpqKiQoWFhfI8T42Njaqo\nqNBdd93lehYCcPvtt+vAgQMqLy9XeXm50tLSNGXKFGLbyd166606evSoamtrJUmvv/66xowZ43iV\nrXYvVfPz87V///7PXPLHxcUpPT095MPCwdq1a+V5nnJyclqfGzt2rPLy8hyuQiCWL1+uvLw8JScn\ny+fzKSkpSbNnz3Y9C/if1adPH82bN0+FhYVqbm5WbGysMjMzXc8y5fNudGP9/4Hf71dNTU2oXh4h\nxNmFN84vfPn9fklScXGx4yXoiMzMzBt+7XWdd6sBAHCI4AIAYIDgAgBggOACAGCA4AIAYIDgAgBg\ngOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDg\nAgBggOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDgAgBggOACAGCA4AIA\nYIDgAgBggOACAGCA4AIAYMDneZ4Xqhf3+/2hemkAADqlmpqaNp8PaXABAMB13FIGAMAAwQUAwADB\nBQDAAMEFAMAAwQUAwADBBQDAAMHtgJaWFuXm5mrGjBlKT0/XsWPHXE9CkA4ePKj09HTXMxCkpqYm\nLVu2TD/4wQ80ffp0/fGPf3Q9CQFqbm5WVlaW0tLS9OCDD+qf//yn60nmCG4HVFZWqrGxUS+++KIe\ne+wxrV692vUkBKGoqEg5OTlqaGhwPQVB2rVrl/r166cXXnhBxcXF+tnPfuZ6EgL02muvSZJKSkr0\n6KOP6he/+IXjRfYIbge88847GjdunCRpzJgxqq6udrwIwRg6dKieeeYZ1zPQAZMnT9YjjzwiSfI8\nT5GRkY4XIVATJkxo/QHp5MmT6tOnj+NF9qJcDwhHly9fVq9evVofR0ZG6tq1a4qK4j9nOJg0aZJO\nnDjhegY6IDo6WtL1r8HFixfr0UcfdbwIwYiKitJPfvIT7du3T7/61a9czzHHFW4H9OrVS1euXGl9\n3NLSQmwBI6dOndLs2bOVkpKi5ORk13MQpDVr1ujVV1/V448/rqtXr7qeY4rgdsDYsWP1pz/9SZL0\n3nvvacSIEY4XAV3DmTNnNG/ePC1btkzTp093PQdBKCsr04YNGyRJPXv2lM/nU0RE10oQl2UdMHHi\nRL355ptKS0uT53l68sknXU8CuoRnn31WFy9e1Pr167V+/XpJ138J7qabbnK8DF/kvvvuU1ZWlmbO\nnKlr165pxYoVXe7c+L8FAQBgoGtdzwMA4AjBBQDAAMEFAMAAwQUAwADBBQDAAMEFAMAAwQUAwADB\nBQDAwH8BCoTS1WbM11MAAAAASUVORK5CYII=\n",
183 | "text/plain": [
184 | ""
185 | ]
186 | },
187 | "metadata": {},
188 | "output_type": "display_data"
189 | }
190 | ],
191 | "source": [
192 | "gw.show_environment()\n",
193 | "plt.show()"
194 | ]
195 | }
196 | ],
197 | "metadata": {
198 | "kernelspec": {
199 | "display_name": "Python 3",
200 | "language": "python",
201 | "name": "python3"
202 | },
203 | "language_info": {
204 | "codemirror_mode": {
205 | "name": "ipython",
206 | "version": 3
207 | },
208 | "file_extension": ".py",
209 | "mimetype": "text/x-python",
210 | "name": "python",
211 | "nbconvert_exporter": "python",
212 | "pygments_lexer": "ipython3",
213 | "version": "3.6.1"
214 | }
215 | },
216 | "nbformat": 4,
217 | "nbformat_minor": 2
218 | }
219 |
--------------------------------------------------------------------------------
/source code/0-MDP Environment (Chapter 3)/__pycache__/gridWorldEnvironment.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buomsoo-kim/Tabular-RL-with-Python/0157c05126821524dc9d744613b9aa8dd5d47232/source code/0-MDP Environment (Chapter 3)/__pycache__/gridWorldEnvironment.cpython-36.pyc
--------------------------------------------------------------------------------
/source code/0-MDP Environment (Chapter 3)/gridWorldEnvironment.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import seaborn
4 | from matplotlib.colors import ListedColormap
5 |
6 | class GridWorld:
7 | def __init__(self, gamma = 1.0, theta = 0.5):
8 | self.actions = ("U", "D", "L", "R")
9 | self.states = np.arange(1, 15)
10 | self.transitions = pd.read_csv("gridworld.txt", header = None, sep = "\t").values
11 | self.gamma = gamma
12 | self.theta = theta
13 |
14 | def state_transition(self, state, action):
15 | next_state, reward = None, None
16 | for tr in self.transitions:
17 | if tr[0] == state and tr[1] == action:
18 | next_state = tr[2]
19 | reward = tr[3]
20 | return next_state, reward
21 |
22 | def show_environment(self):
23 | all_states = np.concatenate(([0], self.states, [0])).reshape(4,4)
24 | colors = []
25 | # colors = ["#ffffff"]
26 | for i in range(len(self.states) + 1):
27 | if i == 0:
28 | colors.append("#c4c4c4")
29 | else:
30 | colors.append("#ffffff")
31 |
32 | cmap = ListedColormap(seaborn.color_palette(colors).as_hex())
33 | ax = seaborn.heatmap(all_states, cmap = cmap, \
34 | annot = True, linecolor = "#282828", linewidths = 0.2, \
35 | cbar = False)
--------------------------------------------------------------------------------
/source code/0-MDP Environment (Chapter 3)/gridworld.txt:
--------------------------------------------------------------------------------
1 | 1 U 1 -1
2 | 1 D 5 -1
3 | 1 R 2 -1
4 | 1 L 0 -1
5 | 2 U 2 -1
6 | 2 D 6 -1
7 | 2 R 3 -1
8 | 2 L 1 -1
9 | 3 U 3 -1
10 | 3 D 7 -1
11 | 3 R 3 -1
12 | 3 L 2 -1
13 | 4 U 0 -1
14 | 4 D 8 -1
15 | 4 R 5 -1
16 | 4 L 4 -1
17 | 5 U 1 -1
18 | 5 D 9 -1
19 | 5 R 6 -1
20 | 5 L 4 -1
21 | 6 U 2 -1
22 | 6 D 10 -1
23 | 6 R 7 -1
24 | 6 L 5 -1
25 | 7 U 3 -1
26 | 7 D 11 -1
27 | 7 R 7 -1
28 | 7 L 6 -1
29 | 8 U 4 -1
30 | 8 D 12 -1
31 | 8 R 9 -1
32 | 8 L 8 -1
33 | 9 U 5 -1
34 | 9 D 13 -1
35 | 9 R 10 -1
36 | 9 L 8 -1
37 | 10 U 6 -1
38 | 10 D 14 -1
39 | 10 R 11 -1
40 | 10 L 9 -1
41 | 11 U 7 -1
42 | 11 D 0 -1
43 | 11 R 11 -1
44 | 11 L 10 -1
45 | 12 U 8 -1
46 | 12 D 12 -1
47 | 12 R 13 -1
48 | 12 L 12 -1
49 | 13 U 9 -1
50 | 13 D 13 -1
51 | 13 R 14 -1
52 | 13 L 12 -1
53 | 14 U 10 -1
54 | 14 D 14 -1
55 | 14 R 0 -1
56 | 14 L 13 -1
--------------------------------------------------------------------------------
/source code/1-Dynamic Programming (Chapter 4)/gridWorldEnvironment.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import seaborn
4 | from matplotlib.colors import ListedColormap
5 |
6 | class GridWorld:
7 | def __init__(self, gamma = 1.0, theta = 0.5):
8 | self.actions = ("U", "D", "L", "R")
9 | self.states = np.arange(1, 15)
10 | self.transitions = pd.read_csv("gridworld.txt", header = None, sep = "\t").values
11 | self.gamma = gamma
12 | self.theta = theta
13 |
14 | def state_transition(self, state, action):
15 | next_state, reward = None, None
16 | for tr in self.transitions:
17 | if tr[0] == state and tr[1] == action:
18 | next_state = tr[2]
19 | reward = tr[3]
20 | return next_state, reward
21 |
22 | def show_environment(self):
23 | all_states = np.concatenate(([0], self.states, [0])).reshape(4,4)
24 | colors = []
25 | # colors = ["#ffffff"]
26 | for i in range(len(self.states) + 1):
27 | if i == 0:
28 | colors.append("#c4c4c4")
29 | else:
30 | colors.append("#ffffff")
31 |
32 | cmap = ListedColormap(seaborn.color_palette(colors).as_hex())
33 | ax = seaborn.heatmap(all_states, cmap = cmap, \
34 | annot = True, linecolor = "#282828", linewidths = 0.2, \
35 | cbar = False)
--------------------------------------------------------------------------------
/source code/1-Dynamic Programming (Chapter 4)/gridworld.txt:
--------------------------------------------------------------------------------
1 | 1 U 1 -1
2 | 1 D 5 -1
3 | 1 R 2 -1
4 | 1 L 0 -1
5 | 2 U 2 -1
6 | 2 D 6 -1
7 | 2 R 3 -1
8 | 2 L 1 -1
9 | 3 U 3 -1
10 | 3 D 7 -1
11 | 3 R 3 -1
12 | 3 L 2 -1
13 | 4 U 0 -1
14 | 4 D 8 -1
15 | 4 R 5 -1
16 | 4 L 4 -1
17 | 5 U 1 -1
18 | 5 D 9 -1
19 | 5 R 6 -1
20 | 5 L 4 -1
21 | 6 U 2 -1
22 | 6 D 10 -1
23 | 6 R 7 -1
24 | 6 L 5 -1
25 | 7 U 3 -1
26 | 7 D 11 -1
27 | 7 R 7 -1
28 | 7 L 6 -1
29 | 8 U 4 -1
30 | 8 D 12 -1
31 | 8 R 9 -1
32 | 8 L 8 -1
33 | 9 U 5 -1
34 | 9 D 13 -1
35 | 9 R 10 -1
36 | 9 L 8 -1
37 | 10 U 6 -1
38 | 10 D 14 -1
39 | 10 R 11 -1
40 | 10 L 9 -1
41 | 11 U 7 -1
42 | 11 D 0 -1
43 | 11 R 11 -1
44 | 11 L 10 -1
45 | 12 U 8 -1
46 | 12 D 12 -1
47 | 12 R 13 -1
48 | 12 L 12 -1
49 | 13 U 9 -1
50 | 13 D 13 -1
51 | 13 R 14 -1
52 | 13 L 12 -1
53 | 14 U 10 -1
54 | 14 D 14 -1
55 | 14 R 0 -1
56 | 14 L 13 -1
--------------------------------------------------------------------------------
/source code/2-Monte Carlo Methods (Chapter 5)/.ipynb_checkpoints/3-on-policy-monte-carlo-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# On-policy Monte Carlo\n",
8 | "- Algorithms from ```pp. 82 - 84``` in Sutton & Barto 2017\n",
9 | "- Estimates optimal policy $\\pi \\approx \\pi_*$"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 34,
15 | "metadata": {
16 | "collapsed": true
17 | },
18 | "outputs": [],
19 | "source": [
20 | "import matplotlib.pyplot as plt\n",
21 | "import pandas as pd\n",
22 | "import numpy as np\n",
23 | "import seaborn\n",
24 | "\n",
25 | "from gridWorldEnvironment import GridWorld"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": 35,
31 | "metadata": {
32 | "collapsed": true
33 | },
34 | "outputs": [],
35 | "source": [
36 | "# creating gridworld environment\n",
37 | "gw = GridWorld(gamma = .9)"
38 | ]
39 | },
40 | {
41 | "cell_type": "markdown",
42 | "metadata": {},
43 | "source": [
44 | "### Generate policy\n",
45 | "- We start with equiprobable random policy"
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": 36,
51 | "metadata": {
52 | "collapsed": true
53 | },
54 | "outputs": [],
55 | "source": [
56 | "def generate_random_policy(env):\n",
57 | " pi = dict()\n",
58 | " for state in env.states:\n",
59 | " actions = []\n",
60 | " prob = []\n",
61 | " for action in env.actions:\n",
62 | " actions.append(action)\n",
63 | " prob.append(0.25)\n",
64 | " pi[state] = (actions, prob)\n",
65 | " return pi"
66 | ]
67 | },
68 | {
69 | "cell_type": "markdown",
70 | "metadata": {},
71 | "source": [
72 | "### Create Action Values\n",
73 | "- Initialize all state-action values with 0"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": 37,
79 | "metadata": {
80 | "collapsed": true
81 | },
82 | "outputs": [],
83 | "source": [
84 | "def state_action_value(env):\n",
85 | " q = dict()\n",
86 | " for state, action, next_state, reward in env.transitions:\n",
87 | " q[(state, action)] = 0\n",
88 | " return q"
89 | ]
90 | },
91 | {
92 | "cell_type": "code",
93 | "execution_count": 38,
94 | "metadata": {},
95 | "outputs": [
96 | {
97 | "name": "stdout",
98 | "output_type": "stream",
99 | "text": [
100 | "0\n",
101 | "0\n",
102 | "0\n"
103 | ]
104 | }
105 | ],
106 | "source": [
107 | "q = state_action_value(gw)\n",
108 | "\n",
109 | "print(q[(2, \"U\")])\n",
110 | "print(q[(4, \"L\")])\n",
111 | "print(q[(10, \"R\")])"
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": 39,
117 | "metadata": {
118 | "collapsed": true,
119 | "scrolled": true
120 | },
121 | "outputs": [],
122 | "source": [
123 | "pi = generate_random_policy(gw)"
124 | ]
125 | },
126 | {
127 | "cell_type": "markdown",
128 | "metadata": {},
129 | "source": [
130 | "### Generate episode\n",
131 | "- Generate episode based on current policy ($\\pi$)"
132 | ]
133 | },
134 | {
135 | "cell_type": "code",
136 | "execution_count": 40,
137 | "metadata": {
138 | "collapsed": true,
139 | "scrolled": true
140 | },
141 | "outputs": [],
142 | "source": [
143 | "def generate_episode(env, policy):\n",
144 | " episode = []\n",
145 | " done = False\n",
146 | " current_state = np.random.choice(env.states)\n",
147 | " action = np.random.choice(policy[current_state][0], p = policy[current_state][1])\n",
148 | " episode.append((current_state, action, -1))\n",
149 | " \n",
150 | " while not done:\n",
151 | " next_state, reward = gw.state_transition(current_state, action)\n",
152 | " action = np.random.choice(policy[current_state][0], p = policy[current_state][1])\n",
153 | " episode.append((next_state, action, reward))\n",
154 | " \n",
155 | " if next_state == 0:\n",
156 | " done = True\n",
157 | " current_state = next_state\n",
158 | " \n",
159 | " return episode"
160 | ]
161 | },
162 | {
163 | "cell_type": "code",
164 | "execution_count": 41,
165 | "metadata": {},
166 | "outputs": [
167 | {
168 | "data": {
169 | "text/plain": [
170 | "[(6, 'U', -1),\n",
171 | " (2, 'D', -1),\n",
172 | " (6, 'R', -1),\n",
173 | " (7, 'U', -1),\n",
174 | " (3, 'D', -1),\n",
175 | " (7, 'D', -1),\n",
176 | " (11, 'U', -1),\n",
177 | " (7, 'U', -1),\n",
178 | " (3, 'L', -1),\n",
179 | " (2, 'D', -1),\n",
180 | " (6, 'D', -1),\n",
181 | " (10, 'D', -1),\n",
182 | " (14, 'R', -1),\n",
183 | " (0, 'R', -1)]"
184 | ]
185 | },
186 | "execution_count": 41,
187 | "metadata": {},
188 | "output_type": "execute_result"
189 | }
190 | ],
191 | "source": [
192 | "generate_episode(gw, pi)"
193 | ]
194 | },
195 | {
196 | "cell_type": "markdown",
197 | "metadata": {},
198 | "source": [
199 | "### On-policy Monte Carlo Control "
200 | ]
201 | },
202 | {
203 | "cell_type": "code",
204 | "execution_count": 42,
205 | "metadata": {
206 | "collapsed": true
207 | },
208 | "outputs": [],
209 | "source": [
210 | "# first-visit MC\n",
211 | "def on_policy_mc(env, pi, e, num_iter):\n",
212 | " Q = state_action_value(env)\n",
213 | " returns = dict()\n",
214 | " for s, a in Q:\n",
215 | " returns[(s,a)] = []\n",
216 | " \n",
217 | " for i in range(num_iter):\n",
218 | " episode = generate_episode(env, pi)\n",
219 | " already_visited = set({0})\n",
220 | " for s, a, r in episode:\n",
221 | " if s not in already_visited:\n",
222 | " already_visited.add(s)\n",
223 | " idx = episode.index((s, a, r))\n",
224 | " G = 0\n",
225 | " j = 1\n",
226 | " while j + idx < len(episode):\n",
227 | " G = env.gamma * (G + episode[j + idx][-1])\n",
228 | " j += 1\n",
229 | " returns[(s,a)].append(G)\n",
230 | " Q[(s,a)] = np.mean(returns[(s,a)])\n",
231 | " for s, _, _ in episode:\n",
232 | " if s != 0:\n",
233 | " actions = []\n",
234 | " action_values = []\n",
235 | " prob = []\n",
236 | "\n",
237 | " for a in env.actions:\n",
238 | " actions.append(a)\n",
239 | " action_values.append(Q[s,a]) \n",
240 | " for i in range(len(action_values)):\n",
241 | " if i == np.argmax(action_values):\n",
242 | " prob.append(1 - e + e/len(actions))\n",
243 | " else:\n",
244 | " prob.append(e/len(actions)) \n",
245 | " pi[s] = (actions, prob)\n",
246 | " return Q, pi"
247 | ]
248 | },
249 | {
250 | "cell_type": "code",
251 | "execution_count": 43,
252 | "metadata": {},
253 | "outputs": [
254 | {
255 | "name": "stdout",
256 | "output_type": "stream",
257 | "text": [
258 | "Wall time: 1min 8s\n"
259 | ]
260 | }
261 | ],
262 | "source": [
263 | "# Obtained Estimates for Q & pi after 100000 iterations\n",
264 | "Q_hat, pi_hat = on_policy_mc(gw, generate_random_policy(gw), 0.2, 100000)"
265 | ]
266 | },
267 | {
268 | "cell_type": "code",
269 | "execution_count": 44,
270 | "metadata": {},
271 | "outputs": [
272 | {
273 | "data": {
274 | "text/plain": [
275 | "{1: (['U', 'D', 'L', 'R'], [0.05, 0.05, 0.8500000000000001, 0.05]),\n",
276 | " 2: (['U', 'D', 'L', 'R'], [0.05, 0.05, 0.8500000000000001, 0.05]),\n",
277 | " 3: (['U', 'D', 'L', 'R'], [0.05, 0.8500000000000001, 0.05, 0.05]),\n",
278 | " 4: (['U', 'D', 'L', 'R'], [0.8500000000000001, 0.05, 0.05, 0.05]),\n",
279 | " 5: (['U', 'D', 'L', 'R'], [0.8500000000000001, 0.05, 0.05, 0.05]),\n",
280 | " 6: (['U', 'D', 'L', 'R'], [0.05, 0.05, 0.8500000000000001, 0.05]),\n",
281 | " 7: (['U', 'D', 'L', 'R'], [0.05, 0.8500000000000001, 0.05, 0.05]),\n",
282 | " 8: (['U', 'D', 'L', 'R'], [0.8500000000000001, 0.05, 0.05, 0.05]),\n",
283 | " 9: (['U', 'D', 'L', 'R'], [0.05, 0.05, 0.05, 0.8500000000000001]),\n",
284 | " 10: (['U', 'D', 'L', 'R'], [0.05, 0.05, 0.05, 0.8500000000000001]),\n",
285 | " 11: (['U', 'D', 'L', 'R'], [0.05, 0.8500000000000001, 0.05, 0.05]),\n",
286 | " 12: (['U', 'D', 'L', 'R'], [0.05, 0.05, 0.05, 0.8500000000000001]),\n",
287 | " 13: (['U', 'D', 'L', 'R'], [0.05, 0.05, 0.05, 0.8500000000000001]),\n",
288 | " 14: (['U', 'D', 'L', 'R'], [0.05, 0.05, 0.05, 0.8500000000000001])}"
289 | ]
290 | },
291 | "execution_count": 44,
292 | "metadata": {},
293 | "output_type": "execute_result"
294 | }
295 | ],
296 | "source": [
297 | "# final policy obtained\n",
298 | "pi_hat"
299 | ]
300 | },
301 | {
302 | "cell_type": "markdown",
303 | "metadata": {},
304 | "source": [
305 | "### Visualizing policy"
306 | ]
307 | },
308 | {
309 | "cell_type": "code",
310 | "execution_count": 45,
311 | "metadata": {
312 | "collapsed": true
313 | },
314 | "outputs": [],
315 | "source": [
316 | "def show_policy(pi, env):\n",
317 | " temp = np.zeros(len(env.states) + 2)\n",
318 | " for s in env.states:\n",
319 | " a = pi_hat[s][0][np.argmax(pi_hat[s][1])]\n",
320 | " if a == \"U\":\n",
321 | " temp[s] = 0.25\n",
322 | " elif a == \"D\":\n",
323 | " temp[s] = 0.5\n",
324 | " elif a == \"R\":\n",
325 | " temp[s] = 0.75\n",
326 | " else:\n",
327 | " temp[s] = 1.0\n",
328 | " \n",
329 | " temp = temp.reshape(4,4)\n",
330 | " ax = seaborn.heatmap(temp, cmap = \"prism\", linecolor=\"#282828\", cbar = False, linewidths = 0.1)\n",
331 | " plt.show()"
332 | ]
333 | },
334 | {
335 | "cell_type": "code",
336 | "execution_count": 46,
337 | "metadata": {},
338 | "outputs": [
339 | {
340 | "data": {
341 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAdwAAAFJCAYAAAAxCJwFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACsRJREFUeJzt3UtoXAX/xvFfmhEV4wUsSBGEbpoMvNDiwl0XgcZLQUKl\nYKtEQbJy0RRqkZh6wUtbV4KXULUrF75xZXEhFGuLgoKLQguBiSKCeNu4krbYtM68C6F//NPkNXmZ\nZ8z089nlHHLmoYfy7TkkdKDT6XQKAOiqdb0eAADXAsEFgADBBYAAwQWAAMEFgADBBYCARjcv3mw2\nq7Ww0M2PoEuaIyP1r5Z7t1bNN92/tWq+OVJVVa0P3b+1qLljpFqt1lXPecIFgADBBYAAwQWAAMEF\ngADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWA\nAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAA\nwQWAAMEFgIC/Hdx2u93NHQDQ1xrLnfzhhx/q0KFDNT8/X41Go9rtdm3atKmmp6dr48aNqY0AsOYt\nG9yZmZnat29fbd68+cqxM2fO1PT0dM3NzXV9HAD0i2VfKS8uLv4ltlVVW7Zs6eogAOhHyz7hDg8P\n1/T0dG3durVuvvnmOn/+fH322Wc1PDyc2gcAfWHZ4L7wwgt14sSJOn36dJ07d66GhoZqdHS0xsbG\nUvsAoC8sG9yBgYEaGxsTWAD4H/k9XAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQ\nXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBc\nAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBjodDqdbl282Wx269IA8I/UarWuerzR\n7Q9e2Hy02x9BF4ycnXTv1rCRs5P1r9ZCr2ewCvPNkaqqan3o/q1FzR0jS57zShkAAgQXAAIEFwAC\nBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIE\nFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQX\nAAIEFwACBBcAAgQXAAIEFwACGsudnJiYqEuXLv3lWKfTqYGBgZqbm+vqMADoJ8sG96mnnqoDBw7U\nW2+9VYODg6lNANB3lg3u5s2ba3x8vL7++usaGxtLbQKAvrNscKuqJicnEzsAoK/5oSkACBBcAAgQ\nXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBc\nAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwA\nCBBcAAgQXAAIGOh0Op1uXbzZbHbr0gDwj9Rqta56vNHtD17YfLTbH0EXjJydrJmFf/d6Bqv0yshu\n92+NemVkd1VVtT5c6PESVqO5Y2TJc14pA0CA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOAC\nQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJA\ngOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQMCKg7u4uNiNHQDQ15YM7smT\nJ2t0dLTGxsbq448/vnJ8cnIyMgwA+kljqRNHjhypY8eOVbvdrqmpqbp48WLt2LGjOp1Och8A9IUl\ng3vdddfVrbfeWlVVs7Oz9fjjj9eGDRtqYGAgNg4A+sWSr5TvvPPOOnToUF24cKGGhobqzTffrBdf\nfLG+++675D4A6AtLBvfgwYM1PDx85Yl2w4YN9d5779UDDzwQGwcA/WLJV8qNRqMeeuihvxxbv359\nzczMdH0UAPQbv4cLAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsA\nAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwAB\nggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAQOdTqfTrYs3m81uXRoA/pFardZVjze6/cEzC//u\n9kfQBa+M7Hbv1jD3b+16ZWR3VVW1FhZ6vITVaI6MLHnOK2UACBBcAAgQXAAIEFwACBBcAAgQXAAI\nEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQ\nXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIWFFwf//9\n91pcXOzWFgDoW8sG99tvv60nn3yypqen68svv6zt27fX9u3b69SpU6l9ANAXGsudfP7552tqaqp+\n+umn2rNnTx0/fryuv/76mpycrNHR0dRGAFjzlg1uu92ue+65p6qqvvrqq7r99tv//KbGst8GAPw/\ny75S3rhxY83MzFS73a7Dhw9XVdU777xT69evj4wDgH6x7KPqyy+/XCdPnqx16/6vy3fccUdNTEx0\nfRgA9JNlg7tu3bratm3bX46Nj493dRAA9CO/hwsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYIL\nAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsA\nAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAEDnU6n062LN5vNbl0aAP6R\nWq3WVY93NbgAwJ+8UgaAAMEFgADBBYAAwQWAAMEFgADBBYAAwV2Fdrtdzz33XD388MM1MTFR33//\nfa8nsUJnz56tiYmJXs9ghS5dulT79++vRx55pHbu3FmffvppryfxN/3xxx81PT1du3btqt27d9c3\n33zT60lxgrsKJ06cqMXFxfrggw9q3759dfjw4V5PYgXefffdOnDgQF28eLHXU1ihjz76qG677bZ6\n//336+jRo/XSSy/1ehJ/06lTp6qqam5urvbu3VuvvfZajxflCe4qnD59urZu3VpVVVu2bKn5+fke\nL2Il7rrrrnrjjTd6PYNVuP/++2tqaqqqqjqdTg0ODvZ4EX/Xtm3brvwD6eeff65bbrmlx4vyGr0e\nsBadO3euhoaGrnw9ODhYly9frkbDH+dacN9999WPP/7Y6xmswk033VRVf/4d3LNnT+3du7fHi1iJ\nRqNRTz/9dH3yySf1+uuv93pOnCfcVRgaGqrz589f+brdbosthPzyyy/12GOP1fj4eD344IO9nsMK\nvfrqq3X8+PF69tln68KFC72eEyW4q3D33XfX559/XlVVZ86cqU2bNvV4EVwbfv3113riiSdq//79\ntXPnzl7PYQWOHTtWb7/9dlVV3XjjjTUwMFDr1l1bCfJYtgpjY2P1xRdf1K5du6rT6dTBgwd7PQmu\nCUeOHKnffvutZmdna3Z2tqr+/CG4G264ocfL+G/uvffemp6erkcffbQuX75czzzzzDV33/xvQQAQ\ncG09zwNAjwguAAQILgAECC4ABAguAAQILgAECC4ABAguAAT8B0u/pa2vpAxwAAAAAElFTkSuQmCC\n",
342 | "text/plain": [
343 | ""
344 | ]
345 | },
346 | "metadata": {},
347 | "output_type": "display_data"
348 | }
349 | ],
350 | "source": [
351 | "### RED = TERMINAL (0)\n",
352 | "### GREEN = LEFT\n",
353 | "### BLUE = UP\n",
354 | "### PURPLE = RIGHT\n",
355 | "### ORANGE = DOWN\n",
356 | "\n",
357 | "show_policy(pi_hat, gw)"
358 | ]
359 | }
360 | ],
361 | "metadata": {
362 | "kernelspec": {
363 | "display_name": "Python 3",
364 | "language": "python",
365 | "name": "python3"
366 | },
367 | "language_info": {
368 | "codemirror_mode": {
369 | "name": "ipython",
370 | "version": 3
371 | },
372 | "file_extension": ".py",
373 | "mimetype": "text/x-python",
374 | "name": "python",
375 | "nbconvert_exporter": "python",
376 | "pygments_lexer": "ipython3",
377 | "version": "3.6.1"
378 | }
379 | },
380 | "nbformat": 4,
381 | "nbformat_minor": 2
382 | }
383 |
--------------------------------------------------------------------------------
/source code/2-Monte Carlo Methods (Chapter 5)/.ipynb_checkpoints/4-off-policy-monte-carlo-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Off-policy Monte Carlo\n",
8 | "- Algorithms from ```pp. 84 - 92``` in Sutton & Barto 2017\n",
9 | " - Off-policy prediction via importance sampling\n",
10 | " - Off-policy MC control"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 1,
16 | "metadata": {
17 | "collapsed": true
18 | },
19 | "outputs": [],
20 | "source": [
21 | "import matplotlib.pyplot as plt\n",
22 | "import pandas as pd\n",
23 | "import numpy as np\n",
24 | "import seaborn\n",
25 | "\n",
26 | "from gridWorldEnvironment import GridWorld"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": 2,
32 | "metadata": {
33 | "collapsed": true
34 | },
35 | "outputs": [],
36 | "source": [
37 | "# creating gridworld environment\n",
38 | "gw = GridWorld(gamma = .9)"
39 | ]
40 | },
41 | {
42 | "cell_type": "markdown",
43 | "metadata": {},
44 | "source": [
45 | "### Generate functions for generating policy"
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": 3,
51 | "metadata": {
52 | "collapsed": true
53 | },
54 | "outputs": [],
55 | "source": [
56 | "def generate_any_policy(env):\n",
57 | " pi = dict()\n",
58 | " for state in env.states:\n",
59 | " r = sorted(np.random.sample(3))\n",
60 | " actions = env.actions\n",
61 | " prob = [r[0], r[1] - r[0], r[2] - r[1], 1-r[2]]\n",
62 | " pi[state] = (actions, prob)\n",
63 | " return pi "
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": 4,
69 | "metadata": {
70 | "collapsed": true
71 | },
72 | "outputs": [],
73 | "source": [
74 | "def generate_random_policy(env):\n",
75 | " pi = dict()\n",
76 | " for state in env.states:\n",
77 | " actions = []\n",
78 | " prob = []\n",
79 | " for action in env.actions:\n",
80 | " actions.append(action)\n",
81 | " prob.append(0.25)\n",
82 | " pi[state] = (actions, prob)\n",
83 | " return pi"
84 | ]
85 | },
86 | {
87 | "cell_type": "code",
88 | "execution_count": 5,
89 | "metadata": {
90 | "collapsed": true
91 | },
92 | "outputs": [],
93 | "source": [
94 | "def generate_greedy_policy(env, Q):\n",
95 | " pi = dict()\n",
96 | " for state in env.states:\n",
97 | " actions = []\n",
98 | " q_values = []\n",
99 | " prob = []\n",
100 | " \n",
101 | " for a in env.actions:\n",
102 | " actions.append(a)\n",
103 | " q_values.append(Q[state,a]) \n",
104 | " for i in range(len(q_values)):\n",
105 | " if i == np.argmax(q_values):\n",
106 | " prob.append(1)\n",
107 | " else:\n",
108 | " prob.append(0) \n",
109 | " \n",
110 | " pi[state] = (actions, prob)\n",
111 | " return pi"
112 | ]
113 | },
114 | {
115 | "cell_type": "markdown",
116 | "metadata": {},
117 | "source": [
118 | "### Create Action Values\n",
119 | "- Initialize all state-action values with 0"
120 | ]
121 | },
122 | {
123 | "cell_type": "code",
124 | "execution_count": 6,
125 | "metadata": {
126 | "collapsed": true
127 | },
128 | "outputs": [],
129 | "source": [
130 | "def state_action_value(env):\n",
131 | " q = dict()\n",
132 | " for state, action, next_state, reward in env.transitions:\n",
133 | " q[(state, action)] = np.random.random()\n",
134 | " return q"
135 | ]
136 | },
137 | {
138 | "cell_type": "code",
139 | "execution_count": 7,
140 | "metadata": {
141 | "collapsed": true
142 | },
143 | "outputs": [],
144 | "source": [
145 | "def weight_cum_sum(env):\n",
146 | " c = dict()\n",
147 | " for state, action, next_state, reward in env.transitions:\n",
148 | " c[(state, action)] = 0\n",
149 | " return c"
150 | ]
151 | },
152 | {
153 | "cell_type": "markdown",
154 | "metadata": {},
155 | "source": [
156 | "### Generate episode\n",
157 | "- Generate episode based on current policy ($\\pi$)"
158 | ]
159 | },
160 | {
161 | "cell_type": "code",
162 | "execution_count": 8,
163 | "metadata": {
164 | "collapsed": true,
165 | "scrolled": true
166 | },
167 | "outputs": [],
168 | "source": [
169 | "def generate_episode(env, policy):\n",
170 | " episode = []\n",
171 | " done = False\n",
172 | " current_state = np.random.choice(env.states)\n",
173 | " action = np.random.choice(policy[current_state][0], p = policy[current_state][1])\n",
174 | " episode.append((current_state, action, -1))\n",
175 | " \n",
176 | " while not done:\n",
177 | " next_state, reward = gw.state_transition(current_state, action)\n",
178 | " action = np.random.choice(policy[current_state][0], p = policy[current_state][1])\n",
179 | " episode.append((next_state, action, reward))\n",
180 | " \n",
181 | " if next_state == 0:\n",
182 | " done = True\n",
183 | " current_state = next_state\n",
184 | " \n",
185 | " return episode"
186 | ]
187 | },
188 | {
189 | "cell_type": "code",
190 | "execution_count": 9,
191 | "metadata": {
192 | "collapsed": true,
193 | "scrolled": true
194 | },
195 | "outputs": [],
196 | "source": [
197 | "pi = generate_random_policy(gw)"
198 | ]
199 | },
200 | {
201 | "cell_type": "markdown",
202 | "metadata": {},
203 | "source": [
204 | "### Off-policy MC Prediction\n",
205 | "- Estimates $Q$ values"
206 | ]
207 | },
208 | {
209 | "cell_type": "code",
210 | "execution_count": 10,
211 | "metadata": {
212 | "collapsed": true
213 | },
214 | "outputs": [],
215 | "source": [
216 | "def off_policy_mc_prediction(env, pi, num_iter):\n",
217 | " Q = state_action_value(env)\n",
218 | " C = weight_cum_sum(env)\n",
219 | " \n",
220 | " for _ in range(num_iter):\n",
221 | " b = generate_any_policy(env)\n",
222 | " episode = generate_episode(gw, b)\n",
223 | " G = 0\n",
224 | " W = 1\n",
225 | " for i in range(len(episode)-1, -1, -1):\n",
226 | " s, a, r = episode[i]\n",
227 | " if s != 0:\n",
228 | " G = env.gamma * G + r\n",
229 | " C[s,a] += W\n",
230 | " Q[s,a] += (W / C[s,a]) * (G - Q[s,a])\n",
231 | " W *= pi[s][1][pi[s][0].index(a)] / b[s][1][b[s][0].index(a)]\n",
232 | " if W == 0:\n",
233 | " break\n",
234 | " \n",
235 | " return Q"
236 | ]
237 | },
238 | {
239 | "cell_type": "markdown",
240 | "metadata": {},
241 | "source": [
242 | "### Off-policy MC Control\n",
243 | "- Finds optimal policy $pi \\approx pi_*$"
244 | ]
245 | },
246 | {
247 | "cell_type": "code",
248 | "execution_count": 12,
249 | "metadata": {
250 | "collapsed": true
251 | },
252 | "outputs": [],
253 | "source": [
254 | "def off_policy_mc_control(env, pi, num_iter):\n",
255 | " Q = state_action_value(env)\n",
256 | " C = weight_cum_sum(env)\n",
257 | " pi = generate_greedy_policy(env, Q)\n",
258 | " \n",
259 | " for _ in range(num_iter):\n",
260 | " b = generate_any_policy(env)\n",
261 | " episode = generate_episode(gw, b)\n",
262 | " G = 0\n",
263 | " W = 1\n",
264 | " for i in range(len(episode)-1, -1, -1):\n",
265 | " s, a, r = episode[i]\n",
266 | " if s != 0:\n",
267 | " G = env.gamma * G + r\n",
268 | " C[s,a] += W\n",
269 | " Q[s,a] += (W / C[s,a]) * (G - Q[s,a])\n",
270 | " pi = generate_greedy_policy(env, Q)\n",
271 | " if a == pi[s][0][np.argmax(pi[s][1])]:\n",
272 | " break\n",
273 | " W *= 1 / b[s][1][b[s][0].index(a)]\n",
274 | "\n",
275 | " return Q, pi"
276 | ]
277 | },
278 | {
279 | "cell_type": "code",
280 | "execution_count": 21,
281 | "metadata": {},
282 | "outputs": [
283 | {
284 | "name": "stdout",
285 | "output_type": "stream",
286 | "text": [
287 | "Wall time: 1.51 s\n"
288 | ]
289 | }
290 | ],
291 | "source": [
292 | "%%time\n",
293 | "Q_hat, pi_hat = off_policy_mc_control(gw, generate_random_policy(gw), 1000)"
294 | ]
295 | },
296 | {
297 | "cell_type": "code",
298 | "execution_count": 22,
299 | "metadata": {},
300 | "outputs": [
301 | {
302 | "data": {
303 | "text/plain": [
304 | "{1: (['U', 'D', 'L', 'R'], [0, 0, 1, 0]),\n",
305 | " 2: (['U', 'D', 'L', 'R'], [0, 1, 0, 0]),\n",
306 | " 3: (['U', 'D', 'L', 'R'], [0, 1, 0, 0]),\n",
307 | " 4: (['U', 'D', 'L', 'R'], [1, 0, 0, 0]),\n",
308 | " 5: (['U', 'D', 'L', 'R'], [1, 0, 0, 0]),\n",
309 | " 6: (['U', 'D', 'L', 'R'], [0, 0, 0, 1]),\n",
310 | " 7: (['U', 'D', 'L', 'R'], [0, 0, 0, 1]),\n",
311 | " 8: (['U', 'D', 'L', 'R'], [0, 0, 0, 1]),\n",
312 | " 9: (['U', 'D', 'L', 'R'], [0, 0, 1, 0]),\n",
313 | " 10: (['U', 'D', 'L', 'R'], [0, 0, 0, 1]),\n",
314 | " 11: (['U', 'D', 'L', 'R'], [0, 1, 0, 0]),\n",
315 | " 12: (['U', 'D', 'L', 'R'], [0, 0, 1, 0]),\n",
316 | " 13: (['U', 'D', 'L', 'R'], [0, 0, 1, 0]),\n",
317 | " 14: (['U', 'D', 'L', 'R'], [0, 0, 0, 1])}"
318 | ]
319 | },
320 | "execution_count": 22,
321 | "metadata": {},
322 | "output_type": "execute_result"
323 | }
324 | ],
325 | "source": [
326 | "# final policy obtained\n",
327 | "pi_hat"
328 | ]
329 | },
330 | {
331 | "cell_type": "markdown",
332 | "metadata": {},
333 | "source": [
334 | "### Visualizing policy"
335 | ]
336 | },
337 | {
338 | "cell_type": "code",
339 | "execution_count": 23,
340 | "metadata": {
341 | "collapsed": true
342 | },
343 | "outputs": [],
344 | "source": [
345 | "def show_policy(pi, env):\n",
346 | " temp = np.zeros(len(env.states) + 2)\n",
347 | " for s in env.states:\n",
348 | " a = pi_hat[s][0][np.argmax(pi_hat[s][1])]\n",
349 | " if a == \"U\":\n",
350 | " temp[s] = 0.25\n",
351 | " elif a == \"D\":\n",
352 | " temp[s] = 0.5\n",
353 | " elif a == \"R\":\n",
354 | " temp[s] = 0.75\n",
355 | " else:\n",
356 | " temp[s] = 1.0\n",
357 | " \n",
358 | " temp = temp.reshape(4,4)\n",
359 | " ax = seaborn.heatmap(temp, cmap = \"prism\", linecolor=\"#282828\", cbar = False, linewidths = 0.1)\n",
360 | " plt.show()"
361 | ]
362 | },
363 | {
364 | "cell_type": "code",
365 | "execution_count": 24,
366 | "metadata": {},
367 | "outputs": [
368 | {
369 | "data": {
370 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAdwAAAFJCAYAAAAxCJwFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACsZJREFUeJzt3UtoXIXfxvFfmhEV4wUsSBGEbtoMCC0u3HURaLwUJFQK\ntkoUJCsXTaEWGVMveGnrSvAS6mXlQuPK4kIo1hYFBReFFgITRQTxtnElbbFpnfkv5N8XX9q8Ji/z\njJl+PrucQ8489DB8O4eEDHW73W4BAD21pt8DAOBqILgAECC4ABAguAAQILgAECC4ABDQ6OXFm81m\ntRcWevkS9EhzdLTubLt3q9V8c7TaH7l/q1Fz+2hVlfu3SjW3j1a73b7sOZ9wASBAcAEgQHABIEBw\nASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHAB\nIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEg\nQHABIEBwASDgHwe30+n0cgcADLTGUid//PHHOnjwYM3Pz1ej0ahOp1MbNmyoVqtV69evT20EgFVv\nyeDOzMzU3r17a9OmTZeOnTp1qlqtVs3NzfV8HAAMiiUfKS8uLv4ttlVVmzdv7ukgABhES37C3bhx\nY7VardqyZUvdeOONdfbs2fr8889r48aNqX0AMBCWDO7zzz9fx44dq5MnT9aZM2dqZGSkxsbGanx8\nPLUPAAbCksEdGhqq8fFxgQWA/ye/hwsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwAB\nggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGC\nCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAFD3W6326uLN5vNXl0aAP6V2u32ZY83\nev3CC5ve7fVL0AOjp6fcu1Vs9PRUzSx80O8ZrMDLo7uqqty/Veq/9+9yPFIGgADBBYAAwQWAAMEF\ngADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWA\nAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAA\nwQWAAMEFgADBBYAAwQWAgMZSJycnJ+vChQt/O9btdmtoaKjm5uZ6OgwABsmSwX3yySdr//799eab\nb9bw8HBqEwAMnCWDu2nTppqYmKhvvvmmxsfHU5sAYOAsGdyqqqmpqcQOABhofmgKAAIEFwACBBcA\nAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwAC\nBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIE\nFwACBBcAAoa63W63VxdvNpu9ujQA/Cu12+3LHm/0+oVnFj7o9UvQAy+P7qo72wv9nsEKzTdHvfdW\nqZdHd1VVVfsj77/VqLl99IrnPFIGgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEF\ngADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWA\nAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAgGUHd3FxsRc7AGCgXTG4x48fr7Gx\nsRofH69PPvnk0vGpqanIMAAYJI0rnTh8+HAdOXKkOp1OTU9P1/nz52v79u3V7XaT+wBgIFwxuNdc\nc03dfPPNVVU1Oztbjz32WK1bt66GhoZi4wBgUFzxkfLtt99eBw8erHPnztXIyEi98cYb9cILL9T3\n33+f3AcAA+GKwT1w4EBt3Ljx0ifadevW1XvvvVf3339/bBwADIorPlJuNBr14IMP/u3Y2rVra2Zm\npuejAGDQ+D1cAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBc\nAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwA\nCBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBjqdrvdXl282Wz26tIA8K/Ubrcve7zR6xe+s73Q65eg\nB+abo+7dKjbfHK2ZhQ/6PYMVeHl0V1VVtRe8/1aj5ujoFc95pAwAAYILAAGCCwABggsAAYILAAGC\nCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYIL\nAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABywru\nH3/8UYuLi73aAgADa8ngfvfdd/XEE09Uq9Wqr776qrZt21bbtm2rEydOpPYBwEBoLHXyueeeq+np\n6fr5559r9+7ddfTo0br22mtramqqxsbGUhsBYNVbMridTqfuvvvuqqr6+uuv69Zbb/3rmxpLfhsA\n8L8s+Uh5/fr1NTMzU51Opw4dOlRVVW+//XatXbs2Mg4ABsWSH1VfeumlOn78eK1Z8z9dvu2222py\ncrLnwwBgkCwZ3DVr1tTWrVv/dmxiYqKngwBgEPk9XAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAI\nEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQ\nXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBjqdrvdXl282Wz26tIA\n8K/Ubrcve7ynwQUA/uKRMgAECC4ABAguAAQILgAECC4ABAguAAQI7gp0Op169tln66GHHqrJycn6\n4Ycf+j2JZTp9+nRNTk72ewbLdOHChdq3b189/PDDtWPHjvrss8/6PYl/6M8//6xWq1U7d+6sXbt2\n1bffftvvSXGCuwLHjh2rxcXF+vDDD2vv3r116NChfk9iGd55553av39/nT9/vt9TWKaPP/64brnl\nlnr//ffr3XffrRdffLHfk/iHTpw4UVVVc3NztWfPnnr11Vf7vChPcFfg5MmTtWXLlqqq2rx5c83P\nz/d5Ectxxx131Ouvv97vGazAfffdV9PT01VV1e12a3h4uM+L+Ke2bt166T9Iv/zyS9100019XpTX\n6PeA1ejMmTM1MjJy6evh4eG6ePFiNRr+OVeDe++9t3766ad+z2AFbrjhhqr66z24e/fu2rNnT58X\nsRyNRqOeeuqp+vTTT+u1117r95w4n3BXYGRkpM6ePXvp606nI7YQ8uuvv9ajjz5aExMT9cADD/R7\nDsv0yiuv1NGjR+uZZ56pc+fO9XtOlOCuwF133VVffPFFVVWdOnWqNmzY0OdFcHX47bff6vHHH699\n+/bVjh07+j2HZThy5Ei99dZbVVV1/fXX19DQUK1Zc3UlyMeyFRgfH68vv/yydu7cWd1utw4cONDv\nSXBVOHz4cP3+++81Oztbs7OzVfXXD8Fdd911fV7G/+Wee+6pVqtVjzzySF28eLGefvrpq+6++WtB\nABBwdX2eB4A+EVwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACPgPqwelrdFfLd8AAAAASUVORK5C\nYII=\n",
371 | "text/plain": [
372 | ""
373 | ]
374 | },
375 | "metadata": {},
376 | "output_type": "display_data"
377 | }
378 | ],
379 | "source": [
380 | "### RED = TERMINAL (0)\n",
381 | "### GREEN = LEFT\n",
382 | "### BLUE = UP\n",
383 | "### PURPLE = RIGHT\n",
384 | "### ORANGE = DOWN\n",
385 | "\n",
386 | "show_policy(pi_hat, gw)"
387 | ]
388 | }
389 | ],
390 | "metadata": {
391 | "kernelspec": {
392 | "display_name": "Python 3",
393 | "language": "python",
394 | "name": "python3"
395 | },
396 | "language_info": {
397 | "codemirror_mode": {
398 | "name": "ipython",
399 | "version": 3
400 | },
401 | "file_extension": ".py",
402 | "mimetype": "text/x-python",
403 | "name": "python",
404 | "nbconvert_exporter": "python",
405 | "pygments_lexer": "ipython3",
406 | "version": "3.6.1"
407 | }
408 | },
409 | "nbformat": 4,
410 | "nbformat_minor": 2
411 | }
412 |
--------------------------------------------------------------------------------
/source code/2-Monte Carlo Methods (Chapter 5)/3-on-policy-monte-carlo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# On-policy Monte Carlo\n",
8 | "- Algorithms from ```pp. 82 - 84``` in Sutton & Barto 2017\n",
9 | "- Estimates optimal policy $\\pi \\approx \\pi_*$"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 34,
15 | "metadata": {
16 | "collapsed": true
17 | },
18 | "outputs": [],
19 | "source": [
20 | "import matplotlib.pyplot as plt\n",
21 | "import pandas as pd\n",
22 | "import numpy as np\n",
23 | "import seaborn\n",
24 | "\n",
25 | "from gridWorldEnvironment import GridWorld"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": 35,
31 | "metadata": {
32 | "collapsed": true
33 | },
34 | "outputs": [],
35 | "source": [
36 | "# creating gridworld environment\n",
37 | "gw = GridWorld(gamma = .9)"
38 | ]
39 | },
40 | {
41 | "cell_type": "markdown",
42 | "metadata": {},
43 | "source": [
44 | "### Generate policy\n",
45 | "- We start with equiprobable random policy"
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": 36,
51 | "metadata": {
52 | "collapsed": true
53 | },
54 | "outputs": [],
55 | "source": [
56 | "def generate_random_policy(env):\n",
57 | " pi = dict()\n",
58 | " for state in env.states:\n",
59 | " actions = []\n",
60 | " prob = []\n",
61 | " for action in env.actions:\n",
62 | " actions.append(action)\n",
63 | " prob.append(0.25)\n",
64 | " pi[state] = (actions, prob)\n",
65 | " return pi"
66 | ]
67 | },
68 | {
69 | "cell_type": "markdown",
70 | "metadata": {},
71 | "source": [
72 | "### Create Action Values\n",
73 | "- Initialize all state-action values with 0"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": 37,
79 | "metadata": {
80 | "collapsed": true
81 | },
82 | "outputs": [],
83 | "source": [
84 | "def state_action_value(env):\n",
85 | " q = dict()\n",
86 | " for state, action, next_state, reward in env.transitions:\n",
87 | " q[(state, action)] = 0\n",
88 | " return q"
89 | ]
90 | },
91 | {
92 | "cell_type": "code",
93 | "execution_count": 38,
94 | "metadata": {},
95 | "outputs": [
96 | {
97 | "name": "stdout",
98 | "output_type": "stream",
99 | "text": [
100 | "0\n",
101 | "0\n",
102 | "0\n"
103 | ]
104 | }
105 | ],
106 | "source": [
107 | "q = state_action_value(gw)\n",
108 | "\n",
109 | "print(q[(2, \"U\")])\n",
110 | "print(q[(4, \"L\")])\n",
111 | "print(q[(10, \"R\")])"
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": 39,
117 | "metadata": {
118 | "collapsed": true,
119 | "scrolled": true
120 | },
121 | "outputs": [],
122 | "source": [
123 | "pi = generate_random_policy(gw)"
124 | ]
125 | },
126 | {
127 | "cell_type": "markdown",
128 | "metadata": {},
129 | "source": [
130 | "### Generate episode\n",
131 | "- Generate episode based on current policy ($\\pi$)"
132 | ]
133 | },
134 | {
135 | "cell_type": "code",
136 | "execution_count": 40,
137 | "metadata": {
138 | "collapsed": true,
139 | "scrolled": true
140 | },
141 | "outputs": [],
142 | "source": [
143 | "def generate_episode(env, policy):\n",
144 | " episode = []\n",
145 | " done = False\n",
146 | " current_state = np.random.choice(env.states)\n",
147 | " action = np.random.choice(policy[current_state][0], p = policy[current_state][1])\n",
148 | " episode.append((current_state, action, -1))\n",
149 | " \n",
150 | " while not done:\n",
151 | " next_state, reward = gw.state_transition(current_state, action)\n",
152 | " action = np.random.choice(policy[current_state][0], p = policy[current_state][1])\n",
153 | " episode.append((next_state, action, reward))\n",
154 | " \n",
155 | " if next_state == 0:\n",
156 | " done = True\n",
157 | " current_state = next_state\n",
158 | " \n",
159 | " return episode"
160 | ]
161 | },
162 | {
163 | "cell_type": "code",
164 | "execution_count": 41,
165 | "metadata": {},
166 | "outputs": [
167 | {
168 | "data": {
169 | "text/plain": [
170 | "[(6, 'U', -1),\n",
171 | " (2, 'D', -1),\n",
172 | " (6, 'R', -1),\n",
173 | " (7, 'U', -1),\n",
174 | " (3, 'D', -1),\n",
175 | " (7, 'D', -1),\n",
176 | " (11, 'U', -1),\n",
177 | " (7, 'U', -1),\n",
178 | " (3, 'L', -1),\n",
179 | " (2, 'D', -1),\n",
180 | " (6, 'D', -1),\n",
181 | " (10, 'D', -1),\n",
182 | " (14, 'R', -1),\n",
183 | " (0, 'R', -1)]"
184 | ]
185 | },
186 | "execution_count": 41,
187 | "metadata": {},
188 | "output_type": "execute_result"
189 | }
190 | ],
191 | "source": [
192 | "generate_episode(gw, pi)"
193 | ]
194 | },
195 | {
196 | "cell_type": "markdown",
197 | "metadata": {},
198 | "source": [
199 | "### On-policy Monte Carlo Control "
200 | ]
201 | },
202 | {
203 | "cell_type": "code",
204 | "execution_count": 42,
205 | "metadata": {
206 | "collapsed": true
207 | },
208 | "outputs": [],
209 | "source": [
210 | "# first-visit MC\n",
211 | "def on_policy_mc(env, pi, e, num_iter):\n",
212 | " Q = state_action_value(env)\n",
213 | " returns = dict()\n",
214 | " for s, a in Q:\n",
215 | " returns[(s,a)] = []\n",
216 | " \n",
217 | " for i in range(num_iter):\n",
218 | " episode = generate_episode(env, pi)\n",
219 | " already_visited = set({0})\n",
220 | " for s, a, r in episode:\n",
221 | " if s not in already_visited:\n",
222 | " already_visited.add(s)\n",
223 | " idx = episode.index((s, a, r))\n",
224 | " G = 0\n",
225 | " j = 1\n",
226 | " while j + idx < len(episode):\n",
227 | " G = env.gamma * (G + episode[j + idx][-1])\n",
228 | " j += 1\n",
229 | " returns[(s,a)].append(G)\n",
230 | " Q[(s,a)] = np.mean(returns[(s,a)])\n",
231 | " for s, _, _ in episode:\n",
232 | " if s != 0:\n",
233 | " actions = []\n",
234 | " action_values = []\n",
235 | " prob = []\n",
236 | "\n",
237 | " for a in env.actions:\n",
238 | " actions.append(a)\n",
239 | " action_values.append(Q[s,a]) \n",
240 | " for i in range(len(action_values)):\n",
241 | " if i == np.argmax(action_values):\n",
242 | " prob.append(1 - e + e/len(actions))\n",
243 | " else:\n",
244 | " prob.append(e/len(actions)) \n",
245 | " pi[s] = (actions, prob)\n",
246 | " return Q, pi"
247 | ]
248 | },
249 | {
250 | "cell_type": "code",
251 | "execution_count": 43,
252 | "metadata": {},
253 | "outputs": [
254 | {
255 | "name": "stdout",
256 | "output_type": "stream",
257 | "text": [
258 | "Wall time: 1min 8s\n"
259 | ]
260 | }
261 | ],
262 | "source": [
263 | "# Obtained Estimates for Q & pi after 100000 iterations\n",
264 | "Q_hat, pi_hat = on_policy_mc(gw, generate_random_policy(gw), 0.2, 100000)"
265 | ]
266 | },
267 | {
268 | "cell_type": "code",
269 | "execution_count": 44,
270 | "metadata": {},
271 | "outputs": [
272 | {
273 | "data": {
274 | "text/plain": [
275 | "{1: (['U', 'D', 'L', 'R'], [0.05, 0.05, 0.8500000000000001, 0.05]),\n",
276 | " 2: (['U', 'D', 'L', 'R'], [0.05, 0.05, 0.8500000000000001, 0.05]),\n",
277 | " 3: (['U', 'D', 'L', 'R'], [0.05, 0.8500000000000001, 0.05, 0.05]),\n",
278 | " 4: (['U', 'D', 'L', 'R'], [0.8500000000000001, 0.05, 0.05, 0.05]),\n",
279 | " 5: (['U', 'D', 'L', 'R'], [0.8500000000000001, 0.05, 0.05, 0.05]),\n",
280 | " 6: (['U', 'D', 'L', 'R'], [0.05, 0.05, 0.8500000000000001, 0.05]),\n",
281 | " 7: (['U', 'D', 'L', 'R'], [0.05, 0.8500000000000001, 0.05, 0.05]),\n",
282 | " 8: (['U', 'D', 'L', 'R'], [0.8500000000000001, 0.05, 0.05, 0.05]),\n",
283 | " 9: (['U', 'D', 'L', 'R'], [0.05, 0.05, 0.05, 0.8500000000000001]),\n",
284 | " 10: (['U', 'D', 'L', 'R'], [0.05, 0.05, 0.05, 0.8500000000000001]),\n",
285 | " 11: (['U', 'D', 'L', 'R'], [0.05, 0.8500000000000001, 0.05, 0.05]),\n",
286 | " 12: (['U', 'D', 'L', 'R'], [0.05, 0.05, 0.05, 0.8500000000000001]),\n",
287 | " 13: (['U', 'D', 'L', 'R'], [0.05, 0.05, 0.05, 0.8500000000000001]),\n",
288 | " 14: (['U', 'D', 'L', 'R'], [0.05, 0.05, 0.05, 0.8500000000000001])}"
289 | ]
290 | },
291 | "execution_count": 44,
292 | "metadata": {},
293 | "output_type": "execute_result"
294 | }
295 | ],
296 | "source": [
297 | "# final policy obtained\n",
298 | "pi_hat"
299 | ]
300 | },
301 | {
302 | "cell_type": "markdown",
303 | "metadata": {},
304 | "source": [
305 | "### Visualizing policy"
306 | ]
307 | },
308 | {
309 | "cell_type": "code",
310 | "execution_count": 45,
311 | "metadata": {
312 | "collapsed": true
313 | },
314 | "outputs": [],
315 | "source": [
316 | "def show_policy(pi, env):\n",
317 | " temp = np.zeros(len(env.states) + 2)\n",
318 | " for s in env.states:\n",
319 | " a = pi_hat[s][0][np.argmax(pi_hat[s][1])]\n",
320 | " if a == \"U\":\n",
321 | " temp[s] = 0.25\n",
322 | " elif a == \"D\":\n",
323 | " temp[s] = 0.5\n",
324 | " elif a == \"R\":\n",
325 | " temp[s] = 0.75\n",
326 | " else:\n",
327 | " temp[s] = 1.0\n",
328 | " \n",
329 | " temp = temp.reshape(4,4)\n",
330 | " ax = seaborn.heatmap(temp, cmap = \"prism\", linecolor=\"#282828\", cbar = False, linewidths = 0.1)\n",
331 | " plt.show()"
332 | ]
333 | },
334 | {
335 | "cell_type": "code",
336 | "execution_count": 46,
337 | "metadata": {},
338 | "outputs": [
339 | {
340 | "data": {
341 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAdwAAAFJCAYAAAAxCJwFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACsRJREFUeJzt3UtoXAX/xvFfmhEV4wUsSBGEbpoMvNDiwl0XgcZLQUKl\nYKtEQbJy0RRqkZh6wUtbV4KXULUrF75xZXEhFGuLgoKLQguBiSKCeNu4krbYtM68C6F//NPkNXmZ\nZ8z089nlHHLmoYfy7TkkdKDT6XQKAOiqdb0eAADXAsEFgADBBYAAwQWAAMEFgADBBYCARjcv3mw2\nq7Ww0M2PoEuaIyP1r5Z7t1bNN92/tWq+OVJVVa0P3b+1qLljpFqt1lXPecIFgADBBYAAwQWAAMEF\ngADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWA\nAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAA\nwQWAAMEFgIC/Hdx2u93NHQDQ1xrLnfzhhx/q0KFDNT8/X41Go9rtdm3atKmmp6dr48aNqY0AsOYt\nG9yZmZnat29fbd68+cqxM2fO1PT0dM3NzXV9HAD0i2VfKS8uLv4ltlVVW7Zs6eogAOhHyz7hDg8P\n1/T0dG3durVuvvnmOn/+fH322Wc1PDyc2gcAfWHZ4L7wwgt14sSJOn36dJ07d66GhoZqdHS0xsbG\nUvsAoC8sG9yBgYEaGxsTWAD4H/k9XAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQ\nXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBc\nAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBjodDqdbl282Wx269IA8I/UarWuerzR\n7Q9e2Hy02x9BF4ycnXTv1rCRs5P1r9ZCr2ewCvPNkaqqan3o/q1FzR0jS57zShkAAgQXAAIEFwAC\nBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIE\nFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQX\nAAIEFwACBBcAAgQXAAIEFwACGsudnJiYqEuXLv3lWKfTqYGBgZqbm+vqMADoJ8sG96mnnqoDBw7U\nW2+9VYODg6lNANB3lg3u5s2ba3x8vL7++usaGxtLbQKAvrNscKuqJicnEzsAoK/5oSkACBBcAAgQ\nXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBc\nAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwA\nCBBcAAgQXAAIGOh0Op1uXbzZbHbr0gDwj9Rqta56vNHtD17YfLTbH0EXjJydrJmFf/d6Bqv0yshu\n92+NemVkd1VVtT5c6PESVqO5Y2TJc14pA0CA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOAC\nQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJA\ngOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQMCKg7u4uNiNHQDQ15YM7smT\nJ2t0dLTGxsbq448/vnJ8cnIyMgwA+kljqRNHjhypY8eOVbvdrqmpqbp48WLt2LGjOp1Och8A9IUl\ng3vdddfVrbfeWlVVs7Oz9fjjj9eGDRtqYGAgNg4A+sWSr5TvvPPOOnToUF24cKGGhobqzTffrBdf\nfLG+++675D4A6AtLBvfgwYM1PDx85Yl2w4YN9d5779UDDzwQGwcA/WLJV8qNRqMeeuihvxxbv359\nzczMdH0UAPQbv4cLAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsA\nAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwAB\nggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAQOdTqfTrYs3m81uXRoA/pFardZVjze6/cEzC//u\n9kfQBa+M7Hbv1jD3b+16ZWR3VVW1FhZ6vITVaI6MLHnOK2UACBBcAAgQXAAIEFwACBBcAAgQXAAI\nEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQ\nXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIWFFwf//9\n91pcXOzWFgDoW8sG99tvv60nn3yypqen68svv6zt27fX9u3b69SpU6l9ANAXGsudfP7552tqaqp+\n+umn2rNnTx0/fryuv/76mpycrNHR0dRGAFjzlg1uu92ue+65p6qqvvrqq7r99tv//KbGst8GAPw/\ny75S3rhxY83MzFS73a7Dhw9XVdU777xT69evj4wDgH6x7KPqyy+/XCdPnqx16/6vy3fccUdNTEx0\nfRgA9JNlg7tu3bratm3bX46Nj493dRAA9CO/hwsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYIL\nAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsA\nAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAEDnU6n062LN5vNbl0aAP6R\nWq3WVY93NbgAwJ+8UgaAAMEFgADBBYAAwQWAAMEFgADBBYAAwV2Fdrtdzz33XD388MM1MTFR33//\nfa8nsUJnz56tiYmJXs9ghS5dulT79++vRx55pHbu3FmffvppryfxN/3xxx81PT1du3btqt27d9c3\n33zT60lxgrsKJ06cqMXFxfrggw9q3759dfjw4V5PYgXefffdOnDgQF28eLHXU1ihjz76qG677bZ6\n//336+jRo/XSSy/1ehJ/06lTp6qqam5urvbu3VuvvfZajxflCe4qnD59urZu3VpVVVu2bKn5+fke\nL2Il7rrrrnrjjTd6PYNVuP/++2tqaqqqqjqdTg0ODvZ4EX/Xtm3brvwD6eeff65bbrmlx4vyGr0e\nsBadO3euhoaGrnw9ODhYly9frkbDH+dacN9999WPP/7Y6xmswk033VRVf/4d3LNnT+3du7fHi1iJ\nRqNRTz/9dH3yySf1+uuv93pOnCfcVRgaGqrz589f+brdbosthPzyyy/12GOP1fj4eD344IO9nsMK\nvfrqq3X8+PF69tln68KFC72eEyW4q3D33XfX559/XlVVZ86cqU2bNvV4EVwbfv3113riiSdq//79\ntXPnzl7PYQWOHTtWb7/9dlVV3XjjjTUwMFDr1l1bCfJYtgpjY2P1xRdf1K5du6rT6dTBgwd7PQmu\nCUeOHKnffvutZmdna3Z2tqr+/CG4G264ocfL+G/uvffemp6erkcffbQuX75czzzzzDV33/xvQQAQ\ncG09zwNAjwguAAQILgAECC4ABAguAAQILgAECC4ABAguAAT8B0u/pa2vpAxwAAAAAElFTkSuQmCC\n",
342 | "text/plain": [
343 | ""
344 | ]
345 | },
346 | "metadata": {},
347 | "output_type": "display_data"
348 | }
349 | ],
350 | "source": [
351 | "### RED = TERMINAL (0)\n",
352 | "### GREEN = LEFT\n",
353 | "### BLUE = UP\n",
354 | "### PURPLE = RIGHT\n",
355 | "### ORANGE = DOWN\n",
356 | "\n",
357 | "show_policy(pi_hat, gw)"
358 | ]
359 | }
360 | ],
361 | "metadata": {
362 | "kernelspec": {
363 | "display_name": "Python 3",
364 | "language": "python",
365 | "name": "python3"
366 | },
367 | "language_info": {
368 | "codemirror_mode": {
369 | "name": "ipython",
370 | "version": 3
371 | },
372 | "file_extension": ".py",
373 | "mimetype": "text/x-python",
374 | "name": "python",
375 | "nbconvert_exporter": "python",
376 | "pygments_lexer": "ipython3",
377 | "version": "3.6.1"
378 | }
379 | },
380 | "nbformat": 4,
381 | "nbformat_minor": 2
382 | }
383 |
--------------------------------------------------------------------------------
/source code/2-Monte Carlo Methods (Chapter 5)/4-off-policy-monte-carlo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Off-policy Monte Carlo\n",
8 | "- Algorithms from ```pp. 84 - 92``` in Sutton & Barto 2017\n",
9 | " - Off-policy prediction via importance sampling\n",
10 | " - Off-policy MC control"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 1,
16 | "metadata": {
17 | "collapsed": true
18 | },
19 | "outputs": [],
20 | "source": [
21 | "import matplotlib.pyplot as plt\n",
22 | "import pandas as pd\n",
23 | "import numpy as np\n",
24 | "import seaborn\n",
25 | "\n",
26 | "from gridWorldEnvironment import GridWorld"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": 2,
32 | "metadata": {
33 | "collapsed": true
34 | },
35 | "outputs": [],
36 | "source": [
37 | "# creating gridworld environment\n",
38 | "gw = GridWorld(gamma = .9)"
39 | ]
40 | },
41 | {
42 | "cell_type": "markdown",
43 | "metadata": {},
44 | "source": [
45 | "### Generate functions for generating policy"
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": 3,
51 | "metadata": {
52 | "collapsed": true
53 | },
54 | "outputs": [],
55 | "source": [
56 | "def generate_any_policy(env):\n",
57 | " pi = dict()\n",
58 | " for state in env.states:\n",
59 | " r = sorted(np.random.sample(3))\n",
60 | " actions = env.actions\n",
61 | " prob = [r[0], r[1] - r[0], r[2] - r[1], 1-r[2]]\n",
62 | " pi[state] = (actions, prob)\n",
63 | " return pi "
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": 4,
69 | "metadata": {
70 | "collapsed": true
71 | },
72 | "outputs": [],
73 | "source": [
74 | "def generate_random_policy(env):\n",
75 | " pi = dict()\n",
76 | " for state in env.states:\n",
77 | " actions = []\n",
78 | " prob = []\n",
79 | " for action in env.actions:\n",
80 | " actions.append(action)\n",
81 | " prob.append(0.25)\n",
82 | " pi[state] = (actions, prob)\n",
83 | " return pi"
84 | ]
85 | },
86 | {
87 | "cell_type": "code",
88 | "execution_count": 5,
89 | "metadata": {
90 | "collapsed": true
91 | },
92 | "outputs": [],
93 | "source": [
94 | "def generate_greedy_policy(env, Q):\n",
95 | " pi = dict()\n",
96 | " for state in env.states:\n",
97 | " actions = []\n",
98 | " q_values = []\n",
99 | " prob = []\n",
100 | " \n",
101 | " for a in env.actions:\n",
102 | " actions.append(a)\n",
103 | " q_values.append(Q[state,a]) \n",
104 | " for i in range(len(q_values)):\n",
105 | " if i == np.argmax(q_values):\n",
106 | " prob.append(1)\n",
107 | " else:\n",
108 | " prob.append(0) \n",
109 | " \n",
110 | " pi[state] = (actions, prob)\n",
111 | " return pi"
112 | ]
113 | },
114 | {
115 | "cell_type": "markdown",
116 | "metadata": {},
117 | "source": [
118 | "### Create Action Values\n",
119 | "- Initialize all state-action values with 0"
120 | ]
121 | },
122 | {
123 | "cell_type": "code",
124 | "execution_count": 6,
125 | "metadata": {
126 | "collapsed": true
127 | },
128 | "outputs": [],
129 | "source": [
130 | "def state_action_value(env):\n",
131 | " q = dict()\n",
132 | " for state, action, next_state, reward in env.transitions:\n",
133 | " q[(state, action)] = np.random.random()\n",
134 | " return q"
135 | ]
136 | },
137 | {
138 | "cell_type": "code",
139 | "execution_count": 7,
140 | "metadata": {
141 | "collapsed": true
142 | },
143 | "outputs": [],
144 | "source": [
145 | "def weight_cum_sum(env):\n",
146 | " c = dict()\n",
147 | " for state, action, next_state, reward in env.transitions:\n",
148 | " c[(state, action)] = 0\n",
149 | " return c"
150 | ]
151 | },
152 | {
153 | "cell_type": "markdown",
154 | "metadata": {},
155 | "source": [
156 | "### Generate episode\n",
157 | "- Generate episode based on current policy ($\\pi$)"
158 | ]
159 | },
160 | {
161 | "cell_type": "code",
162 | "execution_count": 8,
163 | "metadata": {
164 | "collapsed": true,
165 | "scrolled": true
166 | },
167 | "outputs": [],
168 | "source": [
169 | "def generate_episode(env, policy):\n",
170 | " episode = []\n",
171 | " done = False\n",
172 | " current_state = np.random.choice(env.states)\n",
173 | " action = np.random.choice(policy[current_state][0], p = policy[current_state][1])\n",
174 | " episode.append((current_state, action, -1))\n",
175 | " \n",
176 | " while not done:\n",
177 | " next_state, reward = gw.state_transition(current_state, action)\n",
178 | " action = np.random.choice(policy[current_state][0], p = policy[current_state][1])\n",
179 | " episode.append((next_state, action, reward))\n",
180 | " \n",
181 | " if next_state == 0:\n",
182 | " done = True\n",
183 | " current_state = next_state\n",
184 | " \n",
185 | " return episode"
186 | ]
187 | },
188 | {
189 | "cell_type": "code",
190 | "execution_count": 9,
191 | "metadata": {
192 | "collapsed": true,
193 | "scrolled": true
194 | },
195 | "outputs": [],
196 | "source": [
197 | "pi = generate_random_policy(gw)"
198 | ]
199 | },
200 | {
201 | "cell_type": "markdown",
202 | "metadata": {},
203 | "source": [
204 | "### Off-policy MC Prediction\n",
205 | "- Estimates $Q$ values"
206 | ]
207 | },
208 | {
209 | "cell_type": "code",
210 | "execution_count": 10,
211 | "metadata": {
212 | "collapsed": true
213 | },
214 | "outputs": [],
215 | "source": [
216 | "def off_policy_mc_prediction(env, pi, num_iter):\n",
217 | " Q = state_action_value(env)\n",
218 | " C = weight_cum_sum(env)\n",
219 | " \n",
220 | " for _ in range(num_iter):\n",
221 | " b = generate_any_policy(env)\n",
222 | " episode = generate_episode(gw, b)\n",
223 | " G = 0\n",
224 | " W = 1\n",
225 | " for i in range(len(episode)-1, -1, -1):\n",
226 | " s, a, r = episode[i]\n",
227 | " if s != 0:\n",
228 | " G = env.gamma * G + r\n",
229 | " C[s,a] += W\n",
230 | " Q[s,a] += (W / C[s,a]) * (G - Q[s,a])\n",
231 | " W *= pi[s][1][pi[s][0].index(a)] / b[s][1][b[s][0].index(a)]\n",
232 | " if W == 0:\n",
233 | " break\n",
234 | " \n",
235 | " return Q"
236 | ]
237 | },
238 | {
239 | "cell_type": "markdown",
240 | "metadata": {},
241 | "source": [
242 | "### Off-policy MC Control\n",
243 | "- Finds optimal policy $pi \\approx pi_*$"
244 | ]
245 | },
246 | {
247 | "cell_type": "code",
248 | "execution_count": 12,
249 | "metadata": {
250 | "collapsed": true
251 | },
252 | "outputs": [],
253 | "source": [
254 | "def off_policy_mc_control(env, pi, num_iter):\n",
255 | " Q = state_action_value(env)\n",
256 | " C = weight_cum_sum(env)\n",
257 | " pi = generate_greedy_policy(env, Q)\n",
258 | " \n",
259 | " for _ in range(num_iter):\n",
260 | " b = generate_any_policy(env)\n",
261 | " episode = generate_episode(gw, b)\n",
262 | " G = 0\n",
263 | " W = 1\n",
264 | " for i in range(len(episode)-1, -1, -1):\n",
265 | " s, a, r = episode[i]\n",
266 | " if s != 0:\n",
267 | " G = env.gamma * G + r\n",
268 | " C[s,a] += W\n",
269 | " Q[s,a] += (W / C[s,a]) * (G - Q[s,a])\n",
270 | " pi = generate_greedy_policy(env, Q)\n",
271 | " if a == pi[s][0][np.argmax(pi[s][1])]:\n",
272 | " break\n",
273 | " W *= 1 / b[s][1][b[s][0].index(a)]\n",
274 | "\n",
275 | " return Q, pi"
276 | ]
277 | },
278 | {
279 | "cell_type": "code",
280 | "execution_count": 21,
281 | "metadata": {},
282 | "outputs": [
283 | {
284 | "name": "stdout",
285 | "output_type": "stream",
286 | "text": [
287 | "Wall time: 1.51 s\n"
288 | ]
289 | }
290 | ],
291 | "source": [
292 | "%%time\n",
293 | "Q_hat, pi_hat = off_policy_mc_control(gw, generate_random_policy(gw), 1000)"
294 | ]
295 | },
296 | {
297 | "cell_type": "code",
298 | "execution_count": 22,
299 | "metadata": {},
300 | "outputs": [
301 | {
302 | "data": {
303 | "text/plain": [
304 | "{1: (['U', 'D', 'L', 'R'], [0, 0, 1, 0]),\n",
305 | " 2: (['U', 'D', 'L', 'R'], [0, 1, 0, 0]),\n",
306 | " 3: (['U', 'D', 'L', 'R'], [0, 1, 0, 0]),\n",
307 | " 4: (['U', 'D', 'L', 'R'], [1, 0, 0, 0]),\n",
308 | " 5: (['U', 'D', 'L', 'R'], [1, 0, 0, 0]),\n",
309 | " 6: (['U', 'D', 'L', 'R'], [0, 0, 0, 1]),\n",
310 | " 7: (['U', 'D', 'L', 'R'], [0, 0, 0, 1]),\n",
311 | " 8: (['U', 'D', 'L', 'R'], [0, 0, 0, 1]),\n",
312 | " 9: (['U', 'D', 'L', 'R'], [0, 0, 1, 0]),\n",
313 | " 10: (['U', 'D', 'L', 'R'], [0, 0, 0, 1]),\n",
314 | " 11: (['U', 'D', 'L', 'R'], [0, 1, 0, 0]),\n",
315 | " 12: (['U', 'D', 'L', 'R'], [0, 0, 1, 0]),\n",
316 | " 13: (['U', 'D', 'L', 'R'], [0, 0, 1, 0]),\n",
317 | " 14: (['U', 'D', 'L', 'R'], [0, 0, 0, 1])}"
318 | ]
319 | },
320 | "execution_count": 22,
321 | "metadata": {},
322 | "output_type": "execute_result"
323 | }
324 | ],
325 | "source": [
326 | "# final policy obtained\n",
327 | "pi_hat"
328 | ]
329 | },
330 | {
331 | "cell_type": "markdown",
332 | "metadata": {},
333 | "source": [
334 | "### Visualizing policy"
335 | ]
336 | },
337 | {
338 | "cell_type": "code",
339 | "execution_count": 23,
340 | "metadata": {
341 | "collapsed": true
342 | },
343 | "outputs": [],
344 | "source": [
345 | "def show_policy(pi, env):\n",
346 | " temp = np.zeros(len(env.states) + 2)\n",
347 | " for s in env.states:\n",
348 | " a = pi_hat[s][0][np.argmax(pi_hat[s][1])]\n",
349 | " if a == \"U\":\n",
350 | " temp[s] = 0.25\n",
351 | " elif a == \"D\":\n",
352 | " temp[s] = 0.5\n",
353 | " elif a == \"R\":\n",
354 | " temp[s] = 0.75\n",
355 | " else:\n",
356 | " temp[s] = 1.0\n",
357 | " \n",
358 | " temp = temp.reshape(4,4)\n",
359 | " ax = seaborn.heatmap(temp, cmap = \"prism\", linecolor=\"#282828\", cbar = False, linewidths = 0.1)\n",
360 | " plt.show()"
361 | ]
362 | },
363 | {
364 | "cell_type": "code",
365 | "execution_count": 24,
366 | "metadata": {},
367 | "outputs": [
368 | {
369 | "data": {
370 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAdwAAAFJCAYAAAAxCJwFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACsZJREFUeJzt3UtoXIXfxvFfmhEV4wUsSBGEbtoMCC0u3HURaLwUJFQK\ntkoUJCsXTaEWGVMveGnrSvAS6mXlQuPK4kIo1hYFBReFFgITRQTxtnElbbFpnfkv5N8XX9q8Ji/z\njJl+PrucQ8489DB8O4eEDHW73W4BAD21pt8DAOBqILgAECC4ABAguAAQILgAECC4ABDQ6OXFm81m\ntRcWevkS9EhzdLTubLt3q9V8c7TaH7l/q1Fz+2hVlfu3SjW3j1a73b7sOZ9wASBAcAEgQHABIEBw\nASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHAB\nIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEg\nQHABIEBwASDgHwe30+n0cgcADLTGUid//PHHOnjwYM3Pz1ej0ahOp1MbNmyoVqtV69evT20EgFVv\nyeDOzMzU3r17a9OmTZeOnTp1qlqtVs3NzfV8HAAMiiUfKS8uLv4ttlVVmzdv7ukgABhES37C3bhx\nY7VardqyZUvdeOONdfbs2fr8889r48aNqX0AMBCWDO7zzz9fx44dq5MnT9aZM2dqZGSkxsbGanx8\nPLUPAAbCksEdGhqq8fFxgQWA/ye/hwsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwAB\nggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGC\nCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAFD3W6326uLN5vNXl0aAP6V2u32ZY83\nev3CC5ve7fVL0AOjp6fcu1Vs9PRUzSx80O8ZrMDLo7uqqty/Veq/9+9yPFIGgADBBYAAwQWAAMEF\ngADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWA\nAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAA\nwQWAAMEFgADBBYAAwQWAgMZSJycnJ+vChQt/O9btdmtoaKjm5uZ6OgwABsmSwX3yySdr//799eab\nb9bw8HBqEwAMnCWDu2nTppqYmKhvvvmmxsfHU5sAYOAsGdyqqqmpqcQOABhofmgKAAIEFwACBBcA\nAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwAC\nBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIE\nFwACBBcAAoa63W63VxdvNpu9ujQA/Cu12+3LHm/0+oVnFj7o9UvQAy+P7qo72wv9nsEKzTdHvfdW\nqZdHd1VVVfsj77/VqLl99IrnPFIGgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEF\ngADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWA\nAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAgGUHd3FxsRc7AGCgXTG4x48fr7Gx\nsRofH69PPvnk0vGpqanIMAAYJI0rnTh8+HAdOXKkOp1OTU9P1/nz52v79u3V7XaT+wBgIFwxuNdc\nc03dfPPNVVU1Oztbjz32WK1bt66GhoZi4wBgUFzxkfLtt99eBw8erHPnztXIyEi98cYb9cILL9T3\n33+f3AcAA+GKwT1w4EBt3Ljx0ifadevW1XvvvVf3339/bBwADIorPlJuNBr14IMP/u3Y2rVra2Zm\npuejAGDQ+D1cAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBc\nAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwA\nCBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBjqdrvdXl282Wz26tIA8K/Ubrcve7zR6xe+s73Q65eg\nB+abo+7dKjbfHK2ZhQ/6PYMVeHl0V1VVtRe8/1aj5ujoFc95pAwAAYILAAGCCwABggsAAYILAAGC\nCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYIL\nAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABywru\nH3/8UYuLi73aAgADa8ngfvfdd/XEE09Uq9Wqr776qrZt21bbtm2rEydOpPYBwEBoLHXyueeeq+np\n6fr5559r9+7ddfTo0br22mtramqqxsbGUhsBYNVbMridTqfuvvvuqqr6+uuv69Zbb/3rmxpLfhsA\n8L8s+Uh5/fr1NTMzU51Opw4dOlRVVW+//XatXbs2Mg4ABsWSH1VfeumlOn78eK1Z8z9dvu2222py\ncrLnwwBgkCwZ3DVr1tTWrVv/dmxiYqKngwBgEPk9XAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAI\nEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQ\nXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBjqdrvdXl282Wz26tIA\n8K/Ubrcve7ynwQUA/uKRMgAECC4ABAguAAQILgAECC4ABAguAAQI7gp0Op169tln66GHHqrJycn6\n4Ycf+j2JZTp9+nRNTk72ewbLdOHChdq3b189/PDDtWPHjvrss8/6PYl/6M8//6xWq1U7d+6sXbt2\n1bffftvvSXGCuwLHjh2rxcXF+vDDD2vv3r116NChfk9iGd55553av39/nT9/vt9TWKaPP/64brnl\nlnr//ffr3XffrRdffLHfk/iHTpw4UVVVc3NztWfPnnr11Vf7vChPcFfg5MmTtWXLlqqq2rx5c83P\nz/d5Ectxxx131Ouvv97vGazAfffdV9PT01VV1e12a3h4uM+L+Ke2bt166T9Iv/zyS9100019XpTX\n6PeA1ejMmTM1MjJy6evh4eG6ePFiNRr+OVeDe++9t3766ad+z2AFbrjhhqr66z24e/fu2rNnT58X\nsRyNRqOeeuqp+vTTT+u1117r95w4n3BXYGRkpM6ePXvp606nI7YQ8uuvv9ajjz5aExMT9cADD/R7\nDsv0yiuv1NGjR+uZZ56pc+fO9XtOlOCuwF133VVffPFFVVWdOnWqNmzY0OdFcHX47bff6vHHH699\n+/bVjh07+j2HZThy5Ei99dZbVVV1/fXX19DQUK1Zc3UlyMeyFRgfH68vv/yydu7cWd1utw4cONDv\nSXBVOHz4cP3+++81Oztbs7OzVfXXD8Fdd911fV7G/+Wee+6pVqtVjzzySF28eLGefvrpq+6++WtB\nABBwdX2eB4A+EVwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACPgPqwelrdFfLd8AAAAASUVORK5C\nYII=\n",
371 | "text/plain": [
372 | ""
373 | ]
374 | },
375 | "metadata": {},
376 | "output_type": "display_data"
377 | }
378 | ],
379 | "source": [
380 | "### RED = TERMINAL (0)\n",
381 | "### GREEN = LEFT\n",
382 | "### BLUE = UP\n",
383 | "### PURPLE = RIGHT\n",
384 | "### ORANGE = DOWN\n",
385 | "\n",
386 | "show_policy(pi_hat, gw)"
387 | ]
388 | }
389 | ],
390 | "metadata": {
391 | "kernelspec": {
392 | "display_name": "Python 3",
393 | "language": "python",
394 | "name": "python3"
395 | },
396 | "language_info": {
397 | "codemirror_mode": {
398 | "name": "ipython",
399 | "version": 3
400 | },
401 | "file_extension": ".py",
402 | "mimetype": "text/x-python",
403 | "name": "python",
404 | "nbconvert_exporter": "python",
405 | "pygments_lexer": "ipython3",
406 | "version": "3.6.1"
407 | }
408 | },
409 | "nbformat": 4,
410 | "nbformat_minor": 2
411 | }
412 |
--------------------------------------------------------------------------------
/source code/2-Monte Carlo Methods (Chapter 5)/__pycache__/gridWorldEnvironment.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buomsoo-kim/Tabular-RL-with-Python/0157c05126821524dc9d744613b9aa8dd5d47232/source code/2-Monte Carlo Methods (Chapter 5)/__pycache__/gridWorldEnvironment.cpython-36.pyc
--------------------------------------------------------------------------------
/source code/2-Monte Carlo Methods (Chapter 5)/gridWorldEnvironment.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import seaborn
4 | from matplotlib.colors import ListedColormap
5 |
6 | class GridWorld:
7 | def __init__(self, gamma = 1.0, theta = 0.5):
8 | self.actions = ("U", "D", "L", "R")
9 | self.states = np.arange(1, 15)
10 | self.transitions = pd.read_csv("gridworld.txt", header = None, sep = "\t").values
11 | self.gamma = gamma
12 | self.theta = theta
13 |
14 | def state_transition(self, state, action):
15 | next_state, reward = None, None
16 | for tr in self.transitions:
17 | if tr[0] == state and tr[1] == action:
18 | next_state = tr[2]
19 | reward = tr[3]
20 | return next_state, reward
21 |
22 | def show_environment(self):
23 | all_states = np.concatenate(([0], self.states, [0])).reshape(4,4)
24 | colors = []
25 | # colors = ["#ffffff"]
26 | for i in range(len(self.states) + 1):
27 | if i == 0:
28 | colors.append("#c4c4c4")
29 | else:
30 | colors.append("#ffffff")
31 |
32 | cmap = ListedColormap(seaborn.color_palette(colors).as_hex())
33 | ax = seaborn.heatmap(all_states, cmap = cmap, \
34 | annot = True, linecolor = "#282828", linewidths = 0.2, \
35 | cbar = False)
--------------------------------------------------------------------------------
/source code/2-Monte Carlo Methods (Chapter 5)/gridworld.txt:
--------------------------------------------------------------------------------
1 | 1 U 1 -1
2 | 1 D 5 -1
3 | 1 R 2 -1
4 | 1 L 0 -1
5 | 2 U 2 -1
6 | 2 D 6 -1
7 | 2 R 3 -1
8 | 2 L 1 -1
9 | 3 U 3 -1
10 | 3 D 7 -1
11 | 3 R 3 -1
12 | 3 L 2 -1
13 | 4 U 0 -1
14 | 4 D 8 -1
15 | 4 R 5 -1
16 | 4 L 4 -1
17 | 5 U 1 -1
18 | 5 D 9 -1
19 | 5 R 6 -1
20 | 5 L 4 -1
21 | 6 U 2 -1
22 | 6 D 10 -1
23 | 6 R 7 -1
24 | 6 L 5 -1
25 | 7 U 3 -1
26 | 7 D 11 -1
27 | 7 R 7 -1
28 | 7 L 6 -1
29 | 8 U 4 -1
30 | 8 D 12 -1
31 | 8 R 9 -1
32 | 8 L 8 -1
33 | 9 U 5 -1
34 | 9 D 13 -1
35 | 9 R 10 -1
36 | 9 L 8 -1
37 | 10 U 6 -1
38 | 10 D 14 -1
39 | 10 R 11 -1
40 | 10 L 9 -1
41 | 11 U 7 -1
42 | 11 D 0 -1
43 | 11 R 11 -1
44 | 11 L 10 -1
45 | 12 U 8 -1
46 | 12 D 12 -1
47 | 12 R 13 -1
48 | 12 L 12 -1
49 | 13 U 9 -1
50 | 13 D 13 -1
51 | 13 R 14 -1
52 | 13 L 12 -1
53 | 14 U 10 -1
54 | 14 D 14 -1
55 | 14 R 0 -1
56 | 14 L 13 -1
--------------------------------------------------------------------------------
/source code/3-Temporal Difference Learning (Chapter 6)/.ipynb_checkpoints/2-SARSA-on-policy control-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# SARSA\n",
8 | "- Algorithms from ```pp. 105 - 107``` in Sutton & Barto 2017"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": 1,
14 | "metadata": {
15 | "collapsed": true
16 | },
17 | "outputs": [],
18 | "source": [
19 | "import matplotlib.pyplot as plt\n",
20 | "import pandas as pd\n",
21 | "import numpy as np\n",
22 | "import seaborn, random\n",
23 | "\n",
24 | "from gridWorldEnvironment import GridWorld"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": 2,
30 | "metadata": {
31 | "collapsed": true
32 | },
33 | "outputs": [],
34 | "source": [
35 | "# creating gridworld environment\n",
36 | "gw = GridWorld(gamma = .9, theta = .5)"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": 3,
42 | "metadata": {
43 | "collapsed": true
44 | },
45 | "outputs": [],
46 | "source": [
47 | "def state_action_value(env):\n",
48 | " q = dict()\n",
49 | " for state, action, next_state, reward in env.transitions:\n",
50 | " q[(state, action)] = np.random.normal()\n",
51 | " for action in env.actions:\n",
52 | " q[0, action] = 0\n",
53 | " return q"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": 4,
59 | "metadata": {
60 | "scrolled": true
61 | },
62 | "outputs": [
63 | {
64 | "data": {
65 | "text/plain": [
66 | "{(0, 'D'): 0,\n",
67 | " (0, 'L'): 0,\n",
68 | " (0, 'R'): 0,\n",
69 | " (0, 'U'): 0,\n",
70 | " (1, 'D'): -1.4931958690513218,\n",
71 | " (1, 'L'): 0.45253247253698986,\n",
72 | " (1, 'R'): -1.789094792083647,\n",
73 | " (1, 'U'): 1.9103660029206884,\n",
74 | " (2, 'D'): -0.3082624959814856,\n",
75 | " (2, 'L'): -0.7483834716798741,\n",
76 | " (2, 'R'): 0.9839358952573672,\n",
77 | " (2, 'U'): -0.2724328270166447,\n",
78 | " (3, 'D'): -0.6283488057971491,\n",
79 | " (3, 'L'): -1.3156943242567214,\n",
80 | " (3, 'R'): 0.6211123056414489,\n",
81 | " (3, 'U'): 0.7976038544679848,\n",
82 | " (4, 'D'): -1.050452706273533,\n",
83 | " (4, 'L'): -0.1741895951805081,\n",
84 | " (4, 'R'): 1.8182867323493,\n",
85 | " (4, 'U'): 1.801387714646261,\n",
86 | " (5, 'D'): 0.10021107202221527,\n",
87 | " (5, 'L'): -0.47465610710346245,\n",
88 | " (5, 'R'): -0.6918872493076375,\n",
89 | " (5, 'U'): -3.2003225504816597,\n",
90 | " (6, 'D'): 0.6278525454978494,\n",
91 | " (6, 'L'): -0.24213713052505115,\n",
92 | " (6, 'R'): -2.630728473016553,\n",
93 | " (6, 'U'): 1.2604497963698,\n",
94 | " (7, 'D'): 0.7220353661454653,\n",
95 | " (7, 'L'): -0.046864386969943536,\n",
96 | " (7, 'R'): 0.3648650012222875,\n",
97 | " (7, 'U'): -0.13955201414490984,\n",
98 | " (8, 'D'): -0.73401777416155,\n",
99 | " (8, 'L'): 0.32247567696601837,\n",
100 | " (8, 'R'): 0.5780991299024739,\n",
101 | " (8, 'U'): -0.44090209956778376,\n",
102 | " (9, 'D'): 1.4691269385027337,\n",
103 | " (9, 'L'): 0.052954400141560456,\n",
104 | " (9, 'R'): 0.13195379460282597,\n",
105 | " (9, 'U'): -0.3923944749627829,\n",
106 | " (10, 'D'): 1.7199147774865016,\n",
107 | " (10, 'L'): -1.9247987278054801,\n",
108 | " (10, 'R'): -0.143510086697551,\n",
109 | " (10, 'U'): -0.5647971071775687,\n",
110 | " (11, 'D'): -0.5741454748190717,\n",
111 | " (11, 'L'): -1.5720736584714539,\n",
112 | " (11, 'R'): -1.4601134792863597,\n",
113 | " (11, 'U'): -1.257386467083216,\n",
114 | " (12, 'D'): -0.7472490605228015,\n",
115 | " (12, 'L'): -0.22563599050611213,\n",
116 | " (12, 'R'): 0.4395772978492674,\n",
117 | " (12, 'U'): -0.7210157758438556,\n",
118 | " (13, 'D'): 0.6291351870645969,\n",
119 | " (13, 'L'): -1.1313917732858065,\n",
120 | " (13, 'R'): 1.7365108020870084,\n",
121 | " (13, 'U'): 0.1339793058824657,\n",
122 | " (14, 'D'): 1.028938708862189,\n",
123 | " (14, 'L'): 0.050033944728027975,\n",
124 | " (14, 'R'): 1.3081957299962321,\n",
125 | " (14, 'U'): 0.05849306927859837}"
126 | ]
127 | },
128 | "execution_count": 4,
129 | "metadata": {},
130 | "output_type": "execute_result"
131 | }
132 | ],
133 | "source": [
134 | "state_action_value(gw)"
135 | ]
136 | },
137 | {
138 | "cell_type": "code",
139 | "execution_count": 5,
140 | "metadata": {
141 | "collapsed": true
142 | },
143 | "outputs": [],
144 | "source": [
145 | "def generate_greedy_policy(env, Q):\n",
146 | " pi = dict()\n",
147 | " for state in env.states:\n",
148 | " actions = []\n",
149 | " q_values = []\n",
150 | " prob = []\n",
151 | " \n",
152 | " for a in env.actions:\n",
153 | " actions.append(a)\n",
154 | " q_values.append(Q[state,a]) \n",
155 | " for i in range(len(q_values)):\n",
156 | " if i == np.argmax(q_values):\n",
157 | " prob.append(1)\n",
158 | " else:\n",
159 | " prob.append(0) \n",
160 | " \n",
161 | " pi[state] = (actions, prob)\n",
162 | " return pi"
163 | ]
164 | },
165 | {
166 | "cell_type": "code",
167 | "execution_count": 6,
168 | "metadata": {
169 | "collapsed": true
170 | },
171 | "outputs": [],
172 | "source": [
173 | "# selects action epsilon-greedily, given current state\n",
174 | "def e_greedy(env, e, q, state):\n",
175 | " actions = env.actions\n",
176 | " action_values = []\n",
177 | " prob = []\n",
178 | " for action in actions:\n",
179 | " action_values.append(q[(state, action)])\n",
180 | " for i in range(len(action_values)):\n",
181 | " if i == np.argmax(action_values):\n",
182 | " prob.append(1 - e + e/len(action_values))\n",
183 | " else:\n",
184 | " prob.append(e/len(action_values))\n",
185 | " return np.random.choice(actions, p = prob)"
186 | ]
187 | },
188 | {
189 | "cell_type": "markdown",
190 | "metadata": {
191 | "collapsed": true
192 | },
193 | "source": [
194 | "### SARSA: On-policy TD Control\n",
195 | "- Evaluates action-value function ($Q$)"
196 | ]
197 | },
198 | {
199 | "cell_type": "code",
200 | "execution_count": 7,
201 | "metadata": {
202 | "collapsed": true
203 | },
204 | "outputs": [],
205 | "source": [
206 | "def sarsa(env, epsilon, alpha, num_iter):\n",
207 | " Q = state_action_value(env)\n",
208 | " for _ in range(num_iter):\n",
209 | " current_state = np.random.choice(env.states)\n",
210 | " current_action = e_greedy(env, epsilon, Q, current_state)\n",
211 | " while current_state != 0:\n",
212 | " next_state, reward = env.state_transition(current_state, current_action)\n",
213 | " next_action = e_greedy(env, epsilon, Q, next_state)\n",
214 | " Q[current_state, current_action] += alpha * (reward + env.gamma * Q[next_state, next_action] - Q[current_state, current_action])\n",
215 | " current_state, current_action = next_state, next_action\n",
216 | " return Q"
217 | ]
218 | },
219 | {
220 | "cell_type": "code",
221 | "execution_count": 8,
222 | "metadata": {
223 | "collapsed": true
224 | },
225 | "outputs": [],
226 | "source": [
227 | "Q = sarsa(gw, 0.1, 0.5, 10000)"
228 | ]
229 | },
230 | {
231 | "cell_type": "code",
232 | "execution_count": 9,
233 | "metadata": {
234 | "collapsed": true
235 | },
236 | "outputs": [],
237 | "source": [
238 | "pi_hat = generate_greedy_policy(gw, Q)"
239 | ]
240 | },
241 | {
242 | "cell_type": "markdown",
243 | "metadata": {},
244 | "source": [
245 | "### Visualizing policy"
246 | ]
247 | },
248 | {
249 | "cell_type": "code",
250 | "execution_count": 10,
251 | "metadata": {
252 | "collapsed": true
253 | },
254 | "outputs": [],
255 | "source": [
256 | "def show_policy(pi, env):\n",
257 | " temp = np.zeros(len(env.states) + 2)\n",
258 | " for s in env.states:\n",
259 | " a = pi_hat[s][0][np.argmax(pi_hat[s][1])]\n",
260 | " if a == \"U\":\n",
261 | " temp[s] = 0.25\n",
262 | " elif a == \"D\":\n",
263 | " temp[s] = 0.5\n",
264 | " elif a == \"R\":\n",
265 | " temp[s] = 0.75\n",
266 | " else:\n",
267 | " temp[s] = 1.0\n",
268 | " \n",
269 | " temp = temp.reshape(4,4)\n",
270 | " ax = seaborn.heatmap(temp, cmap = \"prism\", linecolor=\"#282828\", cbar = False, linewidths = 0.1)\n",
271 | " plt.show()"
272 | ]
273 | },
274 | {
275 | "cell_type": "code",
276 | "execution_count": 11,
277 | "metadata": {},
278 | "outputs": [
279 | {
280 | "data": {
281 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAdwAAAFJCAYAAAAxCJwFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACs5JREFUeJzt3UtoXIXfxvFfmhEV4wUsSBGEbpoMCCku3HURaLwUJFQK\ntkoUJCsXbaEWiakXvLR1JXgJVbty4T+uLC6EYm1RUHBRaCEwUUQQbxtX0hab1pn/QuiLL01ek5d5\nxkw/n13OoWceegjfnkNCBzqdTqcAgK5a1+sBAHAtEFwACBBcAAgQXAAIEFwACBBcAAhodPPizWaz\nWgsL3fwIuqQ5MlJ3t9y7tWq+6f6tVfPNkaqqan3k/q1Fze0j1Wq1rnrOEy4ABAguAAQILgAECC4A\nBAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAE\nCC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQI\nLgAECC4ABPzj4Lbb7W7uAIC+1lju5I8//liHDh2q+fn5ajQa1W63a9OmTTU9PV0bN25MbQSANW/Z\n4M7MzNS+fftqdHT0yrEzZ87U9PR0zc3NdX0cAPSLZV8pLy4u/i22VVWbN2/u6iAA6EfLPuEODw/X\n9PR0bdmypW6++eY6f/58ff755zU8PJzaBwB9Ydngvvjii3XixIk6ffp0nTt3roaGhmpsbKzGx8dT\n+wCgLywb3IGBgRofHxdYAPh/8nu4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4\nABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgA\nECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQMNDpdDrduniz2ezWpQHgX6nVal31eKPb\nH7wwerTbH0EXjJydqrtbC72ewSrNN0fcvzVqvjlSVVWtj9y/tai5fWTJc14pA0CA4AJAgOACQIDg\nAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOAC\nQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJA\ngOACQIDgAkCA4AJAgOACQEBjuZOTk5N16dKlvx3rdDo1MDBQc3NzXR0GAP1k2eA+/fTTdeDAgXr7\n7bdrcHAwtQkA+s6ywR0dHa2JiYn65ptvanx8PLUJAPrOssGtqpqamkrsAIC+5oemACBAcAEgQHAB\nIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEg\nQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBA\ncAEgQHABIGCg0+l0unXxZrPZrUsDwL9Sq9W66vFGtz94YfRotz+CLhg5O1V3txZ6PYNVmm+O1MzC\nf3o9g1V4dWRXVVW1PvL9txY1t48sec4rZQAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBc\nAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwA\nCBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAhYcXAXFxe7sQMA+tqSwT15\n8mSNjY3V+Ph4ffLJJ1eOT01NRYYBQD9pLHXiyJEjdezYsWq327Vnz566ePFibd++vTqdTnIfAPSF\nJYN73XXX1a233lpVVbOzs/XEE0/Uhg0bamBgIDYOAPrFkq+U77zzzjp06FBduHChhoaG6q233qqX\nXnqpvv/+++Q+AOgLSwb34MGDNTw8fOWJdsOGDfX+++/Xgw8+GBsHAP1iyVfKjUajHn744b8dW79+\nfc3MzHR9FAD0G7+HCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYIL\nAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsA\nAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAEDnU6n062LN5vNbl0aAP6VWq3WVY83uv3BC6NH\nu/0RdMHI2amaWfhPr2ewSq+O7HL/1qhXR3ZVVVVrYaHHS1iN5sjIkue8UgaAAMEFgADBBYAAwQWA\nAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAA\nwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADB\nBYCAFQX3jz/+qMXFxW5tAYC+tWxwv/vuu3rqqadqenq6vvrqq9q2bVtt27atTp06ldoHAH2hsdzJ\nF154ofbs2VM///xz7d69u44fP17XX399TU1N1djYWGojAKx5ywa33W7XvffeW1VVX3/9dd1+++1/\n/aHGsn8MAPhfln2lvHHjxpqZmal2u12HDx+uqqp333231q9fHxkHAP1i2UfVV155pU6ePFnr1v1P\nl++4446anJzs+jAA6CfLBnfdunW1devWvx2bmJjo6iAA6Ed+DxcAAgQXAAIEFwACBBcAAgQXAAIE\nFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQX\nAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIGOp1Op1sX\nbzab3bo0APwrtVqtqx7vanABgL94pQwAAYILAAGCCwABggsAAYILAAGCCwABgrsK7Xa7nn/++Xrk\nkUdqcnKyfvjhh15PYoXOnj1bk5OTvZ7BCl26dKn2799fjz76aO3YsaM+++yzXk/iH/rzzz9renq6\ndu7cWbt27apvv/2215PiBHcVTpw4UYuLi/Xhhx/Wvn376vDhw72exAq89957deDAgbp48WKvp7BC\nH3/8cd122231wQcf1NGjR+vll1/u9ST+oVOnTlVV1dzcXO3du7def/31Hi/KE9xVOH36dG3ZsqWq\nqjZv3lzz8/M9XsRK3HXXXfXmm2/2egar8MADD9SePXuqqqrT6dTg4GCPF/FPbd269co/kH755Ze6\n5ZZberwor9HrAWvRuXPnamho6MrXg4ODdfny5Wo0/HWuBffff3/99NNPvZ7BKtx0001V9df34O7d\nu2vv3r09XsRKNBqNeuaZZ+rTTz+tN954o9dz4jzhrsLQ0FCdP3/+ytftdltsIeTXX3+txx9/vCYm\nJuqhhx7q9RxW6LXXXqvjx4/Xc889VxcuXOj1nCjBXYV77rmnvvjii6qqOnPmTG3atKnHi+Da8Ntv\nv9WTTz5Z+/fvrx07dvR6Ditw7Nixeuedd6qq6sYbb6yBgYFat+7aSpDHslUYHx+vL7/8snbu3Fmd\nTqcOHjzY60lwTThy5Ej9/vvvNTs7W7Ozs1X11w/B3XDDDT1exv/lvvvuq+np6Xrsscfq8uXL9eyz\nz15z983/FgQAAdfW8zwA9IjgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkDAfwHhhKWt2U7SnQAA\nAABJRU5ErkJggg==\n",
282 | "text/plain": [
283 | ""
284 | ]
285 | },
286 | "metadata": {},
287 | "output_type": "display_data"
288 | }
289 | ],
290 | "source": [
291 | "### RED = TERMINAL (0)\n",
292 | "### GREEN = LEFT\n",
293 | "### BLUE = UP\n",
294 | "### PURPLE = RIGHT\n",
295 | "### ORANGE = DOWN\n",
296 | "\n",
297 | "show_policy(pi_hat, gw)"
298 | ]
299 | }
300 | ],
301 | "metadata": {
302 | "kernelspec": {
303 | "display_name": "Python 3",
304 | "language": "python",
305 | "name": "python3"
306 | },
307 | "language_info": {
308 | "codemirror_mode": {
309 | "name": "ipython",
310 | "version": 3
311 | },
312 | "file_extension": ".py",
313 | "mimetype": "text/x-python",
314 | "name": "python",
315 | "nbconvert_exporter": "python",
316 | "pygments_lexer": "ipython3",
317 | "version": "3.6.1"
318 | }
319 | },
320 | "nbformat": 4,
321 | "nbformat_minor": 2
322 | }
323 |
--------------------------------------------------------------------------------
/source code/3-Temporal Difference Learning (Chapter 6)/.ipynb_checkpoints/4-double-Q-learning-off-policy-control-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Double Q-Learning\n",
8 | "- Algorithms from ```pp. 110 - 111``` in Sutton & Barto 2017\n",
9 | "- Double Q-learning algorithm employs to action-value functions (i.e., $Q_1 and Q_2$) to avoid maximization bias"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 15,
15 | "metadata": {
16 | "collapsed": true
17 | },
18 | "outputs": [],
19 | "source": [
20 | "import matplotlib.pyplot as plt\n",
21 | "import pandas as pd\n",
22 | "import numpy as np\n",
23 | "import seaborn, random\n",
24 | "\n",
25 | "from gridWorldEnvironment import GridWorld"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": 16,
31 | "metadata": {
32 | "collapsed": true
33 | },
34 | "outputs": [],
35 | "source": [
36 | "# creating gridworld environment\n",
37 | "gw = GridWorld(gamma = .9, theta = .5)"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 17,
43 | "metadata": {
44 | "collapsed": true
45 | },
46 | "outputs": [],
47 | "source": [
48 | "def state_action_value(env):\n",
49 | " q = dict()\n",
50 | " for state, action, next_state, reward in env.transitions:\n",
51 | " q[(state, action)] = np.random.normal()\n",
52 | " for action in env.actions:\n",
53 | " q[0, action] = 0\n",
54 | " return q"
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": 18,
60 | "metadata": {
61 | "scrolled": true
62 | },
63 | "outputs": [
64 | {
65 | "data": {
66 | "text/plain": [
67 | "{(0, 'D'): 0,\n",
68 | " (0, 'L'): 0,\n",
69 | " (0, 'R'): 0,\n",
70 | " (0, 'U'): 0,\n",
71 | " (1, 'D'): 1.0159263422385194,\n",
72 | " (1, 'L'): -0.9015277224413166,\n",
73 | " (1, 'R'): -0.8401929147381398,\n",
74 | " (1, 'U'): -0.3866748964867951,\n",
75 | " (2, 'D'): -1.6171831004555488,\n",
76 | " (2, 'L'): -0.757413177831847,\n",
77 | " (2, 'R'): 1.1948020656778442,\n",
78 | " (2, 'U'): -0.6814167904466197,\n",
79 | " (3, 'D'): 0.8048046240684537,\n",
80 | " (3, 'L'): 0.5058075927359411,\n",
81 | " (3, 'R'): -0.2396998941031068,\n",
82 | " (3, 'U'): 0.2155229655857796,\n",
83 | " (4, 'D'): -1.6147931728590075,\n",
84 | " (4, 'L'): -0.7788455232241043,\n",
85 | " (4, 'R'): -1.4834265811813951,\n",
86 | " (4, 'U'): -0.6831094436413707,\n",
87 | " (5, 'D'): 0.553787569878218,\n",
88 | " (5, 'L'): -0.7535460982740707,\n",
89 | " (5, 'R'): -0.5446399864366045,\n",
90 | " (5, 'U'): 0.3275946472542217,\n",
91 | " (6, 'D'): 0.4967878126487371,\n",
92 | " (6, 'L'): -0.08765553689024165,\n",
93 | " (6, 'R'): 2.561671991266108,\n",
94 | " (6, 'U'): -0.6251803976764682,\n",
95 | " (7, 'D'): -1.4239033515729715,\n",
96 | " (7, 'L'): 0.8961390571157126,\n",
97 | " (7, 'R'): -1.752510423129062,\n",
98 | " (7, 'U'): 0.186255694272231,\n",
99 | " (8, 'D'): 2.0464864623182057,\n",
100 | " (8, 'L'): -1.4145441882793333,\n",
101 | " (8, 'R'): -0.01479790948318322,\n",
102 | " (8, 'U'): -0.5827888737714622,\n",
103 | " (9, 'D'): -1.99334821662137,\n",
104 | " (9, 'L'): -0.2143702145160287,\n",
105 | " (9, 'R'): 0.3093816164828963,\n",
106 | " (9, 'U'): -0.13810316479560683,\n",
107 | " (10, 'D'): -0.4286970274384264,\n",
108 | " (10, 'L'): -1.0313201895326287,\n",
109 | " (10, 'R'): 0.674733204171753,\n",
110 | " (10, 'U'): -0.06485579239030625,\n",
111 | " (11, 'D'): -1.08761491406436,\n",
112 | " (11, 'L'): 0.37568985747762257,\n",
113 | " (11, 'R'): 0.6549082083274831,\n",
114 | " (11, 'U'): -0.044872797471142374,\n",
115 | " (12, 'D'): -0.8474472601122482,\n",
116 | " (12, 'L'): -0.6624032631781939,\n",
117 | " (12, 'R'): 0.4720691318622512,\n",
118 | " (12, 'U'): -1.2333806138625314,\n",
119 | " (13, 'D'): -0.38599325593593187,\n",
120 | " (13, 'L'): -1.677564721893633,\n",
121 | " (13, 'R'): -1.1425774085316547,\n",
122 | " (13, 'U'): -1.3606173055898687,\n",
123 | " (14, 'D'): -1.1641372208339686,\n",
124 | " (14, 'L'): 1.1457975665987994,\n",
125 | " (14, 'R'): -1.012618770180745,\n",
126 | " (14, 'U'): -0.06091464445082383}"
127 | ]
128 | },
129 | "execution_count": 18,
130 | "metadata": {},
131 | "output_type": "execute_result"
132 | }
133 | ],
134 | "source": [
135 | "state_action_value(gw)"
136 | ]
137 | },
138 | {
139 | "cell_type": "code",
140 | "execution_count": 19,
141 | "metadata": {
142 | "collapsed": true
143 | },
144 | "outputs": [],
145 | "source": [
146 | "def generate_greedy_policy(env, Q):\n",
147 | " pi = dict()\n",
148 | " for state in env.states:\n",
149 | " actions = []\n",
150 | " q_values = []\n",
151 | " prob = []\n",
152 | " \n",
153 | " for a in env.actions:\n",
154 | " actions.append(a)\n",
155 | " q_values.append(Q[state,a]) \n",
156 | " for i in range(len(q_values)):\n",
157 | " if i == np.argmax(q_values):\n",
158 | " prob.append(1)\n",
159 | " else:\n",
160 | " prob.append(0) \n",
161 | " \n",
162 | " pi[state] = (actions, prob)\n",
163 | " return pi"
164 | ]
165 | },
166 | {
167 | "cell_type": "code",
168 | "execution_count": 20,
169 | "metadata": {
170 | "collapsed": true
171 | },
172 | "outputs": [],
173 | "source": [
174 | "def e_greedy(env, e, q, state):\n",
175 | " actions = env.actions\n",
176 | " action_values = []\n",
177 | " prob = []\n",
178 | " for action in actions:\n",
179 | " action_values.append(q[(state, action)])\n",
180 | " for i in range(len(action_values)):\n",
181 | " if i == np.argmax(action_values):\n",
182 | " prob.append(1 - e + e/len(action_values))\n",
183 | " else:\n",
184 | " prob.append(e/len(action_values))\n",
185 | " return np.random.choice(actions, p = prob)"
186 | ]
187 | },
188 | {
189 | "cell_type": "code",
190 | "execution_count": 21,
191 | "metadata": {
192 | "collapsed": true
193 | },
194 | "outputs": [],
195 | "source": [
196 | "def greedy(env, q, state):\n",
197 | " actions = env.actions\n",
198 | " action_values = []\n",
199 | " for action in actions:\n",
200 | " action_values.append(q[state, action])\n",
201 | " return actions[np.argmax(action_values)]"
202 | ]
203 | },
204 | {
205 | "cell_type": "markdown",
206 | "metadata": {},
207 | "source": [
208 | "### Double Q-learning"
209 | ]
210 | },
211 | {
212 | "cell_type": "code",
213 | "execution_count": 34,
214 | "metadata": {
215 | "collapsed": true
216 | },
217 | "outputs": [],
218 | "source": [
219 | "def double_q_learning(env, epsilon, alpha, num_iter):\n",
220 | " Q1, Q2 = state_action_value(env), state_action_value(env)\n",
221 | " for _ in range(num_iter):\n",
222 | " current_state = np.random.choice(env.states)\n",
223 | " while current_state != 0:\n",
224 | " Q = dict()\n",
225 | " for key in Q1.keys():\n",
226 | " Q[key] = Q1[key] + Q2[key]\n",
227 | " current_action = e_greedy(env, epsilon, Q, current_state)\n",
228 | " next_state, reward = env.state_transition(current_state, current_action)\n",
229 | " \n",
230 | " # choose Q1 or Q2 with equal probabilities (0.5)\n",
231 | " chosen_Q = np.random.choice([\"Q1\", \"Q2\"])\n",
232 | " if chosen_Q == \"Q1\": # when Q1 is chosen\n",
233 | " best_action = greedy(env, Q1, next_state)\n",
234 | " Q1[current_state, current_action] += alpha * \\\n",
235 | " (reward + env.gamma * Q2[next_state, best_action] - Q1[current_state, current_action])\n",
236 | " else: # when Q2 is chosen\n",
237 | " best_action = greedy(env, Q2, next_state)\n",
238 | " Q2[current_state, current_action] += alpha * \\\n",
239 | " (reward + env.gamma * Q1[next_state, best_action] - Q2[current_state, current_action])\n",
240 | " \n",
241 | " current_state = next_state\n",
242 | " return Q1, Q2"
243 | ]
244 | },
245 | {
246 | "cell_type": "code",
247 | "execution_count": 38,
248 | "metadata": {
249 | "collapsed": true
250 | },
251 | "outputs": [],
252 | "source": [
253 | "Q1, Q2 = double_q_learning(gw, 0.2, 0.5, 5000)\n",
254 | "\n",
255 | "# sum Q1 & Q2 elementwise to obtain final Q-values\n",
256 | "Q = dict()\n",
257 | "for key in Q1.keys():\n",
258 | " Q[key] = Q1[key] + Q2[key]"
259 | ]
260 | },
261 | {
262 | "cell_type": "code",
263 | "execution_count": 39,
264 | "metadata": {
265 | "collapsed": true
266 | },
267 | "outputs": [],
268 | "source": [
269 | "pi_hat = generate_greedy_policy(gw, Q)"
270 | ]
271 | },
272 | {
273 | "cell_type": "code",
274 | "execution_count": 32,
275 | "metadata": {
276 | "collapsed": true
277 | },
278 | "outputs": [],
279 | "source": [
280 | "def show_policy(pi, env):\n",
281 | " temp = np.zeros(len(env.states) + 2)\n",
282 | " for s in env.states:\n",
283 | " a = pi_hat[s][0][np.argmax(pi_hat[s][1])]\n",
284 | " if a == \"U\":\n",
285 | " temp[s] = 0.25\n",
286 | " elif a == \"D\":\n",
287 | " temp[s] = 0.5\n",
288 | " elif a == \"R\":\n",
289 | " temp[s] = 0.75\n",
290 | " else:\n",
291 | " temp[s] = 1.0\n",
292 | " \n",
293 | " temp = temp.reshape(4,4)\n",
294 | " ax = seaborn.heatmap(temp, cmap = \"prism\", linecolor=\"#282828\", cbar = False, linewidths = 0.1)\n",
295 | " plt.show()"
296 | ]
297 | },
298 | {
299 | "cell_type": "code",
300 | "execution_count": 40,
301 | "metadata": {},
302 | "outputs": [
303 | {
304 | "data": {
305 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAdwAAAFJCAYAAAAxCJwFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACrhJREFUeJzt3U2IlYUex/H/OCcqml4gISQI3Dhz4ILSop2LAacXIQZD\nSIspiFm1cASTOI29UF21VdDLYOWqRXdaXWkRSKYUFLQQFAaORQTR26ZVqORo59xFXC9ddG4zl/M7\nzfHz2c3zMM/54YN8fR5mcKjb7XYLAOipNf0eAADXAsEFgADBBYAAwQWAAMEFgADBBYCARi8v3mw2\nq33mTC8/gh5pjo3V39ru3Wq10HT/VquF5lhVlfu3Si00x6rdbl/xnCdcAAgQXAAIEFwACBBcAAgQ\nXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBc\nAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwA\nCBBcAAj408HtdDq93AEAA62x1MnvvvuuDhw4UAsLC9VoNKrT6dSGDRuq1WrV+vXrUxsBYNVbMriz\ns7O1Z8+e2rhx4+Vjp06dqlarVfPz8z0fBwCDYslXyouLi3+IbVXVpk2bejoIAAbRkk+4o6Oj1Wq1\navPmzXXzzTfXuXPn6pNPPqnR0dHUPgAYCEsG94UXXqhjx47VyZMn6+zZszUyMlLj4+M1MTGR2gcA\nA2HJ4A4NDdXExITAAsD/ye/hAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA\n4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDg\nAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAwFC32+326uLNZrNXlwaAv6R2u33F441ef/CZ\njYd7/RH0wNjpafduFXP/Vq+x09NVVdX+55k+L2ElmtvGrnrOK2UACBBcAAgQXAAIEFwACBBcAAgQ\nXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBc\nAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwA\nCBBcAAgQXAAIaCx1cmpqqi5evPiHY91ut4aGhmp+fr6nwwBgkCwZ3Keeeqr27dtXb775Zg0PD6c2\nAcDAWTK4GzdurMnJyfryyy9rYmIitQkABs6Swa2qmp6eTuwAgIHmh6YAIEBwASBAcAEgQHABIEBw\nASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHAB\nIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEg\nYKjb7XZ7dfFms9mrSwPAX1K73b7i8UavP/jMxsO9/gh6YOz0dP2tfabfM1ihheZYtf/p/q1GzW1j\nVVXu3yr17/t3JV4pA0CA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOAC\nQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJA\ngOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQMCyg7u4uNiLHQAw0K4a3OPHj9f4+HhNTEzUhx9+\nePn49PR0ZBgADJLG1U4cOnSojhw5Up1Op2ZmZurChQu1bdu26na7yX0AMBCuGtzrrruubr311qqq\nmpubq8cff7zWrVtXQ0NDsXEAMCiu+kr5zjvvrAMHDtT58+drZGSk3njjjXrxxRfrm2++Se4DgIFw\n1eDu37+/RkdHLz/Rrlu3rt5999164IEHYuMAYFBc9ZVyo9Gohx566A/H1q5dW7Ozsz0fBQCDxu/h\nAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOAC\nQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJA\ngOACQIDgAkCA4AJAgOACQMBQt9vt9urizWazV5cGgL+kdrt9xeONXn/w7Jl/9Poj6IG/j+1071Yx\n92/1+vvYzqqqap850+clrERzbOyq57xSBoAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEF\ngADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWA\nAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgIBlBffXX3+txcXFXm0BgIG1\nZHC//vrrevLJJ6vVatXnn39eW7dura1bt9aJEydS+wBgIDSWOvn888/XzMxM/fDDD7Vr1646evRo\nXX/99TU9PV3j4+OpjQCw6i0Z3E6nU/fcc09VVX3xxRd1++23//5NjSW/DQD4L0u+Ul6/fn3Nzs5W\np9OpgwcPVlXV22+/XWvXro2MA4BBseSj6ssvv1zHjx+vNWv+0+U77rijpqamej4MAAbJksFds2ZN\nbdmy5Q/HJicnezoIAAaR38MFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADB\nBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEF\ngADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYCAoW632+3VxZvNZq8uDQB/Se12+4rHexpcAOB3\nXikDQIDgAkCA4AJAgOACQIDgAkCA4AJAgOCuQKfTqeeee64efvjhmpqaqm+//bbfk1im06dP19TU\nVL9nsEwXL16svXv31iOPPFLbt2+vjz/+uN+T+JN+++23arVatWPHjtq5c2d99dVX/Z4UJ7grcOzY\nsVpcXKz333+/9uzZUwcPHuz3JJbhnXfeqX379tWFCxf6PYVl+uCDD+q2226r9957rw4fPlwvvfRS\nvyfxJ504caKqqubn52v37t316quv9nlRnuCuwMmTJ2vz5s1VVbVp06ZaWFjo8yKW46677qrXX3+9\n3zNYgfvvv79mZmaqqqrb7dbw8HCfF/Fnbdmy5fI/kH788ce65ZZb+rwor9HvAavR2bNna2Rk5PLX\nw8PDdenSpWo0/HGuBvfdd199//33/Z7BCtx0001V9fvfwV27dtXu3bv7vIjlaDQa9fTTT9dHH31U\nr732Wr/nxHnCXYGRkZE6d+7c5a87nY7YQshPP/1Ujz32WE1OTtaDDz7Y7zks0yuvvFJHjx6tZ599\nts6fP9/vOVGCuwJ33313ffrpp1VVderUqdqwYUOfF8G14eeff64nnnii9u7dW9u3b+/3HJbhyJEj\n9dZbb1VV1Y033lhDQ0O1Zs21lSCPZSswMTFRn332We3YsaO63W7t37+/35PgmnDo0KH65Zdfam5u\nrubm5qrq9x+Cu+GGG/q8jP/l3nvvrVarVY8++mhdunSpnnnmmWvuvvnfggAg4Np6ngeAPhFcAAgQ\nXAAIEFwACBBcAAgQXAAIEFwACBBcAAj4F4Bkpa2zV4RVAAAAAElFTkSuQmCC\n",
306 | "text/plain": [
307 | ""
308 | ]
309 | },
310 | "metadata": {},
311 | "output_type": "display_data"
312 | }
313 | ],
314 | "source": [
315 | "### RED = TERMINAL (0)\n",
316 | "### GREEN = LEFT\n",
317 | "### BLUE = UP\n",
318 | "### PURPLE = RIGHT\n",
319 | "### ORANGE = DOWN\n",
320 | "\n",
321 | "show_policy(pi_hat, gw)"
322 | ]
323 | }
324 | ],
325 | "metadata": {
326 | "kernelspec": {
327 | "display_name": "Python 3",
328 | "language": "python",
329 | "name": "python3"
330 | },
331 | "language_info": {
332 | "codemirror_mode": {
333 | "name": "ipython",
334 | "version": 3
335 | },
336 | "file_extension": ".py",
337 | "mimetype": "text/x-python",
338 | "name": "python",
339 | "nbconvert_exporter": "python",
340 | "pygments_lexer": "ipython3",
341 | "version": "3.6.1"
342 | }
343 | },
344 | "nbformat": 4,
345 | "nbformat_minor": 2
346 | }
347 |
--------------------------------------------------------------------------------
/source code/3-Temporal Difference Learning (Chapter 6)/2-SARSA-on-policy control.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# SARSA\n",
8 | "- Algorithms from ```pp. 105 - 107``` in Sutton & Barto 2017"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": 1,
14 | "metadata": {
15 | "collapsed": true
16 | },
17 | "outputs": [],
18 | "source": [
19 | "import matplotlib.pyplot as plt\n",
20 | "import pandas as pd\n",
21 | "import numpy as np\n",
22 | "import seaborn, random\n",
23 | "\n",
24 | "from gridWorldEnvironment import GridWorld"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": 2,
30 | "metadata": {
31 | "collapsed": true
32 | },
33 | "outputs": [],
34 | "source": [
35 | "# creating gridworld environment\n",
36 | "gw = GridWorld(gamma = .9, theta = .5)"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": 3,
42 | "metadata": {
43 | "collapsed": true
44 | },
45 | "outputs": [],
46 | "source": [
47 | "def state_action_value(env):\n",
48 | " q = dict()\n",
49 | " for state, action, next_state, reward in env.transitions:\n",
50 | " q[(state, action)] = np.random.normal()\n",
51 | " for action in env.actions:\n",
52 | " q[0, action] = 0\n",
53 | " return q"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": 4,
59 | "metadata": {
60 | "scrolled": true
61 | },
62 | "outputs": [
63 | {
64 | "data": {
65 | "text/plain": [
66 | "{(0, 'D'): 0,\n",
67 | " (0, 'L'): 0,\n",
68 | " (0, 'R'): 0,\n",
69 | " (0, 'U'): 0,\n",
70 | " (1, 'D'): -1.4931958690513218,\n",
71 | " (1, 'L'): 0.45253247253698986,\n",
72 | " (1, 'R'): -1.789094792083647,\n",
73 | " (1, 'U'): 1.9103660029206884,\n",
74 | " (2, 'D'): -0.3082624959814856,\n",
75 | " (2, 'L'): -0.7483834716798741,\n",
76 | " (2, 'R'): 0.9839358952573672,\n",
77 | " (2, 'U'): -0.2724328270166447,\n",
78 | " (3, 'D'): -0.6283488057971491,\n",
79 | " (3, 'L'): -1.3156943242567214,\n",
80 | " (3, 'R'): 0.6211123056414489,\n",
81 | " (3, 'U'): 0.7976038544679848,\n",
82 | " (4, 'D'): -1.050452706273533,\n",
83 | " (4, 'L'): -0.1741895951805081,\n",
84 | " (4, 'R'): 1.8182867323493,\n",
85 | " (4, 'U'): 1.801387714646261,\n",
86 | " (5, 'D'): 0.10021107202221527,\n",
87 | " (5, 'L'): -0.47465610710346245,\n",
88 | " (5, 'R'): -0.6918872493076375,\n",
89 | " (5, 'U'): -3.2003225504816597,\n",
90 | " (6, 'D'): 0.6278525454978494,\n",
91 | " (6, 'L'): -0.24213713052505115,\n",
92 | " (6, 'R'): -2.630728473016553,\n",
93 | " (6, 'U'): 1.2604497963698,\n",
94 | " (7, 'D'): 0.7220353661454653,\n",
95 | " (7, 'L'): -0.046864386969943536,\n",
96 | " (7, 'R'): 0.3648650012222875,\n",
97 | " (7, 'U'): -0.13955201414490984,\n",
98 | " (8, 'D'): -0.73401777416155,\n",
99 | " (8, 'L'): 0.32247567696601837,\n",
100 | " (8, 'R'): 0.5780991299024739,\n",
101 | " (8, 'U'): -0.44090209956778376,\n",
102 | " (9, 'D'): 1.4691269385027337,\n",
103 | " (9, 'L'): 0.052954400141560456,\n",
104 | " (9, 'R'): 0.13195379460282597,\n",
105 | " (9, 'U'): -0.3923944749627829,\n",
106 | " (10, 'D'): 1.7199147774865016,\n",
107 | " (10, 'L'): -1.9247987278054801,\n",
108 | " (10, 'R'): -0.143510086697551,\n",
109 | " (10, 'U'): -0.5647971071775687,\n",
110 | " (11, 'D'): -0.5741454748190717,\n",
111 | " (11, 'L'): -1.5720736584714539,\n",
112 | " (11, 'R'): -1.4601134792863597,\n",
113 | " (11, 'U'): -1.257386467083216,\n",
114 | " (12, 'D'): -0.7472490605228015,\n",
115 | " (12, 'L'): -0.22563599050611213,\n",
116 | " (12, 'R'): 0.4395772978492674,\n",
117 | " (12, 'U'): -0.7210157758438556,\n",
118 | " (13, 'D'): 0.6291351870645969,\n",
119 | " (13, 'L'): -1.1313917732858065,\n",
120 | " (13, 'R'): 1.7365108020870084,\n",
121 | " (13, 'U'): 0.1339793058824657,\n",
122 | " (14, 'D'): 1.028938708862189,\n",
123 | " (14, 'L'): 0.050033944728027975,\n",
124 | " (14, 'R'): 1.3081957299962321,\n",
125 | " (14, 'U'): 0.05849306927859837}"
126 | ]
127 | },
128 | "execution_count": 4,
129 | "metadata": {},
130 | "output_type": "execute_result"
131 | }
132 | ],
133 | "source": [
134 | "state_action_value(gw)"
135 | ]
136 | },
137 | {
138 | "cell_type": "code",
139 | "execution_count": 5,
140 | "metadata": {
141 | "collapsed": true
142 | },
143 | "outputs": [],
144 | "source": [
145 | "def generate_greedy_policy(env, Q):\n",
146 | " pi = dict()\n",
147 | " for state in env.states:\n",
148 | " actions = []\n",
149 | " q_values = []\n",
150 | " prob = []\n",
151 | " \n",
152 | " for a in env.actions:\n",
153 | " actions.append(a)\n",
154 | " q_values.append(Q[state,a]) \n",
155 | " for i in range(len(q_values)):\n",
156 | " if i == np.argmax(q_values):\n",
157 | " prob.append(1)\n",
158 | " else:\n",
159 | " prob.append(0) \n",
160 | " \n",
161 | " pi[state] = (actions, prob)\n",
162 | " return pi"
163 | ]
164 | },
165 | {
166 | "cell_type": "code",
167 | "execution_count": 6,
168 | "metadata": {
169 | "collapsed": true
170 | },
171 | "outputs": [],
172 | "source": [
173 | "# selects action epsilon-greedily, given current state\n",
174 | "def e_greedy(env, e, q, state):\n",
175 | " actions = env.actions\n",
176 | " action_values = []\n",
177 | " prob = []\n",
178 | " for action in actions:\n",
179 | " action_values.append(q[(state, action)])\n",
180 | " for i in range(len(action_values)):\n",
181 | " if i == np.argmax(action_values):\n",
182 | " prob.append(1 - e + e/len(action_values))\n",
183 | " else:\n",
184 | " prob.append(e/len(action_values))\n",
185 | " return np.random.choice(actions, p = prob)"
186 | ]
187 | },
188 | {
189 | "cell_type": "markdown",
190 | "metadata": {
191 | "collapsed": true
192 | },
193 | "source": [
194 | "### SARSA: On-policy TD Control\n",
195 | "- Evaluates action-value function ($Q$)"
196 | ]
197 | },
198 | {
199 | "cell_type": "code",
200 | "execution_count": 7,
201 | "metadata": {
202 | "collapsed": true
203 | },
204 | "outputs": [],
205 | "source": [
206 | "def sarsa(env, epsilon, alpha, num_iter):\n",
207 | " Q = state_action_value(env)\n",
208 | " for _ in range(num_iter):\n",
209 | " current_state = np.random.choice(env.states)\n",
210 | " current_action = e_greedy(env, epsilon, Q, current_state)\n",
211 | " while current_state != 0:\n",
212 | " next_state, reward = env.state_transition(current_state, current_action)\n",
213 | " next_action = e_greedy(env, epsilon, Q, next_state)\n",
214 | " Q[current_state, current_action] += alpha * (reward + env.gamma * Q[next_state, next_action] - Q[current_state, current_action])\n",
215 | " current_state, current_action = next_state, next_action\n",
216 | " return Q"
217 | ]
218 | },
219 | {
220 | "cell_type": "code",
221 | "execution_count": 8,
222 | "metadata": {
223 | "collapsed": true
224 | },
225 | "outputs": [],
226 | "source": [
227 | "Q = sarsa(gw, 0.1, 0.5, 10000)"
228 | ]
229 | },
230 | {
231 | "cell_type": "code",
232 | "execution_count": 9,
233 | "metadata": {
234 | "collapsed": true
235 | },
236 | "outputs": [],
237 | "source": [
238 | "pi_hat = generate_greedy_policy(gw, Q)"
239 | ]
240 | },
241 | {
242 | "cell_type": "markdown",
243 | "metadata": {},
244 | "source": [
245 | "### Visualizing policy"
246 | ]
247 | },
248 | {
249 | "cell_type": "code",
250 | "execution_count": 10,
251 | "metadata": {
252 | "collapsed": true
253 | },
254 | "outputs": [],
255 | "source": [
256 | "def show_policy(pi, env):\n",
257 | " temp = np.zeros(len(env.states) + 2)\n",
258 | " for s in env.states:\n",
259 | " a = pi_hat[s][0][np.argmax(pi_hat[s][1])]\n",
260 | " if a == \"U\":\n",
261 | " temp[s] = 0.25\n",
262 | " elif a == \"D\":\n",
263 | " temp[s] = 0.5\n",
264 | " elif a == \"R\":\n",
265 | " temp[s] = 0.75\n",
266 | " else:\n",
267 | " temp[s] = 1.0\n",
268 | " \n",
269 | " temp = temp.reshape(4,4)\n",
270 | " ax = seaborn.heatmap(temp, cmap = \"prism\", linecolor=\"#282828\", cbar = False, linewidths = 0.1)\n",
271 | " plt.show()"
272 | ]
273 | },
274 | {
275 | "cell_type": "code",
276 | "execution_count": 11,
277 | "metadata": {},
278 | "outputs": [
279 | {
280 | "data": {
281 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAdwAAAFJCAYAAAAxCJwFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACs5JREFUeJzt3UtoXIXfxvFfmhEV4wUsSBGEbpoMCCku3HURaLwUJFQK\ntkoUJCsXbaEWiakXvLR1JXgJVbty4T+uLC6EYm1RUHBRaCEwUUQQbxtX0hab1pn/QuiLL01ek5d5\nxkw/n13OoWceegjfnkNCBzqdTqcAgK5a1+sBAHAtEFwACBBcAAgQXAAIEFwACBBcAAhodPPizWaz\nWgsL3fwIuqQ5MlJ3t9y7tWq+6f6tVfPNkaqqan3k/q1Fze0j1Wq1rnrOEy4ABAguAAQILgAECC4A\nBAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAE\nCC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQI\nLgAECC4ABPzj4Lbb7W7uAIC+1lju5I8//liHDh2q+fn5ajQa1W63a9OmTTU9PV0bN25MbQSANW/Z\n4M7MzNS+fftqdHT0yrEzZ87U9PR0zc3NdX0cAPSLZV8pLy4u/i22VVWbN2/u6iAA6EfLPuEODw/X\n9PR0bdmypW6++eY6f/58ff755zU8PJzaBwB9Ydngvvjii3XixIk6ffp0nTt3roaGhmpsbKzGx8dT\n+wCgLywb3IGBgRofHxdYAPh/8nu4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4\nABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgA\nECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQMNDpdDrduniz2ezWpQHgX6nVal31eKPb\nH7wwerTbH0EXjJydqrtbC72ewSrNN0fcvzVqvjlSVVWtj9y/tai5fWTJc14pA0CA4AJAgOACQIDg\nAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOAC\nQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJA\ngOACQIDgAkCA4AJAgOACQEBjuZOTk5N16dKlvx3rdDo1MDBQc3NzXR0GAP1k2eA+/fTTdeDAgXr7\n7bdrcHAwtQkA+s6ywR0dHa2JiYn65ptvanx8PLUJAPrOssGtqpqamkrsAIC+5oemACBAcAEgQHAB\nIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEg\nQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBA\ncAEgQHABIGCg0+l0unXxZrPZrUsDwL9Sq9W66vFGtz94YfRotz+CLhg5O1V3txZ6PYNVmm+O1MzC\nf3o9g1V4dWRXVVW1PvL9txY1t48sec4rZQAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBc\nAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwA\nCBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAhYcXAXFxe7sQMA+tqSwT15\n8mSNjY3V+Ph4ffLJJ1eOT01NRYYBQD9pLHXiyJEjdezYsWq327Vnz566ePFibd++vTqdTnIfAPSF\nJYN73XXX1a233lpVVbOzs/XEE0/Uhg0bamBgIDYOAPrFkq+U77zzzjp06FBduHChhoaG6q233qqX\nXnqpvv/+++Q+AOgLSwb34MGDNTw8fOWJdsOGDfX+++/Xgw8+GBsHAP1iyVfKjUajHn744b8dW79+\nfc3MzHR9FAD0G7+HCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYIL\nAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsA\nAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAEDnU6n062LN5vNbl0aAP6VWq3WVY83uv3BC6NH\nu/0RdMHI2amaWfhPr2ewSq+O7HL/1qhXR3ZVVVVrYaHHS1iN5sjIkue8UgaAAMEFgADBBYAAwQWA\nAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAA\nwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADB\nBYCAFQX3jz/+qMXFxW5tAYC+tWxwv/vuu3rqqadqenq6vvrqq9q2bVtt27atTp06ldoHAH2hsdzJ\nF154ofbs2VM///xz7d69u44fP17XX399TU1N1djYWGojAKx5ywa33W7XvffeW1VVX3/9dd1+++1/\n/aHGsn8MAPhfln2lvHHjxpqZmal2u12HDx+uqqp333231q9fHxkHAP1i2UfVV155pU6ePFnr1v1P\nl++4446anJzs+jAA6CfLBnfdunW1devWvx2bmJjo6iAA6Ed+DxcAAgQXAAIEFwACBBcAAgQXAAIE\nFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQX\nAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIGOp1Op1sX\nbzab3bo0APwrtVqtqx7vanABgL94pQwAAYILAAGCCwABggsAAYILAAGCCwABgrsK7Xa7nn/++Xrk\nkUdqcnKyfvjhh15PYoXOnj1bk5OTvZ7BCl26dKn2799fjz76aO3YsaM+++yzXk/iH/rzzz9renq6\ndu7cWbt27apvv/2215PiBHcVTpw4UYuLi/Xhhx/Wvn376vDhw72exAq89957deDAgbp48WKvp7BC\nH3/8cd122231wQcf1NGjR+vll1/u9ST+oVOnTlVV1dzcXO3du7def/31Hi/KE9xVOH36dG3ZsqWq\nqjZv3lzz8/M9XsRK3HXXXfXmm2/2egar8MADD9SePXuqqqrT6dTg4GCPF/FPbd269co/kH755Ze6\n5ZZberwor9HrAWvRuXPnamho6MrXg4ODdfny5Wo0/HWuBffff3/99NNPvZ7BKtx0001V9df34O7d\nu2vv3r09XsRKNBqNeuaZZ+rTTz+tN954o9dz4jzhrsLQ0FCdP3/+ytftdltsIeTXX3+txx9/vCYm\nJuqhhx7q9RxW6LXXXqvjx4/Xc889VxcuXOj1nCjBXYV77rmnvvjii6qqOnPmTG3atKnHi+Da8Ntv\nv9WTTz5Z+/fvrx07dvR6Ditw7Nixeuedd6qq6sYbb6yBgYFat+7aSpDHslUYHx+vL7/8snbu3Fmd\nTqcOHjzY60lwTThy5Ej9/vvvNTs7W7Ozs1X11w/B3XDDDT1exv/lvvvuq+np6Xrsscfq8uXL9eyz\nz15z983/FgQAAdfW8zwA9IjgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkDAfwHhhKWt2U7SnQAA\nAABJRU5ErkJggg==\n",
282 | "text/plain": [
283 | ""
284 | ]
285 | },
286 | "metadata": {},
287 | "output_type": "display_data"
288 | }
289 | ],
290 | "source": [
291 | "### RED = TERMINAL (0)\n",
292 | "### GREEN = LEFT\n",
293 | "### BLUE = UP\n",
294 | "### PURPLE = RIGHT\n",
295 | "### ORANGE = DOWN\n",
296 | "\n",
297 | "show_policy(pi_hat, gw)"
298 | ]
299 | }
300 | ],
301 | "metadata": {
302 | "kernelspec": {
303 | "display_name": "Python 3",
304 | "language": "python",
305 | "name": "python3"
306 | },
307 | "language_info": {
308 | "codemirror_mode": {
309 | "name": "ipython",
310 | "version": 3
311 | },
312 | "file_extension": ".py",
313 | "mimetype": "text/x-python",
314 | "name": "python",
315 | "nbconvert_exporter": "python",
316 | "pygments_lexer": "ipython3",
317 | "version": "3.6.1"
318 | }
319 | },
320 | "nbformat": 4,
321 | "nbformat_minor": 2
322 | }
323 |
--------------------------------------------------------------------------------
/source code/3-Temporal Difference Learning (Chapter 6)/3-Q-learning-off-policy-control.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Q-Learning\n",
8 | "- Algorithms from ```pp. 107 - 109``` in Sutton & Barto 2017"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": 15,
14 | "metadata": {
15 | "collapsed": true
16 | },
17 | "outputs": [],
18 | "source": [
19 | "import matplotlib.pyplot as plt\n",
20 | "import pandas as pd\n",
21 | "import numpy as np\n",
22 | "import seaborn, random\n",
23 | "\n",
24 | "from gridWorldEnvironment import GridWorld"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": 16,
30 | "metadata": {
31 | "collapsed": true
32 | },
33 | "outputs": [],
34 | "source": [
35 | "# creating gridworld environment\n",
36 | "gw = GridWorld(gamma = .9, theta = .5)"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": 17,
42 | "metadata": {
43 | "collapsed": true
44 | },
45 | "outputs": [],
46 | "source": [
47 | "def state_action_value(env):\n",
48 | " q = dict()\n",
49 | " for state, action, next_state, reward in env.transitions:\n",
50 | " q[(state, action)] = np.random.normal()\n",
51 | " for action in env.actions:\n",
52 | " q[0, action] = 0\n",
53 | " return q"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": 18,
59 | "metadata": {
60 | "scrolled": true
61 | },
62 | "outputs": [
63 | {
64 | "data": {
65 | "text/plain": [
66 | "{(0, 'D'): 0,\n",
67 | " (0, 'L'): 0,\n",
68 | " (0, 'R'): 0,\n",
69 | " (0, 'U'): 0,\n",
70 | " (1, 'D'): 1.0159263422385194,\n",
71 | " (1, 'L'): -0.9015277224413166,\n",
72 | " (1, 'R'): -0.8401929147381398,\n",
73 | " (1, 'U'): -0.3866748964867951,\n",
74 | " (2, 'D'): -1.6171831004555488,\n",
75 | " (2, 'L'): -0.757413177831847,\n",
76 | " (2, 'R'): 1.1948020656778442,\n",
77 | " (2, 'U'): -0.6814167904466197,\n",
78 | " (3, 'D'): 0.8048046240684537,\n",
79 | " (3, 'L'): 0.5058075927359411,\n",
80 | " (3, 'R'): -0.2396998941031068,\n",
81 | " (3, 'U'): 0.2155229655857796,\n",
82 | " (4, 'D'): -1.6147931728590075,\n",
83 | " (4, 'L'): -0.7788455232241043,\n",
84 | " (4, 'R'): -1.4834265811813951,\n",
85 | " (4, 'U'): -0.6831094436413707,\n",
86 | " (5, 'D'): 0.553787569878218,\n",
87 | " (5, 'L'): -0.7535460982740707,\n",
88 | " (5, 'R'): -0.5446399864366045,\n",
89 | " (5, 'U'): 0.3275946472542217,\n",
90 | " (6, 'D'): 0.4967878126487371,\n",
91 | " (6, 'L'): -0.08765553689024165,\n",
92 | " (6, 'R'): 2.561671991266108,\n",
93 | " (6, 'U'): -0.6251803976764682,\n",
94 | " (7, 'D'): -1.4239033515729715,\n",
95 | " (7, 'L'): 0.8961390571157126,\n",
96 | " (7, 'R'): -1.752510423129062,\n",
97 | " (7, 'U'): 0.186255694272231,\n",
98 | " (8, 'D'): 2.0464864623182057,\n",
99 | " (8, 'L'): -1.4145441882793333,\n",
100 | " (8, 'R'): -0.01479790948318322,\n",
101 | " (8, 'U'): -0.5827888737714622,\n",
102 | " (9, 'D'): -1.99334821662137,\n",
103 | " (9, 'L'): -0.2143702145160287,\n",
104 | " (9, 'R'): 0.3093816164828963,\n",
105 | " (9, 'U'): -0.13810316479560683,\n",
106 | " (10, 'D'): -0.4286970274384264,\n",
107 | " (10, 'L'): -1.0313201895326287,\n",
108 | " (10, 'R'): 0.674733204171753,\n",
109 | " (10, 'U'): -0.06485579239030625,\n",
110 | " (11, 'D'): -1.08761491406436,\n",
111 | " (11, 'L'): 0.37568985747762257,\n",
112 | " (11, 'R'): 0.6549082083274831,\n",
113 | " (11, 'U'): -0.044872797471142374,\n",
114 | " (12, 'D'): -0.8474472601122482,\n",
115 | " (12, 'L'): -0.6624032631781939,\n",
116 | " (12, 'R'): 0.4720691318622512,\n",
117 | " (12, 'U'): -1.2333806138625314,\n",
118 | " (13, 'D'): -0.38599325593593187,\n",
119 | " (13, 'L'): -1.677564721893633,\n",
120 | " (13, 'R'): -1.1425774085316547,\n",
121 | " (13, 'U'): -1.3606173055898687,\n",
122 | " (14, 'D'): -1.1641372208339686,\n",
123 | " (14, 'L'): 1.1457975665987994,\n",
124 | " (14, 'R'): -1.012618770180745,\n",
125 | " (14, 'U'): -0.06091464445082383}"
126 | ]
127 | },
128 | "execution_count": 18,
129 | "metadata": {},
130 | "output_type": "execute_result"
131 | }
132 | ],
133 | "source": [
134 | "state_action_value(gw)"
135 | ]
136 | },
137 | {
138 | "cell_type": "code",
139 | "execution_count": 19,
140 | "metadata": {
141 | "collapsed": true
142 | },
143 | "outputs": [],
144 | "source": [
145 | "def generate_greedy_policy(env, Q):\n",
146 | " pi = dict()\n",
147 | " for state in env.states:\n",
148 | " actions = []\n",
149 | " q_values = []\n",
150 | " prob = []\n",
151 | " \n",
152 | " for a in env.actions:\n",
153 | " actions.append(a)\n",
154 | " q_values.append(Q[state,a]) \n",
155 | " for i in range(len(q_values)):\n",
156 | " if i == np.argmax(q_values):\n",
157 | " prob.append(1)\n",
158 | " else:\n",
159 | " prob.append(0) \n",
160 | " \n",
161 | " pi[state] = (actions, prob)\n",
162 | " return pi"
163 | ]
164 | },
165 | {
166 | "cell_type": "code",
167 | "execution_count": 20,
168 | "metadata": {
169 | "collapsed": true
170 | },
171 | "outputs": [],
172 | "source": [
173 | "def e_greedy(env, e, q, state):\n",
174 | " actions = env.actions\n",
175 | " action_values = []\n",
176 | " prob = []\n",
177 | " for action in actions:\n",
178 | " action_values.append(q[(state, action)])\n",
179 | " for i in range(len(action_values)):\n",
180 | " if i == np.argmax(action_values):\n",
181 | " prob.append(1 - e + e/len(action_values))\n",
182 | " else:\n",
183 | " prob.append(e/len(action_values))\n",
184 | " return np.random.choice(actions, p = prob)"
185 | ]
186 | },
187 | {
188 | "cell_type": "code",
189 | "execution_count": 21,
190 | "metadata": {
191 | "collapsed": true
192 | },
193 | "outputs": [],
194 | "source": [
195 | "def greedy(env, q, state):\n",
196 | " actions = env.actions\n",
197 | " action_values = []\n",
198 | " for action in actions:\n",
199 | " action_values.append(q[state, action])\n",
200 | " return actions[np.argmax(action_values)]"
201 | ]
202 | },
203 | {
204 | "cell_type": "markdown",
205 | "metadata": {
206 | "collapsed": true
207 | },
208 | "source": [
209 | "### Q-Learning: Off-policy TD Control"
210 | ]
211 | },
212 | {
213 | "cell_type": "code",
214 | "execution_count": 22,
215 | "metadata": {
216 | "collapsed": true
217 | },
218 | "outputs": [],
219 | "source": [
220 | "def q_learning(env, epsilon, alpha, num_iter):\n",
221 | " Q = state_action_value(env)\n",
222 | " \n",
223 | " for _ in range(num_iter):\n",
224 | " current_state = np.random.choice(env.states)\n",
225 | " while current_state != 0:\n",
226 | " current_action = e_greedy(env, epsilon, Q, current_state)\n",
227 | " next_state, reward = env.state_transition(current_state, current_action)\n",
228 | " best_action = greedy(env, Q, next_state)\n",
229 | " Q[current_state, current_action] += alpha * (reward + env.gamma * Q[next_state, best_action] - Q[current_state, current_action])\n",
230 | " current_state = next_state\n",
231 | " return Q"
232 | ]
233 | },
234 | {
235 | "cell_type": "code",
236 | "execution_count": 29,
237 | "metadata": {
238 | "collapsed": true
239 | },
240 | "outputs": [],
241 | "source": [
242 | "Q = q_learning(gw, 0.2, 1.0, 10000)"
243 | ]
244 | },
245 | {
246 | "cell_type": "code",
247 | "execution_count": 30,
248 | "metadata": {
249 | "scrolled": true
250 | },
251 | "outputs": [
252 | {
253 | "data": {
254 | "text/plain": [
255 | "{(0, 'D'): 0,\n",
256 | " (0, 'L'): 0,\n",
257 | " (0, 'R'): 0,\n",
258 | " (0, 'U'): 0,\n",
259 | " (1, 'D'): -2.71,\n",
260 | " (1, 'L'): -1.0,\n",
261 | " (1, 'R'): -2.71,\n",
262 | " (1, 'U'): -1.9,\n",
263 | " (2, 'D'): -3.439,\n",
264 | " (2, 'L'): -1.9,\n",
265 | " (2, 'R'): -3.439,\n",
266 | " (2, 'U'): -2.71,\n",
267 | " (3, 'D'): -2.71,\n",
268 | " (3, 'L'): -2.71,\n",
269 | " (3, 'R'): -3.439,\n",
270 | " (3, 'U'): -3.439,\n",
271 | " (4, 'D'): -2.71,\n",
272 | " (4, 'L'): -1.9,\n",
273 | " (4, 'R'): -2.71,\n",
274 | " (4, 'U'): -1.0,\n",
275 | " (5, 'D'): -3.439,\n",
276 | " (5, 'L'): -1.9,\n",
277 | " (5, 'R'): -3.439,\n",
278 | " (5, 'U'): -1.9,\n",
279 | " (6, 'D'): -2.71,\n",
280 | " (6, 'L'): -2.71,\n",
281 | " (6, 'R'): -2.71,\n",
282 | " (6, 'U'): -2.71,\n",
283 | " (7, 'D'): -1.9,\n",
284 | " (7, 'L'): -3.439,\n",
285 | " (7, 'R'): -2.71,\n",
286 | " (7, 'U'): -3.439,\n",
287 | " (8, 'D'): -3.439,\n",
288 | " (8, 'L'): -2.71,\n",
289 | " (8, 'R'): -3.439,\n",
290 | " (8, 'U'): -1.9,\n",
291 | " (9, 'D'): -2.71,\n",
292 | " (9, 'L'): -2.71,\n",
293 | " (9, 'R'): -2.71,\n",
294 | " (9, 'U'): -2.71,\n",
295 | " (10, 'D'): -1.9,\n",
296 | " (10, 'L'): -3.439,\n",
297 | " (10, 'R'): -1.9,\n",
298 | " (10, 'U'): -3.439,\n",
299 | " (11, 'D'): -1.0,\n",
300 | " (11, 'L'): -2.71,\n",
301 | " (11, 'R'): -1.9,\n",
302 | " (11, 'U'): -2.71,\n",
303 | " (12, 'D'): -3.439,\n",
304 | " (12, 'L'): -3.439,\n",
305 | " (12, 'R'): -2.71,\n",
306 | " (12, 'U'): -2.71,\n",
307 | " (13, 'D'): -2.71,\n",
308 | " (13, 'L'): -3.439,\n",
309 | " (13, 'R'): -1.9,\n",
310 | " (13, 'U'): -3.439,\n",
311 | " (14, 'D'): -1.9,\n",
312 | " (14, 'L'): -2.71,\n",
313 | " (14, 'R'): -1.0,\n",
314 | " (14, 'U'): -2.71}"
315 | ]
316 | },
317 | "execution_count": 30,
318 | "metadata": {},
319 | "output_type": "execute_result"
320 | }
321 | ],
322 | "source": [
323 | "Q"
324 | ]
325 | },
326 | {
327 | "cell_type": "code",
328 | "execution_count": 31,
329 | "metadata": {
330 | "collapsed": true
331 | },
332 | "outputs": [],
333 | "source": [
334 | "pi_hat = generate_greedy_policy(gw, Q)"
335 | ]
336 | },
337 | {
338 | "cell_type": "markdown",
339 | "metadata": {},
340 | "source": [
341 | "### Visualizing policy"
342 | ]
343 | },
344 | {
345 | "cell_type": "code",
346 | "execution_count": 32,
347 | "metadata": {
348 | "collapsed": true
349 | },
350 | "outputs": [],
351 | "source": [
352 | "def show_policy(pi, env):\n",
353 | " temp = np.zeros(len(env.states) + 2)\n",
354 | " for s in env.states:\n",
355 | " a = pi_hat[s][0][np.argmax(pi_hat[s][1])]\n",
356 | " if a == \"U\":\n",
357 | " temp[s] = 0.25\n",
358 | " elif a == \"D\":\n",
359 | " temp[s] = 0.5\n",
360 | " elif a == \"R\":\n",
361 | " temp[s] = 0.75\n",
362 | " else:\n",
363 | " temp[s] = 1.0\n",
364 | " \n",
365 | " temp = temp.reshape(4,4)\n",
366 | " ax = seaborn.heatmap(temp, cmap = \"prism\", linecolor=\"#282828\", cbar = False, linewidths = 0.1)\n",
367 | " plt.show()"
368 | ]
369 | },
370 | {
371 | "cell_type": "code",
372 | "execution_count": 33,
373 | "metadata": {},
374 | "outputs": [
375 | {
376 | "data": {
377 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAdwAAAFJCAYAAAAxCJwFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACsFJREFUeJzt3U+IlYX+x/HvOCcqmv5AQkgQuHHmwAWlRTsXA05/hBgM\nIS2mIGbVwhFM4jT2h+qqrYL+DFauWnSnVdIikEwpKGghKAhnigiif5tWoZKjnfNbxM9LF53bzOV8\nTnN8vXbzPMxzPvggb5+HGRzqdrvdAgB6ak2/BwDAtUBwASBAcAEgQHABIEBwASBAcAEgoNHLizeb\nzWovLPTyI+iR5thY/aPt3q1WZ5ru32p1pjlWVVXtD9y/1ai5baza7fYVz3nCBYAAwQWAAMEFgADB\nBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEF\ngADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWA\nAMEFgADBBYCAvxzcTqfTyx0AMNAaS538/vvv68CBA3XmzJlqNBrV6XRqw4YN1Wq1av369amNALDq\nLRnc2dnZ2rNnT23cuPHysVOnTlWr1ar5+fmejwOAQbHkK+XFxcU/xbaqatOmTT0dBACDaMkn3NHR\n0Wq1WrV58+a6+eab69y5c/Xpp5/W6Ohoah8ADIQlg/vCCy/UsWPH6uTJk3X27NkaGRmp8fHxmpiY\nSO0DgIGwZHCHhoZqYmJCYAHgf+T3cAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBA\ncAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBw\nASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIGCo2+12e3XxZrPZq0sDwN9Su92+4vFG\nrz94YePhXn8EPTB2etq9W8Xcv9Vr7PR0VVW1P1jo8xJWorlt7KrnvFIGgADBBYAAwQWAAMEFgADB\nBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEF\ngADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWA\nAMEFgADBBYAAwQWAgMZSJ6empurixYt/OtbtdmtoaKjm5+d7OgwABsmSwX3qqadq37599eabb9bw\n8HBqEwAMnCWDu3HjxpqcnKyvvvqqJiYmUpsAYOAsGdyqqunp6cQOABhofmgKAAIEFwACBBcAAgQX\nAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcA\nAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwAC\nBBcAAoa63W63VxdvNpu9ujQA/C212+0rHm/0+oMXNh7u9UfQA2Onp927VWzs9HS1P1jo9wxWoLlt\nrKrK/Vul/v/+XYlXygAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4\nABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgA\nECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABCw7OAuLi72YgcADLSrBvf48eM1Pj5eExMT9dFH\nH10+Pj09HRkGAIOkcbUThw4dqiNHjlSn06mZmZm6cOFCbdu2rbrdbnIfAAyEqwb3uuuuq1tvvbWq\nqubm5urxxx+vdevW1dDQUGwcAAyKq75SvvPOO+vAgQN1/vz5GhkZqTfeeKNefPHF+vbbb5P7AGAg\nXDW4+/fvr9HR0ctPtOvWrat33323Hnjggdg4ABgUV32l3Gg06qGHHvrTsbVr19bs7GzPRwHAoPF7\nuAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4\nABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgA\nECC4ABAguAAQILgAECC4ABAw1O12u726eLPZ7NWlAeBvqd1uX/F4o9cfvLDxcK8/gh4YOz1dswv/\n6vcMVuifYzvdv1Xqn2M7q6qqvbDQ5yWsRHNs7KrnvFIGgADBBYAAwQWAAMEFgADBBYAAwQWAAMEF\ngADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWA\nAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAgGUF97fffqvF\nxcVebQGAgbVkcL/55pt68sknq9Vq1RdffFFbt26trVu31okTJ1L7AGAgNJY6+fzzz9fMzEz9+OOP\ntWvXrjp69Ghdf/31NT09XePj46mNALDqLRncTqdT99xzT1VVffnll3X77bf/8U2NJb8NAPgPS75S\nXr9+fc3Ozlan06mDBw9WVdXbb79da9eujYwDgEGx5KPqyy+/XMePH681a/7d5TvuuKOmpqZ6PgwA\nBsmSwV2zZk1t2bLlT8cmJyd7OggABpHfwwWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADB\nBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEF\ngADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgIChbrfb7dXFm81mry4NAH9L7Xb7\nisd7GlwA4A9eKQNAgOACQIDgAkCA4AJAgOACQIDgAkCA4K5Ap9Op5557rh5++OGampqq7777rt+T\nWKbTp0/X1NRUv2ewTBcvXqy9e/fWI488Utu3b69PPvmk35P4i37//fdqtVq1Y8eO2rlzZ3399df9\nnhQnuCtw7NixWlxcrPfff7/27NlTBw8e7PckluGdd96pffv21YULF/o9hWX68MMP67bbbqv33nuv\nDh8+XC+99FK/J/EXnThxoqqq5ufna/fu3fXqq6/2eVGe4K7AyZMna/PmzVVVtWnTpjpz5kyfF7Ec\nd911V73++uv9nsEK3H///TUzM1NVVd1ut4aHh/u8iL9qy5Ytl/+B9NNPP9Utt9zS50V5jX4PWI3O\nnj1bIyMjl78eHh6uS5cuVaPhj3M1uO++++qHH37o9wxW4KabbqqqP/4O7tq1q3bv3t3nRSxHo9Go\np59+uj7++ON67bXX+j0nzhPuCoyMjNS5c+cuf93pdMQWQn7++ed67LHHanJysh588MF+z2GZXnnl\nlTp69Gg9++yzdf78+X7PiRLcFbj77rvrs88+q6qqU6dO1YYNG/q8CK4Nv/zySz3xxBO1d+/e2r59\ne7/nsAxHjhypt956q6qqbrzxxhoaGqo1a66tBHksW4GJiYn6/PPPa8eOHdXtdmv//v39ngTXhEOH\nDtWvv/5ac3NzNTc3V1V//BDcDTfc0Odl/Df33ntvtVqtevTRR+vSpUv1zDPPXHP3zf8WBAAB19bz\nPAD0ieACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQMD/AWrkpa0qFuevAAAAAElFTkSuQmCC\n",
378 | "text/plain": [
379 | ""
380 | ]
381 | },
382 | "metadata": {},
383 | "output_type": "display_data"
384 | }
385 | ],
386 | "source": [
387 | "### RED = TERMINAL (0)\n",
388 | "### GREEN = LEFT\n",
389 | "### BLUE = UP\n",
390 | "### PURPLE = RIGHT\n",
391 | "### ORANGE = DOWN\n",
392 | "\n",
393 | "show_policy(pi_hat, gw)"
394 | ]
395 | }
396 | ],
397 | "metadata": {
398 | "kernelspec": {
399 | "display_name": "Python 3",
400 | "language": "python",
401 | "name": "python3"
402 | },
403 | "language_info": {
404 | "codemirror_mode": {
405 | "name": "ipython",
406 | "version": 3
407 | },
408 | "file_extension": ".py",
409 | "mimetype": "text/x-python",
410 | "name": "python",
411 | "nbconvert_exporter": "python",
412 | "pygments_lexer": "ipython3",
413 | "version": "3.6.1"
414 | }
415 | },
416 | "nbformat": 4,
417 | "nbformat_minor": 2
418 | }
419 |
--------------------------------------------------------------------------------
/source code/3-Temporal Difference Learning (Chapter 6)/4-double-Q-learning-off-policy-control.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Double Q-Learning\n",
8 | "- Algorithms from ```pp. 110 - 111``` in Sutton & Barto 2017\n",
9 | "- Double Q-learning algorithm employs to action-value functions (i.e., $Q_1 and Q_2$) to avoid maximization bias"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 15,
15 | "metadata": {
16 | "collapsed": true
17 | },
18 | "outputs": [],
19 | "source": [
20 | "import matplotlib.pyplot as plt\n",
21 | "import pandas as pd\n",
22 | "import numpy as np\n",
23 | "import seaborn, random\n",
24 | "\n",
25 | "from gridWorldEnvironment import GridWorld"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": 16,
31 | "metadata": {
32 | "collapsed": true
33 | },
34 | "outputs": [],
35 | "source": [
36 | "# creating gridworld environment\n",
37 | "gw = GridWorld(gamma = .9, theta = .5)"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 17,
43 | "metadata": {
44 | "collapsed": true
45 | },
46 | "outputs": [],
47 | "source": [
48 | "def state_action_value(env):\n",
49 | " q = dict()\n",
50 | " for state, action, next_state, reward in env.transitions:\n",
51 | " q[(state, action)] = np.random.normal()\n",
52 | " for action in env.actions:\n",
53 | " q[0, action] = 0\n",
54 | " return q"
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": 18,
60 | "metadata": {
61 | "scrolled": true
62 | },
63 | "outputs": [
64 | {
65 | "data": {
66 | "text/plain": [
67 | "{(0, 'D'): 0,\n",
68 | " (0, 'L'): 0,\n",
69 | " (0, 'R'): 0,\n",
70 | " (0, 'U'): 0,\n",
71 | " (1, 'D'): 1.0159263422385194,\n",
72 | " (1, 'L'): -0.9015277224413166,\n",
73 | " (1, 'R'): -0.8401929147381398,\n",
74 | " (1, 'U'): -0.3866748964867951,\n",
75 | " (2, 'D'): -1.6171831004555488,\n",
76 | " (2, 'L'): -0.757413177831847,\n",
77 | " (2, 'R'): 1.1948020656778442,\n",
78 | " (2, 'U'): -0.6814167904466197,\n",
79 | " (3, 'D'): 0.8048046240684537,\n",
80 | " (3, 'L'): 0.5058075927359411,\n",
81 | " (3, 'R'): -0.2396998941031068,\n",
82 | " (3, 'U'): 0.2155229655857796,\n",
83 | " (4, 'D'): -1.6147931728590075,\n",
84 | " (4, 'L'): -0.7788455232241043,\n",
85 | " (4, 'R'): -1.4834265811813951,\n",
86 | " (4, 'U'): -0.6831094436413707,\n",
87 | " (5, 'D'): 0.553787569878218,\n",
88 | " (5, 'L'): -0.7535460982740707,\n",
89 | " (5, 'R'): -0.5446399864366045,\n",
90 | " (5, 'U'): 0.3275946472542217,\n",
91 | " (6, 'D'): 0.4967878126487371,\n",
92 | " (6, 'L'): -0.08765553689024165,\n",
93 | " (6, 'R'): 2.561671991266108,\n",
94 | " (6, 'U'): -0.6251803976764682,\n",
95 | " (7, 'D'): -1.4239033515729715,\n",
96 | " (7, 'L'): 0.8961390571157126,\n",
97 | " (7, 'R'): -1.752510423129062,\n",
98 | " (7, 'U'): 0.186255694272231,\n",
99 | " (8, 'D'): 2.0464864623182057,\n",
100 | " (8, 'L'): -1.4145441882793333,\n",
101 | " (8, 'R'): -0.01479790948318322,\n",
102 | " (8, 'U'): -0.5827888737714622,\n",
103 | " (9, 'D'): -1.99334821662137,\n",
104 | " (9, 'L'): -0.2143702145160287,\n",
105 | " (9, 'R'): 0.3093816164828963,\n",
106 | " (9, 'U'): -0.13810316479560683,\n",
107 | " (10, 'D'): -0.4286970274384264,\n",
108 | " (10, 'L'): -1.0313201895326287,\n",
109 | " (10, 'R'): 0.674733204171753,\n",
110 | " (10, 'U'): -0.06485579239030625,\n",
111 | " (11, 'D'): -1.08761491406436,\n",
112 | " (11, 'L'): 0.37568985747762257,\n",
113 | " (11, 'R'): 0.6549082083274831,\n",
114 | " (11, 'U'): -0.044872797471142374,\n",
115 | " (12, 'D'): -0.8474472601122482,\n",
116 | " (12, 'L'): -0.6624032631781939,\n",
117 | " (12, 'R'): 0.4720691318622512,\n",
118 | " (12, 'U'): -1.2333806138625314,\n",
119 | " (13, 'D'): -0.38599325593593187,\n",
120 | " (13, 'L'): -1.677564721893633,\n",
121 | " (13, 'R'): -1.1425774085316547,\n",
122 | " (13, 'U'): -1.3606173055898687,\n",
123 | " (14, 'D'): -1.1641372208339686,\n",
124 | " (14, 'L'): 1.1457975665987994,\n",
125 | " (14, 'R'): -1.012618770180745,\n",
126 | " (14, 'U'): -0.06091464445082383}"
127 | ]
128 | },
129 | "execution_count": 18,
130 | "metadata": {},
131 | "output_type": "execute_result"
132 | }
133 | ],
134 | "source": [
135 | "state_action_value(gw)"
136 | ]
137 | },
138 | {
139 | "cell_type": "code",
140 | "execution_count": 19,
141 | "metadata": {
142 | "collapsed": true
143 | },
144 | "outputs": [],
145 | "source": [
146 | "def generate_greedy_policy(env, Q):\n",
147 | " pi = dict()\n",
148 | " for state in env.states:\n",
149 | " actions = []\n",
150 | " q_values = []\n",
151 | " prob = []\n",
152 | " \n",
153 | " for a in env.actions:\n",
154 | " actions.append(a)\n",
155 | " q_values.append(Q[state,a]) \n",
156 | " for i in range(len(q_values)):\n",
157 | " if i == np.argmax(q_values):\n",
158 | " prob.append(1)\n",
159 | " else:\n",
160 | " prob.append(0) \n",
161 | " \n",
162 | " pi[state] = (actions, prob)\n",
163 | " return pi"
164 | ]
165 | },
166 | {
167 | "cell_type": "code",
168 | "execution_count": 20,
169 | "metadata": {
170 | "collapsed": true
171 | },
172 | "outputs": [],
173 | "source": [
174 | "def e_greedy(env, e, q, state):\n",
175 | " actions = env.actions\n",
176 | " action_values = []\n",
177 | " prob = []\n",
178 | " for action in actions:\n",
179 | " action_values.append(q[(state, action)])\n",
180 | " for i in range(len(action_values)):\n",
181 | " if i == np.argmax(action_values):\n",
182 | " prob.append(1 - e + e/len(action_values))\n",
183 | " else:\n",
184 | " prob.append(e/len(action_values))\n",
185 | " return np.random.choice(actions, p = prob)"
186 | ]
187 | },
188 | {
189 | "cell_type": "code",
190 | "execution_count": 21,
191 | "metadata": {
192 | "collapsed": true
193 | },
194 | "outputs": [],
195 | "source": [
196 | "def greedy(env, q, state):\n",
197 | " actions = env.actions\n",
198 | " action_values = []\n",
199 | " for action in actions:\n",
200 | " action_values.append(q[state, action])\n",
201 | " return actions[np.argmax(action_values)]"
202 | ]
203 | },
204 | {
205 | "cell_type": "markdown",
206 | "metadata": {},
207 | "source": [
208 | "### Double Q-learning"
209 | ]
210 | },
211 | {
212 | "cell_type": "code",
213 | "execution_count": 34,
214 | "metadata": {
215 | "collapsed": true
216 | },
217 | "outputs": [],
218 | "source": [
219 | "def double_q_learning(env, epsilon, alpha, num_iter):\n",
220 | " Q1, Q2 = state_action_value(env), state_action_value(env)\n",
221 | " for _ in range(num_iter):\n",
222 | " current_state = np.random.choice(env.states)\n",
223 | " while current_state != 0:\n",
224 | " Q = dict()\n",
225 | " for key in Q1.keys():\n",
226 | " Q[key] = Q1[key] + Q2[key]\n",
227 | " current_action = e_greedy(env, epsilon, Q, current_state)\n",
228 | " next_state, reward = env.state_transition(current_state, current_action)\n",
229 | " \n",
230 | " # choose Q1 or Q2 with equal probabilities (0.5)\n",
231 | " chosen_Q = np.random.choice([\"Q1\", \"Q2\"])\n",
232 | " if chosen_Q == \"Q1\": # when Q1 is chosen\n",
233 | " best_action = greedy(env, Q1, next_state)\n",
234 | " Q1[current_state, current_action] += alpha * \\\n",
235 | " (reward + env.gamma * Q2[next_state, best_action] - Q1[current_state, current_action])\n",
236 | " else: # when Q2 is chosen\n",
237 | " best_action = greedy(env, Q2, next_state)\n",
238 | " Q2[current_state, current_action] += alpha * \\\n",
239 | " (reward + env.gamma * Q1[next_state, best_action] - Q2[current_state, current_action])\n",
240 | " \n",
241 | " current_state = next_state\n",
242 | " return Q1, Q2"
243 | ]
244 | },
245 | {
246 | "cell_type": "code",
247 | "execution_count": 38,
248 | "metadata": {
249 | "collapsed": true
250 | },
251 | "outputs": [],
252 | "source": [
253 | "Q1, Q2 = double_q_learning(gw, 0.2, 0.5, 5000)\n",
254 | "\n",
255 | "# sum Q1 & Q2 elementwise to obtain final Q-values\n",
256 | "Q = dict()\n",
257 | "for key in Q1.keys():\n",
258 | " Q[key] = Q1[key] + Q2[key]"
259 | ]
260 | },
261 | {
262 | "cell_type": "code",
263 | "execution_count": 39,
264 | "metadata": {
265 | "collapsed": true
266 | },
267 | "outputs": [],
268 | "source": [
269 | "pi_hat = generate_greedy_policy(gw, Q)"
270 | ]
271 | },
272 | {
273 | "cell_type": "code",
274 | "execution_count": 32,
275 | "metadata": {
276 | "collapsed": true
277 | },
278 | "outputs": [],
279 | "source": [
280 | "def show_policy(pi, env):\n",
281 | " temp = np.zeros(len(env.states) + 2)\n",
282 | " for s in env.states:\n",
283 | " a = pi_hat[s][0][np.argmax(pi_hat[s][1])]\n",
284 | " if a == \"U\":\n",
285 | " temp[s] = 0.25\n",
286 | " elif a == \"D\":\n",
287 | " temp[s] = 0.5\n",
288 | " elif a == \"R\":\n",
289 | " temp[s] = 0.75\n",
290 | " else:\n",
291 | " temp[s] = 1.0\n",
292 | " \n",
293 | " temp = temp.reshape(4,4)\n",
294 | " ax = seaborn.heatmap(temp, cmap = \"prism\", linecolor=\"#282828\", cbar = False, linewidths = 0.1)\n",
295 | " plt.show()"
296 | ]
297 | },
298 | {
299 | "cell_type": "code",
300 | "execution_count": 40,
301 | "metadata": {},
302 | "outputs": [
303 | {
304 | "data": {
305 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAdwAAAFJCAYAAAAxCJwFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACrhJREFUeJzt3U2IlYUex/H/OCcqml4gISQI3Dhz4ILSop2LAacXIQZD\nSIspiFm1cASTOI29UF21VdDLYOWqRXdaXWkRSKYUFLQQFAaORQTR26ZVqORo59xFXC9ddG4zl/M7\nzfHz2c3zMM/54YN8fR5mcKjb7XYLAOipNf0eAADXAsEFgADBBYAAwQWAAMEFgADBBYCARi8v3mw2\nq33mTC8/gh5pjo3V39ru3Wq10HT/VquF5lhVlfu3Si00x6rdbl/xnCdcAAgQXAAIEFwACBBcAAgQ\nXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBc\nAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwA\nCBBcAAj408HtdDq93AEAA62x1MnvvvuuDhw4UAsLC9VoNKrT6dSGDRuq1WrV+vXrUxsBYNVbMriz\ns7O1Z8+e2rhx4+Vjp06dqlarVfPz8z0fBwCDYslXyouLi3+IbVXVpk2bejoIAAbRkk+4o6Oj1Wq1\navPmzXXzzTfXuXPn6pNPPqnR0dHUPgAYCEsG94UXXqhjx47VyZMn6+zZszUyMlLj4+M1MTGR2gcA\nA2HJ4A4NDdXExITAAsD/ye/hAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA\n4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDg\nAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAwFC32+326uLNZrNXlwaAv6R2u33F441ef/CZ\njYd7/RH0wNjpafduFXP/Vq+x09NVVdX+55k+L2ElmtvGrnrOK2UACBBcAAgQXAAIEFwACBBcAAgQ\nXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBc\nAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwA\nCBBcAAgQXAAIaCx1cmpqqi5evPiHY91ut4aGhmp+fr6nwwBgkCwZ3Keeeqr27dtXb775Zg0PD6c2\nAcDAWTK4GzdurMnJyfryyy9rYmIitQkABs6Swa2qmp6eTuwAgIHmh6YAIEBwASBAcAEgQHABIEBw\nASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHAB\nIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEg\nYKjb7XZ7dfFms9mrSwPAX1K73b7i8UavP/jMxsO9/gh6YOz0dP2tfabfM1ihheZYtf/p/q1GzW1j\nVVXu3yr17/t3JV4pA0CA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOAC\nQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJA\ngOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQMCyg7u4uNiLHQAw0K4a3OPHj9f4+HhNTEzUhx9+\nePn49PR0ZBgADJLG1U4cOnSojhw5Up1Op2ZmZurChQu1bdu26na7yX0AMBCuGtzrrruubr311qqq\nmpubq8cff7zWrVtXQ0NDsXEAMCiu+kr5zjvvrAMHDtT58+drZGSk3njjjXrxxRfrm2++Se4DgIFw\n1eDu37+/RkdHLz/Rrlu3rt5999164IEHYuMAYFBc9ZVyo9Gohx566A/H1q5dW7Ozsz0fBQCDxu/h\nAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOAC\nQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJA\ngOACQIDgAkCA4AJAgOACQMBQt9vt9urizWazV5cGgL+kdrt9xeONXn/w7Jl/9Poj6IG/j+1071Yx\n92/1+vvYzqqqap850+clrERzbOyq57xSBoAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEF\ngADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWA\nAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgIBlBffXX3+txcXFXm0BgIG1\nZHC//vrrevLJJ6vVatXnn39eW7dura1bt9aJEydS+wBgIDSWOvn888/XzMxM/fDDD7Vr1646evRo\nXX/99TU9PV3j4+OpjQCw6i0Z3E6nU/fcc09VVX3xxRd1++23//5NjSW/DQD4L0u+Ul6/fn3Nzs5W\np9OpgwcPVlXV22+/XWvXro2MA4BBseSj6ssvv1zHjx+vNWv+0+U77rijpqamej4MAAbJksFds2ZN\nbdmy5Q/HJicnezoIAAaR38MFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADB\nBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEF\ngADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYCAoW632+3VxZvNZq8uDQB/Se12+4rHexpcAOB3\nXikDQIDgAkCA4AJAgOACQIDgAkCA4AJAgOCuQKfTqeeee64efvjhmpqaqm+//bbfk1im06dP19TU\nVL9nsEwXL16svXv31iOPPFLbt2+vjz/+uN+T+JN+++23arVatWPHjtq5c2d99dVX/Z4UJ7grcOzY\nsVpcXKz333+/9uzZUwcPHuz3JJbhnXfeqX379tWFCxf6PYVl+uCDD+q2226r9957rw4fPlwvvfRS\nvyfxJ504caKqqubn52v37t316quv9nlRnuCuwMmTJ2vz5s1VVbVp06ZaWFjo8yKW46677qrXX3+9\n3zNYgfvvv79mZmaqqqrb7dbw8HCfF/Fnbdmy5fI/kH788ce65ZZb+rwor9HvAavR2bNna2Rk5PLX\nw8PDdenSpWo0/HGuBvfdd199//33/Z7BCtx0001V9fvfwV27dtXu3bv7vIjlaDQa9fTTT9dHH31U\nr732Wr/nxHnCXYGRkZE6d+7c5a87nY7YQshPP/1Ujz32WE1OTtaDDz7Y7zks0yuvvFJHjx6tZ599\nts6fP9/vOVGCuwJ33313ffrpp1VVderUqdqwYUOfF8G14eeff64nnnii9u7dW9u3b+/3HJbhyJEj\n9dZbb1VV1Y033lhDQ0O1Zs21lSCPZSswMTFRn332We3YsaO63W7t37+/35PgmnDo0KH65Zdfam5u\nrubm5qrq9x+Cu+GGG/q8jP/l3nvvrVarVY8++mhdunSpnnnmmWvuvvnfggAg4Np6ngeAPhFcAAgQ\nXAAIEFwACBBcAAgQXAAIEFwACBBcAAj4F4Bkpa2zV4RVAAAAAElFTkSuQmCC\n",
306 | "text/plain": [
307 | ""
308 | ]
309 | },
310 | "metadata": {},
311 | "output_type": "display_data"
312 | }
313 | ],
314 | "source": [
315 | "### RED = TERMINAL (0)\n",
316 | "### GREEN = LEFT\n",
317 | "### BLUE = UP\n",
318 | "### PURPLE = RIGHT\n",
319 | "### ORANGE = DOWN\n",
320 | "\n",
321 | "show_policy(pi_hat, gw)"
322 | ]
323 | }
324 | ],
325 | "metadata": {
326 | "kernelspec": {
327 | "display_name": "Python 3",
328 | "language": "python",
329 | "name": "python3"
330 | },
331 | "language_info": {
332 | "codemirror_mode": {
333 | "name": "ipython",
334 | "version": 3
335 | },
336 | "file_extension": ".py",
337 | "mimetype": "text/x-python",
338 | "name": "python",
339 | "nbconvert_exporter": "python",
340 | "pygments_lexer": "ipython3",
341 | "version": "3.6.1"
342 | }
343 | },
344 | "nbformat": 4,
345 | "nbformat_minor": 2
346 | }
347 |
--------------------------------------------------------------------------------
/source code/3-Temporal Difference Learning (Chapter 6)/__pycache__/gridWorldEnvironment.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buomsoo-kim/Tabular-RL-with-Python/0157c05126821524dc9d744613b9aa8dd5d47232/source code/3-Temporal Difference Learning (Chapter 6)/__pycache__/gridWorldEnvironment.cpython-36.pyc
--------------------------------------------------------------------------------
/source code/3-Temporal Difference Learning (Chapter 6)/gridWorldEnvironment.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import seaborn
4 | from matplotlib.colors import ListedColormap
5 |
6 | class GridWorld:
7 | def __init__(self, gamma = 1.0, theta = 0.5):
8 | self.actions = ("U", "D", "L", "R")
9 | self.states = np.arange(1, 15)
10 | self.transitions = pd.read_csv("gridworld.txt", header = None, sep = "\t").values
11 | self.gamma = gamma
12 | self.theta = theta
13 |
14 | def state_transition(self, state, action):
15 | next_state, reward = None, None
16 | for tr in self.transitions:
17 | if tr[0] == state and tr[1] == action:
18 | next_state = tr[2]
19 | reward = tr[3]
20 | return next_state, reward
21 |
22 | def show_environment(self):
23 | all_states = np.concatenate(([0], self.states, [0])).reshape(4,4)
24 | colors = []
25 | # colors = ["#ffffff"]
26 | for i in range(len(self.states) + 1):
27 | if i == 0:
28 | colors.append("#c4c4c4")
29 | else:
30 | colors.append("#ffffff")
31 |
32 | cmap = ListedColormap(seaborn.color_palette(colors).as_hex())
33 | ax = seaborn.heatmap(all_states, cmap = cmap, \
34 | annot = True, linecolor = "#282828", linewidths = 0.2, \
35 | cbar = False)
--------------------------------------------------------------------------------
/source code/3-Temporal Difference Learning (Chapter 6)/gridworld.txt:
--------------------------------------------------------------------------------
1 | 1 U 1 -1
2 | 1 D 5 -1
3 | 1 R 2 -1
4 | 1 L 0 -1
5 | 2 U 2 -1
6 | 2 D 6 -1
7 | 2 R 3 -1
8 | 2 L 1 -1
9 | 3 U 3 -1
10 | 3 D 7 -1
11 | 3 R 3 -1
12 | 3 L 2 -1
13 | 4 U 0 -1
14 | 4 D 8 -1
15 | 4 R 5 -1
16 | 4 L 4 -1
17 | 5 U 1 -1
18 | 5 D 9 -1
19 | 5 R 6 -1
20 | 5 L 4 -1
21 | 6 U 2 -1
22 | 6 D 10 -1
23 | 6 R 7 -1
24 | 6 L 5 -1
25 | 7 U 3 -1
26 | 7 D 11 -1
27 | 7 R 7 -1
28 | 7 L 6 -1
29 | 8 U 4 -1
30 | 8 D 12 -1
31 | 8 R 9 -1
32 | 8 L 8 -1
33 | 9 U 5 -1
34 | 9 D 13 -1
35 | 9 R 10 -1
36 | 9 L 8 -1
37 | 10 U 6 -1
38 | 10 D 14 -1
39 | 10 R 11 -1
40 | 10 L 9 -1
41 | 11 U 7 -1
42 | 11 D 0 -1
43 | 11 R 11 -1
44 | 11 L 10 -1
45 | 12 U 8 -1
46 | 12 D 12 -1
47 | 12 R 13 -1
48 | 12 L 12 -1
49 | 13 U 9 -1
50 | 13 D 13 -1
51 | 13 R 14 -1
52 | 13 L 12 -1
53 | 14 U 10 -1
54 | 14 D 14 -1
55 | 14 R 0 -1
56 | 14 L 13 -1
--------------------------------------------------------------------------------
/source code/4-n-step Bootstrapping (Chapter 7)/.ipynb_checkpoints/4-n-step-off-policy-learning-wo-importance-sampling-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# $n$-step off-policy learning without importance sampling\n",
8 | "- Algorithms from ```pp. 124 - 125``` in Sutton & Barto 2017\n",
9 | " - $n$-step Tree Backup Algorithm"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 1,
15 | "metadata": {
16 | "collapsed": true
17 | },
18 | "outputs": [],
19 | "source": [
20 | "import matplotlib.pyplot as plt\n",
21 | "import pandas as pd\n",
22 | "import numpy as np\n",
23 | "import seaborn, random\n",
24 | "\n",
25 | "from gridWorldEnvironment import GridWorld"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": 2,
31 | "metadata": {
32 | "collapsed": true
33 | },
34 | "outputs": [],
35 | "source": [
36 | "# creating gridworld environment\n",
37 | "gw = GridWorld(gamma = .9)"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 7,
43 | "metadata": {
44 | "collapsed": true
45 | },
46 | "outputs": [],
47 | "source": [
48 | "def state_action_value(env):\n",
49 | " q = dict()\n",
50 | " for state, action, next_state, reward in env.transitions:\n",
51 | " q[(state, action)] = np.random.normal()\n",
52 | " return q"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": 9,
58 | "metadata": {
59 | "collapsed": true
60 | },
61 | "outputs": [],
62 | "source": [
63 | "def e_greedy(env, e, q, state):\n",
64 | " actions = env.actions\n",
65 | " action_values = []\n",
66 | " prob = []\n",
67 | " for action in actions:\n",
68 | " action_values.append(q[(state, action)])\n",
69 | " for i in range(len(action_values)):\n",
70 | " if i == np.argmax(action_values):\n",
71 | " prob.append(1 - e + e/len(action_values))\n",
72 | " else:\n",
73 | " prob.append(e/len(action_values))\n",
74 | " return actions, prob"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": 10,
80 | "metadata": {
81 | "collapsed": true
82 | },
83 | "outputs": [],
84 | "source": [
85 | "def generate_e_greedy_policy(env, e, Q):\n",
86 | " pi = dict()\n",
87 | " for state in env.states:\n",
88 | " pi[state] = e_greedy(env, e, Q, state)\n",
89 | " return pi"
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": 11,
95 | "metadata": {
96 | "collapsed": true
97 | },
98 | "outputs": [],
99 | "source": [
100 | "def generate_random_policy(env):\n",
101 | " pi = dict()\n",
102 | " for state in env.states:\n",
103 | " actions = []\n",
104 | " prob = []\n",
105 | " for action in env.actions:\n",
106 | " actions.append(action)\n",
107 | " prob.append(0.25)\n",
108 | " pi[state] = (actions, prob)\n",
109 | " return pi"
110 | ]
111 | },
112 | {
113 | "cell_type": "code",
114 | "execution_count": 19,
115 | "metadata": {
116 | "collapsed": true
117 | },
118 | "outputs": [],
119 | "source": [
120 | "# function for tree backup algorithm\n",
121 | "def avg_over_actions(pi, Q, state):\n",
122 | " actions, probs = pi[state]\n",
123 | " q_values = np.zeros(4)\n",
124 | " for s, a in Q.keys():\n",
125 | " if s == state:\n",
126 | " q_values[actions.index(a)] = Q[s,a]\n",
127 | " return np.dot(q_values, probs)"
128 | ]
129 | },
130 | {
131 | "cell_type": "markdown",
132 | "metadata": {},
133 | "source": [
134 | "### $n$-step off-policy learning without importance sampling\n",
135 | "- The target includes also the estimated values of dangling action nodes hanging off the sides, at all levels"
136 | ]
137 | },
138 | {
139 | "cell_type": "code",
140 | "execution_count": 32,
141 | "metadata": {
142 | "collapsed": true
143 | },
144 | "outputs": [],
145 | "source": [
146 | "def n_step_tree_backup(env, epsilon, alpha, n, num_iter, learn_pi = True):\n",
147 | " Q = state_action_value(env)\n",
148 | " Q_, pi_, delta = dict(), dict(), dict() \n",
149 | " pi = generate_e_greedy_policy(env, epsilon, Q) \n",
150 | "\n",
151 | " for _ in range(num_iter):\n",
152 | " current_state = np.random.choice(env.states)\n",
153 | " action = np.random.choice(b[current_state][0], p = b[current_state][1])\n",
154 | " state_trace, action_trace, reward_trace = [current_state], [action], [0]\n",
155 | " Q_[0] = Q[current_state, action]\n",
156 | " t, T = 0, 10000\n",
157 | " while True:\n",
158 | " if t < T: \n",
159 | " next_state, reward = env.state_transition(current_state, action)\n",
160 | " state_trace.append(next_state)\n",
161 | " reward_trace.append(reward)\n",
162 | " if next_state == 0:\n",
163 | " T = t + 1\n",
164 | " delta[t] = reward - Q_[t]\n",
165 | " else: \n",
166 | " delta[t] = reward + env.gamma * avg_over_actions(pi, Q, next_state) - Q_[t]\n",
167 | " action = np.random.choice(pi[next_state][0], p = pi[next_state][1])\n",
168 | " action_trace.append(action)\n",
169 | " Q_[t+1] = Q[next_state, action]\n",
170 | " pi_[t+1] = pi[next_state][1][pi[next_state][0].index(action)]\n",
171 | " \n",
172 | " tau = t - n + 1\n",
173 | " if tau >= 0:\n",
174 | " Z = 1\n",
175 | " G = Q_[tau]\n",
176 | " for i in range(tau, min([tau + n -1, T-1])):\n",
177 | " G += Z * delta[i]\n",
178 | " Z *= env.gamma * pi_[i+1]\n",
179 | " Q[state_trace[tau], action_trace[tau]] += alpha * (G - Q[state_trace[tau], action_trace[tau]])\n",
180 | " if learn_pi:\n",
181 | " pi[state_trace[tau]] = e_greedy(env, epsilon, Q, state_trace[tau])\n",
182 | " current_state = next_state \n",
183 | "# print(state_trace, action_trace, reward_trace)\n",
184 | " \n",
185 | " if tau == (T-1):\n",
186 | " break\n",
187 | " t += 1\n",
188 | " \n",
189 | " return pi, Q"
190 | ]
191 | },
192 | {
193 | "cell_type": "code",
194 | "execution_count": 39,
195 | "metadata": {
196 | "collapsed": true
197 | },
198 | "outputs": [],
199 | "source": [
200 | "pi, Q = n_step_tree_backup(gw, 0.2, 0.5, 1, 10000)"
201 | ]
202 | },
203 | {
204 | "cell_type": "code",
205 | "execution_count": 40,
206 | "metadata": {},
207 | "outputs": [
208 | {
209 | "data": {
210 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAdwAAAFJCAYAAAAxCJwFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACtdJREFUeJzt3UuIlYX/x/Hv6ImKpgskhASBG2cOBCMt2rkYcLoIMRhC\nWkxBzKqFI5jENHahi9oq6DJYuWpR0yppEUimFBS0EBQGzhQRRLdNq1DJ0c75LQT/9Efn18yP8znN\n8fXazfPg83xwOLx9HmZwoNPpdAoA6Ko1vR4AANcCwQWAAMEFgADBBYAAwQWAAMEFgIBGNy/ebDar\ntbDQzVvQJc3h4VoYOdzrGazQ8OnJan3ss7caNbcNV1XVzMKHPV7CSrw6vLNardYVz3nCBYAAwQWA\nAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAA\nwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADB\nBYAAwQWAAMEFgADBBYCAfxzcdrvdzR0A0NcaS5386aef6sCBAzU/P1+NRqPa7XZt3Lixpqena8OG\nDamNALDqLRncmZmZ2rNnT42MjFw+durUqZqenq65ubmujwOAfrHkK+XFxcW/xbaqatOmTV0dBAD9\naMkn3KGhoZqenq7NmzfXzTffXGfPnq0vvviihoaGUvsAoC8sGdwXX3yxjh07VidPnqwzZ87U4OBg\njY6O1tjYWGofAPSFJYM7MDBQY2NjAgsA/yO/hwsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYIL\nAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsA\nAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAEDnU6n062LN5vNbl0aAP6V\nWq3WFY83un3jmYUPu30LuuDV4Z11d2uh1zNYofnmcC2MHO71DFZg+PRkVVW1Pvb5W42a24aves4r\nZQAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBc\nAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwA\nCBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAhoLHVyYmKiLly48LdjnU6nBgYGam5urqvDAKCf\nLBncp59+uvbt21dvv/12rV27NrUJAPrOksEdGRmp8fHx+vbbb2tsbCy1CQD6zpLBraqanJxM7ACA\nvuaHpgAgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHAB\nIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEg\nQHABIEBwASBAcAEgQHABIEBwASBgoNPpdLp18Waz2a1LA8C/UqvVuuLxRrdvfHdrodu3oAvmm8M1\ns/Bhr2ewQq8O76zWxz57q1Fz23BVVS2MHO7xElZi+PTkVc95pQwAAYILAAGCCwABggsAAYILAAGC\nCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYIL\nAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwAByw7u\n4uJiN3YAQF+7anCPHz9eo6OjNTY2Vp9++unl45OTk5FhANBPGlc7cejQoTpy5Ei12+2ampqq8+fP\n17Zt26rT6ST3AUBfuGpwr7vuurr11lurqmp2draeeOKJWr9+fQ0MDMTGAUC/uOor5TvvvLMOHDhQ\n586dq8HBwXrrrbfqpZdeqh9++CG5DwD6wlWDu3///hoaGrr8RLt+/fp6//3368EHH4yNA4B+cdVX\nyo1Gox5++OG/HVu3bl3NzMx0fRQA9Bu/hwsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGC\nCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYIL\nAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABA51Op9OtizebzW5dGgD+\nlVqt1hWPN7p947tbC92+BV0w3xz2vVvF5pvDtTByuNczWIHh05NVVdVa8PlbjZrDw1c955UyAAQI\nLgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAgu\nAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4A\nBAguAAQILgAECC4ABCwruH/++WctLi52awsA9K0lg/v999/XU089VdPT0/X111/X1q1ba+vWrXXi\nxInUPgDoC42lTr7wwgs1NTVVv/zyS+3atauOHj1a119/fU1OTtbo6GhqIwCseksGt91u17333ltV\nVd98803dfvvtl/5QY8k/BgD8P0u+Ut6wYUPNzMxUu92ugwcPVlXVu+++W+vWrYuMA4B+seSj6iuv\nvFLHjx+vNWv+r8t33HFHTUxMdH0YAPSTJYO7Zs2a2rJly9+OjY+Pd3UQAPQjv4cLAAGCCwABggsA\nAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwAB\nggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGC\nCwABA51Op9OtizebzW5dGgD+lVqt1hWPdzW4AMAlXikDQIDgAkCA4AJAgOACQIDgAkCA4AJAgOCu\nQLvdrueff74eeeSRmpiYqB9//LHXk1im06dP18TERK9nsEwXLlyovXv31qOPPlrbt2+vzz//vNeT\n+If++uuvmp6erh07dtTOnTvru+++6/WkOMFdgWPHjtXi4mJ99NFHtWfPnjp48GCvJ7EM7733Xu3b\nt6/Onz/f6yks0yeffFK33XZbffDBB3X48OF6+eWXez2Jf+jEiRNVVTU3N1e7d++u119/vceL8gR3\nBU6ePFmbN2+uqqpNmzbV/Px8jxexHHfddVe9+eabvZ7BCjzwwAM1NTVVVVWdTqfWrl3b40X8U1u2\nbLn8D6Rff/21brnllh4vymv0esBqdObMmRocHLz89dq1a+vixYvVaPjrXA3uv//++vnnn3s9gxW4\n6aabqurSZ3DXrl21e/fuHi9iORqNRj3zzDP12Wef1RtvvNHrOXGecFdgcHCwzp49e/nrdrstthDy\n22+/1eOPP17j4+P10EMP9XoOy/Taa6/V0aNH67nnnqtz5871ek6U4K7APffcU19++WVVVZ06dao2\nbtzY40Vwbfj999/rySefrL1799b27dt7PYdlOHLkSL3zzjtVVXXjjTfWwMBArVlzbSXIY9kKjI2N\n1VdffVU7duyoTqdT+/fv7/UkuCYcOnSo/vjjj5qdna3Z2dmquvRDcDfccEOPl/Hf3HfffTU9PV2P\nPfZYXbx4sZ599tlr7vvmfwsCgIBr63keAHpEcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEg4D+A\nZKWtiwAl/gAAAABJRU5ErkJggg==\n",
211 | "text/plain": [
212 | ""
213 | ]
214 | },
215 | "metadata": {},
216 | "output_type": "display_data"
217 | }
218 | ],
219 | "source": [
220 | "### RED = TERMINAL (0)\n",
221 | "### GREEN = LEFT\n",
222 | "### BLUE = UP\n",
223 | "### PURPLE = RIGHT\n",
224 | "### ORANGE = DOWN\n",
225 | "\n",
226 | "show_policy(pi, gw)"
227 | ]
228 | }
229 | ],
230 | "metadata": {
231 | "kernelspec": {
232 | "display_name": "Python 3",
233 | "language": "python",
234 | "name": "python3"
235 | },
236 | "language_info": {
237 | "codemirror_mode": {
238 | "name": "ipython",
239 | "version": 3
240 | },
241 | "file_extension": ".py",
242 | "mimetype": "text/x-python",
243 | "name": "python",
244 | "nbconvert_exporter": "python",
245 | "pygments_lexer": "ipython3",
246 | "version": "3.6.1"
247 | }
248 | },
249 | "nbformat": 4,
250 | "nbformat_minor": 2
251 | }
252 |
--------------------------------------------------------------------------------
/source code/4-n-step Bootstrapping (Chapter 7)/2-n-step-SARSA-on-policy-control.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# $n$-step SARSA\n",
8 | "- Algorithms from ```pp. 119 - 120``` in Sutton & Barto 2017"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": 1,
14 | "metadata": {
15 | "collapsed": true
16 | },
17 | "outputs": [],
18 | "source": [
19 | "import matplotlib.pyplot as plt\n",
20 | "import pandas as pd\n",
21 | "import numpy as np\n",
22 | "import seaborn, random\n",
23 | "\n",
24 | "from gridWorldEnvironment import GridWorld"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": 3,
30 | "metadata": {
31 | "collapsed": true
32 | },
33 | "outputs": [],
34 | "source": [
35 | "# creating gridworld environment\n",
36 | "gw = GridWorld(gamma = .9)"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": 59,
42 | "metadata": {
43 | "collapsed": true
44 | },
45 | "outputs": [],
46 | "source": [
47 | "def state_action_value(env):\n",
48 | " q = dict()\n",
49 | " for state, action, next_state, reward in env.transitions:\n",
50 | " q[(state, action)] = np.random.normal()\n",
51 | " return q"
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": 60,
57 | "metadata": {
58 | "scrolled": true
59 | },
60 | "outputs": [
61 | {
62 | "data": {
63 | "text/plain": [
64 | "{(1, 'D'): -0.7030356916846013,\n",
65 | " (1, 'L'): -0.324895205360978,\n",
66 | " (1, 'R'): -1.1734677062639605,\n",
67 | " (1, 'U'): -0.16629963855724197,\n",
68 | " (2, 'D'): -0.23584377637184736,\n",
69 | " (2, 'L'): 1.4901073639401272,\n",
70 | " (2, 'R'): 0.6934460377797941,\n",
71 | " (2, 'U'): 0.8636095279329232,\n",
72 | " (3, 'D'): -1.609655563626268,\n",
73 | " (3, 'L'): 0.9311445594943784,\n",
74 | " (3, 'R'): 0.8985239654154085,\n",
75 | " (3, 'U'): -0.4080380875477237,\n",
76 | " (4, 'D'): -0.17538315728764214,\n",
77 | " (4, 'L'): -1.2559883445535147,\n",
78 | " (4, 'R'): -0.8940561790577116,\n",
79 | " (4, 'U'): 0.21171585977065896,\n",
80 | " (5, 'D'): 0.22188493818420746,\n",
81 | " (5, 'L'): 0.2624753401534311,\n",
82 | " (5, 'R'): -0.47795119564774396,\n",
83 | " (5, 'U'): 1.4805795529303916,\n",
84 | " (6, 'D'): -0.9419847476859503,\n",
85 | " (6, 'L'): -1.2209175101970824,\n",
86 | " (6, 'R'): 0.39723157885321236,\n",
87 | " (6, 'U'): 0.9834147564834003,\n",
88 | " (7, 'D'): 0.6846683809603875,\n",
89 | " (7, 'L'): -3.065126625005005,\n",
90 | " (7, 'R'): -0.9160309917151935,\n",
91 | " (7, 'U'): 0.6802182059638696,\n",
92 | " (8, 'D'): -2.079437344850202,\n",
93 | " (8, 'L'): -0.7871927916410061,\n",
94 | " (8, 'R'): 0.05422301115908965,\n",
95 | " (8, 'U'): -0.3719669376829563,\n",
96 | " (9, 'D'): 0.7168977139674119,\n",
97 | " (9, 'L'): -0.9203614384252929,\n",
98 | " (9, 'R'): -0.02171286359518462,\n",
99 | " (9, 'U'): 0.7715779552124056,\n",
100 | " (10, 'D'): -0.3829133127865197,\n",
101 | " (10, 'L'): 0.37456870550391924,\n",
102 | " (10, 'R'): 1.9331952589125025,\n",
103 | " (10, 'U'): -0.8416217832093862,\n",
104 | " (11, 'D'): -0.6215265359549603,\n",
105 | " (11, 'L'): 0.2577485614406314,\n",
106 | " (11, 'R'): 1.6237960742580893,\n",
107 | " (11, 'U'): -2.215032302441923,\n",
108 | " (12, 'D'): -0.12021109028974186,\n",
109 | " (12, 'L'): 0.1799993340884875,\n",
110 | " (12, 'R'): -0.19649725248693356,\n",
111 | " (12, 'U'): -0.873336849248908,\n",
112 | " (13, 'D'): 0.3872626481962563,\n",
113 | " (13, 'L'): 1.477848159613166,\n",
114 | " (13, 'R'): -0.5006331850401944,\n",
115 | " (13, 'U'): 0.42855578341556333,\n",
116 | " (14, 'D'): 0.5288272721071394,\n",
117 | " (14, 'L'): 1.0695144596153465,\n",
118 | " (14, 'R'): 2.0397741046686506,\n",
119 | " (14, 'U'): 0.6012823125125923}"
120 | ]
121 | },
122 | "execution_count": 60,
123 | "metadata": {},
124 | "output_type": "execute_result"
125 | }
126 | ],
127 | "source": [
128 | "state_action_value(gw)"
129 | ]
130 | },
131 | {
132 | "cell_type": "code",
133 | "execution_count": 5,
134 | "metadata": {
135 | "collapsed": true
136 | },
137 | "outputs": [],
138 | "source": [
139 | "def generate_greedy_policy(env, Q):\n",
140 | " pi = dict()\n",
141 | " for state in env.states:\n",
142 | " actions = []\n",
143 | " q_values = []\n",
144 | " prob = []\n",
145 | " \n",
146 | " for a in env.actions:\n",
147 | " actions.append(a)\n",
148 | " q_values.append(Q[state,a]) \n",
149 | " for i in range(len(q_values)):\n",
150 | " if i == np.argmax(q_values):\n",
151 | " prob.append(1)\n",
152 | " else:\n",
153 | " prob.append(0) \n",
154 | " \n",
155 | " pi[state] = (actions, prob)\n",
156 | " return pi"
157 | ]
158 | },
159 | {
160 | "cell_type": "code",
161 | "execution_count": 6,
162 | "metadata": {
163 | "collapsed": true
164 | },
165 | "outputs": [],
166 | "source": [
167 | "def e_greedy(env, e, q, state):\n",
168 | " actions = env.actions\n",
169 | " action_values = []\n",
170 | " prob = []\n",
171 | " for action in actions:\n",
172 | " action_values.append(q[(state, action)])\n",
173 | " for i in range(len(action_values)):\n",
174 | " if i == np.argmax(action_values):\n",
175 | " prob.append(1 - e + e/len(action_values))\n",
176 | " else:\n",
177 | " prob.append(e/len(action_values))\n",
178 | " return actions, prob"
179 | ]
180 | },
181 | {
182 | "cell_type": "code",
183 | "execution_count": 7,
184 | "metadata": {
185 | "collapsed": true
186 | },
187 | "outputs": [],
188 | "source": [
189 | "# e-greedy policy is an extension of e_greedy()\n",
190 | "def generate_e_greedy_policy(env, e, Q):\n",
191 | " pi = dict()\n",
192 | " for state in env.states:\n",
193 | " pi[state] = e_greedy(env, e, Q, state)\n",
194 | " return pi"
195 | ]
196 | },
197 | {
198 | "cell_type": "code",
199 | "execution_count": 8,
200 | "metadata": {},
201 | "outputs": [
202 | {
203 | "data": {
204 | "text/plain": [
205 | "{1: (('U', 'D', 'L', 'R'), [0.025, 0.025, 0.925, 0.025]),\n",
206 | " 2: (('U', 'D', 'L', 'R'), [0.025, 0.025, 0.925, 0.025]),\n",
207 | " 3: (('U', 'D', 'L', 'R'), [0.025, 0.025, 0.925, 0.025]),\n",
208 | " 4: (('U', 'D', 'L', 'R'), [0.025, 0.025, 0.025, 0.925]),\n",
209 | " 5: (('U', 'D', 'L', 'R'), [0.025, 0.025, 0.925, 0.025]),\n",
210 | " 6: (('U', 'D', 'L', 'R'), [0.025, 0.925, 0.025, 0.025]),\n",
211 | " 7: (('U', 'D', 'L', 'R'), [0.025, 0.025, 0.925, 0.025]),\n",
212 | " 8: (('U', 'D', 'L', 'R'), [0.025, 0.925, 0.025, 0.025]),\n",
213 | " 9: (('U', 'D', 'L', 'R'), [0.025, 0.925, 0.025, 0.025]),\n",
214 | " 10: (('U', 'D', 'L', 'R'), [0.025, 0.025, 0.025, 0.925]),\n",
215 | " 11: (('U', 'D', 'L', 'R'), [0.925, 0.025, 0.025, 0.025]),\n",
216 | " 12: (('U', 'D', 'L', 'R'), [0.025, 0.925, 0.025, 0.025]),\n",
217 | " 13: (('U', 'D', 'L', 'R'), [0.025, 0.025, 0.925, 0.025]),\n",
218 | " 14: (('U', 'D', 'L', 'R'), [0.025, 0.025, 0.025, 0.925])}"
219 | ]
220 | },
221 | "execution_count": 8,
222 | "metadata": {},
223 | "output_type": "execute_result"
224 | }
225 | ],
226 | "source": [
227 | "generate_e_greedy_policy(gw, 0.1, state_action_value(gw))"
228 | ]
229 | },
230 | {
231 | "cell_type": "markdown",
232 | "metadata": {
233 | "collapsed": true
234 | },
235 | "source": [
236 | "### $n$-step SARSA: On-policy TD Control"
237 | ]
238 | },
239 | {
240 | "cell_type": "code",
241 | "execution_count": 61,
242 | "metadata": {
243 | "collapsed": true
244 | },
245 | "outputs": [],
246 | "source": [
247 | "def n_step_sarsa(env, epsilon, alpha, n, num_iter, learn_pi = True):\n",
248 | " Q = state_action_value(env)\n",
249 | " pi = generate_e_greedy_policy(env, epsilon, Q)\n",
250 | " for _ in range(num_iter):\n",
251 | " current_state = np.random.choice(env.states)\n",
252 | " action = np.random.choice(pi[current_state][0], p = pi[current_state][1])\n",
253 | " state_trace, action_trace, reward_trace = [current_state], [action], [0]\n",
254 | " t, T = 0, 10000\n",
255 | " while True:\n",
256 | " if t < T: \n",
257 | " next_state, reward = env.state_transition(current_state, action)\n",
258 | " state_trace.append(next_state)\n",
259 | " reward_trace.append(reward)\n",
260 | " if next_state == 0:\n",
261 | " T = t + 1\n",
262 | " else: \n",
263 | " action = np.random.choice(pi[next_state][0], p = pi[next_state][1])\n",
264 | " action_trace.append(action)\n",
265 | " \n",
266 | " tau = t - n + 1 # tau designates the time step of estimate being update\n",
267 | " if tau >= 0:\n",
268 | " G = 0\n",
269 | " for i in range(tau + 1, min([tau + n, T]) + 1):\n",
270 | " G += (env.gamma ** (i - tau - 1)) * reward_trace[i-1]\n",
271 | " if tau + n < T:\n",
272 | " G += (env.gamma ** n) * Q[state_trace[tau + n], action_trace[tau + n]]\n",
273 | " Q[state_trace[tau], action_trace[tau]] += alpha * (G - Q[state_trace[tau], action_trace[tau]])\n",
274 | " \n",
275 | " # current policy, pi, can be learned each step, or not\n",
276 | " if learn_pi:\n",
277 | " pi[state_trace[tau]] = e_greedy(env, epsilon, Q, state_trace[tau])\n",
278 | " current_state = next_state \n",
279 | " \n",
280 | " if tau == (T-1):\n",
281 | " break\n",
282 | " t += 1\n",
283 | " \n",
284 | " return pi, Q"
285 | ]
286 | },
287 | {
288 | "cell_type": "code",
289 | "execution_count": 62,
290 | "metadata": {
291 | "collapsed": true
292 | },
293 | "outputs": [],
294 | "source": [
295 | "pi, Q = n_step_sarsa(gw, 0.2, 0.5, 3, 1000)"
296 | ]
297 | },
298 | {
299 | "cell_type": "code",
300 | "execution_count": 63,
301 | "metadata": {},
302 | "outputs": [
303 | {
304 | "data": {
305 | "text/plain": [
306 | "{1: (('U', 'D', 'L', 'R'), [0.05, 0.05, 0.8500000000000001, 0.05]),\n",
307 | " 2: (('U', 'D', 'L', 'R'), [0.05, 0.05, 0.8500000000000001, 0.05]),\n",
308 | " 3: (('U', 'D', 'L', 'R'), [0.05, 0.8500000000000001, 0.05, 0.05]),\n",
309 | " 4: (('U', 'D', 'L', 'R'), [0.8500000000000001, 0.05, 0.05, 0.05]),\n",
310 | " 5: (('U', 'D', 'L', 'R'), [0.8500000000000001, 0.05, 0.05, 0.05]),\n",
311 | " 6: (('U', 'D', 'L', 'R'), [0.05, 0.8500000000000001, 0.05, 0.05]),\n",
312 | " 7: (('U', 'D', 'L', 'R'), [0.05, 0.8500000000000001, 0.05, 0.05]),\n",
313 | " 8: (('U', 'D', 'L', 'R'), [0.8500000000000001, 0.05, 0.05, 0.05]),\n",
314 | " 9: (('U', 'D', 'L', 'R'), [0.05, 0.8500000000000001, 0.05, 0.05]),\n",
315 | " 10: (('U', 'D', 'L', 'R'), [0.05, 0.05, 0.05, 0.8500000000000001]),\n",
316 | " 11: (('U', 'D', 'L', 'R'), [0.05, 0.8500000000000001, 0.05, 0.05]),\n",
317 | " 12: (('U', 'D', 'L', 'R'), [0.05, 0.05, 0.05, 0.8500000000000001]),\n",
318 | " 13: (('U', 'D', 'L', 'R'), [0.05, 0.05, 0.05, 0.8500000000000001]),\n",
319 | " 14: (('U', 'D', 'L', 'R'), [0.05, 0.05, 0.05, 0.8500000000000001])}"
320 | ]
321 | },
322 | "execution_count": 63,
323 | "metadata": {},
324 | "output_type": "execute_result"
325 | }
326 | ],
327 | "source": [
328 | "pi"
329 | ]
330 | },
331 | {
332 | "cell_type": "markdown",
333 | "metadata": {},
334 | "source": [
335 | "### Visualizing policy"
336 | ]
337 | },
338 | {
339 | "cell_type": "code",
340 | "execution_count": 65,
341 | "metadata": {
342 | "collapsed": true
343 | },
344 | "outputs": [],
345 | "source": [
346 | "def show_policy(pi, env):\n",
347 | " temp = np.zeros(len(env.states) + 2)\n",
348 | " for s in env.states:\n",
349 | " a = pi[s][0][np.argmax(pi[s][1])]\n",
350 | " if a == \"U\":\n",
351 | " temp[s] = 0.25\n",
352 | " elif a == \"D\":\n",
353 | " temp[s] = 0.5\n",
354 | " elif a == \"R\":\n",
355 | " temp[s] = 0.75\n",
356 | " else:\n",
357 | " temp[s] = 1.0\n",
358 | " \n",
359 | " temp = temp.reshape(4,4)\n",
360 | " ax = seaborn.heatmap(temp, cmap = \"prism\", linecolor=\"#282828\", cbar = False, linewidths = 0.1)\n",
361 | " plt.show()"
362 | ]
363 | },
364 | {
365 | "cell_type": "code",
366 | "execution_count": 66,
367 | "metadata": {},
368 | "outputs": [
369 | {
370 | "data": {
371 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAdwAAAFJCAYAAAAxCJwFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACsFJREFUeJzt3U2IlQXfx/H/OCcqml4gISQI3Dhz4AGlRTsXA04vQgyG\nkBZTELNq4QgmcRp7oRe1VdDLYOWqRfe0SloEkikFBS0EBeFMEUH0tmkVKjnaOfciHh+60bmbeTi/\n0xw/n91cF17nhxfD1+tiBoe63W63AICeWtPvAQBwLRBcAAgQXAAIEFwACBBcAAgQXAAIaPTy4s1m\ns9oLC738CHqkOTZW/9N271arM033b7U60xyrqqr2h+7fatTcNlbtdvuK5zzhAkCA4AJAgOACQIDg\nAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOAC\nQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJA\ngOACQIDgAkDA3w5up9Pp5Q4AGGiNpU7+8MMPdeDAgTpz5kw1Go3qdDq1YcOGarVatX79+tRGAFj1\nlgzu7Oxs7dmzpzZu3Hj52KlTp6rVatX8/HzPxwHAoFjylfLi4uJfYltVtWnTpp4OAoBBtOQT7ujo\naLVardq8eXPdfPPNde7cufrss89qdHQ0tQ8ABsKSwX3hhRfq2LFjdfLkyTp79myNjIzU+Ph4TUxM\npPYBwEBYMrhDQ0M1MTEhsADw/+T3cAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBA\ncAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBw\nASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIGCo2+12e3XxZrPZq0sDwD9Su92+4vFG\nrz94YePhXn8EPTB2etq9W8XGTk9X+8OFfs9gBZrbxqqq3L9V6n/v35V4pQwAAYILAAGCCwABggsA\nAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwAB\nggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGC\nCwABggsAAYILAAGCCwABjaVOTk1N1cWLF/9yrNvt1tDQUM3Pz/d0GAAMkiWD+9RTT9W+ffvqrbfe\nquHh4dQmABg4SwZ348aNNTk5WV9//XVNTEykNgHAwFkyuFVV09PTiR0AMND80BQABAguAAQILgAE\nCC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQI\nLgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAgu\nAAQILgAEDHW73W6vLt5sNnt1aQD4R2q321c83uj1By9sPNzrj6AHxk5PV/vDhX7PYIWa28ZqduFf\n/Z7BCrwytrOqyvffKtXcNnbVc14pA0CA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDg\nAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOAC\nQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQMCyg7u4uNiLHQAw0K4a3OPHj9f4\n+HhNTEzUxx9/fPn49PR0ZBgADJLG1U4cOnSojhw5Up1Op2ZmZurChQu1bdu26na7yX0AMBCuGtzr\nrruubr311qqqmpubq8cff7zWrVtXQ0NDsXEAMCiu+kr5zjvvrAMHDtT58+drZGSk3nzzzXrxxRfr\nu+++S+4DgIFw1eDu37+/RkdHLz/Rrlu3rt5777164IEHYuMAYFBc9ZVyo9Gohx566C/H1q5dW7Oz\nsz0fBQCDxu/hAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDg\nAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOAC\nQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQMBQt9vt9urizWazV5cGgH+kdrt9xeONXn/w7MK/ev0R\n9MArYzvdu1XM/Vu9XhnbWVVV7YWFPi9hJZpjY1c955UyAAQILgAECC4ABAguAAQILgAECC4ABAgu\nAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4A\nBAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABCwruL///nst\nLi72agsADKwlg/vtt9/Wk08+Wa1Wq7788svaunVrbd26tU6cOJHaBwADobHUyeeff75mZmbqp59+\nql27dtXRo0fr+uuvr+np6RofH09tBIBVb8ngdjqduueee6qq6quvvqrbb7/9zz/UWPKPAQD/YclX\nyuvXr6/Z2dnqdDp18ODBqqp65513au3atZFxADAolnxUffnll+v48eO1Zs3/dfmOO+6oqampng8D\ngEGyZHDXrFlTW7Zs+cuxycnJng4CgEHk93ABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBA\ncAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBw\nASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBgqNvtdnt18Waz2atLA8A/Urvd\nvuLxngYXAPiTV8oAECC4ABAguAAQILgAECC4ABAguAAQILgr0Ol06rnnnquHH364pqam6vvvv+/3\nJJbp9OnTNTU11e8ZLNPFixdr79699cgjj9T27dvr008/7fck/qY//vijWq1W7dixo3bu3FnffPNN\nvyfFCe4KHDt2rBYXF+uDDz6oPXv21MGDB/s9iWV49913a9++fXXhwoV+T2GZPvroo7rtttvq/fff\nr8OHD9dLL73U70n8TSdOnKiqqvn5+dq9e3e99tprfV6UJ7grcPLkydq8eXNVVW3atKnOnDnT50Us\nx1133VVvvPFGv2ewAvfff3/NzMxUVVW3263h4eE+L+Lv2rJly+V/IP388891yy239HlRXqPfA1aj\ns2fP1sjIyOWvh4eH69KlS9Vo+OtcDe6777768ccf+z2DFbjpppuq6s/vwV27dtXu3bv7vIjlaDQa\n9fTTT9cnn3xSr7/+er/nxHnCXYGRkZE6d+7c5a87nY7YQsgvv/xSjz32WE1OTtaDDz7Y7zks06uv\nvlpHjx6tZ599ts6fP9/vOVGCuwJ33313ff7551VVderUqdqwYUOfF8G14ddff60nnnii9u7dW9u3\nb+/3HJbhyJEj9fbbb1dV1Y033lhDQ0O1Zs21lSCPZSswMTFRX3zxRe3YsaO63W7t37+/35PgmnDo\n0KH67bffam5urubm5qrqzx+Cu+GGG/q8jP/m3nvvrVarVY8++mhdunSpnnnmmWvuvvnfggAg4Np6\nngeAPhFcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAj4N9pmpa1W3XhOAAAAAElFTkSuQmCC\n",
372 | "text/plain": [
373 | ""
374 | ]
375 | },
376 | "metadata": {},
377 | "output_type": "display_data"
378 | }
379 | ],
380 | "source": [
381 | "### RED = TERMINAL (0)\n",
382 | "### GREEN = LEFT\n",
383 | "### BLUE = UP\n",
384 | "### PURPLE = RIGHT\n",
385 | "### ORANGE = DOWN\n",
386 | "\n",
387 | "show_policy(pi, gw)"
388 | ]
389 | }
390 | ],
391 | "metadata": {
392 | "kernelspec": {
393 | "display_name": "Python 3",
394 | "language": "python",
395 | "name": "python3"
396 | },
397 | "language_info": {
398 | "codemirror_mode": {
399 | "name": "ipython",
400 | "version": 3
401 | },
402 | "file_extension": ".py",
403 | "mimetype": "text/x-python",
404 | "name": "python",
405 | "nbconvert_exporter": "python",
406 | "pygments_lexer": "ipython3",
407 | "version": "3.6.1"
408 | }
409 | },
410 | "nbformat": 4,
411 | "nbformat_minor": 2
412 | }
413 |
--------------------------------------------------------------------------------
/source code/4-n-step Bootstrapping (Chapter 7)/4-n-step-off-policy-learning-wo-importance-sampling.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# $n$-step off-policy learning without importance sampling\n",
8 | "- Algorithms from ```pp. 124 - 125``` in Sutton & Barto 2017\n",
9 | " - $n$-step Tree Backup Algorithm"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 1,
15 | "metadata": {
16 | "collapsed": true
17 | },
18 | "outputs": [],
19 | "source": [
20 | "import matplotlib.pyplot as plt\n",
21 | "import pandas as pd\n",
22 | "import numpy as np\n",
23 | "import seaborn, random\n",
24 | "\n",
25 | "from gridWorldEnvironment import GridWorld"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": 2,
31 | "metadata": {
32 | "collapsed": true
33 | },
34 | "outputs": [],
35 | "source": [
36 | "# creating gridworld environment\n",
37 | "gw = GridWorld(gamma = .9)"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 7,
43 | "metadata": {
44 | "collapsed": true
45 | },
46 | "outputs": [],
47 | "source": [
48 | "def state_action_value(env):\n",
49 | " q = dict()\n",
50 | " for state, action, next_state, reward in env.transitions:\n",
51 | " q[(state, action)] = np.random.normal()\n",
52 | " return q"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": 9,
58 | "metadata": {
59 | "collapsed": true
60 | },
61 | "outputs": [],
62 | "source": [
63 | "def e_greedy(env, e, q, state):\n",
64 | " actions = env.actions\n",
65 | " action_values = []\n",
66 | " prob = []\n",
67 | " for action in actions:\n",
68 | " action_values.append(q[(state, action)])\n",
69 | " for i in range(len(action_values)):\n",
70 | " if i == np.argmax(action_values):\n",
71 | " prob.append(1 - e + e/len(action_values))\n",
72 | " else:\n",
73 | " prob.append(e/len(action_values))\n",
74 | " return actions, prob"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": 10,
80 | "metadata": {
81 | "collapsed": true
82 | },
83 | "outputs": [],
84 | "source": [
85 | "def generate_e_greedy_policy(env, e, Q):\n",
86 | " pi = dict()\n",
87 | " for state in env.states:\n",
88 | " pi[state] = e_greedy(env, e, Q, state)\n",
89 | " return pi"
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": 11,
95 | "metadata": {
96 | "collapsed": true
97 | },
98 | "outputs": [],
99 | "source": [
100 | "def generate_random_policy(env):\n",
101 | " pi = dict()\n",
102 | " for state in env.states:\n",
103 | " actions = []\n",
104 | " prob = []\n",
105 | " for action in env.actions:\n",
106 | " actions.append(action)\n",
107 | " prob.append(0.25)\n",
108 | " pi[state] = (actions, prob)\n",
109 | " return pi"
110 | ]
111 | },
112 | {
113 | "cell_type": "code",
114 | "execution_count": 19,
115 | "metadata": {
116 | "collapsed": true
117 | },
118 | "outputs": [],
119 | "source": [
120 | "# function for tree backup algorithm\n",
121 | "def avg_over_actions(pi, Q, state):\n",
122 | " actions, probs = pi[state]\n",
123 | " q_values = np.zeros(4)\n",
124 | " for s, a in Q.keys():\n",
125 | " if s == state:\n",
126 | " q_values[actions.index(a)] = Q[s,a]\n",
127 | " return np.dot(q_values, probs)"
128 | ]
129 | },
130 | {
131 | "cell_type": "markdown",
132 | "metadata": {},
133 | "source": [
134 | "### $n$-step off-policy learning without importance sampling\n",
135 | "- The target includes also the estimated values of dangling action nodes hanging off the sides, at all levels"
136 | ]
137 | },
138 | {
139 | "cell_type": "code",
140 | "execution_count": 32,
141 | "metadata": {
142 | "collapsed": true
143 | },
144 | "outputs": [],
145 | "source": [
146 | "def n_step_tree_backup(env, epsilon, alpha, n, num_iter, learn_pi = True):\n",
147 | " Q = state_action_value(env)\n",
148 | " Q_, pi_, delta = dict(), dict(), dict() \n",
149 | " pi = generate_e_greedy_policy(env, epsilon, Q) \n",
150 | "\n",
151 | " for _ in range(num_iter):\n",
152 | " current_state = np.random.choice(env.states)\n",
153 | " action = np.random.choice(b[current_state][0], p = b[current_state][1])\n",
154 | " state_trace, action_trace, reward_trace = [current_state], [action], [0]\n",
155 | " Q_[0] = Q[current_state, action]\n",
156 | " t, T = 0, 10000\n",
157 | " while True:\n",
158 | " if t < T: \n",
159 | " next_state, reward = env.state_transition(current_state, action)\n",
160 | " state_trace.append(next_state)\n",
161 | " reward_trace.append(reward)\n",
162 | " if next_state == 0:\n",
163 | " T = t + 1\n",
164 | " delta[t] = reward - Q_[t]\n",
165 | " else: \n",
166 | " delta[t] = reward + env.gamma * avg_over_actions(pi, Q, next_state) - Q_[t]\n",
167 | " action = np.random.choice(pi[next_state][0], p = pi[next_state][1])\n",
168 | " action_trace.append(action)\n",
169 | " Q_[t+1] = Q[next_state, action]\n",
170 | " pi_[t+1] = pi[next_state][1][pi[next_state][0].index(action)]\n",
171 | " \n",
172 | " tau = t - n + 1\n",
173 | " if tau >= 0:\n",
174 | " Z = 1\n",
175 | " G = Q_[tau]\n",
176 | " for i in range(tau, min([tau + n -1, T-1])):\n",
177 | " G += Z * delta[i]\n",
178 | " Z *= env.gamma * pi_[i+1]\n",
179 | " Q[state_trace[tau], action_trace[tau]] += alpha * (G - Q[state_trace[tau], action_trace[tau]])\n",
180 | " if learn_pi:\n",
181 | " pi[state_trace[tau]] = e_greedy(env, epsilon, Q, state_trace[tau])\n",
182 | " current_state = next_state \n",
183 | "# print(state_trace, action_trace, reward_trace)\n",
184 | " \n",
185 | " if tau == (T-1):\n",
186 | " break\n",
187 | " t += 1\n",
188 | " \n",
189 | " return pi, Q"
190 | ]
191 | },
192 | {
193 | "cell_type": "code",
194 | "execution_count": 39,
195 | "metadata": {
196 | "collapsed": true
197 | },
198 | "outputs": [],
199 | "source": [
200 | "pi, Q = n_step_tree_backup(gw, 0.2, 0.5, 1, 10000)"
201 | ]
202 | },
203 | {
204 | "cell_type": "code",
205 | "execution_count": 40,
206 | "metadata": {},
207 | "outputs": [
208 | {
209 | "data": {
210 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAdwAAAFJCAYAAAAxCJwFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACtdJREFUeJzt3UuIlYX/x/Hv6ImKpgskhASBG2cOBCMt2rkYcLoIMRhC\nWkxBzKqFI5jENHahi9oq6DJYuWpR0yppEUimFBS0EBQGzhQRRLdNq1DJ0c75LQT/9Efn18yP8znN\n8fXazfPg83xwOLx9HmZwoNPpdAoA6Ko1vR4AANcCwQWAAMEFgADBBYAAwQWAAMEFgIBGNy/ebDar\ntbDQzVvQJc3h4VoYOdzrGazQ8OnJan3ss7caNbcNV1XVzMKHPV7CSrw6vLNardYVz3nCBYAAwQWA\nAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAA\nwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADB\nBYAAwQWAAMEFgADBBYCAfxzcdrvdzR0A0NcaS5386aef6sCBAzU/P1+NRqPa7XZt3Lixpqena8OG\nDamNALDqLRncmZmZ2rNnT42MjFw+durUqZqenq65ubmujwOAfrHkK+XFxcW/xbaqatOmTV0dBAD9\naMkn3KGhoZqenq7NmzfXzTffXGfPnq0vvviihoaGUvsAoC8sGdwXX3yxjh07VidPnqwzZ87U4OBg\njY6O1tjYWGofAPSFJYM7MDBQY2NjAgsA/yO/hwsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYIL\nAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsA\nAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAEDnU6n062LN5vNbl0aAP6V\nWq3WFY83un3jmYUPu30LuuDV4Z11d2uh1zNYofnmcC2MHO71DFZg+PRkVVW1Pvb5W42a24aves4r\nZQAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBc\nAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwA\nCBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAhoLHVyYmKiLly48LdjnU6nBgYGam5urqvDAKCf\nLBncp59+uvbt21dvv/12rV27NrUJAPrOksEdGRmp8fHx+vbbb2tsbCy1CQD6zpLBraqanJxM7ACA\nvuaHpgAgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHAB\nIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEg\nQHABIEBwASBAcAEgQHABIEBwASBgoNPpdLp18Waz2a1LA8C/UqvVuuLxRrdvfHdrodu3oAvmm8M1\ns/Bhr2ewQq8O76zWxz57q1Fz23BVVS2MHO7xElZi+PTkVc95pQwAAYILAAGCCwABggsAAYILAAGC\nCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYIL\nAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwAByw7u\n4uJiN3YAQF+7anCPHz9eo6OjNTY2Vp9++unl45OTk5FhANBPGlc7cejQoTpy5Ei12+2ampqq8+fP\n17Zt26rT6ST3AUBfuGpwr7vuurr11lurqmp2draeeOKJWr9+fQ0MDMTGAUC/uOor5TvvvLMOHDhQ\n586dq8HBwXrrrbfqpZdeqh9++CG5DwD6wlWDu3///hoaGrr8RLt+/fp6//3368EHH4yNA4B+cdVX\nyo1Gox5++OG/HVu3bl3NzMx0fRQA9Bu/hwsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGC\nCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYIL\nAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABA51Op9OtizebzW5dGgD+\nlVqt1hWPN7p947tbC92+BV0w3xz2vVvF5pvDtTByuNczWIHh05NVVdVa8PlbjZrDw1c955UyAAQI\nLgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAgu\nAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4A\nBAguAAQILgAECC4ABCwruH/++WctLi52awsA9K0lg/v999/XU089VdPT0/X111/X1q1ba+vWrXXi\nxInUPgDoC42lTr7wwgs1NTVVv/zyS+3atauOHj1a119/fU1OTtbo6GhqIwCseksGt91u17333ltV\nVd98803dfvvtl/5QY8k/BgD8P0u+Ut6wYUPNzMxUu92ugwcPVlXVu+++W+vWrYuMA4B+seSj6iuv\nvFLHjx+vNWv+r8t33HFHTUxMdH0YAPSTJYO7Zs2a2rJly9+OjY+Pd3UQAPQjv4cLAAGCCwABggsA\nAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwAB\nggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGC\nCwABA51Op9OtizebzW5dGgD+lVqt1hWPdzW4AMAlXikDQIDgAkCA4AJAgOACQIDgAkCA4AJAgOCu\nQLvdrueff74eeeSRmpiYqB9//LHXk1im06dP18TERK9nsEwXLlyovXv31qOPPlrbt2+vzz//vNeT\n+If++uuvmp6erh07dtTOnTvru+++6/WkOMFdgWPHjtXi4mJ99NFHtWfPnjp48GCvJ7EM7733Xu3b\nt6/Onz/f6yks0yeffFK33XZbffDBB3X48OF6+eWXez2Jf+jEiRNVVTU3N1e7d++u119/vceL8gR3\nBU6ePFmbN2+uqqpNmzbV/Px8jxexHHfddVe9+eabvZ7BCjzwwAM1NTVVVVWdTqfWrl3b40X8U1u2\nbLn8D6Rff/21brnllh4vymv0esBqdObMmRocHLz89dq1a+vixYvVaPjrXA3uv//++vnnn3s9gxW4\n6aabqurSZ3DXrl21e/fuHi9iORqNRj3zzDP12Wef1RtvvNHrOXGecFdgcHCwzp49e/nrdrstthDy\n22+/1eOPP17j4+P10EMP9XoOy/Taa6/V0aNH67nnnqtz5871ek6U4K7APffcU19++WVVVZ06dao2\nbtzY40Vwbfj999/rySefrL1799b27dt7PYdlOHLkSL3zzjtVVXXjjTfWwMBArVlzbSXIY9kKjI2N\n1VdffVU7duyoTqdT+/fv7/UkuCYcOnSo/vjjj5qdna3Z2dmquvRDcDfccEOPl/Hf3HfffTU9PV2P\nPfZYXbx4sZ599tlr7vvmfwsCgIBr63keAHpEcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEg4D+A\nZKWtiwAl/gAAAABJRU5ErkJggg==\n",
211 | "text/plain": [
212 | ""
213 | ]
214 | },
215 | "metadata": {},
216 | "output_type": "display_data"
217 | }
218 | ],
219 | "source": [
220 | "### RED = TERMINAL (0)\n",
221 | "### GREEN = LEFT\n",
222 | "### BLUE = UP\n",
223 | "### PURPLE = RIGHT\n",
224 | "### ORANGE = DOWN\n",
225 | "\n",
226 | "show_policy(pi, gw)"
227 | ]
228 | }
229 | ],
230 | "metadata": {
231 | "kernelspec": {
232 | "display_name": "Python 3",
233 | "language": "python",
234 | "name": "python3"
235 | },
236 | "language_info": {
237 | "codemirror_mode": {
238 | "name": "ipython",
239 | "version": 3
240 | },
241 | "file_extension": ".py",
242 | "mimetype": "text/x-python",
243 | "name": "python",
244 | "nbconvert_exporter": "python",
245 | "pygments_lexer": "ipython3",
246 | "version": "3.6.1"
247 | }
248 | },
249 | "nbformat": 4,
250 | "nbformat_minor": 2
251 | }
252 |
--------------------------------------------------------------------------------
/source code/4-n-step Bootstrapping (Chapter 7)/__pycache__/gridWorldEnvironment.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buomsoo-kim/Tabular-RL-with-Python/0157c05126821524dc9d744613b9aa8dd5d47232/source code/4-n-step Bootstrapping (Chapter 7)/__pycache__/gridWorldEnvironment.cpython-36.pyc
--------------------------------------------------------------------------------
/source code/4-n-step Bootstrapping (Chapter 7)/gridWorldEnvironment.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import seaborn
4 | from matplotlib.colors import ListedColormap
5 |
6 | class GridWorld:
7 | def __init__(self, gamma = 1.0, theta = 0.5):
8 | self.actions = ("U", "D", "L", "R")
9 | self.states = np.arange(1, 15)
10 | self.transitions = pd.read_csv("gridworld.txt", header = None, sep = "\t").values
11 | self.gamma = gamma
12 | self.theta = theta
13 |
14 | def state_transition(self, state, action):
15 | next_state, reward = None, None
16 | for tr in self.transitions:
17 | if tr[0] == state and tr[1] == action:
18 | next_state = tr[2]
19 | reward = tr[3]
20 | return next_state, reward
21 |
22 | def show_environment(self):
23 | all_states = np.concatenate(([0], self.states, [0])).reshape(4,4)
24 | colors = []
25 | # colors = ["#ffffff"]
26 | for i in range(len(self.states) + 1):
27 | if i == 0:
28 | colors.append("#c4c4c4")
29 | else:
30 | colors.append("#ffffff")
31 |
32 | cmap = ListedColormap(seaborn.color_palette(colors).as_hex())
33 | ax = seaborn.heatmap(all_states, cmap = cmap, \
34 | annot = True, linecolor = "#282828", linewidths = 0.2, \
35 | cbar = False)
--------------------------------------------------------------------------------
/source code/4-n-step Bootstrapping (Chapter 7)/gridworld.txt:
--------------------------------------------------------------------------------
1 | 1 U 1 -1
2 | 1 D 5 -1
3 | 1 R 2 -1
4 | 1 L 0 -1
5 | 2 U 2 -1
6 | 2 D 6 -1
7 | 2 R 3 -1
8 | 2 L 1 -1
9 | 3 U 3 -1
10 | 3 D 7 -1
11 | 3 R 3 -1
12 | 3 L 2 -1
13 | 4 U 0 -1
14 | 4 D 8 -1
15 | 4 R 5 -1
16 | 4 L 4 -1
17 | 5 U 1 -1
18 | 5 D 9 -1
19 | 5 R 6 -1
20 | 5 L 4 -1
21 | 6 U 2 -1
22 | 6 D 10 -1
23 | 6 R 7 -1
24 | 6 L 5 -1
25 | 7 U 3 -1
26 | 7 D 11 -1
27 | 7 R 7 -1
28 | 7 L 6 -1
29 | 8 U 4 -1
30 | 8 D 12 -1
31 | 8 R 9 -1
32 | 8 L 8 -1
33 | 9 U 5 -1
34 | 9 D 13 -1
35 | 9 R 10 -1
36 | 9 L 8 -1
37 | 10 U 6 -1
38 | 10 D 14 -1
39 | 10 R 11 -1
40 | 10 L 9 -1
41 | 11 U 7 -1
42 | 11 D 0 -1
43 | 11 R 11 -1
44 | 11 L 10 -1
45 | 12 U 8 -1
46 | 12 D 12 -1
47 | 12 R 13 -1
48 | 12 L 12 -1
49 | 13 U 9 -1
50 | 13 D 13 -1
51 | 13 R 14 -1
52 | 13 L 12 -1
53 | 14 U 10 -1
54 | 14 D 14 -1
55 | 14 R 0 -1
56 | 14 L 13 -1
--------------------------------------------------------------------------------