├── LICENSE ├── README.md ├── Glossary.md └── Intro_to_Deep_RL_Part_1.ipynb /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Logan Thomson 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Reinforcement Learning Course 2 | 3 | This course aims to teach from the basics of RL to advanced algorithms such as PPO. 4 | 5 | [![GitHub stars](https://img.shields.io/github/stars/xycoord/deep-rl-course?style=social)](https://github.com/xycoord/deep-rl-course/stargazers) 6 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 7 | 8 | ## 📋 Prerequisites 9 | 10 | - Machine Learning (Gradient Descent, Neural Networks) 11 | - Basic Probability Theory (Expectations and Distributions) 12 | - Multivariate Calculus 13 | - Python and PyTorch 14 | 15 | ## 📚 Course Style 16 | 17 | Each module consists of: 18 | - Formal mathematical definitions and theory 19 | - Step-by-step algorithm derivations 20 | - Complete PyTorch implementations 21 | - Runnable experiments 22 | 23 | I recommend working through the notebooks carefully, especially the mathematical derivations and proofs, and ensuring you understand each concept before moving on. This material is designed to be precise and concise so that you can learn efficiently - without rushing through. 24 | 25 | ## 🗺️ Course Roadmap 26 | 27 | | Module | Topic | Colab | Key Concepts | 28 | |--------|-------|--------|--------------| 29 | | Part 1 | RL Basics & Policy Gradients | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Lm_TI-Vrzai-WZQeZL3o7US07vVKWXlQ) | MDPs, Policies, Trajectories, Policy Gradient Theorem, Reward-to-go | 30 | | Part 2 | Discounting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1UULTQYnymQOpa7nuaw6mDXnvWRV9R_2y) | Temporal Discounting, Convergence of Infinite Horizons, Variance Reduction | 31 | | Part 3 | Baselines & Advantage Estimation | ![Coming Soon](https://img.shields.io/badge/Coming_Soon-gray?logo=google-colab&logoColor=white) | Value Functions, Advantage Functions, Variance Reduction | 32 | | Part 4 | Actor-Critic Methods | ![Coming Soon](https://img.shields.io/badge/Coming_Soon-gray?logo=google-colab&logoColor=white) | Value Approximation, Policy-Value Learning | 33 | | Part 5 | GAE | ![Coming Soon](https://img.shields.io/badge/Coming_Soon-gray?logo=google-colab&logoColor=white) | Generalised Advantage Estimation | 34 | | Part 6 | PPO | ![Coming Soon](https://img.shields.io/badge/Coming_Soon-gray?logo=google-colab&logoColor=white) | Proximal Policy Optimisation | 35 | -------------------------------------------------------------------------------- /Glossary.md: -------------------------------------------------------------------------------- 1 | # Glossary 2 | 3 | > **Reinforcement Learning** is a framework in which an *agent* learns to perform tasks through trial-and-error interaction with its *environment*, receiving *rewards/penalties* as feedback from actions. 4 | 5 | > An **environment** is a 5-tuple $(\mathcal S, \mathcal A, p, p_0, r)$ where: 6 | > - $\mathcal S$ is the State Space, the set of all states. A state $s\in\mathcal S$ is a complete description of the environment 7 | > - $\mathcal A$ is the Action Space, the set of all actions an agent can take in the environment 8 | > - $p: \mathcal S\times\mathcal A\times\mathcal S\to[0,1]$ is the transition distribution s.t. $p(s_{t+1} | s_t,a_t)$ is the probability that the environment transitions from state $s_t$ to state $s_{t+1}$ when action $a_t$ is taken 9 | > - $p_0:\mathcal S\to[0,1]$ is the initial state distribution s.t. $p_0(s)$ is the probability that $s_0 = s$ 10 | > - $r:\mathcal S\times\mathcal A\to\mathbb R$ is the reward function s.t. $r(s_t,a_t)$ is the reward given when the environment is in state $s_t$ and action $a_t$ is taken 11 | 12 | > The **Markov Property**: the probability of transitioning to the next state $s_{t+1}$ depends only on the current state $s_t$ and $a_t$, not on the history of previous states or actions. 13 | 14 | > $\mathcal O$ is the **observation space**, the set of all possible observations. An **observation** $o\in\mathcal O$ is a partial description of the environment. 15 | 16 | > A **Terminal State** is one that cannot be left (such as winning, losing, dying etc.) 17 | 18 | > An **Episodic Task** has a *terminal state* 19 | > e.g. a game of chess, a level in a video game, landing a rocket 20 | 21 | > A **Continuing Task** has no *terminal state* 22 | > e.g. trading on the stock market, walking, keeping the pole upright 23 | 24 | > **Truncation** is the act of forcibly ending interaction with the environment before reaching a natural conclusion 25 | 26 | > An **agent** is an entity that interacts with an environment by taking actions according to its policy. 27 | 28 | > A **policy**, $\pi:\mathcal O\to P(\mathcal A)$, is a function from observations to distributions from which an action may be sampled (selected) given an observation: 29 | > $a_t \sim \pi(\cdot|o_t)$ 30 | > Equivalently, $\pi(a_t|o_t)$ is the probability that action $a_t$ is selected given observation $o_t$ 31 | 32 | > An **experience** is the triple $(o_t, a_t, r_t)$ 33 | 34 | > A **trajectory** $\tau$ is a sequence of experiences from an agent interacting with its environment according to its policy until it reaches a terminal state or is truncated. 35 | 36 | > A **reward**, $r_t$, is a numerical feedback signal provided by the reward function at timestep $t$ 37 | 38 | > A **reward function**, $r: \mathcal S \times \mathcal A \to \mathbb R$, is a function such that the accumulation of rewards over a trajectory corresponds to performance on the task in that trajectory. 39 | -------------------------------------------------------------------------------- /Intro_to_Deep_RL_Part_1.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "collapsed_sections": [ 8 | "df42qv7V8PRe" 9 | ] 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "language_info": { 16 | "name": "python" 17 | } 18 | }, 19 | "cells": [ 20 | { 21 | "cell_type": "markdown", 22 | "source": [ 23 | "# Introduction to Deep Reinforcement Learning\n", 24 | "\n", 25 | "**Author: Logan Thomson**\n", 26 | "\n", 27 | "[![GitHub](https://img.shields.io/badge/GitHub-Connect-181717?logo=github&logoColor=white)](https://github.com/xycoord)\n", 28 | "[![LinkedIn](https://img.shields.io/badge/LinkedIn-Connect-blue)](https://www.linkedin.com/in/logan-thomson-01a4942ab)\n", 29 | "\n", 30 | "[![GitHub](https://img.shields.io/badge/Star_Repository-gold?logo=github&logoColor=white)](https://github.com/xycoord/deep-rl-course)\n", 31 | "[![GitHub Issues](https://img.shields.io/badge/Submit_Issue-red?logo=github)](https://github.com/xycoord/deep-rl-course/issues)\n", 32 | "\n", 33 | "Welcome to this introduction to Deep Reinforcement Learning. In this notebook we introduce the core abstractions and framework used in reinforcement learning as well as the policy gradients algorithm.\n", 34 | "\n", 35 | "The approach is to both provide rigorous mathematical definitions and explanations along with implementations. I encourage students to fully read and understand all three of these and not rush ahead. I don't expect you to complete the notebook in one sitting if you are to fully understand all the content.\n", 36 | "\n", 37 | "## Prerequisites\n", 38 | "\n", 39 | "- Machine Learning (Gradient Descent, Neural Networks)\n", 40 | "- Basic Probability Theory (Expectations and Distributions)\n", 41 | "- Multivariate Calculus\n", 42 | "- Python and Pytorch\n", 43 | "\n", 44 | "## Contents\n", 45 | "\n", 46 | "1. Reinforcement Learning Basics (Environments, Agents and Policies)\n", 47 | "2. The Reinforcement Learning Loop (Experiences and Trajectories)\n", 48 | "3. The Reinforcement Learning Objective\n", 49 | "4. Policy Gradients algorithm (Theory and Implementation)\n", 50 | "5. Variance\n", 51 | "6. Reward-to-go\n", 52 | "\n" 53 | ], 54 | "metadata": { 55 | "id": "2-D5VdYpn9gl" 56 | } 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "source": [ 61 | "# Dependencies & Utils\n", 62 | "Run to install and import all dependencies and set-up some utility functions for use with PyTorch. *For this notebook, a CPU runtime is sufficient.*\n" 63 | ], 64 | "metadata": { 65 | "id": "df42qv7V8PRe" 66 | } 67 | }, 68 | { 69 | "cell_type": "code", 70 | "source": [ 71 | "%%capture\n", 72 | "!pip install swig\n", 73 | "!pip install gymnasium[Box2D,mujoco]" 74 | ], 75 | "metadata": { 76 | "id": "NIvrybMt8N-3" 77 | }, 78 | "execution_count": null, 79 | "outputs": [] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "source": [ 84 | "from typing import Union, List, Callable\n", 85 | "from abc import ABC, abstractmethod\n", 86 | "from dataclasses import dataclass\n", 87 | "\n", 88 | "import numpy as np\n", 89 | "import torch\n", 90 | "from torch import nn, optim, distributions\n", 91 | "\n", 92 | "import gymnasium as gym\n", 93 | "\n", 94 | "import itertools\n", 95 | "from tqdm import tqdm\n", 96 | "import matplotlib.pyplot as plt" 97 | ], 98 | "metadata": { 99 | "id": "BECPTmHB8WEC" 100 | }, 101 | "execution_count": null, 102 | "outputs": [] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "source": [ 107 | "## Pytorch Utils" 108 | ], 109 | "metadata": { 110 | "id": "-2aXq-C08nGn" 111 | } 112 | }, 113 | { 114 | "cell_type": "code", 115 | "source": [ 116 | "device = torch.device(\"cpu\")\n", 117 | "dtype = torch.float32\n", 118 | "\n", 119 | "Activation = Union[str, nn.Module]\n", 120 | "\n", 121 | "_str_to_activation = {\n", 122 | " 'relu': nn.ReLU(),\n", 123 | " 'tanh': nn.Tanh(),\n", 124 | " 'leaky_relu': nn.LeakyReLU(),\n", 125 | " 'sigmoid': nn.Sigmoid(),\n", 126 | " 'selu': nn.SELU(),\n", 127 | " 'softplus': nn.Softplus(),\n", 128 | " 'identity': nn.Identity(),\n", 129 | "}\n", 130 | "\n", 131 | "@dataclass\n", 132 | "class MLPConfig:\n", 133 | " input_size: int\n", 134 | " output_size: int\n", 135 | " n_layers: int = 2\n", 136 | " size: int = 64\n", 137 | " activation: Activation = 'tanh'\n", 138 | " output_activation: Activation = 'identity'\n", 139 | "\n", 140 | " def __post_init__(self):\n", 141 | " if isinstance(self.activation, str):\n", 142 | " self.activation = _str_to_activation[self.activation]\n", 143 | " if isinstance(self.output_activation, str):\n", 144 | " self.output_activation = _str_to_activation[self.output_activation]\n", 145 | "\n", 146 | "def build_mlp(config: MLPConfig) -> nn.Module:\n", 147 | " \"\"\"\n", 148 | " Builds a feedforward neural network specified by config\n", 149 | " \"\"\"\n", 150 | "\n", 151 | " layers = []\n", 152 | " in_size = config.input_size\n", 153 | " for _ in range(config.n_layers):\n", 154 | " layers.append(nn.Linear(in_size, config.size))\n", 155 | " layers.append(config.activation)\n", 156 | " in_size = config.size\n", 157 | " layers.append(nn.Linear(in_size, config.output_size))\n", 158 | " layers.append(config.output_activation)\n", 159 | "\n", 160 | " mlp = nn.Sequential(*layers)\n", 161 | " mlp.to(device)\n", 162 | " return mlp\n", 163 | "\n", 164 | "\n", 165 | "def combined_shape(length, shape=None):\n", 166 | " if shape is None:\n", 167 | " return (length,)\n", 168 | " return (length, shape) if np.isscalar(shape) else (length, *shape)\n", 169 | "\n", 170 | "def set_seed(seed):\n", 171 | " \"\"\"Set all the seeds for reproducibility\"\"\"\n", 172 | " np.random.seed(seed) # NumPy\n", 173 | " torch.manual_seed(seed) # PyTorch on CPU\n", 174 | " torch.cuda.manual_seed(seed) # PyTorch on GPU (single GPU)\n", 175 | " torch.cuda.manual_seed_all(seed) # PyTorch on all GPUs (multi-GPU)\n", 176 | " torch.backends.cudnn.deterministic = True # Make CUDA deterministic\n", 177 | " torch.backends.cudnn.benchmark = False # Disable CUDA benchmarking" 178 | ], 179 | "metadata": { 180 | "id": "1Y0oSmam8kXK" 181 | }, 182 | "execution_count": null, 183 | "outputs": [] 184 | }, 185 | { 186 | "cell_type": "markdown", 187 | "source": [ 188 | "# Reinforcement Learning Basics\n", 189 | "\n", 190 | "> **Reinforcement Learning** is a framework in which an *agent* learns to perform tasks through trial-and-error interaction with its *environment*, receiving *rewards/penalties* as feedback from actions.\n", 191 | "\n", 192 | "In this first section, we'll introduce all the formal definitions for the key terms in this definition (shown in *italics*) along with their implementations." 193 | ], 194 | "metadata": { 195 | "id": "ML0g0GCxvu7p" 196 | } 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "source": [ 201 | "## Environments" 202 | ], 203 | "metadata": { 204 | "id": "qb2frBw81eXk" 205 | } 206 | }, 207 | { 208 | "cell_type": "markdown", 209 | "source": [ 210 | "> An **environment** is a 5-tuple $(\\mathcal S, \\mathcal A, p, p_0, r)$ where:\n", 211 | "> - $\\mathcal S$ is the State Space, the set of all states. A state $s\\in\\mathcal S$ is a complete description of the environment\n", 212 | "> - $\\mathcal A$ is the Action Space, the set of all actions an agent can take in the environment\n", 213 | "> - $p: \\mathcal S\\times\\mathcal A\\times\\mathcal S\\to[0,1]$ is the transition distribution s.t. $p(s_{t+1} | s_t,a_t)$ is the probability that the environment transitions from state $s_t$ to state $s_{t+1}$ when action $a_t$ is taken\n", 214 | "> - $p_0:\\mathcal S\\to[0,1]$ is the initial state distribution s.t. $p_0(s)$ is the probability that $s_0 = s$\n", 215 | "> - $r:\\mathcal S\\times\\mathcal A\\to\\mathbb R$ is the reward function s.t. $r(s_t,a_t)$ is the reward given when the environment is in state $s_t$ and action $a_t$ is taken\n", 216 | "\n", 217 | "This formalization is a **Markov Decision Process (MDP)**, the mathematical framework used in reinforcement learning for modelling sequential decision-making problems. MDPs capture problems where outcomes are partly random and partly under the control of a decision maker. Their key characteristic is the Markov Property:\n", 218 | "\n", 219 | "> The **Markov Property**: the probability of transitioning to the next state $s_{t+1}$ depends only on the current state $s_t$ and $a_t$, not on the history of previous states or actions.\n", 220 | "\n", 221 | "For this tutorial, we will use `gymnasium` to make and work with environments. It is a library that comes with a collates many standard environments used for testing RL algorithms, all with a common interface. We have imported it as `gym` for convenience.\n", 222 | "\n", 223 | "One of these is Cart Pole, a classic control problem with the aim of balancing the pole in the cart with left and right forces:\n", 224 | "![cart_pole.gif](https://gymnasium.farama.org/_images/cart_pole.gif)\n", 225 | "\n", 226 | "Let's make an instance using `gym.make`" 227 | ], 228 | "metadata": { 229 | "id": "uL1V8oiKveP_" 230 | } 231 | }, 232 | { 233 | "cell_type": "code", 234 | "source": [ 235 | "cartpole_env = gym.make(\"CartPole-v1\")\n", 236 | "# Reset the environment to an initial state sampled from p_0\n", 237 | "observation, info = cartpole_env.reset()" 238 | ], 239 | "metadata": { 240 | "id": "ZJNTUDTY9F_E" 241 | }, 242 | "execution_count": null, 243 | "outputs": [] 244 | }, 245 | { 246 | "cell_type": "markdown", 247 | "source": [ 248 | "### Observations\n", 249 | "\n", 250 | "In reality, we are rarely fortunate enough to know the whole state of the environment and the best we can get is a partial observation. Let's add this to the formalism:\n", 251 | "\n", 252 | "> $\\mathcal O$ is the observation space, the set of all possible observations. An observation $o\\in\\mathcal O$ is a partial description of the environment.\n", 253 | "\n", 254 | "Let's take a look at the observation space of our environment observation we get of the initial state." 255 | ], 256 | "metadata": { 257 | "id": "iaU6gj531868" 258 | } 259 | }, 260 | { 261 | "cell_type": "code", 262 | "source": [ 263 | "cartpole_env.observation_space" 264 | ], 265 | "metadata": { 266 | "id": "-aPqX5LSDnIb" 267 | }, 268 | "execution_count": null, 269 | "outputs": [] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "source": [ 274 | "observation" 275 | ], 276 | "metadata": { 277 | "id": "OYp6muoI3CW_" 278 | }, 279 | "execution_count": null, 280 | "outputs": [] 281 | }, 282 | { 283 | "cell_type": "markdown", 284 | "source": [ 285 | "This is a numpy ndarray with shape (4,). The [Docs](https://gymnasium.farama.org/environments/classic_control/cart_pole/) tell us that these values refer to:\n", 286 | "\n", 287 | "`[Cart Position, Cart Velocity, Pole Angle, Pole Angular Velocity]`\n", 288 | "\n", 289 | "and the `observation_space` property tells us the minimum and maximum values.\n", 290 | "\n", 291 | "Cart Pole has a continuous state space and observation space but these can also be discrete. For example, consider what the state space of a noughts and crosses environment would be." 292 | ], 293 | "metadata": { 294 | "id": "-29de1nSEySm" 295 | } 296 | }, 297 | { 298 | "cell_type": "markdown", 299 | "source": [ 300 | "Throughout this course, I'll use states and observations somewhat interchangeably to maintain notational simplicity. When discussing theoretical concepts and derivations, I'll primarily refer to states ($s$) to align with standard RL literature. In the code implementations, however, I'll use observations since Gymnasium environments provide observations rather than full states.\n", 301 | "\n", 302 | "It's worth noting that CartPole, which we'll use throughout this notebook, is a fully-observable environment where the observation effectively represents the complete state of the system. Therefore, with this environment, the distinction is less important. However, this isn't always the case." 303 | ], 304 | "metadata": { 305 | "id": "l7HwEEMifXfc" 306 | } 307 | }, 308 | { 309 | "cell_type": "markdown", 310 | "source": [ 311 | "### Actions and Action Spaces" 312 | ], 313 | "metadata": { 314 | "id": "22L-YKeO12K4" 315 | } 316 | }, 317 | { 318 | "cell_type": "code", 319 | "source": [ 320 | "cartpole_env.action_space" 321 | ], 322 | "metadata": { 323 | "id": "odf2gaNmRAo7" 324 | }, 325 | "execution_count": null, 326 | "outputs": [] 327 | }, 328 | { 329 | "cell_type": "markdown", 330 | "source": [ 331 | "This tells us that the action space of Cart Pole is discrete, containing two actions: $\\mathcal A :=\\{0,1\\}$\n", 332 | "\n", 333 | " 0: Push cart to the left\n", 334 | " 1: Push cart to the right\n", 335 | "\n", 336 | "We can sample the action space:" 337 | ], 338 | "metadata": { 339 | "id": "_0BH_qD1RscR" 340 | } 341 | }, 342 | { 343 | "cell_type": "code", 344 | "source": [ 345 | "cartpole_env.action_space.sample()" 346 | ], 347 | "metadata": { 348 | "id": "cl64OdbHRJ8U" 349 | }, 350 | "execution_count": null, 351 | "outputs": [] 352 | }, 353 | { 354 | "cell_type": "markdown", 355 | "source": [ 356 | "Admittedly, sampling $\\{0,1\\}$ is quite underwhelming. Just as with state and observation spaces, action spaces can also be continuous.\n", 357 | "\n", 358 | "Let's look at a different example environment." 359 | ], 360 | "metadata": { 361 | "id": "g3NywXN7U82I" 362 | } 363 | }, 364 | { 365 | "cell_type": "code", 366 | "source": [ 367 | "half_cheetah_env = gym.make(\"HalfCheetah-v5\")\n", 368 | "observation, info = half_cheetah_env.reset()\n", 369 | "\n", 370 | "print(\"Observation Space:\", half_cheetah_env.observation_space)\n", 371 | "print(\"Initial Observation:\", observation)\n", 372 | "print(\"Action Space:\", half_cheetah_env.action_space)\n", 373 | "print(\"Sample Action:\", half_cheetah_env.action_space.sample())" 374 | ], 375 | "metadata": { 376 | "id": "ksTyTxd8YBM7" 377 | }, 378 | "execution_count": null, 379 | "outputs": [] 380 | }, 381 | { 382 | "cell_type": "markdown", 383 | "source": [ 384 | "See the [Docs](https://gymnasium.farama.org/environments/mujoco/half_cheetah/) for more info on the Half Cheetah environment." 385 | ], 386 | "metadata": { 387 | "id": "ZP5aWu_dZ38t" 388 | } 389 | }, 390 | { 391 | "cell_type": "markdown", 392 | "source": [ 393 | "### Take a Step\n", 394 | "Environments exist throughout time, transitioning between states when actions are taken. At each *timestep* $t$,\n", 395 | "\n", 396 | "1. We receive an observation $o_t\\in\\mathcal O$ of the environment's state $s_t\\in\\mathcal S$\n", 397 | "\n", 398 | "2. We select an action $a_t\\in\\mathcal A$\n", 399 | "\n", 400 | "3. The environment transitions to a new state $s_{t+1}$ sampled as:\n", 401 | "\n", 402 | " $s_{t+1} \\sim p(\\cdot|s_t,a_t)$\n", 403 | "\n", 404 | "4. The reward function provides a reward based on the action taken:\n", 405 | "\n", 406 | " $r_t = r(s_t, a_t)$\n", 407 | "\n", 408 | " Higher rewards indicate that a good thing has happened.\n", 409 | "\n", 410 | "With `gymnasium` we take an action, incrementing the timestep, using `env.step(action)`" 411 | ], 412 | "metadata": { 413 | "id": "81cybtloaz4U" 414 | } 415 | }, 416 | { 417 | "cell_type": "code", 418 | "source": [ 419 | "action = cartpole_env.action_space.sample()\n", 420 | "observation, reward, terminated, truncated, info = cartpole_env.step(action)" 421 | ], 422 | "metadata": { 423 | "id": "JwQUl_Aha9S2" 424 | }, 425 | "execution_count": null, 426 | "outputs": [] 427 | }, 428 | { 429 | "cell_type": "markdown", 430 | "source": [ 431 | "This gives us an observation $o_{t+1}$ of the new state $s_{t+1}$ and the reward $r_t$.\n", 432 | "\n", 433 | "In Cart Pole, the aim is to keep the pole upright for as long as possible. Hence, a reward of $+1$ is given at every timestep in which the pole is not on the floor." 434 | ], 435 | "metadata": { 436 | "id": "BWttmbxwDCo0" 437 | } 438 | }, 439 | { 440 | "cell_type": "markdown", 441 | "source": [ 442 | "### Episodes, Termination and Truncation\n", 443 | "\n", 444 | "> A **Terminal State** is one that cannot be left (such as winning, losing, dying etc.)\n", 445 | "\n", 446 | "This leads to a distinction between two categories of task:\n", 447 | "\n", 448 | "> An **Episodic Task** has a *terminal state*\n", 449 | ">\n", 450 | "> e.g. a game of chess, a level in a video game, landing a rocket\n", 451 | "\n", 452 | ">A **Continuing Task** has no *terminal state*\n", 453 | "> \n", 454 | "> e.g. trading on the stock market, walking, keeping the pole upright\n", 455 | "\n", 456 | "When an environment enters a terminal state, we say that the episode has ended. In `gymnasium`, we are notified of this by the `terminated` flag returning `True`.\n", 457 | "\n", 458 | "> **Truncation** is the act of forcibly ending interaction with the environment before reaching a natural conclusion\n", 459 | "\n", 460 | "We can make a continuing task episodic, or truncate an episodic task by enforcing a time limit (i.e. the maximum number of timesteps). Once the time limit is reached, the `truncated` flag returns `True`.\n", 461 | "\n", 462 | "While both termination and truncation end an episode it is important to keep them distinct, since when an episode is truncated, it has not entered a terminal state.\n", 463 | "\n", 464 | "Cart Pole has a terminal state (when the pole falls or the cart leaves the track) and a time limit of 500 steps." 465 | ], 466 | "metadata": { 467 | "id": "p_kFMWKy2Hdp" 468 | } 469 | }, 470 | { 471 | "cell_type": "markdown", 472 | "source": [ 473 | "## Agents and Policies\n", 474 | "\n", 475 | "So far we have directly sampled actions from the action space however in RL, it is the role of the agent to select actions.\n", 476 | "\n", 477 | "> An **agent** is an entity that interacts with an environment by taking actions according to its policy.\n", 478 | "\n", 479 | "> A **policy**, $\\pi:\\mathcal O\\to P(\\mathcal A)$, is a function from observations to distributions from which an action may be sampled (selected) given an observation:\n", 480 | "> $a_t \\sim \\pi(\\cdot|o_t)$\n", 481 | ">\n", 482 | "> Equivalently, $\\pi(a_t|o_t)$ is the probability that action $a_t$ is selected given observation $o_t$\n", 483 | "\n", 484 | "Let's setup some abstract base classes `Policy` and `Agent` to ensure that all the policies and agents we write in this tutorial have the same interface and meet the definitions." 485 | ], 486 | "metadata": { 487 | "id": "oz4dLUCSNYiZ" 488 | } 489 | }, 490 | { 491 | "cell_type": "code", 492 | "source": [ 493 | "class Policy(ABC):\n", 494 | "\n", 495 | " @abstractmethod\n", 496 | " def get_action(self, observation: np.ndarray) -> np.ndarray:\n", 497 | " \"\"\" Select an action given the observation.\n", 498 | " Args:\n", 499 | " observation: an observation of the environment\n", 500 | " Returns:\n", 501 | " action: an action to take\n", 502 | " \"\"\"" 503 | ], 504 | "metadata": { 505 | "id": "ZnDoOlf39Lrd" 506 | }, 507 | "execution_count": null, 508 | "outputs": [] 509 | }, 510 | { 511 | "cell_type": "code", 512 | "source": [ 513 | "class Agent(ABC):\n", 514 | "\n", 515 | " @property\n", 516 | " @abstractmethod\n", 517 | " def policy(self) -> Policy:\n", 518 | " \"\"\" The agent's policy \"\"\"\n", 519 | "\n", 520 | " @abstractmethod\n", 521 | " def update(self, experiences):\n", 522 | " \"\"\" Update the agent's policy \"\"\"\n", 523 | " raise NotImplementedError" 524 | ], 525 | "metadata": { 526 | "id": "PnoPqyDm9dEi" 527 | }, 528 | "execution_count": null, 529 | "outputs": [] 530 | }, 531 | { 532 | "cell_type": "markdown", 533 | "source": [ 534 | "For the next sections, let's define a very simple agent and policy where the actions are sampled at random from the action space. This doesn't provide any new functionality, but it will allow us to implement and test some key bits of infrastructure before implementing proper agents." 535 | ], 536 | "metadata": { 537 | "id": "B_X6fqhR9H9K" 538 | } 539 | }, 540 | { 541 | "cell_type": "code", 542 | "source": [ 543 | "class RandomPolicy(Policy):\n", 544 | " def __init__(self, action_space: gym.Space):\n", 545 | " self.action_space = action_space\n", 546 | "\n", 547 | " def get_action(self, observation: np.ndarray) -> np.ndarray:\n", 548 | " # Sample a random action from the action space\n", 549 | " action = self.action_space.sample()\n", 550 | " return action" 551 | ], 552 | "metadata": { 553 | "id": "Ua8g69IELNgA" 554 | }, 555 | "execution_count": null, 556 | "outputs": [] 557 | }, 558 | { 559 | "cell_type": "code", 560 | "source": [ 561 | "class RandomAgent(Agent):\n", 562 | " def __init__(self, action_space: gym.Space):\n", 563 | " self._policy = RandomPolicy(action_space)\n", 564 | "\n", 565 | " @property\n", 566 | " def policy(self) -> Policy:\n", 567 | " return self._policy\n", 568 | "\n", 569 | " def update(self, experiences):\n", 570 | " pass" 571 | ], 572 | "metadata": { 573 | "id": "a9MX9pN1MNOu" 574 | }, 575 | "execution_count": null, 576 | "outputs": [] 577 | }, 578 | { 579 | "cell_type": "code", 580 | "source": [ 581 | "random_agent = RandomAgent(cartpole_env.action_space)" 582 | ], 583 | "metadata": { 584 | "id": "zsrYBfne8Dtg" 585 | }, 586 | "execution_count": null, 587 | "outputs": [] 588 | }, 589 | { 590 | "cell_type": "markdown", 591 | "source": [ 592 | "It's worth noting that while in some cases the agent is no more than a thin wrapper around the policy, the concepts are distinct. In more complex algorithms than we discuss in this part, the agent has additional functionality." 593 | ], 594 | "metadata": { 595 | "id": "XoWBIa_QEMfm" 596 | } 597 | }, 598 | { 599 | "cell_type": "markdown", 600 | "source": [ 601 | "# The RL Loop\n", 602 | "\n", 603 | "Let us now rewrite the 4 stages of a timestep from above except have the agent select the actions. This gives us the **RL Loop**:\n", 604 | "1. The agent receives an observation $o_t\\in\\mathcal O$ from the environment\n", 605 | "2. The agent selects an action $a_t$ given $o_t$ using it's policy $\\pi$\n", 606 | " \n", 607 | " $a_t\\sim\\pi(\\cdot|o_t)$\n", 608 | "3. The environment transitions to a new state $s_{t+1}$ sampled as:\n", 609 | "\n", 610 | " $s_{t+1} \\sim p(\\cdot|s_t,a_t)$\n", 611 | "\n", 612 | "4. The agent receives a reward $r_t$ from the enviroment for taking $a_t$:\n", 613 | "\n", 614 | " $r_t = r(s_t, a_t)$\n", 615 | "\n", 616 | "In this section, we will implement the RL Loop." 617 | ], 618 | "metadata": { 619 | "id": "HswrXvOk7MAF" 620 | } 621 | }, 622 | { 623 | "cell_type": "markdown", 624 | "source": [ 625 | "## Experiences\n", 626 | "\n", 627 | "In each cycle of the loop, i.e. at each timestep, the agent receives $o_t$, selects $a_t$ and receives $r_t$.\n", 628 | "\n", 629 | "> An **experience** is the triple $(o_t, a_t, r_t)$\n", 630 | "\n", 631 | "It is the atomic unit of agent-environment interaction. In reinforcement learning, the agent learns from experiences.\n", 632 | "\n", 633 | "Let's create a dataclass to record experiences:" 634 | ], 635 | "metadata": { 636 | "id": "yhTGApdOV4eo" 637 | } 638 | }, 639 | { 640 | "cell_type": "code", 641 | "source": [ 642 | "@dataclass\n", 643 | "class ExperienceData:\n", 644 | " observations: torch.Tensor\n", 645 | " actions: torch.Tensor\n", 646 | " rewards: torch.Tensor" 647 | ], 648 | "metadata": { 649 | "id": "XadpBHXQXoR5" 650 | }, 651 | "execution_count": null, 652 | "outputs": [] 653 | }, 654 | { 655 | "cell_type": "markdown", 656 | "source": [ 657 | "An experience is just a single sample and as we know from supervised learning, it can be helpful to learn from multiple samples at once i.e. a batch. For this reason, our experience dataclass is flexible enough to record either one or multiple experiences." 658 | ], 659 | "metadata": { 660 | "id": "MSuyMdxyysWj" 661 | } 662 | }, 663 | { 664 | "cell_type": "markdown", 665 | "source": [ 666 | "## Trajectories\n", 667 | "> A **trajectory** $\\tau$ is a sequence of experiences from an agent interacting with its environment according to its policy until it reaches a terminal state or is truncated.\n", 668 | "\n", 669 | "Note that two trajectories from the same policy and environment may vary considerably due to the stochasiticity in both.\n", 670 | "\n", 671 | "Let's make a dataclass for trajectories:" 672 | ], 673 | "metadata": { 674 | "id": "9eQswR5VLQqX" 675 | } 676 | }, 677 | { 678 | "cell_type": "code", 679 | "source": [ 680 | "@dataclass\n", 681 | "class Trajectory():\n", 682 | " observations: torch.Tensor\n", 683 | " actions: torch.Tensor\n", 684 | " rewards: torch.Tensor\n", 685 | " terminals: torch.Tensor\n", 686 | " truncateds: torch.Tensor\n", 687 | " total_reward: float\n", 688 | " length: int" 689 | ], 690 | "metadata": { 691 | "id": "TgI6VYbsvgbH" 692 | }, 693 | "execution_count": null, 694 | "outputs": [] 695 | }, 696 | { 697 | "cell_type": "markdown", 698 | "source": [ 699 | "While this appears similar to our `ExperienceData` class, trajectories have distinct properties that warrant their own class representation:\n", 700 | "- **Sequential integrity**: experiences in a trajectory follow a strict temporal order, where each directly follows the previous one.\n", 701 | "- **Bounded Completion**: a trajectory ends exactly once, either by termination or truncation. Hence, the final experience must have at least one of its terminated or truncated flags set to true, and all other experiences in the trajectory must have both flags set to false.\n", 702 | "\n", 703 | "To enforce, or at least encourage these properties, we can make a `TrajectoryBuilder` class to record experiences as we run through the RL Loop.\n", 704 | "This implementation:\n", 705 | "- allows us to record an experience using add_experience\n", 706 | "- keeps track of whether the trajectory is finished\n", 707 | "- enforces Bounded Completion through assertions\n", 708 | "- encourages Sequential Integrity since only one experience may be added at a time\n", 709 | "- uses lists for efficiency until the trajectory is complete" 710 | ], 711 | "metadata": { 712 | "id": "zngQXbtOyXqm" 713 | } 714 | }, 715 | { 716 | "cell_type": "code", 717 | "source": [ 718 | "class TrajectoryBuilder():\n", 719 | "\n", 720 | " def __init__(self):\n", 721 | " self._observations = []\n", 722 | " self._actions = []\n", 723 | " self._rewards = []\n", 724 | " self._terminals = []\n", 725 | " self._truncateds = []\n", 726 | " self._done = False\n", 727 | " self._length = 0\n", 728 | "\n", 729 | " def add_experience(self, observation, action, reward, terminal, truncated):\n", 730 | " assert not self._done, \"Trajectory is already done. Call get_trajectory()\"\n", 731 | "\n", 732 | " self._done = terminal or truncated\n", 733 | "\n", 734 | " self._observations.append(observation)\n", 735 | " self._actions.append(action)\n", 736 | " self._rewards.append(reward)\n", 737 | " self._terminals.append(self._done)\n", 738 | " self._truncateds.append(truncated)\n", 739 | "\n", 740 | " self._length += 1\n", 741 | "\n", 742 | " return self._done\n", 743 | "\n", 744 | " @property\n", 745 | " def done(self) -> bool:\n", 746 | " return self._done\n", 747 | "\n", 748 | " @property\n", 749 | " def trajectory(self) -> Trajectory:\n", 750 | " assert self._done, \"Trajectory is not complete. It must be either terminated or truncated\"\n", 751 | "\n", 752 | " def list_to_tensor(x):\n", 753 | " return torch.tensor(np.array(x), dtype=dtype, device=device, requires_grad=False)\n", 754 | "\n", 755 | " # Convert lists to tensors\n", 756 | " observations = list_to_tensor(self._observations)\n", 757 | " actions = list_to_tensor(self._actions)\n", 758 | " rewards = list_to_tensor(self._rewards)\n", 759 | " terminals = list_to_tensor(self._terminals)\n", 760 | " truncateds = list_to_tensor(self._truncateds)\n", 761 | "\n", 762 | " total_reward = rewards.sum()\n", 763 | "\n", 764 | " return Trajectory(\n", 765 | " observations,\n", 766 | " actions,\n", 767 | " rewards,\n", 768 | " terminals,\n", 769 | " truncateds,\n", 770 | " total_reward,\n", 771 | " self._length\n", 772 | " )" 773 | ], 774 | "metadata": { 775 | "id": "phfJ3IK0vkXL" 776 | }, 777 | "execution_count": null, 778 | "outputs": [] 779 | }, 780 | { 781 | "cell_type": "markdown", 782 | "source": [ 783 | "Now we're ready to implement the RL Loop!" 784 | ], 785 | "metadata": { 786 | "id": "ES6ox-F29u1m" 787 | } 788 | }, 789 | { 790 | "cell_type": "code", 791 | "source": [ 792 | "def sample_trajectory(env: gym.Env, agent: Agent) -> Trajectory:\n", 793 | " \"\"\"Sample a trajectory in the environment from a policy.\"\"\"\n", 794 | "\n", 795 | " # Reset the environment to start a new trajectory\n", 796 | " observation, info = env.reset()\n", 797 | "\n", 798 | " traj_builder = TrajectoryBuilder()\n", 799 | "\n", 800 | " # RL Loop\n", 801 | " while not traj_builder.done:\n", 802 | "\n", 803 | " # Select the action based on the observation\n", 804 | " action = agent.policy.get_action(observation)\n", 805 | "\n", 806 | " # Use that action to take a step in the environment\n", 807 | " new_observation, reward, terminated, truncated, info = env.step(action)\n", 808 | "\n", 809 | " # Record the experience\n", 810 | " # Note that we record the o_t not o_t+1\n", 811 | " traj_builder.add_experience(observation, action, reward, terminated, truncated)\n", 812 | "\n", 813 | " # Update the observation for the next loop\n", 814 | " observation = new_observation\n", 815 | "\n", 816 | " return traj_builder.trajectory" 817 | ], 818 | "metadata": { 819 | "id": "fUv-pTyN90UA" 820 | }, 821 | "execution_count": null, 822 | "outputs": [] 823 | }, 824 | { 825 | "cell_type": "markdown", 826 | "source": [ 827 | "Let's test it!" 828 | ], 829 | "metadata": { 830 | "id": "Q6pPlQTzBSpt" 831 | } 832 | }, 833 | { 834 | "cell_type": "code", 835 | "source": [ 836 | "sample_trajectory(cartpole_env, random_agent)" 837 | ], 838 | "metadata": { 839 | "id": "xyUnSjpZBSNT" 840 | }, 841 | "execution_count": null, 842 | "outputs": [] 843 | }, 844 | { 845 | "cell_type": "markdown", 846 | "source": [ 847 | "# Measuring Performance\n", 848 | "\n", 849 | "## Rewards\n", 850 | "\n", 851 | "A key idea in RL is that rewards provide feedback to the agent such that it can learn to perform the task better. Formally:\n", 852 | "\n", 853 | "> A **reward**, $r_t$, is a numerical feedback signal provided by the reward function at timestep $t$\n", 854 | "\n", 855 | "> A **reward function**, $r: \\mathcal S \\times \\mathcal A \\to \\mathbb R$, is a function such that the accumulation of rewards over a trajectory corresponds to performance on the task in that trajectory.\n", 856 | "\n", 857 | "## Cumulative Reward\n", 858 | "\n", 859 | "We can overload the definition of $r$ to accept trajectories giving us a measure of performance in a trajectory:\n", 860 | "\n", 861 | "$$r(\\tau) = r(s_0, a_0, ..., s_{T-1}, a_{T-1}) = \\sum_{t=0}^{T-1} r(s_t, a_t)$$\n", 862 | "\n", 863 | "Let's look at some examples of reward functions. For each, consider how the cumulative reward relates to the goal.\n", 864 | "\n", 865 | "**Cart Pole**\n", 866 | "\n", 867 | " Goal: keep the pole upright for as long as possible\n", 868 | " Reward: +1 for every timestep the pole remains upright\n", 869 | "\n", 870 | "**Lunar Lander**\n", 871 | "\n", 872 | " Goal: Safely land a spacecraft on the moon's surface with minimal fuel usage and proper positioning\n", 873 | " Reward:\n", 874 | " Increases/decreases based on proximity to landing pad (closer = higher reward)\n", 875 | " Increases/decreases based on spacecraft velocity (slower = higher reward)\n", 876 | " Decreases based on tilt angle (penalty for non-horizontal orientation)\n", 877 | " +10 for each leg in contact with the ground\n", 878 | " -0.03 for each frame a side engine is firing\n", 879 | " -0.3 for each frame the main engine is firing\n", 880 | " +100 bonus for landing safely\n", 881 | " -100 penalty for crashing\n", 882 | "\n", 883 | "You may be tempted to characterise rewards as being received for \"good actions\" or \"making progress towards the goal\".\n", 884 | "While rewards are given for the state-action pairs at a particular timestep, it's more accurate to start with the cumulative reward $r(\\tau)$ being a measure of performance and working back from there to design a reward function.\n", 885 | "This is because it is the task performance over the whole trajectory that actually matters rather than the actions taken at any individual timestep.\n", 886 | "The relationship between immediate rewards and overall task performance can sometimes be non-obvious or counterintuitive, and designing effective reward functions plays a large part in reinforcement learning in practice.\n", 887 | "\n", 888 | "## Trajectory Distribution\n", 889 | "\n", 890 | "To measure the success of a policy $\\pi_\\theta$ parameterised by $\\theta$ we also need a notion of the probability of a trajectory occuring under that policy, and transition distribution $p$.\n", 891 | "\n", 892 | "$$p_\\theta(\\tau) = p_\\theta(s_0, a_0, ..., s_{T-1}, a_{T-1}) = p(s_0)\\pi_\\theta(a_0|s_0)\\prod_{t=1}^{T-1} p(s_t|s_{t-1}, a_{t-1})\\pi_\\theta(a_t|s_t)$$\n", 893 | "\n", 894 | "This is just the chain rule of probabilities used to break down the trajectory into its individual steps." 895 | ], 896 | "metadata": { 897 | "id": "7j_rlM4KJpf1" 898 | } 899 | }, 900 | { 901 | "cell_type": "markdown", 902 | "source": [ 903 | "## The RL Objective\n", 904 | "\n", 905 | "In reinforcement learning, we seek to find policy parameters $\\theta$ with maximal task performance. We do so by optimising the RL Objective - a measure of the policy's average performance. It is the expected cumulative reward over trajectories from the policy. Formally:\n", 906 | "\n", 907 | "$$J(\\theta)=\\mathbb E_{\\tau\\sim p_\\theta(\\tau)}[r(\\tau)]$$\n", 908 | "\n", 909 | "In other words, this objective quantifies how well a policy performs across all possible trajectories it might generate, weighted by how likely each trajectory is to occur." 910 | ], 911 | "metadata": { 912 | "id": "6pXbe0_PbcKP" 913 | } 914 | }, 915 | { 916 | "cell_type": "markdown", 917 | "source": [ 918 | "# Vanilla Policy Gradient Algorithm\n", 919 | "\n", 920 | "In this section, we will introduce and implement the simplest algorithm to optimise the RL objective. Up to this point, we have considered the policy as an arbitrary parameterised function but in deep reinforcement learning, it is specifically a neural network. As with most deep learning, the tool of choice to optimise model parameters to maximise an objective is gradient ascent. Therefore, we need to find the gradient of $J(\\theta)$.\n", 921 | "\n", 922 | "$$\\nabla_\\theta J(\\theta)= \\nabla_\\theta\\mathbb E_{\\tau\\sim p_\\theta(\\tau)}[r(\\tau)]$$\n", 923 | "\n", 924 | "Express the expectation as an integral:\n", 925 | "\n", 926 | "$$\n", 927 | "= \\nabla_\\theta \\int p_\\theta(\\tau)r(\\tau) \\rm d\\tau\n", 928 | "$$\n", 929 | "\n", 930 | "Take gradient inside the integral:\n", 931 | "$$\n", 932 | "= \\int\\nabla_\\theta p_\\theta(\\tau)r(\\tau) \\rm d\\tau\n", 933 | "$$\n", 934 | "\n", 935 | "Since the agent doesn't know the transition probabilities, $\\nabla_\\theta p_\\theta(\\tau)$ cannot be found directly. However, as we will see, $\\nabla_\\theta \\log p_\\theta(\\tau)$ can be computed without the transition probabilities. The next step, uses a convenient identity derived from the derivative of a logarithm:\n", 936 | "\n", 937 | "$$\n", 938 | "\\nabla_\\theta p_\\theta(\\tau)\n", 939 | "= p_\\theta(\\tau)\\frac{\\nabla_\\theta p_\\theta(\\tau)}{p_\\theta(\\tau)}\n", 940 | "= p_\\theta(\\tau)\\nabla_\\theta \\log p_\\theta(\\tau)\n", 941 | "$$\n", 942 | "\n", 943 | "Substituting this in gives:\n", 944 | "$$\n", 945 | "\\nabla_\\theta J(\\theta)\n", 946 | "= \\int p_\\theta(\\tau)\\nabla_\\theta \\log p_\\theta(\\tau)r(\\tau) \\rm d\\tau\n", 947 | "$$\n", 948 | "\n", 949 | "Which is in fact just an expectation over trajectories:\n", 950 | "$$ = \\mathbb E_{\\tau\\sim p_\\theta(\\tau)}[\\nabla_\\theta \\log p_\\theta(\\tau)r(\\tau)]$$\n", 951 | "\n", 952 | "\n", 953 | "Now we expand the gradient part to show that it doesn't require the transition probabilities:\n", 954 | "$$\n", 955 | "\\nabla_\\theta \\log p_\\theta(\\tau)\n", 956 | "= \\nabla_\\theta \\log p(s_0)\\pi_\\theta(a_0|s_0)\\prod_{t=1}^{T-1} p(s_t|s_{t-1}, a_{t-1})\\pi_\\theta(a_t|s_t)\n", 957 | "\\\\=\\nabla_\\theta \\big(\\log p(s_0) + \\log \\pi_\\theta(a_0|s_0) + \\sum_{t=1}^{T-1} \\log p(s_t|s_{t-1}, a_{t-1}) + \\log \\pi_\\theta(a_t|s_t)\\big)\n", 958 | "$$\n", 959 | "\n", 960 | "The $p$ terms don't depend on $\\theta$ so their gradients are zero.\n", 961 | "$$\n", 962 | "=\\nabla_\\theta \\big(\\log \\pi_\\theta(a_0|s_0) + \\sum_{t=1}^{T-1} \\log \\pi_\\theta(a_t|s_t)\\big)\n", 963 | "\\\\\n", 964 | "=\\nabla_\\theta \\sum_{t=0}^{T-1} \\log \\pi_\\theta(a_t|s_t)\n", 965 | "$$\n", 966 | "\n", 967 | "Pluging this back into the expectation and expanding the cumulative reward, we get:\n", 968 | "$$\\nabla_\\theta J(\\theta) = \\mathbb E_{\\tau\\sim p_\\theta(\\tau)}\\left[\\left(\\sum_{t=0}^{T-1} \\nabla_\\theta \\log \\pi_\\theta(a_{t}|s_{t})\\right)\\left(\\sum_{t=0}^{T-1} r(s_{t}, a_{t})\\right)\\right]$$\n", 969 | "\n", 970 | "Since we can't enumerate all possible trajectories, we cannot compute this expectation exactly. However, we can use a Monte Carlo approximation by sampling $N$ trajectories:\n", 971 | "\n", 972 | "$$\n", 973 | "\\approx \\frac{1}{N}\\sum_{i=1}^N \\left(\\sum_{t=0}^{T-1} \\nabla_\\theta \\log \\pi_\\theta(a_{it}|s_{it})\\right)\\left(\\sum_{t=0}^{T-1} r(s_{it}, a_{it})\\right)\n", 974 | "$$\n", 975 | "\n", 976 | "Where $i$ is the index of the trajectory and $t$ is the timestep within each trajectory.\n", 977 | "\n", 978 | "Finally, distributing the cumulative reward across the gradient sum gives:\n", 979 | "\n", 980 | "$$\n", 981 | "= \\frac{1}{N}\\sum_{i=1}^N \\sum_{t=0}^{T-1} \\left[\\nabla_\\theta\\log \\pi_\\theta(a_{it}|s_{it}) \\cdot \\left(\\sum_{k=0}^{T-1} r(s_{ik}, a_{ik})\\right)\\right]\n", 982 | "$$\n", 983 | "\n", 984 | "\n", 985 | "This can be implemented as a weighted log likelihood loss where the log likelihood of a state-action pair is weighted by the cumulative reward of the trajectory to which it belongs. We call this a surrogate loss, since it serves as a substitute objective whose gradient matches that of our true objective $J(\\theta)$.\n", 986 | "\n", 987 | "$$\n", 988 | "L_{\\text{surrogate}}(\\theta) = \\frac{1}{N}\\sum_{i=1}^N \\sum_{t=0}^{T-1} \\log \\pi_\\theta(a_{it}|s_{it}) \\cdot \\left(\\sum_{k=0}^{T-1} r(s_{ik}, a_{ik})\\right)\n", 989 | "$$\n", 990 | "\n", 991 | "Unlike in supervised learning, where all examples are good examples and have weight 1, in the surrogate loss, the cumulative reward measures how good the policy decisions in that trajectory are. The rigorous derivation was essential to ensure that gradient ascent on this loss genuinely optimises our original objective, but this interpretation also provides a reassuring intuition for why it works. Indeed, when we used the convenient identity to remove the dependence on the transition distributions, we changed the objective from $J(\\theta)$ to the surrogate loss. While these objectives are not equivalent, they have identical gradients, so gradient ascent on the surrogate loss will also maximise $J(\\theta)$.\n", 992 | "\n", 993 | "\n", 994 | "## Algorithm\n", 995 | "\n", 996 | "Having derived the gradient, we can now formalise the policy gradient algorithm:\n", 997 | "\n", 998 | "1. Sample $N$ trajectories $\\{\\tau^i\\}$ by running the policy\n", 999 | "2. Calculate the gradient\n", 1000 | "$\n", 1001 | "\\nabla_\\theta J(\\theta) \\approx \\sum_i \\sum_t \\left[\\nabla_\\theta\\log \\pi_\\theta(a_{it}|s_{it}) \\cdot \\left(\\sum_k r(s_{ik}, a_{ik})\\right)\\right]\n", 1002 | "$\n", 1003 | "3. Take a gradient ascent step $\\theta \\leftarrow \\theta + \\alpha\\nabla_\\theta J(\\theta)$\n", 1004 | "\n", 1005 | " where $\\alpha$ is the learning rate\n", 1006 | "\n", 1007 | "Repeat steps 1 to 3.\n", 1008 | "\n", 1009 | "We omit the constant $\\frac{1}{N}$ in the gradient since it will just scale the learning rate." 1010 | ], 1011 | "metadata": { 1012 | "id": "mViQAwQ37OP3" 1013 | } 1014 | }, 1015 | { 1016 | "cell_type": "markdown", 1017 | "source": [ 1018 | "## Implemenation\n", 1019 | "\n", 1020 | "Let's work through the algorithm step-by-step." 1021 | ], 1022 | "metadata": { 1023 | "id": "_sqm3XssigUn" 1024 | } 1025 | }, 1026 | { 1027 | "cell_type": "markdown", 1028 | "source": [ 1029 | "### Sampling Trajectories\n", 1030 | "\n", 1031 | "The first step is to sample N trajectories. We can do this with a simple for loop and our sample_trajectory function." 1032 | ], 1033 | "metadata": { 1034 | "id": "wYskxcKIO0Xx" 1035 | } 1036 | }, 1037 | { 1038 | "cell_type": "code", 1039 | "source": [ 1040 | "def sample_N_trajectories(environment, agent, N) -> List[Trajectory]:\n", 1041 | " trajectories = []\n", 1042 | "\n", 1043 | " for _ in range(N):\n", 1044 | " trajectory = sample_trajectory(environment, agent)\n", 1045 | " trajectories.append(trajectory)\n", 1046 | "\n", 1047 | " return trajectories" 1048 | ], 1049 | "metadata": { 1050 | "id": "RwasMDTbJlPR" 1051 | }, 1052 | "execution_count": null, 1053 | "outputs": [] 1054 | }, 1055 | { 1056 | "cell_type": "markdown", 1057 | "source": [ 1058 | "In environments with terminal states, we may find that early in training, the trajectories are much shorter than later on. This happens with Cart Pole. Therefore, with this sampling function the gradient steps early in training are based on fewer experiences than those later on. This leads to less stable gradient updates when the agent is learning most rapidly. Another approach is to fix the number of timesteps/experiences we sample per gradient update. As in supervised learning, this number is called the batch size.\n", 1059 | "\n", 1060 | "To avoid truncating any trajectories, we fix a minimum batch size and sample just enough trajectories to fill it. This ensures we always have sufficient data for stable gradient estimates while respecting episode boundaries. Here's the updated implementation:" 1061 | ], 1062 | "metadata": { 1063 | "id": "KptZ4rE1oNee" 1064 | } 1065 | }, 1066 | { 1067 | "cell_type": "code", 1068 | "source": [ 1069 | "def sample_trajectories(environment, agent, batch_size) -> List[Trajectory]:\n", 1070 | " trajectories = []\n", 1071 | "\n", 1072 | " timesteps_so_far = 0\n", 1073 | " while timesteps_so_far < batch_size:\n", 1074 | " trajectory = sample_trajectory(environment, agent)\n", 1075 | " trajectories.append(trajectory)\n", 1076 | " timesteps_so_far += trajectory.length\n", 1077 | "\n", 1078 | " return trajectories" 1079 | ], 1080 | "metadata": { 1081 | "id": "J_qY63hkWVpb" 1082 | }, 1083 | "execution_count": null, 1084 | "outputs": [] 1085 | }, 1086 | { 1087 | "cell_type": "markdown", 1088 | "source": [ 1089 | "Test it!" 1090 | ], 1091 | "metadata": { 1092 | "id": "f-O0YFsFOiuV" 1093 | } 1094 | }, 1095 | { 1096 | "cell_type": "code", 1097 | "source": [ 1098 | "trajectories = sample_trajectories(cartpole_env, random_agent, 10)\n", 1099 | "print(trajectories)" 1100 | ], 1101 | "metadata": { 1102 | "id": "U7zBQ0DhMsIr" 1103 | }, 1104 | "execution_count": null, 1105 | "outputs": [] 1106 | }, 1107 | { 1108 | "cell_type": "markdown", 1109 | "source": [ 1110 | "### From Trajectories to Experiences\n", 1111 | "\n", 1112 | "Previously, our ExperienceData class stored rewards for individual state-action pairs. However, our theoretical derivation shows that in the policy gradient calculation, these individual rewards aren't directly used. Instead, what matters is the weight applied to each log probability.\n", 1113 | "To better align with this understanding, let's update our ExperienceData class:" 1114 | ], 1115 | "metadata": { 1116 | "id": "QxpK7cmDOsUN" 1117 | } 1118 | }, 1119 | { 1120 | "cell_type": "code", 1121 | "source": [ 1122 | "@dataclass\n", 1123 | "class ExperienceData:\n", 1124 | " observations: torch.Tensor\n", 1125 | " actions: torch.Tensor\n", 1126 | " weights: torch.Tensor" 1127 | ], 1128 | "metadata": { 1129 | "id": "tZfYs4TTZr0-" 1130 | }, 1131 | "execution_count": null, 1132 | "outputs": [] 1133 | }, 1134 | { 1135 | "cell_type": "markdown", 1136 | "source": [ 1137 | "We can now write a function which takes a list of sampled trajectories and returns an experience data object to update the policy with. We abstract the weights calculation using a function `calculate_weights` that we pass in." 1138 | ], 1139 | "metadata": { 1140 | "id": "NFcaphuLgtJ0" 1141 | } 1142 | }, 1143 | { 1144 | "cell_type": "code", 1145 | "source": [ 1146 | "def compile_trajectories(trajectories: list[Trajectory], calculate_weights: Callable[[torch.Tensor], torch.Tensor]) -> ExperienceData:\n", 1147 | "\n", 1148 | " # Use torch.cat to directly concatenate tensors from all trajectories\n", 1149 | " observations = torch.cat([t.observations for t in trajectories], dim=0)\n", 1150 | " actions = torch.cat([t.actions for t in trajectories], dim=0)\n", 1151 | "\n", 1152 | " # Calculate weights for each trajectory from its rewards\n", 1153 | " weights = torch.cat([calculate_weights(t.rewards) for t in trajectories], dim=0)\n", 1154 | "\n", 1155 | " return ExperienceData(observations, actions, weights)" 1156 | ], 1157 | "metadata": { 1158 | "id": "9zsVzdGl9E_q" 1159 | }, 1160 | "execution_count": null, 1161 | "outputs": [] 1162 | }, 1163 | { 1164 | "cell_type": "markdown", 1165 | "source": [ 1166 | "Let's implement the particular function we use to calculate the weights for a trajectory. This takes the sum of rewards across the trajectory and broadcasts it to the length of the trajectory such that there is one weight for each experience." 1167 | ], 1168 | "metadata": { 1169 | "id": "9voL9xHSNmWU" 1170 | } 1171 | }, 1172 | { 1173 | "cell_type": "code", 1174 | "source": [ 1175 | "def sum_of_rewards(rewards: torch.Tensor) -> torch.Tensor:\n", 1176 | " total_reward = rewards.sum()\n", 1177 | " return torch.ones_like(rewards) * total_reward" 1178 | ], 1179 | "metadata": { 1180 | "id": "C1TUJZEi9GYm" 1181 | }, 1182 | "execution_count": null, 1183 | "outputs": [] 1184 | }, 1185 | { 1186 | "cell_type": "markdown", 1187 | "source": [ 1188 | "Test it!" 1189 | ], 1190 | "metadata": { 1191 | "id": "D95EHSWHMkLY" 1192 | } 1193 | }, 1194 | { 1195 | "cell_type": "code", 1196 | "source": [ 1197 | "discrete_experience_data = compile_trajectories(trajectories, sum_of_rewards)\n", 1198 | "print(discrete_experience_data)" 1199 | ], 1200 | "metadata": { 1201 | "id": "sAGybikczk0j" 1202 | }, 1203 | "execution_count": null, 1204 | "outputs": [] 1205 | }, 1206 | { 1207 | "cell_type": "markdown", 1208 | "source": [ 1209 | "### Policy\n", 1210 | "\n", 1211 | "#### Network Architecture\n", 1212 | "\n", 1213 | "Our policy will use a fully connected MLP which takes an observation as input and outputs a distribution over actions. The policy handles both discrete and continuous action spaces using PyTorch's built-in distribution classes, which support sampling and differentiation.\n", 1214 | "\n", 1215 | "For discrete action spaces, we use a `Categorical` distribution parameterized by a vector of logits. The MLP outputs this vector with one logit for each possible action. When we sample from this distribution, we get a single integer action.\n", 1216 | "\n", 1217 | "For continuous action spaces, we use a `MultivariateNormal` distribution parameterized by a mean vector and covariance matrix. The MLP outputs the mean vector directly, with its size matching the action space dimension. The covariance matrix is diagonal, with each diagonal element being the variance for a corresponding action dimension. We learn these variances as separate parameters outside the MLP. This choice simplifies things by removing state-dependence from the variances. When we sample from this distribution, we get a continuous vector action.\n", 1218 | "\n", 1219 | "Note that even though we will only demonstrate the algorithm on a discrete environement, we implement the code for continuous environments too for completeness.\n", 1220 | "\n", 1221 | "To simplify construction of the MLP, we use the helper function `build_mlp` which returns a fully connected MLP with specified by an `MLPConfig`.\n", 1222 | "\n", 1223 | "```\n", 1224 | "@dataclass\n", 1225 | "class MLPConfig:\n", 1226 | " input_size: int\n", 1227 | " output_size: int\n", 1228 | " n_layers: int = 2\n", 1229 | " size: int = 64\n", 1230 | " activation: Activation = 'tanh'\n", 1231 | " output_activation: Activation = 'identity'\n", 1232 | "```\n", 1233 | "\n", 1234 | "#### Sampling Actions\n", 1235 | "\n", 1236 | "Because we're using PyTorch distributions, sampling an action given an observation is as easy as running a forward pass and sampling the resulting distribution. It is worth noting that we do this with `@torch.no_grad()` since, in general, we only need the gradients when performing an update step.\n", 1237 | "\n", 1238 | "#### Update\n", 1239 | "\n", 1240 | "The `update` method follows the classic PyTorch gradient descent structure:\n", 1241 | "1. Zero Gradients\n", 1242 | "2. Forward Pass\n", 1243 | "3. Compute Loss\n", 1244 | "4. Backward Pass (compute gradients)\n", 1245 | "5. Optimiser Step\n", 1246 | "\n", 1247 | "We perform the forward pass again here except with gradients on. Another possible implementation would be to perform only a single forward pass, when sampling an action. However, this would require remembering the `log_prob` throughout the pipeline, which is less clear.\n", 1248 | "\n", 1249 | "Note that since PyTorch implements gradient *decent*, not gradient *ascent*, we negate the loss.\n", 1250 | "\n", 1251 | "#### Optimizer\n", 1252 | "\n", 1253 | "We use the Adam optimiser for more stable training; however, you can think of this as just a gradient decent optimiser which implements the gradient step." 1254 | ], 1255 | "metadata": { 1256 | "id": "9uVBF0pqO9f6" 1257 | } 1258 | }, 1259 | { 1260 | "cell_type": "code", 1261 | "source": [ 1262 | "class PGPolicy(Policy, nn.Module):\n", 1263 | " def __init__(\n", 1264 | " self,\n", 1265 | " mlp_config: MLPConfig,\n", 1266 | " discrete: bool,\n", 1267 | " learning_rate: float,\n", 1268 | " ):\n", 1269 | " # Call both parent initializers\n", 1270 | " Policy.__init__(self)\n", 1271 | " nn.Module.__init__(self)\n", 1272 | "\n", 1273 | " # Define the network\n", 1274 | " if discrete:\n", 1275 | " self.logits_net = build_mlp(mlp_config)\n", 1276 | " parameters = self.logits_net.parameters()\n", 1277 | " else:\n", 1278 | " self.mean_net = build_mlp(mlp_config)\n", 1279 | " # For continuous policies, logvar needs to match the action dimension (output_size)\n", 1280 | "\n", 1281 | " # We use log variances as parameters for numerical stability\n", 1282 | " # This ensures variances remain positive when transformed via exp()\n", 1283 | " # and prevents collapse to extremely small values during training\n", 1284 | " self.logvar = nn.Parameter(\n", 1285 | " torch.zeros(mlp_config.output_size, dtype=dtype)\n", 1286 | " )\n", 1287 | " parameters = itertools.chain([self.logvar], self.mean_net.parameters())\n", 1288 | "\n", 1289 | " self.discrete = discrete\n", 1290 | "\n", 1291 | " # Initialise optimizer\n", 1292 | " self.optimizer = optim.Adam(\n", 1293 | " parameters,\n", 1294 | " learning_rate,\n", 1295 | " )\n", 1296 | "\n", 1297 | " # Move entire model to device\n", 1298 | " self.to(device)\n", 1299 | "\n", 1300 | "\n", 1301 | " def forward(self, observation: torch.FloatTensor) -> distributions.Distribution:\n", 1302 | " \"\"\"\n", 1303 | " This function defines the forward pass of the network.\n", 1304 | " \"\"\"\n", 1305 | " if self.discrete:\n", 1306 | " logits = self.logits_net(observation)\n", 1307 | " return distributions.Categorical(logits=logits)\n", 1308 | " else:\n", 1309 | " mean = self.mean_net(observation)\n", 1310 | " covariance_matrix = torch.diag(torch.exp(self.logvar))\n", 1311 | " return distributions.MultivariateNormal(mean, covariance_matrix)\n", 1312 | "\n", 1313 | "\n", 1314 | " @torch.no_grad()\n", 1315 | " def get_action(self, observation: np.ndarray) -> np.ndarray:\n", 1316 | " \"\"\"Takes a single observation (as a numpy array) and returns a single action (as a numpy array).\"\"\"\n", 1317 | " observation = torch.tensor(observation, dtype=dtype, device=device)\n", 1318 | " distribution = self.forward(observation)\n", 1319 | " action = distribution.sample()\n", 1320 | " action = action.to('cpu').detach().numpy()\n", 1321 | " return action\n", 1322 | "\n", 1323 | "\n", 1324 | " def update(self, experiences: ExperienceData) -> float:\n", 1325 | "\n", 1326 | " self.optimizer.zero_grad()\n", 1327 | "\n", 1328 | " distribution = self.forward(experiences.observations)\n", 1329 | " log_probs = distribution.log_prob(experiences.actions)\n", 1330 | " surrogate_loss = -(log_probs * experiences.weights).mean()\n", 1331 | "\n", 1332 | " surrogate_loss.backward()\n", 1333 | " self.optimizer.step()\n", 1334 | "\n", 1335 | " return surrogate_loss.item()\n" 1336 | ], 1337 | "metadata": { 1338 | "id": "KhpCsHeMbV7m" 1339 | }, 1340 | "execution_count": null, 1341 | "outputs": [] 1342 | }, 1343 | { 1344 | "cell_type": "markdown", 1345 | "source": [ 1346 | "### Agent" 1347 | ], 1348 | "metadata": { 1349 | "id": "4oA-6e31PEWM" 1350 | } 1351 | }, 1352 | { 1353 | "cell_type": "markdown", 1354 | "source": [ 1355 | "In the vanilla policy gradients algorithm, the agent is just a thin wrapper around the policy, which makes it very simple to implement." 1356 | ], 1357 | "metadata": { 1358 | "id": "JN0_ukinWqmi" 1359 | } 1360 | }, 1361 | { 1362 | "cell_type": "code", 1363 | "source": [ 1364 | "class PGAgent(Agent):\n", 1365 | " def __init__(self, policy_args):\n", 1366 | " self._policy = PGPolicy(**policy_args)\n", 1367 | "\n", 1368 | " @property\n", 1369 | " def policy(self) -> Policy:\n", 1370 | " return self._policy\n", 1371 | "\n", 1372 | " def update(self, experiences: ExperienceData) -> float:\n", 1373 | " return self._policy.update(experiences)" 1374 | ], 1375 | "metadata": { 1376 | "id": "Icq6AHhxIA31" 1377 | }, 1378 | "execution_count": null, 1379 | "outputs": [] 1380 | }, 1381 | { 1382 | "cell_type": "code", 1383 | "source": [ 1384 | "discrete = isinstance(cartpole_env.action_space, gym.spaces.Discrete)\n", 1385 | "ac_dim = cartpole_env.action_space.n if discrete else cartpole_env.action_space.shape[0]\n", 1386 | "ob_dim = cartpole_env.observation_space.shape[0]\n", 1387 | "\n", 1388 | "mlp_config = MLPConfig(\n", 1389 | " input_size=ob_dim,\n", 1390 | " output_size=ac_dim,\n", 1391 | " n_layers=2,\n", 1392 | " size=64,\n", 1393 | " activation='tanh',\n", 1394 | " output_activation='identity',\n", 1395 | ")\n", 1396 | "\n", 1397 | "policy_args = {\n", 1398 | " \"mlp_config\": mlp_config,\n", 1399 | " \"discrete\": discrete,\n", 1400 | " \"learning_rate\": 5e-3,\n", 1401 | "}\n", 1402 | "pgagent = PGAgent(policy_args)" 1403 | ], 1404 | "metadata": { 1405 | "id": "z3hbmOgJzNMZ" 1406 | }, 1407 | "execution_count": null, 1408 | "outputs": [] 1409 | }, 1410 | { 1411 | "cell_type": "markdown", 1412 | "source": [ 1413 | "### Training Loop\n", 1414 | "\n", 1415 | "Finally, we can put this all together into a training loop." 1416 | ], 1417 | "metadata": { 1418 | "id": "Gd6LKHHxPHw2" 1419 | } 1420 | }, 1421 | { 1422 | "cell_type": "code", 1423 | "source": [ 1424 | "def train_agent_simple(agent, environment, batch_size, num_updates):\n", 1425 | "\n", 1426 | " for step in range(num_updates):\n", 1427 | " trajectories = sample_trajectories(environment, agent, batch_size)\n", 1428 | " experience_data = compile_trajectories(trajectories, sum_of_rewards)\n", 1429 | " loss = agent.update(experience_data)\n", 1430 | "\n", 1431 | " environment.close()\n", 1432 | " return agent" 1433 | ], 1434 | "metadata": { 1435 | "id": "YVXK-_uWPZr8" 1436 | }, 1437 | "execution_count": null, 1438 | "outputs": [] 1439 | }, 1440 | { 1441 | "cell_type": "markdown", 1442 | "source": [ 1443 | "While this perfectly implements the algorithm, it would be useful to track how the training has gone and plot the metrics. This `Tracker` class lets us do just that. Don't worry too much about the implementation." 1444 | ], 1445 | "metadata": { 1446 | "id": "0F4e5B7VPjXz" 1447 | } 1448 | }, 1449 | { 1450 | "cell_type": "code", 1451 | "source": [ 1452 | "class Tracker:\n", 1453 | " def __init__(self):\n", 1454 | " self.metrics = {}\n", 1455 | "\n", 1456 | " def log(self, metrics_dict, step=None):\n", 1457 | " \"\"\"Log metrics at a specific step\"\"\"\n", 1458 | " for key, value in metrics_dict.items():\n", 1459 | " if key not in self.metrics:\n", 1460 | " self.metrics[key] = []\n", 1461 | " self.metrics[key].append(value)\n", 1462 | "\n", 1463 | " def plot(self, metric_name, y_max=500):\n", 1464 | " \"\"\"Plot a single metric\"\"\"\n", 1465 | " if metric_name in self.metrics:\n", 1466 | " plt.figure(figsize=(8, 4))\n", 1467 | " plt.plot(self.metrics[metric_name])\n", 1468 | " plt.title(metric_name.capitalize())\n", 1469 | " plt.xlabel('Step')\n", 1470 | " plt.ylabel(metric_name)\n", 1471 | " plt.ylim(top=y_max)\n", 1472 | " plt.grid(True)\n", 1473 | " plt.show()\n", 1474 | "\n", 1475 | " def plot_all(self, y_max=500):\n", 1476 | " \"\"\"Plot all metrics\"\"\"\n", 1477 | " for metric_name in self.metrics:\n", 1478 | " self.plot(metric_name, y_max=y_max)" 1479 | ], 1480 | "metadata": { 1481 | "id": "O_3vKL0-EpAf" 1482 | }, 1483 | "execution_count": null, 1484 | "outputs": [] 1485 | }, 1486 | { 1487 | "cell_type": "markdown", 1488 | "source": [ 1489 | "Add a tracker to the train agent and plot the metrics at the end of training." 1490 | ], 1491 | "metadata": { 1492 | "id": "ceYT8azIJMSj" 1493 | } 1494 | }, 1495 | { 1496 | "cell_type": "code", 1497 | "source": [ 1498 | "def train_agent(agent, environment, batch_size, num_updates):\n", 1499 | "\n", 1500 | " tracker = Tracker()\n", 1501 | "\n", 1502 | " # tqdm renders a progress bar during training\n", 1503 | " for step in tqdm(range(num_updates)):\n", 1504 | "\n", 1505 | " trajectories = sample_trajectories(environment, agent, batch_size)\n", 1506 | " experience_data = compile_trajectories(trajectories, sum_of_rewards)\n", 1507 | " surrogate_loss = agent.update(experience_data)\n", 1508 | "\n", 1509 | " average_cumulative_reward = np.mean([t.total_reward for t in trajectories])\n", 1510 | " tracker.log({\n", 1511 | " \"Surrogate Loss\": surrogate_loss,\n", 1512 | " \"Average Cumulative Reward\": average_cumulative_reward,\n", 1513 | " })\n", 1514 | "\n", 1515 | " tracker.plot_all()\n", 1516 | "\n", 1517 | " environment.close()\n", 1518 | " return agent" 1519 | ], 1520 | "metadata": { 1521 | "id": "uy0SUWz1_8lX" 1522 | }, 1523 | "execution_count": null, 1524 | "outputs": [] 1525 | }, 1526 | { 1527 | "cell_type": "markdown", 1528 | "source": [ 1529 | "### Seeds\n", 1530 | "\n", 1531 | "In reinforcement learning, many components involve randomness: environment transitions, reward functions, neural network initialisation, and action sampling from policy distributions. This randomness can significantly impact training outcomes. By setting specific random seeds, we ensure that these random processes produce the same sequence of values each time, making experiments reproducible.\n", 1532 | "\n", 1533 | "The `set_seed` function I've provided sets seeds for both PyTorch (controlling network initialisation and policy sampling) and NumPy (which Gymnasium uses internally). Additionally, we must set the seed when resetting each environment with `env.reset(seed=seed)`. Once seeded, both the neural networks and environment will behave deterministically unless explicitly reseeded.\n", 1534 | "\n", 1535 | "To make sure you have a successful first run, I have pre-selected `seed=2` since it performed well in my tests.\n", 1536 | "However, when evaluating algorithms, we should try multiple seeds (typically 5-10) to account for the variance in performance that different random initialisations can produce. When comparing different algorithms or hyperparameters, use the same set of seeds across all variations to ensure fair comparisons." 1537 | ], 1538 | "metadata": { 1539 | "id": "5vw_5kKc3-Ef" 1540 | } 1541 | }, 1542 | { 1543 | "cell_type": "code", 1544 | "source": [ 1545 | "seed = 2\n", 1546 | "set_seed(seed)\n", 1547 | "cartpole_env.reset(seed=seed)" 1548 | ], 1549 | "metadata": { 1550 | "id": "3VfBno2MdRB9" 1551 | }, 1552 | "execution_count": null, 1553 | "outputs": [] 1554 | }, 1555 | { 1556 | "cell_type": "markdown", 1557 | "source": [ 1558 | "### Train\n", 1559 | "\n", 1560 | "With the seed set, it's time to create a new agent and train it." 1561 | ], 1562 | "metadata": { 1563 | "id": "SjMFPYtT58aS" 1564 | } 1565 | }, 1566 | { 1567 | "cell_type": "code", 1568 | "source": [ 1569 | "pgagent = PGAgent(policy_args)\n", 1570 | "train_agent(pgagent, cartpole_env, batch_size=1000, num_updates=100)" 1571 | ], 1572 | "metadata": { 1573 | "id": "w8CMc3lF57zS" 1574 | }, 1575 | "execution_count": null, 1576 | "outputs": [] 1577 | }, 1578 | { 1579 | "cell_type": "markdown", 1580 | "source": [ 1581 | "### Surrogate Loss vs. Cumulative Reward\n", 1582 | "\n", 1583 | "The two plots show our surrogate loss and the actual RL objective (cumulative reward) across the training run. Notice how these metrics differ: while their general shapes are similar, their scales and detailed patterns vary significantly. This illustrates our earlier theoretical point—these two objectives aren't identical, but they share the same gradients.\n", 1584 | "\n", 1585 | "A crucial difference from supervised learning is that the surrogate loss doesn't consistently decrease, despite our use of gradient descent. As the policy improves, the distribution of experiences shifts, changing the experience weights. However, the optimizer treats these weights as constants that don't depend on the policy parameters. As a result, the loss isn't as reliable a metric in RL as it would be in supervised learning.\n", 1586 | "\n", 1587 | "For this reason, we'll focus primarily on tracking and plotting the average cumulative reward in subsequent sections." 1588 | ], 1589 | "metadata": { 1590 | "id": "8HXR08XEKagY" 1591 | } 1592 | }, 1593 | { 1594 | "cell_type": "code", 1595 | "source": [ 1596 | "def train_agent(agent, environment, batch_size, num_updates):\n", 1597 | "\n", 1598 | " tracker = Tracker()\n", 1599 | "\n", 1600 | " # tqdm renders a progress bar during training\n", 1601 | " for step in tqdm(range(num_updates)):\n", 1602 | "\n", 1603 | " trajectories = sample_trajectories(environment, agent, batch_size)\n", 1604 | " experience_data = compile_trajectories(trajectories, sum_of_rewards)\n", 1605 | " loss = agent.update(experience_data)\n", 1606 | "\n", 1607 | " average_cumulative_reward = np.mean([t.total_reward for t in trajectories])\n", 1608 | " tracker.log({\"Average Cumulative Reward\": average_cumulative_reward})\n", 1609 | "\n", 1610 | " tracker.plot_all()\n", 1611 | "\n", 1612 | " environment.close()\n", 1613 | " return agent" 1614 | ], 1615 | "metadata": { 1616 | "id": "_wSC2tirJ85S" 1617 | }, 1618 | "execution_count": null, 1619 | "outputs": [] 1620 | }, 1621 | { 1622 | "cell_type": "markdown", 1623 | "source": [ 1624 | "# Reducing Variance\n", 1625 | "\n", 1626 | "Above, we used a seed which I knew to perform well. However, in practice we don't have this luxury. To demonstrate how performance varies with randomness, we'll train 5 agents using different seeds [0,1,2,3,4]. To isolate the effect of randomness during training specifically, we'll fix the policy initialisation seed (42) so all agents start with identical parameters, while varying only the training seed between runs.\n" 1627 | ], 1628 | "metadata": { 1629 | "id": "U03fA3n-7XwH" 1630 | } 1631 | }, 1632 | { 1633 | "cell_type": "code", 1634 | "source": [ 1635 | "# This experiment takes about 4mins to run.\n", 1636 | "\n", 1637 | "for seed in range(5):\n", 1638 | " # Constant policy initialisation\n", 1639 | " set_seed(42)\n", 1640 | " pgagent = PGAgent(policy_args)\n", 1641 | "\n", 1642 | " # Seeded training\n", 1643 | " print(f\"Seed {seed}\")\n", 1644 | " set_seed(seed)\n", 1645 | " cartpole_env.reset(seed=seed)\n", 1646 | " train_agent(pgagent, cartpole_env, batch_size=1000, num_updates=100)" 1647 | ], 1648 | "metadata": { 1649 | "id": "r8PYhccir8nH" 1650 | }, 1651 | "execution_count": null, 1652 | "outputs": [] 1653 | }, 1654 | { 1655 | "cell_type": "markdown", 1656 | "source": [ 1657 | "As you can see, both the maximum cumulative reward and time taken to reach it vary significantly between seeds. This high variance stems from the Monte Carlo approximation used in the surrogate loss function. When only a few trajectories are sampled per gradient update, different training runs naturally encounter varied experiences, leading to inconsistent learning outcomes. The most straightforward approach to decrease this variance is simply to sample more trajectories per update. However, doing so is often expensive—especially in real-world environments where each sample requires physical action. As a result, the field has prioritized developing methods that make better use of limited samples.\n", 1658 | "\n", 1659 | "In the last part of this notebook, we'll introduce a simple but very powerful variance reduction technique called reward-to-go, which has become a standard component in virtually all modern RL algorithms." 1660 | ], 1661 | "metadata": { 1662 | "id": "t-S6t9XHxfPL" 1663 | } 1664 | }, 1665 | { 1666 | "cell_type": "markdown", 1667 | "source": [ 1668 | "## Reward-to-go\n", 1669 | "\n", 1670 | "In the current loss, the log probability of the policy action at timestep $t$ is weighted by the sum-of-rewards across *all* timesteps. However, this approach includes rewards that occurred before taking action at time\n", 1671 | "$t$, which couldn't possibly have been influenced by this action due to causality - actions can only affect future rewards, not past ones. If we want the weighting to reflect how good *that specific* action was, we should only consider rewards that could have resulted from it. Therefore, the weight should be only the sum of rewards from timestep $t$ onward, which is known as the reward-to-go.\n", 1672 | "\n" 1673 | ], 1674 | "metadata": { 1675 | "id": "UWtpBL2NvizS" 1676 | } 1677 | }, 1678 | { 1679 | "cell_type": "markdown", 1680 | "source": [ 1681 | "### Unbiased Proof\n", 1682 | "\n", 1683 | "So far, we've given an intuitive explanation of why reward-to-go is a good idea, but to rigorously justify it, we need to prove that using reward-to-go weights in the surrogate loss will still maximise the RL objective. To do this, we need to show that the expectations of the gradient estimators are equal. Formally,\n", 1684 | "\n", 1685 | "$$\n", 1686 | "\\mathbb E_{\\tau\\sim p_\\theta}\\left[ \\sum_{t=0}^{T-1} \\nabla_\\theta \\log \\pi_\\theta(a_{t}|s_{t}) \\sum_{t'=0}^{T-1} r(s_{t'}, a_{t'})\\right] =\n", 1687 | "\\mathbb E_{\\tau\\sim p_\\theta}\\left[ \\sum_{t=0}^{T-1} \\nabla_\\theta \\log \\pi_\\theta(a_{t}|s_{t}) \\sum_{t'=t}^{T-1} r(s_{t'}, a_{t'})\\right]\n", 1688 | "$$\n", 1689 | "\n", 1690 | "To prove this equality, we need to show that the contribution of terms where $t'< t$ to the left hand side expectation is zero.\n", 1691 | "\n", 1692 | "First, rearrange the left-hand side to get:\n", 1693 | "$$\n", 1694 | "\\mathbb E_{\\tau\\sim p_\\theta}\\left[\n", 1695 | " \\sum_{t=0}^{T-1} \\sum_{t'=0}^{T-1}\n", 1696 | " \\nabla_\\theta \\log \\pi_\\theta(a_{t}|s_{t}) r(s_{t'}, a_{t'})\n", 1697 | "\\right]\n", 1698 | "$$\n", 1699 | "\n", 1700 | "Apply linearity of expectations:\n", 1701 | "$$\n", 1702 | "=\\sum_{t=0}^{T-1} \\sum_{t'=0}^{T-1}\n", 1703 | "\\mathbb E_{s_t,a_t,s_{t'},a_{t'}}\\left[\n", 1704 | " \\nabla_\\theta \\log \\pi_\\theta(a_{t}|s_{t}) r(s_{t'}, a_{t'})\n", 1705 | "\\right]\n", 1706 | "$$\n", 1707 | "Note that we've changed to writing the expectations over only the relevant state-action pairs.\n", 1708 | "\n", 1709 | "For all summands, we decompose the expectation by conditioning on $s_t$:\n", 1710 | "$$\n", 1711 | "\\mathbb E_{s_t,a_t,s_{t'},a_{t'}}\\Big[\\nabla_\\theta \\log \\pi_\\theta(a_{t}|s_{t}) r(s_{t'}, a_{t'})\\Big]\n", 1712 | "=\n", 1713 | "\\mathbb E_{s_t}\\bigg[\\mathbb E_{a_t,s_{t'},a_{t'}|s_t}\\Big[\\nabla_\\theta \\log \\pi_\\theta(a_{t}|s_{t}) r(s_{t'}, a_{t'})\\big|s_t\\Big]\\bigg]\n", 1714 | "$$\n", 1715 | "\n", 1716 | "Now consider only the summands for which $t'< t$. By causality, $r(s_{t'},a_{t'})$ is independent of $a_t$, so we can factor out the expectation. This step only holds for $t'< t$!\n", 1717 | "\n", 1718 | "$$\n", 1719 | "=\n", 1720 | "\\mathbb E_{s_t}\\bigg[\n", 1721 | "\\mathbb E_{a_t|s_t}\\Big[\\nabla_\\theta \\log \\pi_\\theta(a_{t}|s_{t}) \\big|s_t\\Big]\n", 1722 | "\\mathbb E_{s_{t'},a_{t'}|s_t}\\Big[r(s_{t'}, a_{t'})\\big|s_t\\Big]\n", 1723 | "\\bigg]\n", 1724 | "$$\n", 1725 | "\n", 1726 | "The next step is to prove that this expectation is zero. To do this, we use a reverse of the log-probability gradient trick we used to derive the original surrogate loss.\n", 1727 | "$$\n", 1728 | "\\begin{align}\n", 1729 | " \\mathbb E_{a_t|s_t}[\\nabla_\\theta \\log \\pi_\\theta(a_t|s_t)|s_t]\n", 1730 | "&= \\sum_{a_t} \\pi_\\theta(a_t|s_t)\\nabla_\\theta \\log \\pi_\\theta(a_t|s_t)\\\\\n", 1731 | "&= \\sum_{a_t} \\pi_\\theta(a_t|s_t)\\frac{\\nabla_\\theta \\pi_\\theta(a_t|s_t)}{\\pi_\\theta(a_t|s_t)}\\\\\n", 1732 | "&= \\nabla_\\theta \\sum_{a_t} \\pi_\\theta(a_t|s_t)\\\\\n", 1733 | "\\end{align}\n", 1734 | "$$\n", 1735 | "\n", 1736 | "This sum is equal to $1$, since it is the total probability. Therefore, the gradient is $0$.\n", 1737 | "\n", 1738 | "This shows that the terms where $t' < t$ contribute zero to the gradient expectation. Since the terms where $t' \\geq t$ are identical in both estimators, the reward-to-go estimator has the same expected gradient as the sum-of-rewards estimator.\n", 1739 | "\n" 1740 | ], 1741 | "metadata": { 1742 | "id": "9eAkiuBhxpu4" 1743 | } 1744 | }, 1745 | { 1746 | "cell_type": "markdown", 1747 | "source": [ 1748 | "### Reduced Variance\n", 1749 | "\n", 1750 | "In the previous section, we proved that the terms removed in the reward-to-go estimator (where $t' < t$) have an expectation of zero. However, these terms still have non-zero variance. Since these past rewards cannot be causally influenced by the current action, they contribute only statistical noise to the gradient estimate without providing useful signal. By removing these zero-expectation terms, the reward-to-go estimator maintains the same expected gradient value while eliminating a source of variance." 1751 | ], 1752 | "metadata": { 1753 | "id": "GaDWc-L99OqE" 1754 | } 1755 | }, 1756 | { 1757 | "cell_type": "markdown", 1758 | "source": [ 1759 | "### Implementation\n", 1760 | "\n", 1761 | "To exemplify the improvements in performance we get from using reward-to-go, let's implement it and run the same seeded experiments as before.\n", 1762 | "\n", 1763 | "First we need a `calculate_weights` function. We can implement this manually using the simple dynamic programming technique of working from the last timestep backwards." 1764 | ], 1765 | "metadata": { 1766 | "id": "uQYwNcvZxrLZ" 1767 | } 1768 | }, 1769 | { 1770 | "cell_type": "code", 1771 | "source": [ 1772 | "def reward_to_go(rewards: torch.Tensor) -> torch.Tensor:\n", 1773 | " T = len(rewards)\n", 1774 | " rtg = torch.zeros_like(rewards)\n", 1775 | "\n", 1776 | " # Base case:\n", 1777 | " rtg[T-1] = rewards[T-1]\n", 1778 | "\n", 1779 | " # Work backwards:\n", 1780 | " for t in range(T-2, -1, -1):\n", 1781 | " rtg[t] = rewards[t] + rtg[t+1]\n", 1782 | "\n", 1783 | " return rtg" 1784 | ], 1785 | "metadata": { 1786 | "id": "oYD86zEJGFTp" 1787 | }, 1788 | "execution_count": null, 1789 | "outputs": [] 1790 | }, 1791 | { 1792 | "cell_type": "markdown", 1793 | "source": [ 1794 | "To implement reward-to-go using Pytorch functions to avoid the overhead of python for loops, we can use `cumsum`. Unfortunately, `cumsum` accumulates from the first element rather than the last. Therefore we need to flip the array before and afterwards." 1795 | ], 1796 | "metadata": { 1797 | "id": "8tBiBw3IGG6O" 1798 | } 1799 | }, 1800 | { 1801 | "cell_type": "code", 1802 | "source": [ 1803 | "def reward_to_go(rewards: torch.Tensor) -> torch.Tensor:\n", 1804 | " rewards_reversed = torch.flip(rewards, [0])\n", 1805 | " rtg_reversed = torch.cumsum(rewards_reversed, 0)\n", 1806 | " rtg = torch.flip(rtg_reversed, [0])\n", 1807 | " return rtg" 1808 | ], 1809 | "metadata": { 1810 | "id": "5KoDRmLHC7hs" 1811 | }, 1812 | "execution_count": null, 1813 | "outputs": [] 1814 | }, 1815 | { 1816 | "cell_type": "markdown", 1817 | "source": [ 1818 | "Let's test this implementation with a simple example:" 1819 | ], 1820 | "metadata": { 1821 | "id": "gmQtIOA3JzUL" 1822 | } 1823 | }, 1824 | { 1825 | "cell_type": "code", 1826 | "source": [ 1827 | "rewards = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])\n", 1828 | "reward_to_go(rewards)\n", 1829 | "# Expected result for reward-to-go:\n", 1830 | "# [1+2+3+4+5, 2+3+4+5, 3+4+5, 4+5, 5]\n", 1831 | "# [15.0, 14.0, 12.0, 9.0, 5.0]" 1832 | ], 1833 | "metadata": { 1834 | "id": "zDZS_t2yJ6pw" 1835 | }, 1836 | "execution_count": null, 1837 | "outputs": [] 1838 | }, 1839 | { 1840 | "cell_type": "markdown", 1841 | "source": [ 1842 | "Now update the training function:" 1843 | ], 1844 | "metadata": { 1845 | "id": "Etx7gdw4Kxy4" 1846 | } 1847 | }, 1848 | { 1849 | "cell_type": "code", 1850 | "source": [ 1851 | "def train_agent_rtg(agent, environment, batch_size, num_updates):\n", 1852 | "\n", 1853 | " tracker = Tracker()\n", 1854 | "\n", 1855 | " for step in tqdm(range(num_updates)):\n", 1856 | "\n", 1857 | " trajectories = sample_trajectories(environment, agent, batch_size)\n", 1858 | " experience_data = compile_trajectories(trajectories, reward_to_go) # REWARD TO GO\n", 1859 | " loss = agent.update(experience_data)\n", 1860 | "\n", 1861 | " average_cumulative_reward = np.mean([t.total_reward for t in trajectories])\n", 1862 | " tracker.log({\"Average Cumulative Reward\": average_cumulative_reward})\n", 1863 | "\n", 1864 | " tracker.plot_all()\n", 1865 | "\n", 1866 | " environment.close()\n", 1867 | " return agent" 1868 | ], 1869 | "metadata": { 1870 | "id": "JVS7TsjCK_Sf" 1871 | }, 1872 | "execution_count": null, 1873 | "outputs": [] 1874 | }, 1875 | { 1876 | "cell_type": "markdown", 1877 | "source": [ 1878 | "Test for variance:" 1879 | ], 1880 | "metadata": { 1881 | "id": "B7uGOMwpLqJG" 1882 | } 1883 | }, 1884 | { 1885 | "cell_type": "code", 1886 | "source": [ 1887 | "for seed in range(5):\n", 1888 | " # Constant policy initialisation\n", 1889 | " set_seed(42)\n", 1890 | " pgagent = PGAgent(policy_args)\n", 1891 | "\n", 1892 | " # Seeded training\n", 1893 | " print(f\"Seed {seed}\")\n", 1894 | " set_seed(seed)\n", 1895 | " cartpole_env.reset(seed=seed)\n", 1896 | " train_agent_rtg(pgagent, cartpole_env, batch_size=1000, num_updates=100)" 1897 | ], 1898 | "metadata": { 1899 | "id": "8tIEvPs4LdpW" 1900 | }, 1901 | "execution_count": null, 1902 | "outputs": [] 1903 | }, 1904 | { 1905 | "cell_type": "markdown", 1906 | "source": [ 1907 | "Woah! That's a huge improvement in performance, well worth the long proof. Now all of the seeds we tested reach 500 cumulative reward, which means we've effectively solved the CartPole environment (reaching its maximum possible reward). It's worth noting that CartPole is considered one of the simpler RL benchmark environments. For more complex environments with higher-dimensional spaces, longer time horizons, and sparse rewards, even reward-to-go—while powerful—isn't enough on its own. We'll need additional variance reduction techniques to achieve stable learning in these scenarios." 1908 | ], 1909 | "metadata": { 1910 | "id": "Iv1SMS5YOLU-" 1911 | } 1912 | }, 1913 | { 1914 | "cell_type": "markdown", 1915 | "source": [ 1916 | "# Next Steps\n", 1917 | "\n", 1918 | "Well done for getting through this introduction to deep RL! In the next parts of the course, we will explore more advanced variance reduction methods which we'll use to improve the policy gradients algorithm:\n", 1919 | "- Discounting\n", 1920 | "- Baselines\n", 1921 | "- Actor-Critic algorithms\n", 1922 | "- Generalised Advantage Estimation (GAE)\n", 1923 | "\n", 1924 | "These techniques form the backbone of modern deep RL algorithms and will enable us to tackle increasingly complex environments beyond CartPole.\n", 1925 | "\n", 1926 | "**Part 2: Discounting**\n", 1927 | "\n", 1928 | "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1UULTQYnymQOpa7nuaw6mDXnvWRV9R_2y?usp=sharing)" 1929 | ], 1930 | "metadata": { 1931 | "id": "N-mMhnwiySZl" 1932 | } 1933 | }, 1934 | { 1935 | "cell_type": "markdown", 1936 | "source": [ 1937 | "# Feedback and Contact Information\n", 1938 | "\n", 1939 | "If you found this notebook useful, please consider:\n", 1940 | "- Sharing it with friends and colleagues who might benefit\n", 1941 | "- Starring the repository on GitHub\n", 1942 | "- Connecting with me on LinkedIn\n", 1943 | "\n", 1944 | "[![GitHub](https://img.shields.io/badge/Star_Repository-gold?logo=github&logoColor=white)](https://github.com/xycoord/deep-rl-course)\n", 1945 | "[![LinkedIn](https://img.shields.io/badge/LinkedIn-Connect-blue)](https://linkedin.com/in/logan-thomson-01a4942ab)\n", 1946 | "\n", 1947 | "If you found anything confusing, I'd be happy to answer any questions:\n", 1948 | "\n", 1949 | "[![GitHub Issues](https://img.shields.io/badge/Submit_Issue-red?logo=github)](https://github.com/xycoord/deep-rl-course/issues)\n", 1950 | "[![GitHub Discussions](https://img.shields.io/badge/Join_Discussion-green?logo=github)](https://github.com/xycoord/deep-rl-course/discussions)\n", 1951 | "\n", 1952 | "I'm continuously working to improve this material and extend it with additional topics. Your feedback helps make this resource better for everyone!" 1953 | ], 1954 | "metadata": { 1955 | "id": "Hu4jPwTx2Otb" 1956 | } 1957 | } 1958 | ] 1959 | } --------------------------------------------------------------------------------