├── README.md ├── imgs ├── 1.PNG ├── 2.PNG └── sutton_barto_example.PNG └── source code ├── 0-MDP Environment (Chapter 3) ├── .ipynb_checkpoints │ └── 1-introduction-to-gridworld-environment-checkpoint.ipynb ├── 1-introduction-to-gridworld-environment.ipynb ├── __pycache__ │ └── gridWorldEnvironment.cpython-36.pyc ├── gridWorldEnvironment.py └── gridworld.txt ├── 1-Dynamic Programming (Chapter 4) ├── .ipynb_checkpoints │ ├── 1-policy-evaluation-and-improvement-checkpoint.ipynb │ ├── 2-policy-iteration-checkpoint.ipynb │ └── 3-value-iteration-checkpoint.ipynb ├── 1-policy-evaluation-and-improvement.ipynb ├── 2-policy-iteration.ipynb ├── 3-value-iteration.ipynb ├── gridWorldEnvironment.py └── gridworld.txt ├── 2-Monte Carlo Methods (Chapter 5) ├── .ipynb_checkpoints │ ├── 1-monte-carlo-prediction-checkpoint.ipynb │ ├── 2-monte-carlo-exploring-starts-checkpoint.ipynb │ ├── 3-on-policy-monte-carlo-checkpoint.ipynb │ └── 4-off-policy-monte-carlo-checkpoint.ipynb ├── 1-monte-carlo-prediction.ipynb ├── 2-monte-carlo-exploring-starts.ipynb ├── 3-on-policy-monte-carlo.ipynb ├── 4-off-policy-monte-carlo.ipynb ├── __pycache__ │ └── gridWorldEnvironment.cpython-36.pyc ├── gridWorldEnvironment.py └── gridworld.txt ├── 3-Temporal Difference Learning (Chapter 6) ├── .ipynb_checkpoints │ ├── 1-td-prediction-checkpoint.ipynb │ ├── 2-SARSA-on-policy control-checkpoint.ipynb │ ├── 3-Q-learning-off-policy-control-checkpoint.ipynb │ └── 4-double-Q-learning-off-policy-control-checkpoint.ipynb ├── 1-td-prediction.ipynb ├── 2-SARSA-on-policy control.ipynb ├── 3-Q-learning-off-policy-control.ipynb ├── 4-double-Q-learning-off-policy-control.ipynb ├── __pycache__ │ └── gridWorldEnvironment.cpython-36.pyc ├── gridWorldEnvironment.py └── gridworld.txt └── 4-n-step Bootstrapping (Chapter 7) ├── .ipynb_checkpoints ├── 1-n-step-td-prediction-checkpoint.ipynb ├── 2-n-step-SARSA-on-policy-control-checkpoint.ipynb ├── 3-n-step-off-policy-learning-by-importance-sampling-checkpoint.ipynb └── 4-n-step-off-policy-learning-wo-importance-sampling-checkpoint.ipynb ├── 1-n-step-td-prediction.ipynb ├── 2-n-step-SARSA-on-policy-control.ipynb ├── 3-n-step-off-policy-learning-by-importance-sampling.ipynb ├── 4-n-step-off-policy-learning-wo-importance-sampling.ipynb ├── __pycache__ └── gridWorldEnvironment.cpython-36.pyc ├── gridWorldEnvironment.py └── gridworld.txt /README.md: -------------------------------------------------------------------------------- 1 | # Tabular Reinforcement Learning with Algorithms Python 2 | Python implementation of Tabular RL Algorithms in Sutton & Barto 2017 (Reinforcement Learning: An Introduction) 3 | Using only NumPy & basic Python data structures (list, tuple, set, and dictionary) to create environment & create algorithms 4 | 5 | ### Algorithms learning from *4X4 Grid World Environment (From Sutton & Barto 2017, pp. 61)* 6 | 7 | ![Alt text](/imgs/sutton_barto_example.PNG)
8 | 9 | ### Tabular Reinforcement Learning Algorithms with *NumPy* 10 | 11 | ![Alt text](/imgs/1.PNG)
12 | 13 | ### Visualizations with *Seaborn* (Policy & Value function) 14 | 15 | ![Alt text](/imgs/2.PNG) 16 | 17 | 18 | ## Contents 19 | 20 | ### 0. MDP Environment (Chapter 3, Sutton & Barto 2017) 21 | 1. Introduction to gridworld environment 22 | 23 | ### 1. Dynamic Programming (Chapter 4, Sutton & Barto 2017) 24 | 1. Policy Evaluation and improvement 25 | 2. Policy Iteration 26 | 3. Value Iteration 27 | 28 | ### 2. Monte Carlo Methods (Chapter 5, Sutton & Barto 2017) 29 | 1. Monte Carlo Prediction 30 | 2. Monte Carlo Exploring Starts 31 | 3. On Policy Monte Carlo 32 | 4. Off Policy Monte Carlo 33 | 34 | ### 3. Temporal Difference Learning (Chapter 6, Sutton & Barto 2017) 35 | 1. TD Prediction 36 | 2. SARSA - On-policy Control 37 | 3. Q-learning - Off-policy Control 38 | 4. Double Q-learning - Off-policy Control 39 | 40 | ### 4. n-step Bootstrapping (Chapter 7, Sutton & Barto 2017) 41 | 1. n-step TD Prediction 42 | 2. n-step SARSA - On-policy Control 43 | 3. n-step Off-policy learning by Importance Sampling 44 | 4. n-step Off-policy learning without Importance Sampling 45 | 46 | 47 | -------------------------------------------------------------------------------- /imgs/1.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buomsoo-kim/Tabular-RL-with-Python/0157c05126821524dc9d744613b9aa8dd5d47232/imgs/1.PNG -------------------------------------------------------------------------------- /imgs/2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buomsoo-kim/Tabular-RL-with-Python/0157c05126821524dc9d744613b9aa8dd5d47232/imgs/2.PNG -------------------------------------------------------------------------------- /imgs/sutton_barto_example.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buomsoo-kim/Tabular-RL-with-Python/0157c05126821524dc9d744613b9aa8dd5d47232/imgs/sutton_barto_example.PNG -------------------------------------------------------------------------------- /source code/0-MDP Environment (Chapter 3)/.ipynb_checkpoints/1-introduction-to-gridworld-environment-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# GridWorld Environment" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": { 14 | "collapsed": true 15 | }, 16 | "outputs": [], 17 | "source": [ 18 | "import matplotlib.pyplot as plt\n", 19 | "import pandas as pd\n", 20 | "from gridWorldEnvironment import GridWorld" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 2, 26 | "metadata": {}, 27 | "outputs": [ 28 | { 29 | "name": "stdout", 30 | "output_type": "stream", 31 | "text": [ 32 | "Actions: ('U', 'D', 'L', 'R')\n", 33 | "States: [ 1 2 3 4 5 6 7 8 9 10 11 12 13 14]\n" 34 | ] 35 | } 36 | ], 37 | "source": [ 38 | "# creating gridworld environment\n", 39 | "gw = GridWorld()\n", 40 | "\n", 41 | "print(\"Actions: \", gw.actions)\n", 42 | "print(\"States: \", gw.states)" 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "metadata": {}, 48 | "source": [ 49 | "### State Transitions\n", 50 | "- All possible state transitions in deterministic gridworld\n", 51 | " - Each transition is a quadruple of (```state```,```action```, ```next state```, ```reward```)\n", 52 | " - For instance, first row implies that in if the agent performs action ```U``` (upwards) in state ```1```, it ends up at state ```1``` again (not moving) with -1 reward" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 12, 58 | "metadata": { 59 | "scrolled": true 60 | }, 61 | "outputs": [ 62 | { 63 | "data": { 64 | "text/html": [ 65 | "
\n", 66 | "\n", 79 | "\n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | "
StateActionNext StateReward
01U1-1
11D5-1
21R2-1
31L0-1
42U2-1
\n", 127 | "
" 128 | ], 129 | "text/plain": [ 130 | " State Action Next State Reward\n", 131 | "0 1 U 1 -1\n", 132 | "1 1 D 5 -1\n", 133 | "2 1 R 2 -1\n", 134 | "3 1 L 0 -1\n", 135 | "4 2 U 2 -1" 136 | ] 137 | }, 138 | "execution_count": 12, 139 | "metadata": {}, 140 | "output_type": "execute_result" 141 | } 142 | ], 143 | "source": [ 144 | "pd.DataFrame(gw.transitions, columns = [\"State\", \"Action\", \"Next State\", \"Reward\"]).head()" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 14, 150 | "metadata": {}, 151 | "outputs": [ 152 | { 153 | "name": "stdout", 154 | "output_type": "stream", 155 | "text": [ 156 | "(1, -1)\n", 157 | "(2, -1)\n" 158 | ] 159 | } 160 | ], 161 | "source": [ 162 | "print(gw.state_transition(1, \"U\"))\n", 163 | "print(gw.state_transition(3, \"L\"))" 164 | ] 165 | }, 166 | { 167 | "cell_type": "markdown", 168 | "metadata": {}, 169 | "source": [ 170 | "### Show environment\n", 171 | "- Visualized environment is shown as table\n", 172 | "- Note that terminal states ```(0,0) and (3,3)``` are added" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 3, 178 | "metadata": {}, 179 | "outputs": [ 180 | { 181 | "data": { 182 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAdwAAAFJCAYAAAAxCJwFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAFnJJREFUeJzt3X1U1vX9x/HXBah58BZDdIlnuKnnWi7NFqftZB2HpukY\njtyRpugUjmflTR3TTYRgNgw967i1xKWgTZdJDQ2ciU1mudYOO9XKHRy5pZtmGunxXrkTvvvDE+dX\nB+m6OL/r/eEaz8dfXpeni1d9gCff78U5+TzP8wQAAEIqwvUAAAC6AoILAIABggsAgAGCCwCAAYIL\nAIABggsAgIGoUL643+9XcXFxKD8EQiQzM1M1NTWuZ6CD/H4/5xem/H6/JHF+Yaq9rz2ucAEAMEBw\nAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEA\nMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBA\ncAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEAMEBwAQAwEOV6QDg6ePCgduzYoaamJsXH\nx2vu3Lnq2bOn61kIgud5ysrK0vDhw5WRkeF6DgJUXl6uTZs2yefzqWfPnsrOztbXv/5117MQoOef\nf17bt2+Xz+dTfHy88vPzNWDAANezzHCFG6SLFy9q8+bNWrBggQoKChQbG6vS0lLXsxCEI0eOaM6c\nOaqoqHA9BUE4evSofv7zn6u4uFjl5eV66KGHtGjRItezEKDq6mpt3rxZJSUl2r17t7785S/r6aef\ndj3LVMDBbWlpCeWOsHHo0CElJCQoLi5OkjR+/HhVVVXJ8zzHyxCobdu2KTU1Vffff7/rKQhC9+7d\nlZ+fr4EDB0qSRo0apTNnzqixsdHxMgRi1KhRevXVV9W7d281NDSotrZW/fr1cz3LVLu3lD/88EMV\nFBSourpaUVFRamlp0YgRI5SVlaWEhASrjZ3K2bNnFRMT0/q4f//+qqurU319PbeVw0Rubq4kqaqq\nyvESBGPIkCEaMmSIpOtvCRQUFOjb3/62unfv7ngZAtWtWzdVVlYqOztb3bt31+LFi11PMtVucLOz\ns/XYY49p9OjRrc+99957ysrKUklJScjHdUY3upKNiODuPGDh6tWrWr58uT7++GMVFxe7noMgTZgw\nQRMmTNBLL72kjIwM7du3r8t8/2z337KxsfEzsZWkMWPGhHRQZzdgwACdP3++9fG5c+cUHR2tHj16\nOFwFdA0nT55UWlqaIiMjtXXrVvXp08f1JATo2LFjevvtt1sfP/DAAzp58qQuXLjgcJWtdq9wR44c\nqaysLI0bN069e/fWlStXdODAAY0cOdJqX6dz66236sUXX1Rtba3i4uL0+uuvd/kfQgAL58+f16xZ\ns5SamqqFCxe6noMgnT59WkuWLFFZWZliYmL0+9//XsOHD1f//v1dTzPTbnB/+tOfqrKyUu+8844u\nX76sXr16afz48Zo4caLVvk6nT58+mjdvngoLC9Xc3KzY2FhlZma6ngX8z9u+fbtOnTqlffv2ad++\nfa3P/+Y3v+lS37TD1Te+8Q396Ec/0uzZsxUZGamBAweqsLDQ9SxTPi+Ev17r9/t5jyVMZWZmqqam\nxvUMdJDf7+f8wpTf75ckzi9Mtfe11zXeqQYAwDGCCwCAAYILAIABggsAgAGCCwCAAYILAIABggsA\ngAGCCwCAAYILAIABggsAgAGCCwCAAYILAIABggsAgAGCCwCAAYILAIABggsAgAGCCwCAAYILAIAB\nggsAgAGCCwCAAYILAIABggsAgAGCCwCAAYILAIABggsAgAGCCwCAAYILAIABggsAgAGCCwCAAYIL\nAIABggsAgAGCCwCAAZ/neV6oXtzv94fqpQEA6JRqamrafD7K1QdG5+b3+zm7MMb5ha9PL1Q4v/DU\n3oUmt5QBADBAcAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEA\nMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBA\ncAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADAQ5XpAOKusrNSPf/xj\n/e1vf3M9BUFYvXq19u7dq759+0qSEhIS9Mtf/tLxKgTi8OHDys/P16VLlxQREaEnnnhCo0aNcj0L\nASgrK9Nzzz3X+vjSpUuqra3VgQMHdPPNNztcZofgdtB//vMfrVmzRp7nuZ6CIL377rtau3atxo4d\n63oKglBXV6eMjAytWrVK9957ryorK7V06VLt3bvX9TQEYNq0aZo2bZokqampSbNmzdL8+fO7TGwl\nbil3SF1dnZYtW6bly5e7noIgNTY26h//+Ic2b96s7373u1q0aJFOnjzpehYC8Oabbyo+Pl733nuv\nJCkpKYk7E2GqqKhIMTExSktLcz3FFMHtgNzcXM2YMUMjR450PQVBqq2t1V133aUlS5aovLxco0eP\n1sMPP8ydijDw73//W7GxsVqxYoVSU1M1d+5cNTc3u56FIJ09e1bPPfecVqxY4XqKOYIbpG3btikq\nKkrTp093PQUdEB8fr6KiIg0bNkw+n08ZGRk6fvy4Tpw44XoavsC1a9d04MABzZgxQzt37my9JdnY\n2Oh6GoLw0ksvKSkpSfHx8a6nmGv3Pdz09HQ1NTV95jnP8+Tz+VRSUhLSYZ3Vyy+/rPr6eqWkpKip\nqan1zxs3blRcXJzrefgC77//vt5///3W95Kk65/T3bp1c7gKgRg4cKCGDRum0aNHS5ImTJignJwc\nffjhh/rKV77ieB0CtWfPHuXk5Lie4US7wV26dKlycnJUWFioyMhIq02dWmlpaeufT5w4oeTkZJWX\nlztchGBERERo1apVuuOOOxQfH68XXnhBI0eO1KBBg1xPwxe45557tGbNGlVXV2vUqFF666235PP5\nNGTIENfTEKALFy7o+PHjuv32211PcaLd4I4ePVopKSk6fPiwJk6caLUJCJkRI0YoJydHDz30kJqb\nmzVo0CCtXbvW9SwEIDY2VoWFhVq5cqXq6urUvXt3PfPMM+rRo4fraQjQsWPHFBsb22XvKPm8EP62\niN/vV01NTaheHiHE2YU3zi98+f1+SeL8wlR7X3v80hQAAAYILgAABgguAAAGCC4AAAYILgAABggu\nAAAGCC4AAAYILgAABgguAAAGCC4AAAYILgAABgguAAAGCC4AAAYILgAABgguAAAGCC4AAAYILgAA\nBgguAAAGCC4AAAYILgAABgguAAAGCC4AAAYILgAABgguAAAGCC4AAAYILgAABgguAAAGCC4AAAYI\nLgAABgguAAAGCC4AAAYILgAABgguAAAGfJ7neaF6cb/fH6qXBgCgU6qpqWnz+ShXHxidm9/v5+zC\nGOcXvj69UOH8wlN7F5rcUgYAwADBBQDAAMEFAMAAwQUAwADBBQDAAMEFAMAAwQUAwADBBQDAAMEF\nAMAAwQUAwADBBQDAAMEFAMAAwQUAwADBBQDAAMEFAMAAwQUAwADBBQDAAMEFAMAAwQUAwADBBQDA\nAMEFAMAAwQUAwADBBQDAAMEFAMAAwQUAwADBBQDAAMEFAMAAwQUAwADBBQDAAMEFAMAAwQUAwADB\n7YB9+/YpOTlZKSkpSk9P1/Hjx11PQhB++9vfatKkSUpJSdGSJUt0/vx515PwBTzP0/Lly7Vp0yZJ\nUnNzs/Lz8zV58mRNnDhR27dvd7wQN/L5s/vUqVOnNG7cOJ09e9bRMnsEN0j19fVatmyZ1q1bp/Ly\nciUlJSk/P9/1LASoqqpKRUVF2rJli8rLy3XPPfcoNzfX9Sy048iRI5ozZ44qKipanyspKdGxY8e0\ne/dulZaWasuWLfr73//ucCXa0tbZSVJZWZlmzpypTz75xNEyN4IObmNjYyh2hI3m5mZ5nqdLly5J\nkq5cuaIePXo4XoVAHTp0SN/61rc0aNAgSdJ9992n/fv3d/nP685s27ZtSk1N1f3339/6XGVlpVJT\nUxUVFaW+fftq6tSp2rVrl8OVaEtbZ1dbW6vKykpt3LjR4TI3bhjc/fv3a/z48Zo4caL27NnT+nxm\nZqbJsM4qOjpaK1euVFpamu6++25t27ZNS5cudT0LAbrttttUVVWljz76SJK0c+dONTU1cVu5E8vN\nzdW0adM+89ypU6c0ePDg1seDBg3Sxx9/bD0NX6Cts4uLi9O6dev01a9+1dEqd6Ju9BfPPvusysrK\n1NLSokceeUQNDQ363ve+J8/zLPd1OocPH1ZhYaH27NmjoUOHauvWrVq0aJHKy8vl8/lcz8MXuPPO\nO7VgwQItXLhQPp9PDzzwgPr166du3bq5noYgtPV9KCKCd8jQud3wM7Rbt27q27ev+vfvr/Xr1+v5\n559XVVVVl4/Kn//8Z40dO1ZDhw6VJM2cOVP/+te/dO7cOcfLEIjLly8rMTFRL7/8snbu3KlJkyZJ\nkvr16+d4GYIxePBgnT59uvVxbW1t69sEQGd1w+DecsstKigo0NWrV9WrVy+tW7dOTzzxhI4ePWq5\nr9P52te+prfeektnzpyRdP29pCFDhigmJsbxMgTik08+UXp6ui5fvixJWr9+vaZOndrlf5AMN0lJ\nSdqxY4euXbumixcv6pVXXtGECRNczwLadcNbyk8++aR27drV+o1o8ODB2rp1qzZs2GA2rjP65je/\nqYyMDKWnp7feBVi/fr3rWQjQsGHDNH/+fH3/+99XS0uL7rjjDn5LOQw9+OCDOn78uFJSUtTU1KQZ\nM2YoMTHR9SygXT4vhG/K+v1+1dTUhOrlEUKcXXjj/MKX3++XJM4vTLX3tcdvGQAAYIDgAgBggOAC\nAGCA4AIAYIDgAgBggOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDgAgBg\ngOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDg\nAgBggOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDgAgBgwOd5nheqF/f7/aF6aQAAOqWampo2n+cK\nFwAAA1Gh/gA3Kj06N7/fz9mFMc4vfH16Z7C4uNjxEnREZmbmDf+OK1wAAAwQXAAADBBcAAAMEFwA\nAAwQXAAADBBcAAAMEFwAAAwQXAAADBBcAAAMEFwAAAwQXAAADBBcAAAMEFwAAAwQXAAADBBcAAAM\nEFwAAAwQXAAADBBcAAAMEFwAAAwQXAAADBBcAAAMEFwAAAwQXAAADBBcAAAMEFwAAAwQXAAADBBc\nAAAMEFwAAAwQXAAADBBcAAAMEFwAAAxEuR4QTjzPU1ZWloYPH66MjAzV19dr5cqVqq6uVktLi267\n7Tbl5eXppptucj0Vn/P5s7t06ZKys7N19OhRtbS0aNq0aZo/f77rmbiBz5/f/7Vw4UINHDhQubm5\njtYhUAcPHtSOHTvU1NSk+Ph4zZ07Vz179nQ9ywxXuAE6cuSI5syZo4qKitbnfv3rX6u5uVnl5eXa\ntWuXGhoatGHDBocr0Za2zu7pp59WXFycdu/erdLSUpWUlOjdd991uBI30tb5faqoqEhvv/22g1UI\n1sWLF7V582YtWLBABQUFio2NVWlpqetZpoK6wq2vr1dERIS6d+8eqj2d1rZt25SamqovfelLrc/d\neeeduuWWWxQRcf3nFr/frw8++MDVRNxAW2eXnZ2t5uZmSdLp06fV2Nio3r17u5qIdrR1fpJUVVWl\nN954Q2lpabp48aKjdQjUoUOHlJCQoLi4OEnS+PHjlZeXp1mzZsnn8zleZ6PdK9wPPvhADz/8sLKy\nsvSXv/xFU6ZM0ZQpU/Taa69Z7es0cnNzNW3atM88d/fddyshIUGS9NFHH2nLli2aPHmyi3loR1tn\n5/P5FBUVpaVLl+o73/mOEhMTW88SnUtb51dbW6tVq1bpqaeeUmRkpKNlCMbZs2cVExPT+rh///6q\nq6tTfX29w1W22g1uXl6efvjDHyoxMVGLFy/W7373O5WVlXHb9HOqq6s1c+ZMzZo1S+PHj3c9B0F4\n6qmnVFVVpQsXLqiwsND1HASgqalJS5Ys0YoVKzRw4EDXcxAgz/PafP7TO4RdQbu3lFtaWpSYmChJ\n+utf/6oBAwZc/4ei+F2rT73yyitauXKlHn/8cSUnJ7uegwC98cYbGjFihOLi4hQdHa2pU6fqD3/4\ng+tZCEB1dbVOnDih1atXS5LOnDmj5uZmNTQ0aNWqVY7X4UYGDBigo0ePtj4+d+6coqOj1aNHD4er\nbLX7o0VCQoKys7PV0tLS+sm9ceNG3XzzzSbjOru9e/cqPz9fmzZtIrZhpqKiQoWFhfI8T42Njaqo\nqNBdd93lehYCcPvtt+vAgQMqLy9XeXm50tLSNGXKFGLbyd166606evSoamtrJUmvv/66xowZ43iV\nrXYvVfPz87V///7PXPLHxcUpPT095MPCwdq1a+V5nnJyclqfGzt2rPLy8hyuQiCWL1+uvLw8JScn\ny+fzKSkpSbNnz3Y9C/if1adPH82bN0+FhYVqbm5WbGysMjMzXc8y5fNudGP9/4Hf71dNTU2oXh4h\nxNmFN84vfPn9fklScXGx4yXoiMzMzBt+7XWdd6sBAHCI4AIAYIDgAgBggOACAGCA4AIAYIDgAgBg\ngOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDg\nAgBggOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDgAgBggOACAGCA4AIA\nYIDgAgBggOACAGCA4AIAYMDneZ4Xqhf3+/2hemkAADqlmpqaNp8PaXABAMB13FIGAMAAwQUAwADB\nBQDAAMEFAMAAwQUAwADBBQDAAMHtgJaWFuXm5mrGjBlKT0/XsWPHXE9CkA4ePKj09HTXMxCkpqYm\nLVu2TD/4wQ80ffp0/fGPf3Q9CQFqbm5WVlaW0tLS9OCDD+qf//yn60nmCG4HVFZWqrGxUS+++KIe\ne+wxrV692vUkBKGoqEg5OTlqaGhwPQVB2rVrl/r166cXXnhBxcXF+tnPfuZ6EgL02muvSZJKSkr0\n6KOP6he/+IXjRfYIbge88847GjdunCRpzJgxqq6udrwIwRg6dKieeeYZ1zPQAZMnT9YjjzwiSfI8\nT5GRkY4XIVATJkxo/QHp5MmT6tOnj+NF9qJcDwhHly9fVq9evVofR0ZG6tq1a4qK4j9nOJg0aZJO\nnDjhegY6IDo6WtL1r8HFixfr0UcfdbwIwYiKitJPfvIT7du3T7/61a9czzHHFW4H9OrVS1euXGl9\n3NLSQmwBI6dOndLs2bOVkpKi5ORk13MQpDVr1ujVV1/V448/rqtXr7qeY4rgdsDYsWP1pz/9SZL0\n3nvvacSIEY4XAV3DmTNnNG/ePC1btkzTp093PQdBKCsr04YNGyRJPXv2lM/nU0RE10oQl2UdMHHi\nRL355ptKS0uT53l68sknXU8CuoRnn31WFy9e1Pr167V+/XpJ138J7qabbnK8DF/kvvvuU1ZWlmbO\nnKlr165pxYoVXe7c+L8FAQBgoGtdzwMA4AjBBQDAAMEFAMAAwQUAwADBBQDAAMEFAMAAwQUAwADB\nBQDAwH8BCoTS1WbM11MAAAAASUVORK5CYII=\n", 183 | "text/plain": [ 184 | "" 185 | ] 186 | }, 187 | "metadata": {}, 188 | "output_type": "display_data" 189 | } 190 | ], 191 | "source": [ 192 | "gw.show_environment()\n", 193 | "plt.show()" 194 | ] 195 | } 196 | ], 197 | "metadata": { 198 | "kernelspec": { 199 | "display_name": "Python 3", 200 | "language": "python", 201 | "name": "python3" 202 | }, 203 | "language_info": { 204 | "codemirror_mode": { 205 | "name": "ipython", 206 | "version": 3 207 | }, 208 | "file_extension": ".py", 209 | "mimetype": "text/x-python", 210 | "name": "python", 211 | "nbconvert_exporter": "python", 212 | "pygments_lexer": "ipython3", 213 | "version": "3.6.1" 214 | } 215 | }, 216 | "nbformat": 4, 217 | "nbformat_minor": 2 218 | } 219 | -------------------------------------------------------------------------------- /source code/0-MDP Environment (Chapter 3)/1-introduction-to-gridworld-environment.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# GridWorld Environment" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": { 14 | "collapsed": true 15 | }, 16 | "outputs": [], 17 | "source": [ 18 | "import matplotlib.pyplot as plt\n", 19 | "import pandas as pd\n", 20 | "from gridWorldEnvironment import GridWorld" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 2, 26 | "metadata": {}, 27 | "outputs": [ 28 | { 29 | "name": "stdout", 30 | "output_type": "stream", 31 | "text": [ 32 | "Actions: ('U', 'D', 'L', 'R')\n", 33 | "States: [ 1 2 3 4 5 6 7 8 9 10 11 12 13 14]\n" 34 | ] 35 | } 36 | ], 37 | "source": [ 38 | "# creating gridworld environment\n", 39 | "gw = GridWorld()\n", 40 | "\n", 41 | "print(\"Actions: \", gw.actions)\n", 42 | "print(\"States: \", gw.states)" 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "metadata": {}, 48 | "source": [ 49 | "### State Transitions\n", 50 | "- All possible state transitions in deterministic gridworld\n", 51 | " - Each transition is a quadruple of (```state```,```action```, ```next state```, ```reward```)\n", 52 | " - For instance, first row implies that in if the agent performs action ```U``` (upwards) in state ```1```, it ends up at state ```1``` again (not moving) with -1 reward" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 12, 58 | "metadata": { 59 | "scrolled": true 60 | }, 61 | "outputs": [ 62 | { 63 | "data": { 64 | "text/html": [ 65 | "
\n", 66 | "\n", 79 | "\n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | "
StateActionNext StateReward
01U1-1
11D5-1
21R2-1
31L0-1
42U2-1
\n", 127 | "
" 128 | ], 129 | "text/plain": [ 130 | " State Action Next State Reward\n", 131 | "0 1 U 1 -1\n", 132 | "1 1 D 5 -1\n", 133 | "2 1 R 2 -1\n", 134 | "3 1 L 0 -1\n", 135 | "4 2 U 2 -1" 136 | ] 137 | }, 138 | "execution_count": 12, 139 | "metadata": {}, 140 | "output_type": "execute_result" 141 | } 142 | ], 143 | "source": [ 144 | "pd.DataFrame(gw.transitions, columns = [\"State\", \"Action\", \"Next State\", \"Reward\"]).head()" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 14, 150 | "metadata": {}, 151 | "outputs": [ 152 | { 153 | "name": "stdout", 154 | "output_type": "stream", 155 | "text": [ 156 | "(1, -1)\n", 157 | "(2, -1)\n" 158 | ] 159 | } 160 | ], 161 | "source": [ 162 | "print(gw.state_transition(1, \"U\"))\n", 163 | "print(gw.state_transition(3, \"L\"))" 164 | ] 165 | }, 166 | { 167 | "cell_type": "markdown", 168 | "metadata": {}, 169 | "source": [ 170 | "### Show environment\n", 171 | "- Visualized environment is shown as table\n", 172 | "- Note that terminal states ```(0,0) and (3,3)``` are added" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 3, 178 | "metadata": {}, 179 | "outputs": [ 180 | { 181 | "data": { 182 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAdwAAAFJCAYAAAAxCJwFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAFnJJREFUeJzt3X1U1vX9x/HXBah58BZDdIlnuKnnWi7NFqftZB2HpukY\njtyRpugUjmflTR3TTYRgNgw967i1xKWgTZdJDQ2ciU1mudYOO9XKHRy5pZtmGunxXrkTvvvDE+dX\nB+m6OL/r/eEaz8dfXpeni1d9gCff78U5+TzP8wQAAEIqwvUAAAC6AoILAIABggsAgAGCCwCAAYIL\nAIABggsAgIGoUL643+9XcXFxKD8EQiQzM1M1NTWuZ6CD/H4/5xem/H6/JHF+Yaq9rz2ucAEAMEBw\nAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEA\nMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBA\ncAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEAMEBwAQAwEOV6QDg6ePCgduzYoaamJsXH\nx2vu3Lnq2bOn61kIgud5ysrK0vDhw5WRkeF6DgJUXl6uTZs2yefzqWfPnsrOztbXv/5117MQoOef\nf17bt2+Xz+dTfHy88vPzNWDAANezzHCFG6SLFy9q8+bNWrBggQoKChQbG6vS0lLXsxCEI0eOaM6c\nOaqoqHA9BUE4evSofv7zn6u4uFjl5eV66KGHtGjRItezEKDq6mpt3rxZJSUl2r17t7785S/r6aef\ndj3LVMDBbWlpCeWOsHHo0CElJCQoLi5OkjR+/HhVVVXJ8zzHyxCobdu2KTU1Vffff7/rKQhC9+7d\nlZ+fr4EDB0qSRo0apTNnzqixsdHxMgRi1KhRevXVV9W7d281NDSotrZW/fr1cz3LVLu3lD/88EMV\nFBSourpaUVFRamlp0YgRI5SVlaWEhASrjZ3K2bNnFRMT0/q4f//+qqurU319PbeVw0Rubq4kqaqq\nyvESBGPIkCEaMmSIpOtvCRQUFOjb3/62unfv7ngZAtWtWzdVVlYqOztb3bt31+LFi11PMtVucLOz\ns/XYY49p9OjRrc+99957ysrKUklJScjHdUY3upKNiODuPGDh6tWrWr58uT7++GMVFxe7noMgTZgw\nQRMmTNBLL72kjIwM7du3r8t8/2z337KxsfEzsZWkMWPGhHRQZzdgwACdP3++9fG5c+cUHR2tHj16\nOFwFdA0nT55UWlqaIiMjtXXrVvXp08f1JATo2LFjevvtt1sfP/DAAzp58qQuXLjgcJWtdq9wR44c\nqaysLI0bN069e/fWlStXdODAAY0cOdJqX6dz66236sUXX1Rtba3i4uL0+uuvd/kfQgAL58+f16xZ\ns5SamqqFCxe6noMgnT59WkuWLFFZWZliYmL0+9//XsOHD1f//v1dTzPTbnB/+tOfqrKyUu+8844u\nX76sXr16afz48Zo4caLVvk6nT58+mjdvngoLC9Xc3KzY2FhlZma6ngX8z9u+fbtOnTqlffv2ad++\nfa3P/+Y3v+lS37TD1Te+8Q396Ec/0uzZsxUZGamBAweqsLDQ9SxTPi+Ev17r9/t5jyVMZWZmqqam\nxvUMdJDf7+f8wpTf75ckzi9Mtfe11zXeqQYAwDGCCwCAAYILAIABggsAgAGCCwCAAYILAIABggsA\ngAGCCwCAAYILAIABggsAgAGCCwCAAYILAIABggsAgAGCCwCAAYILAIABggsAgAGCCwCAAYILAIAB\nggsAgAGCCwCAAYILAIABggsAgAGCCwCAAYILAIABggsAgAGCCwCAAYILAIABggsAgAGCCwCAAYIL\nAIABggsAgAGCCwCAAZ/neV6oXtzv94fqpQEA6JRqamrafD7K1QdG5+b3+zm7MMb5ha9PL1Q4v/DU\n3oUmt5QBADBAcAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEA\nMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBA\ncAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADBAcAEAMEBwAQAwQHABADAQ5XpAOKusrNSPf/xj\n/e1vf3M9BUFYvXq19u7dq759+0qSEhIS9Mtf/tLxKgTi8OHDys/P16VLlxQREaEnnnhCo0aNcj0L\nASgrK9Nzzz3X+vjSpUuqra3VgQMHdPPNNztcZofgdtB//vMfrVmzRp7nuZ6CIL377rtau3atxo4d\n63oKglBXV6eMjAytWrVK9957ryorK7V06VLt3bvX9TQEYNq0aZo2bZokqampSbNmzdL8+fO7TGwl\nbil3SF1dnZYtW6bly5e7noIgNTY26h//+Ic2b96s7373u1q0aJFOnjzpehYC8Oabbyo+Pl733nuv\nJCkpKYk7E2GqqKhIMTExSktLcz3FFMHtgNzcXM2YMUMjR450PQVBqq2t1V133aUlS5aovLxco0eP\n1sMPP8ydijDw73//W7GxsVqxYoVSU1M1d+5cNTc3u56FIJ09e1bPPfecVqxY4XqKOYIbpG3btikq\nKkrTp093PQUdEB8fr6KiIg0bNkw+n08ZGRk6fvy4Tpw44XoavsC1a9d04MABzZgxQzt37my9JdnY\n2Oh6GoLw0ksvKSkpSfHx8a6nmGv3Pdz09HQ1NTV95jnP8+Tz+VRSUhLSYZ3Vyy+/rPr6eqWkpKip\nqan1zxs3blRcXJzrefgC77//vt5///3W95Kk65/T3bp1c7gKgRg4cKCGDRum0aNHS5ImTJignJwc\nffjhh/rKV77ieB0CtWfPHuXk5Lie4US7wV26dKlycnJUWFioyMhIq02dWmlpaeufT5w4oeTkZJWX\nlztchGBERERo1apVuuOOOxQfH68XXnhBI0eO1KBBg1xPwxe45557tGbNGlVXV2vUqFF666235PP5\nNGTIENfTEKALFy7o+PHjuv32211PcaLd4I4ePVopKSk6fPiwJk6caLUJCJkRI0YoJydHDz30kJqb\nmzVo0CCtXbvW9SwEIDY2VoWFhVq5cqXq6urUvXt3PfPMM+rRo4fraQjQsWPHFBsb22XvKPm8EP62\niN/vV01NTaheHiHE2YU3zi98+f1+SeL8wlR7X3v80hQAAAYILgAABgguAAAGCC4AAAYILgAABggu\nAAAGCC4AAAYILgAABgguAAAGCC4AAAYILgAABgguAAAGCC4AAAYILgAABgguAAAGCC4AAAYILgAA\nBgguAAAGCC4AAAYILgAABgguAAAGCC4AAAYILgAABgguAAAGCC4AAAYILgAABgguAAAGCC4AAAYI\nLgAABgguAAAGCC4AAAYILgAABgguAAAGfJ7neaF6cb/fH6qXBgCgU6qpqWnz+ShXHxidm9/v5+zC\nGOcXvj69UOH8wlN7F5rcUgYAwADBBQDAAMEFAMAAwQUAwADBBQDAAMEFAMAAwQUAwADBBQDAAMEF\nAMAAwQUAwADBBQDAAMEFAMAAwQUAwADBBQDAAMEFAMAAwQUAwADBBQDAAMEFAMAAwQUAwADBBQDA\nAMEFAMAAwQUAwADBBQDAAMEFAMAAwQUAwADBBQDAAMEFAMAAwQUAwADBBQDAAMEFAMAAwQUAwADB\n7YB9+/YpOTlZKSkpSk9P1/Hjx11PQhB++9vfatKkSUpJSdGSJUt0/vx515PwBTzP0/Lly7Vp0yZJ\nUnNzs/Lz8zV58mRNnDhR27dvd7wQN/L5s/vUqVOnNG7cOJ09e9bRMnsEN0j19fVatmyZ1q1bp/Ly\nciUlJSk/P9/1LASoqqpKRUVF2rJli8rLy3XPPfcoNzfX9Sy048iRI5ozZ44qKipanyspKdGxY8e0\ne/dulZaWasuWLfr73//ucCXa0tbZSVJZWZlmzpypTz75xNEyN4IObmNjYyh2hI3m5mZ5nqdLly5J\nkq5cuaIePXo4XoVAHTp0SN/61rc0aNAgSdJ9992n/fv3d/nP685s27ZtSk1N1f3339/6XGVlpVJT\nUxUVFaW+fftq6tSp2rVrl8OVaEtbZ1dbW6vKykpt3LjR4TI3bhjc/fv3a/z48Zo4caL27NnT+nxm\nZqbJsM4qOjpaK1euVFpamu6++25t27ZNS5cudT0LAbrttttUVVWljz76SJK0c+dONTU1cVu5E8vN\nzdW0adM+89ypU6c0ePDg1seDBg3Sxx9/bD0NX6Cts4uLi9O6dev01a9+1dEqd6Ju9BfPPvusysrK\n1NLSokceeUQNDQ363ve+J8/zLPd1OocPH1ZhYaH27NmjoUOHauvWrVq0aJHKy8vl8/lcz8MXuPPO\nO7VgwQItXLhQPp9PDzzwgPr166du3bq5noYgtPV9KCKCd8jQud3wM7Rbt27q27ev+vfvr/Xr1+v5\n559XVVVVl4/Kn//8Z40dO1ZDhw6VJM2cOVP/+te/dO7cOcfLEIjLly8rMTFRL7/8snbu3KlJkyZJ\nkvr16+d4GYIxePBgnT59uvVxbW1t69sEQGd1w+DecsstKigo0NWrV9WrVy+tW7dOTzzxhI4ePWq5\nr9P52te+prfeektnzpyRdP29pCFDhigmJsbxMgTik08+UXp6ui5fvixJWr9+vaZOndrlf5AMN0lJ\nSdqxY4euXbumixcv6pVXXtGECRNczwLadcNbyk8++aR27drV+o1o8ODB2rp1qzZs2GA2rjP65je/\nqYyMDKWnp7feBVi/fr3rWQjQsGHDNH/+fH3/+99XS0uL7rjjDn5LOQw9+OCDOn78uFJSUtTU1KQZ\nM2YoMTHR9SygXT4vhG/K+v1+1dTUhOrlEUKcXXjj/MKX3++XJM4vTLX3tcdvGQAAYIDgAgBggOAC\nAGCA4AIAYIDgAgBggOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDgAgBg\ngOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDg\nAgBggOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDgAgBgwOd5nheqF/f7/aF6aQAAOqWampo2n+cK\nFwAAA1Gh/gA3Kj06N7/fz9mFMc4vfH16Z7C4uNjxEnREZmbmDf+OK1wAAAwQXAAADBBcAAAMEFwA\nAAwQXAAADBBcAAAMEFwAAAwQXAAADBBcAAAMEFwAAAwQXAAADBBcAAAMEFwAAAwQXAAADBBcAAAM\nEFwAAAwQXAAADBBcAAAMEFwAAAwQXAAADBBcAAAMEFwAAAwQXAAADBBcAAAMEFwAAAwQXAAADBBc\nAAAMEFwAAAwQXAAADBBcAAAMEFwAAAxEuR4QTjzPU1ZWloYPH66MjAzV19dr5cqVqq6uVktLi267\n7Tbl5eXppptucj0Vn/P5s7t06ZKys7N19OhRtbS0aNq0aZo/f77rmbiBz5/f/7Vw4UINHDhQubm5\njtYhUAcPHtSOHTvU1NSk+Ph4zZ07Vz179nQ9ywxXuAE6cuSI5syZo4qKitbnfv3rX6u5uVnl5eXa\ntWuXGhoatGHDBocr0Za2zu7pp59WXFycdu/erdLSUpWUlOjdd991uBI30tb5faqoqEhvv/22g1UI\n1sWLF7V582YtWLBABQUFio2NVWlpqetZpoK6wq2vr1dERIS6d+8eqj2d1rZt25SamqovfelLrc/d\neeeduuWWWxQRcf3nFr/frw8++MDVRNxAW2eXnZ2t5uZmSdLp06fV2Nio3r17u5qIdrR1fpJUVVWl\nN954Q2lpabp48aKjdQjUoUOHlJCQoLi4OEnS+PHjlZeXp1mzZsnn8zleZ6PdK9wPPvhADz/8sLKy\nsvSXv/xFU6ZM0ZQpU/Taa69Z7es0cnNzNW3atM88d/fddyshIUGS9NFHH2nLli2aPHmyi3loR1tn\n5/P5FBUVpaVLl+o73/mOEhMTW88SnUtb51dbW6tVq1bpqaeeUmRkpKNlCMbZs2cVExPT+rh///6q\nq6tTfX29w1W22g1uXl6efvjDHyoxMVGLFy/W7373O5WVlXHb9HOqq6s1c+ZMzZo1S+PHj3c9B0F4\n6qmnVFVVpQsXLqiwsND1HASgqalJS5Ys0YoVKzRw4EDXcxAgz/PafP7TO4RdQbu3lFtaWpSYmChJ\n+utf/6oBAwZc/4ei+F2rT73yyitauXKlHn/8cSUnJ7uegwC98cYbGjFihOLi4hQdHa2pU6fqD3/4\ng+tZCEB1dbVOnDih1atXS5LOnDmj5uZmNTQ0aNWqVY7X4UYGDBigo0ePtj4+d+6coqOj1aNHD4er\nbLX7o0VCQoKys7PV0tLS+sm9ceNG3XzzzSbjOru9e/cqPz9fmzZtIrZhpqKiQoWFhfI8T42Njaqo\nqNBdd93lehYCcPvtt+vAgQMqLy9XeXm50tLSNGXKFGLbyd166606evSoamtrJUmvv/66xowZ43iV\nrXYvVfPz87V///7PXPLHxcUpPT095MPCwdq1a+V5nnJyclqfGzt2rPLy8hyuQiCWL1+uvLw8JScn\ny+fzKSkpSbNnz3Y9C/if1adPH82bN0+FhYVqbm5WbGysMjMzXc8y5fNudGP9/4Hf71dNTU2oXh4h\nxNmFN84vfPn9fklScXGx4yXoiMzMzBt+7XWdd6sBAHCI4AIAYIDgAgBggOACAGCA4AIAYIDgAgBg\ngOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDg\nAgBggOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDgAgBggOACAGCA4AIAYIDgAgBggOACAGCA4AIA\nYIDgAgBggOACAGCA4AIAYMDneZ4Xqhf3+/2hemkAADqlmpqaNp8PaXABAMB13FIGAMAAwQUAwADB\nBQDAAMEFAMAAwQUAwADBBQDAAMHtgJaWFuXm5mrGjBlKT0/XsWPHXE9CkA4ePKj09HTXMxCkpqYm\nLVu2TD/4wQ80ffp0/fGPf3Q9CQFqbm5WVlaW0tLS9OCDD+qf//yn60nmCG4HVFZWqrGxUS+++KIe\ne+wxrV692vUkBKGoqEg5OTlqaGhwPQVB2rVrl/r166cXXnhBxcXF+tnPfuZ6EgL02muvSZJKSkr0\n6KOP6he/+IXjRfYIbge88847GjdunCRpzJgxqq6udrwIwRg6dKieeeYZ1zPQAZMnT9YjjzwiSfI8\nT5GRkY4XIVATJkxo/QHp5MmT6tOnj+NF9qJcDwhHly9fVq9evVofR0ZG6tq1a4qK4j9nOJg0aZJO\nnDjhegY6IDo6WtL1r8HFixfr0UcfdbwIwYiKitJPfvIT7du3T7/61a9czzHHFW4H9OrVS1euXGl9\n3NLSQmwBI6dOndLs2bOVkpKi5ORk13MQpDVr1ujVV1/V448/rqtXr7qeY4rgdsDYsWP1pz/9SZL0\n3nvvacSIEY4XAV3DmTNnNG/ePC1btkzTp093PQdBKCsr04YNGyRJPXv2lM/nU0RE10oQl2UdMHHi\nRL355ptKS0uT53l68sknXU8CuoRnn31WFy9e1Pr167V+/XpJ138J7qabbnK8DF/kvvvuU1ZWlmbO\nnKlr165pxYoVXe7c+L8FAQBgoGtdzwMA4AjBBQDAAMEFAMAAwQUAwADBBQDAAMEFAMAAwQUAwADB\nBQDAwH8BCoTS1WbM11MAAAAASUVORK5CYII=\n", 183 | "text/plain": [ 184 | "" 185 | ] 186 | }, 187 | "metadata": {}, 188 | "output_type": "display_data" 189 | } 190 | ], 191 | "source": [ 192 | "gw.show_environment()\n", 193 | "plt.show()" 194 | ] 195 | } 196 | ], 197 | "metadata": { 198 | "kernelspec": { 199 | "display_name": "Python 3", 200 | "language": "python", 201 | "name": "python3" 202 | }, 203 | "language_info": { 204 | "codemirror_mode": { 205 | "name": "ipython", 206 | "version": 3 207 | }, 208 | "file_extension": ".py", 209 | "mimetype": "text/x-python", 210 | "name": "python", 211 | "nbconvert_exporter": "python", 212 | "pygments_lexer": "ipython3", 213 | "version": "3.6.1" 214 | } 215 | }, 216 | "nbformat": 4, 217 | "nbformat_minor": 2 218 | } 219 | -------------------------------------------------------------------------------- /source code/0-MDP Environment (Chapter 3)/__pycache__/gridWorldEnvironment.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buomsoo-kim/Tabular-RL-with-Python/0157c05126821524dc9d744613b9aa8dd5d47232/source code/0-MDP Environment (Chapter 3)/__pycache__/gridWorldEnvironment.cpython-36.pyc -------------------------------------------------------------------------------- /source code/0-MDP Environment (Chapter 3)/gridWorldEnvironment.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import seaborn 4 | from matplotlib.colors import ListedColormap 5 | 6 | class GridWorld: 7 | def __init__(self, gamma = 1.0, theta = 0.5): 8 | self.actions = ("U", "D", "L", "R") 9 | self.states = np.arange(1, 15) 10 | self.transitions = pd.read_csv("gridworld.txt", header = None, sep = "\t").values 11 | self.gamma = gamma 12 | self.theta = theta 13 | 14 | def state_transition(self, state, action): 15 | next_state, reward = None, None 16 | for tr in self.transitions: 17 | if tr[0] == state and tr[1] == action: 18 | next_state = tr[2] 19 | reward = tr[3] 20 | return next_state, reward 21 | 22 | def show_environment(self): 23 | all_states = np.concatenate(([0], self.states, [0])).reshape(4,4) 24 | colors = [] 25 | # colors = ["#ffffff"] 26 | for i in range(len(self.states) + 1): 27 | if i == 0: 28 | colors.append("#c4c4c4") 29 | else: 30 | colors.append("#ffffff") 31 | 32 | cmap = ListedColormap(seaborn.color_palette(colors).as_hex()) 33 | ax = seaborn.heatmap(all_states, cmap = cmap, \ 34 | annot = True, linecolor = "#282828", linewidths = 0.2, \ 35 | cbar = False) -------------------------------------------------------------------------------- /source code/0-MDP Environment (Chapter 3)/gridworld.txt: -------------------------------------------------------------------------------- 1 | 1 U 1 -1 2 | 1 D 5 -1 3 | 1 R 2 -1 4 | 1 L 0 -1 5 | 2 U 2 -1 6 | 2 D 6 -1 7 | 2 R 3 -1 8 | 2 L 1 -1 9 | 3 U 3 -1 10 | 3 D 7 -1 11 | 3 R 3 -1 12 | 3 L 2 -1 13 | 4 U 0 -1 14 | 4 D 8 -1 15 | 4 R 5 -1 16 | 4 L 4 -1 17 | 5 U 1 -1 18 | 5 D 9 -1 19 | 5 R 6 -1 20 | 5 L 4 -1 21 | 6 U 2 -1 22 | 6 D 10 -1 23 | 6 R 7 -1 24 | 6 L 5 -1 25 | 7 U 3 -1 26 | 7 D 11 -1 27 | 7 R 7 -1 28 | 7 L 6 -1 29 | 8 U 4 -1 30 | 8 D 12 -1 31 | 8 R 9 -1 32 | 8 L 8 -1 33 | 9 U 5 -1 34 | 9 D 13 -1 35 | 9 R 10 -1 36 | 9 L 8 -1 37 | 10 U 6 -1 38 | 10 D 14 -1 39 | 10 R 11 -1 40 | 10 L 9 -1 41 | 11 U 7 -1 42 | 11 D 0 -1 43 | 11 R 11 -1 44 | 11 L 10 -1 45 | 12 U 8 -1 46 | 12 D 12 -1 47 | 12 R 13 -1 48 | 12 L 12 -1 49 | 13 U 9 -1 50 | 13 D 13 -1 51 | 13 R 14 -1 52 | 13 L 12 -1 53 | 14 U 10 -1 54 | 14 D 14 -1 55 | 14 R 0 -1 56 | 14 L 13 -1 -------------------------------------------------------------------------------- /source code/1-Dynamic Programming (Chapter 4)/gridWorldEnvironment.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import seaborn 4 | from matplotlib.colors import ListedColormap 5 | 6 | class GridWorld: 7 | def __init__(self, gamma = 1.0, theta = 0.5): 8 | self.actions = ("U", "D", "L", "R") 9 | self.states = np.arange(1, 15) 10 | self.transitions = pd.read_csv("gridworld.txt", header = None, sep = "\t").values 11 | self.gamma = gamma 12 | self.theta = theta 13 | 14 | def state_transition(self, state, action): 15 | next_state, reward = None, None 16 | for tr in self.transitions: 17 | if tr[0] == state and tr[1] == action: 18 | next_state = tr[2] 19 | reward = tr[3] 20 | return next_state, reward 21 | 22 | def show_environment(self): 23 | all_states = np.concatenate(([0], self.states, [0])).reshape(4,4) 24 | colors = [] 25 | # colors = ["#ffffff"] 26 | for i in range(len(self.states) + 1): 27 | if i == 0: 28 | colors.append("#c4c4c4") 29 | else: 30 | colors.append("#ffffff") 31 | 32 | cmap = ListedColormap(seaborn.color_palette(colors).as_hex()) 33 | ax = seaborn.heatmap(all_states, cmap = cmap, \ 34 | annot = True, linecolor = "#282828", linewidths = 0.2, \ 35 | cbar = False) -------------------------------------------------------------------------------- /source code/1-Dynamic Programming (Chapter 4)/gridworld.txt: -------------------------------------------------------------------------------- 1 | 1 U 1 -1 2 | 1 D 5 -1 3 | 1 R 2 -1 4 | 1 L 0 -1 5 | 2 U 2 -1 6 | 2 D 6 -1 7 | 2 R 3 -1 8 | 2 L 1 -1 9 | 3 U 3 -1 10 | 3 D 7 -1 11 | 3 R 3 -1 12 | 3 L 2 -1 13 | 4 U 0 -1 14 | 4 D 8 -1 15 | 4 R 5 -1 16 | 4 L 4 -1 17 | 5 U 1 -1 18 | 5 D 9 -1 19 | 5 R 6 -1 20 | 5 L 4 -1 21 | 6 U 2 -1 22 | 6 D 10 -1 23 | 6 R 7 -1 24 | 6 L 5 -1 25 | 7 U 3 -1 26 | 7 D 11 -1 27 | 7 R 7 -1 28 | 7 L 6 -1 29 | 8 U 4 -1 30 | 8 D 12 -1 31 | 8 R 9 -1 32 | 8 L 8 -1 33 | 9 U 5 -1 34 | 9 D 13 -1 35 | 9 R 10 -1 36 | 9 L 8 -1 37 | 10 U 6 -1 38 | 10 D 14 -1 39 | 10 R 11 -1 40 | 10 L 9 -1 41 | 11 U 7 -1 42 | 11 D 0 -1 43 | 11 R 11 -1 44 | 11 L 10 -1 45 | 12 U 8 -1 46 | 12 D 12 -1 47 | 12 R 13 -1 48 | 12 L 12 -1 49 | 13 U 9 -1 50 | 13 D 13 -1 51 | 13 R 14 -1 52 | 13 L 12 -1 53 | 14 U 10 -1 54 | 14 D 14 -1 55 | 14 R 0 -1 56 | 14 L 13 -1 -------------------------------------------------------------------------------- /source code/2-Monte Carlo Methods (Chapter 5)/.ipynb_checkpoints/3-on-policy-monte-carlo-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# On-policy Monte Carlo\n", 8 | "- Algorithms from ```pp. 82 - 84``` in Sutton & Barto 2017\n", 9 | "- Estimates optimal policy $\\pi \\approx \\pi_*$" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 34, 15 | "metadata": { 16 | "collapsed": true 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "import matplotlib.pyplot as plt\n", 21 | "import pandas as pd\n", 22 | "import numpy as np\n", 23 | "import seaborn\n", 24 | "\n", 25 | "from gridWorldEnvironment import GridWorld" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 35, 31 | "metadata": { 32 | "collapsed": true 33 | }, 34 | "outputs": [], 35 | "source": [ 36 | "# creating gridworld environment\n", 37 | "gw = GridWorld(gamma = .9)" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "### Generate policy\n", 45 | "- We start with equiprobable random policy" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 36, 51 | "metadata": { 52 | "collapsed": true 53 | }, 54 | "outputs": [], 55 | "source": [ 56 | "def generate_random_policy(env):\n", 57 | " pi = dict()\n", 58 | " for state in env.states:\n", 59 | " actions = []\n", 60 | " prob = []\n", 61 | " for action in env.actions:\n", 62 | " actions.append(action)\n", 63 | " prob.append(0.25)\n", 64 | " pi[state] = (actions, prob)\n", 65 | " return pi" 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "metadata": {}, 71 | "source": [ 72 | "### Create Action Values\n", 73 | "- Initialize all state-action values with 0" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 37, 79 | "metadata": { 80 | "collapsed": true 81 | }, 82 | "outputs": [], 83 | "source": [ 84 | "def state_action_value(env):\n", 85 | " q = dict()\n", 86 | " for state, action, next_state, reward in env.transitions:\n", 87 | " q[(state, action)] = 0\n", 88 | " return q" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 38, 94 | "metadata": {}, 95 | "outputs": [ 96 | { 97 | "name": "stdout", 98 | "output_type": "stream", 99 | "text": [ 100 | "0\n", 101 | "0\n", 102 | "0\n" 103 | ] 104 | } 105 | ], 106 | "source": [ 107 | "q = state_action_value(gw)\n", 108 | "\n", 109 | "print(q[(2, \"U\")])\n", 110 | "print(q[(4, \"L\")])\n", 111 | "print(q[(10, \"R\")])" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 39, 117 | "metadata": { 118 | "collapsed": true, 119 | "scrolled": true 120 | }, 121 | "outputs": [], 122 | "source": [ 123 | "pi = generate_random_policy(gw)" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "metadata": {}, 129 | "source": [ 130 | "### Generate episode\n", 131 | "- Generate episode based on current policy ($\\pi$)" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 40, 137 | "metadata": { 138 | "collapsed": true, 139 | "scrolled": true 140 | }, 141 | "outputs": [], 142 | "source": [ 143 | "def generate_episode(env, policy):\n", 144 | " episode = []\n", 145 | " done = False\n", 146 | " current_state = np.random.choice(env.states)\n", 147 | " action = np.random.choice(policy[current_state][0], p = policy[current_state][1])\n", 148 | " episode.append((current_state, action, -1))\n", 149 | " \n", 150 | " while not done:\n", 151 | " next_state, reward = gw.state_transition(current_state, action)\n", 152 | " action = np.random.choice(policy[current_state][0], p = policy[current_state][1])\n", 153 | " episode.append((next_state, action, reward))\n", 154 | " \n", 155 | " if next_state == 0:\n", 156 | " done = True\n", 157 | " current_state = next_state\n", 158 | " \n", 159 | " return episode" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": 41, 165 | "metadata": {}, 166 | "outputs": [ 167 | { 168 | "data": { 169 | "text/plain": [ 170 | "[(6, 'U', -1),\n", 171 | " (2, 'D', -1),\n", 172 | " (6, 'R', -1),\n", 173 | " (7, 'U', -1),\n", 174 | " (3, 'D', -1),\n", 175 | " (7, 'D', -1),\n", 176 | " (11, 'U', -1),\n", 177 | " (7, 'U', -1),\n", 178 | " (3, 'L', -1),\n", 179 | " (2, 'D', -1),\n", 180 | " (6, 'D', -1),\n", 181 | " (10, 'D', -1),\n", 182 | " (14, 'R', -1),\n", 183 | " (0, 'R', -1)]" 184 | ] 185 | }, 186 | "execution_count": 41, 187 | "metadata": {}, 188 | "output_type": "execute_result" 189 | } 190 | ], 191 | "source": [ 192 | "generate_episode(gw, pi)" 193 | ] 194 | }, 195 | { 196 | "cell_type": "markdown", 197 | "metadata": {}, 198 | "source": [ 199 | "### On-policy Monte Carlo Control " 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": 42, 205 | "metadata": { 206 | "collapsed": true 207 | }, 208 | "outputs": [], 209 | "source": [ 210 | "# first-visit MC\n", 211 | "def on_policy_mc(env, pi, e, num_iter):\n", 212 | " Q = state_action_value(env)\n", 213 | " returns = dict()\n", 214 | " for s, a in Q:\n", 215 | " returns[(s,a)] = []\n", 216 | " \n", 217 | " for i in range(num_iter):\n", 218 | " episode = generate_episode(env, pi)\n", 219 | " already_visited = set({0})\n", 220 | " for s, a, r in episode:\n", 221 | " if s not in already_visited:\n", 222 | " already_visited.add(s)\n", 223 | " idx = episode.index((s, a, r))\n", 224 | " G = 0\n", 225 | " j = 1\n", 226 | " while j + idx < len(episode):\n", 227 | " G = env.gamma * (G + episode[j + idx][-1])\n", 228 | " j += 1\n", 229 | " returns[(s,a)].append(G)\n", 230 | " Q[(s,a)] = np.mean(returns[(s,a)])\n", 231 | " for s, _, _ in episode:\n", 232 | " if s != 0:\n", 233 | " actions = []\n", 234 | " action_values = []\n", 235 | " prob = []\n", 236 | "\n", 237 | " for a in env.actions:\n", 238 | " actions.append(a)\n", 239 | " action_values.append(Q[s,a]) \n", 240 | " for i in range(len(action_values)):\n", 241 | " if i == np.argmax(action_values):\n", 242 | " prob.append(1 - e + e/len(actions))\n", 243 | " else:\n", 244 | " prob.append(e/len(actions)) \n", 245 | " pi[s] = (actions, prob)\n", 246 | " return Q, pi" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 43, 252 | "metadata": {}, 253 | "outputs": [ 254 | { 255 | "name": "stdout", 256 | "output_type": "stream", 257 | "text": [ 258 | "Wall time: 1min 8s\n" 259 | ] 260 | } 261 | ], 262 | "source": [ 263 | "# Obtained Estimates for Q & pi after 100000 iterations\n", 264 | "Q_hat, pi_hat = on_policy_mc(gw, generate_random_policy(gw), 0.2, 100000)" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": 44, 270 | "metadata": {}, 271 | "outputs": [ 272 | { 273 | "data": { 274 | "text/plain": [ 275 | "{1: (['U', 'D', 'L', 'R'], [0.05, 0.05, 0.8500000000000001, 0.05]),\n", 276 | " 2: (['U', 'D', 'L', 'R'], [0.05, 0.05, 0.8500000000000001, 0.05]),\n", 277 | " 3: (['U', 'D', 'L', 'R'], [0.05, 0.8500000000000001, 0.05, 0.05]),\n", 278 | " 4: (['U', 'D', 'L', 'R'], [0.8500000000000001, 0.05, 0.05, 0.05]),\n", 279 | " 5: (['U', 'D', 'L', 'R'], [0.8500000000000001, 0.05, 0.05, 0.05]),\n", 280 | " 6: (['U', 'D', 'L', 'R'], [0.05, 0.05, 0.8500000000000001, 0.05]),\n", 281 | " 7: (['U', 'D', 'L', 'R'], [0.05, 0.8500000000000001, 0.05, 0.05]),\n", 282 | " 8: (['U', 'D', 'L', 'R'], [0.8500000000000001, 0.05, 0.05, 0.05]),\n", 283 | " 9: (['U', 'D', 'L', 'R'], [0.05, 0.05, 0.05, 0.8500000000000001]),\n", 284 | " 10: (['U', 'D', 'L', 'R'], [0.05, 0.05, 0.05, 0.8500000000000001]),\n", 285 | " 11: (['U', 'D', 'L', 'R'], [0.05, 0.8500000000000001, 0.05, 0.05]),\n", 286 | " 12: (['U', 'D', 'L', 'R'], [0.05, 0.05, 0.05, 0.8500000000000001]),\n", 287 | " 13: (['U', 'D', 'L', 'R'], [0.05, 0.05, 0.05, 0.8500000000000001]),\n", 288 | " 14: (['U', 'D', 'L', 'R'], [0.05, 0.05, 0.05, 0.8500000000000001])}" 289 | ] 290 | }, 291 | "execution_count": 44, 292 | "metadata": {}, 293 | "output_type": "execute_result" 294 | } 295 | ], 296 | "source": [ 297 | "# final policy obtained\n", 298 | "pi_hat" 299 | ] 300 | }, 301 | { 302 | "cell_type": "markdown", 303 | "metadata": {}, 304 | "source": [ 305 | "### Visualizing policy" 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "execution_count": 45, 311 | "metadata": { 312 | "collapsed": true 313 | }, 314 | "outputs": [], 315 | "source": [ 316 | "def show_policy(pi, env):\n", 317 | " temp = np.zeros(len(env.states) + 2)\n", 318 | " for s in env.states:\n", 319 | " a = pi_hat[s][0][np.argmax(pi_hat[s][1])]\n", 320 | " if a == \"U\":\n", 321 | " temp[s] = 0.25\n", 322 | " elif a == \"D\":\n", 323 | " temp[s] = 0.5\n", 324 | " elif a == \"R\":\n", 325 | " temp[s] = 0.75\n", 326 | " else:\n", 327 | " temp[s] = 1.0\n", 328 | " \n", 329 | " temp = temp.reshape(4,4)\n", 330 | " ax = seaborn.heatmap(temp, cmap = \"prism\", linecolor=\"#282828\", cbar = False, linewidths = 0.1)\n", 331 | " plt.show()" 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "execution_count": 46, 337 | "metadata": {}, 338 | "outputs": [ 339 | { 340 | "data": { 341 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAdwAAAFJCAYAAAAxCJwFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACsRJREFUeJzt3UtoXAX/xvFfmhEV4wUsSBGEbpoMvNDiwl0XgcZLQUKl\nYKtEQbJy0RRqkZh6wUtbV4KXULUrF75xZXEhFGuLgoKLQguBiSKCeNu4krbYtM68C6F//NPkNXmZ\nZ8z089nlHHLmoYfy7TkkdKDT6XQKAOiqdb0eAADXAsEFgADBBYAAwQWAAMEFgADBBYCARjcv3mw2\nq7Ww0M2PoEuaIyP1r5Z7t1bNN92/tWq+OVJVVa0P3b+1qLljpFqt1lXPecIFgADBBYAAwQWAAMEF\ngADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWA\nAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAA\nwQWAAMEFgIC/Hdx2u93NHQDQ1xrLnfzhhx/q0KFDNT8/X41Go9rtdm3atKmmp6dr48aNqY0AsOYt\nG9yZmZnat29fbd68+cqxM2fO1PT0dM3NzXV9HAD0i2VfKS8uLv4ltlVVW7Zs6eogAOhHyz7hDg8P\n1/T0dG3durVuvvnmOn/+fH322Wc1PDyc2gcAfWHZ4L7wwgt14sSJOn36dJ07d66GhoZqdHS0xsbG\nUvsAoC8sG9yBgYEaGxsTWAD4H/k9XAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQ\nXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBc\nAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBjodDqdbl282Wx269IA8I/UarWuerzR\n7Q9e2Hy02x9BF4ycnXTv1rCRs5P1r9ZCr2ewCvPNkaqqan3o/q1FzR0jS57zShkAAgQXAAIEFwAC\nBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIE\nFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQX\nAAIEFwACBBcAAgQXAAIEFwACGsudnJiYqEuXLv3lWKfTqYGBgZqbm+vqMADoJ8sG96mnnqoDBw7U\nW2+9VYODg6lNANB3lg3u5s2ba3x8vL7++usaGxtLbQKAvrNscKuqJicnEzsAoK/5oSkACBBcAAgQ\nXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBc\nAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwA\nCBBcAAgQXAAIGOh0Op1uXbzZbHbr0gDwj9Rqta56vNHtD17YfLTbH0EXjJydrJmFf/d6Bqv0yshu\n92+NemVkd1VVtT5c6PESVqO5Y2TJc14pA0CA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOAC\nQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJA\ngOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQMCKg7u4uNiNHQDQ15YM7smT\nJ2t0dLTGxsbq448/vnJ8cnIyMgwA+kljqRNHjhypY8eOVbvdrqmpqbp48WLt2LGjOp1Och8A9IUl\ng3vdddfVrbfeWlVVs7Oz9fjjj9eGDRtqYGAgNg4A+sWSr5TvvPPOOnToUF24cKGGhobqzTffrBdf\nfLG+++675D4A6AtLBvfgwYM1PDx85Yl2w4YN9d5779UDDzwQGwcA/WLJV8qNRqMeeuihvxxbv359\nzczMdH0UAPQbv4cLAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsA\nAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwAB\nggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAQOdTqfTrYs3m81uXRoA/pFardZVjze6/cEzC//u\n9kfQBa+M7Hbv1jD3b+16ZWR3VVW1FhZ6vITVaI6MLHnOK2UACBBcAAgQXAAIEFwACBBcAAgQXAAI\nEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQ\nXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIWFFwf//9\n91pcXOzWFgDoW8sG99tvv60nn3yypqen68svv6zt27fX9u3b69SpU6l9ANAXGsudfP7552tqaqp+\n+umn2rNnTx0/fryuv/76mpycrNHR0dRGAFjzlg1uu92ue+65p6qqvvrqq7r99tv//KbGst8GAPw/\ny75S3rhxY83MzFS73a7Dhw9XVdU777xT69evj4wDgH6x7KPqyy+/XCdPnqx16/6vy3fccUdNTEx0\nfRgA9JNlg7tu3bratm3bX46Nj493dRAA9CO/hwsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYIL\nAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsA\nAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAEDnU6n062LN5vNbl0aAP6R\nWq3WVY93NbgAwJ+8UgaAAMEFgADBBYAAwQWAAMEFgADBBYAAwV2Fdrtdzz33XD388MM1MTFR33//\nfa8nsUJnz56tiYmJXs9ghS5dulT79++vRx55pHbu3FmffvppryfxN/3xxx81PT1du3btqt27d9c3\n33zT60lxgrsKJ06cqMXFxfrggw9q3759dfjw4V5PYgXefffdOnDgQF28eLHXU1ihjz76qG677bZ6\n//336+jRo/XSSy/1ehJ/06lTp6qqam5urvbu3VuvvfZajxflCe4qnD59urZu3VpVVVu2bKn5+fke\nL2Il7rrrrnrjjTd6PYNVuP/++2tqaqqqqjqdTg0ODvZ4EX/Xtm3brvwD6eeff65bbrmlx4vyGr0e\nsBadO3euhoaGrnw9ODhYly9frkbDH+dacN9999WPP/7Y6xmswk033VRVf/4d3LNnT+3du7fHi1iJ\nRqNRTz/9dH3yySf1+uuv93pOnCfcVRgaGqrz589f+brdbosthPzyyy/12GOP1fj4eD344IO9nsMK\nvfrqq3X8+PF69tln68KFC72eEyW4q3D33XfX559/XlVVZ86cqU2bNvV4EVwbfv3113riiSdq//79\ntXPnzl7PYQWOHTtWb7/9dlVV3XjjjTUwMFDr1l1bCfJYtgpjY2P1xRdf1K5du6rT6dTBgwd7PQmu\nCUeOHKnffvutZmdna3Z2tqr+/CG4G264ocfL+G/uvffemp6erkcffbQuX75czzzzzDV33/xvQQAQ\ncG09zwNAjwguAAQILgAECC4ABAguAAQILgAECC4ABAguAAT8B0u/pa2vpAxwAAAAAElFTkSuQmCC\n", 342 | "text/plain": [ 343 | "" 344 | ] 345 | }, 346 | "metadata": {}, 347 | "output_type": "display_data" 348 | } 349 | ], 350 | "source": [ 351 | "### RED = TERMINAL (0)\n", 352 | "### GREEN = LEFT\n", 353 | "### BLUE = UP\n", 354 | "### PURPLE = RIGHT\n", 355 | "### ORANGE = DOWN\n", 356 | "\n", 357 | "show_policy(pi_hat, gw)" 358 | ] 359 | } 360 | ], 361 | "metadata": { 362 | "kernelspec": { 363 | "display_name": "Python 3", 364 | "language": "python", 365 | "name": "python3" 366 | }, 367 | "language_info": { 368 | "codemirror_mode": { 369 | "name": "ipython", 370 | "version": 3 371 | }, 372 | "file_extension": ".py", 373 | "mimetype": "text/x-python", 374 | "name": "python", 375 | "nbconvert_exporter": "python", 376 | "pygments_lexer": "ipython3", 377 | "version": "3.6.1" 378 | } 379 | }, 380 | "nbformat": 4, 381 | "nbformat_minor": 2 382 | } 383 | -------------------------------------------------------------------------------- /source code/2-Monte Carlo Methods (Chapter 5)/.ipynb_checkpoints/4-off-policy-monte-carlo-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Off-policy Monte Carlo\n", 8 | "- Algorithms from ```pp. 84 - 92``` in Sutton & Barto 2017\n", 9 | " - Off-policy prediction via importance sampling\n", 10 | " - Off-policy MC control" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "metadata": { 17 | "collapsed": true 18 | }, 19 | "outputs": [], 20 | "source": [ 21 | "import matplotlib.pyplot as plt\n", 22 | "import pandas as pd\n", 23 | "import numpy as np\n", 24 | "import seaborn\n", 25 | "\n", 26 | "from gridWorldEnvironment import GridWorld" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 2, 32 | "metadata": { 33 | "collapsed": true 34 | }, 35 | "outputs": [], 36 | "source": [ 37 | "# creating gridworld environment\n", 38 | "gw = GridWorld(gamma = .9)" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": {}, 44 | "source": [ 45 | "### Generate functions for generating policy" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 3, 51 | "metadata": { 52 | "collapsed": true 53 | }, 54 | "outputs": [], 55 | "source": [ 56 | "def generate_any_policy(env):\n", 57 | " pi = dict()\n", 58 | " for state in env.states:\n", 59 | " r = sorted(np.random.sample(3))\n", 60 | " actions = env.actions\n", 61 | " prob = [r[0], r[1] - r[0], r[2] - r[1], 1-r[2]]\n", 62 | " pi[state] = (actions, prob)\n", 63 | " return pi " 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 4, 69 | "metadata": { 70 | "collapsed": true 71 | }, 72 | "outputs": [], 73 | "source": [ 74 | "def generate_random_policy(env):\n", 75 | " pi = dict()\n", 76 | " for state in env.states:\n", 77 | " actions = []\n", 78 | " prob = []\n", 79 | " for action in env.actions:\n", 80 | " actions.append(action)\n", 81 | " prob.append(0.25)\n", 82 | " pi[state] = (actions, prob)\n", 83 | " return pi" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 5, 89 | "metadata": { 90 | "collapsed": true 91 | }, 92 | "outputs": [], 93 | "source": [ 94 | "def generate_greedy_policy(env, Q):\n", 95 | " pi = dict()\n", 96 | " for state in env.states:\n", 97 | " actions = []\n", 98 | " q_values = []\n", 99 | " prob = []\n", 100 | " \n", 101 | " for a in env.actions:\n", 102 | " actions.append(a)\n", 103 | " q_values.append(Q[state,a]) \n", 104 | " for i in range(len(q_values)):\n", 105 | " if i == np.argmax(q_values):\n", 106 | " prob.append(1)\n", 107 | " else:\n", 108 | " prob.append(0) \n", 109 | " \n", 110 | " pi[state] = (actions, prob)\n", 111 | " return pi" 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "metadata": {}, 117 | "source": [ 118 | "### Create Action Values\n", 119 | "- Initialize all state-action values with 0" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 6, 125 | "metadata": { 126 | "collapsed": true 127 | }, 128 | "outputs": [], 129 | "source": [ 130 | "def state_action_value(env):\n", 131 | " q = dict()\n", 132 | " for state, action, next_state, reward in env.transitions:\n", 133 | " q[(state, action)] = np.random.random()\n", 134 | " return q" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 7, 140 | "metadata": { 141 | "collapsed": true 142 | }, 143 | "outputs": [], 144 | "source": [ 145 | "def weight_cum_sum(env):\n", 146 | " c = dict()\n", 147 | " for state, action, next_state, reward in env.transitions:\n", 148 | " c[(state, action)] = 0\n", 149 | " return c" 150 | ] 151 | }, 152 | { 153 | "cell_type": "markdown", 154 | "metadata": {}, 155 | "source": [ 156 | "### Generate episode\n", 157 | "- Generate episode based on current policy ($\\pi$)" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 8, 163 | "metadata": { 164 | "collapsed": true, 165 | "scrolled": true 166 | }, 167 | "outputs": [], 168 | "source": [ 169 | "def generate_episode(env, policy):\n", 170 | " episode = []\n", 171 | " done = False\n", 172 | " current_state = np.random.choice(env.states)\n", 173 | " action = np.random.choice(policy[current_state][0], p = policy[current_state][1])\n", 174 | " episode.append((current_state, action, -1))\n", 175 | " \n", 176 | " while not done:\n", 177 | " next_state, reward = gw.state_transition(current_state, action)\n", 178 | " action = np.random.choice(policy[current_state][0], p = policy[current_state][1])\n", 179 | " episode.append((next_state, action, reward))\n", 180 | " \n", 181 | " if next_state == 0:\n", 182 | " done = True\n", 183 | " current_state = next_state\n", 184 | " \n", 185 | " return episode" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": 9, 191 | "metadata": { 192 | "collapsed": true, 193 | "scrolled": true 194 | }, 195 | "outputs": [], 196 | "source": [ 197 | "pi = generate_random_policy(gw)" 198 | ] 199 | }, 200 | { 201 | "cell_type": "markdown", 202 | "metadata": {}, 203 | "source": [ 204 | "### Off-policy MC Prediction\n", 205 | "- Estimates $Q$ values" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": 10, 211 | "metadata": { 212 | "collapsed": true 213 | }, 214 | "outputs": [], 215 | "source": [ 216 | "def off_policy_mc_prediction(env, pi, num_iter):\n", 217 | " Q = state_action_value(env)\n", 218 | " C = weight_cum_sum(env)\n", 219 | " \n", 220 | " for _ in range(num_iter):\n", 221 | " b = generate_any_policy(env)\n", 222 | " episode = generate_episode(gw, b)\n", 223 | " G = 0\n", 224 | " W = 1\n", 225 | " for i in range(len(episode)-1, -1, -1):\n", 226 | " s, a, r = episode[i]\n", 227 | " if s != 0:\n", 228 | " G = env.gamma * G + r\n", 229 | " C[s,a] += W\n", 230 | " Q[s,a] += (W / C[s,a]) * (G - Q[s,a])\n", 231 | " W *= pi[s][1][pi[s][0].index(a)] / b[s][1][b[s][0].index(a)]\n", 232 | " if W == 0:\n", 233 | " break\n", 234 | " \n", 235 | " return Q" 236 | ] 237 | }, 238 | { 239 | "cell_type": "markdown", 240 | "metadata": {}, 241 | "source": [ 242 | "### Off-policy MC Control\n", 243 | "- Finds optimal policy $pi \\approx pi_*$" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": 12, 249 | "metadata": { 250 | "collapsed": true 251 | }, 252 | "outputs": [], 253 | "source": [ 254 | "def off_policy_mc_control(env, pi, num_iter):\n", 255 | " Q = state_action_value(env)\n", 256 | " C = weight_cum_sum(env)\n", 257 | " pi = generate_greedy_policy(env, Q)\n", 258 | " \n", 259 | " for _ in range(num_iter):\n", 260 | " b = generate_any_policy(env)\n", 261 | " episode = generate_episode(gw, b)\n", 262 | " G = 0\n", 263 | " W = 1\n", 264 | " for i in range(len(episode)-1, -1, -1):\n", 265 | " s, a, r = episode[i]\n", 266 | " if s != 0:\n", 267 | " G = env.gamma * G + r\n", 268 | " C[s,a] += W\n", 269 | " Q[s,a] += (W / C[s,a]) * (G - Q[s,a])\n", 270 | " pi = generate_greedy_policy(env, Q)\n", 271 | " if a == pi[s][0][np.argmax(pi[s][1])]:\n", 272 | " break\n", 273 | " W *= 1 / b[s][1][b[s][0].index(a)]\n", 274 | "\n", 275 | " return Q, pi" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": 21, 281 | "metadata": {}, 282 | "outputs": [ 283 | { 284 | "name": "stdout", 285 | "output_type": "stream", 286 | "text": [ 287 | "Wall time: 1.51 s\n" 288 | ] 289 | } 290 | ], 291 | "source": [ 292 | "%%time\n", 293 | "Q_hat, pi_hat = off_policy_mc_control(gw, generate_random_policy(gw), 1000)" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": 22, 299 | "metadata": {}, 300 | "outputs": [ 301 | { 302 | "data": { 303 | "text/plain": [ 304 | "{1: (['U', 'D', 'L', 'R'], [0, 0, 1, 0]),\n", 305 | " 2: (['U', 'D', 'L', 'R'], [0, 1, 0, 0]),\n", 306 | " 3: (['U', 'D', 'L', 'R'], [0, 1, 0, 0]),\n", 307 | " 4: (['U', 'D', 'L', 'R'], [1, 0, 0, 0]),\n", 308 | " 5: (['U', 'D', 'L', 'R'], [1, 0, 0, 0]),\n", 309 | " 6: (['U', 'D', 'L', 'R'], [0, 0, 0, 1]),\n", 310 | " 7: (['U', 'D', 'L', 'R'], [0, 0, 0, 1]),\n", 311 | " 8: (['U', 'D', 'L', 'R'], [0, 0, 0, 1]),\n", 312 | " 9: (['U', 'D', 'L', 'R'], [0, 0, 1, 0]),\n", 313 | " 10: (['U', 'D', 'L', 'R'], [0, 0, 0, 1]),\n", 314 | " 11: (['U', 'D', 'L', 'R'], [0, 1, 0, 0]),\n", 315 | " 12: (['U', 'D', 'L', 'R'], [0, 0, 1, 0]),\n", 316 | " 13: (['U', 'D', 'L', 'R'], [0, 0, 1, 0]),\n", 317 | " 14: (['U', 'D', 'L', 'R'], [0, 0, 0, 1])}" 318 | ] 319 | }, 320 | "execution_count": 22, 321 | "metadata": {}, 322 | "output_type": "execute_result" 323 | } 324 | ], 325 | "source": [ 326 | "# final policy obtained\n", 327 | "pi_hat" 328 | ] 329 | }, 330 | { 331 | "cell_type": "markdown", 332 | "metadata": {}, 333 | "source": [ 334 | "### Visualizing policy" 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": 23, 340 | "metadata": { 341 | "collapsed": true 342 | }, 343 | "outputs": [], 344 | "source": [ 345 | "def show_policy(pi, env):\n", 346 | " temp = np.zeros(len(env.states) + 2)\n", 347 | " for s in env.states:\n", 348 | " a = pi_hat[s][0][np.argmax(pi_hat[s][1])]\n", 349 | " if a == \"U\":\n", 350 | " temp[s] = 0.25\n", 351 | " elif a == \"D\":\n", 352 | " temp[s] = 0.5\n", 353 | " elif a == \"R\":\n", 354 | " temp[s] = 0.75\n", 355 | " else:\n", 356 | " temp[s] = 1.0\n", 357 | " \n", 358 | " temp = temp.reshape(4,4)\n", 359 | " ax = seaborn.heatmap(temp, cmap = \"prism\", linecolor=\"#282828\", cbar = False, linewidths = 0.1)\n", 360 | " plt.show()" 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": 24, 366 | "metadata": {}, 367 | "outputs": [ 368 | { 369 | "data": { 370 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAdwAAAFJCAYAAAAxCJwFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACsZJREFUeJzt3UtoXIXfxvFfmhEV4wUsSBGEbtoMCC0u3HURaLwUJFQK\ntkoUJCsXTaEWGVMveGnrSvAS6mXlQuPK4kIo1hYFBReFFgITRQTxtnElbbFpnfkv5N8XX9q8Ji/z\njJl+PrucQ8489DB8O4eEDHW73W4BAD21pt8DAOBqILgAECC4ABAguAAQILgAECC4ABDQ6OXFm81m\ntRcWevkS9EhzdLTubLt3q9V8c7TaH7l/q1Fz+2hVlfu3SjW3j1a73b7sOZ9wASBAcAEgQHABIEBw\nASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHAB\nIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEg\nQHABIEBwASDgHwe30+n0cgcADLTGUid//PHHOnjwYM3Pz1ej0ahOp1MbNmyoVqtV69evT20EgFVv\nyeDOzMzU3r17a9OmTZeOnTp1qlqtVs3NzfV8HAAMiiUfKS8uLv4ttlVVmzdv7ukgABhES37C3bhx\nY7VardqyZUvdeOONdfbs2fr8889r48aNqX0AMBCWDO7zzz9fx44dq5MnT9aZM2dqZGSkxsbGanx8\nPLUPAAbCksEdGhqq8fFxgQWA/ye/hwsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwAB\nggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGC\nCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAFD3W6326uLN5vNXl0aAP6V2u32ZY83\nev3CC5ve7fVL0AOjp6fcu1Vs9PRUzSx80O8ZrMDLo7uqqty/Veq/9+9yPFIGgADBBYAAwQWAAMEF\ngADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWA\nAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAA\nwQWAAMEFgADBBYAAwQWAgMZSJycnJ+vChQt/O9btdmtoaKjm5uZ6OgwABsmSwX3yySdr//799eab\nb9bw8HBqEwAMnCWDu2nTppqYmKhvvvmmxsfHU5sAYOAsGdyqqqmpqcQOABhofmgKAAIEFwACBBcA\nAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwAC\nBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIE\nFwACBBcAAoa63W63VxdvNpu9ujQA/Cu12+3LHm/0+oVnFj7o9UvQAy+P7qo72wv9nsEKzTdHvfdW\nqZdHd1VVVfsj77/VqLl99IrnPFIGgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEF\ngADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWA\nAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAgGUHd3FxsRc7AGCgXTG4x48fr7Gx\nsRofH69PPvnk0vGpqanIMAAYJI0rnTh8+HAdOXKkOp1OTU9P1/nz52v79u3V7XaT+wBgIFwxuNdc\nc03dfPPNVVU1Oztbjz32WK1bt66GhoZi4wBgUFzxkfLtt99eBw8erHPnztXIyEi98cYb9cILL9T3\n33+f3AcAA+GKwT1w4EBt3Ljx0ifadevW1XvvvVf3339/bBwADIorPlJuNBr14IMP/u3Y2rVra2Zm\npuejAGDQ+D1cAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBc\nAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwA\nCBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBjqdrvdXl282Wz26tIA8K/Ubrcve7zR6xe+s73Q65eg\nB+abo+7dKjbfHK2ZhQ/6PYMVeHl0V1VVtRe8/1aj5ujoFc95pAwAAYILAAGCCwABggsAAYILAAGC\nCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYIL\nAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABywru\nH3/8UYuLi73aAgADa8ngfvfdd/XEE09Uq9Wqr776qrZt21bbtm2rEydOpPYBwEBoLHXyueeeq+np\n6fr5559r9+7ddfTo0br22mtramqqxsbGUhsBYNVbMridTqfuvvvuqqr6+uuv69Zbb/3rmxpLfhsA\n8L8s+Uh5/fr1NTMzU51Opw4dOlRVVW+//XatXbs2Mg4ABsWSH1VfeumlOn78eK1Z8z9dvu2222py\ncrLnwwBgkCwZ3DVr1tTWrVv/dmxiYqKngwBgEPk9XAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAI\nEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQ\nXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBjqdrvdXl282Wz26tIA\n8K/Ubrcve7ynwQUA/uKRMgAECC4ABAguAAQILgAECC4ABAguAAQI7gp0Op169tln66GHHqrJycn6\n4Ycf+j2JZTp9+nRNTk72ewbLdOHChdq3b189/PDDtWPHjvrss8/6PYl/6M8//6xWq1U7d+6sXbt2\n1bffftvvSXGCuwLHjh2rxcXF+vDDD2vv3r116NChfk9iGd55553av39/nT9/vt9TWKaPP/64brnl\nlnr//ffr3XffrRdffLHfk/iHTpw4UVVVc3NztWfPnnr11Vf7vChPcFfg5MmTtWXLlqqq2rx5c83P\nz/d5Ectxxx131Ouvv97vGazAfffdV9PT01VV1e12a3h4uM+L+Ke2bt166T9Iv/zyS9100019XpTX\n6PeA1ejMmTM1MjJy6evh4eG6ePFiNRr+OVeDe++9t3766ad+z2AFbrjhhqr66z24e/fu2rNnT58X\nsRyNRqOeeuqp+vTTT+u1117r95w4n3BXYGRkpM6ePXvp606nI7YQ8uuvv9ajjz5aExMT9cADD/R7\nDsv0yiuv1NGjR+uZZ56pc+fO9XtOlOCuwF133VVffPFFVVWdOnWqNmzY0OdFcHX47bff6vHHH699\n+/bVjh07+j2HZThy5Ei99dZbVVV1/fXX19DQUK1Zc3UlyMeyFRgfH68vv/yydu7cWd1utw4cONDv\nSXBVOHz4cP3+++81Oztbs7OzVfXXD8Fdd911fV7G/+Wee+6pVqtVjzzySF28eLGefvrpq+6++WtB\nABBwdX2eB4A+EVwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACPgPqwelrdFfLd8AAAAASUVORK5C\nYII=\n", 371 | "text/plain": [ 372 | "" 373 | ] 374 | }, 375 | "metadata": {}, 376 | "output_type": "display_data" 377 | } 378 | ], 379 | "source": [ 380 | "### RED = TERMINAL (0)\n", 381 | "### GREEN = LEFT\n", 382 | "### BLUE = UP\n", 383 | "### PURPLE = RIGHT\n", 384 | "### ORANGE = DOWN\n", 385 | "\n", 386 | "show_policy(pi_hat, gw)" 387 | ] 388 | } 389 | ], 390 | "metadata": { 391 | "kernelspec": { 392 | "display_name": "Python 3", 393 | "language": "python", 394 | "name": "python3" 395 | }, 396 | "language_info": { 397 | "codemirror_mode": { 398 | "name": "ipython", 399 | "version": 3 400 | }, 401 | "file_extension": ".py", 402 | "mimetype": "text/x-python", 403 | "name": "python", 404 | "nbconvert_exporter": "python", 405 | "pygments_lexer": "ipython3", 406 | "version": "3.6.1" 407 | } 408 | }, 409 | "nbformat": 4, 410 | "nbformat_minor": 2 411 | } 412 | -------------------------------------------------------------------------------- /source code/2-Monte Carlo Methods (Chapter 5)/3-on-policy-monte-carlo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# On-policy Monte Carlo\n", 8 | "- Algorithms from ```pp. 82 - 84``` in Sutton & Barto 2017\n", 9 | "- Estimates optimal policy $\\pi \\approx \\pi_*$" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 34, 15 | "metadata": { 16 | "collapsed": true 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "import matplotlib.pyplot as plt\n", 21 | "import pandas as pd\n", 22 | "import numpy as np\n", 23 | "import seaborn\n", 24 | "\n", 25 | "from gridWorldEnvironment import GridWorld" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 35, 31 | "metadata": { 32 | "collapsed": true 33 | }, 34 | "outputs": [], 35 | "source": [ 36 | "# creating gridworld environment\n", 37 | "gw = GridWorld(gamma = .9)" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "### Generate policy\n", 45 | "- We start with equiprobable random policy" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 36, 51 | "metadata": { 52 | "collapsed": true 53 | }, 54 | "outputs": [], 55 | "source": [ 56 | "def generate_random_policy(env):\n", 57 | " pi = dict()\n", 58 | " for state in env.states:\n", 59 | " actions = []\n", 60 | " prob = []\n", 61 | " for action in env.actions:\n", 62 | " actions.append(action)\n", 63 | " prob.append(0.25)\n", 64 | " pi[state] = (actions, prob)\n", 65 | " return pi" 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "metadata": {}, 71 | "source": [ 72 | "### Create Action Values\n", 73 | "- Initialize all state-action values with 0" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 37, 79 | "metadata": { 80 | "collapsed": true 81 | }, 82 | "outputs": [], 83 | "source": [ 84 | "def state_action_value(env):\n", 85 | " q = dict()\n", 86 | " for state, action, next_state, reward in env.transitions:\n", 87 | " q[(state, action)] = 0\n", 88 | " return q" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 38, 94 | "metadata": {}, 95 | "outputs": [ 96 | { 97 | "name": "stdout", 98 | "output_type": "stream", 99 | "text": [ 100 | "0\n", 101 | "0\n", 102 | "0\n" 103 | ] 104 | } 105 | ], 106 | "source": [ 107 | "q = state_action_value(gw)\n", 108 | "\n", 109 | "print(q[(2, \"U\")])\n", 110 | "print(q[(4, \"L\")])\n", 111 | "print(q[(10, \"R\")])" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 39, 117 | "metadata": { 118 | "collapsed": true, 119 | "scrolled": true 120 | }, 121 | "outputs": [], 122 | "source": [ 123 | "pi = generate_random_policy(gw)" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "metadata": {}, 129 | "source": [ 130 | "### Generate episode\n", 131 | "- Generate episode based on current policy ($\\pi$)" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 40, 137 | "metadata": { 138 | "collapsed": true, 139 | "scrolled": true 140 | }, 141 | "outputs": [], 142 | "source": [ 143 | "def generate_episode(env, policy):\n", 144 | " episode = []\n", 145 | " done = False\n", 146 | " current_state = np.random.choice(env.states)\n", 147 | " action = np.random.choice(policy[current_state][0], p = policy[current_state][1])\n", 148 | " episode.append((current_state, action, -1))\n", 149 | " \n", 150 | " while not done:\n", 151 | " next_state, reward = gw.state_transition(current_state, action)\n", 152 | " action = np.random.choice(policy[current_state][0], p = policy[current_state][1])\n", 153 | " episode.append((next_state, action, reward))\n", 154 | " \n", 155 | " if next_state == 0:\n", 156 | " done = True\n", 157 | " current_state = next_state\n", 158 | " \n", 159 | " return episode" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": 41, 165 | "metadata": {}, 166 | "outputs": [ 167 | { 168 | "data": { 169 | "text/plain": [ 170 | "[(6, 'U', -1),\n", 171 | " (2, 'D', -1),\n", 172 | " (6, 'R', -1),\n", 173 | " (7, 'U', -1),\n", 174 | " (3, 'D', -1),\n", 175 | " (7, 'D', -1),\n", 176 | " (11, 'U', -1),\n", 177 | " (7, 'U', -1),\n", 178 | " (3, 'L', -1),\n", 179 | " (2, 'D', -1),\n", 180 | " (6, 'D', -1),\n", 181 | " (10, 'D', -1),\n", 182 | " (14, 'R', -1),\n", 183 | " (0, 'R', -1)]" 184 | ] 185 | }, 186 | "execution_count": 41, 187 | "metadata": {}, 188 | "output_type": "execute_result" 189 | } 190 | ], 191 | "source": [ 192 | "generate_episode(gw, pi)" 193 | ] 194 | }, 195 | { 196 | "cell_type": "markdown", 197 | "metadata": {}, 198 | "source": [ 199 | "### On-policy Monte Carlo Control " 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": 42, 205 | "metadata": { 206 | "collapsed": true 207 | }, 208 | "outputs": [], 209 | "source": [ 210 | "# first-visit MC\n", 211 | "def on_policy_mc(env, pi, e, num_iter):\n", 212 | " Q = state_action_value(env)\n", 213 | " returns = dict()\n", 214 | " for s, a in Q:\n", 215 | " returns[(s,a)] = []\n", 216 | " \n", 217 | " for i in range(num_iter):\n", 218 | " episode = generate_episode(env, pi)\n", 219 | " already_visited = set({0})\n", 220 | " for s, a, r in episode:\n", 221 | " if s not in already_visited:\n", 222 | " already_visited.add(s)\n", 223 | " idx = episode.index((s, a, r))\n", 224 | " G = 0\n", 225 | " j = 1\n", 226 | " while j + idx < len(episode):\n", 227 | " G = env.gamma * (G + episode[j + idx][-1])\n", 228 | " j += 1\n", 229 | " returns[(s,a)].append(G)\n", 230 | " Q[(s,a)] = np.mean(returns[(s,a)])\n", 231 | " for s, _, _ in episode:\n", 232 | " if s != 0:\n", 233 | " actions = []\n", 234 | " action_values = []\n", 235 | " prob = []\n", 236 | "\n", 237 | " for a in env.actions:\n", 238 | " actions.append(a)\n", 239 | " action_values.append(Q[s,a]) \n", 240 | " for i in range(len(action_values)):\n", 241 | " if i == np.argmax(action_values):\n", 242 | " prob.append(1 - e + e/len(actions))\n", 243 | " else:\n", 244 | " prob.append(e/len(actions)) \n", 245 | " pi[s] = (actions, prob)\n", 246 | " return Q, pi" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 43, 252 | "metadata": {}, 253 | "outputs": [ 254 | { 255 | "name": "stdout", 256 | "output_type": "stream", 257 | "text": [ 258 | "Wall time: 1min 8s\n" 259 | ] 260 | } 261 | ], 262 | "source": [ 263 | "# Obtained Estimates for Q & pi after 100000 iterations\n", 264 | "Q_hat, pi_hat = on_policy_mc(gw, generate_random_policy(gw), 0.2, 100000)" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": 44, 270 | "metadata": {}, 271 | "outputs": [ 272 | { 273 | "data": { 274 | "text/plain": [ 275 | "{1: (['U', 'D', 'L', 'R'], [0.05, 0.05, 0.8500000000000001, 0.05]),\n", 276 | " 2: (['U', 'D', 'L', 'R'], [0.05, 0.05, 0.8500000000000001, 0.05]),\n", 277 | " 3: (['U', 'D', 'L', 'R'], [0.05, 0.8500000000000001, 0.05, 0.05]),\n", 278 | " 4: (['U', 'D', 'L', 'R'], [0.8500000000000001, 0.05, 0.05, 0.05]),\n", 279 | " 5: (['U', 'D', 'L', 'R'], [0.8500000000000001, 0.05, 0.05, 0.05]),\n", 280 | " 6: (['U', 'D', 'L', 'R'], [0.05, 0.05, 0.8500000000000001, 0.05]),\n", 281 | " 7: (['U', 'D', 'L', 'R'], [0.05, 0.8500000000000001, 0.05, 0.05]),\n", 282 | " 8: (['U', 'D', 'L', 'R'], [0.8500000000000001, 0.05, 0.05, 0.05]),\n", 283 | " 9: (['U', 'D', 'L', 'R'], [0.05, 0.05, 0.05, 0.8500000000000001]),\n", 284 | " 10: (['U', 'D', 'L', 'R'], [0.05, 0.05, 0.05, 0.8500000000000001]),\n", 285 | " 11: (['U', 'D', 'L', 'R'], [0.05, 0.8500000000000001, 0.05, 0.05]),\n", 286 | " 12: (['U', 'D', 'L', 'R'], [0.05, 0.05, 0.05, 0.8500000000000001]),\n", 287 | " 13: (['U', 'D', 'L', 'R'], [0.05, 0.05, 0.05, 0.8500000000000001]),\n", 288 | " 14: (['U', 'D', 'L', 'R'], [0.05, 0.05, 0.05, 0.8500000000000001])}" 289 | ] 290 | }, 291 | "execution_count": 44, 292 | "metadata": {}, 293 | "output_type": "execute_result" 294 | } 295 | ], 296 | "source": [ 297 | "# final policy obtained\n", 298 | "pi_hat" 299 | ] 300 | }, 301 | { 302 | "cell_type": "markdown", 303 | "metadata": {}, 304 | "source": [ 305 | "### Visualizing policy" 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "execution_count": 45, 311 | "metadata": { 312 | "collapsed": true 313 | }, 314 | "outputs": [], 315 | "source": [ 316 | "def show_policy(pi, env):\n", 317 | " temp = np.zeros(len(env.states) + 2)\n", 318 | " for s in env.states:\n", 319 | " a = pi_hat[s][0][np.argmax(pi_hat[s][1])]\n", 320 | " if a == \"U\":\n", 321 | " temp[s] = 0.25\n", 322 | " elif a == \"D\":\n", 323 | " temp[s] = 0.5\n", 324 | " elif a == \"R\":\n", 325 | " temp[s] = 0.75\n", 326 | " else:\n", 327 | " temp[s] = 1.0\n", 328 | " \n", 329 | " temp = temp.reshape(4,4)\n", 330 | " ax = seaborn.heatmap(temp, cmap = \"prism\", linecolor=\"#282828\", cbar = False, linewidths = 0.1)\n", 331 | " plt.show()" 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "execution_count": 46, 337 | "metadata": {}, 338 | "outputs": [ 339 | { 340 | "data": { 341 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAdwAAAFJCAYAAAAxCJwFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACsRJREFUeJzt3UtoXAX/xvFfmhEV4wUsSBGEbpoMvNDiwl0XgcZLQUKl\nYKtEQbJy0RRqkZh6wUtbV4KXULUrF75xZXEhFGuLgoKLQguBiSKCeNu4krbYtM68C6F//NPkNXmZ\nZ8z089nlHHLmoYfy7TkkdKDT6XQKAOiqdb0eAADXAsEFgADBBYAAwQWAAMEFgADBBYCARjcv3mw2\nq7Ww0M2PoEuaIyP1r5Z7t1bNN92/tWq+OVJVVa0P3b+1qLljpFqt1lXPecIFgADBBYAAwQWAAMEF\ngADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWA\nAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAA\nwQWAAMEFgIC/Hdx2u93NHQDQ1xrLnfzhhx/q0KFDNT8/X41Go9rtdm3atKmmp6dr48aNqY0AsOYt\nG9yZmZnat29fbd68+cqxM2fO1PT0dM3NzXV9HAD0i2VfKS8uLv4ltlVVW7Zs6eogAOhHyz7hDg8P\n1/T0dG3durVuvvnmOn/+fH322Wc1PDyc2gcAfWHZ4L7wwgt14sSJOn36dJ07d66GhoZqdHS0xsbG\nUvsAoC8sG9yBgYEaGxsTWAD4H/k9XAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQ\nXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBc\nAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBjodDqdbl282Wx269IA8I/UarWuerzR\n7Q9e2Hy02x9BF4ycnXTv1rCRs5P1r9ZCr2ewCvPNkaqqan3o/q1FzR0jS57zShkAAgQXAAIEFwAC\nBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIE\nFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQX\nAAIEFwACBBcAAgQXAAIEFwACGsudnJiYqEuXLv3lWKfTqYGBgZqbm+vqMADoJ8sG96mnnqoDBw7U\nW2+9VYODg6lNANB3lg3u5s2ba3x8vL7++usaGxtLbQKAvrNscKuqJicnEzsAoK/5oSkACBBcAAgQ\nXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBc\nAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwA\nCBBcAAgQXAAIGOh0Op1uXbzZbHbr0gDwj9Rqta56vNHtD17YfLTbH0EXjJydrJmFf/d6Bqv0yshu\n92+NemVkd1VVtT5c6PESVqO5Y2TJc14pA0CA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOAC\nQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJA\ngOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQMCKg7u4uNiNHQDQ15YM7smT\nJ2t0dLTGxsbq448/vnJ8cnIyMgwA+kljqRNHjhypY8eOVbvdrqmpqbp48WLt2LGjOp1Och8A9IUl\ng3vdddfVrbfeWlVVs7Oz9fjjj9eGDRtqYGAgNg4A+sWSr5TvvPPOOnToUF24cKGGhobqzTffrBdf\nfLG+++675D4A6AtLBvfgwYM1PDx85Yl2w4YN9d5779UDDzwQGwcA/WLJV8qNRqMeeuihvxxbv359\nzczMdH0UAPQbv4cLAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsA\nAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwAB\nggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAQOdTqfTrYs3m81uXRoA/pFardZVjze6/cEzC//u\n9kfQBa+M7Hbv1jD3b+16ZWR3VVW1FhZ6vITVaI6MLHnOK2UACBBcAAgQXAAIEFwACBBcAAgQXAAI\nEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQ\nXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIWFFwf//9\n91pcXOzWFgDoW8sG99tvv60nn3yypqen68svv6zt27fX9u3b69SpU6l9ANAXGsudfP7552tqaqp+\n+umn2rNnTx0/fryuv/76mpycrNHR0dRGAFjzlg1uu92ue+65p6qqvvrqq7r99tv//KbGst8GAPw/\ny75S3rhxY83MzFS73a7Dhw9XVdU777xT69evj4wDgH6x7KPqyy+/XCdPnqx16/6vy3fccUdNTEx0\nfRgA9JNlg7tu3bratm3bX46Nj493dRAA9CO/hwsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYIL\nAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsA\nAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAEDnU6n062LN5vNbl0aAP6R\nWq3WVY93NbgAwJ+8UgaAAMEFgADBBYAAwQWAAMEFgADBBYAAwV2Fdrtdzz33XD388MM1MTFR33//\nfa8nsUJnz56tiYmJXs9ghS5dulT79++vRx55pHbu3FmffvppryfxN/3xxx81PT1du3btqt27d9c3\n33zT60lxgrsKJ06cqMXFxfrggw9q3759dfjw4V5PYgXefffdOnDgQF28eLHXU1ihjz76qG677bZ6\n//336+jRo/XSSy/1ehJ/06lTp6qqam5urvbu3VuvvfZajxflCe4qnD59urZu3VpVVVu2bKn5+fke\nL2Il7rrrrnrjjTd6PYNVuP/++2tqaqqqqjqdTg0ODvZ4EX/Xtm3brvwD6eeff65bbrmlx4vyGr0e\nsBadO3euhoaGrnw9ODhYly9frkbDH+dacN9999WPP/7Y6xmswk033VRVf/4d3LNnT+3du7fHi1iJ\nRqNRTz/9dH3yySf1+uuv93pOnCfcVRgaGqrz589f+brdbosthPzyyy/12GOP1fj4eD344IO9nsMK\nvfrqq3X8+PF69tln68KFC72eEyW4q3D33XfX559/XlVVZ86cqU2bNvV4EVwbfv3113riiSdq//79\ntXPnzl7PYQWOHTtWb7/9dlVV3XjjjTUwMFDr1l1bCfJYtgpjY2P1xRdf1K5du6rT6dTBgwd7PQmu\nCUeOHKnffvutZmdna3Z2tqr+/CG4G264ocfL+G/uvffemp6erkcffbQuX75czzzzzDV33/xvQQAQ\ncG09zwNAjwguAAQILgAECC4ABAguAAQILgAECC4ABAguAAT8B0u/pa2vpAxwAAAAAElFTkSuQmCC\n", 342 | "text/plain": [ 343 | "" 344 | ] 345 | }, 346 | "metadata": {}, 347 | "output_type": "display_data" 348 | } 349 | ], 350 | "source": [ 351 | "### RED = TERMINAL (0)\n", 352 | "### GREEN = LEFT\n", 353 | "### BLUE = UP\n", 354 | "### PURPLE = RIGHT\n", 355 | "### ORANGE = DOWN\n", 356 | "\n", 357 | "show_policy(pi_hat, gw)" 358 | ] 359 | } 360 | ], 361 | "metadata": { 362 | "kernelspec": { 363 | "display_name": "Python 3", 364 | "language": "python", 365 | "name": "python3" 366 | }, 367 | "language_info": { 368 | "codemirror_mode": { 369 | "name": "ipython", 370 | "version": 3 371 | }, 372 | "file_extension": ".py", 373 | "mimetype": "text/x-python", 374 | "name": "python", 375 | "nbconvert_exporter": "python", 376 | "pygments_lexer": "ipython3", 377 | "version": "3.6.1" 378 | } 379 | }, 380 | "nbformat": 4, 381 | "nbformat_minor": 2 382 | } 383 | -------------------------------------------------------------------------------- /source code/2-Monte Carlo Methods (Chapter 5)/4-off-policy-monte-carlo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Off-policy Monte Carlo\n", 8 | "- Algorithms from ```pp. 84 - 92``` in Sutton & Barto 2017\n", 9 | " - Off-policy prediction via importance sampling\n", 10 | " - Off-policy MC control" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "metadata": { 17 | "collapsed": true 18 | }, 19 | "outputs": [], 20 | "source": [ 21 | "import matplotlib.pyplot as plt\n", 22 | "import pandas as pd\n", 23 | "import numpy as np\n", 24 | "import seaborn\n", 25 | "\n", 26 | "from gridWorldEnvironment import GridWorld" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 2, 32 | "metadata": { 33 | "collapsed": true 34 | }, 35 | "outputs": [], 36 | "source": [ 37 | "# creating gridworld environment\n", 38 | "gw = GridWorld(gamma = .9)" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": {}, 44 | "source": [ 45 | "### Generate functions for generating policy" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 3, 51 | "metadata": { 52 | "collapsed": true 53 | }, 54 | "outputs": [], 55 | "source": [ 56 | "def generate_any_policy(env):\n", 57 | " pi = dict()\n", 58 | " for state in env.states:\n", 59 | " r = sorted(np.random.sample(3))\n", 60 | " actions = env.actions\n", 61 | " prob = [r[0], r[1] - r[0], r[2] - r[1], 1-r[2]]\n", 62 | " pi[state] = (actions, prob)\n", 63 | " return pi " 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 4, 69 | "metadata": { 70 | "collapsed": true 71 | }, 72 | "outputs": [], 73 | "source": [ 74 | "def generate_random_policy(env):\n", 75 | " pi = dict()\n", 76 | " for state in env.states:\n", 77 | " actions = []\n", 78 | " prob = []\n", 79 | " for action in env.actions:\n", 80 | " actions.append(action)\n", 81 | " prob.append(0.25)\n", 82 | " pi[state] = (actions, prob)\n", 83 | " return pi" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 5, 89 | "metadata": { 90 | "collapsed": true 91 | }, 92 | "outputs": [], 93 | "source": [ 94 | "def generate_greedy_policy(env, Q):\n", 95 | " pi = dict()\n", 96 | " for state in env.states:\n", 97 | " actions = []\n", 98 | " q_values = []\n", 99 | " prob = []\n", 100 | " \n", 101 | " for a in env.actions:\n", 102 | " actions.append(a)\n", 103 | " q_values.append(Q[state,a]) \n", 104 | " for i in range(len(q_values)):\n", 105 | " if i == np.argmax(q_values):\n", 106 | " prob.append(1)\n", 107 | " else:\n", 108 | " prob.append(0) \n", 109 | " \n", 110 | " pi[state] = (actions, prob)\n", 111 | " return pi" 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "metadata": {}, 117 | "source": [ 118 | "### Create Action Values\n", 119 | "- Initialize all state-action values with 0" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 6, 125 | "metadata": { 126 | "collapsed": true 127 | }, 128 | "outputs": [], 129 | "source": [ 130 | "def state_action_value(env):\n", 131 | " q = dict()\n", 132 | " for state, action, next_state, reward in env.transitions:\n", 133 | " q[(state, action)] = np.random.random()\n", 134 | " return q" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 7, 140 | "metadata": { 141 | "collapsed": true 142 | }, 143 | "outputs": [], 144 | "source": [ 145 | "def weight_cum_sum(env):\n", 146 | " c = dict()\n", 147 | " for state, action, next_state, reward in env.transitions:\n", 148 | " c[(state, action)] = 0\n", 149 | " return c" 150 | ] 151 | }, 152 | { 153 | "cell_type": "markdown", 154 | "metadata": {}, 155 | "source": [ 156 | "### Generate episode\n", 157 | "- Generate episode based on current policy ($\\pi$)" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 8, 163 | "metadata": { 164 | "collapsed": true, 165 | "scrolled": true 166 | }, 167 | "outputs": [], 168 | "source": [ 169 | "def generate_episode(env, policy):\n", 170 | " episode = []\n", 171 | " done = False\n", 172 | " current_state = np.random.choice(env.states)\n", 173 | " action = np.random.choice(policy[current_state][0], p = policy[current_state][1])\n", 174 | " episode.append((current_state, action, -1))\n", 175 | " \n", 176 | " while not done:\n", 177 | " next_state, reward = gw.state_transition(current_state, action)\n", 178 | " action = np.random.choice(policy[current_state][0], p = policy[current_state][1])\n", 179 | " episode.append((next_state, action, reward))\n", 180 | " \n", 181 | " if next_state == 0:\n", 182 | " done = True\n", 183 | " current_state = next_state\n", 184 | " \n", 185 | " return episode" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": 9, 191 | "metadata": { 192 | "collapsed": true, 193 | "scrolled": true 194 | }, 195 | "outputs": [], 196 | "source": [ 197 | "pi = generate_random_policy(gw)" 198 | ] 199 | }, 200 | { 201 | "cell_type": "markdown", 202 | "metadata": {}, 203 | "source": [ 204 | "### Off-policy MC Prediction\n", 205 | "- Estimates $Q$ values" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": 10, 211 | "metadata": { 212 | "collapsed": true 213 | }, 214 | "outputs": [], 215 | "source": [ 216 | "def off_policy_mc_prediction(env, pi, num_iter):\n", 217 | " Q = state_action_value(env)\n", 218 | " C = weight_cum_sum(env)\n", 219 | " \n", 220 | " for _ in range(num_iter):\n", 221 | " b = generate_any_policy(env)\n", 222 | " episode = generate_episode(gw, b)\n", 223 | " G = 0\n", 224 | " W = 1\n", 225 | " for i in range(len(episode)-1, -1, -1):\n", 226 | " s, a, r = episode[i]\n", 227 | " if s != 0:\n", 228 | " G = env.gamma * G + r\n", 229 | " C[s,a] += W\n", 230 | " Q[s,a] += (W / C[s,a]) * (G - Q[s,a])\n", 231 | " W *= pi[s][1][pi[s][0].index(a)] / b[s][1][b[s][0].index(a)]\n", 232 | " if W == 0:\n", 233 | " break\n", 234 | " \n", 235 | " return Q" 236 | ] 237 | }, 238 | { 239 | "cell_type": "markdown", 240 | "metadata": {}, 241 | "source": [ 242 | "### Off-policy MC Control\n", 243 | "- Finds optimal policy $pi \\approx pi_*$" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": 12, 249 | "metadata": { 250 | "collapsed": true 251 | }, 252 | "outputs": [], 253 | "source": [ 254 | "def off_policy_mc_control(env, pi, num_iter):\n", 255 | " Q = state_action_value(env)\n", 256 | " C = weight_cum_sum(env)\n", 257 | " pi = generate_greedy_policy(env, Q)\n", 258 | " \n", 259 | " for _ in range(num_iter):\n", 260 | " b = generate_any_policy(env)\n", 261 | " episode = generate_episode(gw, b)\n", 262 | " G = 0\n", 263 | " W = 1\n", 264 | " for i in range(len(episode)-1, -1, -1):\n", 265 | " s, a, r = episode[i]\n", 266 | " if s != 0:\n", 267 | " G = env.gamma * G + r\n", 268 | " C[s,a] += W\n", 269 | " Q[s,a] += (W / C[s,a]) * (G - Q[s,a])\n", 270 | " pi = generate_greedy_policy(env, Q)\n", 271 | " if a == pi[s][0][np.argmax(pi[s][1])]:\n", 272 | " break\n", 273 | " W *= 1 / b[s][1][b[s][0].index(a)]\n", 274 | "\n", 275 | " return Q, pi" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": 21, 281 | "metadata": {}, 282 | "outputs": [ 283 | { 284 | "name": "stdout", 285 | "output_type": "stream", 286 | "text": [ 287 | "Wall time: 1.51 s\n" 288 | ] 289 | } 290 | ], 291 | "source": [ 292 | "%%time\n", 293 | "Q_hat, pi_hat = off_policy_mc_control(gw, generate_random_policy(gw), 1000)" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": 22, 299 | "metadata": {}, 300 | "outputs": [ 301 | { 302 | "data": { 303 | "text/plain": [ 304 | "{1: (['U', 'D', 'L', 'R'], [0, 0, 1, 0]),\n", 305 | " 2: (['U', 'D', 'L', 'R'], [0, 1, 0, 0]),\n", 306 | " 3: (['U', 'D', 'L', 'R'], [0, 1, 0, 0]),\n", 307 | " 4: (['U', 'D', 'L', 'R'], [1, 0, 0, 0]),\n", 308 | " 5: (['U', 'D', 'L', 'R'], [1, 0, 0, 0]),\n", 309 | " 6: (['U', 'D', 'L', 'R'], [0, 0, 0, 1]),\n", 310 | " 7: (['U', 'D', 'L', 'R'], [0, 0, 0, 1]),\n", 311 | " 8: (['U', 'D', 'L', 'R'], [0, 0, 0, 1]),\n", 312 | " 9: (['U', 'D', 'L', 'R'], [0, 0, 1, 0]),\n", 313 | " 10: (['U', 'D', 'L', 'R'], [0, 0, 0, 1]),\n", 314 | " 11: (['U', 'D', 'L', 'R'], [0, 1, 0, 0]),\n", 315 | " 12: (['U', 'D', 'L', 'R'], [0, 0, 1, 0]),\n", 316 | " 13: (['U', 'D', 'L', 'R'], [0, 0, 1, 0]),\n", 317 | " 14: (['U', 'D', 'L', 'R'], [0, 0, 0, 1])}" 318 | ] 319 | }, 320 | "execution_count": 22, 321 | "metadata": {}, 322 | "output_type": "execute_result" 323 | } 324 | ], 325 | "source": [ 326 | "# final policy obtained\n", 327 | "pi_hat" 328 | ] 329 | }, 330 | { 331 | "cell_type": "markdown", 332 | "metadata": {}, 333 | "source": [ 334 | "### Visualizing policy" 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": 23, 340 | "metadata": { 341 | "collapsed": true 342 | }, 343 | "outputs": [], 344 | "source": [ 345 | "def show_policy(pi, env):\n", 346 | " temp = np.zeros(len(env.states) + 2)\n", 347 | " for s in env.states:\n", 348 | " a = pi_hat[s][0][np.argmax(pi_hat[s][1])]\n", 349 | " if a == \"U\":\n", 350 | " temp[s] = 0.25\n", 351 | " elif a == \"D\":\n", 352 | " temp[s] = 0.5\n", 353 | " elif a == \"R\":\n", 354 | " temp[s] = 0.75\n", 355 | " else:\n", 356 | " temp[s] = 1.0\n", 357 | " \n", 358 | " temp = temp.reshape(4,4)\n", 359 | " ax = seaborn.heatmap(temp, cmap = \"prism\", linecolor=\"#282828\", cbar = False, linewidths = 0.1)\n", 360 | " plt.show()" 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": 24, 366 | "metadata": {}, 367 | "outputs": [ 368 | { 369 | "data": { 370 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAdwAAAFJCAYAAAAxCJwFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACsZJREFUeJzt3UtoXIXfxvFfmhEV4wUsSBGEbtoMCC0u3HURaLwUJFQK\ntkoUJCsXTaEWGVMveGnrSvAS6mXlQuPK4kIo1hYFBReFFgITRQTxtnElbbFpnfkv5N8XX9q8Ji/z\njJl+PrucQ8489DB8O4eEDHW73W4BAD21pt8DAOBqILgAECC4ABAguAAQILgAECC4ABDQ6OXFm81m\ntRcWevkS9EhzdLTubLt3q9V8c7TaH7l/q1Fz+2hVlfu3SjW3j1a73b7sOZ9wASBAcAEgQHABIEBw\nASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHAB\nIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEg\nQHABIEBwASDgHwe30+n0cgcADLTGUid//PHHOnjwYM3Pz1ej0ahOp1MbNmyoVqtV69evT20EgFVv\nyeDOzMzU3r17a9OmTZeOnTp1qlqtVs3NzfV8HAAMiiUfKS8uLv4ttlVVmzdv7ukgABhES37C3bhx\nY7VardqyZUvdeOONdfbs2fr8889r48aNqX0AMBCWDO7zzz9fx44dq5MnT9aZM2dqZGSkxsbGanx8\nPLUPAAbCksEdGhqq8fFxgQWA/ye/hwsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwAB\nggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGC\nCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAFD3W6326uLN5vNXl0aAP6V2u32ZY83\nev3CC5ve7fVL0AOjp6fcu1Vs9PRUzSx80O8ZrMDLo7uqqty/Veq/9+9yPFIGgADBBYAAwQWAAMEF\ngADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWA\nAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAA\nwQWAAMEFgADBBYAAwQWAgMZSJycnJ+vChQt/O9btdmtoaKjm5uZ6OgwABsmSwX3yySdr//799eab\nb9bw8HBqEwAMnCWDu2nTppqYmKhvvvmmxsfHU5sAYOAsGdyqqqmpqcQOABhofmgKAAIEFwACBBcA\nAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwAC\nBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIE\nFwACBBcAAoa63W63VxdvNpu9ujQA/Cu12+3LHm/0+oVnFj7o9UvQAy+P7qo72wv9nsEKzTdHvfdW\nqZdHd1VVVfsj77/VqLl99IrnPFIGgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEF\ngADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWA\nAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAgGUHd3FxsRc7AGCgXTG4x48fr7Gx\nsRofH69PPvnk0vGpqanIMAAYJI0rnTh8+HAdOXKkOp1OTU9P1/nz52v79u3V7XaT+wBgIFwxuNdc\nc03dfPPNVVU1Oztbjz32WK1bt66GhoZi4wBgUFzxkfLtt99eBw8erHPnztXIyEi98cYb9cILL9T3\n33+f3AcAA+GKwT1w4EBt3Ljx0ifadevW1XvvvVf3339/bBwADIorPlJuNBr14IMP/u3Y2rVra2Zm\npuejAGDQ+D1cAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBc\nAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwA\nCBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBjqdrvdXl282Wz26tIA8K/Ubrcve7zR6xe+s73Q65eg\nB+abo+7dKjbfHK2ZhQ/6PYMVeHl0V1VVtRe8/1aj5ujoFc95pAwAAYILAAGCCwABggsAAYILAAGC\nCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYIL\nAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABywru\nH3/8UYuLi73aAgADa8ngfvfdd/XEE09Uq9Wqr776qrZt21bbtm2rEydOpPYBwEBoLHXyueeeq+np\n6fr5559r9+7ddfTo0br22mtramqqxsbGUhsBYNVbMridTqfuvvvuqqr6+uuv69Zbb/3rmxpLfhsA\n8L8s+Uh5/fr1NTMzU51Opw4dOlRVVW+//XatXbs2Mg4ABsWSH1VfeumlOn78eK1Z8z9dvu2222py\ncrLnwwBgkCwZ3DVr1tTWrVv/dmxiYqKngwBgEPk9XAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAI\nEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQ\nXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBjqdrvdXl282Wz26tIA\n8K/Ubrcve7ynwQUA/uKRMgAECC4ABAguAAQILgAECC4ABAguAAQI7gp0Op169tln66GHHqrJycn6\n4Ycf+j2JZTp9+nRNTk72ewbLdOHChdq3b189/PDDtWPHjvrss8/6PYl/6M8//6xWq1U7d+6sXbt2\n1bffftvvSXGCuwLHjh2rxcXF+vDDD2vv3r116NChfk9iGd55553av39/nT9/vt9TWKaPP/64brnl\nlnr//ffr3XffrRdffLHfk/iHTpw4UVVVc3NztWfPnnr11Vf7vChPcFfg5MmTtWXLlqqq2rx5c83P\nz/d5Ectxxx131Ouvv97vGazAfffdV9PT01VV1e12a3h4uM+L+Ke2bt166T9Iv/zyS9100019XpTX\n6PeA1ejMmTM1MjJy6evh4eG6ePFiNRr+OVeDe++9t3766ad+z2AFbrjhhqr66z24e/fu2rNnT58X\nsRyNRqOeeuqp+vTTT+u1117r95w4n3BXYGRkpM6ePXvp606nI7YQ8uuvv9ajjz5aExMT9cADD/R7\nDsv0yiuv1NGjR+uZZ56pc+fO9XtOlOCuwF133VVffPFFVVWdOnWqNmzY0OdFcHX47bff6vHHH699\n+/bVjh07+j2HZThy5Ei99dZbVVV1/fXX19DQUK1Zc3UlyMeyFRgfH68vv/yydu7cWd1utw4cONDv\nSXBVOHz4cP3+++81Oztbs7OzVfXXD8Fdd911fV7G/+Wee+6pVqtVjzzySF28eLGefvrpq+6++WtB\nABBwdX2eB4A+EVwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACPgPqwelrdFfLd8AAAAASUVORK5C\nYII=\n", 371 | "text/plain": [ 372 | "" 373 | ] 374 | }, 375 | "metadata": {}, 376 | "output_type": "display_data" 377 | } 378 | ], 379 | "source": [ 380 | "### RED = TERMINAL (0)\n", 381 | "### GREEN = LEFT\n", 382 | "### BLUE = UP\n", 383 | "### PURPLE = RIGHT\n", 384 | "### ORANGE = DOWN\n", 385 | "\n", 386 | "show_policy(pi_hat, gw)" 387 | ] 388 | } 389 | ], 390 | "metadata": { 391 | "kernelspec": { 392 | "display_name": "Python 3", 393 | "language": "python", 394 | "name": "python3" 395 | }, 396 | "language_info": { 397 | "codemirror_mode": { 398 | "name": "ipython", 399 | "version": 3 400 | }, 401 | "file_extension": ".py", 402 | "mimetype": "text/x-python", 403 | "name": "python", 404 | "nbconvert_exporter": "python", 405 | "pygments_lexer": "ipython3", 406 | "version": "3.6.1" 407 | } 408 | }, 409 | "nbformat": 4, 410 | "nbformat_minor": 2 411 | } 412 | -------------------------------------------------------------------------------- /source code/2-Monte Carlo Methods (Chapter 5)/__pycache__/gridWorldEnvironment.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buomsoo-kim/Tabular-RL-with-Python/0157c05126821524dc9d744613b9aa8dd5d47232/source code/2-Monte Carlo Methods (Chapter 5)/__pycache__/gridWorldEnvironment.cpython-36.pyc -------------------------------------------------------------------------------- /source code/2-Monte Carlo Methods (Chapter 5)/gridWorldEnvironment.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import seaborn 4 | from matplotlib.colors import ListedColormap 5 | 6 | class GridWorld: 7 | def __init__(self, gamma = 1.0, theta = 0.5): 8 | self.actions = ("U", "D", "L", "R") 9 | self.states = np.arange(1, 15) 10 | self.transitions = pd.read_csv("gridworld.txt", header = None, sep = "\t").values 11 | self.gamma = gamma 12 | self.theta = theta 13 | 14 | def state_transition(self, state, action): 15 | next_state, reward = None, None 16 | for tr in self.transitions: 17 | if tr[0] == state and tr[1] == action: 18 | next_state = tr[2] 19 | reward = tr[3] 20 | return next_state, reward 21 | 22 | def show_environment(self): 23 | all_states = np.concatenate(([0], self.states, [0])).reshape(4,4) 24 | colors = [] 25 | # colors = ["#ffffff"] 26 | for i in range(len(self.states) + 1): 27 | if i == 0: 28 | colors.append("#c4c4c4") 29 | else: 30 | colors.append("#ffffff") 31 | 32 | cmap = ListedColormap(seaborn.color_palette(colors).as_hex()) 33 | ax = seaborn.heatmap(all_states, cmap = cmap, \ 34 | annot = True, linecolor = "#282828", linewidths = 0.2, \ 35 | cbar = False) -------------------------------------------------------------------------------- /source code/2-Monte Carlo Methods (Chapter 5)/gridworld.txt: -------------------------------------------------------------------------------- 1 | 1 U 1 -1 2 | 1 D 5 -1 3 | 1 R 2 -1 4 | 1 L 0 -1 5 | 2 U 2 -1 6 | 2 D 6 -1 7 | 2 R 3 -1 8 | 2 L 1 -1 9 | 3 U 3 -1 10 | 3 D 7 -1 11 | 3 R 3 -1 12 | 3 L 2 -1 13 | 4 U 0 -1 14 | 4 D 8 -1 15 | 4 R 5 -1 16 | 4 L 4 -1 17 | 5 U 1 -1 18 | 5 D 9 -1 19 | 5 R 6 -1 20 | 5 L 4 -1 21 | 6 U 2 -1 22 | 6 D 10 -1 23 | 6 R 7 -1 24 | 6 L 5 -1 25 | 7 U 3 -1 26 | 7 D 11 -1 27 | 7 R 7 -1 28 | 7 L 6 -1 29 | 8 U 4 -1 30 | 8 D 12 -1 31 | 8 R 9 -1 32 | 8 L 8 -1 33 | 9 U 5 -1 34 | 9 D 13 -1 35 | 9 R 10 -1 36 | 9 L 8 -1 37 | 10 U 6 -1 38 | 10 D 14 -1 39 | 10 R 11 -1 40 | 10 L 9 -1 41 | 11 U 7 -1 42 | 11 D 0 -1 43 | 11 R 11 -1 44 | 11 L 10 -1 45 | 12 U 8 -1 46 | 12 D 12 -1 47 | 12 R 13 -1 48 | 12 L 12 -1 49 | 13 U 9 -1 50 | 13 D 13 -1 51 | 13 R 14 -1 52 | 13 L 12 -1 53 | 14 U 10 -1 54 | 14 D 14 -1 55 | 14 R 0 -1 56 | 14 L 13 -1 -------------------------------------------------------------------------------- /source code/3-Temporal Difference Learning (Chapter 6)/.ipynb_checkpoints/2-SARSA-on-policy control-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# SARSA\n", 8 | "- Algorithms from ```pp. 105 - 107``` in Sutton & Barto 2017" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "metadata": { 15 | "collapsed": true 16 | }, 17 | "outputs": [], 18 | "source": [ 19 | "import matplotlib.pyplot as plt\n", 20 | "import pandas as pd\n", 21 | "import numpy as np\n", 22 | "import seaborn, random\n", 23 | "\n", 24 | "from gridWorldEnvironment import GridWorld" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 2, 30 | "metadata": { 31 | "collapsed": true 32 | }, 33 | "outputs": [], 34 | "source": [ 35 | "# creating gridworld environment\n", 36 | "gw = GridWorld(gamma = .9, theta = .5)" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 3, 42 | "metadata": { 43 | "collapsed": true 44 | }, 45 | "outputs": [], 46 | "source": [ 47 | "def state_action_value(env):\n", 48 | " q = dict()\n", 49 | " for state, action, next_state, reward in env.transitions:\n", 50 | " q[(state, action)] = np.random.normal()\n", 51 | " for action in env.actions:\n", 52 | " q[0, action] = 0\n", 53 | " return q" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 4, 59 | "metadata": { 60 | "scrolled": true 61 | }, 62 | "outputs": [ 63 | { 64 | "data": { 65 | "text/plain": [ 66 | "{(0, 'D'): 0,\n", 67 | " (0, 'L'): 0,\n", 68 | " (0, 'R'): 0,\n", 69 | " (0, 'U'): 0,\n", 70 | " (1, 'D'): -1.4931958690513218,\n", 71 | " (1, 'L'): 0.45253247253698986,\n", 72 | " (1, 'R'): -1.789094792083647,\n", 73 | " (1, 'U'): 1.9103660029206884,\n", 74 | " (2, 'D'): -0.3082624959814856,\n", 75 | " (2, 'L'): -0.7483834716798741,\n", 76 | " (2, 'R'): 0.9839358952573672,\n", 77 | " (2, 'U'): -0.2724328270166447,\n", 78 | " (3, 'D'): -0.6283488057971491,\n", 79 | " (3, 'L'): -1.3156943242567214,\n", 80 | " (3, 'R'): 0.6211123056414489,\n", 81 | " (3, 'U'): 0.7976038544679848,\n", 82 | " (4, 'D'): -1.050452706273533,\n", 83 | " (4, 'L'): -0.1741895951805081,\n", 84 | " (4, 'R'): 1.8182867323493,\n", 85 | " (4, 'U'): 1.801387714646261,\n", 86 | " (5, 'D'): 0.10021107202221527,\n", 87 | " (5, 'L'): -0.47465610710346245,\n", 88 | " (5, 'R'): -0.6918872493076375,\n", 89 | " (5, 'U'): -3.2003225504816597,\n", 90 | " (6, 'D'): 0.6278525454978494,\n", 91 | " (6, 'L'): -0.24213713052505115,\n", 92 | " (6, 'R'): -2.630728473016553,\n", 93 | " (6, 'U'): 1.2604497963698,\n", 94 | " (7, 'D'): 0.7220353661454653,\n", 95 | " (7, 'L'): -0.046864386969943536,\n", 96 | " (7, 'R'): 0.3648650012222875,\n", 97 | " (7, 'U'): -0.13955201414490984,\n", 98 | " (8, 'D'): -0.73401777416155,\n", 99 | " (8, 'L'): 0.32247567696601837,\n", 100 | " (8, 'R'): 0.5780991299024739,\n", 101 | " (8, 'U'): -0.44090209956778376,\n", 102 | " (9, 'D'): 1.4691269385027337,\n", 103 | " (9, 'L'): 0.052954400141560456,\n", 104 | " (9, 'R'): 0.13195379460282597,\n", 105 | " (9, 'U'): -0.3923944749627829,\n", 106 | " (10, 'D'): 1.7199147774865016,\n", 107 | " (10, 'L'): -1.9247987278054801,\n", 108 | " (10, 'R'): -0.143510086697551,\n", 109 | " (10, 'U'): -0.5647971071775687,\n", 110 | " (11, 'D'): -0.5741454748190717,\n", 111 | " (11, 'L'): -1.5720736584714539,\n", 112 | " (11, 'R'): -1.4601134792863597,\n", 113 | " (11, 'U'): -1.257386467083216,\n", 114 | " (12, 'D'): -0.7472490605228015,\n", 115 | " (12, 'L'): -0.22563599050611213,\n", 116 | " (12, 'R'): 0.4395772978492674,\n", 117 | " (12, 'U'): -0.7210157758438556,\n", 118 | " (13, 'D'): 0.6291351870645969,\n", 119 | " (13, 'L'): -1.1313917732858065,\n", 120 | " (13, 'R'): 1.7365108020870084,\n", 121 | " (13, 'U'): 0.1339793058824657,\n", 122 | " (14, 'D'): 1.028938708862189,\n", 123 | " (14, 'L'): 0.050033944728027975,\n", 124 | " (14, 'R'): 1.3081957299962321,\n", 125 | " (14, 'U'): 0.05849306927859837}" 126 | ] 127 | }, 128 | "execution_count": 4, 129 | "metadata": {}, 130 | "output_type": "execute_result" 131 | } 132 | ], 133 | "source": [ 134 | "state_action_value(gw)" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 5, 140 | "metadata": { 141 | "collapsed": true 142 | }, 143 | "outputs": [], 144 | "source": [ 145 | "def generate_greedy_policy(env, Q):\n", 146 | " pi = dict()\n", 147 | " for state in env.states:\n", 148 | " actions = []\n", 149 | " q_values = []\n", 150 | " prob = []\n", 151 | " \n", 152 | " for a in env.actions:\n", 153 | " actions.append(a)\n", 154 | " q_values.append(Q[state,a]) \n", 155 | " for i in range(len(q_values)):\n", 156 | " if i == np.argmax(q_values):\n", 157 | " prob.append(1)\n", 158 | " else:\n", 159 | " prob.append(0) \n", 160 | " \n", 161 | " pi[state] = (actions, prob)\n", 162 | " return pi" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 6, 168 | "metadata": { 169 | "collapsed": true 170 | }, 171 | "outputs": [], 172 | "source": [ 173 | "# selects action epsilon-greedily, given current state\n", 174 | "def e_greedy(env, e, q, state):\n", 175 | " actions = env.actions\n", 176 | " action_values = []\n", 177 | " prob = []\n", 178 | " for action in actions:\n", 179 | " action_values.append(q[(state, action)])\n", 180 | " for i in range(len(action_values)):\n", 181 | " if i == np.argmax(action_values):\n", 182 | " prob.append(1 - e + e/len(action_values))\n", 183 | " else:\n", 184 | " prob.append(e/len(action_values))\n", 185 | " return np.random.choice(actions, p = prob)" 186 | ] 187 | }, 188 | { 189 | "cell_type": "markdown", 190 | "metadata": { 191 | "collapsed": true 192 | }, 193 | "source": [ 194 | "### SARSA: On-policy TD Control\n", 195 | "- Evaluates action-value function ($Q$)" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 7, 201 | "metadata": { 202 | "collapsed": true 203 | }, 204 | "outputs": [], 205 | "source": [ 206 | "def sarsa(env, epsilon, alpha, num_iter):\n", 207 | " Q = state_action_value(env)\n", 208 | " for _ in range(num_iter):\n", 209 | " current_state = np.random.choice(env.states)\n", 210 | " current_action = e_greedy(env, epsilon, Q, current_state)\n", 211 | " while current_state != 0:\n", 212 | " next_state, reward = env.state_transition(current_state, current_action)\n", 213 | " next_action = e_greedy(env, epsilon, Q, next_state)\n", 214 | " Q[current_state, current_action] += alpha * (reward + env.gamma * Q[next_state, next_action] - Q[current_state, current_action])\n", 215 | " current_state, current_action = next_state, next_action\n", 216 | " return Q" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": 8, 222 | "metadata": { 223 | "collapsed": true 224 | }, 225 | "outputs": [], 226 | "source": [ 227 | "Q = sarsa(gw, 0.1, 0.5, 10000)" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": 9, 233 | "metadata": { 234 | "collapsed": true 235 | }, 236 | "outputs": [], 237 | "source": [ 238 | "pi_hat = generate_greedy_policy(gw, Q)" 239 | ] 240 | }, 241 | { 242 | "cell_type": "markdown", 243 | "metadata": {}, 244 | "source": [ 245 | "### Visualizing policy" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": 10, 251 | "metadata": { 252 | "collapsed": true 253 | }, 254 | "outputs": [], 255 | "source": [ 256 | "def show_policy(pi, env):\n", 257 | " temp = np.zeros(len(env.states) + 2)\n", 258 | " for s in env.states:\n", 259 | " a = pi_hat[s][0][np.argmax(pi_hat[s][1])]\n", 260 | " if a == \"U\":\n", 261 | " temp[s] = 0.25\n", 262 | " elif a == \"D\":\n", 263 | " temp[s] = 0.5\n", 264 | " elif a == \"R\":\n", 265 | " temp[s] = 0.75\n", 266 | " else:\n", 267 | " temp[s] = 1.0\n", 268 | " \n", 269 | " temp = temp.reshape(4,4)\n", 270 | " ax = seaborn.heatmap(temp, cmap = \"prism\", linecolor=\"#282828\", cbar = False, linewidths = 0.1)\n", 271 | " plt.show()" 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": 11, 277 | "metadata": {}, 278 | "outputs": [ 279 | { 280 | "data": { 281 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAdwAAAFJCAYAAAAxCJwFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACs5JREFUeJzt3UtoXIXfxvFfmhEV4wUsSBGEbpoMCCku3HURaLwUJFQK\ntkoUJCsXbaEWiakXvLR1JXgJVbty4T+uLC6EYm1RUHBRaCEwUUQQbxtX0hab1pn/QuiLL01ek5d5\nxkw/n13OoWceegjfnkNCBzqdTqcAgK5a1+sBAHAtEFwACBBcAAgQXAAIEFwACBBcAAhodPPizWaz\nWgsL3fwIuqQ5MlJ3t9y7tWq+6f6tVfPNkaqqan3k/q1Fze0j1Wq1rnrOEy4ABAguAAQILgAECC4A\nBAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAE\nCC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQI\nLgAECC4ABPzj4Lbb7W7uAIC+1lju5I8//liHDh2q+fn5ajQa1W63a9OmTTU9PV0bN25MbQSANW/Z\n4M7MzNS+fftqdHT0yrEzZ87U9PR0zc3NdX0cAPSLZV8pLy4u/i22VVWbN2/u6iAA6EfLPuEODw/X\n9PR0bdmypW6++eY6f/58ff755zU8PJzaBwB9Ydngvvjii3XixIk6ffp0nTt3roaGhmpsbKzGx8dT\n+wCgLywb3IGBgRofHxdYAPh/8nu4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4\nABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgA\nECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQMNDpdDrduniz2ezWpQHgX6nVal31eKPb\nH7wwerTbH0EXjJydqrtbC72ewSrNN0fcvzVqvjlSVVWtj9y/tai5fWTJc14pA0CA4AJAgOACQIDg\nAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOAC\nQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJA\ngOACQIDgAkCA4AJAgOACQEBjuZOTk5N16dKlvx3rdDo1MDBQc3NzXR0GAP1k2eA+/fTTdeDAgXr7\n7bdrcHAwtQkA+s6ywR0dHa2JiYn65ptvanx8PLUJAPrOssGtqpqamkrsAIC+5oemACBAcAEgQHAB\nIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEg\nQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBA\ncAEgQHABIGCg0+l0unXxZrPZrUsDwL9Sq9W66vFGtz94YfRotz+CLhg5O1V3txZ6PYNVmm+O1MzC\nf3o9g1V4dWRXVVW1PvL9txY1t48sec4rZQAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBc\nAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwA\nCBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAhYcXAXFxe7sQMA+tqSwT15\n8mSNjY3V+Ph4ffLJJ1eOT01NRYYBQD9pLHXiyJEjdezYsWq327Vnz566ePFibd++vTqdTnIfAPSF\nJYN73XXX1a233lpVVbOzs/XEE0/Uhg0bamBgIDYOAPrFkq+U77zzzjp06FBduHChhoaG6q233qqX\nXnqpvv/+++Q+AOgLSwb34MGDNTw8fOWJdsOGDfX+++/Xgw8+GBsHAP1iyVfKjUajHn744b8dW79+\nfc3MzHR9FAD0G7+HCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYIL\nAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsA\nAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAEDnU6n062LN5vNbl0aAP6VWq3WVY83uv3BC6NH\nu/0RdMHI2amaWfhPr2ewSq+O7HL/1qhXR3ZVVVVrYaHHS1iN5sjIkue8UgaAAMEFgADBBYAAwQWA\nAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAA\nwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADB\nBYCAFQX3jz/+qMXFxW5tAYC+tWxwv/vuu3rqqadqenq6vvrqq9q2bVtt27atTp06ldoHAH2hsdzJ\nF154ofbs2VM///xz7d69u44fP17XX399TU1N1djYWGojAKx5ywa33W7XvffeW1VVX3/9dd1+++1/\n/aHGsn8MAPhfln2lvHHjxpqZmal2u12HDx+uqqp333231q9fHxkHAP1i2UfVV155pU6ePFnr1v1P\nl++4446anJzs+jAA6CfLBnfdunW1devWvx2bmJjo6iAA6Ed+DxcAAgQXAAIEFwACBBcAAgQXAAIE\nFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQX\nAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIGOp1Op1sX\nbzab3bo0APwrtVqtqx7vanABgL94pQwAAYILAAGCCwABggsAAYILAAGCCwABgrsK7Xa7nn/++Xrk\nkUdqcnKyfvjhh15PYoXOnj1bk5OTvZ7BCl26dKn2799fjz76aO3YsaM+++yzXk/iH/rzzz9renq6\ndu7cWbt27apvv/2215PiBHcVTpw4UYuLi/Xhhx/Wvn376vDhw72exAq89957deDAgbp48WKvp7BC\nH3/8cd122231wQcf1NGjR+vll1/u9ST+oVOnTlVV1dzcXO3du7def/31Hi/KE9xVOH36dG3ZsqWq\nqjZv3lzz8/M9XsRK3HXXXfXmm2/2egar8MADD9SePXuqqqrT6dTg4GCPF/FPbd269co/kH755Ze6\n5ZZberwor9HrAWvRuXPnamho6MrXg4ODdfny5Wo0/HWuBffff3/99NNPvZ7BKtx0001V9df34O7d\nu2vv3r09XsRKNBqNeuaZZ+rTTz+tN954o9dz4jzhrsLQ0FCdP3/+ytftdltsIeTXX3+txx9/vCYm\nJuqhhx7q9RxW6LXXXqvjx4/Xc889VxcuXOj1nCjBXYV77rmnvvjii6qqOnPmTG3atKnHi+Da8Ntv\nv9WTTz5Z+/fvrx07dvR6Ditw7Nixeuedd6qq6sYbb6yBgYFat+7aSpDHslUYHx+vL7/8snbu3Fmd\nTqcOHjzY60lwTThy5Ej9/vvvNTs7W7Ozs1X11w/B3XDDDT1exv/lvvvuq+np6Xrsscfq8uXL9eyz\nz15z983/FgQAAdfW8zwA9IjgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkDAfwHhhKWt2U7SnQAA\nAABJRU5ErkJggg==\n", 282 | "text/plain": [ 283 | "" 284 | ] 285 | }, 286 | "metadata": {}, 287 | "output_type": "display_data" 288 | } 289 | ], 290 | "source": [ 291 | "### RED = TERMINAL (0)\n", 292 | "### GREEN = LEFT\n", 293 | "### BLUE = UP\n", 294 | "### PURPLE = RIGHT\n", 295 | "### ORANGE = DOWN\n", 296 | "\n", 297 | "show_policy(pi_hat, gw)" 298 | ] 299 | } 300 | ], 301 | "metadata": { 302 | "kernelspec": { 303 | "display_name": "Python 3", 304 | "language": "python", 305 | "name": "python3" 306 | }, 307 | "language_info": { 308 | "codemirror_mode": { 309 | "name": "ipython", 310 | "version": 3 311 | }, 312 | "file_extension": ".py", 313 | "mimetype": "text/x-python", 314 | "name": "python", 315 | "nbconvert_exporter": "python", 316 | "pygments_lexer": "ipython3", 317 | "version": "3.6.1" 318 | } 319 | }, 320 | "nbformat": 4, 321 | "nbformat_minor": 2 322 | } 323 | -------------------------------------------------------------------------------- /source code/3-Temporal Difference Learning (Chapter 6)/.ipynb_checkpoints/4-double-Q-learning-off-policy-control-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Double Q-Learning\n", 8 | "- Algorithms from ```pp. 110 - 111``` in Sutton & Barto 2017\n", 9 | "- Double Q-learning algorithm employs to action-value functions (i.e., $Q_1 and Q_2$) to avoid maximization bias" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 15, 15 | "metadata": { 16 | "collapsed": true 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "import matplotlib.pyplot as plt\n", 21 | "import pandas as pd\n", 22 | "import numpy as np\n", 23 | "import seaborn, random\n", 24 | "\n", 25 | "from gridWorldEnvironment import GridWorld" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 16, 31 | "metadata": { 32 | "collapsed": true 33 | }, 34 | "outputs": [], 35 | "source": [ 36 | "# creating gridworld environment\n", 37 | "gw = GridWorld(gamma = .9, theta = .5)" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 17, 43 | "metadata": { 44 | "collapsed": true 45 | }, 46 | "outputs": [], 47 | "source": [ 48 | "def state_action_value(env):\n", 49 | " q = dict()\n", 50 | " for state, action, next_state, reward in env.transitions:\n", 51 | " q[(state, action)] = np.random.normal()\n", 52 | " for action in env.actions:\n", 53 | " q[0, action] = 0\n", 54 | " return q" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 18, 60 | "metadata": { 61 | "scrolled": true 62 | }, 63 | "outputs": [ 64 | { 65 | "data": { 66 | "text/plain": [ 67 | "{(0, 'D'): 0,\n", 68 | " (0, 'L'): 0,\n", 69 | " (0, 'R'): 0,\n", 70 | " (0, 'U'): 0,\n", 71 | " (1, 'D'): 1.0159263422385194,\n", 72 | " (1, 'L'): -0.9015277224413166,\n", 73 | " (1, 'R'): -0.8401929147381398,\n", 74 | " (1, 'U'): -0.3866748964867951,\n", 75 | " (2, 'D'): -1.6171831004555488,\n", 76 | " (2, 'L'): -0.757413177831847,\n", 77 | " (2, 'R'): 1.1948020656778442,\n", 78 | " (2, 'U'): -0.6814167904466197,\n", 79 | " (3, 'D'): 0.8048046240684537,\n", 80 | " (3, 'L'): 0.5058075927359411,\n", 81 | " (3, 'R'): -0.2396998941031068,\n", 82 | " (3, 'U'): 0.2155229655857796,\n", 83 | " (4, 'D'): -1.6147931728590075,\n", 84 | " (4, 'L'): -0.7788455232241043,\n", 85 | " (4, 'R'): -1.4834265811813951,\n", 86 | " (4, 'U'): -0.6831094436413707,\n", 87 | " (5, 'D'): 0.553787569878218,\n", 88 | " (5, 'L'): -0.7535460982740707,\n", 89 | " (5, 'R'): -0.5446399864366045,\n", 90 | " (5, 'U'): 0.3275946472542217,\n", 91 | " (6, 'D'): 0.4967878126487371,\n", 92 | " (6, 'L'): -0.08765553689024165,\n", 93 | " (6, 'R'): 2.561671991266108,\n", 94 | " (6, 'U'): -0.6251803976764682,\n", 95 | " (7, 'D'): -1.4239033515729715,\n", 96 | " (7, 'L'): 0.8961390571157126,\n", 97 | " (7, 'R'): -1.752510423129062,\n", 98 | " (7, 'U'): 0.186255694272231,\n", 99 | " (8, 'D'): 2.0464864623182057,\n", 100 | " (8, 'L'): -1.4145441882793333,\n", 101 | " (8, 'R'): -0.01479790948318322,\n", 102 | " (8, 'U'): -0.5827888737714622,\n", 103 | " (9, 'D'): -1.99334821662137,\n", 104 | " (9, 'L'): -0.2143702145160287,\n", 105 | " (9, 'R'): 0.3093816164828963,\n", 106 | " (9, 'U'): -0.13810316479560683,\n", 107 | " (10, 'D'): -0.4286970274384264,\n", 108 | " (10, 'L'): -1.0313201895326287,\n", 109 | " (10, 'R'): 0.674733204171753,\n", 110 | " (10, 'U'): -0.06485579239030625,\n", 111 | " (11, 'D'): -1.08761491406436,\n", 112 | " (11, 'L'): 0.37568985747762257,\n", 113 | " (11, 'R'): 0.6549082083274831,\n", 114 | " (11, 'U'): -0.044872797471142374,\n", 115 | " (12, 'D'): -0.8474472601122482,\n", 116 | " (12, 'L'): -0.6624032631781939,\n", 117 | " (12, 'R'): 0.4720691318622512,\n", 118 | " (12, 'U'): -1.2333806138625314,\n", 119 | " (13, 'D'): -0.38599325593593187,\n", 120 | " (13, 'L'): -1.677564721893633,\n", 121 | " (13, 'R'): -1.1425774085316547,\n", 122 | " (13, 'U'): -1.3606173055898687,\n", 123 | " (14, 'D'): -1.1641372208339686,\n", 124 | " (14, 'L'): 1.1457975665987994,\n", 125 | " (14, 'R'): -1.012618770180745,\n", 126 | " (14, 'U'): -0.06091464445082383}" 127 | ] 128 | }, 129 | "execution_count": 18, 130 | "metadata": {}, 131 | "output_type": "execute_result" 132 | } 133 | ], 134 | "source": [ 135 | "state_action_value(gw)" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 19, 141 | "metadata": { 142 | "collapsed": true 143 | }, 144 | "outputs": [], 145 | "source": [ 146 | "def generate_greedy_policy(env, Q):\n", 147 | " pi = dict()\n", 148 | " for state in env.states:\n", 149 | " actions = []\n", 150 | " q_values = []\n", 151 | " prob = []\n", 152 | " \n", 153 | " for a in env.actions:\n", 154 | " actions.append(a)\n", 155 | " q_values.append(Q[state,a]) \n", 156 | " for i in range(len(q_values)):\n", 157 | " if i == np.argmax(q_values):\n", 158 | " prob.append(1)\n", 159 | " else:\n", 160 | " prob.append(0) \n", 161 | " \n", 162 | " pi[state] = (actions, prob)\n", 163 | " return pi" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 20, 169 | "metadata": { 170 | "collapsed": true 171 | }, 172 | "outputs": [], 173 | "source": [ 174 | "def e_greedy(env, e, q, state):\n", 175 | " actions = env.actions\n", 176 | " action_values = []\n", 177 | " prob = []\n", 178 | " for action in actions:\n", 179 | " action_values.append(q[(state, action)])\n", 180 | " for i in range(len(action_values)):\n", 181 | " if i == np.argmax(action_values):\n", 182 | " prob.append(1 - e + e/len(action_values))\n", 183 | " else:\n", 184 | " prob.append(e/len(action_values))\n", 185 | " return np.random.choice(actions, p = prob)" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": 21, 191 | "metadata": { 192 | "collapsed": true 193 | }, 194 | "outputs": [], 195 | "source": [ 196 | "def greedy(env, q, state):\n", 197 | " actions = env.actions\n", 198 | " action_values = []\n", 199 | " for action in actions:\n", 200 | " action_values.append(q[state, action])\n", 201 | " return actions[np.argmax(action_values)]" 202 | ] 203 | }, 204 | { 205 | "cell_type": "markdown", 206 | "metadata": {}, 207 | "source": [ 208 | "### Double Q-learning" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": 34, 214 | "metadata": { 215 | "collapsed": true 216 | }, 217 | "outputs": [], 218 | "source": [ 219 | "def double_q_learning(env, epsilon, alpha, num_iter):\n", 220 | " Q1, Q2 = state_action_value(env), state_action_value(env)\n", 221 | " for _ in range(num_iter):\n", 222 | " current_state = np.random.choice(env.states)\n", 223 | " while current_state != 0:\n", 224 | " Q = dict()\n", 225 | " for key in Q1.keys():\n", 226 | " Q[key] = Q1[key] + Q2[key]\n", 227 | " current_action = e_greedy(env, epsilon, Q, current_state)\n", 228 | " next_state, reward = env.state_transition(current_state, current_action)\n", 229 | " \n", 230 | " # choose Q1 or Q2 with equal probabilities (0.5)\n", 231 | " chosen_Q = np.random.choice([\"Q1\", \"Q2\"])\n", 232 | " if chosen_Q == \"Q1\": # when Q1 is chosen\n", 233 | " best_action = greedy(env, Q1, next_state)\n", 234 | " Q1[current_state, current_action] += alpha * \\\n", 235 | " (reward + env.gamma * Q2[next_state, best_action] - Q1[current_state, current_action])\n", 236 | " else: # when Q2 is chosen\n", 237 | " best_action = greedy(env, Q2, next_state)\n", 238 | " Q2[current_state, current_action] += alpha * \\\n", 239 | " (reward + env.gamma * Q1[next_state, best_action] - Q2[current_state, current_action])\n", 240 | " \n", 241 | " current_state = next_state\n", 242 | " return Q1, Q2" 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": 38, 248 | "metadata": { 249 | "collapsed": true 250 | }, 251 | "outputs": [], 252 | "source": [ 253 | "Q1, Q2 = double_q_learning(gw, 0.2, 0.5, 5000)\n", 254 | "\n", 255 | "# sum Q1 & Q2 elementwise to obtain final Q-values\n", 256 | "Q = dict()\n", 257 | "for key in Q1.keys():\n", 258 | " Q[key] = Q1[key] + Q2[key]" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": 39, 264 | "metadata": { 265 | "collapsed": true 266 | }, 267 | "outputs": [], 268 | "source": [ 269 | "pi_hat = generate_greedy_policy(gw, Q)" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": 32, 275 | "metadata": { 276 | "collapsed": true 277 | }, 278 | "outputs": [], 279 | "source": [ 280 | "def show_policy(pi, env):\n", 281 | " temp = np.zeros(len(env.states) + 2)\n", 282 | " for s in env.states:\n", 283 | " a = pi_hat[s][0][np.argmax(pi_hat[s][1])]\n", 284 | " if a == \"U\":\n", 285 | " temp[s] = 0.25\n", 286 | " elif a == \"D\":\n", 287 | " temp[s] = 0.5\n", 288 | " elif a == \"R\":\n", 289 | " temp[s] = 0.75\n", 290 | " else:\n", 291 | " temp[s] = 1.0\n", 292 | " \n", 293 | " temp = temp.reshape(4,4)\n", 294 | " ax = seaborn.heatmap(temp, cmap = \"prism\", linecolor=\"#282828\", cbar = False, linewidths = 0.1)\n", 295 | " plt.show()" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": 40, 301 | "metadata": {}, 302 | "outputs": [ 303 | { 304 | "data": { 305 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAdwAAAFJCAYAAAAxCJwFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACrhJREFUeJzt3U2IlYUex/H/OCcqml4gISQI3Dhz4ILSop2LAacXIQZD\nSIspiFm1cASTOI29UF21VdDLYOWqRXdaXWkRSKYUFLQQFAaORQTR26ZVqORo59xFXC9ddG4zl/M7\nzfHz2c3zMM/54YN8fR5mcKjb7XYLAOipNf0eAADXAsEFgADBBYAAwQWAAMEFgADBBYCARi8v3mw2\nq33mTC8/gh5pjo3V39ru3Wq10HT/VquF5lhVlfu3Si00x6rdbl/xnCdcAAgQXAAIEFwACBBcAAgQ\nXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBc\nAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwA\nCBBcAAj408HtdDq93AEAA62x1MnvvvuuDhw4UAsLC9VoNKrT6dSGDRuq1WrV+vXrUxsBYNVbMriz\ns7O1Z8+e2rhx4+Vjp06dqlarVfPz8z0fBwCDYslXyouLi3+IbVXVpk2bejoIAAbRkk+4o6Oj1Wq1\navPmzXXzzTfXuXPn6pNPPqnR0dHUPgAYCEsG94UXXqhjx47VyZMn6+zZszUyMlLj4+M1MTGR2gcA\nA2HJ4A4NDdXExITAAsD/ye/hAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA\n4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDg\nAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAwFC32+326uLNZrNXlwaAv6R2u33F441ef/CZ\njYd7/RH0wNjpafduFXP/Vq+x09NVVdX+55k+L2ElmtvGrnrOK2UACBBcAAgQXAAIEFwACBBcAAgQ\nXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBc\nAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwA\nCBBcAAgQXAAIaCx1cmpqqi5evPiHY91ut4aGhmp+fr6nwwBgkCwZ3Keeeqr27dtXb775Zg0PD6c2\nAcDAWTK4GzdurMnJyfryyy9rYmIitQkABs6Swa2qmp6eTuwAgIHmh6YAIEBwASBAcAEgQHABIEBw\nASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHAB\nIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEg\nYKjb7XZ7dfFms9mrSwPAX1K73b7i8UavP/jMxsO9/gh6YOz0dP2tfabfM1ihheZYtf/p/q1GzW1j\nVVXu3yr17/t3JV4pA0CA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOAC\nQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJA\ngOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQMCyg7u4uNiLHQAw0K4a3OPHj9f4+HhNTEzUhx9+\nePn49PR0ZBgADJLG1U4cOnSojhw5Up1Op2ZmZurChQu1bdu26na7yX0AMBCuGtzrrruubr311qqq\nmpubq8cff7zWrVtXQ0NDsXEAMCiu+kr5zjvvrAMHDtT58+drZGSk3njjjXrxxRfrm2++Se4DgIFw\n1eDu37+/RkdHLz/Rrlu3rt5999164IEHYuMAYFBc9ZVyo9Gohx566A/H1q5dW7Ozsz0fBQCDxu/h\nAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOAC\nQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJA\ngOACQIDgAkCA4AJAgOACQMBQt9vt9urizWazV5cGgL+kdrt9xeONXn/w7Jl/9Poj6IG/j+1071Yx\n92/1+vvYzqqqap850+clrERzbOyq57xSBoAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEF\ngADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWA\nAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgIBlBffXX3+txcXFXm0BgIG1\nZHC//vrrevLJJ6vVatXnn39eW7dura1bt9aJEydS+wBgIDSWOvn888/XzMxM/fDDD7Vr1646evRo\nXX/99TU9PV3j4+OpjQCw6i0Z3E6nU/fcc09VVX3xxRd1++23//5NjSW/DQD4L0u+Ul6/fn3Nzs5W\np9OpgwcPVlXV22+/XWvXro2MA4BBseSj6ssvv1zHjx+vNWv+0+U77rijpqamej4MAAbJksFds2ZN\nbdmy5Q/HJicnezoIAAaR38MFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADB\nBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEF\ngADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYCAoW632+3VxZvNZq8uDQB/Se12+4rHexpcAOB3\nXikDQIDgAkCA4AJAgOACQIDgAkCA4AJAgOCuQKfTqeeee64efvjhmpqaqm+//bbfk1im06dP19TU\nVL9nsEwXL16svXv31iOPPFLbt2+vjz/+uN+T+JN+++23arVatWPHjtq5c2d99dVX/Z4UJ7grcOzY\nsVpcXKz333+/9uzZUwcPHuz3JJbhnXfeqX379tWFCxf6PYVl+uCDD+q2226r9957rw4fPlwvvfRS\nvyfxJ504caKqqubn52v37t316quv9nlRnuCuwMmTJ2vz5s1VVbVp06ZaWFjo8yKW46677qrXX3+9\n3zNYgfvvv79mZmaqqqrb7dbw8HCfF/Fnbdmy5fI/kH788ce65ZZb+rwor9HvAavR2bNna2Rk5PLX\nw8PDdenSpWo0/HGuBvfdd199//33/Z7BCtx0001V9fvfwV27dtXu3bv7vIjlaDQa9fTTT9dHH31U\nr732Wr/nxHnCXYGRkZE6d+7c5a87nY7YQshPP/1Ujz32WE1OTtaDDz7Y7zks0yuvvFJHjx6tZ599\nts6fP9/vOVGCuwJ33313ffrpp1VVderUqdqwYUOfF8G14eeff64nnnii9u7dW9u3b+/3HJbhyJEj\n9dZbb1VV1Y033lhDQ0O1Zs21lSCPZSswMTFRn332We3YsaO63W7t37+/35PgmnDo0KH65Zdfam5u\nrubm5qrq9x+Cu+GGG/q8jP/l3nvvrVarVY8++mhdunSpnnnmmWvuvvnfggAg4Np6ngeAPhFcAAgQ\nXAAIEFwACBBcAAgQXAAIEFwACBBcAAj4F4Bkpa2zV4RVAAAAAElFTkSuQmCC\n", 306 | "text/plain": [ 307 | "" 308 | ] 309 | }, 310 | "metadata": {}, 311 | "output_type": "display_data" 312 | } 313 | ], 314 | "source": [ 315 | "### RED = TERMINAL (0)\n", 316 | "### GREEN = LEFT\n", 317 | "### BLUE = UP\n", 318 | "### PURPLE = RIGHT\n", 319 | "### ORANGE = DOWN\n", 320 | "\n", 321 | "show_policy(pi_hat, gw)" 322 | ] 323 | } 324 | ], 325 | "metadata": { 326 | "kernelspec": { 327 | "display_name": "Python 3", 328 | "language": "python", 329 | "name": "python3" 330 | }, 331 | "language_info": { 332 | "codemirror_mode": { 333 | "name": "ipython", 334 | "version": 3 335 | }, 336 | "file_extension": ".py", 337 | "mimetype": "text/x-python", 338 | "name": "python", 339 | "nbconvert_exporter": "python", 340 | "pygments_lexer": "ipython3", 341 | "version": "3.6.1" 342 | } 343 | }, 344 | "nbformat": 4, 345 | "nbformat_minor": 2 346 | } 347 | -------------------------------------------------------------------------------- /source code/3-Temporal Difference Learning (Chapter 6)/2-SARSA-on-policy control.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# SARSA\n", 8 | "- Algorithms from ```pp. 105 - 107``` in Sutton & Barto 2017" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "metadata": { 15 | "collapsed": true 16 | }, 17 | "outputs": [], 18 | "source": [ 19 | "import matplotlib.pyplot as plt\n", 20 | "import pandas as pd\n", 21 | "import numpy as np\n", 22 | "import seaborn, random\n", 23 | "\n", 24 | "from gridWorldEnvironment import GridWorld" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 2, 30 | "metadata": { 31 | "collapsed": true 32 | }, 33 | "outputs": [], 34 | "source": [ 35 | "# creating gridworld environment\n", 36 | "gw = GridWorld(gamma = .9, theta = .5)" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 3, 42 | "metadata": { 43 | "collapsed": true 44 | }, 45 | "outputs": [], 46 | "source": [ 47 | "def state_action_value(env):\n", 48 | " q = dict()\n", 49 | " for state, action, next_state, reward in env.transitions:\n", 50 | " q[(state, action)] = np.random.normal()\n", 51 | " for action in env.actions:\n", 52 | " q[0, action] = 0\n", 53 | " return q" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 4, 59 | "metadata": { 60 | "scrolled": true 61 | }, 62 | "outputs": [ 63 | { 64 | "data": { 65 | "text/plain": [ 66 | "{(0, 'D'): 0,\n", 67 | " (0, 'L'): 0,\n", 68 | " (0, 'R'): 0,\n", 69 | " (0, 'U'): 0,\n", 70 | " (1, 'D'): -1.4931958690513218,\n", 71 | " (1, 'L'): 0.45253247253698986,\n", 72 | " (1, 'R'): -1.789094792083647,\n", 73 | " (1, 'U'): 1.9103660029206884,\n", 74 | " (2, 'D'): -0.3082624959814856,\n", 75 | " (2, 'L'): -0.7483834716798741,\n", 76 | " (2, 'R'): 0.9839358952573672,\n", 77 | " (2, 'U'): -0.2724328270166447,\n", 78 | " (3, 'D'): -0.6283488057971491,\n", 79 | " (3, 'L'): -1.3156943242567214,\n", 80 | " (3, 'R'): 0.6211123056414489,\n", 81 | " (3, 'U'): 0.7976038544679848,\n", 82 | " (4, 'D'): -1.050452706273533,\n", 83 | " (4, 'L'): -0.1741895951805081,\n", 84 | " (4, 'R'): 1.8182867323493,\n", 85 | " (4, 'U'): 1.801387714646261,\n", 86 | " (5, 'D'): 0.10021107202221527,\n", 87 | " (5, 'L'): -0.47465610710346245,\n", 88 | " (5, 'R'): -0.6918872493076375,\n", 89 | " (5, 'U'): -3.2003225504816597,\n", 90 | " (6, 'D'): 0.6278525454978494,\n", 91 | " (6, 'L'): -0.24213713052505115,\n", 92 | " (6, 'R'): -2.630728473016553,\n", 93 | " (6, 'U'): 1.2604497963698,\n", 94 | " (7, 'D'): 0.7220353661454653,\n", 95 | " (7, 'L'): -0.046864386969943536,\n", 96 | " (7, 'R'): 0.3648650012222875,\n", 97 | " (7, 'U'): -0.13955201414490984,\n", 98 | " (8, 'D'): -0.73401777416155,\n", 99 | " (8, 'L'): 0.32247567696601837,\n", 100 | " (8, 'R'): 0.5780991299024739,\n", 101 | " (8, 'U'): -0.44090209956778376,\n", 102 | " (9, 'D'): 1.4691269385027337,\n", 103 | " (9, 'L'): 0.052954400141560456,\n", 104 | " (9, 'R'): 0.13195379460282597,\n", 105 | " (9, 'U'): -0.3923944749627829,\n", 106 | " (10, 'D'): 1.7199147774865016,\n", 107 | " (10, 'L'): -1.9247987278054801,\n", 108 | " (10, 'R'): -0.143510086697551,\n", 109 | " (10, 'U'): -0.5647971071775687,\n", 110 | " (11, 'D'): -0.5741454748190717,\n", 111 | " (11, 'L'): -1.5720736584714539,\n", 112 | " (11, 'R'): -1.4601134792863597,\n", 113 | " (11, 'U'): -1.257386467083216,\n", 114 | " (12, 'D'): -0.7472490605228015,\n", 115 | " (12, 'L'): -0.22563599050611213,\n", 116 | " (12, 'R'): 0.4395772978492674,\n", 117 | " (12, 'U'): -0.7210157758438556,\n", 118 | " (13, 'D'): 0.6291351870645969,\n", 119 | " (13, 'L'): -1.1313917732858065,\n", 120 | " (13, 'R'): 1.7365108020870084,\n", 121 | " (13, 'U'): 0.1339793058824657,\n", 122 | " (14, 'D'): 1.028938708862189,\n", 123 | " (14, 'L'): 0.050033944728027975,\n", 124 | " (14, 'R'): 1.3081957299962321,\n", 125 | " (14, 'U'): 0.05849306927859837}" 126 | ] 127 | }, 128 | "execution_count": 4, 129 | "metadata": {}, 130 | "output_type": "execute_result" 131 | } 132 | ], 133 | "source": [ 134 | "state_action_value(gw)" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 5, 140 | "metadata": { 141 | "collapsed": true 142 | }, 143 | "outputs": [], 144 | "source": [ 145 | "def generate_greedy_policy(env, Q):\n", 146 | " pi = dict()\n", 147 | " for state in env.states:\n", 148 | " actions = []\n", 149 | " q_values = []\n", 150 | " prob = []\n", 151 | " \n", 152 | " for a in env.actions:\n", 153 | " actions.append(a)\n", 154 | " q_values.append(Q[state,a]) \n", 155 | " for i in range(len(q_values)):\n", 156 | " if i == np.argmax(q_values):\n", 157 | " prob.append(1)\n", 158 | " else:\n", 159 | " prob.append(0) \n", 160 | " \n", 161 | " pi[state] = (actions, prob)\n", 162 | " return pi" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 6, 168 | "metadata": { 169 | "collapsed": true 170 | }, 171 | "outputs": [], 172 | "source": [ 173 | "# selects action epsilon-greedily, given current state\n", 174 | "def e_greedy(env, e, q, state):\n", 175 | " actions = env.actions\n", 176 | " action_values = []\n", 177 | " prob = []\n", 178 | " for action in actions:\n", 179 | " action_values.append(q[(state, action)])\n", 180 | " for i in range(len(action_values)):\n", 181 | " if i == np.argmax(action_values):\n", 182 | " prob.append(1 - e + e/len(action_values))\n", 183 | " else:\n", 184 | " prob.append(e/len(action_values))\n", 185 | " return np.random.choice(actions, p = prob)" 186 | ] 187 | }, 188 | { 189 | "cell_type": "markdown", 190 | "metadata": { 191 | "collapsed": true 192 | }, 193 | "source": [ 194 | "### SARSA: On-policy TD Control\n", 195 | "- Evaluates action-value function ($Q$)" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 7, 201 | "metadata": { 202 | "collapsed": true 203 | }, 204 | "outputs": [], 205 | "source": [ 206 | "def sarsa(env, epsilon, alpha, num_iter):\n", 207 | " Q = state_action_value(env)\n", 208 | " for _ in range(num_iter):\n", 209 | " current_state = np.random.choice(env.states)\n", 210 | " current_action = e_greedy(env, epsilon, Q, current_state)\n", 211 | " while current_state != 0:\n", 212 | " next_state, reward = env.state_transition(current_state, current_action)\n", 213 | " next_action = e_greedy(env, epsilon, Q, next_state)\n", 214 | " Q[current_state, current_action] += alpha * (reward + env.gamma * Q[next_state, next_action] - Q[current_state, current_action])\n", 215 | " current_state, current_action = next_state, next_action\n", 216 | " return Q" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": 8, 222 | "metadata": { 223 | "collapsed": true 224 | }, 225 | "outputs": [], 226 | "source": [ 227 | "Q = sarsa(gw, 0.1, 0.5, 10000)" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": 9, 233 | "metadata": { 234 | "collapsed": true 235 | }, 236 | "outputs": [], 237 | "source": [ 238 | "pi_hat = generate_greedy_policy(gw, Q)" 239 | ] 240 | }, 241 | { 242 | "cell_type": "markdown", 243 | "metadata": {}, 244 | "source": [ 245 | "### Visualizing policy" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": 10, 251 | "metadata": { 252 | "collapsed": true 253 | }, 254 | "outputs": [], 255 | "source": [ 256 | "def show_policy(pi, env):\n", 257 | " temp = np.zeros(len(env.states) + 2)\n", 258 | " for s in env.states:\n", 259 | " a = pi_hat[s][0][np.argmax(pi_hat[s][1])]\n", 260 | " if a == \"U\":\n", 261 | " temp[s] = 0.25\n", 262 | " elif a == \"D\":\n", 263 | " temp[s] = 0.5\n", 264 | " elif a == \"R\":\n", 265 | " temp[s] = 0.75\n", 266 | " else:\n", 267 | " temp[s] = 1.0\n", 268 | " \n", 269 | " temp = temp.reshape(4,4)\n", 270 | " ax = seaborn.heatmap(temp, cmap = \"prism\", linecolor=\"#282828\", cbar = False, linewidths = 0.1)\n", 271 | " plt.show()" 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": 11, 277 | "metadata": {}, 278 | "outputs": [ 279 | { 280 | "data": { 281 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAdwAAAFJCAYAAAAxCJwFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACs5JREFUeJzt3UtoXIXfxvFfmhEV4wUsSBGEbpoMCCku3HURaLwUJFQK\ntkoUJCsXbaEWiakXvLR1JXgJVbty4T+uLC6EYm1RUHBRaCEwUUQQbxtX0hab1pn/QuiLL01ek5d5\nxkw/n13OoWceegjfnkNCBzqdTqcAgK5a1+sBAHAtEFwACBBcAAgQXAAIEFwACBBcAAhodPPizWaz\nWgsL3fwIuqQ5MlJ3t9y7tWq+6f6tVfPNkaqqan3k/q1Fze0j1Wq1rnrOEy4ABAguAAQILgAECC4A\nBAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAE\nCC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQI\nLgAECC4ABPzj4Lbb7W7uAIC+1lju5I8//liHDh2q+fn5ajQa1W63a9OmTTU9PV0bN25MbQSANW/Z\n4M7MzNS+fftqdHT0yrEzZ87U9PR0zc3NdX0cAPSLZV8pLy4u/i22VVWbN2/u6iAA6EfLPuEODw/X\n9PR0bdmypW6++eY6f/58ff755zU8PJzaBwB9Ydngvvjii3XixIk6ffp0nTt3roaGhmpsbKzGx8dT\n+wCgLywb3IGBgRofHxdYAPh/8nu4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4\nABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgA\nECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQMNDpdDrduniz2ezWpQHgX6nVal31eKPb\nH7wwerTbH0EXjJydqrtbC72ewSrNN0fcvzVqvjlSVVWtj9y/tai5fWTJc14pA0CA4AJAgOACQIDg\nAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOAC\nQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJA\ngOACQIDgAkCA4AJAgOACQEBjuZOTk5N16dKlvx3rdDo1MDBQc3NzXR0GAP1k2eA+/fTTdeDAgXr7\n7bdrcHAwtQkA+s6ywR0dHa2JiYn65ptvanx8PLUJAPrOssGtqpqamkrsAIC+5oemACBAcAEgQHAB\nIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEg\nQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBA\ncAEgQHABIGCg0+l0unXxZrPZrUsDwL9Sq9W66vFGtz94YfRotz+CLhg5O1V3txZ6PYNVmm+O1MzC\nf3o9g1V4dWRXVVW1PvL9txY1t48sec4rZQAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBc\nAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwA\nCBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAhYcXAXFxe7sQMA+tqSwT15\n8mSNjY3V+Ph4ffLJJ1eOT01NRYYBQD9pLHXiyJEjdezYsWq327Vnz566ePFibd++vTqdTnIfAPSF\nJYN73XXX1a233lpVVbOzs/XEE0/Uhg0bamBgIDYOAPrFkq+U77zzzjp06FBduHChhoaG6q233qqX\nXnqpvv/+++Q+AOgLSwb34MGDNTw8fOWJdsOGDfX+++/Xgw8+GBsHAP1iyVfKjUajHn744b8dW79+\nfc3MzHR9FAD0G7+HCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYIL\nAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsA\nAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAEDnU6n062LN5vNbl0aAP6VWq3WVY83uv3BC6NH\nu/0RdMHI2amaWfhPr2ewSq+O7HL/1qhXR3ZVVVVrYaHHS1iN5sjIkue8UgaAAMEFgADBBYAAwQWA\nAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAA\nwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADB\nBYCAFQX3jz/+qMXFxW5tAYC+tWxwv/vuu3rqqadqenq6vvrqq9q2bVtt27atTp06ldoHAH2hsdzJ\nF154ofbs2VM///xz7d69u44fP17XX399TU1N1djYWGojAKx5ywa33W7XvffeW1VVX3/9dd1+++1/\n/aHGsn8MAPhfln2lvHHjxpqZmal2u12HDx+uqqp333231q9fHxkHAP1i2UfVV155pU6ePFnr1v1P\nl++4446anJzs+jAA6CfLBnfdunW1devWvx2bmJjo6iAA6Ed+DxcAAgQXAAIEFwACBBcAAgQXAAIE\nFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQX\nAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIGOp1Op1sX\nbzab3bo0APwrtVqtqx7vanABgL94pQwAAYILAAGCCwABggsAAYILAAGCCwABgrsK7Xa7nn/++Xrk\nkUdqcnKyfvjhh15PYoXOnj1bk5OTvZ7BCl26dKn2799fjz76aO3YsaM+++yzXk/iH/rzzz9renq6\ndu7cWbt27apvv/2215PiBHcVTpw4UYuLi/Xhhx/Wvn376vDhw72exAq89957deDAgbp48WKvp7BC\nH3/8cd122231wQcf1NGjR+vll1/u9ST+oVOnTlVV1dzcXO3du7def/31Hi/KE9xVOH36dG3ZsqWq\nqjZv3lzz8/M9XsRK3HXXXfXmm2/2egar8MADD9SePXuqqqrT6dTg4GCPF/FPbd269co/kH755Ze6\n5ZZberwor9HrAWvRuXPnamho6MrXg4ODdfny5Wo0/HWuBffff3/99NNPvZ7BKtx0001V9df34O7d\nu2vv3r09XsRKNBqNeuaZZ+rTTz+tN954o9dz4jzhrsLQ0FCdP3/+ytftdltsIeTXX3+txx9/vCYm\nJuqhhx7q9RxW6LXXXqvjx4/Xc889VxcuXOj1nCjBXYV77rmnvvjii6qqOnPmTG3atKnHi+Da8Ntv\nv9WTTz5Z+/fvrx07dvR6Ditw7Nixeuedd6qq6sYbb6yBgYFat+7aSpDHslUYHx+vL7/8snbu3Fmd\nTqcOHjzY60lwTThy5Ej9/vvvNTs7W7Ozs1X11w/B3XDDDT1exv/lvvvuq+np6Xrsscfq8uXL9eyz\nz15z983/FgQAAdfW8zwA9IjgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkDAfwHhhKWt2U7SnQAA\nAABJRU5ErkJggg==\n", 282 | "text/plain": [ 283 | "" 284 | ] 285 | }, 286 | "metadata": {}, 287 | "output_type": "display_data" 288 | } 289 | ], 290 | "source": [ 291 | "### RED = TERMINAL (0)\n", 292 | "### GREEN = LEFT\n", 293 | "### BLUE = UP\n", 294 | "### PURPLE = RIGHT\n", 295 | "### ORANGE = DOWN\n", 296 | "\n", 297 | "show_policy(pi_hat, gw)" 298 | ] 299 | } 300 | ], 301 | "metadata": { 302 | "kernelspec": { 303 | "display_name": "Python 3", 304 | "language": "python", 305 | "name": "python3" 306 | }, 307 | "language_info": { 308 | "codemirror_mode": { 309 | "name": "ipython", 310 | "version": 3 311 | }, 312 | "file_extension": ".py", 313 | "mimetype": "text/x-python", 314 | "name": "python", 315 | "nbconvert_exporter": "python", 316 | "pygments_lexer": "ipython3", 317 | "version": "3.6.1" 318 | } 319 | }, 320 | "nbformat": 4, 321 | "nbformat_minor": 2 322 | } 323 | -------------------------------------------------------------------------------- /source code/3-Temporal Difference Learning (Chapter 6)/3-Q-learning-off-policy-control.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Q-Learning\n", 8 | "- Algorithms from ```pp. 107 - 109``` in Sutton & Barto 2017" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 15, 14 | "metadata": { 15 | "collapsed": true 16 | }, 17 | "outputs": [], 18 | "source": [ 19 | "import matplotlib.pyplot as plt\n", 20 | "import pandas as pd\n", 21 | "import numpy as np\n", 22 | "import seaborn, random\n", 23 | "\n", 24 | "from gridWorldEnvironment import GridWorld" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 16, 30 | "metadata": { 31 | "collapsed": true 32 | }, 33 | "outputs": [], 34 | "source": [ 35 | "# creating gridworld environment\n", 36 | "gw = GridWorld(gamma = .9, theta = .5)" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 17, 42 | "metadata": { 43 | "collapsed": true 44 | }, 45 | "outputs": [], 46 | "source": [ 47 | "def state_action_value(env):\n", 48 | " q = dict()\n", 49 | " for state, action, next_state, reward in env.transitions:\n", 50 | " q[(state, action)] = np.random.normal()\n", 51 | " for action in env.actions:\n", 52 | " q[0, action] = 0\n", 53 | " return q" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 18, 59 | "metadata": { 60 | "scrolled": true 61 | }, 62 | "outputs": [ 63 | { 64 | "data": { 65 | "text/plain": [ 66 | "{(0, 'D'): 0,\n", 67 | " (0, 'L'): 0,\n", 68 | " (0, 'R'): 0,\n", 69 | " (0, 'U'): 0,\n", 70 | " (1, 'D'): 1.0159263422385194,\n", 71 | " (1, 'L'): -0.9015277224413166,\n", 72 | " (1, 'R'): -0.8401929147381398,\n", 73 | " (1, 'U'): -0.3866748964867951,\n", 74 | " (2, 'D'): -1.6171831004555488,\n", 75 | " (2, 'L'): -0.757413177831847,\n", 76 | " (2, 'R'): 1.1948020656778442,\n", 77 | " (2, 'U'): -0.6814167904466197,\n", 78 | " (3, 'D'): 0.8048046240684537,\n", 79 | " (3, 'L'): 0.5058075927359411,\n", 80 | " (3, 'R'): -0.2396998941031068,\n", 81 | " (3, 'U'): 0.2155229655857796,\n", 82 | " (4, 'D'): -1.6147931728590075,\n", 83 | " (4, 'L'): -0.7788455232241043,\n", 84 | " (4, 'R'): -1.4834265811813951,\n", 85 | " (4, 'U'): -0.6831094436413707,\n", 86 | " (5, 'D'): 0.553787569878218,\n", 87 | " (5, 'L'): -0.7535460982740707,\n", 88 | " (5, 'R'): -0.5446399864366045,\n", 89 | " (5, 'U'): 0.3275946472542217,\n", 90 | " (6, 'D'): 0.4967878126487371,\n", 91 | " (6, 'L'): -0.08765553689024165,\n", 92 | " (6, 'R'): 2.561671991266108,\n", 93 | " (6, 'U'): -0.6251803976764682,\n", 94 | " (7, 'D'): -1.4239033515729715,\n", 95 | " (7, 'L'): 0.8961390571157126,\n", 96 | " (7, 'R'): -1.752510423129062,\n", 97 | " (7, 'U'): 0.186255694272231,\n", 98 | " (8, 'D'): 2.0464864623182057,\n", 99 | " (8, 'L'): -1.4145441882793333,\n", 100 | " (8, 'R'): -0.01479790948318322,\n", 101 | " (8, 'U'): -0.5827888737714622,\n", 102 | " (9, 'D'): -1.99334821662137,\n", 103 | " (9, 'L'): -0.2143702145160287,\n", 104 | " (9, 'R'): 0.3093816164828963,\n", 105 | " (9, 'U'): -0.13810316479560683,\n", 106 | " (10, 'D'): -0.4286970274384264,\n", 107 | " (10, 'L'): -1.0313201895326287,\n", 108 | " (10, 'R'): 0.674733204171753,\n", 109 | " (10, 'U'): -0.06485579239030625,\n", 110 | " (11, 'D'): -1.08761491406436,\n", 111 | " (11, 'L'): 0.37568985747762257,\n", 112 | " (11, 'R'): 0.6549082083274831,\n", 113 | " (11, 'U'): -0.044872797471142374,\n", 114 | " (12, 'D'): -0.8474472601122482,\n", 115 | " (12, 'L'): -0.6624032631781939,\n", 116 | " (12, 'R'): 0.4720691318622512,\n", 117 | " (12, 'U'): -1.2333806138625314,\n", 118 | " (13, 'D'): -0.38599325593593187,\n", 119 | " (13, 'L'): -1.677564721893633,\n", 120 | " (13, 'R'): -1.1425774085316547,\n", 121 | " (13, 'U'): -1.3606173055898687,\n", 122 | " (14, 'D'): -1.1641372208339686,\n", 123 | " (14, 'L'): 1.1457975665987994,\n", 124 | " (14, 'R'): -1.012618770180745,\n", 125 | " (14, 'U'): -0.06091464445082383}" 126 | ] 127 | }, 128 | "execution_count": 18, 129 | "metadata": {}, 130 | "output_type": "execute_result" 131 | } 132 | ], 133 | "source": [ 134 | "state_action_value(gw)" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 19, 140 | "metadata": { 141 | "collapsed": true 142 | }, 143 | "outputs": [], 144 | "source": [ 145 | "def generate_greedy_policy(env, Q):\n", 146 | " pi = dict()\n", 147 | " for state in env.states:\n", 148 | " actions = []\n", 149 | " q_values = []\n", 150 | " prob = []\n", 151 | " \n", 152 | " for a in env.actions:\n", 153 | " actions.append(a)\n", 154 | " q_values.append(Q[state,a]) \n", 155 | " for i in range(len(q_values)):\n", 156 | " if i == np.argmax(q_values):\n", 157 | " prob.append(1)\n", 158 | " else:\n", 159 | " prob.append(0) \n", 160 | " \n", 161 | " pi[state] = (actions, prob)\n", 162 | " return pi" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 20, 168 | "metadata": { 169 | "collapsed": true 170 | }, 171 | "outputs": [], 172 | "source": [ 173 | "def e_greedy(env, e, q, state):\n", 174 | " actions = env.actions\n", 175 | " action_values = []\n", 176 | " prob = []\n", 177 | " for action in actions:\n", 178 | " action_values.append(q[(state, action)])\n", 179 | " for i in range(len(action_values)):\n", 180 | " if i == np.argmax(action_values):\n", 181 | " prob.append(1 - e + e/len(action_values))\n", 182 | " else:\n", 183 | " prob.append(e/len(action_values))\n", 184 | " return np.random.choice(actions, p = prob)" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 21, 190 | "metadata": { 191 | "collapsed": true 192 | }, 193 | "outputs": [], 194 | "source": [ 195 | "def greedy(env, q, state):\n", 196 | " actions = env.actions\n", 197 | " action_values = []\n", 198 | " for action in actions:\n", 199 | " action_values.append(q[state, action])\n", 200 | " return actions[np.argmax(action_values)]" 201 | ] 202 | }, 203 | { 204 | "cell_type": "markdown", 205 | "metadata": { 206 | "collapsed": true 207 | }, 208 | "source": [ 209 | "### Q-Learning: Off-policy TD Control" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 22, 215 | "metadata": { 216 | "collapsed": true 217 | }, 218 | "outputs": [], 219 | "source": [ 220 | "def q_learning(env, epsilon, alpha, num_iter):\n", 221 | " Q = state_action_value(env)\n", 222 | " \n", 223 | " for _ in range(num_iter):\n", 224 | " current_state = np.random.choice(env.states)\n", 225 | " while current_state != 0:\n", 226 | " current_action = e_greedy(env, epsilon, Q, current_state)\n", 227 | " next_state, reward = env.state_transition(current_state, current_action)\n", 228 | " best_action = greedy(env, Q, next_state)\n", 229 | " Q[current_state, current_action] += alpha * (reward + env.gamma * Q[next_state, best_action] - Q[current_state, current_action])\n", 230 | " current_state = next_state\n", 231 | " return Q" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": 29, 237 | "metadata": { 238 | "collapsed": true 239 | }, 240 | "outputs": [], 241 | "source": [ 242 | "Q = q_learning(gw, 0.2, 1.0, 10000)" 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": 30, 248 | "metadata": { 249 | "scrolled": true 250 | }, 251 | "outputs": [ 252 | { 253 | "data": { 254 | "text/plain": [ 255 | "{(0, 'D'): 0,\n", 256 | " (0, 'L'): 0,\n", 257 | " (0, 'R'): 0,\n", 258 | " (0, 'U'): 0,\n", 259 | " (1, 'D'): -2.71,\n", 260 | " (1, 'L'): -1.0,\n", 261 | " (1, 'R'): -2.71,\n", 262 | " (1, 'U'): -1.9,\n", 263 | " (2, 'D'): -3.439,\n", 264 | " (2, 'L'): -1.9,\n", 265 | " (2, 'R'): -3.439,\n", 266 | " (2, 'U'): -2.71,\n", 267 | " (3, 'D'): -2.71,\n", 268 | " (3, 'L'): -2.71,\n", 269 | " (3, 'R'): -3.439,\n", 270 | " (3, 'U'): -3.439,\n", 271 | " (4, 'D'): -2.71,\n", 272 | " (4, 'L'): -1.9,\n", 273 | " (4, 'R'): -2.71,\n", 274 | " (4, 'U'): -1.0,\n", 275 | " (5, 'D'): -3.439,\n", 276 | " (5, 'L'): -1.9,\n", 277 | " (5, 'R'): -3.439,\n", 278 | " (5, 'U'): -1.9,\n", 279 | " (6, 'D'): -2.71,\n", 280 | " (6, 'L'): -2.71,\n", 281 | " (6, 'R'): -2.71,\n", 282 | " (6, 'U'): -2.71,\n", 283 | " (7, 'D'): -1.9,\n", 284 | " (7, 'L'): -3.439,\n", 285 | " (7, 'R'): -2.71,\n", 286 | " (7, 'U'): -3.439,\n", 287 | " (8, 'D'): -3.439,\n", 288 | " (8, 'L'): -2.71,\n", 289 | " (8, 'R'): -3.439,\n", 290 | " (8, 'U'): -1.9,\n", 291 | " (9, 'D'): -2.71,\n", 292 | " (9, 'L'): -2.71,\n", 293 | " (9, 'R'): -2.71,\n", 294 | " (9, 'U'): -2.71,\n", 295 | " (10, 'D'): -1.9,\n", 296 | " (10, 'L'): -3.439,\n", 297 | " (10, 'R'): -1.9,\n", 298 | " (10, 'U'): -3.439,\n", 299 | " (11, 'D'): -1.0,\n", 300 | " (11, 'L'): -2.71,\n", 301 | " (11, 'R'): -1.9,\n", 302 | " (11, 'U'): -2.71,\n", 303 | " (12, 'D'): -3.439,\n", 304 | " (12, 'L'): -3.439,\n", 305 | " (12, 'R'): -2.71,\n", 306 | " (12, 'U'): -2.71,\n", 307 | " (13, 'D'): -2.71,\n", 308 | " (13, 'L'): -3.439,\n", 309 | " (13, 'R'): -1.9,\n", 310 | " (13, 'U'): -3.439,\n", 311 | " (14, 'D'): -1.9,\n", 312 | " (14, 'L'): -2.71,\n", 313 | " (14, 'R'): -1.0,\n", 314 | " (14, 'U'): -2.71}" 315 | ] 316 | }, 317 | "execution_count": 30, 318 | "metadata": {}, 319 | "output_type": "execute_result" 320 | } 321 | ], 322 | "source": [ 323 | "Q" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": 31, 329 | "metadata": { 330 | "collapsed": true 331 | }, 332 | "outputs": [], 333 | "source": [ 334 | "pi_hat = generate_greedy_policy(gw, Q)" 335 | ] 336 | }, 337 | { 338 | "cell_type": "markdown", 339 | "metadata": {}, 340 | "source": [ 341 | "### Visualizing policy" 342 | ] 343 | }, 344 | { 345 | "cell_type": "code", 346 | "execution_count": 32, 347 | "metadata": { 348 | "collapsed": true 349 | }, 350 | "outputs": [], 351 | "source": [ 352 | "def show_policy(pi, env):\n", 353 | " temp = np.zeros(len(env.states) + 2)\n", 354 | " for s in env.states:\n", 355 | " a = pi_hat[s][0][np.argmax(pi_hat[s][1])]\n", 356 | " if a == \"U\":\n", 357 | " temp[s] = 0.25\n", 358 | " elif a == \"D\":\n", 359 | " temp[s] = 0.5\n", 360 | " elif a == \"R\":\n", 361 | " temp[s] = 0.75\n", 362 | " else:\n", 363 | " temp[s] = 1.0\n", 364 | " \n", 365 | " temp = temp.reshape(4,4)\n", 366 | " ax = seaborn.heatmap(temp, cmap = \"prism\", linecolor=\"#282828\", cbar = False, linewidths = 0.1)\n", 367 | " plt.show()" 368 | ] 369 | }, 370 | { 371 | "cell_type": "code", 372 | "execution_count": 33, 373 | "metadata": {}, 374 | "outputs": [ 375 | { 376 | "data": { 377 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAdwAAAFJCAYAAAAxCJwFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACsFJREFUeJzt3U+IlYX+x/HvOCcqmv5AQkgQuHHmwAWlRTsXA05/hBgM\nIS2mIGbVwhFM4jT2h+qqrYL+DFauWnSnVdIikEwpKGghKAhnigiif5tWoZKjnfNbxM9LF53bzOV8\nTnN8vXbzPMxzPvggb5+HGRzqdrvdAgB6ak2/BwDAtUBwASBAcAEgQHABIEBwASBAcAEgoNHLizeb\nzWovLPTyI+iR5thY/aPt3q1WZ5ru32p1pjlWVVXtD9y/1ai5baza7fYVz3nCBYAAwQWAAMEFgADB\nBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEF\ngADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWA\nAMEFgADBBYCAvxzcTqfTyx0AMNAaS538/vvv68CBA3XmzJlqNBrV6XRqw4YN1Wq1av369amNALDq\nLRnc2dnZ2rNnT23cuPHysVOnTlWr1ar5+fmejwOAQbHkK+XFxcU/xbaqatOmTT0dBACDaMkn3NHR\n0Wq1WrV58+a6+eab69y5c/Xpp5/W6Ohoah8ADIQlg/vCCy/UsWPH6uTJk3X27NkaGRmp8fHxmpiY\nSO0DgIGwZHCHhoZqYmJCYAHgf+T3cAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBA\ncAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBw\nASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIGCo2+12e3XxZrPZq0sDwN9Su92+4vFG\nrz94YePhXn8EPTB2etq9W8Xcv9Vr7PR0VVW1P1jo8xJWorlt7KrnvFIGgADBBYAAwQWAAMEFgADB\nBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEF\ngADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWA\nAMEFgADBBYAAwQWAgMZSJ6empurixYt/OtbtdmtoaKjm5+d7OgwABsmSwX3qqadq37599eabb9bw\n8HBqEwAMnCWDu3HjxpqcnKyvvvqqJiYmUpsAYOAsGdyqqunp6cQOABhofmgKAAIEFwACBBcAAgQX\nAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcA\nAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwAC\nBBcAAoa63W63VxdvNpu9ujQA/C212+0rHm/0+oMXNh7u9UfQA2Onp927VWzs9HS1P1jo9wxWoLlt\nrKrK/Vul/v/+XYlXygAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4\nABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgA\nECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABCw7OAuLi72YgcADLSrBvf48eM1Pj5eExMT9dFH\nH10+Pj09HRkGAIOkcbUThw4dqiNHjlSn06mZmZm6cOFCbdu2rbrdbnIfAAyEqwb3uuuuq1tvvbWq\nqubm5urxxx+vdevW1dDQUGwcAAyKq75SvvPOO+vAgQN1/vz5GhkZqTfeeKNefPHF+vbbb5P7AGAg\nXDW4+/fvr9HR0ctPtOvWrat33323Hnjggdg4ABgUV32l3Gg06qGHHvrTsbVr19bs7GzPRwHAoPF7\nuAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4\nABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgA\nECC4ABAguAAQILgAECC4ABAw1O12u726eLPZ7NWlAeBvqd1uX/F4o9cfvLDxcK8/gh4YOz1dswv/\n6vcMVuifYzvdv1Xqn2M7q6qqvbDQ5yWsRHNs7KrnvFIGgADBBYAAwQWAAMEFgADBBYAAwQWAAMEF\ngADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWA\nAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAgGUF97fffqvF\nxcVebQGAgbVkcL/55pt68sknq9Vq1RdffFFbt26trVu31okTJ1L7AGAgNJY6+fzzz9fMzEz9+OOP\ntWvXrjp69Ghdf/31NT09XePj46mNALDqLRncTqdT99xzT1VVffnll3X77bf/8U2NJb8NAPgPS75S\nXr9+fc3Ozlan06mDBw9WVdXbb79da9eujYwDgEGx5KPqyy+/XMePH681a/7d5TvuuKOmpqZ6PgwA\nBsmSwV2zZk1t2bLlT8cmJyd7OggABpHfwwWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADB\nBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEF\ngADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgIChbrfb7dXFm81mry4NAH9L7Xb7\nisd7GlwA4A9eKQNAgOACQIDgAkCA4AJAgOACQIDgAkCA4K5Ap9Op5557rh5++OGampqq7777rt+T\nWKbTp0/X1NRUv2ewTBcvXqy9e/fWI488Utu3b69PPvmk35P4i37//fdqtVq1Y8eO2rlzZ3399df9\nnhQnuCtw7NixWlxcrPfff7/27NlTBw8e7PckluGdd96pffv21YULF/o9hWX68MMP67bbbqv33nuv\nDh8+XC+99FK/J/EXnThxoqqq5ufna/fu3fXqq6/2eVGe4K7AyZMna/PmzVVVtWnTpjpz5kyfF7Ec\nd911V73++uv9nsEK3H///TUzM1NVVd1ut4aHh/u8iL9qy5Ytl/+B9NNPP9Utt9zS50V5jX4PWI3O\nnj1bIyMjl78eHh6uS5cuVaPhj3M1uO++++qHH37o9wxW4KabbqqqP/4O7tq1q3bv3t3nRSxHo9Go\np59+uj7++ON67bXX+j0nzhPuCoyMjNS5c+cuf93pdMQWQn7++ed67LHHanJysh588MF+z2GZXnnl\nlTp69Gg9++yzdf78+X7PiRLcFbj77rvrs88+q6qqU6dO1YYNG/q8CK4Nv/zySz3xxBO1d+/e2r59\ne7/nsAxHjhypt956q6qqbrzxxhoaGqo1a66tBHksW4GJiYn6/PPPa8eOHdXtdmv//v39ngTXhEOH\nDtWvv/5ac3NzNTc3V1V//BDcDTfc0Odl/Df33ntvtVqtevTRR+vSpUv1zDPPXHP3zf8WBAAB19bz\nPAD0ieACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQMD/AWrkpa0qFuevAAAAAElFTkSuQmCC\n", 378 | "text/plain": [ 379 | "" 380 | ] 381 | }, 382 | "metadata": {}, 383 | "output_type": "display_data" 384 | } 385 | ], 386 | "source": [ 387 | "### RED = TERMINAL (0)\n", 388 | "### GREEN = LEFT\n", 389 | "### BLUE = UP\n", 390 | "### PURPLE = RIGHT\n", 391 | "### ORANGE = DOWN\n", 392 | "\n", 393 | "show_policy(pi_hat, gw)" 394 | ] 395 | } 396 | ], 397 | "metadata": { 398 | "kernelspec": { 399 | "display_name": "Python 3", 400 | "language": "python", 401 | "name": "python3" 402 | }, 403 | "language_info": { 404 | "codemirror_mode": { 405 | "name": "ipython", 406 | "version": 3 407 | }, 408 | "file_extension": ".py", 409 | "mimetype": "text/x-python", 410 | "name": "python", 411 | "nbconvert_exporter": "python", 412 | "pygments_lexer": "ipython3", 413 | "version": "3.6.1" 414 | } 415 | }, 416 | "nbformat": 4, 417 | "nbformat_minor": 2 418 | } 419 | -------------------------------------------------------------------------------- /source code/3-Temporal Difference Learning (Chapter 6)/4-double-Q-learning-off-policy-control.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Double Q-Learning\n", 8 | "- Algorithms from ```pp. 110 - 111``` in Sutton & Barto 2017\n", 9 | "- Double Q-learning algorithm employs to action-value functions (i.e., $Q_1 and Q_2$) to avoid maximization bias" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 15, 15 | "metadata": { 16 | "collapsed": true 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "import matplotlib.pyplot as plt\n", 21 | "import pandas as pd\n", 22 | "import numpy as np\n", 23 | "import seaborn, random\n", 24 | "\n", 25 | "from gridWorldEnvironment import GridWorld" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 16, 31 | "metadata": { 32 | "collapsed": true 33 | }, 34 | "outputs": [], 35 | "source": [ 36 | "# creating gridworld environment\n", 37 | "gw = GridWorld(gamma = .9, theta = .5)" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 17, 43 | "metadata": { 44 | "collapsed": true 45 | }, 46 | "outputs": [], 47 | "source": [ 48 | "def state_action_value(env):\n", 49 | " q = dict()\n", 50 | " for state, action, next_state, reward in env.transitions:\n", 51 | " q[(state, action)] = np.random.normal()\n", 52 | " for action in env.actions:\n", 53 | " q[0, action] = 0\n", 54 | " return q" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 18, 60 | "metadata": { 61 | "scrolled": true 62 | }, 63 | "outputs": [ 64 | { 65 | "data": { 66 | "text/plain": [ 67 | "{(0, 'D'): 0,\n", 68 | " (0, 'L'): 0,\n", 69 | " (0, 'R'): 0,\n", 70 | " (0, 'U'): 0,\n", 71 | " (1, 'D'): 1.0159263422385194,\n", 72 | " (1, 'L'): -0.9015277224413166,\n", 73 | " (1, 'R'): -0.8401929147381398,\n", 74 | " (1, 'U'): -0.3866748964867951,\n", 75 | " (2, 'D'): -1.6171831004555488,\n", 76 | " (2, 'L'): -0.757413177831847,\n", 77 | " (2, 'R'): 1.1948020656778442,\n", 78 | " (2, 'U'): -0.6814167904466197,\n", 79 | " (3, 'D'): 0.8048046240684537,\n", 80 | " (3, 'L'): 0.5058075927359411,\n", 81 | " (3, 'R'): -0.2396998941031068,\n", 82 | " (3, 'U'): 0.2155229655857796,\n", 83 | " (4, 'D'): -1.6147931728590075,\n", 84 | " (4, 'L'): -0.7788455232241043,\n", 85 | " (4, 'R'): -1.4834265811813951,\n", 86 | " (4, 'U'): -0.6831094436413707,\n", 87 | " (5, 'D'): 0.553787569878218,\n", 88 | " (5, 'L'): -0.7535460982740707,\n", 89 | " (5, 'R'): -0.5446399864366045,\n", 90 | " (5, 'U'): 0.3275946472542217,\n", 91 | " (6, 'D'): 0.4967878126487371,\n", 92 | " (6, 'L'): -0.08765553689024165,\n", 93 | " (6, 'R'): 2.561671991266108,\n", 94 | " (6, 'U'): -0.6251803976764682,\n", 95 | " (7, 'D'): -1.4239033515729715,\n", 96 | " (7, 'L'): 0.8961390571157126,\n", 97 | " (7, 'R'): -1.752510423129062,\n", 98 | " (7, 'U'): 0.186255694272231,\n", 99 | " (8, 'D'): 2.0464864623182057,\n", 100 | " (8, 'L'): -1.4145441882793333,\n", 101 | " (8, 'R'): -0.01479790948318322,\n", 102 | " (8, 'U'): -0.5827888737714622,\n", 103 | " (9, 'D'): -1.99334821662137,\n", 104 | " (9, 'L'): -0.2143702145160287,\n", 105 | " (9, 'R'): 0.3093816164828963,\n", 106 | " (9, 'U'): -0.13810316479560683,\n", 107 | " (10, 'D'): -0.4286970274384264,\n", 108 | " (10, 'L'): -1.0313201895326287,\n", 109 | " (10, 'R'): 0.674733204171753,\n", 110 | " (10, 'U'): -0.06485579239030625,\n", 111 | " (11, 'D'): -1.08761491406436,\n", 112 | " (11, 'L'): 0.37568985747762257,\n", 113 | " (11, 'R'): 0.6549082083274831,\n", 114 | " (11, 'U'): -0.044872797471142374,\n", 115 | " (12, 'D'): -0.8474472601122482,\n", 116 | " (12, 'L'): -0.6624032631781939,\n", 117 | " (12, 'R'): 0.4720691318622512,\n", 118 | " (12, 'U'): -1.2333806138625314,\n", 119 | " (13, 'D'): -0.38599325593593187,\n", 120 | " (13, 'L'): -1.677564721893633,\n", 121 | " (13, 'R'): -1.1425774085316547,\n", 122 | " (13, 'U'): -1.3606173055898687,\n", 123 | " (14, 'D'): -1.1641372208339686,\n", 124 | " (14, 'L'): 1.1457975665987994,\n", 125 | " (14, 'R'): -1.012618770180745,\n", 126 | " (14, 'U'): -0.06091464445082383}" 127 | ] 128 | }, 129 | "execution_count": 18, 130 | "metadata": {}, 131 | "output_type": "execute_result" 132 | } 133 | ], 134 | "source": [ 135 | "state_action_value(gw)" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 19, 141 | "metadata": { 142 | "collapsed": true 143 | }, 144 | "outputs": [], 145 | "source": [ 146 | "def generate_greedy_policy(env, Q):\n", 147 | " pi = dict()\n", 148 | " for state in env.states:\n", 149 | " actions = []\n", 150 | " q_values = []\n", 151 | " prob = []\n", 152 | " \n", 153 | " for a in env.actions:\n", 154 | " actions.append(a)\n", 155 | " q_values.append(Q[state,a]) \n", 156 | " for i in range(len(q_values)):\n", 157 | " if i == np.argmax(q_values):\n", 158 | " prob.append(1)\n", 159 | " else:\n", 160 | " prob.append(0) \n", 161 | " \n", 162 | " pi[state] = (actions, prob)\n", 163 | " return pi" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 20, 169 | "metadata": { 170 | "collapsed": true 171 | }, 172 | "outputs": [], 173 | "source": [ 174 | "def e_greedy(env, e, q, state):\n", 175 | " actions = env.actions\n", 176 | " action_values = []\n", 177 | " prob = []\n", 178 | " for action in actions:\n", 179 | " action_values.append(q[(state, action)])\n", 180 | " for i in range(len(action_values)):\n", 181 | " if i == np.argmax(action_values):\n", 182 | " prob.append(1 - e + e/len(action_values))\n", 183 | " else:\n", 184 | " prob.append(e/len(action_values))\n", 185 | " return np.random.choice(actions, p = prob)" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": 21, 191 | "metadata": { 192 | "collapsed": true 193 | }, 194 | "outputs": [], 195 | "source": [ 196 | "def greedy(env, q, state):\n", 197 | " actions = env.actions\n", 198 | " action_values = []\n", 199 | " for action in actions:\n", 200 | " action_values.append(q[state, action])\n", 201 | " return actions[np.argmax(action_values)]" 202 | ] 203 | }, 204 | { 205 | "cell_type": "markdown", 206 | "metadata": {}, 207 | "source": [ 208 | "### Double Q-learning" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": 34, 214 | "metadata": { 215 | "collapsed": true 216 | }, 217 | "outputs": [], 218 | "source": [ 219 | "def double_q_learning(env, epsilon, alpha, num_iter):\n", 220 | " Q1, Q2 = state_action_value(env), state_action_value(env)\n", 221 | " for _ in range(num_iter):\n", 222 | " current_state = np.random.choice(env.states)\n", 223 | " while current_state != 0:\n", 224 | " Q = dict()\n", 225 | " for key in Q1.keys():\n", 226 | " Q[key] = Q1[key] + Q2[key]\n", 227 | " current_action = e_greedy(env, epsilon, Q, current_state)\n", 228 | " next_state, reward = env.state_transition(current_state, current_action)\n", 229 | " \n", 230 | " # choose Q1 or Q2 with equal probabilities (0.5)\n", 231 | " chosen_Q = np.random.choice([\"Q1\", \"Q2\"])\n", 232 | " if chosen_Q == \"Q1\": # when Q1 is chosen\n", 233 | " best_action = greedy(env, Q1, next_state)\n", 234 | " Q1[current_state, current_action] += alpha * \\\n", 235 | " (reward + env.gamma * Q2[next_state, best_action] - Q1[current_state, current_action])\n", 236 | " else: # when Q2 is chosen\n", 237 | " best_action = greedy(env, Q2, next_state)\n", 238 | " Q2[current_state, current_action] += alpha * \\\n", 239 | " (reward + env.gamma * Q1[next_state, best_action] - Q2[current_state, current_action])\n", 240 | " \n", 241 | " current_state = next_state\n", 242 | " return Q1, Q2" 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": 38, 248 | "metadata": { 249 | "collapsed": true 250 | }, 251 | "outputs": [], 252 | "source": [ 253 | "Q1, Q2 = double_q_learning(gw, 0.2, 0.5, 5000)\n", 254 | "\n", 255 | "# sum Q1 & Q2 elementwise to obtain final Q-values\n", 256 | "Q = dict()\n", 257 | "for key in Q1.keys():\n", 258 | " Q[key] = Q1[key] + Q2[key]" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": 39, 264 | "metadata": { 265 | "collapsed": true 266 | }, 267 | "outputs": [], 268 | "source": [ 269 | "pi_hat = generate_greedy_policy(gw, Q)" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": 32, 275 | "metadata": { 276 | "collapsed": true 277 | }, 278 | "outputs": [], 279 | "source": [ 280 | "def show_policy(pi, env):\n", 281 | " temp = np.zeros(len(env.states) + 2)\n", 282 | " for s in env.states:\n", 283 | " a = pi_hat[s][0][np.argmax(pi_hat[s][1])]\n", 284 | " if a == \"U\":\n", 285 | " temp[s] = 0.25\n", 286 | " elif a == \"D\":\n", 287 | " temp[s] = 0.5\n", 288 | " elif a == \"R\":\n", 289 | " temp[s] = 0.75\n", 290 | " else:\n", 291 | " temp[s] = 1.0\n", 292 | " \n", 293 | " temp = temp.reshape(4,4)\n", 294 | " ax = seaborn.heatmap(temp, cmap = \"prism\", linecolor=\"#282828\", cbar = False, linewidths = 0.1)\n", 295 | " plt.show()" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": 40, 301 | "metadata": {}, 302 | "outputs": [ 303 | { 304 | "data": { 305 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAdwAAAFJCAYAAAAxCJwFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACrhJREFUeJzt3U2IlYUex/H/OCcqml4gISQI3Dhz4ILSop2LAacXIQZD\nSIspiFm1cASTOI29UF21VdDLYOWqRXdaXWkRSKYUFLQQFAaORQTR26ZVqORo59xFXC9ddG4zl/M7\nzfHz2c3zMM/54YN8fR5mcKjb7XYLAOipNf0eAADXAsEFgADBBYAAwQWAAMEFgADBBYCARi8v3mw2\nq33mTC8/gh5pjo3V39ru3Wq10HT/VquF5lhVlfu3Si00x6rdbl/xnCdcAAgQXAAIEFwACBBcAAgQ\nXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBc\nAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwA\nCBBcAAj408HtdDq93AEAA62x1MnvvvuuDhw4UAsLC9VoNKrT6dSGDRuq1WrV+vXrUxsBYNVbMriz\ns7O1Z8+e2rhx4+Vjp06dqlarVfPz8z0fBwCDYslXyouLi3+IbVXVpk2bejoIAAbRkk+4o6Oj1Wq1\navPmzXXzzTfXuXPn6pNPPqnR0dHUPgAYCEsG94UXXqhjx47VyZMn6+zZszUyMlLj4+M1MTGR2gcA\nA2HJ4A4NDdXExITAAsD/ye/hAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA\n4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDg\nAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAwFC32+326uLNZrNXlwaAv6R2u33F441ef/CZ\njYd7/RH0wNjpafduFXP/Vq+x09NVVdX+55k+L2ElmtvGrnrOK2UACBBcAAgQXAAIEFwACBBcAAgQ\nXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBc\nAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwA\nCBBcAAgQXAAIaCx1cmpqqi5evPiHY91ut4aGhmp+fr6nwwBgkCwZ3Keeeqr27dtXb775Zg0PD6c2\nAcDAWTK4GzdurMnJyfryyy9rYmIitQkABs6Swa2qmp6eTuwAgIHmh6YAIEBwASBAcAEgQHABIEBw\nASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHAB\nIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEg\nYKjb7XZ7dfFms9mrSwPAX1K73b7i8UavP/jMxsO9/gh6YOz0dP2tfabfM1ihheZYtf/p/q1GzW1j\nVVXu3yr17/t3JV4pA0CA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOAC\nQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJA\ngOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQMCyg7u4uNiLHQAw0K4a3OPHj9f4+HhNTEzUhx9+\nePn49PR0ZBgADJLG1U4cOnSojhw5Up1Op2ZmZurChQu1bdu26na7yX0AMBCuGtzrrruubr311qqq\nmpubq8cff7zWrVtXQ0NDsXEAMCiu+kr5zjvvrAMHDtT58+drZGSk3njjjXrxxRfrm2++Se4DgIFw\n1eDu37+/RkdHLz/Rrlu3rt5999164IEHYuMAYFBc9ZVyo9Gohx566A/H1q5dW7Ozsz0fBQCDxu/h\nAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOAC\nQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJA\ngOACQIDgAkCA4AJAgOACQMBQt9vt9urizWazV5cGgL+kdrt9xeONXn/w7Jl/9Poj6IG/j+1071Yx\n92/1+vvYzqqqap850+clrERzbOyq57xSBoAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEF\ngADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWA\nAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgIBlBffXX3+txcXFXm0BgIG1\nZHC//vrrevLJJ6vVatXnn39eW7dura1bt9aJEydS+wBgIDSWOvn888/XzMxM/fDDD7Vr1646evRo\nXX/99TU9PV3j4+OpjQCw6i0Z3E6nU/fcc09VVX3xxRd1++23//5NjSW/DQD4L0u+Ul6/fn3Nzs5W\np9OpgwcPVlXV22+/XWvXro2MA4BBseSj6ssvv1zHjx+vNWv+0+U77rijpqamej4MAAbJksFds2ZN\nbdmy5Q/HJicnezoIAAaR38MFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADB\nBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEF\ngADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYCAoW632+3VxZvNZq8uDQB/Se12+4rHexpcAOB3\nXikDQIDgAkCA4AJAgOACQIDgAkCA4AJAgOCuQKfTqeeee64efvjhmpqaqm+//bbfk1im06dP19TU\nVL9nsEwXL16svXv31iOPPFLbt2+vjz/+uN+T+JN+++23arVatWPHjtq5c2d99dVX/Z4UJ7grcOzY\nsVpcXKz333+/9uzZUwcPHuz3JJbhnXfeqX379tWFCxf6PYVl+uCDD+q2226r9957rw4fPlwvvfRS\nvyfxJ504caKqqubn52v37t316quv9nlRnuCuwMmTJ2vz5s1VVbVp06ZaWFjo8yKW46677qrXX3+9\n3zNYgfvvv79mZmaqqqrb7dbw8HCfF/Fnbdmy5fI/kH788ce65ZZb+rwor9HvAavR2bNna2Rk5PLX\nw8PDdenSpWo0/HGuBvfdd199//33/Z7BCtx0001V9fvfwV27dtXu3bv7vIjlaDQa9fTTT9dHH31U\nr732Wr/nxHnCXYGRkZE6d+7c5a87nY7YQshPP/1Ujz32WE1OTtaDDz7Y7zks0yuvvFJHjx6tZ599\nts6fP9/vOVGCuwJ33313ffrpp1VVderUqdqwYUOfF8G14eeff64nnnii9u7dW9u3b+/3HJbhyJEj\n9dZbb1VV1Y033lhDQ0O1Zs21lSCPZSswMTFRn332We3YsaO63W7t37+/35PgmnDo0KH65Zdfam5u\nrubm5qrq9x+Cu+GGG/q8jP/l3nvvrVarVY8++mhdunSpnnnmmWvuvvnfggAg4Np6ngeAPhFcAAgQ\nXAAIEFwACBBcAAgQXAAIEFwACBBcAAj4F4Bkpa2zV4RVAAAAAElFTkSuQmCC\n", 306 | "text/plain": [ 307 | "" 308 | ] 309 | }, 310 | "metadata": {}, 311 | "output_type": "display_data" 312 | } 313 | ], 314 | "source": [ 315 | "### RED = TERMINAL (0)\n", 316 | "### GREEN = LEFT\n", 317 | "### BLUE = UP\n", 318 | "### PURPLE = RIGHT\n", 319 | "### ORANGE = DOWN\n", 320 | "\n", 321 | "show_policy(pi_hat, gw)" 322 | ] 323 | } 324 | ], 325 | "metadata": { 326 | "kernelspec": { 327 | "display_name": "Python 3", 328 | "language": "python", 329 | "name": "python3" 330 | }, 331 | "language_info": { 332 | "codemirror_mode": { 333 | "name": "ipython", 334 | "version": 3 335 | }, 336 | "file_extension": ".py", 337 | "mimetype": "text/x-python", 338 | "name": "python", 339 | "nbconvert_exporter": "python", 340 | "pygments_lexer": "ipython3", 341 | "version": "3.6.1" 342 | } 343 | }, 344 | "nbformat": 4, 345 | "nbformat_minor": 2 346 | } 347 | -------------------------------------------------------------------------------- /source code/3-Temporal Difference Learning (Chapter 6)/__pycache__/gridWorldEnvironment.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buomsoo-kim/Tabular-RL-with-Python/0157c05126821524dc9d744613b9aa8dd5d47232/source code/3-Temporal Difference Learning (Chapter 6)/__pycache__/gridWorldEnvironment.cpython-36.pyc -------------------------------------------------------------------------------- /source code/3-Temporal Difference Learning (Chapter 6)/gridWorldEnvironment.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import seaborn 4 | from matplotlib.colors import ListedColormap 5 | 6 | class GridWorld: 7 | def __init__(self, gamma = 1.0, theta = 0.5): 8 | self.actions = ("U", "D", "L", "R") 9 | self.states = np.arange(1, 15) 10 | self.transitions = pd.read_csv("gridworld.txt", header = None, sep = "\t").values 11 | self.gamma = gamma 12 | self.theta = theta 13 | 14 | def state_transition(self, state, action): 15 | next_state, reward = None, None 16 | for tr in self.transitions: 17 | if tr[0] == state and tr[1] == action: 18 | next_state = tr[2] 19 | reward = tr[3] 20 | return next_state, reward 21 | 22 | def show_environment(self): 23 | all_states = np.concatenate(([0], self.states, [0])).reshape(4,4) 24 | colors = [] 25 | # colors = ["#ffffff"] 26 | for i in range(len(self.states) + 1): 27 | if i == 0: 28 | colors.append("#c4c4c4") 29 | else: 30 | colors.append("#ffffff") 31 | 32 | cmap = ListedColormap(seaborn.color_palette(colors).as_hex()) 33 | ax = seaborn.heatmap(all_states, cmap = cmap, \ 34 | annot = True, linecolor = "#282828", linewidths = 0.2, \ 35 | cbar = False) -------------------------------------------------------------------------------- /source code/3-Temporal Difference Learning (Chapter 6)/gridworld.txt: -------------------------------------------------------------------------------- 1 | 1 U 1 -1 2 | 1 D 5 -1 3 | 1 R 2 -1 4 | 1 L 0 -1 5 | 2 U 2 -1 6 | 2 D 6 -1 7 | 2 R 3 -1 8 | 2 L 1 -1 9 | 3 U 3 -1 10 | 3 D 7 -1 11 | 3 R 3 -1 12 | 3 L 2 -1 13 | 4 U 0 -1 14 | 4 D 8 -1 15 | 4 R 5 -1 16 | 4 L 4 -1 17 | 5 U 1 -1 18 | 5 D 9 -1 19 | 5 R 6 -1 20 | 5 L 4 -1 21 | 6 U 2 -1 22 | 6 D 10 -1 23 | 6 R 7 -1 24 | 6 L 5 -1 25 | 7 U 3 -1 26 | 7 D 11 -1 27 | 7 R 7 -1 28 | 7 L 6 -1 29 | 8 U 4 -1 30 | 8 D 12 -1 31 | 8 R 9 -1 32 | 8 L 8 -1 33 | 9 U 5 -1 34 | 9 D 13 -1 35 | 9 R 10 -1 36 | 9 L 8 -1 37 | 10 U 6 -1 38 | 10 D 14 -1 39 | 10 R 11 -1 40 | 10 L 9 -1 41 | 11 U 7 -1 42 | 11 D 0 -1 43 | 11 R 11 -1 44 | 11 L 10 -1 45 | 12 U 8 -1 46 | 12 D 12 -1 47 | 12 R 13 -1 48 | 12 L 12 -1 49 | 13 U 9 -1 50 | 13 D 13 -1 51 | 13 R 14 -1 52 | 13 L 12 -1 53 | 14 U 10 -1 54 | 14 D 14 -1 55 | 14 R 0 -1 56 | 14 L 13 -1 -------------------------------------------------------------------------------- /source code/4-n-step Bootstrapping (Chapter 7)/.ipynb_checkpoints/4-n-step-off-policy-learning-wo-importance-sampling-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# $n$-step off-policy learning without importance sampling\n", 8 | "- Algorithms from ```pp. 124 - 125``` in Sutton & Barto 2017\n", 9 | " - $n$-step Tree Backup Algorithm" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": { 16 | "collapsed": true 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "import matplotlib.pyplot as plt\n", 21 | "import pandas as pd\n", 22 | "import numpy as np\n", 23 | "import seaborn, random\n", 24 | "\n", 25 | "from gridWorldEnvironment import GridWorld" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 2, 31 | "metadata": { 32 | "collapsed": true 33 | }, 34 | "outputs": [], 35 | "source": [ 36 | "# creating gridworld environment\n", 37 | "gw = GridWorld(gamma = .9)" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 7, 43 | "metadata": { 44 | "collapsed": true 45 | }, 46 | "outputs": [], 47 | "source": [ 48 | "def state_action_value(env):\n", 49 | " q = dict()\n", 50 | " for state, action, next_state, reward in env.transitions:\n", 51 | " q[(state, action)] = np.random.normal()\n", 52 | " return q" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 9, 58 | "metadata": { 59 | "collapsed": true 60 | }, 61 | "outputs": [], 62 | "source": [ 63 | "def e_greedy(env, e, q, state):\n", 64 | " actions = env.actions\n", 65 | " action_values = []\n", 66 | " prob = []\n", 67 | " for action in actions:\n", 68 | " action_values.append(q[(state, action)])\n", 69 | " for i in range(len(action_values)):\n", 70 | " if i == np.argmax(action_values):\n", 71 | " prob.append(1 - e + e/len(action_values))\n", 72 | " else:\n", 73 | " prob.append(e/len(action_values))\n", 74 | " return actions, prob" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 10, 80 | "metadata": { 81 | "collapsed": true 82 | }, 83 | "outputs": [], 84 | "source": [ 85 | "def generate_e_greedy_policy(env, e, Q):\n", 86 | " pi = dict()\n", 87 | " for state in env.states:\n", 88 | " pi[state] = e_greedy(env, e, Q, state)\n", 89 | " return pi" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 11, 95 | "metadata": { 96 | "collapsed": true 97 | }, 98 | "outputs": [], 99 | "source": [ 100 | "def generate_random_policy(env):\n", 101 | " pi = dict()\n", 102 | " for state in env.states:\n", 103 | " actions = []\n", 104 | " prob = []\n", 105 | " for action in env.actions:\n", 106 | " actions.append(action)\n", 107 | " prob.append(0.25)\n", 108 | " pi[state] = (actions, prob)\n", 109 | " return pi" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 19, 115 | "metadata": { 116 | "collapsed": true 117 | }, 118 | "outputs": [], 119 | "source": [ 120 | "# function for tree backup algorithm\n", 121 | "def avg_over_actions(pi, Q, state):\n", 122 | " actions, probs = pi[state]\n", 123 | " q_values = np.zeros(4)\n", 124 | " for s, a in Q.keys():\n", 125 | " if s == state:\n", 126 | " q_values[actions.index(a)] = Q[s,a]\n", 127 | " return np.dot(q_values, probs)" 128 | ] 129 | }, 130 | { 131 | "cell_type": "markdown", 132 | "metadata": {}, 133 | "source": [ 134 | "### $n$-step off-policy learning without importance sampling\n", 135 | "- The target includes also the estimated values of dangling action nodes hanging off the sides, at all levels" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 32, 141 | "metadata": { 142 | "collapsed": true 143 | }, 144 | "outputs": [], 145 | "source": [ 146 | "def n_step_tree_backup(env, epsilon, alpha, n, num_iter, learn_pi = True):\n", 147 | " Q = state_action_value(env)\n", 148 | " Q_, pi_, delta = dict(), dict(), dict() \n", 149 | " pi = generate_e_greedy_policy(env, epsilon, Q) \n", 150 | "\n", 151 | " for _ in range(num_iter):\n", 152 | " current_state = np.random.choice(env.states)\n", 153 | " action = np.random.choice(b[current_state][0], p = b[current_state][1])\n", 154 | " state_trace, action_trace, reward_trace = [current_state], [action], [0]\n", 155 | " Q_[0] = Q[current_state, action]\n", 156 | " t, T = 0, 10000\n", 157 | " while True:\n", 158 | " if t < T: \n", 159 | " next_state, reward = env.state_transition(current_state, action)\n", 160 | " state_trace.append(next_state)\n", 161 | " reward_trace.append(reward)\n", 162 | " if next_state == 0:\n", 163 | " T = t + 1\n", 164 | " delta[t] = reward - Q_[t]\n", 165 | " else: \n", 166 | " delta[t] = reward + env.gamma * avg_over_actions(pi, Q, next_state) - Q_[t]\n", 167 | " action = np.random.choice(pi[next_state][0], p = pi[next_state][1])\n", 168 | " action_trace.append(action)\n", 169 | " Q_[t+1] = Q[next_state, action]\n", 170 | " pi_[t+1] = pi[next_state][1][pi[next_state][0].index(action)]\n", 171 | " \n", 172 | " tau = t - n + 1\n", 173 | " if tau >= 0:\n", 174 | " Z = 1\n", 175 | " G = Q_[tau]\n", 176 | " for i in range(tau, min([tau + n -1, T-1])):\n", 177 | " G += Z * delta[i]\n", 178 | " Z *= env.gamma * pi_[i+1]\n", 179 | " Q[state_trace[tau], action_trace[tau]] += alpha * (G - Q[state_trace[tau], action_trace[tau]])\n", 180 | " if learn_pi:\n", 181 | " pi[state_trace[tau]] = e_greedy(env, epsilon, Q, state_trace[tau])\n", 182 | " current_state = next_state \n", 183 | "# print(state_trace, action_trace, reward_trace)\n", 184 | " \n", 185 | " if tau == (T-1):\n", 186 | " break\n", 187 | " t += 1\n", 188 | " \n", 189 | " return pi, Q" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": 39, 195 | "metadata": { 196 | "collapsed": true 197 | }, 198 | "outputs": [], 199 | "source": [ 200 | "pi, Q = n_step_tree_backup(gw, 0.2, 0.5, 1, 10000)" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": 40, 206 | "metadata": {}, 207 | "outputs": [ 208 | { 209 | "data": { 210 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAdwAAAFJCAYAAAAxCJwFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACtdJREFUeJzt3UuIlYX/x/Hv6ImKpgskhASBG2cOBCMt2rkYcLoIMRhC\nWkxBzKqFI5jENHahi9oq6DJYuWpR0yppEUimFBS0EBQGzhQRRLdNq1DJ0c75LQT/9Efn18yP8znN\n8fXazfPg83xwOLx9HmZwoNPpdAoA6Ko1vR4AANcCwQWAAMEFgADBBYAAwQWAAMEFgIBGNy/ebDar\ntbDQzVvQJc3h4VoYOdzrGazQ8OnJan3ss7caNbcNV1XVzMKHPV7CSrw6vLNardYVz3nCBYAAwQWA\nAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAA\nwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADB\nBYAAwQWAAMEFgADBBYCAfxzcdrvdzR0A0NcaS5386aef6sCBAzU/P1+NRqPa7XZt3Lixpqena8OG\nDamNALDqLRncmZmZ2rNnT42MjFw+durUqZqenq65ubmujwOAfrHkK+XFxcW/xbaqatOmTV0dBAD9\naMkn3KGhoZqenq7NmzfXzTffXGfPnq0vvviihoaGUvsAoC8sGdwXX3yxjh07VidPnqwzZ87U4OBg\njY6O1tjYWGofAPSFJYM7MDBQY2NjAgsA/yO/hwsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYIL\nAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsA\nAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAEDnU6n062LN5vNbl0aAP6V\nWq3WFY83un3jmYUPu30LuuDV4Z11d2uh1zNYofnmcC2MHO71DFZg+PRkVVW1Pvb5W42a24aves4r\nZQAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBc\nAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwA\nCBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAhoLHVyYmKiLly48LdjnU6nBgYGam5urqvDAKCf\nLBncp59+uvbt21dvv/12rV27NrUJAPrOksEdGRmp8fHx+vbbb2tsbCy1CQD6zpLBraqanJxM7ACA\nvuaHpgAgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHAB\nIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEg\nQHABIEBwASBAcAEgQHABIEBwASBgoNPpdLp18Waz2a1LA8C/UqvVuuLxRrdvfHdrodu3oAvmm8M1\ns/Bhr2ewQq8O76zWxz57q1Fz23BVVS2MHO7xElZi+PTkVc95pQwAAYILAAGCCwABggsAAYILAAGC\nCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYIL\nAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwAByw7u\n4uJiN3YAQF+7anCPHz9eo6OjNTY2Vp9++unl45OTk5FhANBPGlc7cejQoTpy5Ei12+2ampqq8+fP\n17Zt26rT6ST3AUBfuGpwr7vuurr11lurqmp2draeeOKJWr9+fQ0MDMTGAUC/uOor5TvvvLMOHDhQ\n586dq8HBwXrrrbfqpZdeqh9++CG5DwD6wlWDu3///hoaGrr8RLt+/fp6//3368EHH4yNA4B+cdVX\nyo1Gox5++OG/HVu3bl3NzMx0fRQA9Bu/hwsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGC\nCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYIL\nAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABA51Op9OtizebzW5dGgD+\nlVqt1hWPN7p947tbC92+BV0w3xz2vVvF5pvDtTByuNczWIHh05NVVdVa8PlbjZrDw1c955UyAAQI\nLgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAgu\nAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4A\nBAguAAQILgAECC4ABCwruH/++WctLi52awsA9K0lg/v999/XU089VdPT0/X111/X1q1ba+vWrXXi\nxInUPgDoC42lTr7wwgs1NTVVv/zyS+3atauOHj1a119/fU1OTtbo6GhqIwCseksGt91u17333ltV\nVd98803dfvvtl/5QY8k/BgD8P0u+Ut6wYUPNzMxUu92ugwcPVlXVu+++W+vWrYuMA4B+seSj6iuv\nvFLHjx+vNWv+r8t33HFHTUxMdH0YAPSTJYO7Zs2a2rJly9+OjY+Pd3UQAPQjv4cLAAGCCwABggsA\nAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwAB\nggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGC\nCwABA51Op9OtizebzW5dGgD+lVqt1hWPdzW4AMAlXikDQIDgAkCA4AJAgOACQIDgAkCA4AJAgOCu\nQLvdrueff74eeeSRmpiYqB9//LHXk1im06dP18TERK9nsEwXLlyovXv31qOPPlrbt2+vzz//vNeT\n+If++uuvmp6erh07dtTOnTvru+++6/WkOMFdgWPHjtXi4mJ99NFHtWfPnjp48GCvJ7EM7733Xu3b\nt6/Onz/f6yks0yeffFK33XZbffDBB3X48OF6+eWXez2Jf+jEiRNVVTU3N1e7d++u119/vceL8gR3\nBU6ePFmbN2+uqqpNmzbV/Px8jxexHHfddVe9+eabvZ7BCjzwwAM1NTVVVVWdTqfWrl3b40X8U1u2\nbLn8D6Rff/21brnllh4vymv0esBqdObMmRocHLz89dq1a+vixYvVaPjrXA3uv//++vnnn3s9gxW4\n6aabqurSZ3DXrl21e/fuHi9iORqNRj3zzDP12Wef1RtvvNHrOXGecFdgcHCwzp49e/nrdrstthDy\n22+/1eOPP17j4+P10EMP9XoOy/Taa6/V0aNH67nnnqtz5871ek6U4K7APffcU19++WVVVZ06dao2\nbtzY40Vwbfj999/rySefrL1799b27dt7PYdlOHLkSL3zzjtVVXXjjTfWwMBArVlzbSXIY9kKjI2N\n1VdffVU7duyoTqdT+/fv7/UkuCYcOnSo/vjjj5qdna3Z2dmquvRDcDfccEOPl/Hf3HfffTU9PV2P\nPfZYXbx4sZ599tlr7vvmfwsCgIBr63keAHpEcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEg4D+A\nZKWtiwAl/gAAAABJRU5ErkJggg==\n", 211 | "text/plain": [ 212 | "" 213 | ] 214 | }, 215 | "metadata": {}, 216 | "output_type": "display_data" 217 | } 218 | ], 219 | "source": [ 220 | "### RED = TERMINAL (0)\n", 221 | "### GREEN = LEFT\n", 222 | "### BLUE = UP\n", 223 | "### PURPLE = RIGHT\n", 224 | "### ORANGE = DOWN\n", 225 | "\n", 226 | "show_policy(pi, gw)" 227 | ] 228 | } 229 | ], 230 | "metadata": { 231 | "kernelspec": { 232 | "display_name": "Python 3", 233 | "language": "python", 234 | "name": "python3" 235 | }, 236 | "language_info": { 237 | "codemirror_mode": { 238 | "name": "ipython", 239 | "version": 3 240 | }, 241 | "file_extension": ".py", 242 | "mimetype": "text/x-python", 243 | "name": "python", 244 | "nbconvert_exporter": "python", 245 | "pygments_lexer": "ipython3", 246 | "version": "3.6.1" 247 | } 248 | }, 249 | "nbformat": 4, 250 | "nbformat_minor": 2 251 | } 252 | -------------------------------------------------------------------------------- /source code/4-n-step Bootstrapping (Chapter 7)/2-n-step-SARSA-on-policy-control.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# $n$-step SARSA\n", 8 | "- Algorithms from ```pp. 119 - 120``` in Sutton & Barto 2017" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "metadata": { 15 | "collapsed": true 16 | }, 17 | "outputs": [], 18 | "source": [ 19 | "import matplotlib.pyplot as plt\n", 20 | "import pandas as pd\n", 21 | "import numpy as np\n", 22 | "import seaborn, random\n", 23 | "\n", 24 | "from gridWorldEnvironment import GridWorld" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 3, 30 | "metadata": { 31 | "collapsed": true 32 | }, 33 | "outputs": [], 34 | "source": [ 35 | "# creating gridworld environment\n", 36 | "gw = GridWorld(gamma = .9)" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 59, 42 | "metadata": { 43 | "collapsed": true 44 | }, 45 | "outputs": [], 46 | "source": [ 47 | "def state_action_value(env):\n", 48 | " q = dict()\n", 49 | " for state, action, next_state, reward in env.transitions:\n", 50 | " q[(state, action)] = np.random.normal()\n", 51 | " return q" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 60, 57 | "metadata": { 58 | "scrolled": true 59 | }, 60 | "outputs": [ 61 | { 62 | "data": { 63 | "text/plain": [ 64 | "{(1, 'D'): -0.7030356916846013,\n", 65 | " (1, 'L'): -0.324895205360978,\n", 66 | " (1, 'R'): -1.1734677062639605,\n", 67 | " (1, 'U'): -0.16629963855724197,\n", 68 | " (2, 'D'): -0.23584377637184736,\n", 69 | " (2, 'L'): 1.4901073639401272,\n", 70 | " (2, 'R'): 0.6934460377797941,\n", 71 | " (2, 'U'): 0.8636095279329232,\n", 72 | " (3, 'D'): -1.609655563626268,\n", 73 | " (3, 'L'): 0.9311445594943784,\n", 74 | " (3, 'R'): 0.8985239654154085,\n", 75 | " (3, 'U'): -0.4080380875477237,\n", 76 | " (4, 'D'): -0.17538315728764214,\n", 77 | " (4, 'L'): -1.2559883445535147,\n", 78 | " (4, 'R'): -0.8940561790577116,\n", 79 | " (4, 'U'): 0.21171585977065896,\n", 80 | " (5, 'D'): 0.22188493818420746,\n", 81 | " (5, 'L'): 0.2624753401534311,\n", 82 | " (5, 'R'): -0.47795119564774396,\n", 83 | " (5, 'U'): 1.4805795529303916,\n", 84 | " (6, 'D'): -0.9419847476859503,\n", 85 | " (6, 'L'): -1.2209175101970824,\n", 86 | " (6, 'R'): 0.39723157885321236,\n", 87 | " (6, 'U'): 0.9834147564834003,\n", 88 | " (7, 'D'): 0.6846683809603875,\n", 89 | " (7, 'L'): -3.065126625005005,\n", 90 | " (7, 'R'): -0.9160309917151935,\n", 91 | " (7, 'U'): 0.6802182059638696,\n", 92 | " (8, 'D'): -2.079437344850202,\n", 93 | " (8, 'L'): -0.7871927916410061,\n", 94 | " (8, 'R'): 0.05422301115908965,\n", 95 | " (8, 'U'): -0.3719669376829563,\n", 96 | " (9, 'D'): 0.7168977139674119,\n", 97 | " (9, 'L'): -0.9203614384252929,\n", 98 | " (9, 'R'): -0.02171286359518462,\n", 99 | " (9, 'U'): 0.7715779552124056,\n", 100 | " (10, 'D'): -0.3829133127865197,\n", 101 | " (10, 'L'): 0.37456870550391924,\n", 102 | " (10, 'R'): 1.9331952589125025,\n", 103 | " (10, 'U'): -0.8416217832093862,\n", 104 | " (11, 'D'): -0.6215265359549603,\n", 105 | " (11, 'L'): 0.2577485614406314,\n", 106 | " (11, 'R'): 1.6237960742580893,\n", 107 | " (11, 'U'): -2.215032302441923,\n", 108 | " (12, 'D'): -0.12021109028974186,\n", 109 | " (12, 'L'): 0.1799993340884875,\n", 110 | " (12, 'R'): -0.19649725248693356,\n", 111 | " (12, 'U'): -0.873336849248908,\n", 112 | " (13, 'D'): 0.3872626481962563,\n", 113 | " (13, 'L'): 1.477848159613166,\n", 114 | " (13, 'R'): -0.5006331850401944,\n", 115 | " (13, 'U'): 0.42855578341556333,\n", 116 | " (14, 'D'): 0.5288272721071394,\n", 117 | " (14, 'L'): 1.0695144596153465,\n", 118 | " (14, 'R'): 2.0397741046686506,\n", 119 | " (14, 'U'): 0.6012823125125923}" 120 | ] 121 | }, 122 | "execution_count": 60, 123 | "metadata": {}, 124 | "output_type": "execute_result" 125 | } 126 | ], 127 | "source": [ 128 | "state_action_value(gw)" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 5, 134 | "metadata": { 135 | "collapsed": true 136 | }, 137 | "outputs": [], 138 | "source": [ 139 | "def generate_greedy_policy(env, Q):\n", 140 | " pi = dict()\n", 141 | " for state in env.states:\n", 142 | " actions = []\n", 143 | " q_values = []\n", 144 | " prob = []\n", 145 | " \n", 146 | " for a in env.actions:\n", 147 | " actions.append(a)\n", 148 | " q_values.append(Q[state,a]) \n", 149 | " for i in range(len(q_values)):\n", 150 | " if i == np.argmax(q_values):\n", 151 | " prob.append(1)\n", 152 | " else:\n", 153 | " prob.append(0) \n", 154 | " \n", 155 | " pi[state] = (actions, prob)\n", 156 | " return pi" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": 6, 162 | "metadata": { 163 | "collapsed": true 164 | }, 165 | "outputs": [], 166 | "source": [ 167 | "def e_greedy(env, e, q, state):\n", 168 | " actions = env.actions\n", 169 | " action_values = []\n", 170 | " prob = []\n", 171 | " for action in actions:\n", 172 | " action_values.append(q[(state, action)])\n", 173 | " for i in range(len(action_values)):\n", 174 | " if i == np.argmax(action_values):\n", 175 | " prob.append(1 - e + e/len(action_values))\n", 176 | " else:\n", 177 | " prob.append(e/len(action_values))\n", 178 | " return actions, prob" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 7, 184 | "metadata": { 185 | "collapsed": true 186 | }, 187 | "outputs": [], 188 | "source": [ 189 | "# e-greedy policy is an extension of e_greedy()\n", 190 | "def generate_e_greedy_policy(env, e, Q):\n", 191 | " pi = dict()\n", 192 | " for state in env.states:\n", 193 | " pi[state] = e_greedy(env, e, Q, state)\n", 194 | " return pi" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": 8, 200 | "metadata": {}, 201 | "outputs": [ 202 | { 203 | "data": { 204 | "text/plain": [ 205 | "{1: (('U', 'D', 'L', 'R'), [0.025, 0.025, 0.925, 0.025]),\n", 206 | " 2: (('U', 'D', 'L', 'R'), [0.025, 0.025, 0.925, 0.025]),\n", 207 | " 3: (('U', 'D', 'L', 'R'), [0.025, 0.025, 0.925, 0.025]),\n", 208 | " 4: (('U', 'D', 'L', 'R'), [0.025, 0.025, 0.025, 0.925]),\n", 209 | " 5: (('U', 'D', 'L', 'R'), [0.025, 0.025, 0.925, 0.025]),\n", 210 | " 6: (('U', 'D', 'L', 'R'), [0.025, 0.925, 0.025, 0.025]),\n", 211 | " 7: (('U', 'D', 'L', 'R'), [0.025, 0.025, 0.925, 0.025]),\n", 212 | " 8: (('U', 'D', 'L', 'R'), [0.025, 0.925, 0.025, 0.025]),\n", 213 | " 9: (('U', 'D', 'L', 'R'), [0.025, 0.925, 0.025, 0.025]),\n", 214 | " 10: (('U', 'D', 'L', 'R'), [0.025, 0.025, 0.025, 0.925]),\n", 215 | " 11: (('U', 'D', 'L', 'R'), [0.925, 0.025, 0.025, 0.025]),\n", 216 | " 12: (('U', 'D', 'L', 'R'), [0.025, 0.925, 0.025, 0.025]),\n", 217 | " 13: (('U', 'D', 'L', 'R'), [0.025, 0.025, 0.925, 0.025]),\n", 218 | " 14: (('U', 'D', 'L', 'R'), [0.025, 0.025, 0.025, 0.925])}" 219 | ] 220 | }, 221 | "execution_count": 8, 222 | "metadata": {}, 223 | "output_type": "execute_result" 224 | } 225 | ], 226 | "source": [ 227 | "generate_e_greedy_policy(gw, 0.1, state_action_value(gw))" 228 | ] 229 | }, 230 | { 231 | "cell_type": "markdown", 232 | "metadata": { 233 | "collapsed": true 234 | }, 235 | "source": [ 236 | "### $n$-step SARSA: On-policy TD Control" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": 61, 242 | "metadata": { 243 | "collapsed": true 244 | }, 245 | "outputs": [], 246 | "source": [ 247 | "def n_step_sarsa(env, epsilon, alpha, n, num_iter, learn_pi = True):\n", 248 | " Q = state_action_value(env)\n", 249 | " pi = generate_e_greedy_policy(env, epsilon, Q)\n", 250 | " for _ in range(num_iter):\n", 251 | " current_state = np.random.choice(env.states)\n", 252 | " action = np.random.choice(pi[current_state][0], p = pi[current_state][1])\n", 253 | " state_trace, action_trace, reward_trace = [current_state], [action], [0]\n", 254 | " t, T = 0, 10000\n", 255 | " while True:\n", 256 | " if t < T: \n", 257 | " next_state, reward = env.state_transition(current_state, action)\n", 258 | " state_trace.append(next_state)\n", 259 | " reward_trace.append(reward)\n", 260 | " if next_state == 0:\n", 261 | " T = t + 1\n", 262 | " else: \n", 263 | " action = np.random.choice(pi[next_state][0], p = pi[next_state][1])\n", 264 | " action_trace.append(action)\n", 265 | " \n", 266 | " tau = t - n + 1 # tau designates the time step of estimate being update\n", 267 | " if tau >= 0:\n", 268 | " G = 0\n", 269 | " for i in range(tau + 1, min([tau + n, T]) + 1):\n", 270 | " G += (env.gamma ** (i - tau - 1)) * reward_trace[i-1]\n", 271 | " if tau + n < T:\n", 272 | " G += (env.gamma ** n) * Q[state_trace[tau + n], action_trace[tau + n]]\n", 273 | " Q[state_trace[tau], action_trace[tau]] += alpha * (G - Q[state_trace[tau], action_trace[tau]])\n", 274 | " \n", 275 | " # current policy, pi, can be learned each step, or not\n", 276 | " if learn_pi:\n", 277 | " pi[state_trace[tau]] = e_greedy(env, epsilon, Q, state_trace[tau])\n", 278 | " current_state = next_state \n", 279 | " \n", 280 | " if tau == (T-1):\n", 281 | " break\n", 282 | " t += 1\n", 283 | " \n", 284 | " return pi, Q" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": 62, 290 | "metadata": { 291 | "collapsed": true 292 | }, 293 | "outputs": [], 294 | "source": [ 295 | "pi, Q = n_step_sarsa(gw, 0.2, 0.5, 3, 1000)" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": 63, 301 | "metadata": {}, 302 | "outputs": [ 303 | { 304 | "data": { 305 | "text/plain": [ 306 | "{1: (('U', 'D', 'L', 'R'), [0.05, 0.05, 0.8500000000000001, 0.05]),\n", 307 | " 2: (('U', 'D', 'L', 'R'), [0.05, 0.05, 0.8500000000000001, 0.05]),\n", 308 | " 3: (('U', 'D', 'L', 'R'), [0.05, 0.8500000000000001, 0.05, 0.05]),\n", 309 | " 4: (('U', 'D', 'L', 'R'), [0.8500000000000001, 0.05, 0.05, 0.05]),\n", 310 | " 5: (('U', 'D', 'L', 'R'), [0.8500000000000001, 0.05, 0.05, 0.05]),\n", 311 | " 6: (('U', 'D', 'L', 'R'), [0.05, 0.8500000000000001, 0.05, 0.05]),\n", 312 | " 7: (('U', 'D', 'L', 'R'), [0.05, 0.8500000000000001, 0.05, 0.05]),\n", 313 | " 8: (('U', 'D', 'L', 'R'), [0.8500000000000001, 0.05, 0.05, 0.05]),\n", 314 | " 9: (('U', 'D', 'L', 'R'), [0.05, 0.8500000000000001, 0.05, 0.05]),\n", 315 | " 10: (('U', 'D', 'L', 'R'), [0.05, 0.05, 0.05, 0.8500000000000001]),\n", 316 | " 11: (('U', 'D', 'L', 'R'), [0.05, 0.8500000000000001, 0.05, 0.05]),\n", 317 | " 12: (('U', 'D', 'L', 'R'), [0.05, 0.05, 0.05, 0.8500000000000001]),\n", 318 | " 13: (('U', 'D', 'L', 'R'), [0.05, 0.05, 0.05, 0.8500000000000001]),\n", 319 | " 14: (('U', 'D', 'L', 'R'), [0.05, 0.05, 0.05, 0.8500000000000001])}" 320 | ] 321 | }, 322 | "execution_count": 63, 323 | "metadata": {}, 324 | "output_type": "execute_result" 325 | } 326 | ], 327 | "source": [ 328 | "pi" 329 | ] 330 | }, 331 | { 332 | "cell_type": "markdown", 333 | "metadata": {}, 334 | "source": [ 335 | "### Visualizing policy" 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": 65, 341 | "metadata": { 342 | "collapsed": true 343 | }, 344 | "outputs": [], 345 | "source": [ 346 | "def show_policy(pi, env):\n", 347 | " temp = np.zeros(len(env.states) + 2)\n", 348 | " for s in env.states:\n", 349 | " a = pi[s][0][np.argmax(pi[s][1])]\n", 350 | " if a == \"U\":\n", 351 | " temp[s] = 0.25\n", 352 | " elif a == \"D\":\n", 353 | " temp[s] = 0.5\n", 354 | " elif a == \"R\":\n", 355 | " temp[s] = 0.75\n", 356 | " else:\n", 357 | " temp[s] = 1.0\n", 358 | " \n", 359 | " temp = temp.reshape(4,4)\n", 360 | " ax = seaborn.heatmap(temp, cmap = \"prism\", linecolor=\"#282828\", cbar = False, linewidths = 0.1)\n", 361 | " plt.show()" 362 | ] 363 | }, 364 | { 365 | "cell_type": "code", 366 | "execution_count": 66, 367 | "metadata": {}, 368 | "outputs": [ 369 | { 370 | "data": { 371 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAdwAAAFJCAYAAAAxCJwFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACsFJREFUeJzt3U2IlQXfx/H/OCcqml4gISQI3Dhz4AGlRTsXA04vQgyG\nkBZTELNq4QgmcRp7oRe1VdDLYOWqRfe0SloEkikFBS0EBeFMEUH0tmkVKjnaOfciHh+60bmbeTi/\n0xw/n91cF17nhxfD1+tiBoe63W63AICeWtPvAQBwLRBcAAgQXAAIEFwACBBcAAgQXAAIaPTy4s1m\ns9oLC738CHqkOTZW/9N271arM033b7U60xyrqqr2h+7fatTcNlbtdvuK5zzhAkCA4AJAgOACQIDg\nAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOAC\nQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJA\ngOACQIDgAkDA3w5up9Pp5Q4AGGiNpU7+8MMPdeDAgTpz5kw1Go3qdDq1YcOGarVatX79+tRGAFj1\nlgzu7Oxs7dmzpzZu3Hj52KlTp6rVatX8/HzPxwHAoFjylfLi4uJfYltVtWnTpp4OAoBBtOQT7ujo\naLVardq8eXPdfPPNde7cufrss89qdHQ0tQ8ABsKSwX3hhRfq2LFjdfLkyTp79myNjIzU+Ph4TUxM\npPYBwEBYMrhDQ0M1MTEhsADw/+T3cAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBA\ncAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBw\nASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIGCo2+12e3XxZrPZq0sDwD9Su92+4vFG\nrz94YePhXn8EPTB2etq9W8XGTk9X+8OFfs9gBZrbxqqq3L9V6n/v35V4pQwAAYILAAGCCwABggsA\nAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwAB\nggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGC\nCwABggsAAYILAAGCCwABjaVOTk1N1cWLF/9yrNvt1tDQUM3Pz/d0GAAMkiWD+9RTT9W+ffvqrbfe\nquHh4dQmABg4SwZ348aNNTk5WV9//XVNTEykNgHAwFkyuFVV09PTiR0AMND80BQABAguAAQILgAE\nCC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQI\nLgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAgu\nAAQILgAEDHW73W6vLt5sNnt1aQD4R2q321c83uj1By9sPNzrj6AHxk5PV/vDhX7PYIWa28ZqduFf\n/Z7BCrwytrOqyvffKtXcNnbVc14pA0CA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDg\nAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOAC\nQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQMCyg7u4uNiLHQAw0K4a3OPHj9f4\n+HhNTEzUxx9/fPn49PR0ZBgADJLG1U4cOnSojhw5Up1Op2ZmZurChQu1bdu26na7yX0AMBCuGtzr\nrruubr311qqqmpubq8cff7zWrVtXQ0NDsXEAMCiu+kr5zjvvrAMHDtT58+drZGSk3nzzzXrxxRfr\nu+++S+4DgIFw1eDu37+/RkdHLz/Rrlu3rt5777164IEHYuMAYFBc9ZVyo9Gohx566C/H1q5dW7Oz\nsz0fBQCDxu/hAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDg\nAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQIDgAkCA4AJAgOAC\nQIDgAkCA4AJAgOACQIDgAkCA4AJAgOACQMBQt9vt9urizWazV5cGgH+kdrt9xeONXn/w7MK/ev0R\n9MArYzvdu1XM/Vu9XhnbWVVV7YWFPi9hJZpjY1c955UyAAQILgAECC4ABAguAAQILgAECC4ABAgu\nAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4A\nBAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABCwruL///nst\nLi72agsADKwlg/vtt9/Wk08+Wa1Wq7788svaunVrbd26tU6cOJHaBwADobHUyeeff75mZmbqp59+\nql27dtXRo0fr+uuvr+np6RofH09tBIBVb8ngdjqduueee6qq6quvvqrbb7/9zz/UWPKPAQD/YclX\nyuvXr6/Z2dnqdDp18ODBqqp65513au3atZFxADAolnxUffnll+v48eO1Zs3/dfmOO+6oqampng8D\ngEGyZHDXrFlTW7Zs+cuxycnJng4CgEHk93ABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBA\ncAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBw\nASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBgqNvtdnt18Waz2atLA8A/Urvd\nvuLxngYXAPiTV8oAECC4ABAguAAQILgAECC4ABAguAAQILgr0Ol06rnnnquHH364pqam6vvvv+/3\nJJbp9OnTNTU11e8ZLNPFixdr79699cgjj9T27dvr008/7fck/qY//vijWq1W7dixo3bu3FnffPNN\nvyfFCe4KHDt2rBYXF+uDDz6oPXv21MGDB/s9iWV49913a9++fXXhwoV+T2GZPvroo7rtttvq/fff\nr8OHD9dLL73U70n8TSdOnKiqqvn5+dq9e3e99tprfV6UJ7grcPLkydq8eXNVVW3atKnOnDnT50Us\nx1133VVvvPFGv2ewAvfff3/NzMxUVVW3263h4eE+L+Lv2rJly+V/IP388891yy239HlRXqPfA1aj\ns2fP1sjIyOWvh4eH69KlS9Vo+OtcDe6777768ccf+z2DFbjpppuq6s/vwV27dtXu3bv7vIjlaDQa\n9fTTT9cnn3xSr7/+er/nxHnCXYGRkZE6d+7c5a87nY7YQsgvv/xSjz32WE1OTtaDDz7Y7zks06uv\nvlpHjx6tZ599ts6fP9/vOVGCuwJ33313ff7551VVderUqdqwYUOfF8G14ddff60nnnii9u7dW9u3\nb+/3HJbhyJEj9fbbb1dV1Y033lhDQ0O1Zs21lSCPZSswMTFRX3zxRe3YsaO63W7t37+/35PgmnDo\n0KH67bffam5urubm5qrqzx+Cu+GGG/q8jP/m3nvvrVarVY8++mhdunSpnnnmmWvuvvnfggAg4Np6\nngeAPhFcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAj4N9pmpa1W3XhOAAAAAElFTkSuQmCC\n", 372 | "text/plain": [ 373 | "" 374 | ] 375 | }, 376 | "metadata": {}, 377 | "output_type": "display_data" 378 | } 379 | ], 380 | "source": [ 381 | "### RED = TERMINAL (0)\n", 382 | "### GREEN = LEFT\n", 383 | "### BLUE = UP\n", 384 | "### PURPLE = RIGHT\n", 385 | "### ORANGE = DOWN\n", 386 | "\n", 387 | "show_policy(pi, gw)" 388 | ] 389 | } 390 | ], 391 | "metadata": { 392 | "kernelspec": { 393 | "display_name": "Python 3", 394 | "language": "python", 395 | "name": "python3" 396 | }, 397 | "language_info": { 398 | "codemirror_mode": { 399 | "name": "ipython", 400 | "version": 3 401 | }, 402 | "file_extension": ".py", 403 | "mimetype": "text/x-python", 404 | "name": "python", 405 | "nbconvert_exporter": "python", 406 | "pygments_lexer": "ipython3", 407 | "version": "3.6.1" 408 | } 409 | }, 410 | "nbformat": 4, 411 | "nbformat_minor": 2 412 | } 413 | -------------------------------------------------------------------------------- /source code/4-n-step Bootstrapping (Chapter 7)/4-n-step-off-policy-learning-wo-importance-sampling.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# $n$-step off-policy learning without importance sampling\n", 8 | "- Algorithms from ```pp. 124 - 125``` in Sutton & Barto 2017\n", 9 | " - $n$-step Tree Backup Algorithm" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": { 16 | "collapsed": true 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "import matplotlib.pyplot as plt\n", 21 | "import pandas as pd\n", 22 | "import numpy as np\n", 23 | "import seaborn, random\n", 24 | "\n", 25 | "from gridWorldEnvironment import GridWorld" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 2, 31 | "metadata": { 32 | "collapsed": true 33 | }, 34 | "outputs": [], 35 | "source": [ 36 | "# creating gridworld environment\n", 37 | "gw = GridWorld(gamma = .9)" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 7, 43 | "metadata": { 44 | "collapsed": true 45 | }, 46 | "outputs": [], 47 | "source": [ 48 | "def state_action_value(env):\n", 49 | " q = dict()\n", 50 | " for state, action, next_state, reward in env.transitions:\n", 51 | " q[(state, action)] = np.random.normal()\n", 52 | " return q" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 9, 58 | "metadata": { 59 | "collapsed": true 60 | }, 61 | "outputs": [], 62 | "source": [ 63 | "def e_greedy(env, e, q, state):\n", 64 | " actions = env.actions\n", 65 | " action_values = []\n", 66 | " prob = []\n", 67 | " for action in actions:\n", 68 | " action_values.append(q[(state, action)])\n", 69 | " for i in range(len(action_values)):\n", 70 | " if i == np.argmax(action_values):\n", 71 | " prob.append(1 - e + e/len(action_values))\n", 72 | " else:\n", 73 | " prob.append(e/len(action_values))\n", 74 | " return actions, prob" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 10, 80 | "metadata": { 81 | "collapsed": true 82 | }, 83 | "outputs": [], 84 | "source": [ 85 | "def generate_e_greedy_policy(env, e, Q):\n", 86 | " pi = dict()\n", 87 | " for state in env.states:\n", 88 | " pi[state] = e_greedy(env, e, Q, state)\n", 89 | " return pi" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 11, 95 | "metadata": { 96 | "collapsed": true 97 | }, 98 | "outputs": [], 99 | "source": [ 100 | "def generate_random_policy(env):\n", 101 | " pi = dict()\n", 102 | " for state in env.states:\n", 103 | " actions = []\n", 104 | " prob = []\n", 105 | " for action in env.actions:\n", 106 | " actions.append(action)\n", 107 | " prob.append(0.25)\n", 108 | " pi[state] = (actions, prob)\n", 109 | " return pi" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 19, 115 | "metadata": { 116 | "collapsed": true 117 | }, 118 | "outputs": [], 119 | "source": [ 120 | "# function for tree backup algorithm\n", 121 | "def avg_over_actions(pi, Q, state):\n", 122 | " actions, probs = pi[state]\n", 123 | " q_values = np.zeros(4)\n", 124 | " for s, a in Q.keys():\n", 125 | " if s == state:\n", 126 | " q_values[actions.index(a)] = Q[s,a]\n", 127 | " return np.dot(q_values, probs)" 128 | ] 129 | }, 130 | { 131 | "cell_type": "markdown", 132 | "metadata": {}, 133 | "source": [ 134 | "### $n$-step off-policy learning without importance sampling\n", 135 | "- The target includes also the estimated values of dangling action nodes hanging off the sides, at all levels" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 32, 141 | "metadata": { 142 | "collapsed": true 143 | }, 144 | "outputs": [], 145 | "source": [ 146 | "def n_step_tree_backup(env, epsilon, alpha, n, num_iter, learn_pi = True):\n", 147 | " Q = state_action_value(env)\n", 148 | " Q_, pi_, delta = dict(), dict(), dict() \n", 149 | " pi = generate_e_greedy_policy(env, epsilon, Q) \n", 150 | "\n", 151 | " for _ in range(num_iter):\n", 152 | " current_state = np.random.choice(env.states)\n", 153 | " action = np.random.choice(b[current_state][0], p = b[current_state][1])\n", 154 | " state_trace, action_trace, reward_trace = [current_state], [action], [0]\n", 155 | " Q_[0] = Q[current_state, action]\n", 156 | " t, T = 0, 10000\n", 157 | " while True:\n", 158 | " if t < T: \n", 159 | " next_state, reward = env.state_transition(current_state, action)\n", 160 | " state_trace.append(next_state)\n", 161 | " reward_trace.append(reward)\n", 162 | " if next_state == 0:\n", 163 | " T = t + 1\n", 164 | " delta[t] = reward - Q_[t]\n", 165 | " else: \n", 166 | " delta[t] = reward + env.gamma * avg_over_actions(pi, Q, next_state) - Q_[t]\n", 167 | " action = np.random.choice(pi[next_state][0], p = pi[next_state][1])\n", 168 | " action_trace.append(action)\n", 169 | " Q_[t+1] = Q[next_state, action]\n", 170 | " pi_[t+1] = pi[next_state][1][pi[next_state][0].index(action)]\n", 171 | " \n", 172 | " tau = t - n + 1\n", 173 | " if tau >= 0:\n", 174 | " Z = 1\n", 175 | " G = Q_[tau]\n", 176 | " for i in range(tau, min([tau + n -1, T-1])):\n", 177 | " G += Z * delta[i]\n", 178 | " Z *= env.gamma * pi_[i+1]\n", 179 | " Q[state_trace[tau], action_trace[tau]] += alpha * (G - Q[state_trace[tau], action_trace[tau]])\n", 180 | " if learn_pi:\n", 181 | " pi[state_trace[tau]] = e_greedy(env, epsilon, Q, state_trace[tau])\n", 182 | " current_state = next_state \n", 183 | "# print(state_trace, action_trace, reward_trace)\n", 184 | " \n", 185 | " if tau == (T-1):\n", 186 | " break\n", 187 | " t += 1\n", 188 | " \n", 189 | " return pi, Q" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": 39, 195 | "metadata": { 196 | "collapsed": true 197 | }, 198 | "outputs": [], 199 | "source": [ 200 | "pi, Q = n_step_tree_backup(gw, 0.2, 0.5, 1, 10000)" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": 40, 206 | "metadata": {}, 207 | "outputs": [ 208 | { 209 | "data": { 210 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAdwAAAFJCAYAAAAxCJwFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACtdJREFUeJzt3UuIlYX/x/Hv6ImKpgskhASBG2cOBCMt2rkYcLoIMRhC\nWkxBzKqFI5jENHahi9oq6DJYuWpR0yppEUimFBS0EBQGzhQRRLdNq1DJ0c75LQT/9Efn18yP8znN\n8fXazfPg83xwOLx9HmZwoNPpdAoA6Ko1vR4AANcCwQWAAMEFgADBBYAAwQWAAMEFgIBGNy/ebDar\ntbDQzVvQJc3h4VoYOdzrGazQ8OnJan3ss7caNbcNV1XVzMKHPV7CSrw6vLNardYVz3nCBYAAwQWA\nAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAA\nwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADBBYAAwQWAAMEFgADB\nBYAAwQWAAMEFgADBBYCAfxzcdrvdzR0A0NcaS5386aef6sCBAzU/P1+NRqPa7XZt3Lixpqena8OG\nDamNALDqLRncmZmZ2rNnT42MjFw+durUqZqenq65ubmujwOAfrHkK+XFxcW/xbaqatOmTV0dBAD9\naMkn3KGhoZqenq7NmzfXzTffXGfPnq0vvviihoaGUvsAoC8sGdwXX3yxjh07VidPnqwzZ87U4OBg\njY6O1tjYWGofAPSFJYM7MDBQY2NjAgsA/yO/hwsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYIL\nAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsA\nAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAEDnU6n062LN5vNbl0aAP6V\nWq3WFY83un3jmYUPu30LuuDV4Z11d2uh1zNYofnmcC2MHO71DFZg+PRkVVW1Pvb5W42a24aves4r\nZQAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBc\nAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAgQXAAIEFwA\nCBBcAAgQXAAIEFwACBBcAAgQXAAIEFwACBBcAAhoLHVyYmKiLly48LdjnU6nBgYGam5urqvDAKCf\nLBncp59+uvbt21dvv/12rV27NrUJAPrOksEdGRmp8fHx+vbbb2tsbCy1CQD6zpLBraqanJxM7ACA\nvuaHpgAgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHAB\nIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEg\nQHABIEBwASBAcAEgQHABIEBwASBgoNPpdLp18Waz2a1LA8C/UqvVuuLxRrdvfHdrodu3oAvmm8M1\ns/Bhr2ewQq8O76zWxz57q1Fz23BVVS2MHO7xElZi+PTkVc95pQwAAYILAAGCCwABggsAAYILAAGC\nCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYIL\nAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwAByw7u\n4uJiN3YAQF+7anCPHz9eo6OjNTY2Vp9++unl45OTk5FhANBPGlc7cejQoTpy5Ei12+2ampqq8+fP\n17Zt26rT6ST3AUBfuGpwr7vuurr11lurqmp2draeeOKJWr9+fQ0MDMTGAUC/uOor5TvvvLMOHDhQ\n586dq8HBwXrrrbfqpZdeqh9++CG5DwD6wlWDu3///hoaGrr8RLt+/fp6//3368EHH4yNA4B+cdVX\nyo1Gox5++OG/HVu3bl3NzMx0fRQA9Bu/hwsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGC\nCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYIL\nAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABA51Op9OtizebzW5dGgD+\nlVqt1hWPN7p947tbC92+BV0w3xz2vVvF5pvDtTByuNczWIHh05NVVdVa8PlbjZrDw1c955UyAAQI\nLgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAgu\nAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4ABAguAAQILgAECC4A\nBAguAAQILgAECC4ABCwruH/++WctLi52awsA9K0lg/v999/XU089VdPT0/X111/X1q1ba+vWrXXi\nxInUPgDoC42lTr7wwgs1NTVVv/zyS+3atauOHj1a119/fU1OTtbo6GhqIwCseksGt91u17333ltV\nVd98803dfvvtl/5QY8k/BgD8P0u+Ut6wYUPNzMxUu92ugwcPVlXVu+++W+vWrYuMA4B+seSj6iuv\nvFLHjx+vNWv+r8t33HFHTUxMdH0YAPSTJYO7Zs2a2rJly9+OjY+Pd3UQAPQjv4cLAAGCCwABggsA\nAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwAB\nggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGCCwABggsAAYILAAGC\nCwABA51Op9OtizebzW5dGgD+lVqt1hWPdzW4AMAlXikDQIDgAkCA4AJAgOACQIDgAkCA4AJAgOCu\nQLvdrueff74eeeSRmpiYqB9//LHXk1im06dP18TERK9nsEwXLlyovXv31qOPPlrbt2+vzz//vNeT\n+If++uuvmp6erh07dtTOnTvru+++6/WkOMFdgWPHjtXi4mJ99NFHtWfPnjp48GCvJ7EM7733Xu3b\nt6/Onz/f6yks0yeffFK33XZbffDBB3X48OF6+eWXez2Jf+jEiRNVVTU3N1e7d++u119/vceL8gR3\nBU6ePFmbN2+uqqpNmzbV/Px8jxexHHfddVe9+eabvZ7BCjzwwAM1NTVVVVWdTqfWrl3b40X8U1u2\nbLn8D6Rff/21brnllh4vymv0esBqdObMmRocHLz89dq1a+vixYvVaPjrXA3uv//++vnnn3s9gxW4\n6aabqurSZ3DXrl21e/fuHi9iORqNRj3zzDP12Wef1RtvvNHrOXGecFdgcHCwzp49e/nrdrstthDy\n22+/1eOPP17j4+P10EMP9XoOy/Taa6/V0aNH67nnnqtz5871ek6U4K7APffcU19++WVVVZ06dao2\nbtzY40Vwbfj999/rySefrL1799b27dt7PYdlOHLkSL3zzjtVVXXjjTfWwMBArVlzbSXIY9kKjI2N\n1VdffVU7duyoTqdT+/fv7/UkuCYcOnSo/vjjj5qdna3Z2dmquvRDcDfccEOPl/Hf3HfffTU9PV2P\nPfZYXbx4sZ599tlr7vvmfwsCgIBr63keAHpEcAEgQHABIEBwASBAcAEgQHABIEBwASBAcAEg4D+A\nZKWtiwAl/gAAAABJRU5ErkJggg==\n", 211 | "text/plain": [ 212 | "" 213 | ] 214 | }, 215 | "metadata": {}, 216 | "output_type": "display_data" 217 | } 218 | ], 219 | "source": [ 220 | "### RED = TERMINAL (0)\n", 221 | "### GREEN = LEFT\n", 222 | "### BLUE = UP\n", 223 | "### PURPLE = RIGHT\n", 224 | "### ORANGE = DOWN\n", 225 | "\n", 226 | "show_policy(pi, gw)" 227 | ] 228 | } 229 | ], 230 | "metadata": { 231 | "kernelspec": { 232 | "display_name": "Python 3", 233 | "language": "python", 234 | "name": "python3" 235 | }, 236 | "language_info": { 237 | "codemirror_mode": { 238 | "name": "ipython", 239 | "version": 3 240 | }, 241 | "file_extension": ".py", 242 | "mimetype": "text/x-python", 243 | "name": "python", 244 | "nbconvert_exporter": "python", 245 | "pygments_lexer": "ipython3", 246 | "version": "3.6.1" 247 | } 248 | }, 249 | "nbformat": 4, 250 | "nbformat_minor": 2 251 | } 252 | -------------------------------------------------------------------------------- /source code/4-n-step Bootstrapping (Chapter 7)/__pycache__/gridWorldEnvironment.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buomsoo-kim/Tabular-RL-with-Python/0157c05126821524dc9d744613b9aa8dd5d47232/source code/4-n-step Bootstrapping (Chapter 7)/__pycache__/gridWorldEnvironment.cpython-36.pyc -------------------------------------------------------------------------------- /source code/4-n-step Bootstrapping (Chapter 7)/gridWorldEnvironment.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import seaborn 4 | from matplotlib.colors import ListedColormap 5 | 6 | class GridWorld: 7 | def __init__(self, gamma = 1.0, theta = 0.5): 8 | self.actions = ("U", "D", "L", "R") 9 | self.states = np.arange(1, 15) 10 | self.transitions = pd.read_csv("gridworld.txt", header = None, sep = "\t").values 11 | self.gamma = gamma 12 | self.theta = theta 13 | 14 | def state_transition(self, state, action): 15 | next_state, reward = None, None 16 | for tr in self.transitions: 17 | if tr[0] == state and tr[1] == action: 18 | next_state = tr[2] 19 | reward = tr[3] 20 | return next_state, reward 21 | 22 | def show_environment(self): 23 | all_states = np.concatenate(([0], self.states, [0])).reshape(4,4) 24 | colors = [] 25 | # colors = ["#ffffff"] 26 | for i in range(len(self.states) + 1): 27 | if i == 0: 28 | colors.append("#c4c4c4") 29 | else: 30 | colors.append("#ffffff") 31 | 32 | cmap = ListedColormap(seaborn.color_palette(colors).as_hex()) 33 | ax = seaborn.heatmap(all_states, cmap = cmap, \ 34 | annot = True, linecolor = "#282828", linewidths = 0.2, \ 35 | cbar = False) -------------------------------------------------------------------------------- /source code/4-n-step Bootstrapping (Chapter 7)/gridworld.txt: -------------------------------------------------------------------------------- 1 | 1 U 1 -1 2 | 1 D 5 -1 3 | 1 R 2 -1 4 | 1 L 0 -1 5 | 2 U 2 -1 6 | 2 D 6 -1 7 | 2 R 3 -1 8 | 2 L 1 -1 9 | 3 U 3 -1 10 | 3 D 7 -1 11 | 3 R 3 -1 12 | 3 L 2 -1 13 | 4 U 0 -1 14 | 4 D 8 -1 15 | 4 R 5 -1 16 | 4 L 4 -1 17 | 5 U 1 -1 18 | 5 D 9 -1 19 | 5 R 6 -1 20 | 5 L 4 -1 21 | 6 U 2 -1 22 | 6 D 10 -1 23 | 6 R 7 -1 24 | 6 L 5 -1 25 | 7 U 3 -1 26 | 7 D 11 -1 27 | 7 R 7 -1 28 | 7 L 6 -1 29 | 8 U 4 -1 30 | 8 D 12 -1 31 | 8 R 9 -1 32 | 8 L 8 -1 33 | 9 U 5 -1 34 | 9 D 13 -1 35 | 9 R 10 -1 36 | 9 L 8 -1 37 | 10 U 6 -1 38 | 10 D 14 -1 39 | 10 R 11 -1 40 | 10 L 9 -1 41 | 11 U 7 -1 42 | 11 D 0 -1 43 | 11 R 11 -1 44 | 11 L 10 -1 45 | 12 U 8 -1 46 | 12 D 12 -1 47 | 12 R 13 -1 48 | 12 L 12 -1 49 | 13 U 9 -1 50 | 13 D 13 -1 51 | 13 R 14 -1 52 | 13 L 12 -1 53 | 14 U 10 -1 54 | 14 D 14 -1 55 | 14 R 0 -1 56 | 14 L 13 -1 --------------------------------------------------------------------------------