├── .gitignore ├── Classic_Control_Introduction.ipynb ├── MDP_introduction.ipynb ├── Section_10_reinforce_CartPole.ipynb ├── Section_10_reinforce_CartPole_complete.ipynb ├── Section_11_advantage_actor_critic.ipynb ├── Section_11_advantage_actor_critic_complete.ipynb ├── Section_3_policy_iteration.ipynb ├── Section_3_policy_iteration_complete.ipynb ├── Section_3_value_iteration.ipynb ├── Section_3_value_iteration_complete.ipynb ├── Section_4_off_policy_control.ipynb ├── Section_4_off_policy_control_complete.ipynb ├── Section_4_on_policy_constant_alpha_mc.ipynb ├── Section_4_on_policy_constant_alpha_mc_complete.ipynb ├── Section_4_on_policy_control.ipynb ├── Section_4_on_policy_control_complete.ipynb ├── Section_5_qlearning.ipynb ├── Section_5_qlearning_complete.ipynb ├── Section_5_sarsa.ipynb ├── Section_5_sarsa_complete.ipynb ├── Section_6_n_step_sarsa.ipynb ├── Section_6_n_step_sarsa_complete.ipynb ├── Section_7_continuous_observation_spaces.ipynb ├── Section_7_continuous_observation_spaces_complete.ipynb ├── Section_8_deep_sarsa.ipynb ├── Section_8_deep_sarsa_complete.ipynb ├── Section_9_deep_q_learning.ipynb └── Section_9_deep_q_learning_complete.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | .ipynb_checkpoints/ 3 | -------------------------------------------------------------------------------- /Classic_Control_Introduction.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "Urmh5ntuuEaL" 7 | }, 8 | "source": [ 9 | "# Classic Control: Control theory problems from the classic RL literature\n", 10 | "\n", 11 | "

\n", 12 | "\n", 13 | "In this notebook we will present some classic environments in Reinforcement Learning research. These environments have continuous states spaces (i.e., infinite possible states) and therefore tabular methods cannot solve them. To tackle these environments (and more complex ones) we will have two tools:\n", 14 | "\n", 15 | "- Extend the tabular methods with the techniques of discretization and tile coding\n", 16 | "- Use function approximators (Neural Networks)\n", 17 | "\n", 18 | "
" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "source": [ 24 | "# @title Setup code (not important) - Run this cell by pressing \"Shift + Enter\"\n", 25 | "\n", 26 | "\n", 27 | "!pip install -qq gym==0.23.0\n", 28 | "\n", 29 | "\n", 30 | "import matplotlib\n", 31 | "from matplotlib import animation\n", 32 | "from IPython.display import HTML\n", 33 | "\n", 34 | "\n", 35 | "def display_video(frames):\n", 36 | " # Copied from: https://colab.research.google.com/github/deepmind/dm_control/blob/master/tutorial.ipynb\n", 37 | " orig_backend = matplotlib.get_backend()\n", 38 | " matplotlib.use('Agg')\n", 39 | " fig, ax = plt.subplots(1, 1, figsize=(5, 5))\n", 40 | " matplotlib.use(orig_backend)\n", 41 | " ax.set_axis_off()\n", 42 | " ax.set_aspect('equal')\n", 43 | " ax.set_position([0, 0, 1, 1])\n", 44 | " im = ax.imshow(frames[0])\n", 45 | " def update(frame):\n", 46 | " im.set_data(frame)\n", 47 | " return [im]\n", 48 | " anim = animation.FuncAnimation(fig=fig, func=update, frames=frames,\n", 49 | " interval=50, blit=True, repeat=False)\n", 50 | " return HTML(anim.to_html5_video())\n", 51 | "\n", 52 | "\n", 53 | "def test_env(environment, episodes=10):\n", 54 | " frames = []\n", 55 | " for episode in range(episodes):\n", 56 | " state = environment.reset()\n", 57 | " done = False\n", 58 | " frames.append(environment.render(mode=\"rgb_array\"))\n", 59 | "\n", 60 | " while not done:\n", 61 | " action = environment.action_space.sample()\n", 62 | " next_state, reward, done, extra_info = environment.step(action)\n", 63 | " img = environment.render(mode=\"rgb_array\")\n", 64 | " frames.append(img)\n", 65 | " state = next_state\n", 66 | "\n", 67 | " return display_video(frames)\n", 68 | "\n" 69 | ], 70 | "metadata": { 71 | "colab": { 72 | "base_uri": "https://localhost:8080/" 73 | }, 74 | "cellView": "form", 75 | "id": "FX6nq6g7wAys", 76 | "outputId": "07f12cbc-90a9-4028-dd20-ce89c8f35b26" 77 | }, 78 | "execution_count": 1, 79 | "outputs": [ 80 | { 81 | "output_type": "stream", 82 | "name": "stdout", 83 | "text": [ 84 | "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/624.4 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━━━━━━━\u001b[0m\u001b[90m╺\u001b[0m\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m235.5/624.4 kB\u001b[0m \u001b[31m6.7 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[90m╺\u001b[0m \u001b[32m614.4/624.4 kB\u001b[0m \u001b[31m9.9 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m624.4/624.4 kB\u001b[0m \u001b[31m8.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 85 | "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", 86 | " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", 87 | " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", 88 | " Building wheel for gym (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n" 89 | ] 90 | } 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 2, 96 | "metadata": { 97 | "id": "m_fXRjU9uEaO" 98 | }, 99 | "outputs": [], 100 | "source": [ 101 | "import gym\n", 102 | "import numpy as np\n", 103 | "from IPython import display\n", 104 | "from matplotlib import pyplot as plt\n", 105 | "%matplotlib inline" 106 | ] 107 | }, 108 | { 109 | "cell_type": "markdown", 110 | "metadata": { 111 | "id": "M9vnZWVouEaP" 112 | }, 113 | "source": [ 114 | "## CartPole: Keep the tip of the pole straight." 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 3, 120 | "metadata": { 121 | "id": "rEwvRFsauEaP" 122 | }, 123 | "outputs": [], 124 | "source": [ 125 | "env = gym.make('CartPole-v1')\n", 126 | "test_env(env, 1)\n", 127 | "env.close()" 128 | ] 129 | }, 130 | { 131 | "cell_type": "markdown", 132 | "metadata": { 133 | "id": "JS5l98EauEaQ" 134 | }, 135 | "source": [ 136 | "##### The state\n", 137 | "\n", 138 | "The states of the cartpole task will be represented by a vector of four real numbers:\n", 139 | "\n", 140 | " Num Observation Min Max\n", 141 | " 0 Cart Position -4.8 4.8\n", 142 | " 1 Cart Velocity -Inf Inf\n", 143 | " 2 Pole Angle -0.418 rad (-24 deg) 0.418 rad (24 deg)\n", 144 | " 3 Pole Angular Velocity -Inf Inf\n" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 4, 150 | "metadata": { 151 | "id": "kC_5K14quEaQ", 152 | "colab": { 153 | "base_uri": "https://localhost:8080/" 154 | }, 155 | "outputId": "417e9ffb-50d7-4824-cdc0-664590b41d3d" 156 | }, 157 | "outputs": [ 158 | { 159 | "output_type": "execute_result", 160 | "data": { 161 | "text/plain": [ 162 | "Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32)" 163 | ] 164 | }, 165 | "metadata": {}, 166 | "execution_count": 4 167 | } 168 | ], 169 | "source": [ 170 | "env.observation_space" 171 | ] 172 | }, 173 | { 174 | "cell_type": "markdown", 175 | "metadata": { 176 | "id": "fwSR1g5GuEaQ" 177 | }, 178 | "source": [ 179 | "##### The actions available\n", 180 | "\n", 181 | "We can perform two actions in this environment:\n", 182 | "\n", 183 | " 0 Push cart to the left.\n", 184 | " 1 Push cart to the right.\n", 185 | "\n" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": 5, 191 | "metadata": { 192 | "id": "hphxRLb1uEaQ", 193 | "colab": { 194 | "base_uri": "https://localhost:8080/" 195 | }, 196 | "outputId": "09f47fb7-3a68-49c6-d637-0ee76b5b7286" 197 | }, 198 | "outputs": [ 199 | { 200 | "output_type": "execute_result", 201 | "data": { 202 | "text/plain": [ 203 | "Discrete(2)" 204 | ] 205 | }, 206 | "metadata": {}, 207 | "execution_count": 5 208 | } 209 | ], 210 | "source": [ 211 | "env.action_space" 212 | ] 213 | }, 214 | { 215 | "cell_type": "markdown", 216 | "metadata": { 217 | "id": "9NPvm9fxuEaQ" 218 | }, 219 | "source": [ 220 | "## Acrobot: Swing the bar up to a certain height." 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": 6, 226 | "metadata": { 227 | "id": "kUp2Pj45uEaQ" 228 | }, 229 | "outputs": [], 230 | "source": [ 231 | "env = gym.make('Acrobot-v1')\n", 232 | "test_env(env, 1)\n", 233 | "env.close()" 234 | ] 235 | }, 236 | { 237 | "cell_type": "markdown", 238 | "metadata": { 239 | "id": "QILrkNtAuEaQ" 240 | }, 241 | "source": [ 242 | "##### The state\n", 243 | "\n", 244 | "The states of the cartpole task will be represented by a vector of six real numbers. The first two are the cosine and sine of the first joint. The next two are the cosine and sine of the other joint. The last two are the angular velocities of each joint.\n", 245 | " \n", 246 | "$\\cos(\\theta_1), \\sin(\\theta_1), \\cos(\\theta_2), \\sin(\\theta_2), \\dot\\theta_1, \\dot\\theta_2$" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 7, 252 | "metadata": { 253 | "id": "cx-K5Z_8uEaR", 254 | "colab": { 255 | "base_uri": "https://localhost:8080/" 256 | }, 257 | "outputId": "dcd649d1-4bd6-47e1-dd60-0ca43582b51f" 258 | }, 259 | "outputs": [ 260 | { 261 | "output_type": "execute_result", 262 | "data": { 263 | "text/plain": [ 264 | "Box([ -1. -1. -1. -1. -12.566371 -28.274334], [ 1. 1. 1. 1. 12.566371 28.274334], (6,), float32)" 265 | ] 266 | }, 267 | "metadata": {}, 268 | "execution_count": 7 269 | } 270 | ], 271 | "source": [ 272 | "env.observation_space" 273 | ] 274 | }, 275 | { 276 | "cell_type": "markdown", 277 | "metadata": { 278 | "id": "SfVu41CEuEaR" 279 | }, 280 | "source": [ 281 | "##### The actions available\n", 282 | "\n", 283 | "We can perform two actions in this environment:\n", 284 | "\n", 285 | " 0 Apply +1 torque on the joint between the links.\n", 286 | " 1 Apply -1 torque on the joint between the links." 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "execution_count": 8, 292 | "metadata": { 293 | "id": "UJOCoO9QuEaR", 294 | "colab": { 295 | "base_uri": "https://localhost:8080/" 296 | }, 297 | "outputId": "d3fed225-c8ca-459d-b14b-4b7dcbab3952" 298 | }, 299 | "outputs": [ 300 | { 301 | "output_type": "execute_result", 302 | "data": { 303 | "text/plain": [ 304 | "Discrete(3)" 305 | ] 306 | }, 307 | "metadata": {}, 308 | "execution_count": 8 309 | } 310 | ], 311 | "source": [ 312 | "env.action_space" 313 | ] 314 | }, 315 | { 316 | "cell_type": "markdown", 317 | "metadata": { 318 | "id": "MPNceYXHuEaR" 319 | }, 320 | "source": [ 321 | "## MountainCar: Reach the goal from the bottom of the valley." 322 | ] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": 9, 327 | "metadata": { 328 | "id": "UI00Fw4LuEaR" 329 | }, 330 | "outputs": [], 331 | "source": [ 332 | "env = gym.make('MountainCar-v0')\n", 333 | "test_env(env, 1)\n", 334 | "env.close()" 335 | ] 336 | }, 337 | { 338 | "cell_type": "markdown", 339 | "metadata": { 340 | "id": "vY6aE8TduEaR" 341 | }, 342 | "source": [ 343 | "##### The state\n", 344 | "\n", 345 | "The observation space consists of the car position $\\in [-1.2, 0.6]$ and car velocity $\\in [-0.07, 0.07]$" 346 | ] 347 | }, 348 | { 349 | "cell_type": "code", 350 | "execution_count": 10, 351 | "metadata": { 352 | "id": "7vOD0LvbuEaR", 353 | "colab": { 354 | "base_uri": "https://localhost:8080/" 355 | }, 356 | "outputId": "a6325253-06e1-496d-d9ee-d3790ed7b12a" 357 | }, 358 | "outputs": [ 359 | { 360 | "output_type": "execute_result", 361 | "data": { 362 | "text/plain": [ 363 | "Box([-1.2 -0.07], [0.6 0.07], (2,), float32)" 364 | ] 365 | }, 366 | "metadata": {}, 367 | "execution_count": 10 368 | } 369 | ], 370 | "source": [ 371 | "env.observation_space" 372 | ] 373 | }, 374 | { 375 | "cell_type": "markdown", 376 | "metadata": { 377 | "id": "fs5q1ZYZuEaR" 378 | }, 379 | "source": [ 380 | "##### The actions available\n", 381 | "\n", 382 | "\n", 383 | "The actions available three:\n", 384 | "\n", 385 | " 0 Accelerate to the left.\n", 386 | " 1 Don't accelerate.\n", 387 | " 2 Accelerate to the right." 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "execution_count": 11, 393 | "metadata": { 394 | "id": "XrF5V1SIuEaR", 395 | "colab": { 396 | "base_uri": "https://localhost:8080/" 397 | }, 398 | "outputId": "62124d99-e05e-4edf-e7e0-c6a1352900ca" 399 | }, 400 | "outputs": [ 401 | { 402 | "output_type": "execute_result", 403 | "data": { 404 | "text/plain": [ 405 | "Discrete(3)" 406 | ] 407 | }, 408 | "metadata": {}, 409 | "execution_count": 11 410 | } 411 | ], 412 | "source": [ 413 | "env.action_space" 414 | ] 415 | }, 416 | { 417 | "cell_type": "markdown", 418 | "metadata": { 419 | "id": "vPmwqfEjuEaS" 420 | }, 421 | "source": [ 422 | "## Pendulum: swing it and keep it upright" 423 | ] 424 | }, 425 | { 426 | "cell_type": "code", 427 | "execution_count": 12, 428 | "metadata": { 429 | "id": "Q7csEunxuEaS" 430 | }, 431 | "outputs": [], 432 | "source": [ 433 | "env = gym.make('Pendulum-v1')\n", 434 | "test_env(env, 1)\n", 435 | "env.close()" 436 | ] 437 | }, 438 | { 439 | "cell_type": "markdown", 440 | "metadata": { 441 | "id": "UusKijA0uEaS" 442 | }, 443 | "source": [ 444 | "##### The state\n", 445 | "\n", 446 | "The state is represented by a vector of three values representing $\\cos(\\theta), \\sin(\\theta)$ and speed ($\\theta$ is the angle of the pendulum)." 447 | ] 448 | }, 449 | { 450 | "cell_type": "code", 451 | "execution_count": 13, 452 | "metadata": { 453 | "id": "zIQaIFAsuEaS", 454 | "colab": { 455 | "base_uri": "https://localhost:8080/" 456 | }, 457 | "outputId": "0ff28f6e-03fb-46a1-cab6-a617d9afe252" 458 | }, 459 | "outputs": [ 460 | { 461 | "output_type": "execute_result", 462 | "data": { 463 | "text/plain": [ 464 | "Box([-1. -1. -8.], [1. 1. 8.], (3,), float32)" 465 | ] 466 | }, 467 | "metadata": {}, 468 | "execution_count": 13 469 | } 470 | ], 471 | "source": [ 472 | "env.observation_space" 473 | ] 474 | }, 475 | { 476 | "cell_type": "markdown", 477 | "metadata": { 478 | "id": "QeYAfguwuEaS" 479 | }, 480 | "source": [ 481 | "##### The actions available\n", 482 | "\n", 483 | "The action is a real number in the interval $[-2, 2]$ that represents the torque applied on the pendulum." 484 | ] 485 | }, 486 | { 487 | "cell_type": "code", 488 | "execution_count": 14, 489 | "metadata": { 490 | "id": "01d2orXsuEaS", 491 | "colab": { 492 | "base_uri": "https://localhost:8080/" 493 | }, 494 | "outputId": "dce5aced-a685-4739-9afc-1e9710afa9e8" 495 | }, 496 | "outputs": [ 497 | { 498 | "output_type": "execute_result", 499 | "data": { 500 | "text/plain": [ 501 | "Box(-2.0, 2.0, (1,), float32)" 502 | ] 503 | }, 504 | "metadata": {}, 505 | "execution_count": 14 506 | } 507 | ], 508 | "source": [ 509 | "env.action_space" 510 | ] 511 | }, 512 | { 513 | "cell_type": "markdown", 514 | "metadata": { 515 | "id": "Q0kA6q7auEaS" 516 | }, 517 | "source": [ 518 | "## Resources" 519 | ] 520 | }, 521 | { 522 | "cell_type": "markdown", 523 | "metadata": { 524 | "id": "7PnV4kmDuEaS" 525 | }, 526 | "source": [ 527 | "[[1] OpenAI gym: classic control environments](https://gym.openai.com/envs/#classic_control)" 528 | ] 529 | }, 530 | { 531 | "cell_type": "code", 532 | "source": [], 533 | "metadata": { 534 | "id": "-fnF-sXIwwsn" 535 | }, 536 | "execution_count": 14, 537 | "outputs": [] 538 | } 539 | ], 540 | "metadata": { 541 | "kernelspec": { 542 | "display_name": "Python 3 (ipykernel)", 543 | "language": "python", 544 | "name": "python3" 545 | }, 546 | "language_info": { 547 | "codemirror_mode": { 548 | "name": "ipython", 549 | "version": 3 550 | }, 551 | "file_extension": ".py", 552 | "mimetype": "text/x-python", 553 | "name": "python", 554 | "nbconvert_exporter": "python", 555 | "pygments_lexer": "ipython3", 556 | "version": "3.8.0" 557 | }, 558 | "colab": { 559 | "provenance": [] 560 | } 561 | }, 562 | "nbformat": 4, 563 | "nbformat_minor": 0 564 | } -------------------------------------------------------------------------------- /Section_11_advantage_actor_critic.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "67rhCAvE7vV4", 7 | "pycharm": { 8 | "name": "#%%\n" 9 | } 10 | }, 11 | "source": [ 12 | "
\n", 13 | "

\n", 14 | " Advantage Actor-Critic (A2C)\n", 15 | "

\n", 16 | "
\n", 17 | "\n", 18 | "

\n", 19 | "\n", 20 | "
\n", 21 | "In this notebook we are going to combine temporal difference learning (TD) with policy gradient methods. The resulting algorithm is called Advantage Actor-Critic (A2C) and uses a one-step estimate of the return to update the policy:\n", 22 | "
\n", 23 | "\n", 24 | "\\begin{equation}\n", 25 | "\\hat G_t = R_{t+1} + \\gamma v(S_{t+1}|w)\n", 26 | "\\end{equation}\n", 27 | "\n", 28 | "\n", 29 | "
" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": { 36 | "cellView": "form", 37 | "id": "_uUI8vNy703R" 38 | }, 39 | "outputs": [], 40 | "source": [ 41 | "# @title Setup code (not important) - Run this cell by pressing \"Shift + Enter\"\n", 42 | "\n", 43 | "\n", 44 | "\n", 45 | "!pip install -qq gym==0.23.0 numpy==1.26.4 seaborn==0.12\n", 46 | "\n", 47 | "\n", 48 | "from typing import Tuple, Dict, Optional, Iterable, Callable\n", 49 | "\n", 50 | "import numpy as np\n", 51 | "import seaborn as sns\n", 52 | "import matplotlib\n", 53 | "import torch\n", 54 | "from matplotlib import animation\n", 55 | "import matplotlib.patches as mpatches\n", 56 | "\n", 57 | "from IPython.display import HTML\n", 58 | "\n", 59 | "import gym\n", 60 | "from gym import spaces\n", 61 | "from gym.error import DependencyNotInstalled\n", 62 | "\n", 63 | "import pygame\n", 64 | "from pygame import gfxdraw\n", 65 | "\n", 66 | "\n", 67 | "class Maze(gym.Env):\n", 68 | "\n", 69 | " def __init__(self, exploring_starts: bool = False,\n", 70 | " shaped_rewards: bool = False, size: int = 5) -> None:\n", 71 | " super().__init__()\n", 72 | " self.exploring_starts = exploring_starts\n", 73 | " self.shaped_rewards = shaped_rewards\n", 74 | " self.state = (size - 1, size - 1)\n", 75 | " self.goal = (size - 1, size - 1)\n", 76 | " self.maze = self._create_maze(size=size)\n", 77 | " self.distances = self._compute_distances(self.goal, self.maze)\n", 78 | " self.action_space = spaces.Discrete(n=4)\n", 79 | " self.action_space.action_meanings = {0: 'UP', 1: 'RIGHT', 2: 'DOWN', 3: \"LEFT\"}\n", 80 | " self.observation_space = spaces.MultiDiscrete([size, size])\n", 81 | "\n", 82 | " self.screen = None\n", 83 | " self.agent_transform = None\n", 84 | "\n", 85 | " def step(self, action: int) -> Tuple[Tuple[int, int], float, bool, Dict]:\n", 86 | " reward = self.compute_reward(self.state, action)\n", 87 | " self.state = self._get_next_state(self.state, action)\n", 88 | " done = self.state == self.goal\n", 89 | " info = {}\n", 90 | " return self.state, reward, done, info\n", 91 | "\n", 92 | " def reset(self) -> Tuple[int, int]:\n", 93 | " if self.exploring_starts:\n", 94 | " while self.state == self.goal:\n", 95 | " self.state = tuple(self.observation_space.sample())\n", 96 | " else:\n", 97 | " self.state = (0, 0)\n", 98 | " return self.state\n", 99 | "\n", 100 | " def render(self, mode: str = 'human') -> Optional[np.ndarray]:\n", 101 | " assert mode in ['human', 'rgb_array']\n", 102 | "\n", 103 | " screen_size = 600\n", 104 | " scale = screen_size / 5\n", 105 | "\n", 106 | " if self.screen is None:\n", 107 | " pygame.init()\n", 108 | " self.screen = pygame.Surface((screen_size, screen_size))\n", 109 | "\n", 110 | " surf = pygame.Surface((screen_size, screen_size))\n", 111 | " surf.fill((22, 36, 71))\n", 112 | "\n", 113 | "\n", 114 | " for row in range(5):\n", 115 | " for col in range(5):\n", 116 | "\n", 117 | " state = (row, col)\n", 118 | " for next_state in [(row + 1, col), (row - 1, col), (row, col + 1), (row, col - 1)]:\n", 119 | " if next_state not in self.maze[state]:\n", 120 | "\n", 121 | " # Add the geometry of the edges and walls (i.e. the boundaries between\n", 122 | " # adjacent squares that are not connected).\n", 123 | " row_diff, col_diff = np.subtract(next_state, state)\n", 124 | " left = (col + (col_diff > 0)) * scale - 2 * (col_diff != 0)\n", 125 | " right = ((col + 1) - (col_diff < 0)) * scale + 2 * (col_diff != 0)\n", 126 | " top = (5 - (row + (row_diff > 0))) * scale - 2 * (row_diff != 0)\n", 127 | " bottom = (5 - ((row + 1) - (row_diff < 0))) * scale + 2 * (row_diff != 0)\n", 128 | "\n", 129 | " gfxdraw.filled_polygon(surf, [(left, bottom), (left, top), (right, top), (right, bottom)], (255, 255, 255))\n", 130 | "\n", 131 | " # Add the geometry of the goal square to the viewer.\n", 132 | " left, right, top, bottom = scale * 4 + 10, scale * 5 - 10, scale - 10, 10\n", 133 | " gfxdraw.filled_polygon(surf, [(left, bottom), (left, top), (right, top), (right, bottom)], (40, 199, 172))\n", 134 | "\n", 135 | " # Add the geometry of the agent to the viewer.\n", 136 | " agent_row = int(screen_size - scale * (self.state[0] + .5))\n", 137 | " agent_col = int(scale * (self.state[1] + .5))\n", 138 | " gfxdraw.filled_circle(surf, agent_col, agent_row, int(scale * .6 / 2), (228, 63, 90))\n", 139 | "\n", 140 | " surf = pygame.transform.flip(surf, False, True)\n", 141 | " self.screen.blit(surf, (0, 0))\n", 142 | "\n", 143 | " return np.transpose(\n", 144 | " np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2)\n", 145 | " )\n", 146 | "\n", 147 | " def close(self) -> None:\n", 148 | " if self.screen is not None:\n", 149 | " pygame.display.quit()\n", 150 | " pygame.quit()\n", 151 | " self.screen = None\n", 152 | "\n", 153 | " def compute_reward(self, state: Tuple[int, int], action: int) -> float:\n", 154 | " next_state = self._get_next_state(state, action)\n", 155 | " if self.shaped_rewards:\n", 156 | " return - (self.distances[next_state] / self.distances.max())\n", 157 | " return - float(state != self.goal)\n", 158 | "\n", 159 | " def simulate_step(self, state: Tuple[int, int], action: int):\n", 160 | " reward = self.compute_reward(state, action)\n", 161 | " next_state = self._get_next_state(state, action)\n", 162 | " done = next_state == self.goal\n", 163 | " info = {}\n", 164 | " return next_state, reward, done, info\n", 165 | "\n", 166 | " def _get_next_state(self, state: Tuple[int, int], action: int) -> Tuple[int, int]:\n", 167 | " if action == 0:\n", 168 | " next_state = (state[0] - 1, state[1])\n", 169 | " elif action == 1:\n", 170 | " next_state = (state[0], state[1] + 1)\n", 171 | " elif action == 2:\n", 172 | " next_state = (state[0] + 1, state[1])\n", 173 | " elif action == 3:\n", 174 | " next_state = (state[0], state[1] - 1)\n", 175 | " else:\n", 176 | " raise ValueError(\"Action value not supported:\", action)\n", 177 | " if next_state in self.maze[state]:\n", 178 | " return next_state\n", 179 | " return state\n", 180 | "\n", 181 | " @staticmethod\n", 182 | " def _create_maze(size: int) -> Dict[Tuple[int, int], Iterable[Tuple[int, int]]]:\n", 183 | " maze = {(row, col): [(row - 1, col), (row + 1, col), (row, col - 1), (row, col + 1)]\n", 184 | " for row in range(size) for col in range(size)}\n", 185 | "\n", 186 | " left_edges = [[(row, 0), (row, -1)] for row in range(size)]\n", 187 | " right_edges = [[(row, size - 1), (row, size)] for row in range(size)]\n", 188 | " upper_edges = [[(0, col), (-1, col)] for col in range(size)]\n", 189 | " lower_edges = [[(size - 1, col), (size, col)] for col in range(size)]\n", 190 | " walls = [\n", 191 | " [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)],\n", 192 | " [(1, 1), (1, 2)], [(2, 1), (2, 2)], [(3, 1), (3, 2)],\n", 193 | " [(3, 1), (4, 1)], [(0, 2), (1, 2)], [(1, 2), (1, 3)],\n", 194 | " [(2, 2), (3, 2)], [(2, 3), (3, 3)], [(2, 4), (3, 4)],\n", 195 | " [(4, 2), (4, 3)], [(1, 3), (1, 4)], [(2, 3), (2, 4)],\n", 196 | " ]\n", 197 | "\n", 198 | " obstacles = upper_edges + lower_edges + left_edges + right_edges + walls\n", 199 | "\n", 200 | " for src, dst in obstacles:\n", 201 | " maze[src].remove(dst)\n", 202 | "\n", 203 | " if dst in maze:\n", 204 | " maze[dst].remove(src)\n", 205 | "\n", 206 | " return maze\n", 207 | "\n", 208 | " @staticmethod\n", 209 | " def _compute_distances(goal: Tuple[int, int],\n", 210 | " maze: Dict[Tuple[int, int], Iterable[Tuple[int, int]]]) -> np.ndarray:\n", 211 | " distances = np.full((5, 5), np.inf)\n", 212 | " visited = set()\n", 213 | " distances[goal] = 0.\n", 214 | "\n", 215 | " while visited != set(maze):\n", 216 | " sorted_dst = [(v // 5, v % 5) for v in distances.argsort(axis=None)]\n", 217 | " closest = next(x for x in sorted_dst if x not in visited)\n", 218 | " visited.add(closest)\n", 219 | "\n", 220 | " for neighbour in maze[closest]:\n", 221 | " distances[neighbour] = min(distances[neighbour], distances[closest] + 1)\n", 222 | " return distances\n", 223 | "\n", 224 | "\n", 225 | "def display_video(frames):\n", 226 | " # Copied from: https://colab.research.google.com/github/deepmind/dm_control/blob/master/tutorial.ipynb\n", 227 | " orig_backend = matplotlib.get_backend()\n", 228 | " matplotlib.use('Agg')\n", 229 | " fig, ax = plt.subplots(1, 1, figsize=(5, 5))\n", 230 | " matplotlib.use(orig_backend)\n", 231 | " ax.set_axis_off()\n", 232 | " ax.set_aspect('equal')\n", 233 | " ax.set_position([0, 0, 1, 1])\n", 234 | " im = ax.imshow(frames[0])\n", 235 | " def update(frame):\n", 236 | " im.set_data(frame)\n", 237 | " return [im]\n", 238 | " anim = animation.FuncAnimation(fig=fig, func=update, frames=frames,\n", 239 | " interval=50, blit=True, repeat=False)\n", 240 | " return HTML(anim.to_html5_video())\n", 241 | "\n", 242 | "\n", 243 | "def seed_everything(env: gym.Env, seed: int = 42) -> None:\n", 244 | " env.seed(seed)\n", 245 | " env.action_space.seed(seed)\n", 246 | " env.observation_space.seed(seed)\n", 247 | " np.random.seed(seed)\n", 248 | " torch.manual_seed(seed)\n", 249 | " torch.use_deterministic_algorithms(True)\n", 250 | "\n", 251 | "\n", 252 | "def plot_stats(stats):\n", 253 | " rows = len(stats)\n", 254 | " cols = 1\n", 255 | "\n", 256 | " fig, ax = plt.subplots(rows, cols, figsize=(12, 6))\n", 257 | "\n", 258 | " for i, key in enumerate(stats):\n", 259 | " vals = stats[key]\n", 260 | " vals = [np.mean(vals[i-10:i+10]) for i in range(10, len(vals)-10)]\n", 261 | " if len(stats) > 1:\n", 262 | " ax[i].plot(range(len(vals)), vals)\n", 263 | " ax[i].set_title(key, size=18)\n", 264 | " else:\n", 265 | " ax.plot(range(len(vals)), vals)\n", 266 | " ax.set_title(key, size=18)\n", 267 | " plt.tight_layout()\n", 268 | " plt.show()\n", 269 | "\n", 270 | "\n", 271 | "def test_policy_network(env, policy, episodes=10):\n", 272 | " frames = []\n", 273 | " for episode in range(episodes):\n", 274 | " state = env.reset()\n", 275 | " done = False\n", 276 | " frames.append(env.render(mode=\"rgb_array\"))\n", 277 | "\n", 278 | " while not done:\n", 279 | " state = torch.from_numpy(state).unsqueeze(0).float()\n", 280 | " action = policy(state).multinomial(1).item()\n", 281 | " next_state, _, done, _ = env.step(action)\n", 282 | " img = env.render(mode=\"rgb_array\")\n", 283 | " frames.append(img)\n", 284 | " state = next_state\n", 285 | "\n", 286 | " return display_video(frames)\n", 287 | "\n", 288 | "\n", 289 | "def plot_action_probs(probs, labels):\n", 290 | " plt.figure(figsize=(6, 4))\n", 291 | " plt.bar(labels, probs, color ='orange')\n", 292 | " plt.title(\"$\\pi(s)$\", size=16)\n", 293 | " plt.xticks(fontsize=12)\n", 294 | " plt.yticks(fontsize=12)\n", 295 | " plt.tight_layout()\n", 296 | " plt.show()\n", 297 | "\n" 298 | ] 299 | }, 300 | { 301 | "cell_type": "markdown", 302 | "metadata": { 303 | "id": "RlB0Tbp07vV6" 304 | }, 305 | "source": [ 306 | "## Import the necessary software libraries:" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": null, 312 | "metadata": { 313 | "id": "2OnbUU8t7vV7" 314 | }, 315 | "outputs": [], 316 | "source": [ 317 | "import os\n", 318 | "import torch\n", 319 | "import gym\n", 320 | "import numpy as np\n", 321 | "import matplotlib.pyplot as plt\n", 322 | "from tqdm import tqdm\n", 323 | "from torch import nn as nn\n", 324 | "from torch.optim import AdamW\n", 325 | "import torch.nn.functional as F" 326 | ] 327 | }, 328 | { 329 | "cell_type": "markdown", 330 | "metadata": { 331 | "id": "pPEwlOrt7vV8" 332 | }, 333 | "source": [ 334 | "## Create and preprocess the environment" 335 | ] 336 | }, 337 | { 338 | "cell_type": "markdown", 339 | "metadata": { 340 | "id": "j37j_pOh7vV8" 341 | }, 342 | "source": [ 343 | "### Create the environment" 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "execution_count": null, 349 | "metadata": { 350 | "id": "RDViC8L47vV8" 351 | }, 352 | "outputs": [], 353 | "source": [ 354 | "env = gym.make('Acrobot-v1')" 355 | ] 356 | }, 357 | { 358 | "cell_type": "code", 359 | "execution_count": null, 360 | "metadata": { 361 | "id": "LuJ9Hx4E7vV8" 362 | }, 363 | "outputs": [], 364 | "source": [ 365 | "dims = env.observation_space.shape[0]\n", 366 | "actions = env.action_space.n\n", 367 | "\n", 368 | "print(f\"State dimensions: {dims}. Actions: {actions}\")\n", 369 | "print(f\"Sample state: {env.reset()}\")" 370 | ] 371 | }, 372 | { 373 | "cell_type": "code", 374 | "execution_count": null, 375 | "metadata": { 376 | "id": "QVHSTC827vV8" 377 | }, 378 | "outputs": [], 379 | "source": [ 380 | "plt.imshow(env.render(mode='rgb_array'))" 381 | ] 382 | }, 383 | { 384 | "cell_type": "markdown", 385 | "metadata": { 386 | "id": "8piDIP4E7vV9" 387 | }, 388 | "source": [ 389 | "### Prepare the environment to work with PyTorch" 390 | ] 391 | }, 392 | { 393 | "cell_type": "code", 394 | "execution_count": null, 395 | "metadata": { 396 | "id": "bBqxt5hP7vV9" 397 | }, 398 | "outputs": [], 399 | "source": [ 400 | "class PreprocessEnv(gym.Wrapper):\n", 401 | "\n", 402 | " def __init__(self, env):\n", 403 | " gym.Wrapper.__init__(self, env)\n", 404 | "\n", 405 | " def reset(self):\n", 406 | " state = self.env.reset()\n", 407 | " return torch.from_numpy(state).float()\n", 408 | "\n", 409 | " def step(self, actions):\n", 410 | " actions = actions.squeeze().numpy()\n", 411 | " next_state, reward, done, info = self.env.step(actions)\n", 412 | " next_state = torch.from_numpy(next_state).float()\n", 413 | " reward = torch.tensor(reward).unsqueeze(1).float()\n", 414 | " done = torch.tensor(done).unsqueeze(1)\n", 415 | " return next_state, reward, done, info" 416 | ] 417 | }, 418 | { 419 | "cell_type": "code", 420 | "execution_count": null, 421 | "metadata": { 422 | "id": "jkRkODiA7vV-" 423 | }, 424 | "outputs": [], 425 | "source": [ 426 | "num_envs = 8\n", 427 | "parallel_env = gym.vector.make('Acrobot-v1', num_envs=num_envs)\n", 428 | "seed_everything(parallel_env)\n", 429 | "parallel_env = PreprocessEnv(parallel_env)" 430 | ] 431 | }, 432 | { 433 | "cell_type": "markdown", 434 | "metadata": { 435 | "id": "H_35Rcxy7vV-" 436 | }, 437 | "source": [ 438 | "### Create the policy $\\pi(s)$" 439 | ] 440 | }, 441 | { 442 | "cell_type": "code", 443 | "execution_count": null, 444 | "metadata": { 445 | "id": "rF1_C5Xb7vV-" 446 | }, 447 | "outputs": [], 448 | "source": [] 449 | }, 450 | { 451 | "cell_type": "markdown", 452 | "metadata": { 453 | "id": "LPY8NEB17vV-" 454 | }, 455 | "source": [ 456 | "### Create the value network $v(s)$" 457 | ] 458 | }, 459 | { 460 | "cell_type": "code", 461 | "execution_count": null, 462 | "metadata": { 463 | "id": "EKD6vk4i7vV-" 464 | }, 465 | "outputs": [], 466 | "source": [] 467 | }, 468 | { 469 | "cell_type": "markdown", 470 | "metadata": { 471 | "id": "SH5RVyq-7vV-" 472 | }, 473 | "source": [ 474 | "## Implement the algorithm\n" 475 | ] 476 | }, 477 | { 478 | "cell_type": "code", 479 | "execution_count": null, 480 | "metadata": { 481 | "id": "lI3Zju7u7vV-" 482 | }, 483 | "outputs": [], 484 | "source": [] 485 | }, 486 | { 487 | "cell_type": "code", 488 | "execution_count": null, 489 | "metadata": { 490 | "id": "eEWU63Z07vV-", 491 | "scrolled": true 492 | }, 493 | "outputs": [], 494 | "source": [] 495 | }, 496 | { 497 | "cell_type": "markdown", 498 | "metadata": { 499 | "id": "DoxzCbPz7vV-" 500 | }, 501 | "source": [ 502 | "## Show results" 503 | ] 504 | }, 505 | { 506 | "cell_type": "markdown", 507 | "metadata": { 508 | "id": "3cmyUHP67vV-" 509 | }, 510 | "source": [ 511 | "### Show execution stats" 512 | ] 513 | }, 514 | { 515 | "cell_type": "code", 516 | "execution_count": null, 517 | "metadata": { 518 | "id": "oV46xCdU7vV-" 519 | }, 520 | "outputs": [], 521 | "source": [ 522 | "plot_stats(stats)" 523 | ] 524 | }, 525 | { 526 | "cell_type": "markdown", 527 | "metadata": { 528 | "id": "KD1Khhk17vV-" 529 | }, 530 | "source": [ 531 | "### Test the resulting agent" 532 | ] 533 | }, 534 | { 535 | "cell_type": "code", 536 | "execution_count": null, 537 | "metadata": { 538 | "id": "w6EwEhPd7vV_" 539 | }, 540 | "outputs": [], 541 | "source": [ 542 | "test_policy_network(env, actor, episodes=2)" 543 | ] 544 | }, 545 | { 546 | "cell_type": "markdown", 547 | "metadata": { 548 | "id": "UHS56xgc7vV_" 549 | }, 550 | "source": [ 551 | "## Resources" 552 | ] 553 | }, 554 | { 555 | "cell_type": "markdown", 556 | "metadata": { 557 | "id": "Yk1oi1-_7vV_" 558 | }, 559 | "source": [ 560 | "[[1] Reinforcement Learning: An Introduction. Ch.13](https://web.stanford.edu/class/psych209/Readings/SuttonBartoIPRLBook2ndEd.pdf)" 561 | ] 562 | } 563 | ], 564 | "metadata": { 565 | "colab": { 566 | "provenance": [] 567 | }, 568 | "kernelspec": { 569 | "display_name": "Python 3", 570 | "language": "python", 571 | "name": "python3" 572 | }, 573 | "language_info": { 574 | "codemirror_mode": { 575 | "name": "ipython", 576 | "version": 3 577 | }, 578 | "file_extension": ".py", 579 | "mimetype": "text/x-python", 580 | "name": "python", 581 | "nbconvert_exporter": "python", 582 | "pygments_lexer": "ipython3", 583 | "version": "3.8.5" 584 | } 585 | }, 586 | "nbformat": 4, 587 | "nbformat_minor": 0 588 | } 589 | -------------------------------------------------------------------------------- /Section_11_advantage_actor_critic_complete.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "67rhCAvE7vV4", 7 | "pycharm": { 8 | "name": "#%%\n" 9 | } 10 | }, 11 | "source": [ 12 | "
\n", 13 | "

\n", 14 | " Advantage Actor-Critic (A2C)\n", 15 | "

\n", 16 | "
\n", 17 | "\n", 18 | "

\n", 19 | "\n", 20 | "
\n", 21 | "In this notebook we are going to combine temporal difference learning (TD) with policy gradient methods. The resulting algorithm is called Advantage Actor-Critic (A2C) and uses a one-step estimate of the return to update the policy:\n", 22 | "
\n", 23 | "\n", 24 | "\\begin{equation}\n", 25 | "\\hat G_t = R_{t+1} + \\gamma v(S_{t+1}|w)\n", 26 | "\\end{equation}\n", 27 | "\n", 28 | "\n", 29 | "
" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": { 36 | "cellView": "form", 37 | "id": "_uUI8vNy703R" 38 | }, 39 | "outputs": [], 40 | "source": [ 41 | "# @title Setup code (not important) - Run this cell by pressing \"Shift + Enter\"\n", 42 | "\n", 43 | "\n", 44 | "!pip install -qq gym==0.23.0 numpy==1.26.4 seaborn==0.12\n", 45 | "\n", 46 | "\n", 47 | "from typing import Tuple, Dict, Optional, Iterable, Callable\n", 48 | "\n", 49 | "import numpy as np\n", 50 | "import seaborn as sns\n", 51 | "import matplotlib\n", 52 | "import torch\n", 53 | "from matplotlib import animation\n", 54 | "import matplotlib.patches as mpatches\n", 55 | "\n", 56 | "from IPython.display import HTML\n", 57 | "\n", 58 | "import gym\n", 59 | "from gym import spaces\n", 60 | "from gym.error import DependencyNotInstalled\n", 61 | "\n", 62 | "import pygame\n", 63 | "from pygame import gfxdraw\n", 64 | "\n", 65 | "\n", 66 | "class Maze(gym.Env):\n", 67 | "\n", 68 | " def __init__(self, exploring_starts: bool = False,\n", 69 | " shaped_rewards: bool = False, size: int = 5) -> None:\n", 70 | " super().__init__()\n", 71 | " self.exploring_starts = exploring_starts\n", 72 | " self.shaped_rewards = shaped_rewards\n", 73 | " self.state = (size - 1, size - 1)\n", 74 | " self.goal = (size - 1, size - 1)\n", 75 | " self.maze = self._create_maze(size=size)\n", 76 | " self.distances = self._compute_distances(self.goal, self.maze)\n", 77 | " self.action_space = spaces.Discrete(n=4)\n", 78 | " self.action_space.action_meanings = {0: 'UP', 1: 'RIGHT', 2: 'DOWN', 3: \"LEFT\"}\n", 79 | " self.observation_space = spaces.MultiDiscrete([size, size])\n", 80 | "\n", 81 | " self.screen = None\n", 82 | " self.agent_transform = None\n", 83 | "\n", 84 | " def step(self, action: int) -> Tuple[Tuple[int, int], float, bool, Dict]:\n", 85 | " reward = self.compute_reward(self.state, action)\n", 86 | " self.state = self._get_next_state(self.state, action)\n", 87 | " done = self.state == self.goal\n", 88 | " info = {}\n", 89 | " return self.state, reward, done, info\n", 90 | "\n", 91 | " def reset(self) -> Tuple[int, int]:\n", 92 | " if self.exploring_starts:\n", 93 | " while self.state == self.goal:\n", 94 | " self.state = tuple(self.observation_space.sample())\n", 95 | " else:\n", 96 | " self.state = (0, 0)\n", 97 | " return self.state\n", 98 | "\n", 99 | " def render(self, mode: str = 'human') -> Optional[np.ndarray]:\n", 100 | " assert mode in ['human', 'rgb_array']\n", 101 | "\n", 102 | " screen_size = 600\n", 103 | " scale = screen_size / 5\n", 104 | "\n", 105 | " if self.screen is None:\n", 106 | " pygame.init()\n", 107 | " self.screen = pygame.Surface((screen_size, screen_size))\n", 108 | "\n", 109 | " surf = pygame.Surface((screen_size, screen_size))\n", 110 | " surf.fill((22, 36, 71))\n", 111 | "\n", 112 | "\n", 113 | " for row in range(5):\n", 114 | " for col in range(5):\n", 115 | "\n", 116 | " state = (row, col)\n", 117 | " for next_state in [(row + 1, col), (row - 1, col), (row, col + 1), (row, col - 1)]:\n", 118 | " if next_state not in self.maze[state]:\n", 119 | "\n", 120 | " # Add the geometry of the edges and walls (i.e. the boundaries between\n", 121 | " # adjacent squares that are not connected).\n", 122 | " row_diff, col_diff = np.subtract(next_state, state)\n", 123 | " left = (col + (col_diff > 0)) * scale - 2 * (col_diff != 0)\n", 124 | " right = ((col + 1) - (col_diff < 0)) * scale + 2 * (col_diff != 0)\n", 125 | " top = (5 - (row + (row_diff > 0))) * scale - 2 * (row_diff != 0)\n", 126 | " bottom = (5 - ((row + 1) - (row_diff < 0))) * scale + 2 * (row_diff != 0)\n", 127 | "\n", 128 | " gfxdraw.filled_polygon(surf, [(left, bottom), (left, top), (right, top), (right, bottom)], (255, 255, 255))\n", 129 | "\n", 130 | " # Add the geometry of the goal square to the viewer.\n", 131 | " left, right, top, bottom = scale * 4 + 10, scale * 5 - 10, scale - 10, 10\n", 132 | " gfxdraw.filled_polygon(surf, [(left, bottom), (left, top), (right, top), (right, bottom)], (40, 199, 172))\n", 133 | "\n", 134 | " # Add the geometry of the agent to the viewer.\n", 135 | " agent_row = int(screen_size - scale * (self.state[0] + .5))\n", 136 | " agent_col = int(scale * (self.state[1] + .5))\n", 137 | " gfxdraw.filled_circle(surf, agent_col, agent_row, int(scale * .6 / 2), (228, 63, 90))\n", 138 | "\n", 139 | " surf = pygame.transform.flip(surf, False, True)\n", 140 | " self.screen.blit(surf, (0, 0))\n", 141 | "\n", 142 | " return np.transpose(\n", 143 | " np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2)\n", 144 | " )\n", 145 | "\n", 146 | " def close(self) -> None:\n", 147 | " if self.screen is not None:\n", 148 | " pygame.display.quit()\n", 149 | " pygame.quit()\n", 150 | " self.screen = None\n", 151 | "\n", 152 | " def compute_reward(self, state: Tuple[int, int], action: int) -> float:\n", 153 | " next_state = self._get_next_state(state, action)\n", 154 | " if self.shaped_rewards:\n", 155 | " return - (self.distances[next_state] / self.distances.max())\n", 156 | " return - float(state != self.goal)\n", 157 | "\n", 158 | " def simulate_step(self, state: Tuple[int, int], action: int):\n", 159 | " reward = self.compute_reward(state, action)\n", 160 | " next_state = self._get_next_state(state, action)\n", 161 | " done = next_state == self.goal\n", 162 | " info = {}\n", 163 | " return next_state, reward, done, info\n", 164 | "\n", 165 | " def _get_next_state(self, state: Tuple[int, int], action: int) -> Tuple[int, int]:\n", 166 | " if action == 0:\n", 167 | " next_state = (state[0] - 1, state[1])\n", 168 | " elif action == 1:\n", 169 | " next_state = (state[0], state[1] + 1)\n", 170 | " elif action == 2:\n", 171 | " next_state = (state[0] + 1, state[1])\n", 172 | " elif action == 3:\n", 173 | " next_state = (state[0], state[1] - 1)\n", 174 | " else:\n", 175 | " raise ValueError(\"Action value not supported:\", action)\n", 176 | " if next_state in self.maze[state]:\n", 177 | " return next_state\n", 178 | " return state\n", 179 | "\n", 180 | " @staticmethod\n", 181 | " def _create_maze(size: int) -> Dict[Tuple[int, int], Iterable[Tuple[int, int]]]:\n", 182 | " maze = {(row, col): [(row - 1, col), (row + 1, col), (row, col - 1), (row, col + 1)]\n", 183 | " for row in range(size) for col in range(size)}\n", 184 | "\n", 185 | " left_edges = [[(row, 0), (row, -1)] for row in range(size)]\n", 186 | " right_edges = [[(row, size - 1), (row, size)] for row in range(size)]\n", 187 | " upper_edges = [[(0, col), (-1, col)] for col in range(size)]\n", 188 | " lower_edges = [[(size - 1, col), (size, col)] for col in range(size)]\n", 189 | " walls = [\n", 190 | " [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)],\n", 191 | " [(1, 1), (1, 2)], [(2, 1), (2, 2)], [(3, 1), (3, 2)],\n", 192 | " [(3, 1), (4, 1)], [(0, 2), (1, 2)], [(1, 2), (1, 3)],\n", 193 | " [(2, 2), (3, 2)], [(2, 3), (3, 3)], [(2, 4), (3, 4)],\n", 194 | " [(4, 2), (4, 3)], [(1, 3), (1, 4)], [(2, 3), (2, 4)],\n", 195 | " ]\n", 196 | "\n", 197 | " obstacles = upper_edges + lower_edges + left_edges + right_edges + walls\n", 198 | "\n", 199 | " for src, dst in obstacles:\n", 200 | " maze[src].remove(dst)\n", 201 | "\n", 202 | " if dst in maze:\n", 203 | " maze[dst].remove(src)\n", 204 | "\n", 205 | " return maze\n", 206 | "\n", 207 | " @staticmethod\n", 208 | " def _compute_distances(goal: Tuple[int, int],\n", 209 | " maze: Dict[Tuple[int, int], Iterable[Tuple[int, int]]]) -> np.ndarray:\n", 210 | " distances = np.full((5, 5), np.inf)\n", 211 | " visited = set()\n", 212 | " distances[goal] = 0.\n", 213 | "\n", 214 | " while visited != set(maze):\n", 215 | " sorted_dst = [(v // 5, v % 5) for v in distances.argsort(axis=None)]\n", 216 | " closest = next(x for x in sorted_dst if x not in visited)\n", 217 | " visited.add(closest)\n", 218 | "\n", 219 | " for neighbour in maze[closest]:\n", 220 | " distances[neighbour] = min(distances[neighbour], distances[closest] + 1)\n", 221 | " return distances\n", 222 | "\n", 223 | "\n", 224 | "def display_video(frames):\n", 225 | " # Copied from: https://colab.research.google.com/github/deepmind/dm_control/blob/master/tutorial.ipynb\n", 226 | " orig_backend = matplotlib.get_backend()\n", 227 | " matplotlib.use('Agg')\n", 228 | " fig, ax = plt.subplots(1, 1, figsize=(5, 5))\n", 229 | " matplotlib.use(orig_backend)\n", 230 | " ax.set_axis_off()\n", 231 | " ax.set_aspect('equal')\n", 232 | " ax.set_position([0, 0, 1, 1])\n", 233 | " im = ax.imshow(frames[0])\n", 234 | " def update(frame):\n", 235 | " im.set_data(frame)\n", 236 | " return [im]\n", 237 | " anim = animation.FuncAnimation(fig=fig, func=update, frames=frames,\n", 238 | " interval=50, blit=True, repeat=False)\n", 239 | " return HTML(anim.to_html5_video())\n", 240 | "\n", 241 | "\n", 242 | "def seed_everything(env: gym.Env, seed: int = 42) -> None:\n", 243 | " env.seed(seed)\n", 244 | " env.action_space.seed(seed)\n", 245 | " env.observation_space.seed(seed)\n", 246 | " np.random.seed(seed)\n", 247 | " torch.manual_seed(seed)\n", 248 | " torch.use_deterministic_algorithms(True)\n", 249 | "\n", 250 | "\n", 251 | "def plot_stats(stats):\n", 252 | " rows = len(stats)\n", 253 | " cols = 1\n", 254 | "\n", 255 | " fig, ax = plt.subplots(rows, cols, figsize=(12, 6))\n", 256 | "\n", 257 | " for i, key in enumerate(stats):\n", 258 | " vals = stats[key]\n", 259 | " vals = [np.mean(vals[i-10:i+10]) for i in range(10, len(vals)-10)]\n", 260 | " if len(stats) > 1:\n", 261 | " ax[i].plot(range(len(vals)), vals)\n", 262 | " ax[i].set_title(key, size=18)\n", 263 | " else:\n", 264 | " ax.plot(range(len(vals)), vals)\n", 265 | " ax.set_title(key, size=18)\n", 266 | " plt.tight_layout()\n", 267 | " plt.show()\n", 268 | "\n", 269 | "\n", 270 | "def test_policy_network(env, policy, episodes=10):\n", 271 | " frames = []\n", 272 | " for episode in range(episodes):\n", 273 | " state = env.reset()\n", 274 | " done = False\n", 275 | " frames.append(env.render(mode=\"rgb_array\"))\n", 276 | "\n", 277 | " while not done:\n", 278 | " state = torch.from_numpy(state).unsqueeze(0).float()\n", 279 | " action = policy(state).multinomial(1).item()\n", 280 | " next_state, _, done, _ = env.step(action)\n", 281 | " img = env.render(mode=\"rgb_array\")\n", 282 | " frames.append(img)\n", 283 | " state = next_state\n", 284 | "\n", 285 | " return display_video(frames)\n", 286 | "\n", 287 | "\n", 288 | "def plot_action_probs(probs, labels):\n", 289 | " plt.figure(figsize=(6, 4))\n", 290 | " plt.bar(labels, probs, color ='orange')\n", 291 | " plt.title(\"$\\pi(s)$\", size=16)\n", 292 | " plt.xticks(fontsize=12)\n", 293 | " plt.yticks(fontsize=12)\n", 294 | " plt.tight_layout()\n", 295 | " plt.show()\n", 296 | "\n" 297 | ] 298 | }, 299 | { 300 | "cell_type": "markdown", 301 | "metadata": { 302 | "id": "RlB0Tbp07vV6" 303 | }, 304 | "source": [ 305 | "## Import the necessary software libraries:" 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "execution_count": null, 311 | "metadata": { 312 | "id": "2OnbUU8t7vV7" 313 | }, 314 | "outputs": [], 315 | "source": [ 316 | "import os\n", 317 | "import torch\n", 318 | "import gym\n", 319 | "import numpy as np\n", 320 | "import matplotlib.pyplot as plt\n", 321 | "from tqdm import tqdm\n", 322 | "from torch import nn as nn\n", 323 | "from torch.optim import AdamW\n", 324 | "import torch.nn.functional as F" 325 | ] 326 | }, 327 | { 328 | "cell_type": "markdown", 329 | "metadata": { 330 | "id": "pPEwlOrt7vV8" 331 | }, 332 | "source": [ 333 | "## Create and preprocess the environment" 334 | ] 335 | }, 336 | { 337 | "cell_type": "markdown", 338 | "metadata": { 339 | "id": "j37j_pOh7vV8" 340 | }, 341 | "source": [ 342 | "### Create the environment" 343 | ] 344 | }, 345 | { 346 | "cell_type": "code", 347 | "execution_count": null, 348 | "metadata": { 349 | "id": "RDViC8L47vV8" 350 | }, 351 | "outputs": [], 352 | "source": [ 353 | "env = gym.make('Acrobot-v1')" 354 | ] 355 | }, 356 | { 357 | "cell_type": "code", 358 | "execution_count": null, 359 | "metadata": { 360 | "id": "LuJ9Hx4E7vV8" 361 | }, 362 | "outputs": [], 363 | "source": [ 364 | "dims = env.observation_space.shape[0]\n", 365 | "actions = env.action_space.n\n", 366 | "\n", 367 | "print(f\"State dimensions: {dims}. Actions: {actions}\")\n", 368 | "print(f\"Sample state: {env.reset()}\")" 369 | ] 370 | }, 371 | { 372 | "cell_type": "code", 373 | "execution_count": null, 374 | "metadata": { 375 | "id": "QVHSTC827vV8" 376 | }, 377 | "outputs": [], 378 | "source": [ 379 | "plt.imshow(env.render(mode='rgb_array'))" 380 | ] 381 | }, 382 | { 383 | "cell_type": "markdown", 384 | "metadata": { 385 | "id": "8piDIP4E7vV9" 386 | }, 387 | "source": [ 388 | "### Prepare the environment to work with PyTorch" 389 | ] 390 | }, 391 | { 392 | "cell_type": "code", 393 | "execution_count": null, 394 | "metadata": { 395 | "id": "bBqxt5hP7vV9" 396 | }, 397 | "outputs": [], 398 | "source": [ 399 | "class PreprocessEnv(gym.Wrapper):\n", 400 | "\n", 401 | " def __init__(self, env):\n", 402 | " gym.Wrapper.__init__(self, env)\n", 403 | "\n", 404 | " def reset(self):\n", 405 | " state = self.env.reset()\n", 406 | " return torch.from_numpy(state).float()\n", 407 | "\n", 408 | " def step(self, actions):\n", 409 | " actions = actions.squeeze().numpy()\n", 410 | " next_state, reward, done, info = self.env.step(actions)\n", 411 | " next_state = torch.from_numpy(next_state).float()\n", 412 | " reward = torch.tensor(reward).unsqueeze(1).float()\n", 413 | " done = torch.tensor(done).unsqueeze(1)\n", 414 | " return next_state, reward, done, info" 415 | ] 416 | }, 417 | { 418 | "cell_type": "code", 419 | "execution_count": null, 420 | "metadata": { 421 | "id": "jkRkODiA7vV-" 422 | }, 423 | "outputs": [], 424 | "source": [ 425 | "num_envs = 8\n", 426 | "parallel_env = gym.vector.make('Acrobot-v1', num_envs=num_envs)\n", 427 | "seed_everything(parallel_env)\n", 428 | "parallel_env = PreprocessEnv(parallel_env)" 429 | ] 430 | }, 431 | { 432 | "cell_type": "markdown", 433 | "metadata": { 434 | "id": "H_35Rcxy7vV-" 435 | }, 436 | "source": [ 437 | "### Create the policy $\\pi(s)$" 438 | ] 439 | }, 440 | { 441 | "cell_type": "code", 442 | "execution_count": null, 443 | "metadata": { 444 | "id": "rF1_C5Xb7vV-" 445 | }, 446 | "outputs": [], 447 | "source": [ 448 | "actor = nn.Sequential(\n", 449 | " nn.Linear(dims, 128),\n", 450 | " nn.ReLU(),\n", 451 | " nn.Linear(128, 64),\n", 452 | " nn.ReLU(),\n", 453 | " nn.Linear(64, actions),\n", 454 | " nn.Softmax(dim=-1))" 455 | ] 456 | }, 457 | { 458 | "cell_type": "markdown", 459 | "metadata": { 460 | "id": "LPY8NEB17vV-" 461 | }, 462 | "source": [ 463 | "### Create the value network $v(s)$" 464 | ] 465 | }, 466 | { 467 | "cell_type": "code", 468 | "execution_count": null, 469 | "metadata": { 470 | "id": "EKD6vk4i7vV-" 471 | }, 472 | "outputs": [], 473 | "source": [ 474 | "critic = nn.Sequential(\n", 475 | " nn.Linear(dims, 128),\n", 476 | " nn.ReLU(),\n", 477 | " nn.Linear(128, 64),\n", 478 | " nn.ReLU(),\n", 479 | " nn.Linear(64, 1))" 480 | ] 481 | }, 482 | { 483 | "cell_type": "markdown", 484 | "metadata": { 485 | "id": "SH5RVyq-7vV-" 486 | }, 487 | "source": [ 488 | "## Implement the algorithm\n" 489 | ] 490 | }, 491 | { 492 | "cell_type": "code", 493 | "execution_count": null, 494 | "metadata": { 495 | "id": "lI3Zju7u7vV-" 496 | }, 497 | "outputs": [], 498 | "source": [ 499 | "def actor_critic(actor, critic, episodes, alpha=1e-4, gamma=0.99):\n", 500 | " actor_optim = AdamW(actor.parameters(), lr=1e-3)\n", 501 | " critic_optim = AdamW(critic.parameters(), lr=1e-4)\n", 502 | " stats = {'Actor Loss': [], 'Critic Loss': [], 'Returns': []}\n", 503 | "\n", 504 | " for episode in tqdm(range(1, episodes + 1)):\n", 505 | " state = parallel_env.reset()\n", 506 | " done_b = torch.zeros((num_envs, 1), dtype=torch.bool)\n", 507 | " ep_return = torch.zeros((num_envs, 1))\n", 508 | " I = 1.\n", 509 | "\n", 510 | " while not done_b.all():\n", 511 | " action = actor(state).multinomial(1).detach()\n", 512 | " next_state, reward, done, _ = parallel_env.step(action)\n", 513 | "\n", 514 | " value = critic(state)\n", 515 | " target = reward + ~done * gamma * critic(next_state).detach()\n", 516 | " critic_loss = F.mse_loss(value, target)\n", 517 | " critic.zero_grad()\n", 518 | " critic_loss.backward()\n", 519 | " critic_optim.step()\n", 520 | "\n", 521 | " advantage = (target - value).detach()\n", 522 | " probs = actor(state)\n", 523 | " log_probs = torch.log(probs + 1e-6)\n", 524 | " action_log_prob = log_probs.gather(1, action)\n", 525 | " entropy = - torch.sum(probs * log_probs, dim=-1, keepdim=True)\n", 526 | " actor_loss = - I * action_log_prob * advantage - 0.01 * entropy\n", 527 | " actor_loss = actor_loss.mean()\n", 528 | " actor.zero_grad()\n", 529 | " actor_loss.backward()\n", 530 | " actor_optim.step()\n", 531 | "\n", 532 | " ep_return += reward\n", 533 | " done_b |= done\n", 534 | " state = next_state\n", 535 | " I = I * gamma\n", 536 | "\n", 537 | " stats['Actor Loss'].append(actor_loss.item())\n", 538 | " stats['Critic Loss'].append(critic_loss.item())\n", 539 | " stats['Returns'].append(ep_return.mean().item())\n", 540 | "\n", 541 | " return stats" 542 | ] 543 | }, 544 | { 545 | "cell_type": "code", 546 | "execution_count": null, 547 | "metadata": { 548 | "id": "eEWU63Z07vV-", 549 | "scrolled": true 550 | }, 551 | "outputs": [], 552 | "source": [ 553 | "stats = actor_critic(actor, critic, 100)" 554 | ] 555 | }, 556 | { 557 | "cell_type": "markdown", 558 | "metadata": { 559 | "id": "DoxzCbPz7vV-" 560 | }, 561 | "source": [ 562 | "## Show results" 563 | ] 564 | }, 565 | { 566 | "cell_type": "markdown", 567 | "metadata": { 568 | "id": "3cmyUHP67vV-" 569 | }, 570 | "source": [ 571 | "### Show execution stats" 572 | ] 573 | }, 574 | { 575 | "cell_type": "code", 576 | "execution_count": null, 577 | "metadata": { 578 | "id": "oV46xCdU7vV-" 579 | }, 580 | "outputs": [], 581 | "source": [ 582 | "plot_stats(stats)" 583 | ] 584 | }, 585 | { 586 | "cell_type": "markdown", 587 | "metadata": { 588 | "id": "KD1Khhk17vV-" 589 | }, 590 | "source": [ 591 | "### Test the resulting agent" 592 | ] 593 | }, 594 | { 595 | "cell_type": "code", 596 | "execution_count": null, 597 | "metadata": { 598 | "id": "w6EwEhPd7vV_" 599 | }, 600 | "outputs": [], 601 | "source": [ 602 | "test_policy_network(env, actor, episodes=2)" 603 | ] 604 | }, 605 | { 606 | "cell_type": "markdown", 607 | "metadata": { 608 | "id": "UHS56xgc7vV_" 609 | }, 610 | "source": [ 611 | "## Resources" 612 | ] 613 | }, 614 | { 615 | "cell_type": "markdown", 616 | "metadata": { 617 | "id": "Yk1oi1-_7vV_" 618 | }, 619 | "source": [ 620 | "[[1] Reinforcement Learning: An Introduction. Ch.13](https://web.stanford.edu/class/psych209/Readings/SuttonBartoIPRLBook2ndEd.pdf)" 621 | ] 622 | } 623 | ], 624 | "metadata": { 625 | "colab": { 626 | "provenance": [] 627 | }, 628 | "kernelspec": { 629 | "display_name": "Python 3", 630 | "language": "python", 631 | "name": "python3" 632 | }, 633 | "language_info": { 634 | "codemirror_mode": { 635 | "name": "ipython", 636 | "version": 3 637 | }, 638 | "file_extension": ".py", 639 | "mimetype": "text/x-python", 640 | "name": "python", 641 | "nbconvert_exporter": "python", 642 | "pygments_lexer": "ipython3", 643 | "version": "3.8.5" 644 | } 645 | }, 646 | "nbformat": 4, 647 | "nbformat_minor": 0 648 | } 649 | -------------------------------------------------------------------------------- /Section_3_policy_iteration.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "pycharm": { 7 | "name": "#%%\n" 8 | }, 9 | "id": "b9jmpP4J6VbY" 10 | }, 11 | "source": [ 12 | "
\n", 13 | "

\n", 14 | " Policy Iteration\n", 15 | "

\n", 16 | "
\n", 17 | "
\n", 18 | "\n", 19 | "
\n", 20 | "

\n", 21 | " In this notebook we are going to look at a dynamic programming algorithm called policy iteration. In it, we will iteratively interleave two processes: policy evaluation and policy improvement, until the optimal policy and state values are found.\n", 22 | "

\n", 23 | "
\n", 24 | "\n", 25 | "
" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "source": [ 31 | "# @title Setup code (not important) - Run this cell by pressing \"Shift + Enter\"\n", 32 | "\n", 33 | "\n", 34 | "\n", 35 | "!pip install -qq gym==0.23.0\n", 36 | "\n", 37 | "\n", 38 | "from typing import Tuple, Dict, Optional, Iterable, Callable\n", 39 | "\n", 40 | "import numpy as np\n", 41 | "import seaborn as sns\n", 42 | "import matplotlib\n", 43 | "from matplotlib import animation\n", 44 | "\n", 45 | "from IPython.display import HTML\n", 46 | "\n", 47 | "import gym\n", 48 | "from gym import spaces\n", 49 | "from gym.error import DependencyNotInstalled\n", 50 | "\n", 51 | "import pygame\n", 52 | "from pygame import gfxdraw\n", 53 | "\n", 54 | "\n", 55 | "class Maze(gym.Env):\n", 56 | "\n", 57 | " def __init__(self, exploring_starts: bool = False,\n", 58 | " shaped_rewards: bool = False, size: int = 5) -> None:\n", 59 | " super().__init__()\n", 60 | " self.exploring_starts = exploring_starts\n", 61 | " self.shaped_rewards = shaped_rewards\n", 62 | " self.state = (size - 1, size - 1)\n", 63 | " self.goal = (size - 1, size - 1)\n", 64 | " self.maze = self._create_maze(size=size)\n", 65 | " self.distances = self._compute_distances(self.goal, self.maze)\n", 66 | " self.action_space = spaces.Discrete(n=4)\n", 67 | " self.action_space.action_meanings = {0: 'UP', 1: 'RIGHT', 2: 'DOWN', 3: \"LEFT\"}\n", 68 | " self.observation_space = spaces.MultiDiscrete([size, size])\n", 69 | "\n", 70 | " self.screen = None\n", 71 | " self.agent_transform = None\n", 72 | "\n", 73 | " def step(self, action: int) -> Tuple[Tuple[int, int], float, bool, Dict]:\n", 74 | " reward = self.compute_reward(self.state, action)\n", 75 | " self.state = self._get_next_state(self.state, action)\n", 76 | " done = self.state == self.goal\n", 77 | " info = {}\n", 78 | " return self.state, reward, done, info\n", 79 | "\n", 80 | " def reset(self) -> Tuple[int, int]:\n", 81 | " if self.exploring_starts:\n", 82 | " while self.state == self.goal:\n", 83 | " self.state = tuple(self.observation_space.sample())\n", 84 | " else:\n", 85 | " self.state = (0, 0)\n", 86 | " return self.state\n", 87 | "\n", 88 | " def render(self, mode: str = 'human') -> Optional[np.ndarray]:\n", 89 | " assert mode in ['human', 'rgb_array']\n", 90 | "\n", 91 | " screen_size = 600\n", 92 | " scale = screen_size / 5\n", 93 | "\n", 94 | " if self.screen is None:\n", 95 | " pygame.init()\n", 96 | " self.screen = pygame.Surface((screen_size, screen_size))\n", 97 | "\n", 98 | " surf = pygame.Surface((screen_size, screen_size))\n", 99 | " surf.fill((22, 36, 71))\n", 100 | "\n", 101 | "\n", 102 | " for row in range(5):\n", 103 | " for col in range(5):\n", 104 | "\n", 105 | " state = (row, col)\n", 106 | " for next_state in [(row + 1, col), (row - 1, col), (row, col + 1), (row, col - 1)]:\n", 107 | " if next_state not in self.maze[state]:\n", 108 | "\n", 109 | " # Add the geometry of the edges and walls (i.e. the boundaries between\n", 110 | " # adjacent squares that are not connected).\n", 111 | " row_diff, col_diff = np.subtract(next_state, state)\n", 112 | " left = (col + (col_diff > 0)) * scale - 2 * (col_diff != 0)\n", 113 | " right = ((col + 1) - (col_diff < 0)) * scale + 2 * (col_diff != 0)\n", 114 | " top = (5 - (row + (row_diff > 0))) * scale - 2 * (row_diff != 0)\n", 115 | " bottom = (5 - ((row + 1) - (row_diff < 0))) * scale + 2 * (row_diff != 0)\n", 116 | "\n", 117 | " gfxdraw.filled_polygon(surf, [(left, bottom), (left, top), (right, top), (right, bottom)], (255, 255, 255))\n", 118 | "\n", 119 | " # Add the geometry of the goal square to the viewer.\n", 120 | " left, right, top, bottom = scale * 4 + 10, scale * 5 - 10, scale - 10, 10\n", 121 | " gfxdraw.filled_polygon(surf, [(left, bottom), (left, top), (right, top), (right, bottom)], (40, 199, 172))\n", 122 | "\n", 123 | " # Add the geometry of the agent to the viewer.\n", 124 | " agent_row = int(screen_size - scale * (self.state[0] + .5))\n", 125 | " agent_col = int(scale * (self.state[1] + .5))\n", 126 | " gfxdraw.filled_circle(surf, agent_col, agent_row, int(scale * .6 / 2), (228, 63, 90))\n", 127 | "\n", 128 | " surf = pygame.transform.flip(surf, False, True)\n", 129 | " self.screen.blit(surf, (0, 0))\n", 130 | "\n", 131 | " return np.transpose(\n", 132 | " np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2)\n", 133 | " )\n", 134 | "\n", 135 | " def close(self) -> None:\n", 136 | " if self.screen is not None:\n", 137 | " pygame.display.quit()\n", 138 | " pygame.quit()\n", 139 | " self.screen = None\n", 140 | "\n", 141 | " def compute_reward(self, state: Tuple[int, int], action: int) -> float:\n", 142 | " next_state = self._get_next_state(state, action)\n", 143 | " if self.shaped_rewards:\n", 144 | " return - (self.distances[next_state] / self.distances.max())\n", 145 | " return - float(state != self.goal)\n", 146 | "\n", 147 | " def simulate_step(self, state: Tuple[int, int], action: int):\n", 148 | " reward = self.compute_reward(state, action)\n", 149 | " next_state = self._get_next_state(state, action)\n", 150 | " done = next_state == self.goal\n", 151 | " info = {}\n", 152 | " return next_state, reward, done, info\n", 153 | "\n", 154 | " def _get_next_state(self, state: Tuple[int, int], action: int) -> Tuple[int, int]:\n", 155 | " if action == 0:\n", 156 | " next_state = (state[0] - 1, state[1])\n", 157 | " elif action == 1:\n", 158 | " next_state = (state[0], state[1] + 1)\n", 159 | " elif action == 2:\n", 160 | " next_state = (state[0] + 1, state[1])\n", 161 | " elif action == 3:\n", 162 | " next_state = (state[0], state[1] - 1)\n", 163 | " else:\n", 164 | " raise ValueError(\"Action value not supported:\", action)\n", 165 | " if next_state in self.maze[state]:\n", 166 | " return next_state\n", 167 | " return state\n", 168 | "\n", 169 | " @staticmethod\n", 170 | " def _create_maze(size: int) -> Dict[Tuple[int, int], Iterable[Tuple[int, int]]]:\n", 171 | " maze = {(row, col): [(row - 1, col), (row + 1, col), (row, col - 1), (row, col + 1)]\n", 172 | " for row in range(size) for col in range(size)}\n", 173 | "\n", 174 | " left_edges = [[(row, 0), (row, -1)] for row in range(size)]\n", 175 | " right_edges = [[(row, size - 1), (row, size)] for row in range(size)]\n", 176 | " upper_edges = [[(0, col), (-1, col)] for col in range(size)]\n", 177 | " lower_edges = [[(size - 1, col), (size, col)] for col in range(size)]\n", 178 | " walls = [\n", 179 | " [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)],\n", 180 | " [(1, 1), (1, 2)], [(2, 1), (2, 2)], [(3, 1), (3, 2)],\n", 181 | " [(3, 1), (4, 1)], [(0, 2), (1, 2)], [(1, 2), (1, 3)],\n", 182 | " [(2, 2), (3, 2)], [(2, 3), (3, 3)], [(2, 4), (3, 4)],\n", 183 | " [(4, 2), (4, 3)], [(1, 3), (1, 4)], [(2, 3), (2, 4)],\n", 184 | " ]\n", 185 | "\n", 186 | " obstacles = upper_edges + lower_edges + left_edges + right_edges + walls\n", 187 | "\n", 188 | " for src, dst in obstacles:\n", 189 | " maze[src].remove(dst)\n", 190 | "\n", 191 | " if dst in maze:\n", 192 | " maze[dst].remove(src)\n", 193 | "\n", 194 | " return maze\n", 195 | "\n", 196 | " @staticmethod\n", 197 | " def _compute_distances(goal: Tuple[int, int],\n", 198 | " maze: Dict[Tuple[int, int], Iterable[Tuple[int, int]]]) -> np.ndarray:\n", 199 | " distances = np.full((5, 5), np.inf)\n", 200 | " visited = set()\n", 201 | " distances[goal] = 0.\n", 202 | "\n", 203 | " while visited != set(maze):\n", 204 | " sorted_dst = [(v // 5, v % 5) for v in distances.argsort(axis=None)]\n", 205 | " closest = next(x for x in sorted_dst if x not in visited)\n", 206 | " visited.add(closest)\n", 207 | "\n", 208 | " for neighbour in maze[closest]:\n", 209 | " distances[neighbour] = min(distances[neighbour], distances[closest] + 1)\n", 210 | " return distances\n", 211 | "\n", 212 | "\n", 213 | "def plot_policy(probs_or_qvals, frame, action_meanings=None):\n", 214 | " if action_meanings is None:\n", 215 | " action_meanings = {0: 'U', 1: 'R', 2: 'D', 3: 'L'}\n", 216 | " fig, axes = plt.subplots(1, 2, figsize=(8, 4))\n", 217 | " max_prob_actions = probs_or_qvals.argmax(axis=-1)\n", 218 | " probs_copy = max_prob_actions.copy().astype(object)\n", 219 | " for key in action_meanings:\n", 220 | " probs_copy[probs_copy == key] = action_meanings[key]\n", 221 | " sns.heatmap(max_prob_actions, annot=probs_copy, fmt='', cbar=False, cmap='coolwarm',\n", 222 | " annot_kws={'weight': 'bold', 'size': 12}, linewidths=2, ax=axes[0])\n", 223 | " axes[1].imshow(frame)\n", 224 | " axes[0].axis('off')\n", 225 | " axes[1].axis('off')\n", 226 | " plt.suptitle(\"Policy\", size=18)\n", 227 | " plt.tight_layout()\n", 228 | "\n", 229 | "\n", 230 | "def plot_values(state_values, frame):\n", 231 | " f, axes = plt.subplots(1, 2, figsize=(10, 4))\n", 232 | " sns.heatmap(state_values, annot=True, fmt=\".2f\", cmap='coolwarm',\n", 233 | " annot_kws={'weight': 'bold', 'size': 12}, linewidths=2, ax=axes[0])\n", 234 | " axes[1].imshow(frame)\n", 235 | " axes[0].axis('off')\n", 236 | " axes[1].axis('off')\n", 237 | " plt.tight_layout()\n", 238 | "\n", 239 | "\n", 240 | "def display_video(frames):\n", 241 | " # Copied from: https://colab.research.google.com/github/deepmind/dm_control/blob/master/tutorial.ipynb\n", 242 | " orig_backend = matplotlib.get_backend()\n", 243 | " matplotlib.use('Agg')\n", 244 | " fig, ax = plt.subplots(1, 1, figsize=(5, 5))\n", 245 | " matplotlib.use(orig_backend)\n", 246 | " ax.set_axis_off()\n", 247 | " ax.set_aspect('equal')\n", 248 | " ax.set_position([0, 0, 1, 1])\n", 249 | " im = ax.imshow(frames[0])\n", 250 | " def update(frame):\n", 251 | " im.set_data(frame)\n", 252 | " return [im]\n", 253 | " anim = animation.FuncAnimation(fig=fig, func=update, frames=frames,\n", 254 | " interval=50, blit=True, repeat=False)\n", 255 | " return HTML(anim.to_html5_video())\n", 256 | "\n", 257 | "\n", 258 | "def test_agent(environment, policy, episodes=10):\n", 259 | " frames = []\n", 260 | " for episode in range(episodes):\n", 261 | " state = env.reset()\n", 262 | " done = False\n", 263 | " frames.append(env.render(mode=\"rgb_array\"))\n", 264 | "\n", 265 | " while not done:\n", 266 | " p = policy(state)\n", 267 | " if isinstance(p, np.ndarray):\n", 268 | " action = np.random.choice(4, p=p)\n", 269 | " else:\n", 270 | " action = p\n", 271 | " next_state, reward, done, extra_info = env.step(action)\n", 272 | " img = env.render(mode=\"rgb_array\")\n", 273 | " frames.append(img)\n", 274 | " state = next_state\n", 275 | "\n", 276 | " return display_video(frames)\n", 277 | "\n" 278 | ], 279 | "metadata": { 280 | "cellView": "form", 281 | "id": "dnULhrB06fg1" 282 | }, 283 | "execution_count": null, 284 | "outputs": [] 285 | }, 286 | { 287 | "cell_type": "markdown", 288 | "metadata": { 289 | "id": "j-pns0SL6Vbc" 290 | }, 291 | "source": [ 292 | "## Import the necessary software libraries:" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": null, 298 | "metadata": { 299 | "id": "MiP1siGs6Vbc" 300 | }, 301 | "outputs": [], 302 | "source": [ 303 | "import numpy as np\n", 304 | "import matplotlib.pyplot as plt" 305 | ] 306 | }, 307 | { 308 | "cell_type": "markdown", 309 | "metadata": { 310 | "id": "NM6THrGl6Vbd" 311 | }, 312 | "source": [ 313 | "## Initialize the environment" 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": null, 319 | "metadata": { 320 | "id": "q2ytq7oR6Vbe" 321 | }, 322 | "outputs": [], 323 | "source": [ 324 | "env = Maze()" 325 | ] 326 | }, 327 | { 328 | "cell_type": "code", 329 | "execution_count": null, 330 | "metadata": { 331 | "id": "glvnH2SX6Vbe" 332 | }, 333 | "outputs": [], 334 | "source": [ 335 | "frame = env.render(mode='rgb_array')\n", 336 | "plt.axis('off')\n", 337 | "plt.imshow(frame)" 338 | ] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "execution_count": null, 343 | "metadata": { 344 | "id": "YrpWHqxu6Vbe" 345 | }, 346 | "outputs": [], 347 | "source": [ 348 | "print(f\"Observation space shape: {env.observation_space.nvec}\")\n", 349 | "print(f\"Number of actions: {env.action_space.n}\")" 350 | ] 351 | }, 352 | { 353 | "cell_type": "markdown", 354 | "metadata": { 355 | "id": "Rx947Ii46Vbf" 356 | }, 357 | "source": [ 358 | "## Define the policy $\\pi(\\cdot|s)$" 359 | ] 360 | }, 361 | { 362 | "cell_type": "markdown", 363 | "metadata": { 364 | "id": "mj-DEzgD6Vbf" 365 | }, 366 | "source": [ 367 | "#### Create the policy $\\pi(\\cdot|s)$" 368 | ] 369 | }, 370 | { 371 | "cell_type": "code", 372 | "execution_count": null, 373 | "metadata": { 374 | "id": "kivjfd-t6Vbf" 375 | }, 376 | "outputs": [], 377 | "source": [ 378 | "policy_probs = np.full((5, 5, 4), 0.25)" 379 | ] 380 | }, 381 | { 382 | "cell_type": "code", 383 | "execution_count": null, 384 | "metadata": { 385 | "id": "sunXzggw6Vbf" 386 | }, 387 | "outputs": [], 388 | "source": [ 389 | "def policy(state):\n", 390 | " return policy_probs[state]" 391 | ] 392 | }, 393 | { 394 | "cell_type": "markdown", 395 | "metadata": { 396 | "id": "PGrxXujC6Vbg" 397 | }, 398 | "source": [ 399 | "#### Test the policy with state (0, 0)" 400 | ] 401 | }, 402 | { 403 | "cell_type": "code", 404 | "execution_count": null, 405 | "metadata": { 406 | "id": "VlXckfZ96Vbg" 407 | }, 408 | "outputs": [], 409 | "source": [ 410 | "action_probabilities = policy((0,0))\n", 411 | "for action, prob in zip(range(4), action_probabilities):\n", 412 | " print(f\"Probability of taking action {action}: {prob}\")" 413 | ] 414 | }, 415 | { 416 | "cell_type": "markdown", 417 | "metadata": { 418 | "id": "THcQNGHa6Vbg" 419 | }, 420 | "source": [ 421 | "#### See how the random policy does in the maze" 422 | ] 423 | }, 424 | { 425 | "cell_type": "code", 426 | "execution_count": null, 427 | "metadata": { 428 | "id": "gsSYs5Ry6Vbg" 429 | }, 430 | "outputs": [], 431 | "source": [ 432 | "test_agent(env, policy, episodes=1)" 433 | ] 434 | }, 435 | { 436 | "cell_type": "markdown", 437 | "metadata": { 438 | "id": "tF2UjZft6Vbg" 439 | }, 440 | "source": [ 441 | "#### Plot the policy" 442 | ] 443 | }, 444 | { 445 | "cell_type": "code", 446 | "execution_count": null, 447 | "metadata": { 448 | "id": "HJdUvc1T6Vbg" 449 | }, 450 | "outputs": [], 451 | "source": [ 452 | "plot_policy(policy_probs, frame)" 453 | ] 454 | }, 455 | { 456 | "cell_type": "markdown", 457 | "metadata": { 458 | "id": "btCVXrhJ6Vbg" 459 | }, 460 | "source": [ 461 | "## Define value table $V(s)$" 462 | ] 463 | }, 464 | { 465 | "cell_type": "markdown", 466 | "metadata": { 467 | "id": "byOFEsaA6Vbg" 468 | }, 469 | "source": [ 470 | "#### Create the $V(s)$ table" 471 | ] 472 | }, 473 | { 474 | "cell_type": "code", 475 | "execution_count": null, 476 | "metadata": { 477 | "id": "j2lX_33D6Vbg" 478 | }, 479 | "outputs": [], 480 | "source": [ 481 | "state_values = np.zeros(shape=(5,5))" 482 | ] 483 | }, 484 | { 485 | "cell_type": "markdown", 486 | "metadata": { 487 | "id": "D9_EfMnY6Vbh" 488 | }, 489 | "source": [ 490 | "#### Plot V(s)" 491 | ] 492 | }, 493 | { 494 | "cell_type": "code", 495 | "execution_count": null, 496 | "metadata": { 497 | "id": "TI5YZTSF6Vbh" 498 | }, 499 | "outputs": [], 500 | "source": [ 501 | "plot_values(state_values, frame)" 502 | ] 503 | }, 504 | { 505 | "cell_type": "markdown", 506 | "metadata": { 507 | "id": "qfiam4Ef6Vbh" 508 | }, 509 | "source": [ 510 | "## Implement the Policy Iteration algorithm" 511 | ] 512 | }, 513 | { 514 | "cell_type": "code", 515 | "execution_count": null, 516 | "metadata": { 517 | "id": "N_5XQChT6Vbh" 518 | }, 519 | "outputs": [], 520 | "source": [] 521 | }, 522 | { 523 | "cell_type": "code", 524 | "execution_count": null, 525 | "metadata": { 526 | "id": "JRj7ZxuO6Vbh" 527 | }, 528 | "outputs": [], 529 | "source": [] 530 | }, 531 | { 532 | "cell_type": "code", 533 | "execution_count": null, 534 | "metadata": { 535 | "id": "-0F0aB_f6Vbh" 536 | }, 537 | "outputs": [], 538 | "source": [] 539 | }, 540 | { 541 | "cell_type": "code", 542 | "execution_count": null, 543 | "metadata": { 544 | "id": "0eZr-1ko6Vbh" 545 | }, 546 | "outputs": [], 547 | "source": [] 548 | }, 549 | { 550 | "cell_type": "markdown", 551 | "metadata": { 552 | "id": "zHHT9iLY6Vbh" 553 | }, 554 | "source": [ 555 | "## Show results" 556 | ] 557 | }, 558 | { 559 | "cell_type": "markdown", 560 | "source": [ 561 | "Show resulting value table V(s)" 562 | ], 563 | "metadata": { 564 | "id": "4bVf7AMd7ajW" 565 | } 566 | }, 567 | { 568 | "cell_type": "code", 569 | "source": [ 570 | "plot_values(state_values, frame)" 571 | ], 572 | "metadata": { 573 | "id": "Y1GDXzvt7atl" 574 | }, 575 | "execution_count": null, 576 | "outputs": [] 577 | }, 578 | { 579 | "cell_type": "markdown", 580 | "source": [ 581 | "Show resulting policy $\\pi(\\cdot|s)$" 582 | ], 583 | "metadata": { 584 | "id": "1OdMtp1n7hBU" 585 | } 586 | }, 587 | { 588 | "cell_type": "code", 589 | "source": [ 590 | "plot_policy(policy_probs, frame)" 591 | ], 592 | "metadata": { 593 | "id": "y4u6alml7hHQ" 594 | }, 595 | "execution_count": null, 596 | "outputs": [] 597 | }, 598 | { 599 | "cell_type": "markdown", 600 | "metadata": { 601 | "id": "onu9Yu8D6Vbh" 602 | }, 603 | "source": [ 604 | "#### Test the resulting agent" 605 | ] 606 | }, 607 | { 608 | "cell_type": "code", 609 | "execution_count": null, 610 | "metadata": { 611 | "id": "U_8oUWxl6Vbh" 612 | }, 613 | "outputs": [], 614 | "source": [ 615 | "test_agent(env, policy)" 616 | ] 617 | }, 618 | { 619 | "cell_type": "markdown", 620 | "metadata": { 621 | "id": "bixFf0kp6Vbh" 622 | }, 623 | "source": [ 624 | "## Resources" 625 | ] 626 | }, 627 | { 628 | "cell_type": "markdown", 629 | "metadata": { 630 | "id": "NIU61hO36Vbh" 631 | }, 632 | "source": [ 633 | "[[1] Reinforcement Learning: An Introduction. Ch. 4: Dynamic Programming](https://web.stanford.edu/class/psych209/Readings/SuttonBartoIPRLBook2ndEd.pdf)" 634 | ] 635 | } 636 | ], 637 | "metadata": { 638 | "kernelspec": { 639 | "display_name": "Python 3", 640 | "language": "python", 641 | "name": "python3" 642 | }, 643 | "language_info": { 644 | "codemirror_mode": { 645 | "name": "ipython", 646 | "version": 3 647 | }, 648 | "file_extension": ".py", 649 | "mimetype": "text/x-python", 650 | "name": "python", 651 | "nbconvert_exporter": "python", 652 | "pygments_lexer": "ipython3", 653 | "version": "3.8.5" 654 | }, 655 | "colab": { 656 | "provenance": [] 657 | } 658 | }, 659 | "nbformat": 4, 660 | "nbformat_minor": 0 661 | } -------------------------------------------------------------------------------- /Section_3_value_iteration.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "pycharm": { 7 | "name": "#%%\n" 8 | }, 9 | "id": "AlZ3upHcxuON" 10 | }, 11 | "source": [ 12 | "
\n", 13 | "

\n", 14 | " Value Iteration\n", 15 | "

\n", 16 | "
\n", 17 | "
\n", 18 | "\n", 19 | "
\n", 20 | "

\n", 21 | " In this notebook we are going to look at a dynamic programming algorithm called value iteration. In it, we will sweep the state space and update all the V(s) values.\n", 22 | "

\n", 23 | "
" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "source": [ 29 | "# @title Setup code (not important) - Run this cell by pressing \"Shift + Enter\"\n", 30 | "\n", 31 | "\n", 32 | "\n", 33 | "!pip install -qq gym==0.23.0\n", 34 | "\n", 35 | "\n", 36 | "from typing import Tuple, Dict, Optional, Iterable, Callable\n", 37 | "\n", 38 | "import numpy as np\n", 39 | "import seaborn as sns\n", 40 | "import matplotlib\n", 41 | "from matplotlib import animation\n", 42 | "\n", 43 | "from IPython.display import HTML\n", 44 | "\n", 45 | "import gym\n", 46 | "from gym import spaces\n", 47 | "from gym.error import DependencyNotInstalled\n", 48 | "\n", 49 | "import pygame\n", 50 | "from pygame import gfxdraw\n", 51 | "\n", 52 | "\n", 53 | "class Maze(gym.Env):\n", 54 | "\n", 55 | " def __init__(self, exploring_starts: bool = False,\n", 56 | " shaped_rewards: bool = False, size: int = 5) -> None:\n", 57 | " super().__init__()\n", 58 | " self.exploring_starts = exploring_starts\n", 59 | " self.shaped_rewards = shaped_rewards\n", 60 | " self.state = (size - 1, size - 1)\n", 61 | " self.goal = (size - 1, size - 1)\n", 62 | " self.maze = self._create_maze(size=size)\n", 63 | " self.distances = self._compute_distances(self.goal, self.maze)\n", 64 | " self.action_space = spaces.Discrete(n=4)\n", 65 | " self.action_space.action_meanings = {0: 'UP', 1: 'RIGHT', 2: 'DOWN', 3: \"LEFT\"}\n", 66 | " self.observation_space = spaces.MultiDiscrete([size, size])\n", 67 | "\n", 68 | " self.screen = None\n", 69 | " self.agent_transform = None\n", 70 | "\n", 71 | " def step(self, action: int) -> Tuple[Tuple[int, int], float, bool, Dict]:\n", 72 | " reward = self.compute_reward(self.state, action)\n", 73 | " self.state = self._get_next_state(self.state, action)\n", 74 | " done = self.state == self.goal\n", 75 | " info = {}\n", 76 | " return self.state, reward, done, info\n", 77 | "\n", 78 | " def reset(self) -> Tuple[int, int]:\n", 79 | " if self.exploring_starts:\n", 80 | " while self.state == self.goal:\n", 81 | " self.state = tuple(self.observation_space.sample())\n", 82 | " else:\n", 83 | " self.state = (0, 0)\n", 84 | " return self.state\n", 85 | "\n", 86 | " def render(self, mode: str = 'human') -> Optional[np.ndarray]:\n", 87 | " assert mode in ['human', 'rgb_array']\n", 88 | "\n", 89 | " screen_size = 600\n", 90 | " scale = screen_size / 5\n", 91 | "\n", 92 | " if self.screen is None:\n", 93 | " pygame.init()\n", 94 | " self.screen = pygame.Surface((screen_size, screen_size))\n", 95 | "\n", 96 | " surf = pygame.Surface((screen_size, screen_size))\n", 97 | " surf.fill((22, 36, 71))\n", 98 | "\n", 99 | "\n", 100 | " for row in range(5):\n", 101 | " for col in range(5):\n", 102 | "\n", 103 | " state = (row, col)\n", 104 | " for next_state in [(row + 1, col), (row - 1, col), (row, col + 1), (row, col - 1)]:\n", 105 | " if next_state not in self.maze[state]:\n", 106 | "\n", 107 | " # Add the geometry of the edges and walls (i.e. the boundaries between\n", 108 | " # adjacent squares that are not connected).\n", 109 | " row_diff, col_diff = np.subtract(next_state, state)\n", 110 | " left = (col + (col_diff > 0)) * scale - 2 * (col_diff != 0)\n", 111 | " right = ((col + 1) - (col_diff < 0)) * scale + 2 * (col_diff != 0)\n", 112 | " top = (5 - (row + (row_diff > 0))) * scale - 2 * (row_diff != 0)\n", 113 | " bottom = (5 - ((row + 1) - (row_diff < 0))) * scale + 2 * (row_diff != 0)\n", 114 | "\n", 115 | " gfxdraw.filled_polygon(surf, [(left, bottom), (left, top), (right, top), (right, bottom)], (255, 255, 255))\n", 116 | "\n", 117 | " # Add the geometry of the goal square to the viewer.\n", 118 | " left, right, top, bottom = scale * 4 + 10, scale * 5 - 10, scale - 10, 10\n", 119 | " gfxdraw.filled_polygon(surf, [(left, bottom), (left, top), (right, top), (right, bottom)], (40, 199, 172))\n", 120 | "\n", 121 | " # Add the geometry of the agent to the viewer.\n", 122 | " agent_row = int(screen_size - scale * (self.state[0] + .5))\n", 123 | " agent_col = int(scale * (self.state[1] + .5))\n", 124 | " gfxdraw.filled_circle(surf, agent_col, agent_row, int(scale * .6 / 2), (228, 63, 90))\n", 125 | "\n", 126 | " surf = pygame.transform.flip(surf, False, True)\n", 127 | " self.screen.blit(surf, (0, 0))\n", 128 | "\n", 129 | " return np.transpose(\n", 130 | " np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2)\n", 131 | " )\n", 132 | "\n", 133 | " def close(self) -> None:\n", 134 | " if self.screen is not None:\n", 135 | " pygame.display.quit()\n", 136 | " pygame.quit()\n", 137 | " self.screen = None\n", 138 | "\n", 139 | " def compute_reward(self, state: Tuple[int, int], action: int) -> float:\n", 140 | " next_state = self._get_next_state(state, action)\n", 141 | " if self.shaped_rewards:\n", 142 | " return - (self.distances[next_state] / self.distances.max())\n", 143 | " return - float(state != self.goal)\n", 144 | "\n", 145 | " def simulate_step(self, state: Tuple[int, int], action: int):\n", 146 | " reward = self.compute_reward(state, action)\n", 147 | " next_state = self._get_next_state(state, action)\n", 148 | " done = next_state == self.goal\n", 149 | " info = {}\n", 150 | " return next_state, reward, done, info\n", 151 | "\n", 152 | " def _get_next_state(self, state: Tuple[int, int], action: int) -> Tuple[int, int]:\n", 153 | " if action == 0:\n", 154 | " next_state = (state[0] - 1, state[1])\n", 155 | " elif action == 1:\n", 156 | " next_state = (state[0], state[1] + 1)\n", 157 | " elif action == 2:\n", 158 | " next_state = (state[0] + 1, state[1])\n", 159 | " elif action == 3:\n", 160 | " next_state = (state[0], state[1] - 1)\n", 161 | " else:\n", 162 | " raise ValueError(\"Action value not supported:\", action)\n", 163 | " if next_state in self.maze[state]:\n", 164 | " return next_state\n", 165 | " return state\n", 166 | "\n", 167 | " @staticmethod\n", 168 | " def _create_maze(size: int) -> Dict[Tuple[int, int], Iterable[Tuple[int, int]]]:\n", 169 | " maze = {(row, col): [(row - 1, col), (row + 1, col), (row, col - 1), (row, col + 1)]\n", 170 | " for row in range(size) for col in range(size)}\n", 171 | "\n", 172 | " left_edges = [[(row, 0), (row, -1)] for row in range(size)]\n", 173 | " right_edges = [[(row, size - 1), (row, size)] for row in range(size)]\n", 174 | " upper_edges = [[(0, col), (-1, col)] for col in range(size)]\n", 175 | " lower_edges = [[(size - 1, col), (size, col)] for col in range(size)]\n", 176 | " walls = [\n", 177 | " [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)],\n", 178 | " [(1, 1), (1, 2)], [(2, 1), (2, 2)], [(3, 1), (3, 2)],\n", 179 | " [(3, 1), (4, 1)], [(0, 2), (1, 2)], [(1, 2), (1, 3)],\n", 180 | " [(2, 2), (3, 2)], [(2, 3), (3, 3)], [(2, 4), (3, 4)],\n", 181 | " [(4, 2), (4, 3)], [(1, 3), (1, 4)], [(2, 3), (2, 4)],\n", 182 | " ]\n", 183 | "\n", 184 | " obstacles = upper_edges + lower_edges + left_edges + right_edges + walls\n", 185 | "\n", 186 | " for src, dst in obstacles:\n", 187 | " maze[src].remove(dst)\n", 188 | "\n", 189 | " if dst in maze:\n", 190 | " maze[dst].remove(src)\n", 191 | "\n", 192 | " return maze\n", 193 | "\n", 194 | " @staticmethod\n", 195 | " def _compute_distances(goal: Tuple[int, int],\n", 196 | " maze: Dict[Tuple[int, int], Iterable[Tuple[int, int]]]) -> np.ndarray:\n", 197 | " distances = np.full((5, 5), np.inf)\n", 198 | " visited = set()\n", 199 | " distances[goal] = 0.\n", 200 | "\n", 201 | " while visited != set(maze):\n", 202 | " sorted_dst = [(v // 5, v % 5) for v in distances.argsort(axis=None)]\n", 203 | " closest = next(x for x in sorted_dst if x not in visited)\n", 204 | " visited.add(closest)\n", 205 | "\n", 206 | " for neighbour in maze[closest]:\n", 207 | " distances[neighbour] = min(distances[neighbour], distances[closest] + 1)\n", 208 | " return distances\n", 209 | "\n", 210 | "\n", 211 | "def plot_policy(probs_or_qvals, frame, action_meanings=None):\n", 212 | " if action_meanings is None:\n", 213 | " action_meanings = {0: 'U', 1: 'R', 2: 'D', 3: 'L'}\n", 214 | " fig, axes = plt.subplots(1, 2, figsize=(8, 4))\n", 215 | " max_prob_actions = probs_or_qvals.argmax(axis=-1)\n", 216 | " probs_copy = max_prob_actions.copy().astype(object)\n", 217 | " for key in action_meanings:\n", 218 | " probs_copy[probs_copy == key] = action_meanings[key]\n", 219 | " sns.heatmap(max_prob_actions, annot=probs_copy, fmt='', cbar=False, cmap='coolwarm',\n", 220 | " annot_kws={'weight': 'bold', 'size': 12}, linewidths=2, ax=axes[0])\n", 221 | " axes[1].imshow(frame)\n", 222 | " axes[0].axis('off')\n", 223 | " axes[1].axis('off')\n", 224 | " plt.suptitle(\"Policy\", size=18)\n", 225 | " plt.tight_layout()\n", 226 | "\n", 227 | "\n", 228 | "def plot_values(state_values, frame):\n", 229 | " f, axes = plt.subplots(1, 2, figsize=(10, 4))\n", 230 | " sns.heatmap(state_values, annot=True, fmt=\".2f\", cmap='coolwarm',\n", 231 | " annot_kws={'weight': 'bold', 'size': 12}, linewidths=2, ax=axes[0])\n", 232 | " axes[1].imshow(frame)\n", 233 | " axes[0].axis('off')\n", 234 | " axes[1].axis('off')\n", 235 | " plt.tight_layout()\n", 236 | "\n", 237 | "\n", 238 | "def display_video(frames):\n", 239 | " # Copied from: https://colab.research.google.com/github/deepmind/dm_control/blob/master/tutorial.ipynb\n", 240 | " orig_backend = matplotlib.get_backend()\n", 241 | " matplotlib.use('Agg')\n", 242 | " fig, ax = plt.subplots(1, 1, figsize=(5, 5))\n", 243 | " matplotlib.use(orig_backend)\n", 244 | " ax.set_axis_off()\n", 245 | " ax.set_aspect('equal')\n", 246 | " ax.set_position([0, 0, 1, 1])\n", 247 | " im = ax.imshow(frames[0])\n", 248 | " def update(frame):\n", 249 | " im.set_data(frame)\n", 250 | " return [im]\n", 251 | " anim = animation.FuncAnimation(fig=fig, func=update, frames=frames,\n", 252 | " interval=50, blit=True, repeat=False)\n", 253 | " return HTML(anim.to_html5_video())\n", 254 | "\n", 255 | "\n", 256 | "def test_agent(environment, policy, episodes=10):\n", 257 | " frames = []\n", 258 | " for episode in range(episodes):\n", 259 | " state = env.reset()\n", 260 | " done = False\n", 261 | " frames.append(env.render(mode=\"rgb_array\"))\n", 262 | "\n", 263 | " while not done:\n", 264 | " p = policy(state)\n", 265 | " if isinstance(p, np.ndarray):\n", 266 | " action = np.random.choice(4, p=p)\n", 267 | " else:\n", 268 | " action = p\n", 269 | " next_state, reward, done, extra_info = env.step(action)\n", 270 | " img = env.render(mode=\"rgb_array\")\n", 271 | " frames.append(img)\n", 272 | " state = next_state\n", 273 | "\n", 274 | " return display_video(frames)\n", 275 | "\n" 276 | ], 277 | "metadata": { 278 | "cellView": "form", 279 | "id": "q43J78D0zMXC" 280 | }, 281 | "execution_count": null, 282 | "outputs": [] 283 | }, 284 | { 285 | "cell_type": "markdown", 286 | "metadata": { 287 | "id": "g4GSSPpAxuOS" 288 | }, 289 | "source": [ 290 | "## Import the necessary software libraries:" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": null, 296 | "metadata": { 297 | "id": "nforYaTCxuOT" 298 | }, 299 | "outputs": [], 300 | "source": [ 301 | "import numpy as np\n", 302 | "import matplotlib.pyplot as plt" 303 | ] 304 | }, 305 | { 306 | "cell_type": "markdown", 307 | "metadata": { 308 | "id": "cC2dbTlhxuOT" 309 | }, 310 | "source": [ 311 | "## Initialize the environment" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": null, 317 | "metadata": { 318 | "id": "ZoZ7a19kxuOU" 319 | }, 320 | "outputs": [], 321 | "source": [] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": null, 326 | "metadata": { 327 | "id": "kj5M3uj-xuOU" 328 | }, 329 | "outputs": [], 330 | "source": [] 331 | }, 332 | { 333 | "cell_type": "code", 334 | "execution_count": null, 335 | "metadata": { 336 | "id": "WeudzDeoxuOU" 337 | }, 338 | "outputs": [], 339 | "source": [] 340 | }, 341 | { 342 | "cell_type": "markdown", 343 | "metadata": { 344 | "id": "W43hrBlexuOV" 345 | }, 346 | "source": [ 347 | "## Define the policy $\\pi(\\cdot|s)$" 348 | ] 349 | }, 350 | { 351 | "cell_type": "markdown", 352 | "metadata": { 353 | "id": "-tgt_gf3xuOV" 354 | }, 355 | "source": [ 356 | "#### Create the policy $\\pi(\\cdot|s)$" 357 | ] 358 | }, 359 | { 360 | "cell_type": "code", 361 | "execution_count": null, 362 | "metadata": { 363 | "id": "DGLBsP09xuOV" 364 | }, 365 | "outputs": [], 366 | "source": [] 367 | }, 368 | { 369 | "cell_type": "code", 370 | "execution_count": null, 371 | "metadata": { 372 | "id": "2u0NINGsxuOW" 373 | }, 374 | "outputs": [], 375 | "source": [] 376 | }, 377 | { 378 | "cell_type": "markdown", 379 | "metadata": { 380 | "id": "Wyf78bRVxuOW" 381 | }, 382 | "source": [ 383 | "#### Test the policy with state (0, 0)" 384 | ] 385 | }, 386 | { 387 | "cell_type": "code", 388 | "execution_count": null, 389 | "metadata": { 390 | "id": "t7Xu15pQxuOW" 391 | }, 392 | "outputs": [], 393 | "source": [] 394 | }, 395 | { 396 | "cell_type": "markdown", 397 | "metadata": { 398 | "id": "kjKfUVHnxuOW" 399 | }, 400 | "source": [ 401 | "#### See how the random policy does in the maze" 402 | ] 403 | }, 404 | { 405 | "cell_type": "code", 406 | "execution_count": null, 407 | "metadata": { 408 | "id": "4tM6oVe8xuOW" 409 | }, 410 | "outputs": [], 411 | "source": [] 412 | }, 413 | { 414 | "cell_type": "markdown", 415 | "metadata": { 416 | "id": "h9MHt5yMxuOW" 417 | }, 418 | "source": [ 419 | "#### Plot the policy" 420 | ] 421 | }, 422 | { 423 | "cell_type": "code", 424 | "execution_count": null, 425 | "metadata": { 426 | "id": "jqPH9s4rxuOX" 427 | }, 428 | "outputs": [], 429 | "source": [] 430 | }, 431 | { 432 | "cell_type": "markdown", 433 | "metadata": { 434 | "id": "D0oYEMu7xuOX" 435 | }, 436 | "source": [ 437 | "## Define value table $V(s)$" 438 | ] 439 | }, 440 | { 441 | "cell_type": "markdown", 442 | "metadata": { 443 | "id": "_J93fQINxuOX" 444 | }, 445 | "source": [ 446 | "#### Create the $V(s)$ table" 447 | ] 448 | }, 449 | { 450 | "cell_type": "code", 451 | "execution_count": null, 452 | "metadata": { 453 | "id": "tv_Y-x4-xuOX" 454 | }, 455 | "outputs": [], 456 | "source": [] 457 | }, 458 | { 459 | "cell_type": "markdown", 460 | "metadata": { 461 | "id": "ff2B1KprxuOX" 462 | }, 463 | "source": [ 464 | "#### Plot V(s)" 465 | ] 466 | }, 467 | { 468 | "cell_type": "code", 469 | "execution_count": null, 470 | "metadata": { 471 | "id": "W1H0pYVbxuOX" 472 | }, 473 | "outputs": [], 474 | "source": [] 475 | }, 476 | { 477 | "cell_type": "markdown", 478 | "metadata": { 479 | "id": "e7yPNJlnxuOX" 480 | }, 481 | "source": [ 482 | "## Implement the Value Iteration algorithm\n", 483 | "\n", 484 | "
\n", 485 | "\n", 486 | "\n", 487 | "\n", 488 | "\n", 489 | "
\n", 490 | " Adapted from Barto & Sutton: \"Reinforcement Learning: An Introduction\".\n", 491 | "
" 492 | ] 493 | }, 494 | { 495 | "cell_type": "code", 496 | "execution_count": null, 497 | "metadata": { 498 | "id": "jQAjDBoaxuOX" 499 | }, 500 | "outputs": [], 501 | "source": [] 502 | }, 503 | { 504 | "cell_type": "code", 505 | "execution_count": null, 506 | "metadata": { 507 | "id": "KcfSk96rxuOY" 508 | }, 509 | "outputs": [], 510 | "source": [] 511 | }, 512 | { 513 | "cell_type": "markdown", 514 | "metadata": { 515 | "id": "Nktnqle1xuOY" 516 | }, 517 | "source": [ 518 | "## Show results" 519 | ] 520 | }, 521 | { 522 | "cell_type": "markdown", 523 | "metadata": { 524 | "id": "K0SKfO-7xuOY" 525 | }, 526 | "source": [ 527 | "#### Show resulting value table $V(s)$" 528 | ] 529 | }, 530 | { 531 | "cell_type": "code", 532 | "execution_count": null, 533 | "metadata": { 534 | "id": "Ki9R0UTzxuOY" 535 | }, 536 | "outputs": [], 537 | "source": [] 538 | }, 539 | { 540 | "cell_type": "markdown", 541 | "metadata": { 542 | "id": "1UQJjMpVxuOY" 543 | }, 544 | "source": [ 545 | "#### Show resulting policy $\\pi(\\cdot|s)$" 546 | ] 547 | }, 548 | { 549 | "cell_type": "code", 550 | "execution_count": null, 551 | "metadata": { 552 | "id": "AfUKiL7sxuOY" 553 | }, 554 | "outputs": [], 555 | "source": [] 556 | }, 557 | { 558 | "cell_type": "markdown", 559 | "metadata": { 560 | "id": "fn_aj8SHxuOY" 561 | }, 562 | "source": [ 563 | "#### Test the resulting agent" 564 | ] 565 | }, 566 | { 567 | "cell_type": "code", 568 | "execution_count": null, 569 | "metadata": { 570 | "id": "knki0tRqxuOY" 571 | }, 572 | "outputs": [], 573 | "source": [] 574 | }, 575 | { 576 | "cell_type": "markdown", 577 | "metadata": { 578 | "id": "bwV_rsU2xuOY" 579 | }, 580 | "source": [ 581 | "## Resources" 582 | ] 583 | }, 584 | { 585 | "cell_type": "markdown", 586 | "metadata": { 587 | "id": "m_IZCUpixuOY" 588 | }, 589 | "source": [ 590 | "[[1] Reinforcement Learning: An Introduction. Ch. 4: Dynamic Programming](https://web.stanford.edu/class/psych209/Readings/SuttonBartoIPRLBook2ndEd.pdf)" 591 | ] 592 | } 593 | ], 594 | "metadata": { 595 | "kernelspec": { 596 | "display_name": "Python 3", 597 | "language": "python", 598 | "name": "python3" 599 | }, 600 | "language_info": { 601 | "codemirror_mode": { 602 | "name": "ipython", 603 | "version": 3 604 | }, 605 | "file_extension": ".py", 606 | "mimetype": "text/x-python", 607 | "name": "python", 608 | "nbconvert_exporter": "python", 609 | "pygments_lexer": "ipython3", 610 | "version": "3.8.5" 611 | }, 612 | "colab": { 613 | "provenance": [] 614 | } 615 | }, 616 | "nbformat": 4, 617 | "nbformat_minor": 0 618 | } -------------------------------------------------------------------------------- /Section_3_value_iteration_complete.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "pycharm": { 7 | "name": "#%%\n" 8 | }, 9 | "id": "AlZ3upHcxuON" 10 | }, 11 | "source": [ 12 | "
\n", 13 | "

\n", 14 | " Value Iteration\n", 15 | "

\n", 16 | "
\n", 17 | "
\n", 18 | "\n", 19 | "
\n", 20 | "

\n", 21 | " In this notebook we are going to look at a dynamic programming algorithm called value iteration. In it, we will sweep the state space and update all the V(s) values.\n", 22 | "

\n", 23 | "
" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "source": [ 29 | "# @title Setup code (not important) - Run this cell by pressing \"Shift + Enter\"\n", 30 | "\n", 31 | "\n", 32 | "\n", 33 | "!pip install -qq gym==0.23.0\n", 34 | "\n", 35 | "\n", 36 | "from typing import Tuple, Dict, Optional, Iterable, Callable\n", 37 | "\n", 38 | "import numpy as np\n", 39 | "import seaborn as sns\n", 40 | "import matplotlib\n", 41 | "from matplotlib import animation\n", 42 | "\n", 43 | "from IPython.display import HTML\n", 44 | "\n", 45 | "import gym\n", 46 | "from gym import spaces\n", 47 | "from gym.error import DependencyNotInstalled\n", 48 | "\n", 49 | "import pygame\n", 50 | "from pygame import gfxdraw\n", 51 | "\n", 52 | "\n", 53 | "class Maze(gym.Env):\n", 54 | "\n", 55 | " def __init__(self, exploring_starts: bool = False,\n", 56 | " shaped_rewards: bool = False, size: int = 5) -> None:\n", 57 | " super().__init__()\n", 58 | " self.exploring_starts = exploring_starts\n", 59 | " self.shaped_rewards = shaped_rewards\n", 60 | " self.state = (size - 1, size - 1)\n", 61 | " self.goal = (size - 1, size - 1)\n", 62 | " self.maze = self._create_maze(size=size)\n", 63 | " self.distances = self._compute_distances(self.goal, self.maze)\n", 64 | " self.action_space = spaces.Discrete(n=4)\n", 65 | " self.action_space.action_meanings = {0: 'UP', 1: 'RIGHT', 2: 'DOWN', 3: \"LEFT\"}\n", 66 | " self.observation_space = spaces.MultiDiscrete([size, size])\n", 67 | "\n", 68 | " self.screen = None\n", 69 | " self.agent_transform = None\n", 70 | "\n", 71 | " def step(self, action: int) -> Tuple[Tuple[int, int], float, bool, Dict]:\n", 72 | " reward = self.compute_reward(self.state, action)\n", 73 | " self.state = self._get_next_state(self.state, action)\n", 74 | " done = self.state == self.goal\n", 75 | " info = {}\n", 76 | " return self.state, reward, done, info\n", 77 | "\n", 78 | " def reset(self) -> Tuple[int, int]:\n", 79 | " if self.exploring_starts:\n", 80 | " while self.state == self.goal:\n", 81 | " self.state = tuple(self.observation_space.sample())\n", 82 | " else:\n", 83 | " self.state = (0, 0)\n", 84 | " return self.state\n", 85 | "\n", 86 | " def render(self, mode: str = 'human') -> Optional[np.ndarray]:\n", 87 | " assert mode in ['human', 'rgb_array']\n", 88 | "\n", 89 | " screen_size = 600\n", 90 | " scale = screen_size / 5\n", 91 | "\n", 92 | " if self.screen is None:\n", 93 | " pygame.init()\n", 94 | " self.screen = pygame.Surface((screen_size, screen_size))\n", 95 | "\n", 96 | " surf = pygame.Surface((screen_size, screen_size))\n", 97 | " surf.fill((22, 36, 71))\n", 98 | "\n", 99 | "\n", 100 | " for row in range(5):\n", 101 | " for col in range(5):\n", 102 | "\n", 103 | " state = (row, col)\n", 104 | " for next_state in [(row + 1, col), (row - 1, col), (row, col + 1), (row, col - 1)]:\n", 105 | " if next_state not in self.maze[state]:\n", 106 | "\n", 107 | " # Add the geometry of the edges and walls (i.e. the boundaries between\n", 108 | " # adjacent squares that are not connected).\n", 109 | " row_diff, col_diff = np.subtract(next_state, state)\n", 110 | " left = (col + (col_diff > 0)) * scale - 2 * (col_diff != 0)\n", 111 | " right = ((col + 1) - (col_diff < 0)) * scale + 2 * (col_diff != 0)\n", 112 | " top = (5 - (row + (row_diff > 0))) * scale - 2 * (row_diff != 0)\n", 113 | " bottom = (5 - ((row + 1) - (row_diff < 0))) * scale + 2 * (row_diff != 0)\n", 114 | "\n", 115 | " gfxdraw.filled_polygon(surf, [(left, bottom), (left, top), (right, top), (right, bottom)], (255, 255, 255))\n", 116 | "\n", 117 | " # Add the geometry of the goal square to the viewer.\n", 118 | " left, right, top, bottom = scale * 4 + 10, scale * 5 - 10, scale - 10, 10\n", 119 | " gfxdraw.filled_polygon(surf, [(left, bottom), (left, top), (right, top), (right, bottom)], (40, 199, 172))\n", 120 | "\n", 121 | " # Add the geometry of the agent to the viewer.\n", 122 | " agent_row = int(screen_size - scale * (self.state[0] + .5))\n", 123 | " agent_col = int(scale * (self.state[1] + .5))\n", 124 | " gfxdraw.filled_circle(surf, agent_col, agent_row, int(scale * .6 / 2), (228, 63, 90))\n", 125 | "\n", 126 | " surf = pygame.transform.flip(surf, False, True)\n", 127 | " self.screen.blit(surf, (0, 0))\n", 128 | "\n", 129 | " return np.transpose(\n", 130 | " np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2)\n", 131 | " )\n", 132 | "\n", 133 | " def close(self) -> None:\n", 134 | " if self.screen is not None:\n", 135 | " pygame.display.quit()\n", 136 | " pygame.quit()\n", 137 | " self.screen = None\n", 138 | "\n", 139 | " def compute_reward(self, state: Tuple[int, int], action: int) -> float:\n", 140 | " next_state = self._get_next_state(state, action)\n", 141 | " if self.shaped_rewards:\n", 142 | " return - (self.distances[next_state] / self.distances.max())\n", 143 | " return - float(state != self.goal)\n", 144 | "\n", 145 | " def simulate_step(self, state: Tuple[int, int], action: int):\n", 146 | " reward = self.compute_reward(state, action)\n", 147 | " next_state = self._get_next_state(state, action)\n", 148 | " done = next_state == self.goal\n", 149 | " info = {}\n", 150 | " return next_state, reward, done, info\n", 151 | "\n", 152 | " def _get_next_state(self, state: Tuple[int, int], action: int) -> Tuple[int, int]:\n", 153 | " if action == 0:\n", 154 | " next_state = (state[0] - 1, state[1])\n", 155 | " elif action == 1:\n", 156 | " next_state = (state[0], state[1] + 1)\n", 157 | " elif action == 2:\n", 158 | " next_state = (state[0] + 1, state[1])\n", 159 | " elif action == 3:\n", 160 | " next_state = (state[0], state[1] - 1)\n", 161 | " else:\n", 162 | " raise ValueError(\"Action value not supported:\", action)\n", 163 | " if next_state in self.maze[state]:\n", 164 | " return next_state\n", 165 | " return state\n", 166 | "\n", 167 | " @staticmethod\n", 168 | " def _create_maze(size: int) -> Dict[Tuple[int, int], Iterable[Tuple[int, int]]]:\n", 169 | " maze = {(row, col): [(row - 1, col), (row + 1, col), (row, col - 1), (row, col + 1)]\n", 170 | " for row in range(size) for col in range(size)}\n", 171 | "\n", 172 | " left_edges = [[(row, 0), (row, -1)] for row in range(size)]\n", 173 | " right_edges = [[(row, size - 1), (row, size)] for row in range(size)]\n", 174 | " upper_edges = [[(0, col), (-1, col)] for col in range(size)]\n", 175 | " lower_edges = [[(size - 1, col), (size, col)] for col in range(size)]\n", 176 | " walls = [\n", 177 | " [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)],\n", 178 | " [(1, 1), (1, 2)], [(2, 1), (2, 2)], [(3, 1), (3, 2)],\n", 179 | " [(3, 1), (4, 1)], [(0, 2), (1, 2)], [(1, 2), (1, 3)],\n", 180 | " [(2, 2), (3, 2)], [(2, 3), (3, 3)], [(2, 4), (3, 4)],\n", 181 | " [(4, 2), (4, 3)], [(1, 3), (1, 4)], [(2, 3), (2, 4)],\n", 182 | " ]\n", 183 | "\n", 184 | " obstacles = upper_edges + lower_edges + left_edges + right_edges + walls\n", 185 | "\n", 186 | " for src, dst in obstacles:\n", 187 | " maze[src].remove(dst)\n", 188 | "\n", 189 | " if dst in maze:\n", 190 | " maze[dst].remove(src)\n", 191 | "\n", 192 | " return maze\n", 193 | "\n", 194 | " @staticmethod\n", 195 | " def _compute_distances(goal: Tuple[int, int],\n", 196 | " maze: Dict[Tuple[int, int], Iterable[Tuple[int, int]]]) -> np.ndarray:\n", 197 | " distances = np.full((5, 5), np.inf)\n", 198 | " visited = set()\n", 199 | " distances[goal] = 0.\n", 200 | "\n", 201 | " while visited != set(maze):\n", 202 | " sorted_dst = [(v // 5, v % 5) for v in distances.argsort(axis=None)]\n", 203 | " closest = next(x for x in sorted_dst if x not in visited)\n", 204 | " visited.add(closest)\n", 205 | "\n", 206 | " for neighbour in maze[closest]:\n", 207 | " distances[neighbour] = min(distances[neighbour], distances[closest] + 1)\n", 208 | " return distances\n", 209 | "\n", 210 | "\n", 211 | "def plot_policy(probs_or_qvals, frame, action_meanings=None):\n", 212 | " if action_meanings is None:\n", 213 | " action_meanings = {0: 'U', 1: 'R', 2: 'D', 3: 'L'}\n", 214 | " fig, axes = plt.subplots(1, 2, figsize=(8, 4))\n", 215 | " max_prob_actions = probs_or_qvals.argmax(axis=-1)\n", 216 | " probs_copy = max_prob_actions.copy().astype(object)\n", 217 | " for key in action_meanings:\n", 218 | " probs_copy[probs_copy == key] = action_meanings[key]\n", 219 | " sns.heatmap(max_prob_actions, annot=probs_copy, fmt='', cbar=False, cmap='coolwarm',\n", 220 | " annot_kws={'weight': 'bold', 'size': 12}, linewidths=2, ax=axes[0])\n", 221 | " axes[1].imshow(frame)\n", 222 | " axes[0].axis('off')\n", 223 | " axes[1].axis('off')\n", 224 | " plt.suptitle(\"Policy\", size=18)\n", 225 | " plt.tight_layout()\n", 226 | "\n", 227 | "\n", 228 | "def plot_values(state_values, frame):\n", 229 | " f, axes = plt.subplots(1, 2, figsize=(10, 4))\n", 230 | " sns.heatmap(state_values, annot=True, fmt=\".2f\", cmap='coolwarm',\n", 231 | " annot_kws={'weight': 'bold', 'size': 12}, linewidths=2, ax=axes[0])\n", 232 | " axes[1].imshow(frame)\n", 233 | " axes[0].axis('off')\n", 234 | " axes[1].axis('off')\n", 235 | " plt.tight_layout()\n", 236 | "\n", 237 | "\n", 238 | "def display_video(frames):\n", 239 | " # Copied from: https://colab.research.google.com/github/deepmind/dm_control/blob/master/tutorial.ipynb\n", 240 | " orig_backend = matplotlib.get_backend()\n", 241 | " matplotlib.use('Agg')\n", 242 | " fig, ax = plt.subplots(1, 1, figsize=(5, 5))\n", 243 | " matplotlib.use(orig_backend)\n", 244 | " ax.set_axis_off()\n", 245 | " ax.set_aspect('equal')\n", 246 | " ax.set_position([0, 0, 1, 1])\n", 247 | " im = ax.imshow(frames[0])\n", 248 | " def update(frame):\n", 249 | " im.set_data(frame)\n", 250 | " return [im]\n", 251 | " anim = animation.FuncAnimation(fig=fig, func=update, frames=frames,\n", 252 | " interval=50, blit=True, repeat=False)\n", 253 | " return HTML(anim.to_html5_video())\n", 254 | "\n", 255 | "\n", 256 | "def test_agent(environment, policy, episodes=10):\n", 257 | " frames = []\n", 258 | " for episode in range(episodes):\n", 259 | " state = env.reset()\n", 260 | " done = False\n", 261 | " frames.append(env.render(mode=\"rgb_array\"))\n", 262 | "\n", 263 | " while not done:\n", 264 | " p = policy(state)\n", 265 | " if isinstance(p, np.ndarray):\n", 266 | " action = np.random.choice(4, p=p)\n", 267 | " else:\n", 268 | " action = p\n", 269 | " next_state, reward, done, extra_info = env.step(action)\n", 270 | " img = env.render(mode=\"rgb_array\")\n", 271 | " frames.append(img)\n", 272 | " state = next_state\n", 273 | "\n", 274 | " return display_video(frames)\n", 275 | "\n" 276 | ], 277 | "metadata": { 278 | "cellView": "form", 279 | "id": "q43J78D0zMXC" 280 | }, 281 | "execution_count": null, 282 | "outputs": [] 283 | }, 284 | { 285 | "cell_type": "markdown", 286 | "metadata": { 287 | "id": "g4GSSPpAxuOS" 288 | }, 289 | "source": [ 290 | "## Import the necessary software libraries:" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": null, 296 | "metadata": { 297 | "id": "nforYaTCxuOT" 298 | }, 299 | "outputs": [], 300 | "source": [ 301 | "import numpy as np\n", 302 | "import matplotlib.pyplot as plt" 303 | ] 304 | }, 305 | { 306 | "cell_type": "markdown", 307 | "metadata": { 308 | "id": "cC2dbTlhxuOT" 309 | }, 310 | "source": [ 311 | "## Initialize the environment" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": null, 317 | "metadata": { 318 | "id": "ZoZ7a19kxuOU" 319 | }, 320 | "outputs": [], 321 | "source": [ 322 | "env = Maze()" 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": null, 328 | "metadata": { 329 | "id": "kj5M3uj-xuOU" 330 | }, 331 | "outputs": [], 332 | "source": [ 333 | "frame = env.render(mode='rgb_array')\n", 334 | "plt.figure(figsize=(6,6))\n", 335 | "plt.axis('off')\n", 336 | "plt.imshow(frame)" 337 | ] 338 | }, 339 | { 340 | "cell_type": "code", 341 | "execution_count": null, 342 | "metadata": { 343 | "id": "WeudzDeoxuOU" 344 | }, 345 | "outputs": [], 346 | "source": [ 347 | "print(f\"Observation space shape: {env.observation_space.nvec}\")\n", 348 | "print(f\"Number of actions: {env.action_space.n}\")" 349 | ] 350 | }, 351 | { 352 | "cell_type": "markdown", 353 | "metadata": { 354 | "id": "W43hrBlexuOV" 355 | }, 356 | "source": [ 357 | "## Define the policy $\\pi(\\cdot|s)$" 358 | ] 359 | }, 360 | { 361 | "cell_type": "markdown", 362 | "metadata": { 363 | "id": "-tgt_gf3xuOV" 364 | }, 365 | "source": [ 366 | "#### Create the policy $\\pi(\\cdot|s)$" 367 | ] 368 | }, 369 | { 370 | "cell_type": "code", 371 | "execution_count": null, 372 | "metadata": { 373 | "id": "DGLBsP09xuOV" 374 | }, 375 | "outputs": [], 376 | "source": [ 377 | "policy_probs = np.full((5, 5, 4), 0.25)" 378 | ] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "execution_count": null, 383 | "metadata": { 384 | "id": "2u0NINGsxuOW" 385 | }, 386 | "outputs": [], 387 | "source": [ 388 | "def policy(state):\n", 389 | " return policy_probs[state]" 390 | ] 391 | }, 392 | { 393 | "cell_type": "markdown", 394 | "metadata": { 395 | "id": "Wyf78bRVxuOW" 396 | }, 397 | "source": [ 398 | "#### Test the policy with state (0, 0)" 399 | ] 400 | }, 401 | { 402 | "cell_type": "code", 403 | "execution_count": null, 404 | "metadata": { 405 | "id": "t7Xu15pQxuOW" 406 | }, 407 | "outputs": [], 408 | "source": [ 409 | "action_probabilities = policy((0,0))\n", 410 | "for action, prob in zip(range(4), action_probabilities):\n", 411 | " print(f\"Probability of taking action {action}: {prob}\")" 412 | ] 413 | }, 414 | { 415 | "cell_type": "markdown", 416 | "metadata": { 417 | "id": "kjKfUVHnxuOW" 418 | }, 419 | "source": [ 420 | "#### See how the random policy does in the maze" 421 | ] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "execution_count": null, 426 | "metadata": { 427 | "id": "4tM6oVe8xuOW" 428 | }, 429 | "outputs": [], 430 | "source": [ 431 | "test_agent(env, policy, episodes=1)" 432 | ] 433 | }, 434 | { 435 | "cell_type": "markdown", 436 | "metadata": { 437 | "id": "h9MHt5yMxuOW" 438 | }, 439 | "source": [ 440 | "#### Plot the policy" 441 | ] 442 | }, 443 | { 444 | "cell_type": "code", 445 | "execution_count": null, 446 | "metadata": { 447 | "id": "jqPH9s4rxuOX" 448 | }, 449 | "outputs": [], 450 | "source": [ 451 | "plot_policy(policy_probs, frame)" 452 | ] 453 | }, 454 | { 455 | "cell_type": "markdown", 456 | "metadata": { 457 | "id": "D0oYEMu7xuOX" 458 | }, 459 | "source": [ 460 | "## Define value table $V(s)$" 461 | ] 462 | }, 463 | { 464 | "cell_type": "markdown", 465 | "metadata": { 466 | "id": "_J93fQINxuOX" 467 | }, 468 | "source": [ 469 | "#### Create the $V(s)$ table" 470 | ] 471 | }, 472 | { 473 | "cell_type": "code", 474 | "execution_count": null, 475 | "metadata": { 476 | "id": "tv_Y-x4-xuOX" 477 | }, 478 | "outputs": [], 479 | "source": [ 480 | "state_values = np.zeros(shape=(5,5))" 481 | ] 482 | }, 483 | { 484 | "cell_type": "markdown", 485 | "metadata": { 486 | "id": "ff2B1KprxuOX" 487 | }, 488 | "source": [ 489 | "#### Plot V(s)" 490 | ] 491 | }, 492 | { 493 | "cell_type": "code", 494 | "execution_count": null, 495 | "metadata": { 496 | "id": "W1H0pYVbxuOX" 497 | }, 498 | "outputs": [], 499 | "source": [ 500 | "plot_values(state_values, frame)" 501 | ] 502 | }, 503 | { 504 | "cell_type": "markdown", 505 | "metadata": { 506 | "id": "e7yPNJlnxuOX" 507 | }, 508 | "source": [ 509 | "## Implement the Value Iteration algorithm\n", 510 | "\n", 511 | "
\n", 512 | "\n", 513 | "\n", 514 | "\n", 515 | "\n", 516 | "
\n", 517 | " Adapted from Barto & Sutton: \"Reinforcement Learning: An Introduction\".\n", 518 | "
" 519 | ] 520 | }, 521 | { 522 | "cell_type": "code", 523 | "execution_count": null, 524 | "metadata": { 525 | "id": "jQAjDBoaxuOX" 526 | }, 527 | "outputs": [], 528 | "source": [ 529 | "def value_iteration(policy_probs, state_values, theta=1e-6, gamma=0.99):\n", 530 | " delta = float('inf')\n", 531 | "\n", 532 | " while delta > theta:\n", 533 | " delta = 0\n", 534 | " for row in range(5):\n", 535 | " for col in range(5):\n", 536 | " old_value = state_values[(row, col)]\n", 537 | " action_probs = None\n", 538 | " max_qsa = float('-inf')\n", 539 | "\n", 540 | " for action in range(4):\n", 541 | " next_state, reward, _, _ = env.simulate_step((row, col), action)\n", 542 | " qsa = reward + gamma * state_values[next_state]\n", 543 | " if qsa > max_qsa:\n", 544 | " max_qsa = qsa\n", 545 | " action_probs = np.zeros(4)\n", 546 | " action_probs[action] = 1.\n", 547 | "\n", 548 | " state_values[(row, col)] = max_qsa\n", 549 | " policy_probs[(row, col)] = action_probs\n", 550 | "\n", 551 | " delta = max(delta, abs(max_qsa - old_value))" 552 | ] 553 | }, 554 | { 555 | "cell_type": "code", 556 | "execution_count": null, 557 | "metadata": { 558 | "id": "KcfSk96rxuOY" 559 | }, 560 | "outputs": [], 561 | "source": [ 562 | "value_iteration(policy_probs, state_values)" 563 | ] 564 | }, 565 | { 566 | "cell_type": "markdown", 567 | "metadata": { 568 | "id": "Nktnqle1xuOY" 569 | }, 570 | "source": [ 571 | "## Show results" 572 | ] 573 | }, 574 | { 575 | "cell_type": "markdown", 576 | "metadata": { 577 | "id": "K0SKfO-7xuOY" 578 | }, 579 | "source": [ 580 | "#### Show resulting value table $V(s)$" 581 | ] 582 | }, 583 | { 584 | "cell_type": "code", 585 | "execution_count": null, 586 | "metadata": { 587 | "id": "Ki9R0UTzxuOY" 588 | }, 589 | "outputs": [], 590 | "source": [ 591 | "plot_values(state_values, frame)" 592 | ] 593 | }, 594 | { 595 | "cell_type": "markdown", 596 | "metadata": { 597 | "id": "1UQJjMpVxuOY" 598 | }, 599 | "source": [ 600 | "#### Show resulting policy $\\pi(\\cdot|s)$" 601 | ] 602 | }, 603 | { 604 | "cell_type": "code", 605 | "execution_count": null, 606 | "metadata": { 607 | "id": "AfUKiL7sxuOY" 608 | }, 609 | "outputs": [], 610 | "source": [ 611 | "plot_policy(policy_probs, frame)" 612 | ] 613 | }, 614 | { 615 | "cell_type": "markdown", 616 | "metadata": { 617 | "id": "fn_aj8SHxuOY" 618 | }, 619 | "source": [ 620 | "#### Test the resulting agent" 621 | ] 622 | }, 623 | { 624 | "cell_type": "code", 625 | "execution_count": null, 626 | "metadata": { 627 | "id": "knki0tRqxuOY" 628 | }, 629 | "outputs": [], 630 | "source": [ 631 | "test_agent(env, policy)" 632 | ] 633 | }, 634 | { 635 | "cell_type": "markdown", 636 | "metadata": { 637 | "id": "bwV_rsU2xuOY" 638 | }, 639 | "source": [ 640 | "## Resources" 641 | ] 642 | }, 643 | { 644 | "cell_type": "markdown", 645 | "metadata": { 646 | "id": "m_IZCUpixuOY" 647 | }, 648 | "source": [ 649 | "[[1] Reinforcement Learning: An Introduction. Ch. 4: Dynamic Programming](https://web.stanford.edu/class/psych209/Readings/SuttonBartoIPRLBook2ndEd.pdf)" 650 | ] 651 | } 652 | ], 653 | "metadata": { 654 | "kernelspec": { 655 | "display_name": "Python 3", 656 | "language": "python", 657 | "name": "python3" 658 | }, 659 | "language_info": { 660 | "codemirror_mode": { 661 | "name": "ipython", 662 | "version": 3 663 | }, 664 | "file_extension": ".py", 665 | "mimetype": "text/x-python", 666 | "name": "python", 667 | "nbconvert_exporter": "python", 668 | "pygments_lexer": "ipython3", 669 | "version": "3.8.5" 670 | }, 671 | "colab": { 672 | "provenance": [] 673 | } 674 | }, 675 | "nbformat": 4, 676 | "nbformat_minor": 0 677 | } -------------------------------------------------------------------------------- /Section_4_on_policy_control.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "pycharm": { 7 | "name": "#%%\n" 8 | }, 9 | "id": "FegQfsZN8Pap" 10 | }, 11 | "source": [ 12 | "
\n", 13 | "

\n", 14 | " On-policy Monte Carlo Control\n", 15 | "

\n", 16 | "
\n", 17 | "
\n", 18 | "\n", 19 | "
\n", 20 | "

\n", 21 | " In this notebook we are going to implement one of the two major strategies that exist when learning a policy by interacting with the environment, called on-policy learning. The agent will perform the task from start to finish and based on the sample experience generated, update its estimates of the q-values of each state-action pair $Q(s,a)$.\n", 22 | "

\n", 23 | "
\n", 24 | "\n", 25 | "
" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "source": [ 31 | "# @title Setup code (not important) - Run this cell by pressing \"Shift + Enter\"\n", 32 | "\n", 33 | "\n", 34 | "\n", 35 | "!pip install -qq gym==0.23.0\n", 36 | "\n", 37 | "\n", 38 | "from typing import Tuple, Dict, Optional, Iterable, Callable\n", 39 | "\n", 40 | "import numpy as np\n", 41 | "import seaborn as sns\n", 42 | "import matplotlib\n", 43 | "from matplotlib import animation\n", 44 | "\n", 45 | "from IPython.display import HTML\n", 46 | "\n", 47 | "import gym\n", 48 | "from gym import spaces\n", 49 | "from gym.error import DependencyNotInstalled\n", 50 | "\n", 51 | "import pygame\n", 52 | "from pygame import gfxdraw\n", 53 | "\n", 54 | "\n", 55 | "class Maze(gym.Env):\n", 56 | "\n", 57 | " def __init__(self, exploring_starts: bool = False,\n", 58 | " shaped_rewards: bool = False, size: int = 5) -> None:\n", 59 | " super().__init__()\n", 60 | " self.exploring_starts = exploring_starts\n", 61 | " self.shaped_rewards = shaped_rewards\n", 62 | " self.state = (size - 1, size - 1)\n", 63 | " self.goal = (size - 1, size - 1)\n", 64 | " self.maze = self._create_maze(size=size)\n", 65 | " self.distances = self._compute_distances(self.goal, self.maze)\n", 66 | " self.action_space = spaces.Discrete(n=4)\n", 67 | " self.action_space.action_meanings = {0: 'UP', 1: 'RIGHT', 2: 'DOWN', 3: \"LEFT\"}\n", 68 | " self.observation_space = spaces.MultiDiscrete([size, size])\n", 69 | "\n", 70 | " self.screen = None\n", 71 | " self.agent_transform = None\n", 72 | "\n", 73 | " def step(self, action: int) -> Tuple[Tuple[int, int], float, bool, Dict]:\n", 74 | " reward = self.compute_reward(self.state, action)\n", 75 | " self.state = self._get_next_state(self.state, action)\n", 76 | " done = self.state == self.goal\n", 77 | " info = {}\n", 78 | " return self.state, reward, done, info\n", 79 | "\n", 80 | " def reset(self) -> Tuple[int, int]:\n", 81 | " if self.exploring_starts:\n", 82 | " while self.state == self.goal:\n", 83 | " self.state = tuple(self.observation_space.sample())\n", 84 | " else:\n", 85 | " self.state = (0, 0)\n", 86 | " return self.state\n", 87 | "\n", 88 | " def render(self, mode: str = 'human') -> Optional[np.ndarray]:\n", 89 | " assert mode in ['human', 'rgb_array']\n", 90 | "\n", 91 | " screen_size = 600\n", 92 | " scale = screen_size / 5\n", 93 | "\n", 94 | " if self.screen is None:\n", 95 | " pygame.init()\n", 96 | " self.screen = pygame.Surface((screen_size, screen_size))\n", 97 | "\n", 98 | " surf = pygame.Surface((screen_size, screen_size))\n", 99 | " surf.fill((22, 36, 71))\n", 100 | "\n", 101 | "\n", 102 | " for row in range(5):\n", 103 | " for col in range(5):\n", 104 | "\n", 105 | " state = (row, col)\n", 106 | " for next_state in [(row + 1, col), (row - 1, col), (row, col + 1), (row, col - 1)]:\n", 107 | " if next_state not in self.maze[state]:\n", 108 | "\n", 109 | " # Add the geometry of the edges and walls (i.e. the boundaries between\n", 110 | " # adjacent squares that are not connected).\n", 111 | " row_diff, col_diff = np.subtract(next_state, state)\n", 112 | " left = (col + (col_diff > 0)) * scale - 2 * (col_diff != 0)\n", 113 | " right = ((col + 1) - (col_diff < 0)) * scale + 2 * (col_diff != 0)\n", 114 | " top = (5 - (row + (row_diff > 0))) * scale - 2 * (row_diff != 0)\n", 115 | " bottom = (5 - ((row + 1) - (row_diff < 0))) * scale + 2 * (row_diff != 0)\n", 116 | "\n", 117 | " gfxdraw.filled_polygon(surf, [(left, bottom), (left, top), (right, top), (right, bottom)], (255, 255, 255))\n", 118 | "\n", 119 | " # Add the geometry of the goal square to the viewer.\n", 120 | " left, right, top, bottom = scale * 4 + 10, scale * 5 - 10, scale - 10, 10\n", 121 | " gfxdraw.filled_polygon(surf, [(left, bottom), (left, top), (right, top), (right, bottom)], (40, 199, 172))\n", 122 | "\n", 123 | " # Add the geometry of the agent to the viewer.\n", 124 | " agent_row = int(screen_size - scale * (self.state[0] + .5))\n", 125 | " agent_col = int(scale * (self.state[1] + .5))\n", 126 | " gfxdraw.filled_circle(surf, agent_col, agent_row, int(scale * .6 / 2), (228, 63, 90))\n", 127 | "\n", 128 | " surf = pygame.transform.flip(surf, False, True)\n", 129 | " self.screen.blit(surf, (0, 0))\n", 130 | "\n", 131 | " return np.transpose(\n", 132 | " np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2)\n", 133 | " )\n", 134 | "\n", 135 | " def close(self) -> None:\n", 136 | " if self.screen is not None:\n", 137 | " pygame.display.quit()\n", 138 | " pygame.quit()\n", 139 | " self.screen = None\n", 140 | "\n", 141 | " def compute_reward(self, state: Tuple[int, int], action: int) -> float:\n", 142 | " next_state = self._get_next_state(state, action)\n", 143 | " if self.shaped_rewards:\n", 144 | " return - (self.distances[next_state] / self.distances.max())\n", 145 | " return - float(state != self.goal)\n", 146 | "\n", 147 | " def simulate_step(self, state: Tuple[int, int], action: int):\n", 148 | " reward = self.compute_reward(state, action)\n", 149 | " next_state = self._get_next_state(state, action)\n", 150 | " done = next_state == self.goal\n", 151 | " info = {}\n", 152 | " return next_state, reward, done, info\n", 153 | "\n", 154 | " def _get_next_state(self, state: Tuple[int, int], action: int) -> Tuple[int, int]:\n", 155 | " if action == 0:\n", 156 | " next_state = (state[0] - 1, state[1])\n", 157 | " elif action == 1:\n", 158 | " next_state = (state[0], state[1] + 1)\n", 159 | " elif action == 2:\n", 160 | " next_state = (state[0] + 1, state[1])\n", 161 | " elif action == 3:\n", 162 | " next_state = (state[0], state[1] - 1)\n", 163 | " else:\n", 164 | " raise ValueError(\"Action value not supported:\", action)\n", 165 | " if next_state in self.maze[state]:\n", 166 | " return next_state\n", 167 | " return state\n", 168 | "\n", 169 | " @staticmethod\n", 170 | " def _create_maze(size: int) -> Dict[Tuple[int, int], Iterable[Tuple[int, int]]]:\n", 171 | " maze = {(row, col): [(row - 1, col), (row + 1, col), (row, col - 1), (row, col + 1)]\n", 172 | " for row in range(size) for col in range(size)}\n", 173 | "\n", 174 | " left_edges = [[(row, 0), (row, -1)] for row in range(size)]\n", 175 | " right_edges = [[(row, size - 1), (row, size)] for row in range(size)]\n", 176 | " upper_edges = [[(0, col), (-1, col)] for col in range(size)]\n", 177 | " lower_edges = [[(size - 1, col), (size, col)] for col in range(size)]\n", 178 | " walls = [\n", 179 | " [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)],\n", 180 | " [(1, 1), (1, 2)], [(2, 1), (2, 2)], [(3, 1), (3, 2)],\n", 181 | " [(3, 1), (4, 1)], [(0, 2), (1, 2)], [(1, 2), (1, 3)],\n", 182 | " [(2, 2), (3, 2)], [(2, 3), (3, 3)], [(2, 4), (3, 4)],\n", 183 | " [(4, 2), (4, 3)], [(1, 3), (1, 4)], [(2, 3), (2, 4)],\n", 184 | " ]\n", 185 | "\n", 186 | " obstacles = upper_edges + lower_edges + left_edges + right_edges + walls\n", 187 | "\n", 188 | " for src, dst in obstacles:\n", 189 | " maze[src].remove(dst)\n", 190 | "\n", 191 | " if dst in maze:\n", 192 | " maze[dst].remove(src)\n", 193 | "\n", 194 | " return maze\n", 195 | "\n", 196 | " @staticmethod\n", 197 | " def _compute_distances(goal: Tuple[int, int],\n", 198 | " maze: Dict[Tuple[int, int], Iterable[Tuple[int, int]]]) -> np.ndarray:\n", 199 | " distances = np.full((5, 5), np.inf)\n", 200 | " visited = set()\n", 201 | " distances[goal] = 0.\n", 202 | "\n", 203 | " while visited != set(maze):\n", 204 | " sorted_dst = [(v // 5, v % 5) for v in distances.argsort(axis=None)]\n", 205 | " closest = next(x for x in sorted_dst if x not in visited)\n", 206 | " visited.add(closest)\n", 207 | "\n", 208 | " for neighbour in maze[closest]:\n", 209 | " distances[neighbour] = min(distances[neighbour], distances[closest] + 1)\n", 210 | " return distances\n", 211 | "\n", 212 | "\n", 213 | "def plot_policy(probs_or_qvals, frame, action_meanings=None):\n", 214 | " if action_meanings is None:\n", 215 | " action_meanings = {0: 'U', 1: 'R', 2: 'D', 3: 'L'}\n", 216 | " fig, axes = plt.subplots(1, 2, figsize=(8, 4))\n", 217 | " max_prob_actions = probs_or_qvals.argmax(axis=-1)\n", 218 | " probs_copy = max_prob_actions.copy().astype(object)\n", 219 | " for key in action_meanings:\n", 220 | " probs_copy[probs_copy == key] = action_meanings[key]\n", 221 | " sns.heatmap(max_prob_actions, annot=probs_copy, fmt='', cbar=False, cmap='coolwarm',\n", 222 | " annot_kws={'weight': 'bold', 'size': 12}, linewidths=2, ax=axes[0])\n", 223 | " axes[1].imshow(frame)\n", 224 | " axes[0].axis('off')\n", 225 | " axes[1].axis('off')\n", 226 | " plt.suptitle(\"Policy\", size=18)\n", 227 | " plt.tight_layout()\n", 228 | "\n", 229 | "\n", 230 | "def plot_values(state_values, frame):\n", 231 | " f, axes = plt.subplots(1, 2, figsize=(10, 4))\n", 232 | " sns.heatmap(state_values, annot=True, fmt=\".2f\", cmap='coolwarm',\n", 233 | " annot_kws={'weight': 'bold', 'size': 12}, linewidths=2, ax=axes[0])\n", 234 | " axes[1].imshow(frame)\n", 235 | " axes[0].axis('off')\n", 236 | " axes[1].axis('off')\n", 237 | " plt.tight_layout()\n", 238 | "\n", 239 | "\n", 240 | "def display_video(frames):\n", 241 | " # Copied from: https://colab.research.google.com/github/deepmind/dm_control/blob/master/tutorial.ipynb\n", 242 | " orig_backend = matplotlib.get_backend()\n", 243 | " matplotlib.use('Agg')\n", 244 | " fig, ax = plt.subplots(1, 1, figsize=(5, 5))\n", 245 | " matplotlib.use(orig_backend)\n", 246 | " ax.set_axis_off()\n", 247 | " ax.set_aspect('equal')\n", 248 | " ax.set_position([0, 0, 1, 1])\n", 249 | " im = ax.imshow(frames[0])\n", 250 | " def update(frame):\n", 251 | " im.set_data(frame)\n", 252 | " return [im]\n", 253 | " anim = animation.FuncAnimation(fig=fig, func=update, frames=frames,\n", 254 | " interval=50, blit=True, repeat=False)\n", 255 | " return HTML(anim.to_html5_video())\n", 256 | "\n", 257 | "\n", 258 | "def test_agent(environment, policy, episodes=10):\n", 259 | " frames = []\n", 260 | " for episode in range(episodes):\n", 261 | " state = env.reset()\n", 262 | " done = False\n", 263 | " frames.append(env.render(mode=\"rgb_array\"))\n", 264 | "\n", 265 | " while not done:\n", 266 | " p = policy(state)\n", 267 | " if isinstance(p, np.ndarray):\n", 268 | " action = np.random.choice(4, p=p)\n", 269 | " else:\n", 270 | " action = p\n", 271 | " next_state, reward, done, extra_info = env.step(action)\n", 272 | " img = env.render(mode=\"rgb_array\")\n", 273 | " frames.append(img)\n", 274 | " state = next_state\n", 275 | "\n", 276 | " return display_video(frames)\n", 277 | "\n", 278 | "\n", 279 | "def plot_action_values(action_values):\n", 280 | "\n", 281 | " text_positions = [\n", 282 | " [(0.35, 4.75), (1.35, 4.75), (2.35, 4.75), (3.35, 4.75), (4.35, 4.75),\n", 283 | " (0.35, 3.75), (1.35, 3.75), (2.35, 3.75), (3.35, 3.75), (4.35, 3.75),\n", 284 | " (0.35, 2.75), (1.35, 2.75), (2.35, 2.75), (3.35, 2.75), (4.35, 2.75),\n", 285 | " (0.35, 1.75), (1.35, 1.75), (2.35, 1.75), (3.35, 1.75), (4.35, 1.75),\n", 286 | " (0.35, 0.75), (1.35, 0.75), (2.35, 0.75), (3.35, 0.75), (4.35, 0.75)],\n", 287 | " [(0.6, 4.45), (1.6, 4.45), (2.6, 4.45), (3.6, 4.45), (4.6, 4.45),\n", 288 | " (0.6, 3.45), (1.6, 3.45), (2.6, 3.45), (3.6, 3.45), (4.6, 3.45),\n", 289 | " (0.6, 2.45), (1.6, 2.45), (2.6, 2.45), (3.6, 2.45), (4.6, 2.45),\n", 290 | " (0.6, 1.45), (1.6, 1.45), (2.6, 1.45), (3.6, 1.45), (4.6, 1.45),\n", 291 | " (0.6, 0.45), (1.6, 0.45), (2.6, 0.45), (3.6, 0.45), (4.6, 0.45)],\n", 292 | " [(0.35, 4.15), (1.35, 4.15), (2.35, 4.15), (3.35, 4.15), (4.35, 4.15),\n", 293 | " (0.35, 3.15), (1.35, 3.15), (2.35, 3.15), (3.35, 3.15), (4.35, 3.15),\n", 294 | " (0.35, 2.15), (1.35, 2.15), (2.35, 2.15), (3.35, 2.15), (4.35, 2.15),\n", 295 | " (0.35, 1.15), (1.35, 1.15), (2.35, 1.15), (3.35, 1.15), (4.35, 1.15),\n", 296 | " (0.35, 0.15), (1.35, 0.15), (2.35, 0.15), (3.35, 0.15), (4.35, 0.15)],\n", 297 | " [(0.05, 4.45), (1.05, 4.45), (2.05, 4.45), (3.05, 4.45), (4.05, 4.45),\n", 298 | " (0.05, 3.45), (1.05, 3.45), (2.05, 3.45), (3.05, 3.45), (4.05, 3.45),\n", 299 | " (0.05, 2.45), (1.05, 2.45), (2.05, 2.45), (3.05, 2.45), (4.05, 2.45),\n", 300 | " (0.05, 1.45), (1.05, 1.45), (2.05, 1.45), (3.05, 1.45), (4.05, 1.45),\n", 301 | " (0.05, 0.45), (1.05, 0.45), (2.05, 0.45), (3.05, 0.45), (4.05, 0.45)]]\n", 302 | "\n", 303 | " fig, ax = plt.subplots(figsize=(7, 7))\n", 304 | " tripcolor = quatromatrix(action_values, ax=ax,\n", 305 | " triplotkw={\"color\": \"k\", \"lw\": 1}, tripcolorkw={\"cmap\": \"coolwarm\"})\n", 306 | " ax.margins(0)\n", 307 | " ax.set_aspect(\"equal\")\n", 308 | " fig.colorbar(tripcolor)\n", 309 | "\n", 310 | " for j, av in enumerate(text_positions):\n", 311 | " for i, (xi, yi) in enumerate(av):\n", 312 | " plt.text(xi, yi, round(action_values[:, :, j].flatten()[i], 2), size=8, color=\"w\", weight=\"bold\")\n", 313 | "\n", 314 | " plt.title(\"Action values Q(s,a)\", size=18)\n", 315 | " plt.tight_layout()\n", 316 | " plt.show()\n", 317 | "\n", 318 | "\n", 319 | "def quatromatrix(action_values, ax=None, triplotkw=None, tripcolorkw=None):\n", 320 | " action_values = np.flipud(action_values)\n", 321 | " n = 5\n", 322 | " m = 5\n", 323 | " a = np.array([[0, 0], [0, 1], [.5, .5], [1, 0], [1, 1]])\n", 324 | " tr = np.array([[0, 1, 2], [0, 2, 3], [2, 3, 4], [1, 2, 4]])\n", 325 | " A = np.zeros((n * m * 5, 2))\n", 326 | " Tr = np.zeros((n * m * 4, 3))\n", 327 | " for i in range(n):\n", 328 | " for j in range(m):\n", 329 | " k = i * m + j\n", 330 | " A[k * 5:(k + 1) * 5, :] = np.c_[a[:, 0] + j, a[:, 1] + i]\n", 331 | " Tr[k * 4:(k + 1) * 4, :] = tr + k * 5\n", 332 | " C = np.c_[action_values[:, :, 3].flatten(), action_values[:, :, 2].flatten(),\n", 333 | " action_values[:, :, 1].flatten(), action_values[:, :, 0].flatten()].flatten()\n", 334 | "\n", 335 | " ax.triplot(A[:, 0], A[:, 1], Tr, **triplotkw)\n", 336 | " tripcolor = ax.tripcolor(A[:, 0], A[:, 1], Tr, facecolors=C, **tripcolorkw)\n", 337 | " return tripcolor\n", 338 | "\n", 339 | "\n" 340 | ], 341 | "metadata": { 342 | "cellView": "form", 343 | "id": "WQ3CjcEb9Fkx" 344 | }, 345 | "execution_count": null, 346 | "outputs": [] 347 | }, 348 | { 349 | "cell_type": "markdown", 350 | "metadata": { 351 | "id": "i1QHHcC38Pas" 352 | }, 353 | "source": [ 354 | "## Import the necessary software libraries:" 355 | ] 356 | }, 357 | { 358 | "cell_type": "code", 359 | "execution_count": null, 360 | "metadata": { 361 | "id": "ddSL9RtK8Pas" 362 | }, 363 | "outputs": [], 364 | "source": [ 365 | "import numpy as np\n", 366 | "import matplotlib.pyplot as plt" 367 | ] 368 | }, 369 | { 370 | "cell_type": "markdown", 371 | "metadata": { 372 | "id": "aujcXnZg8Pat" 373 | }, 374 | "source": [ 375 | "## Initialize the environment" 376 | ] 377 | }, 378 | { 379 | "cell_type": "code", 380 | "execution_count": null, 381 | "metadata": { 382 | "id": "aPtijbjT8Pat" 383 | }, 384 | "outputs": [], 385 | "source": [ 386 | "env = Maze()" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": null, 392 | "metadata": { 393 | "id": "bEYRDn1Q8Pau" 394 | }, 395 | "outputs": [], 396 | "source": [ 397 | "frame = env.render(mode='rgb_array')\n", 398 | "plt.axis('off')\n", 399 | "plt.imshow(frame)" 400 | ] 401 | }, 402 | { 403 | "cell_type": "code", 404 | "execution_count": null, 405 | "metadata": { 406 | "id": "tWRgQvYn8Pau" 407 | }, 408 | "outputs": [], 409 | "source": [ 410 | "print(f\"Observation space shape: {env.observation_space.nvec}\")\n", 411 | "print(f\"Number of actions: {env.action_space.n}\")" 412 | ] 413 | }, 414 | { 415 | "cell_type": "markdown", 416 | "metadata": { 417 | "id": "nYSxH6vb8Pau" 418 | }, 419 | "source": [ 420 | "## Define value table $Q(s, a)$" 421 | ] 422 | }, 423 | { 424 | "cell_type": "markdown", 425 | "metadata": { 426 | "id": "thgvYa_S8Pav" 427 | }, 428 | "source": [ 429 | "#### Create the $Q(s, a)$ table" 430 | ] 431 | }, 432 | { 433 | "cell_type": "code", 434 | "execution_count": null, 435 | "metadata": { 436 | "id": "fehyO_5p8Pav" 437 | }, 438 | "outputs": [], 439 | "source": [] 440 | }, 441 | { 442 | "cell_type": "markdown", 443 | "metadata": { 444 | "id": "aVNIyhiH8Pav" 445 | }, 446 | "source": [ 447 | "#### Plot Q(s, a)" 448 | ] 449 | }, 450 | { 451 | "cell_type": "code", 452 | "execution_count": null, 453 | "metadata": { 454 | "id": "ZSvw1wgF8Pav" 455 | }, 456 | "outputs": [], 457 | "source": [] 458 | }, 459 | { 460 | "cell_type": "markdown", 461 | "metadata": { 462 | "id": "ysJQJGXm8Pav" 463 | }, 464 | "source": [ 465 | "## Define the policy $\\pi(s)$" 466 | ] 467 | }, 468 | { 469 | "cell_type": "markdown", 470 | "metadata": { 471 | "id": "0lRYXyx68Pav" 472 | }, 473 | "source": [ 474 | "#### Create the policy $\\pi(s)$" 475 | ] 476 | }, 477 | { 478 | "cell_type": "code", 479 | "execution_count": null, 480 | "metadata": { 481 | "id": "Zc2nSOTb8Pav" 482 | }, 483 | "outputs": [], 484 | "source": [] 485 | }, 486 | { 487 | "cell_type": "markdown", 488 | "metadata": { 489 | "id": "V6y0E6nx8Pav" 490 | }, 491 | "source": [ 492 | "#### Test the policy with state (0, 0)" 493 | ] 494 | }, 495 | { 496 | "cell_type": "code", 497 | "execution_count": null, 498 | "metadata": { 499 | "id": "ck167K4M8Pav" 500 | }, 501 | "outputs": [], 502 | "source": [] 503 | }, 504 | { 505 | "cell_type": "markdown", 506 | "metadata": { 507 | "id": "g2xH7qSZ8Pav" 508 | }, 509 | "source": [ 510 | "#### Plot the policy" 511 | ] 512 | }, 513 | { 514 | "cell_type": "code", 515 | "execution_count": null, 516 | "metadata": { 517 | "scrolled": false, 518 | "id": "KzYsJsnO8Pav" 519 | }, 520 | "outputs": [], 521 | "source": [] 522 | }, 523 | { 524 | "cell_type": "markdown", 525 | "metadata": { 526 | "id": "PqAqHkpk8Paw" 527 | }, 528 | "source": [ 529 | "## Implement the algorithm\n" 530 | ] 531 | }, 532 | { 533 | "cell_type": "code", 534 | "execution_count": null, 535 | "metadata": { 536 | "id": "Os79GeR88Paw" 537 | }, 538 | "outputs": [], 539 | "source": [] 540 | }, 541 | { 542 | "cell_type": "code", 543 | "execution_count": null, 544 | "metadata": { 545 | "id": "GMHevlFT8Paw" 546 | }, 547 | "outputs": [], 548 | "source": [] 549 | }, 550 | { 551 | "cell_type": "markdown", 552 | "metadata": { 553 | "id": "VMNGJYQ38Paw" 554 | }, 555 | "source": [ 556 | "## Show results" 557 | ] 558 | }, 559 | { 560 | "cell_type": "markdown", 561 | "metadata": { 562 | "id": "rvgNcDkX8Paw" 563 | }, 564 | "source": [ 565 | "#### Show resulting value table $Q(s, a)$" 566 | ] 567 | }, 568 | { 569 | "cell_type": "code", 570 | "execution_count": null, 571 | "metadata": { 572 | "id": "yIyUuFZE8Paw" 573 | }, 574 | "outputs": [], 575 | "source": [] 576 | }, 577 | { 578 | "cell_type": "markdown", 579 | "metadata": { 580 | "id": "61A_1lhs8Paw" 581 | }, 582 | "source": [ 583 | "#### Show resulting policy $\\pi(\\cdot|s)$" 584 | ] 585 | }, 586 | { 587 | "cell_type": "code", 588 | "execution_count": null, 589 | "metadata": { 590 | "id": "lmmo4uMN8Paw" 591 | }, 592 | "outputs": [], 593 | "source": [] 594 | }, 595 | { 596 | "cell_type": "markdown", 597 | "metadata": { 598 | "id": "qaQWXzb28Paw" 599 | }, 600 | "source": [ 601 | "#### Test the resulting agent" 602 | ] 603 | }, 604 | { 605 | "cell_type": "code", 606 | "execution_count": null, 607 | "metadata": { 608 | "id": "aMm5uC9x8Paw" 609 | }, 610 | "outputs": [], 611 | "source": [ 612 | "test_agent(env, policy)" 613 | ] 614 | }, 615 | { 616 | "cell_type": "markdown", 617 | "metadata": { 618 | "id": "MNTpDkGO8Paw" 619 | }, 620 | "source": [ 621 | "## Resources" 622 | ] 623 | }, 624 | { 625 | "cell_type": "markdown", 626 | "metadata": { 627 | "id": "NDaGZUNh8Paw" 628 | }, 629 | "source": [ 630 | "[[1] Reinforcement Learning: An Introduction. Ch. 4: Dynamic Programming](https://web.stanford.edu/class/psych209/Readings/SuttonBartoIPRLBook2ndEd.pdf)" 631 | ] 632 | } 633 | ], 634 | "metadata": { 635 | "kernelspec": { 636 | "display_name": "Python 3", 637 | "language": "python", 638 | "name": "python3" 639 | }, 640 | "language_info": { 641 | "codemirror_mode": { 642 | "name": "ipython", 643 | "version": 3 644 | }, 645 | "file_extension": ".py", 646 | "mimetype": "text/x-python", 647 | "name": "python", 648 | "nbconvert_exporter": "python", 649 | "pygments_lexer": "ipython3", 650 | "version": "3.8.5" 651 | }, 652 | "colab": { 653 | "provenance": [] 654 | } 655 | }, 656 | "nbformat": 4, 657 | "nbformat_minor": 0 658 | } -------------------------------------------------------------------------------- /Section_5_sarsa.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "pycharm": { 7 | "name": "#%%\n" 8 | }, 9 | "id": "CqXtasNJqBFG" 10 | }, 11 | "source": [ 12 | "
\n", 13 | "

\n", 14 | " SARSA\n", 15 | "

\n", 16 | "
\n", 17 | "\n", 18 | "

\n", 19 | "\n", 20 | "
\n", 21 | " In this notebook we are going to implement a method that learns from experience and uses bootstrapping.\n", 22 | " It is known as SARSA because of the elements involved in the update rule:\n", 23 | "
\n", 24 | "\n", 25 | "\\begin{equation}\n", 26 | "\\text{State}_t, \\text{Action}_t, \\text{Reward}_t, \\text{State}_{t+1}, \\text{Action}_{t+1}\n", 27 | "\\end{equation}\n", 28 | "\n", 29 | "
\n", 30 | "\n", 31 | "
\n", 32 | " This method follows an on-policy strategy, in which the same policy that is optimized is responsible for scanning the environment.\n", 33 | "
\n", 34 | "\n", 35 | "\n", 36 | "
" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "source": [ 42 | "# @title Setup code (not important) - Run this cell by pressing \"Shift + Enter\"\n", 43 | "\n", 44 | "\n", 45 | "\n", 46 | "!pip install -qq gym==0.23.0\n", 47 | "\n", 48 | "\n", 49 | "from typing import Tuple, Dict, Optional, Iterable, Callable\n", 50 | "\n", 51 | "import numpy as np\n", 52 | "import seaborn as sns\n", 53 | "import matplotlib\n", 54 | "from matplotlib import animation\n", 55 | "\n", 56 | "from IPython.display import HTML\n", 57 | "\n", 58 | "import gym\n", 59 | "from gym import spaces\n", 60 | "from gym.error import DependencyNotInstalled\n", 61 | "\n", 62 | "import pygame\n", 63 | "from pygame import gfxdraw\n", 64 | "\n", 65 | "\n", 66 | "class Maze(gym.Env):\n", 67 | "\n", 68 | " def __init__(self, exploring_starts: bool = False,\n", 69 | " shaped_rewards: bool = False, size: int = 5) -> None:\n", 70 | " super().__init__()\n", 71 | " self.exploring_starts = exploring_starts\n", 72 | " self.shaped_rewards = shaped_rewards\n", 73 | " self.state = (size - 1, size - 1)\n", 74 | " self.goal = (size - 1, size - 1)\n", 75 | " self.maze = self._create_maze(size=size)\n", 76 | " self.distances = self._compute_distances(self.goal, self.maze)\n", 77 | " self.action_space = spaces.Discrete(n=4)\n", 78 | " self.action_space.action_meanings = {0: 'UP', 1: 'RIGHT', 2: 'DOWN', 3: \"LEFT\"}\n", 79 | " self.observation_space = spaces.MultiDiscrete([size, size])\n", 80 | "\n", 81 | " self.screen = None\n", 82 | " self.agent_transform = None\n", 83 | "\n", 84 | " def step(self, action: int) -> Tuple[Tuple[int, int], float, bool, Dict]:\n", 85 | " reward = self.compute_reward(self.state, action)\n", 86 | " self.state = self._get_next_state(self.state, action)\n", 87 | " done = self.state == self.goal\n", 88 | " info = {}\n", 89 | " return self.state, reward, done, info\n", 90 | "\n", 91 | " def reset(self) -> Tuple[int, int]:\n", 92 | " if self.exploring_starts:\n", 93 | " while self.state == self.goal:\n", 94 | " self.state = tuple(self.observation_space.sample())\n", 95 | " else:\n", 96 | " self.state = (0, 0)\n", 97 | " return self.state\n", 98 | "\n", 99 | " def render(self, mode: str = 'human') -> Optional[np.ndarray]:\n", 100 | " assert mode in ['human', 'rgb_array']\n", 101 | "\n", 102 | " screen_size = 600\n", 103 | " scale = screen_size / 5\n", 104 | "\n", 105 | " if self.screen is None:\n", 106 | " pygame.init()\n", 107 | " self.screen = pygame.Surface((screen_size, screen_size))\n", 108 | "\n", 109 | " surf = pygame.Surface((screen_size, screen_size))\n", 110 | " surf.fill((22, 36, 71))\n", 111 | "\n", 112 | "\n", 113 | " for row in range(5):\n", 114 | " for col in range(5):\n", 115 | "\n", 116 | " state = (row, col)\n", 117 | " for next_state in [(row + 1, col), (row - 1, col), (row, col + 1), (row, col - 1)]:\n", 118 | " if next_state not in self.maze[state]:\n", 119 | "\n", 120 | " # Add the geometry of the edges and walls (i.e. the boundaries between\n", 121 | " # adjacent squares that are not connected).\n", 122 | " row_diff, col_diff = np.subtract(next_state, state)\n", 123 | " left = (col + (col_diff > 0)) * scale - 2 * (col_diff != 0)\n", 124 | " right = ((col + 1) - (col_diff < 0)) * scale + 2 * (col_diff != 0)\n", 125 | " top = (5 - (row + (row_diff > 0))) * scale - 2 * (row_diff != 0)\n", 126 | " bottom = (5 - ((row + 1) - (row_diff < 0))) * scale + 2 * (row_diff != 0)\n", 127 | "\n", 128 | " gfxdraw.filled_polygon(surf, [(left, bottom), (left, top), (right, top), (right, bottom)], (255, 255, 255))\n", 129 | "\n", 130 | " # Add the geometry of the goal square to the viewer.\n", 131 | " left, right, top, bottom = scale * 4 + 10, scale * 5 - 10, scale - 10, 10\n", 132 | " gfxdraw.filled_polygon(surf, [(left, bottom), (left, top), (right, top), (right, bottom)], (40, 199, 172))\n", 133 | "\n", 134 | " # Add the geometry of the agent to the viewer.\n", 135 | " agent_row = int(screen_size - scale * (self.state[0] + .5))\n", 136 | " agent_col = int(scale * (self.state[1] + .5))\n", 137 | " gfxdraw.filled_circle(surf, agent_col, agent_row, int(scale * .6 / 2), (228, 63, 90))\n", 138 | "\n", 139 | " surf = pygame.transform.flip(surf, False, True)\n", 140 | " self.screen.blit(surf, (0, 0))\n", 141 | "\n", 142 | " return np.transpose(\n", 143 | " np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2)\n", 144 | " )\n", 145 | "\n", 146 | " def close(self) -> None:\n", 147 | " if self.screen is not None:\n", 148 | " pygame.display.quit()\n", 149 | " pygame.quit()\n", 150 | " self.screen = None\n", 151 | "\n", 152 | " def compute_reward(self, state: Tuple[int, int], action: int) -> float:\n", 153 | " next_state = self._get_next_state(state, action)\n", 154 | " if self.shaped_rewards:\n", 155 | " return - (self.distances[next_state] / self.distances.max())\n", 156 | " return - float(state != self.goal)\n", 157 | "\n", 158 | " def simulate_step(self, state: Tuple[int, int], action: int):\n", 159 | " reward = self.compute_reward(state, action)\n", 160 | " next_state = self._get_next_state(state, action)\n", 161 | " done = next_state == self.goal\n", 162 | " info = {}\n", 163 | " return next_state, reward, done, info\n", 164 | "\n", 165 | " def _get_next_state(self, state: Tuple[int, int], action: int) -> Tuple[int, int]:\n", 166 | " if action == 0:\n", 167 | " next_state = (state[0] - 1, state[1])\n", 168 | " elif action == 1:\n", 169 | " next_state = (state[0], state[1] + 1)\n", 170 | " elif action == 2:\n", 171 | " next_state = (state[0] + 1, state[1])\n", 172 | " elif action == 3:\n", 173 | " next_state = (state[0], state[1] - 1)\n", 174 | " else:\n", 175 | " raise ValueError(\"Action value not supported:\", action)\n", 176 | " if next_state in self.maze[state]:\n", 177 | " return next_state\n", 178 | " return state\n", 179 | "\n", 180 | " @staticmethod\n", 181 | " def _create_maze(size: int) -> Dict[Tuple[int, int], Iterable[Tuple[int, int]]]:\n", 182 | " maze = {(row, col): [(row - 1, col), (row + 1, col), (row, col - 1), (row, col + 1)]\n", 183 | " for row in range(size) for col in range(size)}\n", 184 | "\n", 185 | " left_edges = [[(row, 0), (row, -1)] for row in range(size)]\n", 186 | " right_edges = [[(row, size - 1), (row, size)] for row in range(size)]\n", 187 | " upper_edges = [[(0, col), (-1, col)] for col in range(size)]\n", 188 | " lower_edges = [[(size - 1, col), (size, col)] for col in range(size)]\n", 189 | " walls = [\n", 190 | " [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)],\n", 191 | " [(1, 1), (1, 2)], [(2, 1), (2, 2)], [(3, 1), (3, 2)],\n", 192 | " [(3, 1), (4, 1)], [(0, 2), (1, 2)], [(1, 2), (1, 3)],\n", 193 | " [(2, 2), (3, 2)], [(2, 3), (3, 3)], [(2, 4), (3, 4)],\n", 194 | " [(4, 2), (4, 3)], [(1, 3), (1, 4)], [(2, 3), (2, 4)],\n", 195 | " ]\n", 196 | "\n", 197 | " obstacles = upper_edges + lower_edges + left_edges + right_edges + walls\n", 198 | "\n", 199 | " for src, dst in obstacles:\n", 200 | " maze[src].remove(dst)\n", 201 | "\n", 202 | " if dst in maze:\n", 203 | " maze[dst].remove(src)\n", 204 | "\n", 205 | " return maze\n", 206 | "\n", 207 | " @staticmethod\n", 208 | " def _compute_distances(goal: Tuple[int, int],\n", 209 | " maze: Dict[Tuple[int, int], Iterable[Tuple[int, int]]]) -> np.ndarray:\n", 210 | " distances = np.full((5, 5), np.inf)\n", 211 | " visited = set()\n", 212 | " distances[goal] = 0.\n", 213 | "\n", 214 | " while visited != set(maze):\n", 215 | " sorted_dst = [(v // 5, v % 5) for v in distances.argsort(axis=None)]\n", 216 | " closest = next(x for x in sorted_dst if x not in visited)\n", 217 | " visited.add(closest)\n", 218 | "\n", 219 | " for neighbour in maze[closest]:\n", 220 | " distances[neighbour] = min(distances[neighbour], distances[closest] + 1)\n", 221 | " return distances\n", 222 | "\n", 223 | "\n", 224 | "def plot_policy(probs_or_qvals, frame, action_meanings=None):\n", 225 | " if action_meanings is None:\n", 226 | " action_meanings = {0: 'U', 1: 'R', 2: 'D', 3: 'L'}\n", 227 | " fig, axes = plt.subplots(1, 2, figsize=(8, 4))\n", 228 | " max_prob_actions = probs_or_qvals.argmax(axis=-1)\n", 229 | " probs_copy = max_prob_actions.copy().astype(object)\n", 230 | " for key in action_meanings:\n", 231 | " probs_copy[probs_copy == key] = action_meanings[key]\n", 232 | " sns.heatmap(max_prob_actions, annot=probs_copy, fmt='', cbar=False, cmap='coolwarm',\n", 233 | " annot_kws={'weight': 'bold', 'size': 12}, linewidths=2, ax=axes[0])\n", 234 | " axes[1].imshow(frame)\n", 235 | " axes[0].axis('off')\n", 236 | " axes[1].axis('off')\n", 237 | " plt.suptitle(\"Policy\", size=18)\n", 238 | " plt.tight_layout()\n", 239 | "\n", 240 | "\n", 241 | "def plot_values(state_values, frame):\n", 242 | " f, axes = plt.subplots(1, 2, figsize=(10, 4))\n", 243 | " sns.heatmap(state_values, annot=True, fmt=\".2f\", cmap='coolwarm',\n", 244 | " annot_kws={'weight': 'bold', 'size': 12}, linewidths=2, ax=axes[0])\n", 245 | " axes[1].imshow(frame)\n", 246 | " axes[0].axis('off')\n", 247 | " axes[1].axis('off')\n", 248 | " plt.tight_layout()\n", 249 | "\n", 250 | "\n", 251 | "def display_video(frames):\n", 252 | " # Copied from: https://colab.research.google.com/github/deepmind/dm_control/blob/master/tutorial.ipynb\n", 253 | " orig_backend = matplotlib.get_backend()\n", 254 | " matplotlib.use('Agg')\n", 255 | " fig, ax = plt.subplots(1, 1, figsize=(5, 5))\n", 256 | " matplotlib.use(orig_backend)\n", 257 | " ax.set_axis_off()\n", 258 | " ax.set_aspect('equal')\n", 259 | " ax.set_position([0, 0, 1, 1])\n", 260 | " im = ax.imshow(frames[0])\n", 261 | " def update(frame):\n", 262 | " im.set_data(frame)\n", 263 | " return [im]\n", 264 | " anim = animation.FuncAnimation(fig=fig, func=update, frames=frames,\n", 265 | " interval=50, blit=True, repeat=False)\n", 266 | " return HTML(anim.to_html5_video())\n", 267 | "\n", 268 | "\n", 269 | "def test_agent(environment, policy, episodes=10):\n", 270 | " frames = []\n", 271 | " for episode in range(episodes):\n", 272 | " state = env.reset()\n", 273 | " done = False\n", 274 | " frames.append(env.render(mode=\"rgb_array\"))\n", 275 | "\n", 276 | " while not done:\n", 277 | " p = policy(state)\n", 278 | " if isinstance(p, np.ndarray):\n", 279 | " action = np.random.choice(4, p=p)\n", 280 | " else:\n", 281 | " action = p\n", 282 | " next_state, reward, done, extra_info = env.step(action)\n", 283 | " img = env.render(mode=\"rgb_array\")\n", 284 | " frames.append(img)\n", 285 | " state = next_state\n", 286 | "\n", 287 | " return display_video(frames)\n", 288 | "\n", 289 | "\n", 290 | "def plot_action_values(action_values):\n", 291 | "\n", 292 | " text_positions = [\n", 293 | " [(0.35, 4.75), (1.35, 4.75), (2.35, 4.75), (3.35, 4.75), (4.35, 4.75),\n", 294 | " (0.35, 3.75), (1.35, 3.75), (2.35, 3.75), (3.35, 3.75), (4.35, 3.75),\n", 295 | " (0.35, 2.75), (1.35, 2.75), (2.35, 2.75), (3.35, 2.75), (4.35, 2.75),\n", 296 | " (0.35, 1.75), (1.35, 1.75), (2.35, 1.75), (3.35, 1.75), (4.35, 1.75),\n", 297 | " (0.35, 0.75), (1.35, 0.75), (2.35, 0.75), (3.35, 0.75), (4.35, 0.75)],\n", 298 | " [(0.6, 4.45), (1.6, 4.45), (2.6, 4.45), (3.6, 4.45), (4.6, 4.45),\n", 299 | " (0.6, 3.45), (1.6, 3.45), (2.6, 3.45), (3.6, 3.45), (4.6, 3.45),\n", 300 | " (0.6, 2.45), (1.6, 2.45), (2.6, 2.45), (3.6, 2.45), (4.6, 2.45),\n", 301 | " (0.6, 1.45), (1.6, 1.45), (2.6, 1.45), (3.6, 1.45), (4.6, 1.45),\n", 302 | " (0.6, 0.45), (1.6, 0.45), (2.6, 0.45), (3.6, 0.45), (4.6, 0.45)],\n", 303 | " [(0.35, 4.15), (1.35, 4.15), (2.35, 4.15), (3.35, 4.15), (4.35, 4.15),\n", 304 | " (0.35, 3.15), (1.35, 3.15), (2.35, 3.15), (3.35, 3.15), (4.35, 3.15),\n", 305 | " (0.35, 2.15), (1.35, 2.15), (2.35, 2.15), (3.35, 2.15), (4.35, 2.15),\n", 306 | " (0.35, 1.15), (1.35, 1.15), (2.35, 1.15), (3.35, 1.15), (4.35, 1.15),\n", 307 | " (0.35, 0.15), (1.35, 0.15), (2.35, 0.15), (3.35, 0.15), (4.35, 0.15)],\n", 308 | " [(0.05, 4.45), (1.05, 4.45), (2.05, 4.45), (3.05, 4.45), (4.05, 4.45),\n", 309 | " (0.05, 3.45), (1.05, 3.45), (2.05, 3.45), (3.05, 3.45), (4.05, 3.45),\n", 310 | " (0.05, 2.45), (1.05, 2.45), (2.05, 2.45), (3.05, 2.45), (4.05, 2.45),\n", 311 | " (0.05, 1.45), (1.05, 1.45), (2.05, 1.45), (3.05, 1.45), (4.05, 1.45),\n", 312 | " (0.05, 0.45), (1.05, 0.45), (2.05, 0.45), (3.05, 0.45), (4.05, 0.45)]]\n", 313 | "\n", 314 | " fig, ax = plt.subplots(figsize=(7, 7))\n", 315 | " tripcolor = quatromatrix(action_values, ax=ax,\n", 316 | " triplotkw={\"color\": \"k\", \"lw\": 1}, tripcolorkw={\"cmap\": \"coolwarm\"})\n", 317 | " ax.margins(0)\n", 318 | " ax.set_aspect(\"equal\")\n", 319 | " fig.colorbar(tripcolor)\n", 320 | "\n", 321 | " for j, av in enumerate(text_positions):\n", 322 | " for i, (xi, yi) in enumerate(av):\n", 323 | " plt.text(xi, yi, round(action_values[:, :, j].flatten()[i], 2), size=8, color=\"w\", weight=\"bold\")\n", 324 | "\n", 325 | " plt.title(\"Action values Q(s,a)\", size=18)\n", 326 | " plt.tight_layout()\n", 327 | " plt.show()\n", 328 | "\n", 329 | "\n", 330 | "def quatromatrix(action_values, ax=None, triplotkw=None, tripcolorkw=None):\n", 331 | " action_values = np.flipud(action_values)\n", 332 | " n = 5\n", 333 | " m = 5\n", 334 | " a = np.array([[0, 0], [0, 1], [.5, .5], [1, 0], [1, 1]])\n", 335 | " tr = np.array([[0, 1, 2], [0, 2, 3], [2, 3, 4], [1, 2, 4]])\n", 336 | " A = np.zeros((n * m * 5, 2))\n", 337 | " Tr = np.zeros((n * m * 4, 3))\n", 338 | " for i in range(n):\n", 339 | " for j in range(m):\n", 340 | " k = i * m + j\n", 341 | " A[k * 5:(k + 1) * 5, :] = np.c_[a[:, 0] + j, a[:, 1] + i]\n", 342 | " Tr[k * 4:(k + 1) * 4, :] = tr + k * 5\n", 343 | " C = np.c_[action_values[:, :, 3].flatten(), action_values[:, :, 2].flatten(),\n", 344 | " action_values[:, :, 1].flatten(), action_values[:, :, 0].flatten()].flatten()\n", 345 | "\n", 346 | " ax.triplot(A[:, 0], A[:, 1], Tr, **triplotkw)\n", 347 | " tripcolor = ax.tripcolor(A[:, 0], A[:, 1], Tr, facecolors=C, **tripcolorkw)\n", 348 | " return tripcolor\n", 349 | "\n", 350 | "\n" 351 | ], 352 | "metadata": { 353 | "cellView": "form", 354 | "id": "ykuSUp-6qM0x" 355 | }, 356 | "execution_count": null, 357 | "outputs": [] 358 | }, 359 | { 360 | "cell_type": "markdown", 361 | "metadata": { 362 | "id": "b4FlUGb-qBFL" 363 | }, 364 | "source": [ 365 | "## Import the necessary software libraries:" 366 | ] 367 | }, 368 | { 369 | "cell_type": "code", 370 | "execution_count": null, 371 | "metadata": { 372 | "id": "XZ5pjEI9qBFM" 373 | }, 374 | "outputs": [], 375 | "source": [ 376 | "import numpy as np\n", 377 | "import matplotlib.pyplot as plt" 378 | ] 379 | }, 380 | { 381 | "cell_type": "markdown", 382 | "metadata": { 383 | "id": "uPOLSChNqBFO" 384 | }, 385 | "source": [ 386 | "## Create the environment, value table and policy" 387 | ] 388 | }, 389 | { 390 | "cell_type": "markdown", 391 | "metadata": { 392 | "id": "n8I0R27cqBFO" 393 | }, 394 | "source": [ 395 | "#### Create the environment" 396 | ] 397 | }, 398 | { 399 | "cell_type": "code", 400 | "execution_count": null, 401 | "metadata": { 402 | "id": "vKz0yD5hqBFP" 403 | }, 404 | "outputs": [], 405 | "source": [] 406 | }, 407 | { 408 | "cell_type": "markdown", 409 | "metadata": { 410 | "id": "1XinBDNSqBFP" 411 | }, 412 | "source": [ 413 | "#### Create the $Q(s, a)$ table" 414 | ] 415 | }, 416 | { 417 | "cell_type": "code", 418 | "execution_count": null, 419 | "metadata": { 420 | "id": "sXX5YdTnqBFQ" 421 | }, 422 | "outputs": [], 423 | "source": [] 424 | }, 425 | { 426 | "cell_type": "markdown", 427 | "metadata": { 428 | "id": "nB_aaY8WqBFQ" 429 | }, 430 | "source": [ 431 | "#### Create the policy $\\pi(s)$" 432 | ] 433 | }, 434 | { 435 | "cell_type": "code", 436 | "execution_count": null, 437 | "metadata": { 438 | "id": "zcG4n5L8qBFR" 439 | }, 440 | "outputs": [], 441 | "source": [] 442 | }, 443 | { 444 | "cell_type": "markdown", 445 | "metadata": { 446 | "id": "KbwXSKQCqBFR" 447 | }, 448 | "source": [ 449 | "#### Plot the value table $Q(s,a)$" 450 | ] 451 | }, 452 | { 453 | "cell_type": "code", 454 | "execution_count": null, 455 | "metadata": { 456 | "id": "RPJAHlTRqBFR" 457 | }, 458 | "outputs": [], 459 | "source": [ 460 | "plot_action_values(action_values)" 461 | ] 462 | }, 463 | { 464 | "cell_type": "markdown", 465 | "metadata": { 466 | "id": "ldE0EH4kqBFR" 467 | }, 468 | "source": [ 469 | "#### Plot the policy" 470 | ] 471 | }, 472 | { 473 | "cell_type": "code", 474 | "execution_count": null, 475 | "metadata": { 476 | "scrolled": false, 477 | "id": "uTMuUFlYqBFS" 478 | }, 479 | "outputs": [], 480 | "source": [ 481 | "plot_policy(action_values, env.render(mode='rgb_array'))" 482 | ] 483 | }, 484 | { 485 | "cell_type": "markdown", 486 | "metadata": { 487 | "id": "rOJiawyuqBFS" 488 | }, 489 | "source": [ 490 | "## Implement the algorithm\n" 491 | ] 492 | }, 493 | { 494 | "cell_type": "code", 495 | "execution_count": null, 496 | "metadata": { 497 | "id": "FrfUBijoqBFS" 498 | }, 499 | "outputs": [], 500 | "source": [] 501 | }, 502 | { 503 | "cell_type": "code", 504 | "execution_count": null, 505 | "metadata": { 506 | "id": "EhRr4G7pqBFS" 507 | }, 508 | "outputs": [], 509 | "source": [] 510 | }, 511 | { 512 | "cell_type": "markdown", 513 | "metadata": { 514 | "id": "lB7zVHxDqBFS" 515 | }, 516 | "source": [ 517 | "## Show results" 518 | ] 519 | }, 520 | { 521 | "cell_type": "markdown", 522 | "metadata": { 523 | "id": "VafpyTWkqBFT" 524 | }, 525 | "source": [ 526 | "#### Show resulting value table $Q(s,a)$" 527 | ] 528 | }, 529 | { 530 | "cell_type": "code", 531 | "execution_count": null, 532 | "metadata": { 533 | "id": "i-uQ6vb-qBFT" 534 | }, 535 | "outputs": [], 536 | "source": [ 537 | "plot_action_values(action_values)" 538 | ] 539 | }, 540 | { 541 | "cell_type": "markdown", 542 | "metadata": { 543 | "id": "lsXpPsKnqBFT" 544 | }, 545 | "source": [ 546 | "#### Show resulting policy $\\pi(\\cdot|s)$" 547 | ] 548 | }, 549 | { 550 | "cell_type": "code", 551 | "execution_count": null, 552 | "metadata": { 553 | "id": "mZKL4RXvqBFT" 554 | }, 555 | "outputs": [], 556 | "source": [ 557 | "plot_policy(action_values, env.render(mode='rgb_array'))" 558 | ] 559 | }, 560 | { 561 | "cell_type": "markdown", 562 | "metadata": { 563 | "id": "AoJbX_79qBFT" 564 | }, 565 | "source": [ 566 | "#### Test the resulting agent" 567 | ] 568 | }, 569 | { 570 | "cell_type": "code", 571 | "execution_count": null, 572 | "metadata": { 573 | "id": "ckonhYqBqBFT" 574 | }, 575 | "outputs": [], 576 | "source": [ 577 | "test_agent(env, policy)" 578 | ] 579 | }, 580 | { 581 | "cell_type": "markdown", 582 | "metadata": { 583 | "id": "ZmJOa0huqBFT" 584 | }, 585 | "source": [ 586 | "## Resources" 587 | ] 588 | }, 589 | { 590 | "cell_type": "markdown", 591 | "metadata": { 592 | "id": "joHvXmxXqBFU" 593 | }, 594 | "source": [ 595 | "[[1] Reinforcement Learning: An Introduction. Ch. 6: Temporal difference learning](https://web.stanford.edu/class/psych209/Readings/SuttonBartoIPRLBook2ndEd.pdf)" 596 | ] 597 | } 598 | ], 599 | "metadata": { 600 | "kernelspec": { 601 | "display_name": "Python 3", 602 | "language": "python", 603 | "name": "python3" 604 | }, 605 | "language_info": { 606 | "codemirror_mode": { 607 | "name": "ipython", 608 | "version": 3 609 | }, 610 | "file_extension": ".py", 611 | "mimetype": "text/x-python", 612 | "name": "python", 613 | "nbconvert_exporter": "python", 614 | "pygments_lexer": "ipython3", 615 | "version": "3.8.5" 616 | }, 617 | "colab": { 618 | "provenance": [] 619 | } 620 | }, 621 | "nbformat": 4, 622 | "nbformat_minor": 0 623 | } -------------------------------------------------------------------------------- /Section_6_n_step_sarsa.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "pycharm": { 7 | "name": "#%%\n" 8 | }, 9 | "id": "Kbe3qi92tJh8" 10 | }, 11 | "source": [ 12 | "
\n", 13 | "

\n", 14 | " n-step SARSA\n", 15 | "

\n", 16 | "
\n", 17 | "\n", 18 | "

\n", 19 | "\n", 20 | "
\n", 21 | " In this notebook we are going to combine the temporal difference method SARSA with n-step bootstrapping. The resulting algorithm is called n-step SARSA and uses the following target for the updates:\n", 22 | "
\n", 23 | "\n", 24 | "\\begin{equation}\n", 25 | "\\hat G_t = R_{t+1} + \\gamma R_{t+2} + \\cdots + \\gamma^{n-1} R_{n} + \\gamma Q(S_n, A_n)\n", 26 | "\\end{equation}\n", 27 | "\n", 28 | "
\n", 29 | "\n", 30 | "
\n", 31 | " This method follows an on-policy strategy, in which the same policy that is optimized is responsible for exploring the environment.\n", 32 | "
\n", 33 | "\n" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "source": [ 39 | "# @title Setup code (not important) - Run this cell by pressing \"Shift + Enter\"\n", 40 | "\n", 41 | "\n", 42 | "\n", 43 | "!pip install -qq gym==0.23.0\n", 44 | "\n", 45 | "\n", 46 | "from typing import Tuple, Dict, Optional, Iterable, Callable\n", 47 | "\n", 48 | "import numpy as np\n", 49 | "import seaborn as sns\n", 50 | "import matplotlib\n", 51 | "from matplotlib import animation\n", 52 | "\n", 53 | "from IPython.display import HTML\n", 54 | "\n", 55 | "import gym\n", 56 | "from gym import spaces\n", 57 | "from gym.error import DependencyNotInstalled\n", 58 | "\n", 59 | "import pygame\n", 60 | "from pygame import gfxdraw\n", 61 | "\n", 62 | "\n", 63 | "class Maze(gym.Env):\n", 64 | "\n", 65 | " def __init__(self, exploring_starts: bool = False,\n", 66 | " shaped_rewards: bool = False, size: int = 5) -> None:\n", 67 | " super().__init__()\n", 68 | " self.exploring_starts = exploring_starts\n", 69 | " self.shaped_rewards = shaped_rewards\n", 70 | " self.state = (size - 1, size - 1)\n", 71 | " self.goal = (size - 1, size - 1)\n", 72 | " self.maze = self._create_maze(size=size)\n", 73 | " self.distances = self._compute_distances(self.goal, self.maze)\n", 74 | " self.action_space = spaces.Discrete(n=4)\n", 75 | " self.action_space.action_meanings = {0: 'UP', 1: 'RIGHT', 2: 'DOWN', 3: \"LEFT\"}\n", 76 | " self.observation_space = spaces.MultiDiscrete([size, size])\n", 77 | "\n", 78 | " self.screen = None\n", 79 | " self.agent_transform = None\n", 80 | "\n", 81 | " def step(self, action: int) -> Tuple[Tuple[int, int], float, bool, Dict]:\n", 82 | " reward = self.compute_reward(self.state, action)\n", 83 | " self.state = self._get_next_state(self.state, action)\n", 84 | " done = self.state == self.goal\n", 85 | " info = {}\n", 86 | " return self.state, reward, done, info\n", 87 | "\n", 88 | " def reset(self) -> Tuple[int, int]:\n", 89 | " if self.exploring_starts:\n", 90 | " while self.state == self.goal:\n", 91 | " self.state = tuple(self.observation_space.sample())\n", 92 | " else:\n", 93 | " self.state = (0, 0)\n", 94 | " return self.state\n", 95 | "\n", 96 | " def render(self, mode: str = 'human') -> Optional[np.ndarray]:\n", 97 | " assert mode in ['human', 'rgb_array']\n", 98 | "\n", 99 | " screen_size = 600\n", 100 | " scale = screen_size / 5\n", 101 | "\n", 102 | " if self.screen is None:\n", 103 | " pygame.init()\n", 104 | " self.screen = pygame.Surface((screen_size, screen_size))\n", 105 | "\n", 106 | " surf = pygame.Surface((screen_size, screen_size))\n", 107 | " surf.fill((22, 36, 71))\n", 108 | "\n", 109 | "\n", 110 | " for row in range(5):\n", 111 | " for col in range(5):\n", 112 | "\n", 113 | " state = (row, col)\n", 114 | " for next_state in [(row + 1, col), (row - 1, col), (row, col + 1), (row, col - 1)]:\n", 115 | " if next_state not in self.maze[state]:\n", 116 | "\n", 117 | " # Add the geometry of the edges and walls (i.e. the boundaries between\n", 118 | " # adjacent squares that are not connected).\n", 119 | " row_diff, col_diff = np.subtract(next_state, state)\n", 120 | " left = (col + (col_diff > 0)) * scale - 2 * (col_diff != 0)\n", 121 | " right = ((col + 1) - (col_diff < 0)) * scale + 2 * (col_diff != 0)\n", 122 | " top = (5 - (row + (row_diff > 0))) * scale - 2 * (row_diff != 0)\n", 123 | " bottom = (5 - ((row + 1) - (row_diff < 0))) * scale + 2 * (row_diff != 0)\n", 124 | "\n", 125 | " gfxdraw.filled_polygon(surf, [(left, bottom), (left, top), (right, top), (right, bottom)], (255, 255, 255))\n", 126 | "\n", 127 | " # Add the geometry of the goal square to the viewer.\n", 128 | " left, right, top, bottom = scale * 4 + 10, scale * 5 - 10, scale - 10, 10\n", 129 | " gfxdraw.filled_polygon(surf, [(left, bottom), (left, top), (right, top), (right, bottom)], (40, 199, 172))\n", 130 | "\n", 131 | " # Add the geometry of the agent to the viewer.\n", 132 | " agent_row = int(screen_size - scale * (self.state[0] + .5))\n", 133 | " agent_col = int(scale * (self.state[1] + .5))\n", 134 | " gfxdraw.filled_circle(surf, agent_col, agent_row, int(scale * .6 / 2), (228, 63, 90))\n", 135 | "\n", 136 | " surf = pygame.transform.flip(surf, False, True)\n", 137 | " self.screen.blit(surf, (0, 0))\n", 138 | "\n", 139 | " return np.transpose(\n", 140 | " np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2)\n", 141 | " )\n", 142 | "\n", 143 | " def close(self) -> None:\n", 144 | " if self.screen is not None:\n", 145 | " pygame.display.quit()\n", 146 | " pygame.quit()\n", 147 | " self.screen = None\n", 148 | "\n", 149 | " def compute_reward(self, state: Tuple[int, int], action: int) -> float:\n", 150 | " next_state = self._get_next_state(state, action)\n", 151 | " if self.shaped_rewards:\n", 152 | " return - (self.distances[next_state] / self.distances.max())\n", 153 | " return - float(state != self.goal)\n", 154 | "\n", 155 | " def simulate_step(self, state: Tuple[int, int], action: int):\n", 156 | " reward = self.compute_reward(state, action)\n", 157 | " next_state = self._get_next_state(state, action)\n", 158 | " done = next_state == self.goal\n", 159 | " info = {}\n", 160 | " return next_state, reward, done, info\n", 161 | "\n", 162 | " def _get_next_state(self, state: Tuple[int, int], action: int) -> Tuple[int, int]:\n", 163 | " if action == 0:\n", 164 | " next_state = (state[0] - 1, state[1])\n", 165 | " elif action == 1:\n", 166 | " next_state = (state[0], state[1] + 1)\n", 167 | " elif action == 2:\n", 168 | " next_state = (state[0] + 1, state[1])\n", 169 | " elif action == 3:\n", 170 | " next_state = (state[0], state[1] - 1)\n", 171 | " else:\n", 172 | " raise ValueError(\"Action value not supported:\", action)\n", 173 | " if next_state in self.maze[state]:\n", 174 | " return next_state\n", 175 | " return state\n", 176 | "\n", 177 | " @staticmethod\n", 178 | " def _create_maze(size: int) -> Dict[Tuple[int, int], Iterable[Tuple[int, int]]]:\n", 179 | " maze = {(row, col): [(row - 1, col), (row + 1, col), (row, col - 1), (row, col + 1)]\n", 180 | " for row in range(size) for col in range(size)}\n", 181 | "\n", 182 | " left_edges = [[(row, 0), (row, -1)] for row in range(size)]\n", 183 | " right_edges = [[(row, size - 1), (row, size)] for row in range(size)]\n", 184 | " upper_edges = [[(0, col), (-1, col)] for col in range(size)]\n", 185 | " lower_edges = [[(size - 1, col), (size, col)] for col in range(size)]\n", 186 | " walls = [\n", 187 | " [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)],\n", 188 | " [(1, 1), (1, 2)], [(2, 1), (2, 2)], [(3, 1), (3, 2)],\n", 189 | " [(3, 1), (4, 1)], [(0, 2), (1, 2)], [(1, 2), (1, 3)],\n", 190 | " [(2, 2), (3, 2)], [(2, 3), (3, 3)], [(2, 4), (3, 4)],\n", 191 | " [(4, 2), (4, 3)], [(1, 3), (1, 4)], [(2, 3), (2, 4)],\n", 192 | " ]\n", 193 | "\n", 194 | " obstacles = upper_edges + lower_edges + left_edges + right_edges + walls\n", 195 | "\n", 196 | " for src, dst in obstacles:\n", 197 | " maze[src].remove(dst)\n", 198 | "\n", 199 | " if dst in maze:\n", 200 | " maze[dst].remove(src)\n", 201 | "\n", 202 | " return maze\n", 203 | "\n", 204 | " @staticmethod\n", 205 | " def _compute_distances(goal: Tuple[int, int],\n", 206 | " maze: Dict[Tuple[int, int], Iterable[Tuple[int, int]]]) -> np.ndarray:\n", 207 | " distances = np.full((5, 5), np.inf)\n", 208 | " visited = set()\n", 209 | " distances[goal] = 0.\n", 210 | "\n", 211 | " while visited != set(maze):\n", 212 | " sorted_dst = [(v // 5, v % 5) for v in distances.argsort(axis=None)]\n", 213 | " closest = next(x for x in sorted_dst if x not in visited)\n", 214 | " visited.add(closest)\n", 215 | "\n", 216 | " for neighbour in maze[closest]:\n", 217 | " distances[neighbour] = min(distances[neighbour], distances[closest] + 1)\n", 218 | " return distances\n", 219 | "\n", 220 | "\n", 221 | "def plot_policy(probs_or_qvals, frame, action_meanings=None):\n", 222 | " if action_meanings is None:\n", 223 | " action_meanings = {0: 'U', 1: 'R', 2: 'D', 3: 'L'}\n", 224 | " fig, axes = plt.subplots(1, 2, figsize=(8, 4))\n", 225 | " max_prob_actions = probs_or_qvals.argmax(axis=-1)\n", 226 | " probs_copy = max_prob_actions.copy().astype(object)\n", 227 | " for key in action_meanings:\n", 228 | " probs_copy[probs_copy == key] = action_meanings[key]\n", 229 | " sns.heatmap(max_prob_actions, annot=probs_copy, fmt='', cbar=False, cmap='coolwarm',\n", 230 | " annot_kws={'weight': 'bold', 'size': 12}, linewidths=2, ax=axes[0])\n", 231 | " axes[1].imshow(frame)\n", 232 | " axes[0].axis('off')\n", 233 | " axes[1].axis('off')\n", 234 | " plt.suptitle(\"Policy\", size=18)\n", 235 | " plt.tight_layout()\n", 236 | "\n", 237 | "\n", 238 | "def plot_values(state_values, frame):\n", 239 | " f, axes = plt.subplots(1, 2, figsize=(10, 4))\n", 240 | " sns.heatmap(state_values, annot=True, fmt=\".2f\", cmap='coolwarm',\n", 241 | " annot_kws={'weight': 'bold', 'size': 12}, linewidths=2, ax=axes[0])\n", 242 | " axes[1].imshow(frame)\n", 243 | " axes[0].axis('off')\n", 244 | " axes[1].axis('off')\n", 245 | " plt.tight_layout()\n", 246 | "\n", 247 | "\n", 248 | "def display_video(frames):\n", 249 | " # Copied from: https://colab.research.google.com/github/deepmind/dm_control/blob/master/tutorial.ipynb\n", 250 | " orig_backend = matplotlib.get_backend()\n", 251 | " matplotlib.use('Agg')\n", 252 | " fig, ax = plt.subplots(1, 1, figsize=(5, 5))\n", 253 | " matplotlib.use(orig_backend)\n", 254 | " ax.set_axis_off()\n", 255 | " ax.set_aspect('equal')\n", 256 | " ax.set_position([0, 0, 1, 1])\n", 257 | " im = ax.imshow(frames[0])\n", 258 | " def update(frame):\n", 259 | " im.set_data(frame)\n", 260 | " return [im]\n", 261 | " anim = animation.FuncAnimation(fig=fig, func=update, frames=frames,\n", 262 | " interval=50, blit=True, repeat=False)\n", 263 | " return HTML(anim.to_html5_video())\n", 264 | "\n", 265 | "\n", 266 | "def test_agent(environment, policy, episodes=10):\n", 267 | " frames = []\n", 268 | " for episode in range(episodes):\n", 269 | " state = env.reset()\n", 270 | " done = False\n", 271 | " frames.append(env.render(mode=\"rgb_array\"))\n", 272 | "\n", 273 | " while not done:\n", 274 | " p = policy(state)\n", 275 | " if isinstance(p, np.ndarray):\n", 276 | " action = np.random.choice(4, p=p)\n", 277 | " else:\n", 278 | " action = p\n", 279 | " next_state, reward, done, extra_info = env.step(action)\n", 280 | " img = env.render(mode=\"rgb_array\")\n", 281 | " frames.append(img)\n", 282 | " state = next_state\n", 283 | "\n", 284 | " return display_video(frames)\n", 285 | "\n", 286 | "\n", 287 | "def plot_action_values(action_values):\n", 288 | "\n", 289 | " text_positions = [\n", 290 | " [(0.35, 4.75), (1.35, 4.75), (2.35, 4.75), (3.35, 4.75), (4.35, 4.75),\n", 291 | " (0.35, 3.75), (1.35, 3.75), (2.35, 3.75), (3.35, 3.75), (4.35, 3.75),\n", 292 | " (0.35, 2.75), (1.35, 2.75), (2.35, 2.75), (3.35, 2.75), (4.35, 2.75),\n", 293 | " (0.35, 1.75), (1.35, 1.75), (2.35, 1.75), (3.35, 1.75), (4.35, 1.75),\n", 294 | " (0.35, 0.75), (1.35, 0.75), (2.35, 0.75), (3.35, 0.75), (4.35, 0.75)],\n", 295 | " [(0.6, 4.45), (1.6, 4.45), (2.6, 4.45), (3.6, 4.45), (4.6, 4.45),\n", 296 | " (0.6, 3.45), (1.6, 3.45), (2.6, 3.45), (3.6, 3.45), (4.6, 3.45),\n", 297 | " (0.6, 2.45), (1.6, 2.45), (2.6, 2.45), (3.6, 2.45), (4.6, 2.45),\n", 298 | " (0.6, 1.45), (1.6, 1.45), (2.6, 1.45), (3.6, 1.45), (4.6, 1.45),\n", 299 | " (0.6, 0.45), (1.6, 0.45), (2.6, 0.45), (3.6, 0.45), (4.6, 0.45)],\n", 300 | " [(0.35, 4.15), (1.35, 4.15), (2.35, 4.15), (3.35, 4.15), (4.35, 4.15),\n", 301 | " (0.35, 3.15), (1.35, 3.15), (2.35, 3.15), (3.35, 3.15), (4.35, 3.15),\n", 302 | " (0.35, 2.15), (1.35, 2.15), (2.35, 2.15), (3.35, 2.15), (4.35, 2.15),\n", 303 | " (0.35, 1.15), (1.35, 1.15), (2.35, 1.15), (3.35, 1.15), (4.35, 1.15),\n", 304 | " (0.35, 0.15), (1.35, 0.15), (2.35, 0.15), (3.35, 0.15), (4.35, 0.15)],\n", 305 | " [(0.05, 4.45), (1.05, 4.45), (2.05, 4.45), (3.05, 4.45), (4.05, 4.45),\n", 306 | " (0.05, 3.45), (1.05, 3.45), (2.05, 3.45), (3.05, 3.45), (4.05, 3.45),\n", 307 | " (0.05, 2.45), (1.05, 2.45), (2.05, 2.45), (3.05, 2.45), (4.05, 2.45),\n", 308 | " (0.05, 1.45), (1.05, 1.45), (2.05, 1.45), (3.05, 1.45), (4.05, 1.45),\n", 309 | " (0.05, 0.45), (1.05, 0.45), (2.05, 0.45), (3.05, 0.45), (4.05, 0.45)]]\n", 310 | "\n", 311 | " fig, ax = plt.subplots(figsize=(7, 7))\n", 312 | " tripcolor = quatromatrix(action_values, ax=ax,\n", 313 | " triplotkw={\"color\": \"k\", \"lw\": 1}, tripcolorkw={\"cmap\": \"coolwarm\"})\n", 314 | " ax.margins(0)\n", 315 | " ax.set_aspect(\"equal\")\n", 316 | " fig.colorbar(tripcolor)\n", 317 | "\n", 318 | " for j, av in enumerate(text_positions):\n", 319 | " for i, (xi, yi) in enumerate(av):\n", 320 | " plt.text(xi, yi, round(action_values[:, :, j].flatten()[i], 2), size=8, color=\"w\", weight=\"bold\")\n", 321 | "\n", 322 | " plt.title(\"Action values Q(s,a)\", size=18)\n", 323 | " plt.tight_layout()\n", 324 | " plt.show()\n", 325 | "\n", 326 | "\n", 327 | "def quatromatrix(action_values, ax=None, triplotkw=None, tripcolorkw=None):\n", 328 | " action_values = np.flipud(action_values)\n", 329 | " n = 5\n", 330 | " m = 5\n", 331 | " a = np.array([[0, 0], [0, 1], [.5, .5], [1, 0], [1, 1]])\n", 332 | " tr = np.array([[0, 1, 2], [0, 2, 3], [2, 3, 4], [1, 2, 4]])\n", 333 | " A = np.zeros((n * m * 5, 2))\n", 334 | " Tr = np.zeros((n * m * 4, 3))\n", 335 | " for i in range(n):\n", 336 | " for j in range(m):\n", 337 | " k = i * m + j\n", 338 | " A[k * 5:(k + 1) * 5, :] = np.c_[a[:, 0] + j, a[:, 1] + i]\n", 339 | " Tr[k * 4:(k + 1) * 4, :] = tr + k * 5\n", 340 | " C = np.c_[action_values[:, :, 3].flatten(), action_values[:, :, 2].flatten(),\n", 341 | " action_values[:, :, 1].flatten(), action_values[:, :, 0].flatten()].flatten()\n", 342 | "\n", 343 | " ax.triplot(A[:, 0], A[:, 1], Tr, **triplotkw)\n", 344 | " tripcolor = ax.tripcolor(A[:, 0], A[:, 1], Tr, facecolors=C, **tripcolorkw)\n", 345 | " return tripcolor\n", 346 | "\n", 347 | "\n" 348 | ], 349 | "metadata": { 350 | "cellView": "form", 351 | "id": "4xWInwTqtqOt" 352 | }, 353 | "execution_count": null, 354 | "outputs": [] 355 | }, 356 | { 357 | "cell_type": "markdown", 358 | "metadata": { 359 | "id": "QZhkAlOYtJh_" 360 | }, 361 | "source": [ 362 | "## Import the necessary software libraries:" 363 | ] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "execution_count": null, 368 | "metadata": { 369 | "id": "mL6kB3EKtJh_" 370 | }, 371 | "outputs": [], 372 | "source": [ 373 | "import numpy as np\n", 374 | "import matplotlib.pyplot as plt" 375 | ] 376 | }, 377 | { 378 | "cell_type": "markdown", 379 | "metadata": { 380 | "id": "cBt6vFhEtJiA" 381 | }, 382 | "source": [ 383 | "## Create the environment, value table and policy" 384 | ] 385 | }, 386 | { 387 | "cell_type": "markdown", 388 | "metadata": { 389 | "id": "xp1nhsSVtJiA" 390 | }, 391 | "source": [ 392 | "#### Create the environment" 393 | ] 394 | }, 395 | { 396 | "cell_type": "code", 397 | "execution_count": null, 398 | "metadata": { 399 | "id": "7UyBHR1EtJiB" 400 | }, 401 | "outputs": [], 402 | "source": [ 403 | "env = Maze()" 404 | ] 405 | }, 406 | { 407 | "cell_type": "markdown", 408 | "metadata": { 409 | "id": "OGXvY4R-tJiB" 410 | }, 411 | "source": [ 412 | "#### Create the $Q(s, a)$ table" 413 | ] 414 | }, 415 | { 416 | "cell_type": "code", 417 | "execution_count": null, 418 | "metadata": { 419 | "id": "R8FDqx-YtJiB" 420 | }, 421 | "outputs": [], 422 | "source": [ 423 | "action_values = np.zeros(shape=(5, 5, 4))" 424 | ] 425 | }, 426 | { 427 | "cell_type": "markdown", 428 | "metadata": { 429 | "id": "CsBBY99_tJiC" 430 | }, 431 | "source": [ 432 | "#### Create the policy $\\pi(s)$" 433 | ] 434 | }, 435 | { 436 | "cell_type": "code", 437 | "execution_count": null, 438 | "metadata": { 439 | "id": "qiU2V00ytJiC" 440 | }, 441 | "outputs": [], 442 | "source": [ 443 | "def policy(state, epsilon=0.):\n", 444 | " if np.random.random() < epsilon:\n", 445 | " return np.random.randint(4)\n", 446 | " else:\n", 447 | " av = action_values[state]\n", 448 | " return np.random.choice(np.flatnonzero(av == av.max()))" 449 | ] 450 | }, 451 | { 452 | "cell_type": "markdown", 453 | "metadata": { 454 | "id": "UxmVgniwtJiC" 455 | }, 456 | "source": [ 457 | "#### Plot the value table $Q(s,a)$" 458 | ] 459 | }, 460 | { 461 | "cell_type": "code", 462 | "execution_count": null, 463 | "metadata": { 464 | "id": "o7h36l_mtJiC" 465 | }, 466 | "outputs": [], 467 | "source": [ 468 | "plot_action_values(action_values)" 469 | ] 470 | }, 471 | { 472 | "cell_type": "markdown", 473 | "metadata": { 474 | "id": "-0zAZkS2tJiD" 475 | }, 476 | "source": [ 477 | "#### Plot the policy" 478 | ] 479 | }, 480 | { 481 | "cell_type": "code", 482 | "execution_count": null, 483 | "metadata": { 484 | "scrolled": false, 485 | "id": "trn2wXKstJiD" 486 | }, 487 | "outputs": [], 488 | "source": [ 489 | "plot_policy(action_values, env.render(mode='rgb_array'))" 490 | ] 491 | }, 492 | { 493 | "cell_type": "markdown", 494 | "metadata": { 495 | "id": "IETb2I2PtJiD" 496 | }, 497 | "source": [ 498 | "## Implement the algorithm\n" 499 | ] 500 | }, 501 | { 502 | "cell_type": "code", 503 | "execution_count": null, 504 | "metadata": { 505 | "id": "WAB21vpatJiD" 506 | }, 507 | "outputs": [], 508 | "source": [] 509 | }, 510 | { 511 | "cell_type": "code", 512 | "execution_count": null, 513 | "metadata": { 514 | "scrolled": false, 515 | "id": "o1SfGVcbtJiD" 516 | }, 517 | "outputs": [], 518 | "source": [] 519 | }, 520 | { 521 | "cell_type": "markdown", 522 | "metadata": { 523 | "id": "zohenO_EtJiD" 524 | }, 525 | "source": [ 526 | "## Show results" 527 | ] 528 | }, 529 | { 530 | "cell_type": "markdown", 531 | "metadata": { 532 | "id": "Fm_sKMEwtJiD" 533 | }, 534 | "source": [ 535 | "#### Show resulting value table $Q(s, a)$" 536 | ] 537 | }, 538 | { 539 | "cell_type": "code", 540 | "execution_count": null, 541 | "metadata": { 542 | "id": "_AGIajOitJiD" 543 | }, 544 | "outputs": [], 545 | "source": [ 546 | "plot_action_values(action_values)" 547 | ] 548 | }, 549 | { 550 | "cell_type": "markdown", 551 | "metadata": { 552 | "id": "yHgLEGQ6tJiD" 553 | }, 554 | "source": [ 555 | "#### Show resulting policy $\\pi(\\cdot|s)$" 556 | ] 557 | }, 558 | { 559 | "cell_type": "code", 560 | "execution_count": null, 561 | "metadata": { 562 | "id": "iEQ1P6lstJiD" 563 | }, 564 | "outputs": [], 565 | "source": [ 566 | "plot_policy(action_values, env.render(mode='rgb_array'))" 567 | ] 568 | }, 569 | { 570 | "cell_type": "markdown", 571 | "metadata": { 572 | "id": "8U2qqgTytJiD" 573 | }, 574 | "source": [ 575 | "#### Test the resulting agent" 576 | ] 577 | }, 578 | { 579 | "cell_type": "code", 580 | "execution_count": null, 581 | "metadata": { 582 | "id": "TTJGndBmtJiD" 583 | }, 584 | "outputs": [], 585 | "source": [ 586 | "test_agent(env, policy)" 587 | ] 588 | }, 589 | { 590 | "cell_type": "markdown", 591 | "metadata": { 592 | "id": "8fP_400RtJiE" 593 | }, 594 | "source": [ 595 | "## Resources" 596 | ] 597 | }, 598 | { 599 | "cell_type": "markdown", 600 | "metadata": { 601 | "id": "nen4XtVftJiE" 602 | }, 603 | "source": [ 604 | "[[1] Reinforcement Learning: An Introduction. Ch. 7: n-step bootstrapping](https://web.stanford.edu/class/psych209/Readings/SuttonBartoIPRLBook2ndEd.pdf)" 605 | ] 606 | } 607 | ], 608 | "metadata": { 609 | "kernelspec": { 610 | "display_name": "Python 3", 611 | "language": "python", 612 | "name": "python3" 613 | }, 614 | "language_info": { 615 | "codemirror_mode": { 616 | "name": "ipython", 617 | "version": 3 618 | }, 619 | "file_extension": ".py", 620 | "mimetype": "text/x-python", 621 | "name": "python", 622 | "nbconvert_exporter": "python", 623 | "pygments_lexer": "ipython3", 624 | "version": "3.8.5" 625 | }, 626 | "colab": { 627 | "provenance": [] 628 | } 629 | }, 630 | "nbformat": 4, 631 | "nbformat_minor": 0 632 | } --------------------------------------------------------------------------------