├── .gitignore ├── banner.png ├── README.md ├── LICENSE ├── labs ├── lab_one.ipynb └── lab_two.ipynb └── solutions └── lab_one_complete.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | .DS_store 3 | -------------------------------------------------------------------------------- /banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eje24/iap-diffusion-labs/HEAD/banner.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![banner](banner.png) 2 | # Welcome! 3 | This repository contains the labs for [**6.S184/6.S975: Generative AI with Stochastic Differential Equations**](https://diffusion.csail.mit.edu), as taught at MIT over IAP 2025. For questions and concerns, please email us at `ezraerives@gmail.com` and `phold@mit.edu`. Enjoy :) 4 | 5 | ### Changelog 6 | - 1/22/25: Lab 1: Fix several typos in lab one. 7 | - 1/27/25: Labs 1 + 2: Add clarification regarding truncated PDFs when submitting w/ Colab + Chrome. 8 | - 1/27/25: Lab 2: Clarify numerical issues when sampling conditional SDE for large noise levels. Clarify differences between score parameterizations. 9 | - 3/9/25: Lab 1: resolve timestep bug in Langevin dynamics (thanks Ádám!) 10 | - 3/12/25: Lab 3: Fix guidance embedding dimension bug (thanks Roger!) 11 | - 3/13/25: Lab 3 solutions: Fix bug when sampling conditioning variable (thanks Zewen!) 12 | 13 | ### Acknowledgements 14 | We would like to thank the following individuals for their invaluable feedback and suggestions: 15 | - Cameron Diao 16 | - Tally Portnoi 17 | - Andi Qu 18 | - Ádám Burián 19 | - Roger Trullo 20 | - Zewen Yang 21 | 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Ezra Erives and Peter Holderrieth 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /labs/lab_one.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "93d49b57-85cb-407c-a2e3-c79f682a3dc1", 6 | "metadata": {}, 7 | "source": [ 8 | "# Lab One: Simulating ODEs and SDEs" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "76e99dea-acb4-4fad-a347-1b99567c3188", 14 | "metadata": {}, 15 | "source": [ 16 | "Welcome to lab one! In this lab, we will provide an intuitive and hands-on walk-through of ODEs and SDEs. If you find any mistakes, or have any other feedback, please feel free to email us at `erives@mit.edu` and `phold@mit.edu`. Enjoy!" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "id": "fe737f00-0dff-4ffe-b40a-8187f8c615b8", 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "from abc import ABC, abstractmethod\n", 27 | "from typing import Optional\n", 28 | "import math\n", 29 | "\n", 30 | "import numpy as np\n", 31 | "from matplotlib import pyplot as plt\n", 32 | "from matplotlib.axes._axes import Axes\n", 33 | "import torch\n", 34 | "import torch.distributions as D\n", 35 | "from torch.func import vmap, jacrev\n", 36 | "from tqdm import tqdm\n", 37 | "import seaborn as sns\n", 38 | "\n", 39 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "id": "2dd11ba4-80a5-4354-913b-ae17827f59e1", 45 | "metadata": {}, 46 | "source": [ 47 | "# Part 0: Introduction" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "id": "43565591-329c-43ba-85fe-620b5eb2219b", 53 | "metadata": {}, 54 | "source": [ 55 | "First, let us make precise the central objects of study: *ordinary differential equations* (ODEs) and *stochastic differential equations* (SDEs). The basis of both ODEs and SDEs are time-dependent *vector fields*, which we recall from lecture as being functions $u$ defined by $$u:\\mathbb{R}^d\\times [0,1]\\to \\mathbb{R}^d,\\quad (x,t)\\mapsto u_t(x)$$\n", 56 | "That is, $u_t(x)$ takes in *where in space we are* ($x$) and *where in time we are* ($t$), and spits out the *direction we should be going in* $u_t(x)$. An ODE is then given by $$d X_t = u_t(X_t)dt, \\quad \\quad X_0 = x_0.$$\n", 57 | "Similarly, an SDE is of the form $$d X_t = u_t(X_t)dt + \\sigma_t d W_t, \\quad \\quad X_0 = x_0,$$\n", 58 | "which can be thought of as starting with an ODE given by $u_t$, and adding noise via the *Brownian motion* $(W_t)_{0 \\le t \\le 1}$. The deterministic term is referred to as the *drift coefficient* $u_t(x)$ and amount of noise added is referred to as the *diffusion coefficient* $\\sigma_t$." 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "id": "738edb40-d990-4ba6-8749-261c66ebe51e", 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "class ODE(ABC):\n", 69 | " @abstractmethod\n", 70 | " def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", 71 | " \"\"\"\n", 72 | " Returns the drift coefficient of the ODE.\n", 73 | " Args:\n", 74 | " - xt: state at time t, shape (bs, dim)\n", 75 | " - t: time, shape ()\n", 76 | " Returns:\n", 77 | " - drift_coefficient: shape (batch_size, dim)\n", 78 | " \"\"\"\n", 79 | " pass\n", 80 | "\n", 81 | "class SDE(ABC):\n", 82 | " @abstractmethod\n", 83 | " def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", 84 | " \"\"\"\n", 85 | " Returns the drift coefficient of the ODE.\n", 86 | " Args:\n", 87 | " - xt: state at time t, shape (batch_size, dim)\n", 88 | " - t: time, shape ()\n", 89 | " Returns:\n", 90 | " - drift_coefficient: shape (batch_size, dim)\n", 91 | " \"\"\"\n", 92 | " pass\n", 93 | "\n", 94 | " @abstractmethod\n", 95 | " def diffusion_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", 96 | " \"\"\"\n", 97 | " Returns the diffusion coefficient of the ODE.\n", 98 | " Args:\n", 99 | " - xt: state at time t, shape (batch_size, dim)\n", 100 | " - t: time, shape ()\n", 101 | " Returns:\n", 102 | " - diffusion_coefficient: shape (batch_size, dim)\n", 103 | " \"\"\"\n", 104 | " pass" 105 | ] 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "id": "81f02840-39af-41c0-b105-fdcc3aa3d5d2", 110 | "metadata": {}, 111 | "source": [ 112 | "**Note**: One might consider an ODE to be a special case of SDEs with zero diffusion coefficient. This intuition is valid, however for pedagogical (and performance) reasons, we will treat them separately for the scope of this lab." 113 | ] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "id": "8a220737-5fbb-4405-852d-a9610d27e345", 118 | "metadata": {}, 119 | "source": [ 120 | "# Part 1: Numerical Methods for Simulating ODEs and SDEs\n", 121 | "We may think of ODEs and SDEs as describing the motion of a particle through space. Intuitively, the ODE above says \"start at $X_0=x_0$\", and move so that your instantaneous velocity is given by $u_t(X_t)$. Similarly, the SDE says \"start at $X_0=x_0$\", and move so that your instantaneous velocity is given by $u_t(X_t)$ plus a little bit of random noise given scaled by $\\sigma_t$. Formally, these trajectories traced out by this intuitive descriptions are said to be *solutions* to the ODEs and SDEs, respectively. Numerical methods for computing these solutions are all essentially based on *simulating*, or *integrating*, the ODE or SDE. \n", 122 | "\n", 123 | "In this section we'll implement the *Euler* and *Euler-Maruyama* numerical simulation schemes for integrating ODEs and SDEs, respectively. Recall from lecture that the Euler simulation scheme corresponds to the discretization\n", 124 | "$$d X_t = u_t(X_t) dt \\quad \\quad \\rightarrow \\quad \\quad X_{t + h} = X_t + hu_t(X_t),$$\n", 125 | "where $h = \\Delta t$ is the *step size*. Similarly, the Euler-Maruyama scheme corresponds to the discretization \n", 126 | "$$ dX_t = u(X_t,t) dt + \\sigma_t d W_t \\quad \\quad \\rightarrow \\quad \\quad X_{t + h} = X_t + hu_t(X_t) + \\sqrt{h} \\sigma_t z_t, \\quad z_t \\sim N(0,I_d).$$ \n", 127 | "Let's implement these!" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "id": "98202014-cf8e-4d50-a819-61d937ebea20", 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "class Simulator(ABC):\n", 138 | " @abstractmethod\n", 139 | " def step(self, xt: torch.Tensor, t: torch.Tensor, dt: torch.Tensor):\n", 140 | " \"\"\"\n", 141 | " Takes one simulation step\n", 142 | " Args:\n", 143 | " - xt: state at time t, shape (batch_size, dim)\n", 144 | " - t: time, shape ()\n", 145 | " - dt: time, shape ()\n", 146 | " Returns:\n", 147 | " - nxt: state at time t + dt\n", 148 | " \"\"\"\n", 149 | " pass\n", 150 | "\n", 151 | " @torch.no_grad()\n", 152 | " def simulate(self, x: torch.Tensor, ts: torch.Tensor):\n", 153 | " \"\"\"\n", 154 | " Simulates using the discretization gives by ts\n", 155 | " Args:\n", 156 | " - x_init: initial state at time ts[0], shape (batch_size, dim)\n", 157 | " - ts: timesteps, shape (nts,)\n", 158 | " Returns:\n", 159 | " - x_final: final state at time ts[-1], shape (batch_size, dim)\n", 160 | " \"\"\"\n", 161 | " for t_idx in range(len(ts) - 1):\n", 162 | " t = ts[t_idx]\n", 163 | " h = ts[t_idx + 1] - ts[t_idx]\n", 164 | " x = self.step(x, t, h)\n", 165 | " return x\n", 166 | "\n", 167 | " @torch.no_grad()\n", 168 | " def simulate_with_trajectory(self, x: torch.Tensor, ts: torch.Tensor):\n", 169 | " \"\"\"\n", 170 | " Simulates using the discretization gives by ts\n", 171 | " Args:\n", 172 | " - x_init: initial state at time ts[0], shape (bs, dim)\n", 173 | " - ts: timesteps, shape (num_timesteps,)\n", 174 | " Returns:\n", 175 | " - xs: trajectory of xts over ts, shape (batch_size, num_timesteps, dim)\n", 176 | " \"\"\"\n", 177 | " xs = [x.clone()]\n", 178 | " for t_idx in tqdm(range(len(ts) - 1)):\n", 179 | " t = ts[t_idx]\n", 180 | " h = ts[t_idx + 1] - ts[t_idx]\n", 181 | " x = self.step(x, t, h)\n", 182 | " xs.append(x.clone())\n", 183 | " return torch.stack(xs, dim=1)" 184 | ] 185 | }, 186 | { 187 | "cell_type": "markdown", 188 | "id": "27d9b6b5-e7eb-44a1-a180-aa55d16d8356", 189 | "metadata": {}, 190 | "source": [ 191 | "### Question 1.1: Implement EulerSimulator and EulerMaruyamaSimulator" 192 | ] 193 | }, 194 | { 195 | "cell_type": "markdown", 196 | "id": "c0e5b471-a228-4688-903f-7a0898d8b736", 197 | "metadata": {}, 198 | "source": [ 199 | "**Your job**: Fill in the `step` methods of `EulerSimulator` and `EulerMaruyamaSimulator`." 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "id": "3326c993-b5ef-4b6c-97b5-a403a509b9ac", 206 | "metadata": {}, 207 | "outputs": [], 208 | "source": [ 209 | "class EulerSimulator(Simulator):\n", 210 | " def __init__(self, ode: ODE):\n", 211 | " self.ode = ode\n", 212 | " \n", 213 | " def step(self, xt: torch.Tensor, t: torch.Tensor, h: torch.Tensor):\n", 214 | " raise NotImplementedError(\"Fill me in for Question 1.1!\")" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": null, 220 | "id": "3151adb4-e013-4728-8c96-8bab07ee6262", 221 | "metadata": {}, 222 | "outputs": [], 223 | "source": [ 224 | "class EulerMaruyamaSimulator(Simulator):\n", 225 | " def __init__(self, sde: SDE):\n", 226 | " self.sde = sde\n", 227 | " \n", 228 | " def step(self, xt: torch.Tensor, t: torch.Tensor, h: torch.Tensor):\n", 229 | " raise NotImplementedError(\"Fill me in for Question 1.1!\")" 230 | ] 231 | }, 232 | { 233 | "cell_type": "markdown", 234 | "id": "fcb930b6-63a1-4cc2-8f60-d30de5d380ea", 235 | "metadata": {}, 236 | "source": [ 237 | "**Note:** When the diffusion coefficient is zero, the Euler and Euler-Maruyama simulation are equivalent! " 238 | ] 239 | }, 240 | { 241 | "cell_type": "markdown", 242 | "id": "4afdf59a-25e3-4537-a357-bbbb8ee5332b", 243 | "metadata": {}, 244 | "source": [ 245 | "# Part 2: Visualizing Solutions to SDEs\n", 246 | "Let's get a feel for what the solutions to these SDEs look like in practice (we'll get to ODEs later...). To do so, we we'll implement and visualize two special choices of SDEs from lecture: a (scaled) *Brownian motion*, and an *Ornstein-Uhlenbeck* (OU) process." 247 | ] 248 | }, 249 | { 250 | "cell_type": "markdown", 251 | "id": "84c480e6-bfcf-4105-b21a-002d3a110923", 252 | "metadata": {}, 253 | "source": [ 254 | "### Question 2.1: Implementing Brownian Motion\n", 255 | "First, recall that a Brownian motion is recovered (by definition) by setting $u_t = 0$ and $\\sigma_t = \\sigma$, viz.,\n", 256 | "$$ dX_t = \\sigma dW_t, \\quad \\quad X_0 = 0.$$" 257 | ] 258 | }, 259 | { 260 | "cell_type": "markdown", 261 | "id": "d0e83673-54dc-451e-8422-0110c462b295", 262 | "metadata": {}, 263 | "source": [ 264 | "**Your job**: Intuitively, what might be expect the trajectories of $X_t$ to look like when $\\sigma$ is very large? What about when $\\sigma$ is close to zero?\n", 265 | "\n", 266 | "**Your answer**:" 267 | ] 268 | }, 269 | { 270 | "cell_type": "markdown", 271 | "id": "62ba5c9e-f5cc-41a9-a850-43e46d79b3fb", 272 | "metadata": {}, 273 | "source": [ 274 | "**Your job**: Fill in the `drift_coefficient` and `difusion_coefficient` methods of the `BrownianMotion` class below." 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": null, 280 | "id": "2c372a79-585e-4bbb-a78e-3870dd5c458b", 281 | "metadata": {}, 282 | "outputs": [], 283 | "source": [ 284 | "class BrownianMotion(SDE):\n", 285 | " def __init__(self, sigma: float):\n", 286 | " self.sigma = sigma\n", 287 | " \n", 288 | " def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", 289 | " \"\"\"\n", 290 | " Returns the drift coefficient of the ODE.\n", 291 | " Args:\n", 292 | " - xt: state at time t, shape (bs, dim)\n", 293 | " - t: time, shape ()\n", 294 | " Returns:\n", 295 | " - drift: shape (bs, dim)\n", 296 | " \"\"\"\n", 297 | " raise NotImplementedError(\"Fill me in for Question 2.1!\")\n", 298 | " \n", 299 | " def diffusion_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", 300 | " \"\"\"\n", 301 | " Returns the diffusion coefficient of the ODE.\n", 302 | " Args:\n", 303 | " - xt: state at time t, shape (bs, dim)\n", 304 | " - t: time, shape ()\n", 305 | " Returns:\n", 306 | " - diffusion: shape (bs, dim)\n", 307 | " \"\"\"\n", 308 | " raise NotImplementedError(\"Fill me in for Question 2.1!\")" 309 | ] 310 | }, 311 | { 312 | "cell_type": "markdown", 313 | "id": "79f2318e-ae09-426a-bf57-c44cd5074288", 314 | "metadata": {}, 315 | "source": [ 316 | "Now let's plot! We'll make use of the following utility function." 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": null, 322 | "id": "e577649b-2290-4308-9d5a-7ab614984b4c", 323 | "metadata": {}, 324 | "outputs": [], 325 | "source": [ 326 | "def plot_trajectories_1d(x0: torch.Tensor, simulator: Simulator, timesteps: torch.Tensor, ax: Optional[Axes] = None):\n", 327 | " \"\"\"\n", 328 | " Graphs the trajectories of a one-dimensional SDE with given initial values (x0) and simulation timesteps (timesteps).\n", 329 | " Args:\n", 330 | " - x0: state at time t, shape (num_trajectories, 1)\n", 331 | " - simulator: Simulator object used to simulate\n", 332 | " - t: timesteps to simulate along, shape (num_timesteps,)\n", 333 | " - ax: pyplot Axes object to plot on\n", 334 | " \"\"\"\n", 335 | " if ax is None:\n", 336 | " ax = plt.gca()\n", 337 | " trajectories = simulator.simulate_with_trajectory(x0, timesteps) # (num_trajectories, num_timesteps, ...)\n", 338 | " for trajectory_idx in range(trajectories.shape[0]):\n", 339 | " trajectory = trajectories[trajectory_idx, :, 0] # (num_timesteps,)\n", 340 | " ax.plot(ts.cpu(), trajectory.cpu())" 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "execution_count": null, 346 | "id": "c4245ac9-3092-4c5b-a149-75a909ceaa9e", 347 | "metadata": {}, 348 | "outputs": [], 349 | "source": [ 350 | "sigma = 1.0\n", 351 | "brownian_motion = BrownianMotion(sigma)\n", 352 | "simulator = EulerMaruyamaSimulator(sde=brownian_motion)\n", 353 | "x0 = torch.zeros(5,1).to(device) # Initial values - let's start at zero\n", 354 | "ts = torch.linspace(0.0,5.0,500).to(device) # simulation timesteps\n", 355 | "\n", 356 | "plt.figure(figsize=(8, 8))\n", 357 | "ax = plt.gca()\n", 358 | "ax.set_title(r'Trajectories of Brownian Motion with $\\sigma=$' + str(sigma), fontsize=18)\n", 359 | "ax.set_xlabel(r'Time ($t$)', fontsize=18)\n", 360 | "ax.set_ylabel(r'$X_t$', fontsize=18)\n", 361 | "plot_trajectories_1d(x0, simulator, ts, ax)\n", 362 | "plt.show()" 363 | ] 364 | }, 365 | { 366 | "cell_type": "markdown", 367 | "id": "e3a7bd5c-d558-45e0-b741-6de4d7003776", 368 | "metadata": {}, 369 | "source": [ 370 | "**Your job**: What happens when you vary the value of `sigma`?\n", 371 | "\n", 372 | "**Your answer**:" 373 | ] 374 | }, 375 | { 376 | "cell_type": "markdown", 377 | "id": "9b22c81b-7ce9-4b84-82ab-fac03f741b03", 378 | "metadata": {}, 379 | "source": [ 380 | "### Question 2.2: Implementing an Ornstein-Uhlenbeck Process\n", 381 | "An OU process is given by setting $u_t(X_t) = - \\theta X_t$ and $\\sigma_t = \\sigma$, viz.,\n", 382 | "$$ dX_t = -\\theta X_t\\, dt + \\sigma\\, dW_t, \\quad \\quad X_0 = x_0.$$" 383 | ] 384 | }, 385 | { 386 | "cell_type": "markdown", 387 | "id": "4b0cdb7d-f8bf-4826-9c32-8cb101376103", 388 | "metadata": {}, 389 | "source": [ 390 | "**Your job**: Intuitively, what would the trajectory of $X_t$ look like for a very small value of $\\theta$? What about a very large value of $\\theta$?\n", 391 | "\n", 392 | "**Your answer**:" 393 | ] 394 | }, 395 | { 396 | "cell_type": "markdown", 397 | "id": "12325951-709c-4486-9ea7-f4c22b3cc1ef", 398 | "metadata": {}, 399 | "source": [ 400 | "**Your job**: Fill in the `drift_coefficient` and `difusion_coefficient` methods of the `OUProcess` class below." 401 | ] 402 | }, 403 | { 404 | "cell_type": "code", 405 | "execution_count": null, 406 | "id": "214d0e94-698c-4729-878e-3c2a9881dbe4", 407 | "metadata": {}, 408 | "outputs": [], 409 | "source": [ 410 | "class OUProcess(SDE):\n", 411 | " def __init__(self, theta: float, sigma: float):\n", 412 | " self.theta = theta\n", 413 | " self.sigma = sigma\n", 414 | " \n", 415 | " def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", 416 | " \"\"\"\n", 417 | " Returns the drift coefficient of the ODE.\n", 418 | " Args:\n", 419 | " - xt: state at time t, shape (bs, dim)\n", 420 | " - t: time, shape ()\n", 421 | " Returns:\n", 422 | " - drift: shape (bs, dim)\n", 423 | " \"\"\"\n", 424 | " raise NotImplementedError(\"Fill me in for Question 2.2!\")\n", 425 | " \n", 426 | " def diffusion_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", 427 | " \"\"\"\n", 428 | " Returns the diffusion coefficient of the ODE.\n", 429 | " Args:\n", 430 | " - xt: state at time t, shape (bs, dim)\n", 431 | " - t: time, shape ()\n", 432 | " Returns:\n", 433 | " - diffusion: shape (bs, dim)\n", 434 | " \"\"\"\n", 435 | " raise NotImplementedError(\"Fill me in for Question 2.2!\")" 436 | ] 437 | }, 438 | { 439 | "cell_type": "code", 440 | "execution_count": null, 441 | "id": "bd7dc249-826e-4a81-91f7-fa8fb17278c4", 442 | "metadata": {}, 443 | "outputs": [], 444 | "source": [ 445 | "# Try comparing multiple choices side-by-side\n", 446 | "thetas_and_sigmas = [\n", 447 | " (0.25, 0.0),\n", 448 | " (0.25, 0.25),\n", 449 | " (0.25, 0.5),\n", 450 | " (0.25, 1.0),\n", 451 | "]\n", 452 | "simulation_time = 20.0\n", 453 | "\n", 454 | "num_plots = len(thetas_and_sigmas)\n", 455 | "fig, axes = plt.subplots(1, num_plots, figsize=(8 * num_plots, 7))\n", 456 | "\n", 457 | "for idx, (theta, sigma) in enumerate(thetas_and_sigmas):\n", 458 | " ou_process = OUProcess(theta, sigma)\n", 459 | " simulator = EulerMaruyamaSimulator(sde=ou_process)\n", 460 | " x0 = torch.linspace(-10.0,10.0,10).view(-1,1).to(device) # Initial values - let's start at zero\n", 461 | " ts = torch.linspace(0.0,simulation_time,1000).to(device) # simulation timesteps\n", 462 | "\n", 463 | " ax = axes[idx]\n", 464 | " ax.set_title(f'Trajectories of OU Process with $\\\\sigma = ${sigma}, $\\\\theta = ${theta}', fontsize=15)\n", 465 | " ax.set_xlabel(r'Time ($t$)', fontsize=15)\n", 466 | " ax.set_ylabel(r'$X_t$', fontsize=15)\n", 467 | " plot_trajectories_1d(x0, simulator, ts, ax)\n", 468 | "plt.show()" 469 | ] 470 | }, 471 | { 472 | "cell_type": "markdown", 473 | "id": "76e0c7eb-0eb3-4eb3-a854-5cac18a85baa", 474 | "metadata": {}, 475 | "source": [ 476 | "**Your job**: What do you notice about the convergence of the solutions? Are they converging to a particular point? Or to a distribution? Your answer should be two *qualitative* sentences of the form: \"When ($\\theta$ or $\\sigma$) goes (up or down), we see...\".\n", 477 | "\n", 478 | "**Hint**: Pay close attention to the ratio $D \\triangleq \\frac{\\sigma^2}{2\\theta}$ (see the next few cells below!).\n", 479 | "\n", 480 | "**Your answer**:" 481 | ] 482 | }, 483 | { 484 | "cell_type": "code", 485 | "execution_count": null, 486 | "id": "31502fba-c582-4d67-8fdf-8d474604cd44", 487 | "metadata": {}, 488 | "outputs": [], 489 | "source": [ 490 | "def plot_scaled_trajectories_1d(x0: torch.Tensor, simulator: Simulator, timesteps: torch.Tensor, time_scale: float, label: str, ax: Optional[Axes] = None):\n", 491 | " \"\"\"\n", 492 | " Graphs the trajectories of a one-dimensional SDE with given initial values (x0) and simulation timesteps (timesteps).\n", 493 | " Args:\n", 494 | " - x0: state at time t, shape (num_trajectories, 1)\n", 495 | " - simulator: Simulator object used to simulate\n", 496 | " - t: timesteps to simulate along, shape (num_timesteps,)\n", 497 | " - time_scale: scalar by which to scale time\n", 498 | " - label: self-explanatory\n", 499 | " - ax: pyplot Axes object to plot on\n", 500 | " \"\"\"\n", 501 | " if ax is None:\n", 502 | " ax = plt.gca()\n", 503 | " trajectories = simulator.simulate_with_trajectory(x0, timesteps) # (num_trajectories, num_timesteps, ...)\n", 504 | " for trajectory_idx in range(trajectories.shape[0]):\n", 505 | " trajectory = trajectories[trajectory_idx, :, 0] # (num_timesteps,)\n", 506 | " ax.plot(ts.cpu() * time_scale, trajectory.cpu(), label=label)" 507 | ] 508 | }, 509 | { 510 | "cell_type": "code", 511 | "execution_count": null, 512 | "id": "ad50c533-a8cf-4fc0-be0c-b5fd5c3cbc1c", 513 | "metadata": {}, 514 | "outputs": [], 515 | "source": [ 516 | "# Let's try rescaling with time\n", 517 | "sigmas = [1.0, 2.0, 10.0]\n", 518 | "ds = [0.25, 1.0, 4.0] # sigma**2 / 2t\n", 519 | "simulation_time = 10.0\n", 520 | "\n", 521 | "fig, axes = plt.subplots(len(ds), len(sigmas), figsize=(8 * len(sigmas), 8 * len(ds)))\n", 522 | "axes = axes.reshape((len(ds), len(sigmas)))\n", 523 | "for d_idx, d in enumerate(ds):\n", 524 | " for s_idx, sigma in enumerate(sigmas):\n", 525 | " theta = sigma**2 / 2 / d\n", 526 | " ou_process = OUProcess(theta, sigma)\n", 527 | " simulator = EulerMaruyamaSimulator(sde=ou_process)\n", 528 | " x0 = torch.linspace(-20.0,20.0,20).view(-1,1).to(device)\n", 529 | " time_scale = sigma**2\n", 530 | " ts = torch.linspace(0.0,simulation_time / time_scale,1000).to(device) # simulation timesteps\n", 531 | " ax = axes[d_idx, s_idx]\n", 532 | " plot_scaled_trajectories_1d(x0=x0, simulator=simulator, timesteps=ts, time_scale=time_scale, label=f'Sigma = {sigma}', ax=ax)\n", 533 | " ax.set_title(f'OU Trajectories with Sigma={sigma}, Theta={theta}, D={d}')\n", 534 | " ax.set_xlabel(f't / (sigma^2)')\n", 535 | " ax.set_ylabel('X_t')\n", 536 | "plt.show()" 537 | ] 538 | }, 539 | { 540 | "cell_type": "markdown", 541 | "id": "850111a6-30be-4265-b423-ed23671deaf0", 542 | "metadata": {}, 543 | "source": [ 544 | "**Your job**: What conclusion can we draw from the figure above? One qualitative sentence is fine. We'll revisit this in Section 3.2.\n", 545 | "\n", 546 | "**Your answer**:" 547 | ] 548 | }, 549 | { 550 | "cell_type": "markdown", 551 | "id": "d49bcbdc-237c-4639-a447-b0c13f575a8d", 552 | "metadata": {}, 553 | "source": [ 554 | "# Part 3: Transforming Distributions with SDEs\n", 555 | "In the previous section, we observed how individual *points* are transformed by an SDE. Ultimately, we are interested in understanding how *distributions* are transformed by an SDE (or an ODE...). After all, our goal is to design ODEs and SDEs which transform a noisy distribution (such as the Gaussian $N(0, I_d)$), to the data distribution $p_{\\text{data}}$ of interest. In this section, we will visualize how distributions are transformed by a very particular family of SDEs: *Langevin dynamics*.\n", 556 | "\n", 557 | "First, let's define some distributions to play around with. In practice, there are two qualities one might hope a distribution to have:\n", 558 | "1. The first quality is that one can measure the *density* of a distribution $p(x)$. This ensures that we can compute the gradient $\\nabla \\log p(x)$ of the log density. This quantity is known as the *score* of $p$, and paints a picture of the local geometry of the distribution. Using the score, we will construct and simulate the *Langevin dynamics*, a family of SDEs which \"drive\" samples toward the distribution $\\pi$. In particular, the Langevin dynamics *preserve* the distribution $p(x)$. In Lecture 2, we will make this notion of driving more precise.\n", 559 | "2. The second quality is that we can draw samples from the distribution $p(x)$.\n", 560 | "For simple, toy distributions, such as Gaussians and simple mixture models, it is often true that both qualities are satisfied. For more complex choices of $p$, such as distributions over images, we can sample but cannot measure the density." 561 | ] 562 | }, 563 | { 564 | "cell_type": "code", 565 | "execution_count": null, 566 | "id": "6e2b64f2-732a-4ea3-84a4-8f955ca64f7f", 567 | "metadata": {}, 568 | "outputs": [], 569 | "source": [ 570 | "class Density(ABC):\n", 571 | " \"\"\"\n", 572 | " Distribution with tractable density\n", 573 | " \"\"\"\n", 574 | " @abstractmethod\n", 575 | " def log_density(self, x: torch.Tensor) -> torch.Tensor:\n", 576 | " \"\"\"\n", 577 | " Returns the log density at x.\n", 578 | " Args:\n", 579 | " - x: shape (batch_size, dim)\n", 580 | " Returns:\n", 581 | " - log_density: shape (batch_size, 1)\n", 582 | " \"\"\"\n", 583 | " pass\n", 584 | "\n", 585 | " def score(self, x: torch.Tensor) -> torch.Tensor:\n", 586 | " \"\"\"\n", 587 | " Returns the score dx log density(x)\n", 588 | " Args:\n", 589 | " - x: (batch_size, dim)\n", 590 | " Returns:\n", 591 | " - score: (batch_size, dim)\n", 592 | " \"\"\"\n", 593 | " x = x.unsqueeze(1) # (batch_size, 1, ...)\n", 594 | " score = vmap(jacrev(self.log_density))(x) # (batch_size, 1, 1, 1, ...)\n", 595 | " return score.squeeze((1, 2, 3)) # (batch_size, ...)\n", 596 | "\n", 597 | "class Sampleable(ABC):\n", 598 | " \"\"\"\n", 599 | " Distribution which can be sampled from\n", 600 | " \"\"\"\n", 601 | " @abstractmethod\n", 602 | " def sample(self, num_samples: int) -> torch.Tensor:\n", 603 | " \"\"\"\n", 604 | " Returns the log density at x.\n", 605 | " Args:\n", 606 | " - num_samples: the desired number of samples\n", 607 | " Returns:\n", 608 | " - samples: shape (batch_size, dim)\n", 609 | " \"\"\"\n", 610 | " pass" 611 | ] 612 | }, 613 | { 614 | "cell_type": "code", 615 | "execution_count": null, 616 | "id": "3805b3f8-f0ab-4bb0-a41a-4d97c65e24e8", 617 | "metadata": {}, 618 | "outputs": [], 619 | "source": [ 620 | "# Several plotting utility functions\n", 621 | "def hist2d_sampleable(sampleable: Sampleable, num_samples: int, ax: Optional[Axes] = None, **kwargs):\n", 622 | " if ax is None:\n", 623 | " ax = plt.gca()\n", 624 | " samples = sampleable.sample(num_samples) # (ns, 2)\n", 625 | " ax.hist2d(samples[:,0].cpu(), samples[:,1].cpu(), **kwargs)\n", 626 | "\n", 627 | "def scatter_sampleable(sampleable: Sampleable, num_samples: int, ax: Optional[Axes] = None, **kwargs):\n", 628 | " if ax is None:\n", 629 | " ax = plt.gca()\n", 630 | " samples = sampleable.sample(num_samples) # (ns, 2)\n", 631 | " ax.scatter(samples[:,0].cpu(), samples[:,1].cpu(), **kwargs)\n", 632 | "\n", 633 | "def imshow_density(density: Density, bins: int, scale: float, ax: Optional[Axes] = None, **kwargs):\n", 634 | " if ax is None:\n", 635 | " ax = plt.gca()\n", 636 | " x = torch.linspace(-scale, scale, bins).to(device)\n", 637 | " y = torch.linspace(-scale, scale, bins).to(device)\n", 638 | " X, Y = torch.meshgrid(x, y)\n", 639 | " xy = torch.stack([X.reshape(-1), Y.reshape(-1)], dim=-1)\n", 640 | " density = density.log_density(xy).reshape(bins, bins).T\n", 641 | " im = ax.imshow(density.cpu(), extent=[-scale, scale, -scale, scale], origin='lower', **kwargs)\n", 642 | "\n", 643 | "def contour_density(density: Density, bins: int, scale: float, ax: Optional[Axes] = None, **kwargs):\n", 644 | " if ax is None:\n", 645 | " ax = plt.gca()\n", 646 | " x = torch.linspace(-scale, scale, bins).to(device)\n", 647 | " y = torch.linspace(-scale, scale, bins).to(device)\n", 648 | " X, Y = torch.meshgrid(x, y)\n", 649 | " xy = torch.stack([X.reshape(-1), Y.reshape(-1)], dim=-1)\n", 650 | " density = density.log_density(xy).reshape(bins, bins).T\n", 651 | " im = ax.contour(density.cpu(), extent=[-scale, scale, -scale, scale], origin='lower', **kwargs)" 652 | ] 653 | }, 654 | { 655 | "cell_type": "code", 656 | "execution_count": null, 657 | "id": "498eb6cb-1261-4cc1-b1d0-281e5f73d2cc", 658 | "metadata": {}, 659 | "outputs": [], 660 | "source": [ 661 | "class Gaussian(torch.nn.Module, Sampleable, Density):\n", 662 | " \"\"\"\n", 663 | " Two-dimensional Gaussian. Is a Density and a Sampleable. Wrapper around torch.distributions.MultivariateNormal\n", 664 | " \"\"\"\n", 665 | " def __init__(self, mean, cov):\n", 666 | " \"\"\"\n", 667 | " mean: shape (2,)\n", 668 | " cov: shape (2,2)\n", 669 | " \"\"\"\n", 670 | " super().__init__()\n", 671 | " self.register_buffer(\"mean\", mean)\n", 672 | " self.register_buffer(\"cov\", cov)\n", 673 | "\n", 674 | " @property\n", 675 | " def distribution(self):\n", 676 | " return D.MultivariateNormal(self.mean, self.cov, validate_args=False)\n", 677 | "\n", 678 | " def sample(self, num_samples) -> torch.Tensor:\n", 679 | " return self.distribution.sample((num_samples,))\n", 680 | "\n", 681 | " def log_density(self, x: torch.Tensor):\n", 682 | " return self.distribution.log_prob(x).view(-1, 1)\n", 683 | "\n", 684 | "class GaussianMixture(torch.nn.Module, Sampleable, Density):\n", 685 | " \"\"\"\n", 686 | " Two-dimensional Gaussian mixture model, and is a Density and a Sampleable. Wrapper around torch.distributions.MixtureSameFamily.\n", 687 | " \"\"\"\n", 688 | " def __init__(\n", 689 | " self,\n", 690 | " means: torch.Tensor, # nmodes x data_dim\n", 691 | " covs: torch.Tensor, # nmodes x data_dim x data_dim\n", 692 | " weights: torch.Tensor, # nmodes\n", 693 | " ):\n", 694 | " \"\"\"\n", 695 | " means: shape (nmodes, 2)\n", 696 | " covs: shape (nmodes, 2, 2)\n", 697 | " weights: shape (nmodes, 1)\n", 698 | " \"\"\"\n", 699 | " super().__init__()\n", 700 | " self.nmodes = means.shape[0]\n", 701 | " self.register_buffer(\"means\", means)\n", 702 | " self.register_buffer(\"covs\", covs)\n", 703 | " self.register_buffer(\"weights\", weights)\n", 704 | "\n", 705 | " @property\n", 706 | " def dim(self) -> int:\n", 707 | " return self.means.shape[1]\n", 708 | "\n", 709 | " @property\n", 710 | " def distribution(self):\n", 711 | " return D.MixtureSameFamily(\n", 712 | " mixture_distribution=D.Categorical(probs=self.weights, validate_args=False),\n", 713 | " component_distribution=D.MultivariateNormal(\n", 714 | " loc=self.means,\n", 715 | " covariance_matrix=self.covs,\n", 716 | " validate_args=False,\n", 717 | " ),\n", 718 | " validate_args=False,\n", 719 | " )\n", 720 | "\n", 721 | " def log_density(self, x: torch.Tensor) -> torch.Tensor:\n", 722 | " return self.distribution.log_prob(x).view(-1, 1)\n", 723 | "\n", 724 | " def sample(self, num_samples: int) -> torch.Tensor:\n", 725 | " return self.distribution.sample(torch.Size((num_samples,)))\n", 726 | "\n", 727 | " @classmethod\n", 728 | " def random_2D(\n", 729 | " cls, nmodes: int, std: float, scale: float = 10.0, seed = 0.0\n", 730 | " ) -> \"GaussianMixture\":\n", 731 | " torch.manual_seed(seed)\n", 732 | " means = (torch.rand(nmodes, 2) - 0.5) * scale\n", 733 | " covs = torch.diag_embed(torch.ones(nmodes, 2)) * std ** 2\n", 734 | " weights = torch.ones(nmodes)\n", 735 | " return cls(means, covs, weights)\n", 736 | "\n", 737 | " @classmethod\n", 738 | " def symmetric_2D(\n", 739 | " cls, nmodes: int, std: float, scale: float = 10.0,\n", 740 | " ) -> \"GaussianMixture\":\n", 741 | " angles = torch.linspace(0, 2 * np.pi, nmodes + 1)[:nmodes]\n", 742 | " means = torch.stack([torch.cos(angles), torch.sin(angles)], dim=1) * scale\n", 743 | " covs = torch.diag_embed(torch.ones(nmodes, 2) * std ** 2)\n", 744 | " weights = torch.ones(nmodes) / nmodes\n", 745 | " return cls(means, covs, weights)" 746 | ] 747 | }, 748 | { 749 | "cell_type": "code", 750 | "execution_count": null, 751 | "id": "36ce6533-2e56-42cb-98e7-958c45d583f8", 752 | "metadata": {}, 753 | "outputs": [], 754 | "source": [ 755 | "# Visualize densities\n", 756 | "densities = {\n", 757 | " \"Gaussian\": Gaussian(mean=torch.zeros(2), cov=10 * torch.eye(2)).to(device),\n", 758 | " \"Random Mixture\": GaussianMixture.random_2D(nmodes=5, std=1.0, scale=20.0, seed=3.0).to(device),\n", 759 | " \"Symmetric Mixture\": GaussianMixture.symmetric_2D(nmodes=5, std=1.0, scale=8.0).to(device),\n", 760 | "}\n", 761 | "\n", 762 | "fig, axes = plt.subplots(1,3, figsize=(18, 6))\n", 763 | "bins = 100\n", 764 | "scale = 15\n", 765 | "for idx, (name, density) in enumerate(densities.items()):\n", 766 | " ax = axes[idx]\n", 767 | " ax.set_title(name)\n", 768 | " imshow_density(density, bins, scale, ax, vmin=-15, cmap=plt.get_cmap('Blues'))\n", 769 | " contour_density(density, bins, scale, ax, colors='grey', linestyles='solid', alpha=0.25, levels=20)\n", 770 | "plt.show()\n" 771 | ] 772 | }, 773 | { 774 | "cell_type": "markdown", 775 | "id": "5b51093e-b25a-4cdb-bc0d-ab8c1a3145b3", 776 | "metadata": {}, 777 | "source": [ 778 | "### Question 3.1: Implementing Langevin Dynamics" 779 | ] 780 | }, 781 | { 782 | "cell_type": "markdown", 783 | "id": "91888056-b900-401b-bebb-4e4bf5301afe", 784 | "metadata": {}, 785 | "source": [ 786 | "In this section, we'll simulate the (overdamped) Langevin dynamics $$dX_t = \\frac{1}{2} \\sigma^2\\nabla \\log p(X_t) dt + \\sigma dW_t.$$\n", 787 | "\n", 788 | "**Your job**: Fill in the `drift_coefficient` and `diffusion_coefficient` methods of the class `LangevinSDE` below.\n", 789 | "\n", 790 | "**Hint**: Use `Density.score` to access the score." 791 | ] 792 | }, 793 | { 794 | "cell_type": "code", 795 | "execution_count": null, 796 | "id": "371f00c5-3030-4f6c-945f-4a7d3483ff56", 797 | "metadata": {}, 798 | "outputs": [], 799 | "source": [ 800 | "class LangevinSDE(SDE):\n", 801 | " def __init__(self, sigma: float, density: Density):\n", 802 | " self.sigma = sigma\n", 803 | " self.density = density\n", 804 | " \n", 805 | " def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", 806 | " \"\"\"\n", 807 | " Returns the drift coefficient of the ODE.\n", 808 | " Args:\n", 809 | " - xt: state at time t, shape (bs, dim)\n", 810 | " - t: time, shape ()\n", 811 | " Returns:\n", 812 | " - drift: shape (bs, dim)\n", 813 | " \"\"\"\n", 814 | " raise NotImplementedError(\"Fill me in for Question 3.1!\")\n", 815 | " \n", 816 | " def diffusion_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", 817 | " \"\"\"\n", 818 | " Returns the diffusion coefficient of the ODE.\n", 819 | " Args:\n", 820 | " - xt: state at time t, shape (bs, dim)\n", 821 | " - t: time, shape ()\n", 822 | " Returns:\n", 823 | " - diffusion: shape (bs, dim)\n", 824 | " \"\"\"\n", 825 | " raise NotImplementedError(\"Fill me in for Question 3.1!\")" 826 | ] 827 | }, 828 | { 829 | "cell_type": "markdown", 830 | "id": "52e2235a-befd-4f25-b0ac-849c9868cd5d", 831 | "metadata": {}, 832 | "source": [ 833 | "Now, let's graph the results!" 834 | ] 835 | }, 836 | { 837 | "cell_type": "code", 838 | "execution_count": null, 839 | "id": "f7347543-533a-48c9-9c08-193898b139c9", 840 | "metadata": {}, 841 | "outputs": [], 842 | "source": [ 843 | "# First, let's define two utility functions...\n", 844 | "def every_nth_index(num_timesteps: int, n: int) -> torch.Tensor:\n", 845 | " \"\"\"\n", 846 | " Compute the indices to record in the trajectory\n", 847 | " \"\"\"\n", 848 | " if n == 1:\n", 849 | " return torch.arange(num_timesteps)\n", 850 | " return torch.cat(\n", 851 | " [\n", 852 | " torch.arange(0, num_timesteps - 1, n),\n", 853 | " torch.tensor([num_timesteps - 1]),\n", 854 | " ]\n", 855 | " )\n", 856 | "\n", 857 | "def graph_dynamics(\n", 858 | " num_samples: int,\n", 859 | " source_distribution: Sampleable,\n", 860 | " simulator: Simulator, \n", 861 | " density: Density,\n", 862 | " timesteps: torch.Tensor, \n", 863 | " plot_every: int,\n", 864 | " bins: int,\n", 865 | " scale: float\n", 866 | "):\n", 867 | " \"\"\"\n", 868 | " Plot the evolution of samples from source under the simulation scheme given by simulator (itself a discretization of an ODE or SDE).\n", 869 | " Args:\n", 870 | " - num_samples: the number of samples to simulate\n", 871 | " - source_distribution: distribution from which we draw initial samples at t=0\n", 872 | " - simulator: the discertized simulation scheme used to simulate the dynamics\n", 873 | " - density: the target density\n", 874 | " - timesteps: the timesteps used by the simulator\n", 875 | " - plot_every: number of timesteps between consecutive plots\n", 876 | " - bins: number of bins for imshow\n", 877 | " - scale: scale for imshow\n", 878 | " \"\"\"\n", 879 | " # Simulate\n", 880 | " x0 = source_distribution.sample(num_samples)\n", 881 | " xts = simulator.simulate_with_trajectory(x0, timesteps)\n", 882 | " indices_to_plot = every_nth_index(len(timesteps), plot_every)\n", 883 | " plot_timesteps = timesteps[indices_to_plot]\n", 884 | " plot_xts = xts[:,indices_to_plot]\n", 885 | "\n", 886 | " # Graph\n", 887 | " fig, axes = plt.subplots(2, len(plot_timesteps), figsize=(8*len(plot_timesteps), 16))\n", 888 | " axes = axes.reshape((2,len(plot_timesteps)))\n", 889 | " for t_idx in range(len(plot_timesteps)):\n", 890 | " t = plot_timesteps[t_idx].item()\n", 891 | " xt = plot_xts[:,t_idx]\n", 892 | " # Scatter axes\n", 893 | " scatter_ax = axes[0, t_idx]\n", 894 | " imshow_density(density, bins, scale, scatter_ax, vmin=-15, alpha=0.25, cmap=plt.get_cmap('Blues'))\n", 895 | " scatter_ax.scatter(xt[:,0].cpu(), xt[:,1].cpu(), marker='x', color='black', alpha=0.75, s=15)\n", 896 | " scatter_ax.set_title(f'Samples at t={t:.1f}', fontsize=15)\n", 897 | " scatter_ax.set_xticks([])\n", 898 | " scatter_ax.set_yticks([])\n", 899 | "\n", 900 | " # Kdeplot axes\n", 901 | " kdeplot_ax = axes[1, t_idx]\n", 902 | " imshow_density(density, bins, scale, kdeplot_ax, vmin=-15, alpha=0.5, cmap=plt.get_cmap('Blues'))\n", 903 | " sns.kdeplot(x=xt[:,0].cpu(), y=xt[:,1].cpu(), alpha=0.5, ax=kdeplot_ax,color='grey')\n", 904 | " kdeplot_ax.set_title(f'Density of Samples at t={t:.1f}', fontsize=15)\n", 905 | " kdeplot_ax.set_xticks([])\n", 906 | " kdeplot_ax.set_yticks([])\n", 907 | " kdeplot_ax.set_xlabel(\"\")\n", 908 | " kdeplot_ax.set_ylabel(\"\")\n", 909 | "\n", 910 | " plt.show()" 911 | ] 912 | }, 913 | { 914 | "cell_type": "code", 915 | "execution_count": null, 916 | "id": "f0b244a6-5b25-4b83-a4fb-ff14c29d5eb2", 917 | "metadata": {}, 918 | "outputs": [], 919 | "source": [ 920 | "# Construct the simulator\n", 921 | "target = GaussianMixture.random_2D(nmodes=5, std=0.75, scale=15.0, seed=3.0).to(device)\n", 922 | "sde = LangevinSDE(sigma = 0.6, density = target)\n", 923 | "simulator = EulerMaruyamaSimulator(sde)\n", 924 | "\n", 925 | "# Graph the results!\n", 926 | "graph_dynamics(\n", 927 | " num_samples = 1000,\n", 928 | " source_distribution = Gaussian(mean=torch.zeros(2), cov=20 * torch.eye(2)).to(device),\n", 929 | " simulator=simulator,\n", 930 | " density=target,\n", 931 | " timesteps=torch.linspace(0,5.0,1000).to(device),\n", 932 | " plot_every=334,\n", 933 | " bins=200,\n", 934 | " scale=15\n", 935 | ") " 936 | ] 937 | }, 938 | { 939 | "cell_type": "markdown", 940 | "id": "683d75ef-806a-4012-ae59-6cc7faa5eaf1", 941 | "metadata": {}, 942 | "source": [ 943 | "**Your job**: Try varying the value of $\\sigma$, the number and range of the simulation steps, the source distribution, and target density. What do you notice? Why?\n", 944 | "\n", 945 | "**Your answer**:" 946 | ] 947 | }, 948 | { 949 | "cell_type": "markdown", 950 | "id": "2e3d552e-0d1d-4cae-b5e6-23d270b7193f", 951 | "metadata": {}, 952 | "source": [ 953 | "Note: To run the folowing two **optional** cells, you will need to download the `ffmpeg` library. You can do so using e.g., `conda install -c conda-forge ffmpeg` (or, ideally, `mamba`). Running `pip install ffmpeg` or similar will likely **not** work." 954 | ] 955 | }, 956 | { 957 | "cell_type": "code", 958 | "execution_count": null, 959 | "id": "e98a53ea-d24e-4113-a190-e53d661bacec", 960 | "metadata": {}, 961 | "outputs": [], 962 | "source": [ 963 | "from celluloid import Camera\n", 964 | "from IPython.display import HTML\n", 965 | "\n", 966 | "def animate_dynamics(\n", 967 | " num_samples: int,\n", 968 | " source_distribution: Sampleable,\n", 969 | " simulator: Simulator, \n", 970 | " density: Density,\n", 971 | " timesteps: torch.Tensor, \n", 972 | " animate_every: int,\n", 973 | " bins: int,\n", 974 | " scale: float,\n", 975 | " save_path: str = 'dynamics_animation.mp4'\n", 976 | "):\n", 977 | " \"\"\"\n", 978 | " Plot the evolution of samples from source under the simulation scheme given by simulator (itself a discretization of an ODE or SDE).\n", 979 | " Args:\n", 980 | " - num_samples: the number of samples to simulate\n", 981 | " - source_distribution: distribution from which we draw initial samples at t=0\n", 982 | " - simulator: the discertized simulation scheme used to simulate the dynamics\n", 983 | " - density: the target density\n", 984 | " - timesteps: the timesteps used by the simulator\n", 985 | " - animate_every: number of timesteps between consecutive frames in the resulting animation\n", 986 | " \"\"\"\n", 987 | " # Simulate\n", 988 | " x0 = source_distribution.sample(num_samples)\n", 989 | " xts = simulator.simulate_with_trajectory(x0, timesteps)\n", 990 | " indices_to_animate = every_nth_index(len(timesteps), animate_every)\n", 991 | " animate_timesteps = timesteps[indices_to_animate]\n", 992 | " animate_xts = xts[:, indices_to_animate]\n", 993 | "\n", 994 | " # Graph\n", 995 | " fig, axes = plt.subplots(1, 2, figsize=(16, 8))\n", 996 | " camera = Camera(fig)\n", 997 | " for t_idx in range(len(animate_timesteps)):\n", 998 | " t = animate_timesteps[t_idx].item()\n", 999 | " xt = animate_xts[:,t_idx]\n", 1000 | " # Scatter axes\n", 1001 | " scatter_ax = axes[0]\n", 1002 | " imshow_density(density, bins, scale, scatter_ax, vmin=-15, alpha=0.25, cmap=plt.get_cmap('Blues'))\n", 1003 | " scatter_ax.scatter(xt[:,0].cpu(), xt[:,1].cpu(), marker='x', color='black', alpha=0.75, s=15)\n", 1004 | " scatter_ax.set_title(f'Samples')\n", 1005 | "\n", 1006 | " # Kdeplot axes\n", 1007 | " kdeplot_ax = axes[1]\n", 1008 | " imshow_density(density, bins, scale, kdeplot_ax, vmin=-15, alpha=0.5, cmap=plt.get_cmap('Blues'))\n", 1009 | " sns.kdeplot(x=xt[:,0].cpu(), y=xt[:,1].cpu(), alpha=0.5, ax=kdeplot_ax,color='grey')\n", 1010 | " kdeplot_ax.set_title(f'Density of Samples', fontsize=15)\n", 1011 | " kdeplot_ax.set_xticks([])\n", 1012 | " kdeplot_ax.set_yticks([])\n", 1013 | " kdeplot_ax.set_xlabel(\"\")\n", 1014 | " kdeplot_ax.set_ylabel(\"\")\n", 1015 | " camera.snap()\n", 1016 | " \n", 1017 | " animation = camera.animate()\n", 1018 | " animation.save(save_path)\n", 1019 | " plt.close()\n", 1020 | " return HTML(animation.to_html5_video())" 1021 | ] 1022 | }, 1023 | { 1024 | "cell_type": "code", 1025 | "execution_count": null, 1026 | "id": "b1f4269e-9e6d-4a50-8353-3b380eae4ed7", 1027 | "metadata": {}, 1028 | "outputs": [], 1029 | "source": [ 1030 | "# OPTIONAL CELL\n", 1031 | "# Construct the simulator\n", 1032 | "target = GaussianMixture.random_2D(nmodes=5, std=0.75, scale=15.0, seed=3.0).to(device)\n", 1033 | "sde = LangevinSDE(sigma = 0.6, density = target)\n", 1034 | "simulator = EulerMaruyamaSimulator(sde)\n", 1035 | "\n", 1036 | "# Graph the results!\n", 1037 | "animate_dynamics(\n", 1038 | " num_samples = 1000,\n", 1039 | " source_distribution = Gaussian(mean=torch.zeros(2), cov=20 * torch.eye(2)).to(device),\n", 1040 | " simulator=simulator,\n", 1041 | " density=target,\n", 1042 | " timesteps=torch.linspace(0,5.0,1000).to(device),\n", 1043 | " bins=200,\n", 1044 | " scale=15,\n", 1045 | " animate_every=100\n", 1046 | ") " 1047 | ] 1048 | }, 1049 | { 1050 | "cell_type": "markdown", 1051 | "id": "d149c323-3c9e-40a7-8e00-15fe8b87f3f8", 1052 | "metadata": {}, 1053 | "source": [ 1054 | "### Question 3.2: Ornstein-Uhlenbeck as Langevin Dynamics\n", 1055 | "In this section, we'll finish off with a brief mathematical exercise connecting Langevin dynamics and Ornstein-Uhlenbeck processes. Recall that for (suitably nice) distribution $p$, the *Langevin dynamics* are given by\n", 1056 | "$$dX_t = \\frac{1}{2} \\sigma^2\\nabla \\log p(X_t) dt + \\sigma\\, dW_t, \\quad \\quad X_0 = x_0,$$\n", 1057 | "while for given $\\theta, \\sigma$, the Ornstein-Uhlenbeck process is given by\n", 1058 | "$$dX_t = -\\theta X_t\\, dt + \\sigma\\, dW_t, \\quad \\quad X_0 = x_0.$$" 1059 | ] 1060 | }, 1061 | { 1062 | "cell_type": "markdown", 1063 | "id": "86954c67-510b-4d10-aea1-b5636f4dbb47", 1064 | "metadata": {}, 1065 | "source": [ 1066 | "**Your job**: Show that when $p(x) = N(0, \\frac{\\sigma^2}{2\\theta})$, the score is given by $$\\nabla \\log p(x) = -\\frac{2\\theta}{\\sigma^2}x.$$\n", 1067 | "\n", 1068 | "**Hint**: The probability density of the Gaussian $p(x) = N(0, \\frac{\\sigma^2}{2\\theta})$ is given by $$p(x) = \\frac{\\sqrt{\\theta}}{\\sigma\\sqrt{\\pi}} \\exp\\left(-\\frac{x^2\\theta}{\\sigma^2}\\right).$$\n", 1069 | "\n", 1070 | "**Your answer**:" 1071 | ] 1072 | }, 1073 | { 1074 | "cell_type": "markdown", 1075 | "id": "f3a0f761-c85c-4e30-b497-bc7622bc8e72", 1076 | "metadata": {}, 1077 | "source": [ 1078 | "**Your job**: Conclude that when $p(x) = N(0, \\frac{\\sigma^2}{2\\theta})$, the Langevin dynamics \n", 1079 | "$$dX_t = \\frac{1}{2} \\sigma^2\\nabla \\log p(X_t) dt + \\sigma dW_t,$$\n", 1080 | "is equivalent to the Ornstein-Uhlenbeck process\n", 1081 | "$$ dX_t = -\\theta X_t\\, dt + \\sigma\\, dW_t, \\quad \\quad X_0 = 0.$$\n", 1082 | "\n", 1083 | "**Your answer**:" 1084 | ] 1085 | } 1086 | ], 1087 | "metadata": { 1088 | "kernelspec": { 1089 | "display_name": "mtds", 1090 | "language": "python", 1091 | "name": "mtds" 1092 | }, 1093 | "language_info": { 1094 | "codemirror_mode": { 1095 | "name": "ipython", 1096 | "version": 3 1097 | }, 1098 | "file_extension": ".py", 1099 | "mimetype": "text/x-python", 1100 | "name": "python", 1101 | "nbconvert_exporter": "python", 1102 | "pygments_lexer": "ipython3", 1103 | "version": "3.9.20" 1104 | } 1105 | }, 1106 | "nbformat": 4, 1107 | "nbformat_minor": 5 1108 | } 1109 | -------------------------------------------------------------------------------- /solutions/lab_one_complete.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "93d49b57-85cb-407c-a2e3-c79f682a3dc1", 6 | "metadata": {}, 7 | "source": [ 8 | "# Lab One: Simulating ODEs and SDEs" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "76e99dea-acb4-4fad-a347-1b99567c3188", 14 | "metadata": {}, 15 | "source": [ 16 | "Welcome to lab one! In this lab, we will provide an intuitive and hands-on walk-through of ODEs and SDEs. If you find any mistakes, or have any other feedback, please feel free to email us at `erives@mit.edu` and `phold@mit.edu`. Enjoy!" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "id": "fe737f00-0dff-4ffe-b40a-8187f8c615b8", 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "from abc import ABC, abstractmethod\n", 27 | "from typing import Optional\n", 28 | "import math\n", 29 | "\n", 30 | "import numpy as np\n", 31 | "from matplotlib import pyplot as plt\n", 32 | "from matplotlib.axes._axes import Axes\n", 33 | "import torch\n", 34 | "import torch.distributions as D\n", 35 | "from torch.func import vmap, jacrev\n", 36 | "from tqdm import tqdm\n", 37 | "import seaborn as sns\n", 38 | "\n", 39 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "id": "2dd11ba4-80a5-4354-913b-ae17827f59e1", 45 | "metadata": {}, 46 | "source": [ 47 | "# Part 0: Introduction" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "id": "43565591-329c-43ba-85fe-620b5eb2219b", 53 | "metadata": {}, 54 | "source": [ 55 | "First, let us make precise the central objects of study: *ordinary differential equations* (ODEs) and *stochastic differential equations* (SDEs). The basis of both ODEs and SDEs are time-dependent *vector fields*, which we recall from lecture as being functions $u$ defined by $$u:\\mathbb{R}^d\\times [0,1]\\to \\mathbb{R}^d,\\quad (x,t)\\mapsto u_t(x)$$\n", 56 | "That is, $u_t(x)$ takes in *where in space we are* ($x$) and *where in time we are* ($t$), and spits out the *direction we should be going in* $u_t(x)$. An ODE is then given by $$d X_t = u_t(X_t)dt, \\quad \\quad X_0 = x_0.$$\n", 57 | "Similarly, an SDE is of the form $$d X_t = u_t(X_t)dt + \\sigma_t d W_t, \\quad \\quad X_0 = x_0,$$\n", 58 | "which can be thought of as starting with an ODE given by $u_t$, and adding noise via the *Brownian motion* $(W_t)_{0 \\le t \\le 1}$. The deterministic term is referred to as the *drift coefficient* $u_t(x)$ and amount of noise added is referred to as the *diffusion coefficient* $\\sigma_t$." 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "id": "738edb40-d990-4ba6-8749-261c66ebe51e", 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "class ODE(ABC):\n", 69 | " @abstractmethod\n", 70 | " def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", 71 | " \"\"\"\n", 72 | " Returns the drift coefficient of the ODE.\n", 73 | " Args:\n", 74 | " - xt: state at time t, shape (bs, dim)\n", 75 | " - t: time, shape ()\n", 76 | " Returns:\n", 77 | " - drift_coefficient: shape (batch_size, dim)\n", 78 | " \"\"\"\n", 79 | " pass\n", 80 | "\n", 81 | "class SDE(ABC):\n", 82 | " @abstractmethod\n", 83 | " def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", 84 | " \"\"\"\n", 85 | " Returns the drift coefficient of the ODE.\n", 86 | " Args:\n", 87 | " - xt: state at time t, shape (batch_size, dim)\n", 88 | " - t: time, shape ()\n", 89 | " Returns:\n", 90 | " - drift_coefficient: shape (batch_size, dim)\n", 91 | " \"\"\"\n", 92 | " pass\n", 93 | "\n", 94 | " @abstractmethod\n", 95 | " def diffusion_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", 96 | " \"\"\"\n", 97 | " Returns the diffusion coefficient of the ODE.\n", 98 | " Args:\n", 99 | " - xt: state at time t, shape (batch_size, dim)\n", 100 | " - t: time, shape ()\n", 101 | " Returns:\n", 102 | " - diffusion_coefficient: shape (batch_size, dim)\n", 103 | " \"\"\"\n", 104 | " pass" 105 | ] 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "id": "81f02840-39af-41c0-b105-fdcc3aa3d5d2", 110 | "metadata": {}, 111 | "source": [ 112 | "**Note**: One might consider an ODE to be a special case of SDEs with zero diffusion coefficient. This intuition is valid, however for pedagogical (and performance) reasons, we will treat them separately for the scope of this lab." 113 | ] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "id": "8a220737-5fbb-4405-852d-a9610d27e345", 118 | "metadata": {}, 119 | "source": [ 120 | "# Part 1: Numerical Methods for Simulating ODEs and SDEs\n", 121 | "We may think of ODEs and SDEs as describing the motion of a particle through space. Intuitively, the ODE above says \"start at $X_0=x_0$\", and move so that your instantaneous velocity is given by $u_t(X_t)$. Similarly, the SDE says \"start at $X_0=x_0$\", and move so that your instantaneous velocity is given by $u_t(X_t)$ plus a little bit of random noise given scaled by $\\sigma_t$. Formally, these trajectories traced out by this intuitive descriptions are said to be *solutions* to the ODEs and SDEs, respectively. Numerical methods for computing these solutions are all essentially based on *simulating*, or *integrating*, the ODE or SDE. \n", 122 | "\n", 123 | "In this section we'll implement the *Euler* and *Euler-Maruyama* numerical simulation schemes for integrating ODEs and SDEs, respectively. Recall from lecture that the Euler simulation scheme corresponds to the discretization\n", 124 | "$$d X_t = u_t(X_t) dt \\quad \\quad \\rightarrow \\quad \\quad X_{t + h} = X_t + hu_t(X_t),$$\n", 125 | "where $h = \\Delta t$ is the *step size*. Similarly, the Euler-Maruyama scheme corresponds to the discretization \n", 126 | "$$ dX_t = u(X_t,t) dt + \\sigma_t d W_t \\quad \\quad \\rightarrow \\quad \\quad X_{t + h} = X_t + hu_t(X_t) + \\sqrt{h} \\sigma_t z_t, \\quad z_t \\sim N(0,I_d).$$ \n", 127 | "Let's implement these!" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "id": "98202014-cf8e-4d50-a819-61d937ebea20", 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "class Simulator(ABC):\n", 138 | " @abstractmethod\n", 139 | " def step(self, xt: torch.Tensor, t: torch.Tensor, dt: torch.Tensor):\n", 140 | " \"\"\"\n", 141 | " Takes one simulation step\n", 142 | " Args:\n", 143 | " - xt: state at time t, shape (batch_size, dim)\n", 144 | " - t: time, shape ()\n", 145 | " - dt: time, shape ()\n", 146 | " Returns:\n", 147 | " - nxt: state at time t + dt\n", 148 | " \"\"\"\n", 149 | " pass\n", 150 | "\n", 151 | " @torch.no_grad()\n", 152 | " def simulate(self, x: torch.Tensor, ts: torch.Tensor):\n", 153 | " \"\"\"\n", 154 | " Simulates using the discretization gives by ts\n", 155 | " Args:\n", 156 | " - x_init: initial state at time ts[0], shape (batch_size, dim)\n", 157 | " - ts: timesteps, shape (nts,)\n", 158 | " Returns:\n", 159 | " - x_fina: final state at time ts[-1], shape (batch_size, dim)\n", 160 | " \"\"\"\n", 161 | " for t_idx in range(len(ts) - 1):\n", 162 | " t = ts[t_idx]\n", 163 | " h = ts[t_idx + 1] - ts[t_idx]\n", 164 | " x = self.step(x, t, h)\n", 165 | " return x\n", 166 | "\n", 167 | " @torch.no_grad()\n", 168 | " def simulate_with_trajectory(self, x: torch.Tensor, ts: torch.Tensor):\n", 169 | " \"\"\"\n", 170 | " Simulates using the discretization gives by ts\n", 171 | " Args:\n", 172 | " - x_init: initial state at time ts[0], shape (bs, dim)\n", 173 | " - ts: timesteps, shape (num_timesteps,)\n", 174 | " Returns:\n", 175 | " - xs: trajectory of xts over ts, shape (batch_size, num_timesteps, dim)\n", 176 | " \"\"\"\n", 177 | " xs = [x.clone()]\n", 178 | " for t_idx in tqdm(range(len(ts) - 1)):\n", 179 | " t = ts[t_idx]\n", 180 | " h = ts[t_idx + 1] - ts[t_idx]\n", 181 | " x = self.step(x, t, h)\n", 182 | " xs.append(x.clone())\n", 183 | " return torch.stack(xs, dim=1)" 184 | ] 185 | }, 186 | { 187 | "cell_type": "markdown", 188 | "id": "27d9b6b5-e7eb-44a1-a180-aa55d16d8356", 189 | "metadata": {}, 190 | "source": [ 191 | "### Question 1.1: Integrate EulerSimulator and EulerMaruyamaSimulator" 192 | ] 193 | }, 194 | { 195 | "cell_type": "markdown", 196 | "id": "c0e5b471-a228-4688-903f-7a0898d8b736", 197 | "metadata": {}, 198 | "source": [ 199 | "**Your job**: Fill in the `step` methods of `EulerSimulator` and `EulerMaruyamaSimulator`." 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "id": "3326c993-b5ef-4b6c-97b5-a403a509b9ac", 206 | "metadata": {}, 207 | "outputs": [], 208 | "source": [ 209 | "class EulerSimulator(Simulator):\n", 210 | " def __init__(self, ode: ODE):\n", 211 | " self.ode = ode\n", 212 | " \n", 213 | " def step(self, xt: torch.Tensor, t: torch.Tensor, h: torch.Tensor):\n", 214 | " return xt + self.ode.drift_coefficient(xt,t) * h" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": null, 220 | "id": "3151adb4-e013-4728-8c96-8bab07ee6262", 221 | "metadata": {}, 222 | "outputs": [], 223 | "source": [ 224 | "class EulerMaruyamaSimulator(Simulator):\n", 225 | " def __init__(self, sde: SDE):\n", 226 | " self.sde = sde\n", 227 | " \n", 228 | " def step(self, xt: torch.Tensor, t: torch.Tensor, h: torch.Tensor):\n", 229 | " return xt + self.sde.drift_coefficient(xt,t) * h + self.sde.diffusion_coefficient(xt,t) * torch.sqrt(h) * torch.randn_like(xt)" 230 | ] 231 | }, 232 | { 233 | "cell_type": "markdown", 234 | "id": "fcb930b6-63a1-4cc2-8f60-d30de5d380ea", 235 | "metadata": {}, 236 | "source": [ 237 | "**Note:** When the diffusion coefficient is zero, the Euler and Euler-Maruyama simulation are equivalent! " 238 | ] 239 | }, 240 | { 241 | "cell_type": "markdown", 242 | "id": "4afdf59a-25e3-4537-a357-bbbb8ee5332b", 243 | "metadata": {}, 244 | "source": [ 245 | "# Part 2: Visualizing Solutions to SDEs\n", 246 | "Let's get a feel for what the solutions to these SDEs look like in practice (we'll get to ODEs later...). To do so, we we'll implement and visualize two special choices of SDEs from lecture: a (scaled) *Brownian motion*, and an *Ornstein-Uhlenbeck* (OU) process." 247 | ] 248 | }, 249 | { 250 | "cell_type": "markdown", 251 | "id": "84c480e6-bfcf-4105-b21a-002d3a110923", 252 | "metadata": {}, 253 | "source": [ 254 | "### Question 2.1: Implementing Brownian Motion\n", 255 | "First, recall that a Brownian motion is recovered (by definition) by setting $u_t = 0$ and $\\sigma_t = \\sigma$, viz.,\n", 256 | "$$ dX_t = \\sigma dW_t, \\quad \\quad X_0 = 0.$$" 257 | ] 258 | }, 259 | { 260 | "cell_type": "markdown", 261 | "id": "d0e83673-54dc-451e-8422-0110c462b295", 262 | "metadata": {}, 263 | "source": [ 264 | "**Your job**: Intuitively, what might be expect the trajectories of $X_t$ to look like when $\\sigma$ is very large? What about when $\\sigma$ is close to zero?\n", 265 | "\n", 266 | "**Your answer**:" 267 | ] 268 | }, 269 | { 270 | "cell_type": "markdown", 271 | "id": "62ba5c9e-f5cc-41a9-a850-43e46d79b3fb", 272 | "metadata": {}, 273 | "source": [ 274 | "**Your job**: Fill in the `drift_coefficient` and `difusion_coefficient` methods of the `BrownianMotion` class below." 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": null, 280 | "id": "2c372a79-585e-4bbb-a78e-3870dd5c458b", 281 | "metadata": {}, 282 | "outputs": [], 283 | "source": [ 284 | "class BrownianMotion(SDE):\n", 285 | " def __init__(self, sigma: float):\n", 286 | " self.sigma = sigma\n", 287 | " \n", 288 | " def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", 289 | " \"\"\"\n", 290 | " Returns the drift coefficient of the ODE.\n", 291 | " Args:\n", 292 | " - xt: state at time t, shape (bs, dim)\n", 293 | " - t: time, shape ()\n", 294 | " Returns:\n", 295 | " - drift: shape (bs, dim)\n", 296 | " \"\"\"\n", 297 | " return torch.zeros_like(xt)\n", 298 | "\n", 299 | " def diffusion_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", 300 | " \"\"\"\n", 301 | " Returns the diffusion coefficient of the ODE.\n", 302 | " Args:\n", 303 | " - xt: state at time t, shape (bs, dim)\n", 304 | " - t: time, shape ()\n", 305 | " Returns:\n", 306 | " - diffusion: shape (bs, dim)\n", 307 | " \"\"\"\n", 308 | " return self.sigma * torch.ones_like(xt)" 309 | ] 310 | }, 311 | { 312 | "cell_type": "markdown", 313 | "id": "79f2318e-ae09-426a-bf57-c44cd5074288", 314 | "metadata": {}, 315 | "source": [ 316 | "Now let's plot! We'll make use of the following utility function." 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": null, 322 | "id": "e577649b-2290-4308-9d5a-7ab614984b4c", 323 | "metadata": {}, 324 | "outputs": [], 325 | "source": [ 326 | "def plot_trajectories_1d(x0: torch.Tensor, simulator: Simulator, timesteps: torch.Tensor, ax: Optional[Axes] = None):\n", 327 | " \"\"\"\n", 328 | " Graphs the trajectories of a one-dimensional SDE with given initial values (x0) and simulation timesteps (timesteps).\n", 329 | " Args:\n", 330 | " - x0: state at time t, shape (num_trajectories, 1)\n", 331 | " - simulator: Simulator object used to simulate\n", 332 | " - t: timesteps to simulate along, shape (num_timesteps,)\n", 333 | " - ax: pyplot Axes object to plot on\n", 334 | " \"\"\"\n", 335 | " if ax is None:\n", 336 | " ax = plt.gca()\n", 337 | " trajectories = simulator.simulate_with_trajectory(x0, timesteps) # (num_trajectories, num_timesteps, ...)\n", 338 | " for trajectory_idx in range(trajectories.shape[0]):\n", 339 | " trajectory = trajectories[trajectory_idx, :, 0] # (num_timesteps,)\n", 340 | " ax.plot(ts.cpu(), trajectory.cpu())" 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "execution_count": null, 346 | "id": "c4245ac9-3092-4c5b-a149-75a909ceaa9e", 347 | "metadata": {}, 348 | "outputs": [], 349 | "source": [ 350 | "sigma = 1.0\n", 351 | "brownian_motion = BrownianMotion(sigma)\n", 352 | "simulator = EulerMaruyamaSimulator(sde=brownian_motion)\n", 353 | "x0 = torch.zeros(5,1).to(device) # Initial values - let's start at zero\n", 354 | "ts = torch.linspace(0.0,5.0,500).to(device) # simulation timesteps\n", 355 | "\n", 356 | "plt.figure(figsize=(8, 8))\n", 357 | "ax = plt.gca()\n", 358 | "ax.set_title(r'Trajectories of Brownian Motion with $\\sigma=$' + str(sigma), fontsize=18)\n", 359 | "ax.set_xlabel(r'Time ($t$)', fontsize=18)\n", 360 | "ax.set_ylabel(r'$X_t$', fontsize=18)\n", 361 | "plot_trajectories_1d(x0, simulator, ts, ax)\n", 362 | "plt.show()" 363 | ] 364 | }, 365 | { 366 | "cell_type": "markdown", 367 | "id": "e3a7bd5c-d558-45e0-b741-6de4d7003776", 368 | "metadata": {}, 369 | "source": [ 370 | "**Your job**: What happens when you vary the value of `sigma`?\n", 371 | "\n", 372 | "**Your answer**:" 373 | ] 374 | }, 375 | { 376 | "cell_type": "markdown", 377 | "id": "9b22c81b-7ce9-4b84-82ab-fac03f741b03", 378 | "metadata": {}, 379 | "source": [ 380 | "### Question 2.2: Implementing an Ornstein-Uhlenbeck Process\n", 381 | "An OU process is given by setting $u_t(X_t) = - \\theta X_t$ and $\\sigma_t = \\sigma$, viz.,\n", 382 | "$$ dX_t = -\\theta X_t\\, dt + \\sigma\\, dW_t, \\quad \\quad X_0 = x_0.$$" 383 | ] 384 | }, 385 | { 386 | "cell_type": "markdown", 387 | "id": "4b0cdb7d-f8bf-4826-9c32-8cb101376103", 388 | "metadata": {}, 389 | "source": [ 390 | "**Your job**: Intuitively, what would the trajectory of $X_t$ look like for a very small value of $\\theta$? What about a very large value of $\\theta$?\n", 391 | "\n", 392 | "**Your answer**:" 393 | ] 394 | }, 395 | { 396 | "cell_type": "markdown", 397 | "id": "12325951-709c-4486-9ea7-f4c22b3cc1ef", 398 | "metadata": {}, 399 | "source": [ 400 | "**Your job**: Fill in the `drift_coefficient` and `difusion_coefficient` methods of the `OUProcess` class below." 401 | ] 402 | }, 403 | { 404 | "cell_type": "code", 405 | "execution_count": null, 406 | "id": "214d0e94-698c-4729-878e-3c2a9881dbe4", 407 | "metadata": {}, 408 | "outputs": [], 409 | "source": [ 410 | "class OUProcess(SDE):\n", 411 | " def __init__(self, theta: float, sigma: float):\n", 412 | " self.theta = theta\n", 413 | " self.sigma = sigma\n", 414 | " \n", 415 | " def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", 416 | " \"\"\"\n", 417 | " Returns the drift coefficient of the ODE.\n", 418 | " Args:\n", 419 | " - xt: state at time t, shape (bs, dim)\n", 420 | " - t: time, shape ()\n", 421 | " Returns:\n", 422 | " - drift: shape (bs, dim)\n", 423 | " \"\"\"\n", 424 | " return - self.theta * xt\n", 425 | "\n", 426 | " def diffusion_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", 427 | " \"\"\"\n", 428 | " Returns the diffusion coefficient of the ODE.\n", 429 | " Args:\n", 430 | " - xt: state at time t, shape (bs, dim)\n", 431 | " - t: time, shape ()\n", 432 | " Returns:\n", 433 | " - diffusion: shape (bs, dim)\n", 434 | " \"\"\"\n", 435 | " return self.sigma * torch.ones_like(xt)" 436 | ] 437 | }, 438 | { 439 | "cell_type": "code", 440 | "execution_count": null, 441 | "id": "bd7dc249-826e-4a81-91f7-fa8fb17278c4", 442 | "metadata": {}, 443 | "outputs": [], 444 | "source": [ 445 | "# Try comparing multiple choices side-by-side\n", 446 | "thetas_and_sigmas = [\n", 447 | " (0.25, 0.0),\n", 448 | " (0.25, 0.25),\n", 449 | " (0.25, 0.5),\n", 450 | " (0.25, 1.0),\n", 451 | "]\n", 452 | "simulation_time = 20.0\n", 453 | "\n", 454 | "num_plots = len(thetas_and_sigmas)\n", 455 | "fig, axes = plt.subplots(1, num_plots, figsize=(8 * num_plots, 7))\n", 456 | "\n", 457 | "for idx, (theta, sigma) in enumerate(thetas_and_sigmas):\n", 458 | " ou_process = OUProcess(theta, sigma)\n", 459 | " simulator = EulerMaruyamaSimulator(sde=ou_process)\n", 460 | " x0 = torch.linspace(-10.0,10.0,10).view(-1,1).to(device) # Initial values - let's start at zero\n", 461 | " ts = torch.linspace(0.0,simulation_time,1000).to(device) # simulation timesteps\n", 462 | "\n", 463 | " ax = axes[idx]\n", 464 | " ax.set_title(f'Trajectories of OU Process with $\\\\sigma = ${sigma}, $\\\\theta = ${theta}', fontsize=15)\n", 465 | " ax.set_xlabel(r'Time ($t$)', fontsize=15)\n", 466 | " ax.set_ylabel(r'$X_t$', fontsize=15)\n", 467 | " plot_trajectories_1d(x0, simulator, ts, ax)\n", 468 | "plt.show()" 469 | ] 470 | }, 471 | { 472 | "cell_type": "markdown", 473 | "id": "76e0c7eb-0eb3-4eb3-a854-5cac18a85baa", 474 | "metadata": {}, 475 | "source": [ 476 | "**Your job**: What do you notice about the convergence of the solutions? Are they converging to a particular point? Or to a distribution? Your answer should be two *qualitative* sentences of the form: \"When ($\\theta$ or $\\sigma$) goes (up or down), we see...\".\n", 477 | "\n", 478 | "**Hint**: Pay close attention to the ratio $D \\triangleq \\frac{\\sigma^2}{2\\theta}$ (see the next few cells below!).\n", 479 | "\n", 480 | "**Your answer**:" 481 | ] 482 | }, 483 | { 484 | "cell_type": "code", 485 | "execution_count": null, 486 | "id": "31502fba-c582-4d67-8fdf-8d474604cd44", 487 | "metadata": {}, 488 | "outputs": [], 489 | "source": [ 490 | "def plot_scaled_trajectories_1d(x0: torch.Tensor, simulator: Simulator, timesteps: torch.Tensor, time_scale: float, label: str, ax: Optional[Axes] = None):\n", 491 | " \"\"\"\n", 492 | " Graphs the trajectories of a one-dimensional SDE with given initial values (x0) and simulation timesteps (timesteps).\n", 493 | " Args:\n", 494 | " - x0: state at time t, shape (num_trajectories, 1)\n", 495 | " - simulator: Simulator object used to simulate\n", 496 | " - t: timesteps to simulate along, shape (num_timesteps,)\n", 497 | " - time_scale: scalar by which to scale time\n", 498 | " - label: self-explanatory\n", 499 | " - ax: pyplot Axes object to plot on\n", 500 | " \"\"\"\n", 501 | " if ax is None:\n", 502 | " print('moo')\n", 503 | " ax = plt.gca()\n", 504 | " trajectories = simulator.simulate_with_trajectory(x0, timesteps) # (num_trajectories, num_timesteps, ...)\n", 505 | " for trajectory_idx in range(trajectories.shape[0]):\n", 506 | " trajectory = trajectories[trajectory_idx, :, 0] # (num_timesteps,)\n", 507 | " ax.plot(ts.cpu() * time_scale, trajectory.cpu(), label=label)" 508 | ] 509 | }, 510 | { 511 | "cell_type": "code", 512 | "execution_count": null, 513 | "id": "ad50c533-a8cf-4fc0-be0c-b5fd5c3cbc1c", 514 | "metadata": {}, 515 | "outputs": [], 516 | "source": [ 517 | "# Let's try rescaling with time\n", 518 | "sigmas = [1.0, 2.0, 10.0]\n", 519 | "ds = [0.25, 1.0, 4.0] # sigma**2 / 2t\n", 520 | "simulation_time = 10.0\n", 521 | "\n", 522 | "fig, axes = plt.subplots(len(ds), len(sigmas), figsize=(8 * len(sigmas), 8 * len(ds)))\n", 523 | "axes = axes.reshape((len(ds), len(sigmas)))\n", 524 | "for d_idx, d in enumerate(ds):\n", 525 | " for s_idx, sigma in enumerate(sigmas):\n", 526 | " theta = sigma**2 / 2 / d\n", 527 | " ou_process = OUProcess(theta, sigma)\n", 528 | " simulator = EulerMaruyamaSimulator(sde=ou_process)\n", 529 | " x0 = torch.linspace(-20.0,20.0,20).view(-1,1).to(device)\n", 530 | " time_scale = sigma**2\n", 531 | " ts = torch.linspace(0.0,simulation_time / time_scale,1000).to(device) # simulation timesteps\n", 532 | " ax = axes[d_idx, s_idx]\n", 533 | " plot_scaled_trajectories_1d(x0=x0, simulator=simulator, timesteps=ts, time_scale=time_scale, label=f'Sigma = {sigma}', ax=ax)\n", 534 | " ax.set_title(f'OU Trajectories with Sigma={sigma}, Theta={theta}, D={d}')\n", 535 | " ax.set_xlabel(f't / (sigma^2)')\n", 536 | " ax.set_ylabel('X_t')\n", 537 | "plt.show()" 538 | ] 539 | }, 540 | { 541 | "cell_type": "markdown", 542 | "id": "850111a6-30be-4265-b423-ed23671deaf0", 543 | "metadata": {}, 544 | "source": [ 545 | "**Your job**: What conclusion can we draw from the figure above? One qualitative sentence is fine. We'll revisit this in Section 3.2.\n", 546 | "\n", 547 | "**Your answer**:" 548 | ] 549 | }, 550 | { 551 | "cell_type": "markdown", 552 | "id": "d49bcbdc-237c-4639-a447-b0c13f575a8d", 553 | "metadata": {}, 554 | "source": [ 555 | "# Part 3: Transforming Distributions with SDEs\n", 556 | "In the previous section, we observed how individual *points* are transformed by an SDE. Ultimately, we are interested in understanding how *distributions* are transformed by an SDE (or an ODE...). After all, our goal is to design ODEs and SDEs which transform a noisy distribution (such as the Gaussian $N(0, I_d)$), to the data distribution $p_{\\text{data}}$ of interest. In this section, we will visualize how distributions are transformed by a very particular family of SDEs: *Langevin dynamics*.\n", 557 | "\n", 558 | "First, let's define some distributions to play around with. In practice, there are two qualities one might hope a distribution to have:\n", 559 | "1. The first quality is that one can measure the *density* of a distribution $p(x)$. This ensures that we can compute the gradient $\\nabla \\log p(x)$ of the log density. This quantity is known as the *score* of $p$, and paints a picture of the local geometry of the distribution. Using the score, we will construct and simulate the *Langevin dynamics*, a family of SDEs which \"drive\" samples toward the distribution $\\pi$. In particular, the Langevin dynamics *preserve* the distribution $p(x)$. In Lecture 2, we will make this notion of driving more precise.\n", 560 | "2. The second quality is that we can draw samples from the distribution $p(x)$.\n", 561 | "For simple, toy distributions, such as Gaussians and simple mixture models, it is often true that both qualities are satisfied. For more complex choices of $p$, such as distributions over images, we can sample but cannot measure the density." 562 | ] 563 | }, 564 | { 565 | "cell_type": "code", 566 | "execution_count": null, 567 | "id": "6e2b64f2-732a-4ea3-84a4-8f955ca64f7f", 568 | "metadata": {}, 569 | "outputs": [], 570 | "source": [ 571 | "class Density(ABC):\n", 572 | " \"\"\"\n", 573 | " Distribution with tractable density\n", 574 | " \"\"\"\n", 575 | " @abstractmethod\n", 576 | " def log_density(self, x: torch.Tensor) -> torch.Tensor:\n", 577 | " \"\"\"\n", 578 | " Returns the log density at x.\n", 579 | " Args:\n", 580 | " - x: shape (batch_size, dim)\n", 581 | " Returns:\n", 582 | " - log_density: shape (batch_size, 1)\n", 583 | " \"\"\"\n", 584 | " pass\n", 585 | "\n", 586 | " def score(self, x: torch.Tensor) -> torch.Tensor:\n", 587 | " \"\"\"\n", 588 | " Returns the score dx log density(x)\n", 589 | " Args:\n", 590 | " - x: (batch_size, dim)\n", 591 | " Returns:\n", 592 | " - score: (batch_size, dim)\n", 593 | " \"\"\"\n", 594 | " x = x.unsqueeze(1) # (batch_size, 1, ...)\n", 595 | " score = vmap(jacrev(self.log_density))(x) # (batch_size, 1, 1, 1, ...)\n", 596 | " return score.squeeze((1, 2, 3)) # (batch_size, ...)\n", 597 | "\n", 598 | "class Sampleable(ABC):\n", 599 | " \"\"\"\n", 600 | " Distribution which can be sampled from\n", 601 | " \"\"\"\n", 602 | " @abstractmethod\n", 603 | " def sample(self, num_samples: int) -> torch.Tensor:\n", 604 | " \"\"\"\n", 605 | " Returns the log density at x.\n", 606 | " Args:\n", 607 | " - num_samples: the desired number of samples\n", 608 | " Returns:\n", 609 | " - samples: shape (batch_size, dim)\n", 610 | " \"\"\"\n", 611 | " pass" 612 | ] 613 | }, 614 | { 615 | "cell_type": "code", 616 | "execution_count": null, 617 | "id": "3805b3f8-f0ab-4bb0-a41a-4d97c65e24e8", 618 | "metadata": {}, 619 | "outputs": [], 620 | "source": [ 621 | "# Several plotting utility functions\n", 622 | "def hist2d_sampleable(sampleable: Sampleable, num_samples: int, ax: Optional[Axes] = None, **kwargs):\n", 623 | " if ax is None:\n", 624 | " ax = plt.gca()\n", 625 | " samples = sampleable.sample(num_samples) # (ns, 2)\n", 626 | " ax.hist2d(samples[:,0].cpu(), samples[:,1].cpu(), **kwargs)\n", 627 | "\n", 628 | "def scatter_sampleable(sampleable: Sampleable, num_samples: int, ax: Optional[Axes] = None, **kwargs):\n", 629 | " if ax is None:\n", 630 | " ax = plt.gca()\n", 631 | " samples = sampleable.sample(num_samples) # (ns, 2)\n", 632 | " ax.scatter(samples[:,0].cpu(), samples[:,1].cpu(), **kwargs)\n", 633 | "\n", 634 | "def imshow_density(density: Density, bins: int, scale: float, ax: Optional[Axes] = None, **kwargs):\n", 635 | " if ax is None:\n", 636 | " ax = plt.gca()\n", 637 | " x = torch.linspace(-scale, scale, bins).to(device)\n", 638 | " y = torch.linspace(-scale, scale, bins).to(device)\n", 639 | " X, Y = torch.meshgrid(x, y)\n", 640 | " xy = torch.stack([X.reshape(-1), Y.reshape(-1)], dim=-1)\n", 641 | " density = density.log_density(xy).reshape(bins, bins).T\n", 642 | " im = ax.imshow(density.cpu(), extent=[-scale, scale, -scale, scale], origin='lower', **kwargs)\n", 643 | "\n", 644 | "def contour_density(density: Density, bins: int, scale: float, ax: Optional[Axes] = None, **kwargs):\n", 645 | " if ax is None:\n", 646 | " ax = plt.gca()\n", 647 | " x = torch.linspace(-scale, scale, bins).to(device)\n", 648 | " y = torch.linspace(-scale, scale, bins).to(device)\n", 649 | " X, Y = torch.meshgrid(x, y)\n", 650 | " xy = torch.stack([X.reshape(-1), Y.reshape(-1)], dim=-1)\n", 651 | " density = density.log_density(xy).reshape(bins, bins).T\n", 652 | " im = ax.contour(density.cpu(), extent=[-scale, scale, -scale, scale], origin='lower', **kwargs)" 653 | ] 654 | }, 655 | { 656 | "cell_type": "code", 657 | "execution_count": null, 658 | "id": "498eb6cb-1261-4cc1-b1d0-281e5f73d2cc", 659 | "metadata": {}, 660 | "outputs": [], 661 | "source": [ 662 | "class Gaussian(torch.nn.Module, Sampleable, Density):\n", 663 | " \"\"\"\n", 664 | " Two-dimensional Gaussian. Is a Density and a Sampleable. Wrapper around torch.distributions.MultivariateNormal\n", 665 | " \"\"\"\n", 666 | " def __init__(self, mean, cov):\n", 667 | " \"\"\"\n", 668 | " mean: shape (2,)\n", 669 | " cov: shape (2,2)\n", 670 | " \"\"\"\n", 671 | " super().__init__()\n", 672 | " self.register_buffer(\"mean\", mean)\n", 673 | " self.register_buffer(\"cov\", cov)\n", 674 | "\n", 675 | " @property\n", 676 | " def distribution(self):\n", 677 | " return D.MultivariateNormal(self.mean, self.cov, validate_args=False)\n", 678 | "\n", 679 | " def sample(self, num_samples) -> torch.Tensor:\n", 680 | " return self.distribution.sample((num_samples,))\n", 681 | "\n", 682 | " def log_density(self, x: torch.Tensor):\n", 683 | " return self.distribution.log_prob(x).view(-1, 1)\n", 684 | "\n", 685 | "class GaussianMixture(torch.nn.Module, Sampleable, Density):\n", 686 | " \"\"\"\n", 687 | " Two-dimensional Gaussian mixture model, and is a Density and a Sampleable. Wrapper around torch.distributions.MixtureSameFamily.\n", 688 | " \"\"\"\n", 689 | " def __init__(\n", 690 | " self,\n", 691 | " means: torch.Tensor, # nmodes x data_dim\n", 692 | " covs: torch.Tensor, # nmodes x data_dim x data_dim\n", 693 | " weights: torch.Tensor, # nmodes\n", 694 | " ):\n", 695 | " \"\"\"\n", 696 | " means: shape (nmodes, 2)\n", 697 | " covs: shape (nmodes, 2, 2)\n", 698 | " weights: shape (nmodes, 1)\n", 699 | " \"\"\"\n", 700 | " super().__init__()\n", 701 | " self.nmodes = means.shape[0]\n", 702 | " self.register_buffer(\"means\", means)\n", 703 | " self.register_buffer(\"covs\", covs)\n", 704 | " self.register_buffer(\"weights\", weights)\n", 705 | "\n", 706 | " @property\n", 707 | " def dim(self) -> int:\n", 708 | " return self.means.shape[1]\n", 709 | "\n", 710 | " @property\n", 711 | " def distribution(self):\n", 712 | " return D.MixtureSameFamily(\n", 713 | " mixture_distribution=D.Categorical(probs=self.weights, validate_args=False),\n", 714 | " component_distribution=D.MultivariateNormal(\n", 715 | " loc=self.means,\n", 716 | " covariance_matrix=self.covs,\n", 717 | " validate_args=False,\n", 718 | " ),\n", 719 | " validate_args=False,\n", 720 | " )\n", 721 | "\n", 722 | " def log_density(self, x: torch.Tensor) -> torch.Tensor:\n", 723 | " return self.distribution.log_prob(x).view(-1, 1)\n", 724 | "\n", 725 | " def sample(self, num_samples: int) -> torch.Tensor:\n", 726 | " return self.distribution.sample(torch.Size((num_samples,)))\n", 727 | "\n", 728 | " @classmethod\n", 729 | " def random_2D(\n", 730 | " cls, nmodes: int, std: float, scale: float = 10.0, seed = 0.0\n", 731 | " ) -> \"GaussianMixture\":\n", 732 | " torch.manual_seed(seed)\n", 733 | " means = (torch.rand(nmodes, 2) - 0.5) * scale\n", 734 | " covs = torch.diag_embed(torch.ones(nmodes, 2)) * std ** 2\n", 735 | " weights = torch.ones(nmodes)\n", 736 | " return cls(means, covs, weights)\n", 737 | "\n", 738 | " @classmethod\n", 739 | " def symmetric_2D(\n", 740 | " cls, nmodes: int, std: float, scale: float = 10.0,\n", 741 | " ) -> \"GaussianMixture\":\n", 742 | " angles = torch.linspace(0, 2 * np.pi, nmodes + 1)[:nmodes]\n", 743 | " means = torch.stack([torch.cos(angles), torch.sin(angles)], dim=1) * scale\n", 744 | " covs = torch.diag_embed(torch.ones(nmodes, 2) * std ** 2)\n", 745 | " weights = torch.ones(nmodes) / nmodes\n", 746 | " return cls(means, covs, weights)" 747 | ] 748 | }, 749 | { 750 | "cell_type": "code", 751 | "execution_count": null, 752 | "id": "36ce6533-2e56-42cb-98e7-958c45d583f8", 753 | "metadata": {}, 754 | "outputs": [], 755 | "source": [ 756 | "# Visualize densities\n", 757 | "densities = {\n", 758 | " \"Gaussian\": Gaussian(mean=torch.zeros(2), cov=10 * torch.eye(2)).to(device),\n", 759 | " \"Random Mixture\": GaussianMixture.random_2D(nmodes=5, std=1.0, scale=20.0, seed=3.0).to(device),\n", 760 | " \"Symmetric Mixture\": GaussianMixture.symmetric_2D(nmodes=5, std=1.0, scale=8.0).to(device),\n", 761 | "}\n", 762 | "\n", 763 | "fig, axes = plt.subplots(1,3, figsize=(18, 6))\n", 764 | "bins = 100\n", 765 | "scale = 15\n", 766 | "for idx, (name, density) in enumerate(densities.items()):\n", 767 | " ax = axes[idx]\n", 768 | " ax.set_title(name)\n", 769 | " imshow_density(density, bins, scale, ax, vmin=-15, cmap=plt.get_cmap('Blues'))\n", 770 | " contour_density(density, bins, scale, ax, colors='grey', linestyles='solid', alpha=0.25, levels=20)\n", 771 | "plt.show()\n" 772 | ] 773 | }, 774 | { 775 | "cell_type": "markdown", 776 | "id": "5b51093e-b25a-4cdb-bc0d-ab8c1a3145b3", 777 | "metadata": {}, 778 | "source": [ 779 | "### Question 3.1: Implementing Langevin Dynamics" 780 | ] 781 | }, 782 | { 783 | "cell_type": "markdown", 784 | "id": "91888056-b900-401b-bebb-4e4bf5301afe", 785 | "metadata": {}, 786 | "source": [ 787 | "In this section, we'll simulate the (overdamped) Langevin dynamics $$dX_t = \\frac{1}{2} \\sigma^2\\nabla \\log p(X_t) dt + \\sigma dW_t,$$.\n", 788 | "\n", 789 | "**Your job**: Fill in the `drift_coefficient` and `diffusion_coefficient` methods of the class `LangevinSDE` below." 790 | ] 791 | }, 792 | { 793 | "cell_type": "code", 794 | "execution_count": null, 795 | "id": "371f00c5-3030-4f6c-945f-4a7d3483ff56", 796 | "metadata": {}, 797 | "outputs": [], 798 | "source": [ 799 | "class LangevinSDE(SDE):\n", 800 | " def __init__(self, sigma: float, density: Density):\n", 801 | " self.sigma = sigma\n", 802 | " self.density = density\n", 803 | " \n", 804 | " def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", 805 | " \"\"\"\n", 806 | " Returns the drift coefficient of the ODE.\n", 807 | " Args:\n", 808 | " - xt: state at time t, shape (bs, dim)\n", 809 | " - t: time, shape ()\n", 810 | " Returns:\n", 811 | " - drift: shape (bs, dim)\n", 812 | " \"\"\"\n", 813 | " return 0.5 * self.sigma ** 2 * self.density.score(xt)\n", 814 | "\n", 815 | " def diffusion_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", 816 | " \"\"\"\n", 817 | " Returns the diffusion coefficient of the ODE.\n", 818 | " Args:\n", 819 | " - xt: state at time t, shape (bs, dim)\n", 820 | " - t: time, shape ()\n", 821 | " Returns:\n", 822 | " - diffusion: shape (bs, dim)\n", 823 | " \"\"\"\n", 824 | " return self.sigma * torch.ones_like(xt)" 825 | ] 826 | }, 827 | { 828 | "cell_type": "markdown", 829 | "id": "52e2235a-befd-4f25-b0ac-849c9868cd5d", 830 | "metadata": {}, 831 | "source": [ 832 | "Now, let's graph the results!" 833 | ] 834 | }, 835 | { 836 | "cell_type": "code", 837 | "execution_count": null, 838 | "id": "f7347543-533a-48c9-9c08-193898b139c9", 839 | "metadata": {}, 840 | "outputs": [], 841 | "source": [ 842 | "# First, let's define two utility functions...\n", 843 | "def every_nth_index(num_timesteps: int, n: int) -> torch.Tensor:\n", 844 | " \"\"\"\n", 845 | " Compute the indices to record in the trajectory given a record_every parameter\n", 846 | " \"\"\"\n", 847 | " if n == 1:\n", 848 | " return torch.arange(num_timesteps)\n", 849 | " return torch.cat(\n", 850 | " [\n", 851 | " torch.arange(0, num_timesteps - 1, n),\n", 852 | " torch.tensor([num_timesteps - 1]),\n", 853 | " ]\n", 854 | " )\n", 855 | "\n", 856 | "def graph_dynamics(\n", 857 | " num_samples: int,\n", 858 | " source_distribution: Sampleable,\n", 859 | " simulator: Simulator, \n", 860 | " density: Density,\n", 861 | " timesteps: torch.Tensor, \n", 862 | " plot_every: int,\n", 863 | " bins: int,\n", 864 | " scale: float\n", 865 | "):\n", 866 | " \"\"\"\n", 867 | " Plot the evolution of samples from source under the simulation scheme given by simulator (itself a discretization of an ODE or SDE).\n", 868 | " Args:\n", 869 | " - num_samples: the number of samples to simulate\n", 870 | " - source_distribution: distribution from which we draw initial samples at t=0\n", 871 | " - simulator: the discertized simulation scheme used to simulate the dynamics\n", 872 | " - density: the target density\n", 873 | " - timesteps: the timesteps used by the simulator\n", 874 | " - plot_every: number of timesteps between consecutive plots\n", 875 | " - bins: number of bins for imshow\n", 876 | " - scale: scale for imshow\n", 877 | " \"\"\"\n", 878 | " # Simulate\n", 879 | " x0 = source_distribution.sample(num_samples)\n", 880 | " xts = simulator.simulate_with_trajectory(x0, timesteps)\n", 881 | " indices_to_plot = every_nth_index(len(timesteps), plot_every)\n", 882 | " plot_timesteps = timesteps[indices_to_plot]\n", 883 | " plot_xts = xts[:,indices_to_plot]\n", 884 | "\n", 885 | " # Graph\n", 886 | " fig, axes = plt.subplots(2, len(plot_timesteps), figsize=(8*len(plot_timesteps), 16))\n", 887 | " axes = axes.reshape((2,len(plot_timesteps)))\n", 888 | " for t_idx in range(len(plot_timesteps)):\n", 889 | " t = plot_timesteps[t_idx].item()\n", 890 | " xt = plot_xts[:,t_idx]\n", 891 | " # Scatter axes\n", 892 | " scatter_ax = axes[0, t_idx]\n", 893 | " imshow_density(density, bins, scale, scatter_ax, vmin=-15, alpha=0.25, cmap=plt.get_cmap('Blues'))\n", 894 | " scatter_ax.scatter(xt[:,0].cpu(), xt[:,1].cpu(), marker='x', color='black', alpha=0.75, s=15)\n", 895 | " scatter_ax.set_title(f'Samples at t={t:.1f}', fontsize=15)\n", 896 | " scatter_ax.set_xticks([])\n", 897 | " scatter_ax.set_yticks([])\n", 898 | "\n", 899 | " # Kdeplot axes\n", 900 | " kdeplot_ax = axes[1, t_idx]\n", 901 | " imshow_density(density, bins, scale, kdeplot_ax, vmin=-15, alpha=0.5, cmap=plt.get_cmap('Blues'))\n", 902 | " sns.kdeplot(x=xt[:,0].cpu(), y=xt[:,1].cpu(), alpha=0.5, ax=kdeplot_ax,color='grey')\n", 903 | " kdeplot_ax.set_title(f'Density of Samples at t={t:.1f}', fontsize=15)\n", 904 | " kdeplot_ax.set_xticks([])\n", 905 | " kdeplot_ax.set_yticks([])\n", 906 | " kdeplot_ax.set_xlabel(\"\")\n", 907 | " kdeplot_ax.set_ylabel(\"\")\n", 908 | "\n", 909 | " plt.show()" 910 | ] 911 | }, 912 | { 913 | "cell_type": "code", 914 | "execution_count": null, 915 | "id": "f0b244a6-5b25-4b83-a4fb-ff14c29d5eb2", 916 | "metadata": {}, 917 | "outputs": [], 918 | "source": [ 919 | "# Construct the simulator\n", 920 | "target = GaussianMixture.random_2D(nmodes=5, std=0.75, scale=15.0, seed=3.0).to(device)\n", 921 | "sde = LangevinSDE(sigma = 0.6, density = target)\n", 922 | "simulator = EulerMaruyamaSimulator(sde)\n", 923 | "\n", 924 | "# Graph the results!\n", 925 | "graph_dynamics(\n", 926 | " num_samples = 1000,\n", 927 | " source_distribution = Gaussian(mean=torch.zeros(2), cov=20 * torch.eye(2)).to(device),\n", 928 | " simulator=simulator,\n", 929 | " density=target,\n", 930 | " timesteps=torch.linspace(0,5.0,1000).to(device),\n", 931 | " plot_every=334,\n", 932 | " bins=200,\n", 933 | " scale=15\n", 934 | ") " 935 | ] 936 | }, 937 | { 938 | "cell_type": "markdown", 939 | "id": "683d75ef-806a-4012-ae59-6cc7faa5eaf1", 940 | "metadata": {}, 941 | "source": [ 942 | "**Your job**: Try varying the value of $\\sigma$, the number and range of the simulation steps, the source distribution, and target density. What do you notice? Why?\n", 943 | "\n", 944 | "**Your answer**:" 945 | ] 946 | }, 947 | { 948 | "cell_type": "markdown", 949 | "id": "2e3d552e-0d1d-4cae-b5e6-23d270b7193f", 950 | "metadata": {}, 951 | "source": [ 952 | "Note: To run the folowing two **optional** cells, you will need to download the `ffmpeg` library. You can do so using e.g., `conda install -c conda-forge ffmpeg` (or, ideally, `mamba`). Running `pip install ffmpeg` or similar will likely **not** work." 953 | ] 954 | }, 955 | { 956 | "cell_type": "code", 957 | "execution_count": null, 958 | "id": "e98a53ea-d24e-4113-a190-e53d661bacec", 959 | "metadata": {}, 960 | "outputs": [], 961 | "source": [ 962 | "from celluloid import Camera\n", 963 | "from IPython.display import HTML\n", 964 | "\n", 965 | "def animate_dynamics(\n", 966 | " num_samples: int,\n", 967 | " source_distribution: Sampleable,\n", 968 | " simulator: Simulator, \n", 969 | " density: Density,\n", 970 | " timesteps: torch.Tensor, \n", 971 | " animate_every: int,\n", 972 | " bins: int,\n", 973 | " scale: float,\n", 974 | " save_path: str = 'dynamics_animation.mp4'\n", 975 | "):\n", 976 | " \"\"\"\n", 977 | " Plot the evolution of samples from source under the simulation scheme given by simulator (itself a discretization of an ODE or SDE).\n", 978 | " Args:\n", 979 | " - num_samples: the number of samples to simulate\n", 980 | " - source_distribution: distribution from which we draw initial samples at t=0\n", 981 | " - simulator: the discertized simulation scheme used to simulate the dynamics\n", 982 | " - density: the target density\n", 983 | " - timesteps: the timesteps used by the simulator\n", 984 | " - animate_every: number of timesteps between consecutive frames in the resulting animation\n", 985 | " \"\"\"\n", 986 | " # Simulate\n", 987 | " x0 = source_distribution.sample(num_samples)\n", 988 | " xts = simulator.simulate_with_trajectory(x0, timesteps)\n", 989 | " indices_to_animate = every_nth_index(len(timesteps), animate_every)\n", 990 | " animate_timesteps = timesteps[indices_to_animate]\n", 991 | " animate_xts = xts[:, indices_to_animate]\n", 992 | "\n", 993 | " # Graph\n", 994 | " fig, axes = plt.subplots(1, 2, figsize=(16, 8))\n", 995 | " camera = Camera(fig)\n", 996 | " for t_idx in range(len(animate_timesteps)):\n", 997 | " t = animate_timesteps[t_idx].item()\n", 998 | " xt = animate_xts[:,t_idx]\n", 999 | " # Scatter axes\n", 1000 | " scatter_ax = axes[0]\n", 1001 | " imshow_density(density, bins, scale, scatter_ax, vmin=-15, alpha=0.25, cmap=plt.get_cmap('Blues'))\n", 1002 | " scatter_ax.scatter(xt[:,0].cpu(), xt[:,1].cpu(), marker='x', color='black', alpha=0.75, s=15)\n", 1003 | " scatter_ax.set_title(f'Samples')\n", 1004 | "\n", 1005 | " # Kdeplot axes\n", 1006 | " kdeplot_ax = axes[1]\n", 1007 | " imshow_density(density, bins, scale, kdeplot_ax, vmin=-15, alpha=0.5, cmap=plt.get_cmap('Blues'))\n", 1008 | " sns.kdeplot(x=xt[:,0].cpu(), y=xt[:,1].cpu(), alpha=0.5, ax=kdeplot_ax,color='grey')\n", 1009 | " kdeplot_ax.set_title(f'Density of Samples', fontsize=15)\n", 1010 | " kdeplot_ax.set_xticks([])\n", 1011 | " kdeplot_ax.set_yticks([])\n", 1012 | " kdeplot_ax.set_xlabel(\"\")\n", 1013 | " kdeplot_ax.set_ylabel(\"\")\n", 1014 | " camera.snap()\n", 1015 | " \n", 1016 | " animation = camera.animate()\n", 1017 | " animation.save(save_path)\n", 1018 | " plt.close()\n", 1019 | " return HTML(animation.to_html5_video())" 1020 | ] 1021 | }, 1022 | { 1023 | "cell_type": "code", 1024 | "execution_count": null, 1025 | "id": "b1f4269e-9e6d-4a50-8353-3b380eae4ed7", 1026 | "metadata": {}, 1027 | "outputs": [], 1028 | "source": [ 1029 | "# OPTIONAL CELL\n", 1030 | "# Construct the simulator\n", 1031 | "target = GaussianMixture.random_2D(nmodes=5, std=0.75, scale=15.0, seed=3.0).to(device)\n", 1032 | "sde = LangevinSDE(sigma = 0.6, density = target)\n", 1033 | "simulator = EulerMaruyamaSimulator(sde)\n", 1034 | "\n", 1035 | "# Graph the results!\n", 1036 | "animate_dynamics(\n", 1037 | " num_samples = 1000,\n", 1038 | " source_distribution = Gaussian(mean=torch.zeros(2), cov=20 * torch.eye(2)).to(device),\n", 1039 | " simulator=simulator,\n", 1040 | " density=target,\n", 1041 | " timesteps=torch.linspace(0,5.0,1000).to(device),\n", 1042 | " bins=200,\n", 1043 | " scale=15,\n", 1044 | " animate_every=100\n", 1045 | ") " 1046 | ] 1047 | }, 1048 | { 1049 | "cell_type": "markdown", 1050 | "id": "d149c323-3c9e-40a7-8e00-15fe8b87f3f8", 1051 | "metadata": {}, 1052 | "source": [ 1053 | "### Question 3.2: Ornstein-Uhlenbeck as Langevin Dynamics\n", 1054 | "In this section, we'll finish off with a brief mathematical exercise connecting Langevin dynamics and Ornstein-Uhlenbeck processes. Recall that for (suitably nice) distribution $p$, the *Langevin dynamics* are given by\n", 1055 | "$$dX_t = \\frac{1}{2} \\sigma^2\\nabla \\log p(X_t) dt + \\sigma\\, dW_t, \\quad \\quad X_0 = x_0,$$\n", 1056 | "while for given $\\theta, \\sigma$, the Ornstein-Uhlenbeck process is given by\n", 1057 | "$$ dX_t = -\\theta X_t\\, dt + \\sigma\\, dW_t, \\quad \\quad X_0 = x_0.$$" 1058 | ] 1059 | }, 1060 | { 1061 | "cell_type": "markdown", 1062 | "id": "86954c67-510b-4d10-aea1-b5636f4dbb47", 1063 | "metadata": {}, 1064 | "source": [ 1065 | "**Your job**: Show that when $p(x) = N(0, \\frac{\\sigma^2}{2\\theta})$, the score is given by $$\\nabla \\log p(x) = -\\frac{2\\theta}{\\sigma^2}x.$$\n", 1066 | "\n", 1067 | "**Hint**: The probability density of the Gaussian $p(x) = N(0, \\frac{\\sigma^2}{2\\theta})$ is given by $$p(x) = \\frac{\\sqrt{\\theta}}{\\sigma\\sqrt{\\pi}} \\exp\\left(-\\frac{x^2\\theta}{\\sigma^2}\\right).$$\n", 1068 | "\n", 1069 | "**Your answer**: From the hint,\n", 1070 | "$$\\log p(x) = - \\frac{x^2\\theta}{\\sigma^2} + C.$$\n", 1071 | "Thus, $\\nabla \\log p(x) = \\frac{d}{dx} \\log p(x)$ is given by \n", 1072 | "$$ \\frac{d}{dx} \\left(- \\frac{\\theta}{\\sigma^2}x^2\\right) = \\boxed{- \\frac{2\\theta}{\\sigma^2}x}.$$" 1073 | ] 1074 | }, 1075 | { 1076 | "cell_type": "markdown", 1077 | "id": "f3a0f761-c85c-4e30-b497-bc7622bc8e72", 1078 | "metadata": {}, 1079 | "source": [ 1080 | "**Your job**: Conclude that when $p(x) = N(0, \\frac{\\sigma^2}{2\\theta})$, the Langevin dynamics \n", 1081 | "$$dX_t = \\frac{1}{2} \\sigma^2\\nabla \\log p(X_t) dt + \\sigma dW_t,$$\n", 1082 | "is equivalent to the Ornstein-Uhlenbeck process\n", 1083 | "$$ dX_t = -\\theta X_t\\, dt + \\sigma\\, dW_t, \\quad \\quad X_0 = 0.$$\n", 1084 | "\n", 1085 | "**Your answer**: Just plug in the previous result." 1086 | ] 1087 | } 1088 | ], 1089 | "metadata": { 1090 | "kernelspec": { 1091 | "display_name": "mtds", 1092 | "language": "python", 1093 | "name": "mtds" 1094 | }, 1095 | "language_info": { 1096 | "codemirror_mode": { 1097 | "name": "ipython", 1098 | "version": 3 1099 | }, 1100 | "file_extension": ".py", 1101 | "mimetype": "text/x-python", 1102 | "name": "python", 1103 | "nbconvert_exporter": "python", 1104 | "pygments_lexer": "ipython3", 1105 | "version": "3.9.20" 1106 | } 1107 | }, 1108 | "nbformat": 4, 1109 | "nbformat_minor": 5 1110 | } 1111 | -------------------------------------------------------------------------------- /labs/lab_two.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "f4ca5863-5112-49cd-800c-61199d955cb6", 6 | "metadata": {}, 7 | "source": [ 8 | "# Lab Two: Flow Matching and Score Matching" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "ebb2a99a-6d00-4f94-aa23-a2146897321f", 14 | "metadata": {}, 15 | "source": [ 16 | "Welcome to lab two! In this lab, we will provide an intuitive and hands-on walk-through of *flow matching* and *score matching*. If you find any mistakes, or have any other feedback, please feel free to email us at `erives@mit.edu` and `phold@mit.edu`. Enjoy!" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "id": "82a79a8b-a061-4c93-aedb-3d2269011f36", 22 | "metadata": {}, 23 | "source": [ 24 | "### Part 0: Miscellaneous Imports and Utility Functions\n", 25 | "No questions here, but free to read through to familiarize yourself with these helper functions. Most of this is what you already completed in lab one!" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "id": "e32fa50e-30d9-4048-9c8b-f0661aedeffe", 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "from abc import ABC, abstractmethod\n", 36 | "from typing import Optional, List, Type, Tuple, Dict\n", 37 | "import math\n", 38 | "\n", 39 | "import numpy as np\n", 40 | "from matplotlib import pyplot as plt\n", 41 | "import matplotlib.cm as cm\n", 42 | "from matplotlib.axes._axes import Axes\n", 43 | "import torch\n", 44 | "import torch.distributions as D\n", 45 | "from torch.func import vmap, jacrev\n", 46 | "from tqdm import tqdm\n", 47 | "import seaborn as sns\n", 48 | "from sklearn.datasets import make_moons, make_circles\n", 49 | "\n", 50 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "id": "ef0e04ac-f9a1-4be0-9e8c-162f85901207", 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "class Sampleable(ABC):\n", 61 | " \"\"\"\n", 62 | " Distribution which can be sampled from\n", 63 | " \"\"\"\n", 64 | " @property\n", 65 | " @abstractmethod\n", 66 | " def dim(self) -> int:\n", 67 | " \"\"\"\n", 68 | " Returns:\n", 69 | " - Dimensionality of the distribution\n", 70 | " \"\"\"\n", 71 | " pass\n", 72 | " \n", 73 | " @abstractmethod\n", 74 | " def sample(self, num_samples: int) -> torch.Tensor:\n", 75 | " \"\"\"\n", 76 | " Args:\n", 77 | " - num_samples: the desired number of samples\n", 78 | " Returns:\n", 79 | " - samples: shape (batch_size, dim)\n", 80 | " \"\"\"\n", 81 | " pass\n", 82 | "\n", 83 | "class Density(ABC):\n", 84 | " \"\"\"\n", 85 | " Distribution with tractable density\n", 86 | " \"\"\"\n", 87 | " @abstractmethod\n", 88 | " def log_density(self, x: torch.Tensor) -> torch.Tensor:\n", 89 | " \"\"\"\n", 90 | " Returns the log density at x.\n", 91 | " Args:\n", 92 | " - x: shape (batch_size, dim)\n", 93 | " Returns:\n", 94 | " - log_density: shape (batch_size, 1)\n", 95 | " \"\"\"\n", 96 | " pass\n", 97 | "\n", 98 | "class Gaussian(torch.nn.Module, Sampleable, Density):\n", 99 | " \"\"\"\n", 100 | " Multivariate Gaussian distribution\n", 101 | " \"\"\"\n", 102 | " def __init__(self, mean: torch.Tensor, cov: torch.Tensor):\n", 103 | " \"\"\"\n", 104 | " mean: shape (dim,)\n", 105 | " cov: shape (dim,dim)\n", 106 | " \"\"\"\n", 107 | " super().__init__()\n", 108 | " self.register_buffer(\"mean\", mean)\n", 109 | " self.register_buffer(\"cov\", cov)\n", 110 | "\n", 111 | " @property\n", 112 | " def dim(self) -> int:\n", 113 | " return self.mean.shape[0]\n", 114 | "\n", 115 | " @property\n", 116 | " def distribution(self):\n", 117 | " return D.MultivariateNormal(self.mean, self.cov, validate_args=False)\n", 118 | "\n", 119 | " def sample(self, num_samples) -> torch.Tensor:\n", 120 | " return self.distribution.sample((num_samples,))\n", 121 | " \n", 122 | " def log_density(self, x: torch.Tensor):\n", 123 | " return self.distribution.log_prob(x).view(-1, 1)\n", 124 | "\n", 125 | " @classmethod\n", 126 | " def isotropic(cls, dim: int, std: float) -> \"Gaussian\":\n", 127 | " mean = torch.zeros(dim)\n", 128 | " cov = torch.eye(dim) * std ** 2\n", 129 | " return cls(mean, cov)\n", 130 | "\n", 131 | "class GaussianMixture(torch.nn.Module, Sampleable, Density):\n", 132 | " \"\"\"\n", 133 | " Two-dimensional Gaussian mixture model, and is a Density and a Sampleable. Wrapper around torch.distributions.MixtureSameFamily.\n", 134 | " \"\"\"\n", 135 | " def __init__(\n", 136 | " self,\n", 137 | " means: torch.Tensor, # nmodes x data_dim\n", 138 | " covs: torch.Tensor, # nmodes x data_dim x data_dim\n", 139 | " weights: torch.Tensor, # nmodes\n", 140 | " ):\n", 141 | " \"\"\"\n", 142 | " means: shape (nmodes, 2)\n", 143 | " covs: shape (nmodes, 2, 2)\n", 144 | " weights: shape (nmodes, 1)\n", 145 | " \"\"\"\n", 146 | " super().__init__()\n", 147 | " self.nmodes = means.shape[0]\n", 148 | " self.register_buffer(\"means\", means)\n", 149 | " self.register_buffer(\"covs\", covs)\n", 150 | " self.register_buffer(\"weights\", weights)\n", 151 | "\n", 152 | " @property\n", 153 | " def dim(self) -> int:\n", 154 | " return self.means.shape[1]\n", 155 | "\n", 156 | " @property\n", 157 | " def distribution(self):\n", 158 | " return D.MixtureSameFamily(\n", 159 | " mixture_distribution=D.Categorical(probs=self.weights, validate_args=False),\n", 160 | " component_distribution=D.MultivariateNormal(\n", 161 | " loc=self.means,\n", 162 | " covariance_matrix=self.covs,\n", 163 | " validate_args=False,\n", 164 | " ),\n", 165 | " validate_args=False,\n", 166 | " )\n", 167 | "\n", 168 | " def log_density(self, x: torch.Tensor) -> torch.Tensor:\n", 169 | " return self.distribution.log_prob(x).view(-1, 1)\n", 170 | "\n", 171 | " def sample(self, num_samples: int) -> torch.Tensor:\n", 172 | " return self.distribution.sample(torch.Size((num_samples,)))\n", 173 | "\n", 174 | " @classmethod\n", 175 | " def random_2D(\n", 176 | " cls, nmodes: int, std: float, scale: float = 10.0, x_offset: float = 0.0, seed = 0.0\n", 177 | " ) -> \"GaussianMixture\":\n", 178 | " torch.manual_seed(seed)\n", 179 | " means = (torch.rand(nmodes, 2) - 0.5) * scale + x_offset * torch.Tensor([1.0, 0.0])\n", 180 | " covs = torch.diag_embed(torch.ones(nmodes, 2)) * std ** 2\n", 181 | " weights = torch.ones(nmodes)\n", 182 | " return cls(means, covs, weights)\n", 183 | "\n", 184 | " @classmethod\n", 185 | " def symmetric_2D(\n", 186 | " cls, nmodes: int, std: float, scale: float = 10.0, x_offset: float = 0.0\n", 187 | " ) -> \"GaussianMixture\":\n", 188 | " angles = torch.linspace(0, 2 * np.pi, nmodes + 1)[:nmodes]\n", 189 | " means = torch.stack([torch.cos(angles), torch.sin(angles)], dim=1) * scale + torch.Tensor([1.0, 0.0]) * x_offset\n", 190 | " covs = torch.diag_embed(torch.ones(nmodes, 2) * std ** 2)\n", 191 | " weights = torch.ones(nmodes) / nmodes\n", 192 | " return cls(means, covs, weights)" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": null, 198 | "id": "037723a8-e212-46f5-ae75-997230282515", 199 | "metadata": {}, 200 | "outputs": [], 201 | "source": [ 202 | "# Several plotting utility functions\n", 203 | "def hist2d_samples(samples, ax: Optional[Axes] = None, bins: int = 200, scale: float = 5.0, percentile: int = 99, **kwargs):\n", 204 | " H, xedges, yedges = np.histogram2d(samples[:, 0], samples[:, 1], bins=bins, range=[[-scale, scale], [-scale, scale]])\n", 205 | " \n", 206 | " # Determine color normalization based on the 99th percentile\n", 207 | " cmax = np.percentile(H, percentile)\n", 208 | " cmin = 0.0\n", 209 | " norm = cm.colors.Normalize(vmax=cmax, vmin=cmin)\n", 210 | " \n", 211 | " # Plot using imshow for more control\n", 212 | " extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]\n", 213 | " ax.imshow(H.T, extent=extent, origin='lower', norm=norm, **kwargs)\n", 214 | "\n", 215 | "def hist2d_sampleable(sampleable: Sampleable, num_samples: int, ax: Optional[Axes] = None, bins=200, scale: float = 5.0, percentile: int = 99, **kwargs):\n", 216 | " assert sampleable.dim == 2\n", 217 | " if ax is None:\n", 218 | " ax = plt.gca()\n", 219 | " samples = sampleable.sample(num_samples).detach().cpu() # (ns, 2)\n", 220 | " hist2d_samples(samples, ax, bins, scale, percentile, **kwargs)\n", 221 | "\n", 222 | "def scatter_sampleable(sampleable: Sampleable, num_samples: int, ax: Optional[Axes] = None, **kwargs):\n", 223 | " assert sampleable.dim == 2\n", 224 | " if ax is None:\n", 225 | " ax = plt.gca()\n", 226 | " samples = sampleable.sample(num_samples) # (ns, 2)\n", 227 | " ax.scatter(samples[:,0].cpu(), samples[:,1].cpu(), **kwargs)\n", 228 | "\n", 229 | "def kdeplot_sampleable(sampleable: Sampleable, num_samples: int, ax: Optional[Axes] = None, **kwargs):\n", 230 | " assert sampleable.dim == 2\n", 231 | " if ax is None:\n", 232 | " ax = plt.gca()\n", 233 | " samples = sampleable.sample(num_samples) # (ns, 2)\n", 234 | " sns.kdeplot(x=samples[:,0].cpu(), y=samples[:,1].cpu(), ax=ax, **kwargs)\n", 235 | "\n", 236 | "def imshow_density(density: Density, x_bounds: Tuple[float, float], y_bounds: Tuple[float, float], bins: int, ax: Optional[Axes] = None, x_offset: float = 0.0, **kwargs):\n", 237 | " if ax is None:\n", 238 | " ax = plt.gca()\n", 239 | " x_min, x_max = x_bounds\n", 240 | " y_min, y_max = y_bounds\n", 241 | " x = torch.linspace(x_min, x_max, bins).to(device) + x_offset\n", 242 | " y = torch.linspace(y_min, y_max, bins).to(device)\n", 243 | " X, Y = torch.meshgrid(x, y)\n", 244 | " xy = torch.stack([X.reshape(-1), Y.reshape(-1)], dim=-1)\n", 245 | " density = density.log_density(xy).reshape(bins, bins).T\n", 246 | " im = ax.imshow(density.cpu(), extent=[x_min, x_max, y_min, y_max], origin='lower', **kwargs)\n", 247 | "\n", 248 | "def contour_density(density: Density, bins: int, scale: float, ax: Optional[Axes] = None, x_offset:float = 0.0, **kwargs):\n", 249 | " if ax is None:\n", 250 | " ax = plt.gca()\n", 251 | " x = torch.linspace(-scale + x_offset, scale + x_offset, bins).to(device)\n", 252 | " y = torch.linspace(-scale, scale, bins).to(device)\n", 253 | " X, Y = torch.meshgrid(x, y)\n", 254 | " xy = torch.stack([X.reshape(-1), Y.reshape(-1)], dim=-1)\n", 255 | " density = density.log_density(xy).reshape(bins, bins).T\n", 256 | " im = ax.contour(density.cpu(), origin='lower', **kwargs)" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": null, 262 | "id": "f47118d5-30d0-4374-81de-f47e3b96f6e3", 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "class ODE(ABC):\n", 267 | " @abstractmethod\n", 268 | " def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", 269 | " \"\"\"\n", 270 | " Returns the drift coefficient of the ODE.\n", 271 | " Args:\n", 272 | " - xt: state at time t, shape (bs, dim)\n", 273 | " - t: time, shape (batch_size, 1)\n", 274 | " Returns:\n", 275 | " - drift_coefficient: shape (batch_size, dim)\n", 276 | " \"\"\"\n", 277 | " pass\n", 278 | "\n", 279 | "class SDE(ABC):\n", 280 | " @abstractmethod\n", 281 | " def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", 282 | " \"\"\"\n", 283 | " Returns the drift coefficient of the ODE.\n", 284 | " Args:\n", 285 | " - xt: state at time t, shape (batch_size, dim)\n", 286 | " - t: time, shape (batch_size, 1)\n", 287 | " Returns:\n", 288 | " - drift_coefficient: shape (batch_size, dim)\n", 289 | " \"\"\"\n", 290 | " pass\n", 291 | "\n", 292 | " @abstractmethod\n", 293 | " def diffusion_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", 294 | " \"\"\"\n", 295 | " Returns the diffusion coefficient of the ODE.\n", 296 | " Args:\n", 297 | " - xt: state at time t, shape (batch_size, dim)\n", 298 | " - t: time, shape (batch_size, 1)\n", 299 | " Returns:\n", 300 | " - diffusion_coefficient: shape (batch_size, dim)\n", 301 | " \"\"\"\n", 302 | " pass" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": null, 308 | "id": "7f98a6a4-67a7-4740-837f-156aac725c2b", 309 | "metadata": {}, 310 | "outputs": [], 311 | "source": [ 312 | "class Simulator(ABC):\n", 313 | " @abstractmethod\n", 314 | " def step(self, xt: torch.Tensor, t: torch.Tensor, dt: torch.Tensor):\n", 315 | " \"\"\"\n", 316 | " Takes one simulation step\n", 317 | " Args:\n", 318 | " - xt: state at time t, shape (bs, dim)\n", 319 | " - t: time, shape (bs,1)\n", 320 | " - dt: time, shape (bs,1)\n", 321 | " Returns:\n", 322 | " - nxt: state at time t + dt (bs, dim)\n", 323 | " \"\"\"\n", 324 | " pass\n", 325 | "\n", 326 | " @torch.no_grad()\n", 327 | " def simulate(self, x: torch.Tensor, ts: torch.Tensor):\n", 328 | " \"\"\"\n", 329 | " Simulates using the discretization gives by ts\n", 330 | " Args:\n", 331 | " - x_init: initial state at time ts[0], shape (batch_size, dim)\n", 332 | " - ts: timesteps, shape (bs, num_timesteps,1)\n", 333 | " Returns:\n", 334 | " - x_final: final state at time ts[-1], shape (batch_size, dim)\n", 335 | " \"\"\"\n", 336 | " for t_idx in range(len(ts) - 1):\n", 337 | " t = ts[:, t_idx]\n", 338 | " h = ts[:, t_idx + 1] - ts[:, t_idx]\n", 339 | " x = self.step(x, t, h)\n", 340 | " return x\n", 341 | "\n", 342 | " @torch.no_grad()\n", 343 | " def simulate_with_trajectory(self, x: torch.Tensor, ts: torch.Tensor):\n", 344 | " \"\"\"\n", 345 | " Simulates using the discretization gives by ts\n", 346 | " Args:\n", 347 | " - x_init: initial state at time ts[0], shape (bs, dim)\n", 348 | " - ts: timesteps, shape (bs, num_timesteps, 1)\n", 349 | " Returns:\n", 350 | " - xs: trajectory of xts over ts, shape (batch_size, num\n", 351 | " _timesteps, dim)\n", 352 | " \"\"\"\n", 353 | " xs = [x.clone()]\n", 354 | " nts = ts.shape[1]\n", 355 | " for t_idx in tqdm(range(nts - 1)):\n", 356 | " t = ts[:,t_idx]\n", 357 | " h = ts[:, t_idx + 1] - ts[:, t_idx]\n", 358 | " x = self.step(x, t, h)\n", 359 | " xs.append(x.clone())\n", 360 | " return torch.stack(xs, dim=1)\n", 361 | "\n", 362 | "class EulerSimulator(Simulator):\n", 363 | " def __init__(self, ode: ODE):\n", 364 | " self.ode = ode\n", 365 | " \n", 366 | " def step(self, xt: torch.Tensor, t: torch.Tensor, h: torch.Tensor):\n", 367 | " return xt + self.ode.drift_coefficient(xt,t) * h\n", 368 | "\n", 369 | "class EulerMaruyamaSimulator(Simulator):\n", 370 | " def __init__(self, sde: SDE):\n", 371 | " self.sde = sde\n", 372 | " \n", 373 | " def step(self, xt: torch.Tensor, t: torch.Tensor, h: torch.Tensor):\n", 374 | " return xt + self.sde.drift_coefficient(xt,t) * h + self.sde.diffusion_coefficient(xt,t) * torch.sqrt(h) * torch.randn_like(xt)\n", 375 | "\n", 376 | "def record_every(num_timesteps: int, record_every: int) -> torch.Tensor:\n", 377 | " \"\"\"\n", 378 | " Compute the indices to record in the trajectory given a record_every parameter\n", 379 | " \"\"\"\n", 380 | " if record_every == 1:\n", 381 | " return torch.arange(num_timesteps)\n", 382 | " return torch.cat(\n", 383 | " [\n", 384 | " torch.arange(0, num_timesteps - 1, record_every),\n", 385 | " torch.tensor([num_timesteps - 1]),\n", 386 | " ]\n", 387 | " )" 388 | ] 389 | }, 390 | { 391 | "cell_type": "markdown", 392 | "id": "9b91789e-d50d-4aef-9d98-4188aa20ceed", 393 | "metadata": {}, 394 | "source": [ 395 | "### Part 1: Implementing Conditional Probability Paths\n", 396 | "Recall from lecture and the class notes the basic premise of conditional flow matching: describe a *conditional probability path* $p_t(x|z)$, so that $p_1(x|z) = \\delta_z(x)$, and $p_0(z) = p_{\\text{simple}}$ (e.g., a Gaussian), and $p_t(x|z)$ interpolates continuously (we are not being rigorous here) between $p_0(x|z)$ and $p_1(x|z)$. Such a conditional path can be seen as corresponding to some corruption process which (in reverse time) drives the point $z$ at $t=1$ to be distribution as $p_0(x|z)$ at time $t=0$. Such a corruption process is given by the ODE\n", 397 | "$$dX_t = u_t^{\\text{ref}}(X_t|z)\\,dt,\\quad \\quad X_0 \\sim p_{\\text{simple}}.$$\n", 398 | "The drift $u_t^{\\text{ref}}(X_t|z)$ is referred to as the *conditional vector field*. By averaging $u_t^{\\text{ref}}(x|z)$ over all such choices of $z$, we obtain the *marginal* vector field $u_t^{\\text{ref}}(x)$. Flow matching proposes to exploit the fact that the *marginal probability path* $p_t(x)$ generated by the marginal vector field $u_t^{\\text{ref}}(x)$, bridges $p_{\\text{simple}}$ to $p_{\\text{data}}$. Since the conditional vector field $u_t^{\\text{ref}}(x|z)$ is often analytically available, we may implicitly regress against the unknown marginal vector field $u_t^{\\text{ref}}(x)$ by explicitly regressing against the conditional vector field $u_t^{\\text{ref}}(x|z)$." 399 | ] 400 | }, 401 | { 402 | "cell_type": "markdown", 403 | "id": "8ca98cf8-4eae-4d3a-8f88-6424b80b06e5", 404 | "metadata": {}, 405 | "source": [ 406 | "The central object in this construction is a *conditional probability path*, whose interface is implemented below in the class `ConditionalProbabilityPath`. In this lab, you will implement two subclasses: `GaussianConditionalProbabilityPath`, and `LinearConditionalProbabilityPath` corresponding to probability paths of the same names from the lectures and notes." 407 | ] 408 | }, 409 | { 410 | "cell_type": "code", 411 | "execution_count": null, 412 | "id": "4706df1f-33d0-4484-b702-99282fff23cf", 413 | "metadata": {}, 414 | "outputs": [], 415 | "source": [ 416 | "class ConditionalProbabilityPath(torch.nn.Module, ABC):\n", 417 | " \"\"\"\n", 418 | " Abstract base class for conditional probability paths\n", 419 | " \"\"\"\n", 420 | " def __init__(self, p_simple: Sampleable, p_data: Sampleable):\n", 421 | " super().__init__()\n", 422 | " self.p_simple = p_simple\n", 423 | " self.p_data = p_data\n", 424 | "\n", 425 | " def sample_marginal_path(self, t: torch.Tensor) -> torch.Tensor:\n", 426 | " \"\"\"\n", 427 | " Samples from the marginal distribution p_t(x) = p_t(x|z) p(z)\n", 428 | " Args:\n", 429 | " - t: time (num_samples, 1)\n", 430 | " Returns:\n", 431 | " - x: samples from p_t(x), (num_samples, dim)\n", 432 | " \"\"\"\n", 433 | " num_samples = t.shape[0]\n", 434 | " # Sample conditioning variable z ~ p(z)\n", 435 | " z = self.sample_conditioning_variable(num_samples) # (num_samples, dim)\n", 436 | " # Sample conditional probability path x ~ p_t(x|z)\n", 437 | " x = self.sample_conditional_path(z, t) # (num_samples, dim)\n", 438 | " return x\n", 439 | "\n", 440 | " @abstractmethod\n", 441 | " def sample_conditioning_variable(self, num_samples: int) -> torch.Tensor:\n", 442 | " \"\"\"\n", 443 | " Samples the conditioning variable z\n", 444 | " Args:\n", 445 | " - num_samples: the number of samples\n", 446 | " Returns:\n", 447 | " - z: samples from p(z), (num_samples, dim)\n", 448 | " \"\"\"\n", 449 | " pass\n", 450 | " \n", 451 | " @abstractmethod\n", 452 | " def sample_conditional_path(self, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", 453 | " \"\"\"\n", 454 | " Samples from the conditional distribution p_t(x|z)\n", 455 | " Args:\n", 456 | " - z: conditioning variable (num_samples, dim)\n", 457 | " - t: time (num_samples, 1)\n", 458 | " Returns:\n", 459 | " - x: samples from p_t(x|z), (num_samples, dim)\n", 460 | " \"\"\"\n", 461 | " pass\n", 462 | " \n", 463 | " @abstractmethod\n", 464 | " def conditional_vector_field(self, x: torch.Tensor, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", 465 | " \"\"\"\n", 466 | " Evaluates the conditional vector field u_t(x|z)\n", 467 | " Args:\n", 468 | " - x: position variable (num_samples, dim)\n", 469 | " - z: conditioning variable (num_samples, dim)\n", 470 | " - t: time (num_samples, 1)\n", 471 | " Returns:\n", 472 | " - conditional_vector_field: conditional vector field (num_samples, dim)\n", 473 | " \"\"\" \n", 474 | " pass\n", 475 | "\n", 476 | " @abstractmethod\n", 477 | " def conditional_score(self, x: torch.Tensor, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", 478 | " \"\"\"\n", 479 | " Evaluates the conditional score of p_t(x|z)\n", 480 | " Args:\n", 481 | " - x: position variable (num_samples, dim)\n", 482 | " - z: conditioning variable (num_samples, dim)\n", 483 | " - t: time (num_samples, 1)\n", 484 | " Returns:\n", 485 | " - conditional_score: conditional score (num_samples, dim)\n", 486 | " \"\"\" \n", 487 | " pass" 488 | ] 489 | }, 490 | { 491 | "cell_type": "markdown", 492 | "id": "7351c076-f555-4068-9337-d3d0311bc6de", 493 | "metadata": {}, 494 | "source": [ 495 | "# Part 2: Gaussian Conditional Probability Paths\n", 496 | "In this section, we'll implement a **Gaussian conditional probability path** via the class `GaussianConditionalProbabilityPath`. We will then use it to transform a simple source $p_{\\text{simple}} = N(0, I_d)$ into a Gaussian mixture $p_{\\text{data}}$. Later, we'll experiment with more exciting distributions. Recall that a Gaussian conditional probability path is given by\n", 497 | "$$p_t(x|z) = N(x;\\alpha_t z,\\beta_t^2 I_d),\\quad\\quad\\quad p_{\\text{simple}}=N(0,I_d),$$\n", 498 | "where $\\alpha_t: [0,1] \\to \\mathbb{R}$ and $\\beta_t: [0,1] \\to \\mathbb{R}$ are monotonic, continuously differentiable functions satisfying $\\alpha_1 = \\beta_0 = 1$ and $\\alpha_0 = \\beta_1 = 0$. In other words, this implies that $p_1(x|z) = \\delta_z$ and $p_0(x|z) = N(0, I_d)$ is a unit Gaussian. Before we dive into things, let's take a look at $p_{\\text{simple}}$ and $p_{\\text{data}}$. " 499 | ] 500 | }, 501 | { 502 | "cell_type": "code", 503 | "execution_count": null, 504 | "id": "d4d99507-2f5d-44df-8aca-cbcd9940df34", 505 | "metadata": {}, 506 | "outputs": [], 507 | "source": [ 508 | "# Constants for the duration of our use of Gaussian conditional probability paths, to avoid polluting the namespace...\n", 509 | "PARAMS = {\n", 510 | " \"scale\": 15.0,\n", 511 | " \"target_scale\": 10.0,\n", 512 | " \"target_std\": 1.0,\n", 513 | "}" 514 | ] 515 | }, 516 | { 517 | "cell_type": "code", 518 | "execution_count": null, 519 | "id": "56ab35ea-7c63-4bb7-ac35-48d92cc56517", 520 | "metadata": {}, 521 | "outputs": [], 522 | "source": [ 523 | "p_simple = Gaussian.isotropic(dim=2, std = 1.0).to(device)\n", 524 | "p_data = GaussianMixture.symmetric_2D(nmodes=5, std=PARAMS[\"target_std\"], scale=PARAMS[\"target_scale\"]).to(device)\n", 525 | "\n", 526 | "fig, axes = plt.subplots(1,3, figsize=(24,8))\n", 527 | "bins = 200\n", 528 | "\n", 529 | "scale = PARAMS[\"scale\"]\n", 530 | "x_bounds = [-scale,scale]\n", 531 | "y_bounds = [-scale,scale]\n", 532 | "\n", 533 | "axes[0].set_title('Heatmap of p_simple')\n", 534 | "axes[0].set_xticks([])\n", 535 | "axes[0].set_yticks([])\n", 536 | "imshow_density(density=p_simple, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=axes[0], vmin=-10, alpha=0.25, cmap=plt.get_cmap('Reds'))\n", 537 | "\n", 538 | "\n", 539 | "axes[1].set_title('Heatmap of p_data')\n", 540 | "axes[1].set_xticks([])\n", 541 | "axes[1].set_yticks([])\n", 542 | "imshow_density(density=p_data, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=axes[1], vmin=-10, alpha=0.25, cmap=plt.get_cmap('Blues'))\n", 543 | "\n", 544 | "axes[2].set_title('Heatmap of p_simple and p_data')\n", 545 | "axes[2].set_xticks([])\n", 546 | "axes[2].set_yticks([])\n", 547 | "imshow_density(density=p_simple, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Reds'))\n", 548 | "imshow_density(density=p_data, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Blues'))" 549 | ] 550 | }, 551 | { 552 | "cell_type": "markdown", 553 | "id": "574e24a6-59f7-4ea5-b7d8-efc1106a412a", 554 | "metadata": {}, 555 | "source": [ 556 | "### Problem 2.1: Implementing $\\alpha_t$ and $\\beta_t$ " 557 | ] 558 | }, 559 | { 560 | "cell_type": "markdown", 561 | "id": "2690fc06-d90e-4505-aff4-01d5e8279465", 562 | "metadata": {}, 563 | "source": [ 564 | "Let's get started by implementing $\\alpha_t$ and $\\beta_t$. We can think of these simply as callable objects which fulfill the simple contract $\\alpha_1 = \\beta_0 = 1$ and $\\alpha_0 = \\beta_1 = 0$, and which can compute their time derivatives $\\dot{\\alpha}_t$ and $\\dot{\\beta}_t$. We implement them below via the classes `Alpha` and `Beta`." 565 | ] 566 | }, 567 | { 568 | "cell_type": "code", 569 | "execution_count": null, 570 | "id": "f28999de-9f02-4439-a462-39965d932eb3", 571 | "metadata": {}, 572 | "outputs": [], 573 | "source": [ 574 | "class Alpha(ABC):\n", 575 | " def __init__(self):\n", 576 | " # Check alpha_t(0) = 0\n", 577 | " assert torch.allclose(\n", 578 | " self(torch.zeros(1,1)), torch.zeros(1,1)\n", 579 | " )\n", 580 | " # Check alpha_1 = 1\n", 581 | " assert torch.allclose(\n", 582 | " self(torch.ones(1,1)), torch.ones(1,1)\n", 583 | " )\n", 584 | " \n", 585 | " @abstractmethod\n", 586 | " def __call__(self, t: torch.Tensor) -> torch.Tensor:\n", 587 | " \"\"\"\n", 588 | " Evaluates alpha_t. Should satisfy: self(0.0) = 0.0, self(1.0) = 1.0.\n", 589 | " Args:\n", 590 | " - t: time (num_samples, 1)\n", 591 | " Returns:\n", 592 | " - alpha_t (num_samples, 1)\n", 593 | " \"\"\" \n", 594 | " pass\n", 595 | "\n", 596 | " def dt(self, t: torch.Tensor) -> torch.Tensor:\n", 597 | " \"\"\"\n", 598 | " Evaluates d/dt alpha_t.\n", 599 | " Args:\n", 600 | " - t: time (num_samples, 1)\n", 601 | " Returns:\n", 602 | " - d/dt alpha_t (num_samples, 1)\n", 603 | " \"\"\" \n", 604 | " t = t.unsqueeze(1) # (num_samples, 1, 1)\n", 605 | " dt = vmap(jacrev(self))(t) # (num_samples, 1, 1, 1, 1)\n", 606 | " return dt.view(-1, 1)\n", 607 | " \n", 608 | "class Beta(ABC):\n", 609 | " def __init__(self):\n", 610 | " # Check beta_0 = 1\n", 611 | " assert torch.allclose(\n", 612 | " self(torch.zeros(1,1)), torch.ones(1,1)\n", 613 | " )\n", 614 | " # Check beta_1 = 0\n", 615 | " assert torch.allclose(\n", 616 | " self(torch.ones(1,1)), torch.zeros(1,1)\n", 617 | " )\n", 618 | " \n", 619 | " @abstractmethod\n", 620 | " def __call__(self, t: torch.Tensor) -> torch.Tensor:\n", 621 | " \"\"\"\n", 622 | " Evaluates alpha_t. Should satisfy: self(0.0) = 1.0, self(1.0) = 0.0.\n", 623 | " Args:\n", 624 | " - t: time (num_samples, 1)\n", 625 | " Returns:\n", 626 | " - beta_t (num_samples, 1)\n", 627 | " \"\"\" \n", 628 | " pass \n", 629 | "\n", 630 | " def dt(self, t: torch.Tensor) -> torch.Tensor:\n", 631 | " \"\"\"\n", 632 | " Evaluates d/dt beta_t.\n", 633 | " Args:\n", 634 | " - t: time (num_samples, 1)\n", 635 | " Returns:\n", 636 | " - d/dt beta_t (num_samples, 1)\n", 637 | " \"\"\" \n", 638 | " t = t.unsqueeze(1) # (num_samples, 1, 1)\n", 639 | " dt = vmap(jacrev(self))(t) # (num_samples, 1, 1, 1, 1)\n", 640 | " return dt.view(-1, 1)" 641 | ] 642 | }, 643 | { 644 | "cell_type": "markdown", 645 | "id": "205ab204-8f54-4011-88f7-d6765c1ae4e4", 646 | "metadata": {}, 647 | "source": [ 648 | "In this section, we'll be using $$\\alpha_t = t \\quad \\quad \\text{and} \\quad \\quad \\beta_t = \\sqrt{1-t}.$$ It is not hard to check that both functions are continuously differentiable on $[0,1)$, and monotonic, that $\\alpha_1 = \\beta_0 = 1$, and that $\\alpha_0 = \\beta_1 = 0$.\n", 649 | "\n", 650 | "**Your job**: Implement the `__call__` methods of the classes `LinearAlpha` and `SquareRootBeta` below." 651 | ] 652 | }, 653 | { 654 | "cell_type": "code", 655 | "execution_count": null, 656 | "id": "39faaaef-d2af-4617-b172-c1191761e129", 657 | "metadata": {}, 658 | "outputs": [], 659 | "source": [ 660 | "class LinearAlpha(Alpha):\n", 661 | " \"\"\"\n", 662 | " Implements alpha_t = t\n", 663 | " \"\"\"\n", 664 | " \n", 665 | " def __call__(self, t: torch.Tensor) -> torch.Tensor:\n", 666 | " \"\"\"\n", 667 | " Args:\n", 668 | " - t: time (num_samples, 1)\n", 669 | " Returns:\n", 670 | " - alpha_t (num_samples, 1)\n", 671 | " \"\"\" \n", 672 | " raise NotImplementedError(\"Fill me in for Question 2.1!\")\n", 673 | " \n", 674 | " def dt(self, t: torch.Tensor) -> torch.Tensor:\n", 675 | " \"\"\"\n", 676 | " Evaluates d/dt alpha_t.\n", 677 | " Args:\n", 678 | " - t: time (num_samples, 1)\n", 679 | " Returns:\n", 680 | " - d/dt alpha_t (num_samples, 1)\n", 681 | " \"\"\" \n", 682 | " return torch.ones_like(t)\n", 683 | "\n", 684 | "class SquareRootBeta(Beta):\n", 685 | " \"\"\"\n", 686 | " Implements beta_t = rt(1-t)\n", 687 | " \"\"\"\n", 688 | " def __call__(self, t: torch.Tensor) -> torch.Tensor:\n", 689 | " \"\"\"\n", 690 | " Args:\n", 691 | " - t: time (num_samples, 1)\n", 692 | " Returns:\n", 693 | " - beta_t (num_samples, 1)\n", 694 | " \"\"\" \n", 695 | " raise NotImplementedError(\"Fill me in for Question 2.1!\")\n", 696 | "\n", 697 | " def dt(self, t: torch.Tensor) -> torch.Tensor:\n", 698 | " \"\"\"\n", 699 | " Evaluates d/dt alpha_t.\n", 700 | " Args:\n", 701 | " - t: time (num_samples, 1)\n", 702 | " Returns:\n", 703 | " - d/dt alpha_t (num_samples, 1)\n", 704 | " \"\"\" \n", 705 | " return - 0.5 / (torch.sqrt(1 - t) + 1e-4)" 706 | ] 707 | }, 708 | { 709 | "cell_type": "markdown", 710 | "id": "b9c2bfe7-2bab-4a69-99fa-1553ddfcd93c", 711 | "metadata": {}, 712 | "source": [ 713 | "Let us know turn towards the task of implementing the `GaussianConditionalProbabilityPath` path. " 714 | ] 715 | }, 716 | { 717 | "cell_type": "code", 718 | "execution_count": null, 719 | "id": "4cc2084d-68ab-420e-adcd-ca8dfff34f74", 720 | "metadata": {}, 721 | "outputs": [], 722 | "source": [ 723 | "class GaussianConditionalProbabilityPath(ConditionalProbabilityPath):\n", 724 | " def __init__(self, p_data: Sampleable, alpha: Alpha, beta: Beta):\n", 725 | " p_simple = Gaussian.isotropic(p_data.dim, 1.0)\n", 726 | " super().__init__(p_simple, p_data)\n", 727 | " self.alpha = alpha\n", 728 | " self.beta = beta\n", 729 | "\n", 730 | " def sample_conditioning_variable(self, num_samples: int) -> torch.Tensor:\n", 731 | " \"\"\"\n", 732 | " Samples the conditioning variable z ~ p_data(x)\n", 733 | " Args:\n", 734 | " - num_samples: the number of samples\n", 735 | " Returns:\n", 736 | " - z: samples from p(z), (num_samples, dim)\n", 737 | " \"\"\"\n", 738 | " return p_data.sample(num_samples)\n", 739 | " \n", 740 | " def sample_conditional_path(self, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", 741 | " \"\"\"\n", 742 | " Samples from the conditional distribution p_t(x|z) = N(alpha_t * z, beta_t**2 * I_d)\n", 743 | " Args:\n", 744 | " - z: conditioning variable (num_samples, dim)\n", 745 | " - t: time (num_samples, 1)\n", 746 | " Returns:\n", 747 | " - x: samples from p_t(x|z), (num_samples, dim)\n", 748 | " \"\"\"\n", 749 | " raise NotImplementedError(\"Fill me in for Question 2.2!\")\n", 750 | " \n", 751 | " def conditional_vector_field(self, x: torch.Tensor, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", 752 | " \"\"\"\n", 753 | " Evaluates the conditional vector field u_t(x|z)\n", 754 | " Note: Only defined on t in [0,1)\n", 755 | " Args:\n", 756 | " - x: position variable (num_samples, dim)\n", 757 | " - z: conditioning variable (num_samples, dim)\n", 758 | " - t: time (num_samples, 1)\n", 759 | " Returns:\n", 760 | " - conditional_vector_field: conditional vector field (num_samples, dim)\n", 761 | " \"\"\" \n", 762 | " raise NotImplementedError(\"Fill me in for Question 2.3!\")\n", 763 | "\n", 764 | " def conditional_score(self, x: torch.Tensor, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", 765 | " \"\"\"\n", 766 | " Evaluates the conditional score of p_t(x|z) = N(alpha_t * z, beta_t**2 * I_d)\n", 767 | " Note: Only defined on t in [0,1)\n", 768 | " Args:\n", 769 | " - x: position variable (num_samples, dim)\n", 770 | " - z: conditioning variable (num_samples, dim)\n", 771 | " - t: time (num_samples, 1)\n", 772 | " Returns:\n", 773 | " - conditional_score: conditional score (num_samples, dim)\n", 774 | " \"\"\" \n", 775 | " raise NotImplementedError(\"Fill me in for Question 2.4!\")" 776 | ] 777 | }, 778 | { 779 | "cell_type": "markdown", 780 | "id": "6ae68e0c-b534-41ca-8f05-86d3994ea34d", 781 | "metadata": {}, 782 | "source": [ 783 | "### Problem 2.2: Gaussian Conditional Probability Path" 784 | ] 785 | }, 786 | { 787 | "cell_type": "markdown", 788 | "id": "7f362c33-10fb-440d-aede-48486004c4b8", 789 | "metadata": {}, 790 | "source": [ 791 | "**Your work**: Implement the class method `sample_conditional_path` to sample from the conditional distribution $p_t(x|z) = N(x;\\alpha_t z,\\beta_t^2 I_d)$. You can check the correctness of your implementation by running the next two cells to generate an image of the conditional probability path and comparing these to the corresponding plot from Figure 6 in the lecture notes (the one labeled \"Ground-Truth Conditional Probability Path\").\n", 792 | "\n", 793 | "**Hint**: You may use the fact that the random variable $X \\sim N(\\mu, \\sigma^2 I_d)$ is obtained via $X = \\mu + \\sigma Z$, where $Z \\sim N(0, I_d)$." 794 | ] 795 | }, 796 | { 797 | "cell_type": "markdown", 798 | "id": "3b3744cf-1ba6-438c-9e53-25e823ce696f", 799 | "metadata": {}, 800 | "source": [ 801 | "We can now sample from, and thus visualize, the *conditional* probaability path." 802 | ] 803 | }, 804 | { 805 | "cell_type": "code", 806 | "execution_count": null, 807 | "id": "a8d49b40-083d-4318-b028-a1f32fe839bb", 808 | "metadata": {}, 809 | "outputs": [], 810 | "source": [ 811 | "# Construct conditional probability path\n", 812 | "path = GaussianConditionalProbabilityPath(\n", 813 | " p_data = GaussianMixture.symmetric_2D(nmodes=5, std=PARAMS[\"target_std\"], scale=PARAMS[\"target_scale\"]).to(device), \n", 814 | " alpha = LinearAlpha(),\n", 815 | " beta = SquareRootBeta()\n", 816 | ").to(device)\n", 817 | "\n", 818 | "scale = PARAMS[\"scale\"]\n", 819 | "x_bounds = [-scale,scale]\n", 820 | "y_bounds = [-scale,scale]\n", 821 | "\n", 822 | "plt.figure(figsize=(10,10))\n", 823 | "plt.xlim(*x_bounds)\n", 824 | "plt.ylim(*y_bounds)\n", 825 | "plt.title('Gaussian Conditional Probability Path')\n", 826 | "\n", 827 | "# Plot source and target\n", 828 | "imshow_density(density=p_simple, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Reds'))\n", 829 | "imshow_density(density=p_data, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Blues'))\n", 830 | "\n", 831 | "# Sample conditioning variable z\n", 832 | "z = path.sample_conditioning_variable(1) # (1,2)\n", 833 | "ts = torch.linspace(0.0, 1.0, 7).to(device)\n", 834 | "\n", 835 | "# Plot z\n", 836 | "plt.scatter(z[:,0].cpu(), z[:,1].cpu(), marker='*', color='red', s=75, label='z')\n", 837 | "plt.xticks([])\n", 838 | "plt.yticks([])\n", 839 | "\n", 840 | "# Plot conditional probability path at each intermediate t\n", 841 | "num_samples = 1000\n", 842 | "for t in ts:\n", 843 | " zz = z.expand(num_samples, 2)\n", 844 | " tt = t.unsqueeze(0).expand(num_samples, 1) # (samples, 1)\n", 845 | " samples = path.sample_conditional_path(zz, tt) # (samples, 2)\n", 846 | " plt.scatter(samples[:,0].cpu(), samples[:,1].cpu(), alpha=0.25, s=8, label=f't={t.item():.1f}')\n", 847 | "\n", 848 | "plt.legend(prop={'size': 18}, markerscale=3)\n", 849 | "plt.show()" 850 | ] 851 | }, 852 | { 853 | "cell_type": "markdown", 854 | "id": "ec77b226-8faf-4fd6-80f9-f2d151d529d0", 855 | "metadata": {}, 856 | "source": [ 857 | "### Problem 2.3: Conditional Vector Field\n", 858 | "From lecture and the notes, we know that the conditional vector field $u_t(x|z)$ is given by\n", 859 | "$$u_t(x|z) = \\left(\\dot{\\alpha}_t-\\frac{\\dot{\\beta}_t}{\\beta_t}\\alpha_t\\right)z+\\frac{\\dot{\\beta}_t}{\\beta_t}x.$$" 860 | ] 861 | }, 862 | { 863 | "cell_type": "markdown", 864 | "id": "76df0399-be66-4bcb-9095-03db72a6298e", 865 | "metadata": {}, 866 | "source": [ 867 | "**Your work**: Implement the class method `conditional_vector_field` to compute the conditional vector field $u_t(x|z)$.\n", 868 | "\n", 869 | "**Hint**: You can compute $\\dot{\\alpha}_t$ with `self.alpha.dt(t)`, which has been implemented for you. You may compute $\\dot{\\beta}_t$ similarly." 870 | ] 871 | }, 872 | { 873 | "cell_type": "markdown", 874 | "id": "65dbeee3-36a7-4004-adf3-828bcfa66032", 875 | "metadata": {}, 876 | "source": [ 877 | "We may now visualize the conditional trajectories corresponding to the ODE $$d X_t = u_t(X_t|z)dt, \\quad \\quad X_0 = x_0 \\sim p_{\\text{simple}}.$$" 878 | ] 879 | }, 880 | { 881 | "cell_type": "code", 882 | "execution_count": null, 883 | "id": "b5aee7e7-a730-4270-85d8-f860ca162830", 884 | "metadata": {}, 885 | "outputs": [], 886 | "source": [ 887 | "class ConditionalVectorFieldODE(ODE):\n", 888 | " def __init__(self, path: ConditionalProbabilityPath, z: torch.Tensor):\n", 889 | " \"\"\"\n", 890 | " Args:\n", 891 | " - path: the ConditionalProbabilityPath object to which this vector field corresponds\n", 892 | " - z: the conditioning variable, (1, dim)\n", 893 | " \"\"\"\n", 894 | " super().__init__()\n", 895 | " self.path = path\n", 896 | " self.z = z\n", 897 | "\n", 898 | " def drift_coefficient(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", 899 | " \"\"\"\n", 900 | " Returns the conditional vector field u_t(x|z)\n", 901 | " Args:\n", 902 | " - x: state at time t, shape (bs, dim)\n", 903 | " - t: time, shape (bs,.)\n", 904 | " Returns:\n", 905 | " - u_t(x|z): shape (batch_size, dim)\n", 906 | " \"\"\"\n", 907 | " bs = x.shape[0]\n", 908 | " z = self.z.expand(bs, *self.z.shape[1:])\n", 909 | " return self.path.conditional_vector_field(x,z,t)" 910 | ] 911 | }, 912 | { 913 | "cell_type": "code", 914 | "execution_count": null, 915 | "id": "0f7b9b84-eb6d-4323-80ab-dc41b6281cb4", 916 | "metadata": {}, 917 | "outputs": [], 918 | "source": [ 919 | "# Run me for Problem 2.3!\n", 920 | "\n", 921 | "#######################\n", 922 | "# Change these values #\n", 923 | "#######################\n", 924 | "num_samples = 1000\n", 925 | "num_timesteps = 1000\n", 926 | "num_marginals = 3\n", 927 | "\n", 928 | "########################\n", 929 | "# Setup path and plot #\n", 930 | "########################\n", 931 | "\n", 932 | "path = GaussianConditionalProbabilityPath(\n", 933 | " p_data = GaussianMixture.symmetric_2D(nmodes=5, std=PARAMS[\"target_std\"], scale=PARAMS[\"target_scale\"]).to(device), \n", 934 | " alpha = LinearAlpha(),\n", 935 | " beta = SquareRootBeta()\n", 936 | ").to(device)\n", 937 | "\n", 938 | "\n", 939 | "# Setup figure\n", 940 | "fig, axes = plt.subplots(1,3, figsize=(36, 12))\n", 941 | "scale = PARAMS[\"scale\"]\n", 942 | "legend_size = 24\n", 943 | "markerscale = 1.8\n", 944 | "x_bounds = [-scale,scale]\n", 945 | "y_bounds = [-scale,scale]\n", 946 | "\n", 947 | "# Sample conditioning variable z\n", 948 | "torch.cuda.manual_seed(1)\n", 949 | "z = path.sample_conditioning_variable(1) # (1,2)\n", 950 | "\n", 951 | "######################################\n", 952 | "# Graph samples from conditional ODE #\n", 953 | "######################################\n", 954 | "ax = axes[1]\n", 955 | "\n", 956 | "ax.set_xlim(*x_bounds)\n", 957 | "ax.set_ylim(*y_bounds)\n", 958 | "ax.set_xticks([])\n", 959 | "ax.set_yticks([])\n", 960 | "ax.set_title('Samples from Conditional ODE', fontsize=20)\n", 961 | "ax.scatter(z[:,0].cpu(), z[:,1].cpu(), marker='*', color='red', s=200, label='z',zorder=20) # Plot z\n", 962 | "\n", 963 | "# Plot source and target\n", 964 | "imshow_density(density=p_simple, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Reds'))\n", 965 | "imshow_density(density=p_data, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Blues'))\n", 966 | "\n", 967 | "\n", 968 | "# Construct integrator and plot trajectories\n", 969 | "sigma = 0.5 # Can't make this too high or integration is numerically unstable!\n", 970 | "ode = ConditionalVectorFieldODE(path, z)\n", 971 | "simulator = EulerSimulator(ode)\n", 972 | "x0 = path.p_simple.sample(num_samples) # (num_samples, 2)\n", 973 | "ts = torch.linspace(0.0, 1.0, num_timesteps).view(1,-1,1).expand(num_samples,-1,1).to(device) # (num_samples, nts, 1)\n", 974 | "xts = simulator.simulate_with_trajectory(x0, ts) # (bs, nts, dim)\n", 975 | "\n", 976 | "# Extract every n-th integration step to plot\n", 977 | "every_n = record_every(num_timesteps=num_timesteps, record_every=num_timesteps // num_marginals)\n", 978 | "xts_every_n = xts[:,every_n,:] # (bs, nts // n, dim)\n", 979 | "ts_every_n = ts[0,every_n] # (nts // n,)\n", 980 | "for plot_idx in range(xts_every_n.shape[1]):\n", 981 | " tt = ts_every_n[plot_idx].item()\n", 982 | " ax.scatter(xts_every_n[:,plot_idx,0].detach().cpu(), xts_every_n[:,plot_idx,1].detach().cpu(), marker='o', alpha=0.5, label=f't={tt:.2f}')\n", 983 | "ax.legend(prop={'size': legend_size}, loc='upper right', markerscale=markerscale)\n", 984 | "\n", 985 | "\n", 986 | "#########################################\n", 987 | "# Graph Trajectories of Conditional ODE #\n", 988 | "#########################################\n", 989 | "ax = axes[2]\n", 990 | "\n", 991 | "ax.set_xlim(*x_bounds)\n", 992 | "ax.set_ylim(*y_bounds)\n", 993 | "ax.set_xticks([])\n", 994 | "ax.set_yticks([])\n", 995 | "ax.set_title('Trajectories of Conditional ODE', fontsize=20)\n", 996 | "ax.scatter(z[:,0].cpu(), z[:,1].cpu(), marker='*', color='red', s=200, label='z',zorder=20) # Plot z\n", 997 | "\n", 998 | "\n", 999 | "# Plot source and target\n", 1000 | "imshow_density(density=p_simple, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Reds'))\n", 1001 | "imshow_density(density=p_data, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Blues'))\n", 1002 | "\n", 1003 | "for traj_idx in range(15):\n", 1004 | " ax.plot(xts[traj_idx,:,0].detach().cpu(), xts[traj_idx,:,1].detach().cpu(), alpha=0.5, color='black')\n", 1005 | "ax.legend(prop={'size': legend_size}, loc='upper right', markerscale=markerscale)\n", 1006 | "\n", 1007 | "\n", 1008 | "###################################################\n", 1009 | "# Graph Ground-Truth Conditional Probability Path #\n", 1010 | "###################################################\n", 1011 | "ax = axes[0]\n", 1012 | "\n", 1013 | "ax.set_xlim(*x_bounds)\n", 1014 | "ax.set_ylim(*y_bounds)\n", 1015 | "ax.set_xticks([])\n", 1016 | "ax.set_yticks([])\n", 1017 | "ax.set_title('Ground-Truth Conditional Probability Path', fontsize=20)\n", 1018 | "ax.scatter(z[:,0].cpu(), z[:,1].cpu(), marker='*', color='red', s=200, label='z',zorder=20) # Plot z\n", 1019 | "\n", 1020 | "\n", 1021 | "for plot_idx in range(xts_every_n.shape[1]):\n", 1022 | " tt = ts_every_n[plot_idx].unsqueeze(0).expand(num_samples, 1)\n", 1023 | " zz = z.expand(num_samples, 2)\n", 1024 | " marginal_samples = path.sample_conditional_path(zz, tt)\n", 1025 | " ax.scatter(marginal_samples[:,0].detach().cpu(), marginal_samples[:,1].detach().cpu(), marker='o', alpha=0.5, label=f't={tt[0,0].item():.2f}')\n", 1026 | "\n", 1027 | "# Plot source and target\n", 1028 | "imshow_density(density=p_simple, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Reds'))\n", 1029 | "imshow_density(density=p_data, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Blues'))\n", 1030 | "ax.legend(prop={'size': legend_size}, loc='upper right', markerscale=markerscale)\n", 1031 | "\n", 1032 | "plt.show()" 1033 | ] 1034 | }, 1035 | { 1036 | "cell_type": "markdown", 1037 | "id": "3f9a579c-a0e7-48c3-b1b5-dae685d90b20", 1038 | "metadata": {}, 1039 | "source": [ 1040 | "**Note**: You may have noticed that since for Gaussian probability paths, $z \\sim p_{\\text{data}}(x)$, the method `GaussianConditionalProbabilityPath.sample_conditioning_variable` is effectively sampling from the data distribution. But wait - aren't we trying to learn to sample from $p_{\\text{data}}$ in the first place? This is a subtlety that we have glossed over thus far. The answer is that *in practice*, `sample_conditioning_variable` would return points from a finite *training set*, which is formally assumed to have been sampled IID from the true distribution $z \\sim p_{\\text{data}}$." 1041 | ] 1042 | }, 1043 | { 1044 | "cell_type": "markdown", 1045 | "id": "75225acb-5522-467e-82df-9fba32c0b708", 1046 | "metadata": {}, 1047 | "source": [ 1048 | "### Problem 2.4: The Conditional Score" 1049 | ] 1050 | }, 1051 | { 1052 | "cell_type": "markdown", 1053 | "id": "705cac87-6d1e-4df6-868c-e3d8dbd1c55d", 1054 | "metadata": {}, 1055 | "source": [ 1056 | "As in lecture may now visualize the conditional trajectories corresponding to the SDE $$d X_t = \\left[u_t(X_t|z) + \\frac{1}{2}\\sigma_t^2 \\nabla_x \\log p_t(X_t|z) \\right]dt + \\sigma\\, dW_t, \\quad \\quad X_0 = x_0 \\sim p_{\\text{simple}},$$\n", 1057 | "obtained by adding *Langevin dynamics* to the original ODE." 1058 | ] 1059 | }, 1060 | { 1061 | "cell_type": "markdown", 1062 | "id": "bf5de941-a496-4f79-8814-dd5d47f0118e", 1063 | "metadata": {}, 1064 | "source": [ 1065 | "**Your work**: Implement the class method `conditional_score` to compute the conditional distribution $\\nabla_x \\log p_t(x|z)$, which we compute to be\n", 1066 | "$$\\nabla_x \\log p_t(x|z) = \\nabla_x N(x;\\alpha_t z,\\beta_t^2 I_d) = \\frac{\\alpha_t z - x}{\\beta_t^2}.$$\n", 1067 | "To check for correctness, use the next two cells to verify that samples from the conditional SDE match the samples drawn analytically from the conditional probability path." 1068 | ] 1069 | }, 1070 | { 1071 | "cell_type": "code", 1072 | "execution_count": null, 1073 | "id": "9c0c2adf-0c0b-4b30-9de0-8c0dd8db47a8", 1074 | "metadata": {}, 1075 | "outputs": [], 1076 | "source": [ 1077 | "class ConditionalVectorFieldSDE(SDE):\n", 1078 | " def __init__(self, path: ConditionalProbabilityPath, z: torch.Tensor, sigma: float):\n", 1079 | " \"\"\"\n", 1080 | " Args:\n", 1081 | " - path: the ConditionalProbabilityPath object to which this vector field corresponds\n", 1082 | " - z: the conditioning variable, (1, ...)\n", 1083 | " \"\"\"\n", 1084 | " super().__init__()\n", 1085 | " self.path = path\n", 1086 | " self.z = z\n", 1087 | " self.sigma = sigma\n", 1088 | "\n", 1089 | " def drift_coefficient(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", 1090 | " \"\"\"\n", 1091 | " Returns the conditional vector field u_t(x|z)\n", 1092 | " Args:\n", 1093 | " - x: state at time t, shape (bs, dim)\n", 1094 | " - t: time, shape (bs,.)\n", 1095 | " Returns:\n", 1096 | " - u_t(x|z): shape (batch_size, dim)\n", 1097 | " \"\"\"\n", 1098 | " bs = x.shape[0]\n", 1099 | " z = self.z.expand(bs, *self.z.shape[1:])\n", 1100 | " return self.path.conditional_vector_field(x,z,t) + 0.5 * self.sigma**2 * self.path.conditional_score(x,z,t)\n", 1101 | "\n", 1102 | " def diffusion_coefficient(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", 1103 | " \"\"\"\n", 1104 | " Args:\n", 1105 | " - x: state at time t, shape (bs, dim)\n", 1106 | " - t: time, shape (bs,.)\n", 1107 | " Returns:\n", 1108 | " - u_t(x|z): shape (batch_size, dim)\n", 1109 | " \"\"\"\n", 1110 | " return self.sigma * torch.randn_like(x)" 1111 | ] 1112 | }, 1113 | { 1114 | "cell_type": "markdown", 1115 | "id": "529d91e4-c4ab-4f6e-ab7a-fac8cc40dd7e", 1116 | "metadata": {}, 1117 | "source": [ 1118 | "**Note**: You may notice that strange things happen for large (or even not-so-large) values of $\\sigma$. Plugging in $$\\nabla_x \\log p_t(x|z) = \\frac{\\alpha_t z - x}{\\beta_t^2}$$ into $$d X_t = \\left[u_t(X_t|z) + \\frac{1}{2}\\sigma^2 \\nabla_x \\log p_t(X_t|z) \\right]dt + \\sigma\\, dW_t$$ yields\n", 1119 | "$$d X_t = \\left[u_t(X_t|z) + \\frac{1}{2}\\sigma^2 \\left(\\frac{\\alpha_t z - X_t}{\\beta_t^2}\\right) \\right]dt + \\sigma\\, dW_t.$$\n", 1120 | "When $t \\to 1$, $\\beta_t \\to 0$, so that the second term of the drift explodes (and this explosion scales quadratically with $\\sigma$). With a finite number of simulation steps, we cannot accurately simulate this explosion and thus encounter numerical issues. In practice, this is usually circumvented by setting e.g., $\\sigma_t = \\beta_t$, so that the exploding affect is canceled out by a gradually decreasing noise level." 1121 | ] 1122 | }, 1123 | { 1124 | "cell_type": "code", 1125 | "execution_count": null, 1126 | "id": "4ac806db-9f0f-49cc-b367-84739bda7d1e", 1127 | "metadata": {}, 1128 | "outputs": [], 1129 | "source": [ 1130 | "# Run me for Problem 2.3!\n", 1131 | "\n", 1132 | "#######################\n", 1133 | "# Change these values #\n", 1134 | "#######################\n", 1135 | "num_samples = 1000\n", 1136 | "num_timesteps = 1000\n", 1137 | "num_marginals = 3\n", 1138 | "sigma = 2.5\n", 1139 | "\n", 1140 | "########################\n", 1141 | "# Setup path and plot #\n", 1142 | "########################\n", 1143 | "\n", 1144 | "path = GaussianConditionalProbabilityPath(\n", 1145 | " p_data = GaussianMixture.symmetric_2D(nmodes=5, std=PARAMS[\"target_std\"], scale=PARAMS[\"target_scale\"]).to(device), \n", 1146 | " alpha = LinearAlpha(),\n", 1147 | " beta = SquareRootBeta()\n", 1148 | ").to(device)\n", 1149 | "\n", 1150 | "\n", 1151 | "# Setup figure\n", 1152 | "fig, axes = plt.subplots(1,3, figsize=(36, 12))\n", 1153 | "scale = PARAMS[\"scale\"]\n", 1154 | "x_bounds = [-scale,scale]\n", 1155 | "y_bounds = [-scale,scale]\n", 1156 | "legend_size = 24\n", 1157 | "markerscale = 1.8\n", 1158 | "\n", 1159 | "# Sample conditioning variable z\n", 1160 | "torch.cuda.manual_seed(1)\n", 1161 | "z = path.sample_conditioning_variable(1) # (1,2)\n", 1162 | "\n", 1163 | "######################################\n", 1164 | "# Graph Samples from Conditional SDE #\n", 1165 | "######################################\n", 1166 | "ax = axes[1]\n", 1167 | "\n", 1168 | "ax.set_xlim(*x_bounds)\n", 1169 | "ax.set_ylim(*y_bounds)\n", 1170 | "ax.set_xticks([])\n", 1171 | "ax.set_yticks([])\n", 1172 | "ax.set_title('Samples from Conditional SDE', fontsize=20)\n", 1173 | "ax.scatter(z[:,0].cpu(), z[:,1].cpu(), marker='*', color='red', s=200, label='z',zorder=20) # Plot z\n", 1174 | "\n", 1175 | "# Plot source and target\n", 1176 | "imshow_density(density=p_simple, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Reds'))\n", 1177 | "imshow_density(density=p_data, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Blues'))\n", 1178 | "\n", 1179 | "\n", 1180 | "# Construct integrator and plot trajectories\n", 1181 | "sde = ConditionalVectorFieldSDE(path, z, sigma)\n", 1182 | "simulator = EulerMaruyamaSimulator(sde)\n", 1183 | "x0 = path.p_simple.sample(num_samples) # (num_samples, 2)\n", 1184 | "ts = torch.linspace(0.0, 1.0, num_timesteps).view(1,-1,1).expand(num_samples,-1,1).to(device) # (num_samples, nts, 1)\n", 1185 | "xts = simulator.simulate_with_trajectory(x0, ts) # (bs, nts, dim)\n", 1186 | "\n", 1187 | "# Extract every n-th integration step to plot\n", 1188 | "every_n = record_every(num_timesteps=num_timesteps, record_every=num_timesteps // num_marginals)\n", 1189 | "xts_every_n = xts[:,every_n,:] # (bs, nts // n, dim)\n", 1190 | "ts_every_n = ts[0,every_n] # (nts // n,)\n", 1191 | "for plot_idx in range(xts_every_n.shape[1]):\n", 1192 | " tt = ts_every_n[plot_idx].item()\n", 1193 | " ax.scatter(xts_every_n[:,plot_idx,0].detach().cpu(), xts_every_n[:,plot_idx,1].detach().cpu(), marker='o', alpha=0.5, label=f't={tt:.2f}')\n", 1194 | "ax.legend(prop={'size': legend_size}, loc='upper right', markerscale=markerscale)\n", 1195 | "\n", 1196 | "\n", 1197 | "##########################################\n", 1198 | "# Graph Trajectories of Conditional SDE #\n", 1199 | "##########################################\n", 1200 | "ax = axes[2]\n", 1201 | "\n", 1202 | "ax.set_xlim(*x_bounds)\n", 1203 | "ax.set_ylim(*y_bounds)\n", 1204 | "ax.set_xticks([])\n", 1205 | "ax.set_yticks([])\n", 1206 | "ax.set_title('Trajectories of Conditional SDE', fontsize=20)\n", 1207 | "ax.scatter(z[:,0].cpu(), z[:,1].cpu(), marker='*', color='red', s=200, label='z',zorder=20) # Plot z\n", 1208 | "\n", 1209 | "\n", 1210 | "# Plot source and target\n", 1211 | "imshow_density(density=p_simple, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Reds'))\n", 1212 | "imshow_density(density=p_data, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Blues'))\n", 1213 | "\n", 1214 | "for traj_idx in range(5):\n", 1215 | " ax.plot(xts[traj_idx,:,0].detach().cpu(), xts[traj_idx,:,1].detach().cpu(), alpha=0.5, color='black')\n", 1216 | "ax.legend(prop={'size': legend_size}, loc='upper right', markerscale=markerscale)\n", 1217 | "\n", 1218 | "\n", 1219 | "###################################################\n", 1220 | "# Graph Ground-Truth Conditional Probability Path #\n", 1221 | "###################################################\n", 1222 | "ax = axes[0]\n", 1223 | "\n", 1224 | "ax.set_xlim(*x_bounds)\n", 1225 | "ax.set_ylim(*y_bounds)\n", 1226 | "ax.set_xticks([])\n", 1227 | "ax.set_yticks([])\n", 1228 | "ax.set_title('Ground-Truth Conditional Probability Path', fontsize=20)\n", 1229 | "ax.scatter(z[:,0].cpu(), z[:,1].cpu(), marker='*', color='red', s=200, label='z',zorder=20) # Plot z\n", 1230 | "\n", 1231 | "\n", 1232 | "for plot_idx in range(xts_every_n.shape[1]):\n", 1233 | " tt = ts_every_n[plot_idx].unsqueeze(0).expand(num_samples, 1)\n", 1234 | " zz = z.expand(num_samples, 2)\n", 1235 | " marginal_samples = path.sample_conditional_path(zz, tt)\n", 1236 | " ax.scatter(marginal_samples[:,0].detach().cpu(), marginal_samples[:,1].detach().cpu(), marker='o', alpha=0.5, label=f't={tt[0,0].item():.2f}')\n", 1237 | "\n", 1238 | "# Plot source and target\n", 1239 | "imshow_density(density=p_simple, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Reds'))\n", 1240 | "imshow_density(density=p_data, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Blues'))\n", 1241 | "ax.legend(prop={'size': legend_size}, loc='upper right', markerscale=markerscale)\n", 1242 | "\n", 1243 | "plt.show()" 1244 | ] 1245 | }, 1246 | { 1247 | "cell_type": "markdown", 1248 | "id": "4609e0a6-cc7f-429e-9f81-2423164b5c8d", 1249 | "metadata": {}, 1250 | "source": [ 1251 | "# Part 3: Flow Matching and Score Matching with Gaussian Conditional Probability Paths\n" 1252 | ] 1253 | }, 1254 | { 1255 | "cell_type": "markdown", 1256 | "id": "565e1f4f-f02e-4663-ae99-b62d9fee95b3", 1257 | "metadata": {}, 1258 | "source": [ 1259 | "### Problem 3.1 Flow Matching with Gaussian Conditional Probability Paths" 1260 | ] 1261 | }, 1262 | { 1263 | "cell_type": "markdown", 1264 | "id": "760936a6-fc17-452e-82a1-abb6e1e43c53", 1265 | "metadata": {}, 1266 | "source": [ 1267 | "Recall now that from lecture that our goal is to learn the *marginal vector field* $u_t(x)$ given by $$u_t^{\\text{ref}}(x) = \\mathbb{E}_{z \\sim p_t(z|x)}\\left[u_t^{\\text{ref}}(x|z)\\right].$$\n", 1268 | "Unfortunately, we don't actually know what $u_t^{\\text{ref}}(x)$ is! We will thus approximate $u_t^{\\text{ref}}(x)$ as a neural network $u_t^{\\theta}(x)$, and exploit the identity $$ u_t^{\\text{ref}}(x) = \\text{argmin}_{u_t(x)} \\,\\,\\mathbb{E}_{z \\sim p_t(z|x)} \\left[\\lVert u_t(x) - u_t^{\\text{ref}}(x|z)\\rVert^2\\right]$$ to obtain the **conditional flow matching objective**\n", 1269 | "$$ \\mathcal{L}_{\\text{CFM}}(\\theta) = \\,\\,\\mathbb{E}_{z \\sim p(z), x \\sim p_t(x|z)} \\left[\\lVert u_t^{\\theta}(x) - u_t^{\\text{ref}}(x|z)\\rVert^2\\right].$$\n", 1270 | "To model $u_t^{\\theta}(x)$, we'll use a simple MLP. This network will take in both $x$ and $t$, and will return the learned vector field $u_t^{\\theta}(x)$." 1271 | ] 1272 | }, 1273 | { 1274 | "cell_type": "code", 1275 | "execution_count": null, 1276 | "id": "e57626e7-765e-4e39-aa46-ae403f7960ef", 1277 | "metadata": {}, 1278 | "outputs": [], 1279 | "source": [ 1280 | "def build_mlp(dims: List[int], activation: Type[torch.nn.Module] = torch.nn.SiLU):\n", 1281 | " mlp = []\n", 1282 | " for idx in range(len(dims) - 1):\n", 1283 | " mlp.append(torch.nn.Linear(dims[idx], dims[idx + 1]))\n", 1284 | " if idx < len(dims) - 2:\n", 1285 | " mlp.append(activation())\n", 1286 | " return torch.nn.Sequential(*mlp)\n", 1287 | "\n", 1288 | "class MLPVectorField(torch.nn.Module):\n", 1289 | " \"\"\"\n", 1290 | " MLP-parameterization of the learned vector field u_t^theta(x)\n", 1291 | " \"\"\"\n", 1292 | " def __init__(self, dim: int, hiddens: List[int]):\n", 1293 | " super().__init__()\n", 1294 | " self.dim = dim\n", 1295 | " self.net = build_mlp([dim + 1] + hiddens + [dim])\n", 1296 | "\n", 1297 | " def forward(self, x: torch.Tensor, t: torch.Tensor):\n", 1298 | " \"\"\"\n", 1299 | " Args:\n", 1300 | " - x: (bs, dim)\n", 1301 | " Returns:\n", 1302 | " - u_t^theta(x): (bs, dim)\n", 1303 | " \"\"\"\n", 1304 | " xt = torch.cat([x,t], dim=-1)\n", 1305 | " return self.net(xt) " 1306 | ] 1307 | }, 1308 | { 1309 | "cell_type": "markdown", 1310 | "id": "73bdbaaa-fecb-4753-a9a2-043b1dee585a", 1311 | "metadata": {}, 1312 | "source": [ 1313 | "Let's first define a general-purpose class `Trainer` to keep things tidy as we start training." 1314 | ] 1315 | }, 1316 | { 1317 | "cell_type": "code", 1318 | "execution_count": null, 1319 | "id": "94f618c8-5033-4f16-8826-f29d4f1772f8", 1320 | "metadata": {}, 1321 | "outputs": [], 1322 | "source": [ 1323 | "class Trainer(ABC):\n", 1324 | " def __init__(self, model: torch.nn.Module):\n", 1325 | " super().__init__()\n", 1326 | " self.model = model\n", 1327 | "\n", 1328 | " @abstractmethod\n", 1329 | " def get_train_loss(self, **kwargs) -> torch.Tensor:\n", 1330 | " pass\n", 1331 | "\n", 1332 | " def get_optimizer(self, lr: float):\n", 1333 | " return torch.optim.Adam(self.model.parameters(), lr=lr)\n", 1334 | "\n", 1335 | " def train(self, num_epochs: int, device: torch.device, lr: float = 1e-3, **kwargs) -> torch.Tensor:\n", 1336 | " # Start\n", 1337 | " self.model.to(device)\n", 1338 | " opt = self.get_optimizer(lr)\n", 1339 | " self.model.train()\n", 1340 | "\n", 1341 | " # Train loop\n", 1342 | " pbar = tqdm(enumerate(range(num_epochs)))\n", 1343 | " for idx, epoch in pbar:\n", 1344 | " opt.zero_grad()\n", 1345 | " loss = self.get_train_loss(**kwargs)\n", 1346 | " loss.backward()\n", 1347 | " opt.step()\n", 1348 | " pbar.set_description(f'Epoch {idx}, loss: {loss.item()}')\n", 1349 | "\n", 1350 | " # Finish\n", 1351 | " self.model.eval()" 1352 | ] 1353 | }, 1354 | { 1355 | "cell_type": "markdown", 1356 | "id": "4f270b97-ce2c-4b6b-b924-e4d976813e0e", 1357 | "metadata": {}, 1358 | "source": [ 1359 | "**Your work**: Fill in `ConditionalFlowMatchingTrainer.get_train_loss` below. This function should implement the conditional flow matching objective $$\\mathcal{L}_{\\text{CFM}}(\\theta) = \\,\\,\\mathbb{E}_{\\textcolor{blue}{t \\in \\mathcal{U}[0,1), z \\sim p(z), x \\sim p_t(x|z)}} \\textcolor{green}{\\lVert u_t^{\\theta}(x) - u_t^{\\text{ref}}(x|z)\\rVert^2}$$\n", 1360 | "using a Monte-Carlo estimate of the form\n", 1361 | "$$\\frac{1}{N}\\sum_{i=1}^N \\textcolor{green}{\\lVert u_{t_i}^{\\theta}(x_i) - u_{t_i}^{\\text{ref}}(x_i|z_i)\\rVert^2}, \\quad \\quad \\quad \\forall i\\in[1, \\dots, N]: \\textcolor{blue}{\\,z_i \\sim p_{\\text{data}},\\, t_i \\sim \\mathcal{U}[0,1),\\, x_i \\sim p_t(\\cdot | z_i)}.$$\n", 1362 | "Here, $N$ is our *batch size*.\n", 1363 | "\n", 1364 | "\n", 1365 | "**Hint 1**: For sampling:\n", 1366 | "- You can sample `batch_size` points $z$ from $p_{\\text{data}}$ using `self.path.p_data.sample(batch_size)`.\n", 1367 | "- You can sample `batch_size` values of `t` using `torch.rand(batch_size, 1)`.\n", 1368 | "- You can sample `batch_size` points from `p_t(x|z)` using `self.path.sample_conditional_path(z,t)`.\n", 1369 | "\n", 1370 | "**Hint 2**: For the loss function:\n", 1371 | "- You can access $u_t^{\\theta}(x)$ using `self.model(x,t)`.\n", 1372 | "- You can access $u_t^{\\text{ref}}(x|z)$ using `self.path.conditional_vector_field(x,z,t)`." 1373 | ] 1374 | }, 1375 | { 1376 | "cell_type": "code", 1377 | "execution_count": null, 1378 | "id": "0f01a592-40da-4e71-a7ce-b53363100c64", 1379 | "metadata": {}, 1380 | "outputs": [], 1381 | "source": [ 1382 | "class ConditionalFlowMatchingTrainer(Trainer):\n", 1383 | " def __init__(self, path: ConditionalProbabilityPath, model: MLPVectorField, **kwargs):\n", 1384 | " super().__init__(model, **kwargs)\n", 1385 | " self.path = path\n", 1386 | "\n", 1387 | " def get_train_loss(self, batch_size: int) -> torch.Tensor:\n", 1388 | " raise NotImplementedError(\"Fill me in for Question 3.1!\")" 1389 | ] 1390 | }, 1391 | { 1392 | "cell_type": "markdown", 1393 | "id": "75fc5768-d14c-4477-88d1-7b27dfc61538", 1394 | "metadata": {}, 1395 | "source": [ 1396 | "Now let's train! This may take about a minute... **Remember, the loss should converge, but not to zero!**" 1397 | ] 1398 | }, 1399 | { 1400 | "cell_type": "code", 1401 | "execution_count": null, 1402 | "id": "3a2d13db-7fd3-429b-ab54-86d937b4a9ff", 1403 | "metadata": {}, 1404 | "outputs": [], 1405 | "source": [ 1406 | "# Construct conditional probability path\n", 1407 | "path = GaussianConditionalProbabilityPath(\n", 1408 | " p_data = GaussianMixture.symmetric_2D(nmodes=5, std=PARAMS[\"target_std\"], scale=PARAMS[\"target_scale\"]).to(device), \n", 1409 | " alpha = LinearAlpha(),\n", 1410 | " beta = SquareRootBeta()\n", 1411 | ").to(device)\n", 1412 | "\n", 1413 | "# Construct learnable vector field\n", 1414 | "flow_model = MLPVectorField(dim=2, hiddens=[64,64,64,64])\n", 1415 | "\n", 1416 | "# Construct trainer\n", 1417 | "trainer = ConditionalFlowMatchingTrainer(path, flow_model)\n", 1418 | "losses = trainer.train(num_epochs=5000, device=device, lr=1e-3, batch_size=1000)" 1419 | ] 1420 | }, 1421 | { 1422 | "cell_type": "markdown", 1423 | "id": "b3e40ead-7bfc-4083-9d9c-bf0abd83a5c7", 1424 | "metadata": {}, 1425 | "source": [ 1426 | "Is our model any good? Let's visualize? First, we need to wrap our learned vector field in an subclass of `ODE` so that we can simulate it using our `Simulator` class." 1427 | ] 1428 | }, 1429 | { 1430 | "cell_type": "code", 1431 | "execution_count": null, 1432 | "id": "a25f8b48-cc63-4b52-b59e-5bc57813e338", 1433 | "metadata": {}, 1434 | "outputs": [], 1435 | "source": [ 1436 | "class LearnedVectorFieldODE(ODE):\n", 1437 | " def __init__(self, net: MLPVectorField):\n", 1438 | " self.net = net\n", 1439 | "\n", 1440 | " def drift_coefficient(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", 1441 | " \"\"\"\n", 1442 | " Args:\n", 1443 | " - x: (bs, dim)\n", 1444 | " - t: (bs, dim)\n", 1445 | " Returns:\n", 1446 | " - u_t: (bs, dim)\n", 1447 | " \"\"\"\n", 1448 | " return self.net(x, t)" 1449 | ] 1450 | }, 1451 | { 1452 | "cell_type": "code", 1453 | "execution_count": null, 1454 | "id": "56152a16-abf6-47cb-aa0f-f8f69b8ab690", 1455 | "metadata": {}, 1456 | "outputs": [], 1457 | "source": [ 1458 | "#######################\n", 1459 | "# Change these values #\n", 1460 | "#######################\n", 1461 | "num_samples = 1000\n", 1462 | "num_timesteps = 1000\n", 1463 | "num_marginals = 3\n", 1464 | "\n", 1465 | "\n", 1466 | "##############\n", 1467 | "# Setup Plot #\n", 1468 | "##############\n", 1469 | "\n", 1470 | "scale = PARAMS[\"scale\"]\n", 1471 | "x_bounds = [-scale,scale]\n", 1472 | "y_bounds = [-scale,scale]\n", 1473 | "legend_size=24\n", 1474 | "markerscale=1.8\n", 1475 | "\n", 1476 | "# Setup figure\n", 1477 | "fig, axes = plt.subplots(1,3, figsize=(36, 12))\n", 1478 | "\n", 1479 | "###########################################\n", 1480 | "# Graph Samples from Learned Marginal ODE #\n", 1481 | "###########################################\n", 1482 | "ax = axes[1]\n", 1483 | "\n", 1484 | "ax.set_xlim(*x_bounds)\n", 1485 | "ax.set_ylim(*y_bounds)\n", 1486 | "ax.set_xticks([])\n", 1487 | "ax.set_yticks([])\n", 1488 | "ax.set_title(\"Samples from Learned Marginal ODE\", fontsize=20)\n", 1489 | "\n", 1490 | "# Plot source and target\n", 1491 | "imshow_density(density=p_simple, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Reds'))\n", 1492 | "imshow_density(density=p_data, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Blues'))\n", 1493 | "\n", 1494 | "\n", 1495 | "# Construct integrator and plot trajectories\n", 1496 | "ode = LearnedVectorFieldODE(flow_model)\n", 1497 | "simulator = EulerSimulator(ode)\n", 1498 | "x0 = path.p_simple.sample(num_samples) # (num_samples, 2)\n", 1499 | "ts = torch.linspace(0.0, 1.0, num_timesteps).view(1,-1,1).expand(num_samples,-1,1).to(device) # (num_samples, nts, 1)\n", 1500 | "xts = simulator.simulate_with_trajectory(x0, ts) # (bs, nts, dim)\n", 1501 | "\n", 1502 | "# Extract every n-th integration step to plot\n", 1503 | "every_n = record_every(num_timesteps=num_timesteps, record_every=num_timesteps // num_marginals)\n", 1504 | "xts_every_n = xts[:,every_n,:] # (bs, nts // n, dim)\n", 1505 | "ts_every_n = ts[0,every_n] # (nts // n,)\n", 1506 | "for plot_idx in range(xts_every_n.shape[1]):\n", 1507 | " tt = ts_every_n[plot_idx].item()\n", 1508 | " ax.scatter(xts_every_n[:,plot_idx,0].detach().cpu(), xts_every_n[:,plot_idx,1].detach().cpu(), marker='o', alpha=0.5, label=f't={tt:.2f}')\n", 1509 | "\n", 1510 | "ax.legend(prop={'size': legend_size}, loc='upper right', markerscale=markerscale)\n", 1511 | "\n", 1512 | "##############################################\n", 1513 | "# Graph Trajectories of Learned Marginal ODE #\n", 1514 | "##############################################\n", 1515 | "ax = axes[2]\n", 1516 | "ax.set_title(\"Trajectories of Learned Marginal ODE\", fontsize=20)\n", 1517 | "ax.set_xlim(*x_bounds)\n", 1518 | "ax.set_ylim(*y_bounds)\n", 1519 | "ax.set_xticks([])\n", 1520 | "ax.set_yticks([])\n", 1521 | "\n", 1522 | "# Plot source and target\n", 1523 | "imshow_density(density=p_simple, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Reds'))\n", 1524 | "imshow_density(density=p_data, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Blues'))\n", 1525 | "\n", 1526 | "for traj_idx in range(num_samples // 10):\n", 1527 | " ax.plot(xts[traj_idx,:,0].detach().cpu(), xts[traj_idx,:,1].detach().cpu(), alpha=0.5, color='black')\n", 1528 | "\n", 1529 | "################################################\n", 1530 | "# Graph Ground-Truth Marginal Probability Path #\n", 1531 | "################################################\n", 1532 | "ax = axes[0]\n", 1533 | "ax.set_title(\"Ground-Truth Marginal Probability Path\", fontsize=20)\n", 1534 | "ax.set_xlim(*x_bounds)\n", 1535 | "ax.set_ylim(*y_bounds)\n", 1536 | "ax.set_xticks([])\n", 1537 | "ax.set_yticks([])\n", 1538 | "\n", 1539 | "for plot_idx in range(xts_every_n.shape[1]):\n", 1540 | " tt = ts_every_n[plot_idx].unsqueeze(0).expand(num_samples, 1)\n", 1541 | " marginal_samples = path.sample_marginal_path(tt)\n", 1542 | " ax.scatter(marginal_samples[:,0].detach().cpu(), marginal_samples[:,1].detach().cpu(), marker='o', alpha=0.5, label=f't={tt[0,0].item():.2f}')\n", 1543 | "\n", 1544 | "# Plot source and target\n", 1545 | "imshow_density(density=p_simple, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Reds'))\n", 1546 | "imshow_density(density=p_data, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Blues'))\n", 1547 | "\n", 1548 | "ax.legend(prop={'size': legend_size}, loc='upper right', markerscale=markerscale)\n", 1549 | " \n", 1550 | "plt.show()" 1551 | ] 1552 | }, 1553 | { 1554 | "cell_type": "markdown", 1555 | "id": "24b0491a-9d82-413d-b6c9-95199e0f23e4", 1556 | "metadata": {}, 1557 | "source": [ 1558 | "### Problem 3.2: Score Matching with Gaussian Conditional Probability Paths" 1559 | ] 1560 | }, 1561 | { 1562 | "cell_type": "markdown", 1563 | "id": "48fe3eb5-51ee-4137-9626-6c1d1af910fe", 1564 | "metadata": {}, 1565 | "source": [ 1566 | "We have thus far used flow matching to train a model $u_t^{\\theta}(x) \\approx u_t^{\\text{ref}}$ so that $$d X_t = u_t^{\\theta}(X_t) dt $$ approximately passes through the desired marginal probability path $p_t(x)$. Now recall from lecture that we may augment the reference marginal vector field $u_t^{\\text{ref}}(x)$ with *Langevin dynamics* to add stochasticity while preserving the marginals, viz., $$dX_t = \\left[u_t^{\\text{ref}}(x) + \\frac{1}{2}\\sigma^2 \\nabla \\log p_t(x)\\right] dt + \\sigma d W_t.$$\n", 1567 | "Substituting our learned approximation $u_t^{\\theta}(x) \\approx u_t^{\\text{ref}}$ therefore yields \n", 1568 | "$$dX_t = \\left[u_t^{\\theta}(x) + \\frac{1}{2}\\sigma^2 \\nabla \\log p_t(x)\\right] dt + \\sigma d W_t.$$\n", 1569 | "There's just one issue, what's the marginal score $\\nabla \\log p_t(x)$? In Question 2.3, we computed the conditional score $\\nabla \\log p_t(x|z)$ of the Gaussian probability path. In the same way that we learned an approximation $u_t^{\\theta}(x) \\approx u_t^{\\text{ref}}$, we'd like to be able to learn a similar approximation $s_t^{\\theta}(x) \\approx \\nabla \\log p_t(x)$. Recall from lecture the identity $$\\nabla \\log p_t(x) = \\mathbb{E}_{z \\sim p_t(z|x)}\\left[\\nabla \\log p_t(x|z) \\right].$$ It then immediately follows that\n", 1570 | "$$\\nabla \\log p_t(x) = \\text{argmin}_{s_t(x)} \\,\\,\\mathbb{E}_{z \\sim p(z), x \\sim p_t(x|z)} \\left[\\lVert s_t(x) - \\nabla \\log p_t(x|z)\\rVert^2\\right].$$\n", 1571 | "We thus obtain the **conditional score matching** loss\n", 1572 | "$$\\mathcal{L}_{\\text{CSM}}(\\theta) \\triangleq \\mathbb{E}_{t \\sim \\mathcal{U}[0,1), z \\sim p(z), x \\sim p_t(x|z)} \\left[\\lVert s_t^{\\theta}(x) - \\nabla \\log p_t(x|z)\\rVert^2\\right].$$\n", 1573 | "Here, we will parameterize $s_t^{\\theta}(x): \\mathbb{R}^2 \\to \\mathbb{R}^2$ as a simple MLP, just like $u_t^{\\theta}(x)$." 1574 | ] 1575 | }, 1576 | { 1577 | "cell_type": "markdown", 1578 | "id": "7efcae65-2c86-4018-9f7c-ae9175607e06", 1579 | "metadata": {}, 1580 | "source": [ 1581 | "**Your job**: Fill in method `ConditionalScoreMatchingTrainer.get_train_loss` to implement the conditional score matching loss $\\mathcal{L}_{\\text{CSM}}(\\theta)$.\n", 1582 | "\n", 1583 | "**Hint:** Remember to re-use your implementation of `GaussianConditionalProbabilityPath.conditional_score`!" 1584 | ] 1585 | }, 1586 | { 1587 | "cell_type": "code", 1588 | "execution_count": null, 1589 | "id": "96d654b2-13d4-47bf-a5c3-98606c580771", 1590 | "metadata": {}, 1591 | "outputs": [], 1592 | "source": [ 1593 | "class MLPScore(torch.nn.Module):\n", 1594 | " \"\"\"\n", 1595 | " MLP-parameterization of the learned score field\n", 1596 | " \"\"\"\n", 1597 | " def __init__(self, dim: int, hiddens: List[int]):\n", 1598 | " super().__init__()\n", 1599 | " self.dim = dim\n", 1600 | " self.net = build_mlp([dim + 1] + hiddens + [dim])\n", 1601 | "\n", 1602 | " def forward(self, x: torch.Tensor, t: torch.Tensor):\n", 1603 | " \"\"\"\n", 1604 | " Args:\n", 1605 | " - x: (bs, dim)\n", 1606 | " Returns:\n", 1607 | " - s_t^theta(x): (bs, dim)\n", 1608 | " \"\"\"\n", 1609 | " xt = torch.cat([x,t], dim=-1)\n", 1610 | " return self.net(xt) \n", 1611 | " \n", 1612 | "class ConditionalScoreMatchingTrainer(Trainer):\n", 1613 | " def __init__(self, path: ConditionalProbabilityPath, model: MLPScore, **kwargs):\n", 1614 | " super().__init__(model, **kwargs)\n", 1615 | " self.path = path\n", 1616 | "\n", 1617 | " def get_train_loss(self, batch_size: int) -> torch.Tensor:\n", 1618 | " raise NotImplementedError(\"Fill me in for Question 3.2!\")" 1619 | ] 1620 | }, 1621 | { 1622 | "cell_type": "markdown", 1623 | "id": "4f35e47c-1836-41b4-8234-37acc589ee70", 1624 | "metadata": {}, 1625 | "source": [ 1626 | "Now let's train! **Remember, the loss should converge, but not to zero!**" 1627 | ] 1628 | }, 1629 | { 1630 | "cell_type": "code", 1631 | "execution_count": null, 1632 | "id": "bfff0733-cd1d-4e90-970a-3f9948f0ea65", 1633 | "metadata": {}, 1634 | "outputs": [], 1635 | "source": [ 1636 | "# Construct conditional probability path\n", 1637 | "path = GaussianConditionalProbabilityPath(\n", 1638 | " p_data = GaussianMixture.symmetric_2D(nmodes=5, std=PARAMS[\"target_std\"], scale=PARAMS[\"target_scale\"]).to(device), \n", 1639 | " alpha = LinearAlpha(),\n", 1640 | " beta = SquareRootBeta()\n", 1641 | ").to(device)\n", 1642 | "\n", 1643 | "# Construct learnable vector field\n", 1644 | "score_model = MLPScore(dim=2, hiddens=[64,64,64,64])\n", 1645 | "\n", 1646 | "# Construct trainer\n", 1647 | "trainer = ConditionalScoreMatchingTrainer(path, score_model)\n", 1648 | "losses = trainer.train(num_epochs=1000, device=device, lr=1e-3, batch_size=1000)" 1649 | ] 1650 | }, 1651 | { 1652 | "cell_type": "markdown", 1653 | "id": "a946925a-1230-48bc-bc6c-19b72bd26099", 1654 | "metadata": {}, 1655 | "source": [ 1656 | "Now let's visualize our work! Before we do however, we'll need to wrap our learned our flow model and score model in an instance of `SDE` so that we can integrate it using our `EulerMaruyamaIntegrator` class. This new class, `LangevinFlowSDE` will correspond to the dynamics $$dX_t = \\left[u_t^{\\theta}(x) + \\frac{1}{2}\\sigma^2 s_t^{\\theta}(x)\\right] dt + \\sigma d W_t.$$" 1657 | ] 1658 | }, 1659 | { 1660 | "cell_type": "code", 1661 | "execution_count": null, 1662 | "id": "afc4fae7-9dbb-4e5e-9839-8088f491549e", 1663 | "metadata": {}, 1664 | "outputs": [], 1665 | "source": [ 1666 | "class LangevinFlowSDE(SDE):\n", 1667 | " def __init__(self, flow_model: MLPVectorField, score_model: MLPScore, sigma: float):\n", 1668 | " \"\"\"\n", 1669 | " Args:\n", 1670 | " - path: the ConditionalProbabilityPath object to which this vector field corresponds\n", 1671 | " - z: the conditioning variable, (1, dim)\n", 1672 | " \"\"\"\n", 1673 | " super().__init__()\n", 1674 | " self.flow_model = flow_model\n", 1675 | " self.score_model = score_model\n", 1676 | " self.sigma = sigma\n", 1677 | "\n", 1678 | " def drift_coefficient(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", 1679 | " \"\"\"\n", 1680 | " Args:\n", 1681 | " - x: state at time t, shape (bs, dim)\n", 1682 | " - t: time, shape (bs,.)\n", 1683 | " Returns:\n", 1684 | " - u_t(x|z): shape (batch_size, dim)\n", 1685 | " \"\"\"\n", 1686 | " return self.flow_model(x,t) + 0.5 * self.sigma ** 2 * self.score_model(x, t)\n", 1687 | "\n", 1688 | " def diffusion_coefficient(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", 1689 | " \"\"\"\n", 1690 | " Args:\n", 1691 | " - x: state at time t, shape (bs, dim)\n", 1692 | " - t: time, shape (bs,.)\n", 1693 | " Returns:\n", 1694 | " - u_t(x|z): shape (batch_size, dim)\n", 1695 | " \"\"\"\n", 1696 | " return self.sigma * torch.randn_like(x)" 1697 | ] 1698 | }, 1699 | { 1700 | "cell_type": "code", 1701 | "execution_count": null, 1702 | "id": "e40fa803-25d7-4bbd-8230-e3650bdb79d4", 1703 | "metadata": {}, 1704 | "outputs": [], 1705 | "source": [ 1706 | "#######################\n", 1707 | "# Change these values #\n", 1708 | "#######################\n", 1709 | "num_samples = 1000\n", 1710 | "num_timesteps = 300\n", 1711 | "num_marginals = 3\n", 1712 | "sigma = 2.0 # Don't set sigma too large or you'll get numerical issues!\n", 1713 | "\n", 1714 | "\n", 1715 | "##############\n", 1716 | "# Setup Plot #\n", 1717 | "##############\n", 1718 | "\n", 1719 | "scale = PARAMS[\"scale\"]\n", 1720 | "x_bounds = [-scale,scale]\n", 1721 | "y_bounds = [-scale,scale]\n", 1722 | "legend_size = 24\n", 1723 | "markerscale = 1.8\n", 1724 | "\n", 1725 | "# Setup figure\n", 1726 | "fig, axes = plt.subplots(1,3, figsize=(36, 12))\n", 1727 | "\n", 1728 | "###########################################\n", 1729 | "# Graph Samples from Learned Marginal SDE #\n", 1730 | "###########################################\n", 1731 | "ax = axes[1]\n", 1732 | "ax.set_title(\"Samples from Learned Marginal SDE\", fontsize=20)\n", 1733 | "ax.set_xlim(*x_bounds)\n", 1734 | "ax.set_ylim(*y_bounds)\n", 1735 | "ax.set_xticks([])\n", 1736 | "ax.set_yticks([])\n", 1737 | "\n", 1738 | "# Plot source and target\n", 1739 | "imshow_density(density=path.p_simple, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Reds'))\n", 1740 | "imshow_density(density=path.p_data, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Blues'))\n", 1741 | "\n", 1742 | "\n", 1743 | "# Construct integrator and plot trajectories\n", 1744 | "sde = LangevinFlowSDE(flow_model, score_model, sigma)\n", 1745 | "simulator = EulerMaruyamaSimulator(sde)\n", 1746 | "x0 = path.p_simple.sample(num_samples) # (num_samples, 2)\n", 1747 | "ts = torch.linspace(0.0, 1.0, num_timesteps).view(1,-1,1).expand(num_samples,-1,1).to(device) # (num_samples, nts, 1)\n", 1748 | "xts = simulator.simulate_with_trajectory(x0, ts) # (bs, nts, dim)\n", 1749 | "\n", 1750 | "# Extract every n-th integration step to plot\n", 1751 | "every_n = record_every(num_timesteps=num_timesteps, record_every=num_timesteps // num_marginals)\n", 1752 | "xts_every_n = xts[:,every_n,:] # (bs, nts // n, dim)\n", 1753 | "ts_every_n = ts[0,every_n] # (nts // n,)\n", 1754 | "for plot_idx in range(xts_every_n.shape[1]):\n", 1755 | " tt = ts_every_n[plot_idx].item()\n", 1756 | " ax.scatter(xts_every_n[:,plot_idx,0].detach().cpu(), xts_every_n[:,plot_idx,1].detach().cpu(), marker='o', alpha=0.5, label=f't={tt:.2f}')\n", 1757 | "\n", 1758 | "ax.legend(prop={'size': legend_size}, loc='upper right', markerscale=markerscale)\n", 1759 | "\n", 1760 | "###############################################\n", 1761 | "# Graph Trajectories of Learned Marginal SDE #\n", 1762 | "###############################################\n", 1763 | "ax = axes[2]\n", 1764 | "ax.set_title(\"Trajectories of Learned Marginal SDE\", fontsize=20)\n", 1765 | "ax.set_xlim(*x_bounds)\n", 1766 | "ax.set_ylim(*y_bounds)\n", 1767 | "ax.set_xticks([])\n", 1768 | "ax.set_yticks([])\n", 1769 | "\n", 1770 | "# Plot source and target\n", 1771 | "imshow_density(density=path.p_simple, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Reds'))\n", 1772 | "imshow_density(density=path.p_data, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Blues'))\n", 1773 | "\n", 1774 | "for traj_idx in range(num_samples // 10):\n", 1775 | " ax.plot(xts[traj_idx,:,0].detach().cpu(), xts[traj_idx,:,1].detach().cpu(), alpha=0.5, color='black')\n", 1776 | "\n", 1777 | "################################################\n", 1778 | "# Graph Ground-Truth Marginal Probability Path #\n", 1779 | "################################################\n", 1780 | "ax = axes[0]\n", 1781 | "ax.set_title(\"Ground-Truth Marginal Probability Path\", fontsize=20)\n", 1782 | "ax.set_xlim(*x_bounds)\n", 1783 | "ax.set_ylim(*y_bounds)\n", 1784 | "ax.set_xticks([])\n", 1785 | "ax.set_yticks([])\n", 1786 | "\n", 1787 | "for plot_idx in range(xts_every_n.shape[1]):\n", 1788 | " tt = ts_every_n[plot_idx].unsqueeze(0).expand(num_samples, 1)\n", 1789 | " marginal_samples = path.sample_marginal_path(tt)\n", 1790 | " ax.scatter(marginal_samples[:,0].detach().cpu(), marginal_samples[:,1].detach().cpu(), marker='o', alpha=0.5, label=f't={tt[0,0].item():.2f}')\n", 1791 | "\n", 1792 | "# Plot source and target\n", 1793 | "imshow_density(density=path.p_simple, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Reds'))\n", 1794 | "imshow_density(density=path.p_data, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Blues'))\n", 1795 | "\n", 1796 | "ax.legend(prop={'size': legend_size}, loc='upper right', markerscale=markerscale)\n", 1797 | " \n", 1798 | "plt.show()" 1799 | ] 1800 | }, 1801 | { 1802 | "cell_type": "markdown", 1803 | "id": "57aaf7f6-d7d0-4fe8-b415-8a069ac02f93", 1804 | "metadata": {}, 1805 | "source": [ 1806 | "### Question 3.3: Deriving the Marginal Score from the Marginal Flow\n", 1807 | "Recall from the notes and the lecture that for Gaussian probability paths $$u_t^{\\text{ref}}(x) = a_tx + b_t\\nabla \\log p_t^{\\text{ref}}(x).$$\n", 1808 | "\n", 1809 | "where $(a_t, b_t) = \\left(\\frac{\\dot{\\alpha}_t}{\\alpha_t}, \\beta_t^2 \\frac{\\dot{\\alpha}_t}{\\alpha_t} - \\dot{\\beta}_t \\beta_t\\right)$. Rearranging yields $$\\nabla \\log p_t^{\\text{ref}}(x) = \\frac{u_t^{\\text{ref}}(x) - a_tx}{b_t}.$$\n", 1810 | "\n", 1811 | "Therefore, we may instead exploit the fact that we have already trained $u_t^{\\theta}(x)$, to parameterize $s_t^{\\theta}(x)$ via\n", 1812 | "$$\\tilde{s}_t^{\\theta}(x) = \\frac{u_t^{\\theta}(x) - a_tx}{b_t} = \\frac{\\alpha_t u_t^{\\theta}(x) - \\dot{\\alpha}_t x}{\\beta_t^2 \\dot{\\alpha}_t - \\alpha_t \\dot{\\beta}_t \\beta_t},$$\n", 1813 | "so long as $\\beta_t^2 \\dot{\\alpha}_t - \\alpha_t \\dot{\\beta}_t \\beta_t \\neq 0$ (which is true for $t \\in [0,1)$ by monotonicity). Here, we differentiate $\\tilde{s}_t^{\\theta}(x)$ paramterized via $u_t^{\\theta}(x)$ from $s_t^{\\theta}(x)$ learned indepedently using score matching. Plugging in $\\alpha_t = t$ and $\\beta_t = \\sqrt{1-t}$, we find that $$\\beta_t^2 \\dot{\\alpha}_t - \\alpha_t \\dot{\\beta}_t \\beta_t = \\begin{cases} 1 - \\frac{t}{2} & \\text{if}\\,\\,t\\in [0,1)\\\\0 & \\text{if}\\,\\,{t=1}. \\end{cases}.$$ In the following visualization, we'll circumvent the issue at $t=1.0$ by taking $t=1 - \\varepsilon$ in place of $t=1$, for small $\\varepsilon \\approx 0$." 1814 | ] 1815 | }, 1816 | { 1817 | "cell_type": "markdown", 1818 | "id": "3e1f0dd1-dab6-44bd-aeb1-e472a6294af4", 1819 | "metadata": {}, 1820 | "source": [ 1821 | "**Your job**: Implement $\\tilde{s}_t^{\\theta}(x)$ by filling in the body of `ScoreFromVectorField.forward` below. The next several cells generate a visualization comparing the flow-parameterized score $\\tilde{s}_t^{\\theta}(x)$ to our independently learned score $s_t^{\\theta}(x)$. You can check that your implementation is correct by making sure that the visualizations match." 1822 | ] 1823 | }, 1824 | { 1825 | "cell_type": "code", 1826 | "execution_count": null, 1827 | "id": "91efc10a-7ef7-4ec6-975f-8e52e0cac2a1", 1828 | "metadata": {}, 1829 | "outputs": [], 1830 | "source": [ 1831 | "class ScoreFromVectorField(torch.nn.Module):\n", 1832 | " \"\"\"\n", 1833 | " Parameterization of score via learned vector field (for the special case of a Gaussian conditional probability path)\n", 1834 | " \"\"\"\n", 1835 | " def __init__(self, vector_field: MLPVectorField, alpha: Alpha, beta: Beta):\n", 1836 | " super().__init__()\n", 1837 | " self.vector_field = vector_field\n", 1838 | " self.alpha = alpha\n", 1839 | " self.beta = beta\n", 1840 | "\n", 1841 | " def forward(self, x: torch.Tensor, t: torch.Tensor):\n", 1842 | " \"\"\"\n", 1843 | " Args:\n", 1844 | " - x: (bs, dim)\n", 1845 | " Returns:\n", 1846 | " - score: (bs, dim)\n", 1847 | " \"\"\"\n", 1848 | " raise NotImplementedError(\"Fill me in for Question 3.3!\") " 1849 | ] 1850 | }, 1851 | { 1852 | "cell_type": "markdown", 1853 | "id": "0aefa1ba-794f-47bb-8fa1-2f6a53b486b1", 1854 | "metadata": {}, 1855 | "source": [ 1856 | "Now, let's compare our learned marginal score $s_t^{\\theta}(x)$ (an instance of `MLPScore`) to our flow-parameterized score (an instance of `ScoreFromVectorField`). We'll do so by plotting the vector fields across time and space.\n", 1857 | "\n", 1858 | "**Note**: The two score parameterizations will probably look a bit different, but should generally point in the same direction, especially around modes. To sanity check your output, you may consult Figure 10 from the lecture notes." 1859 | ] 1860 | }, 1861 | { 1862 | "cell_type": "code", 1863 | "execution_count": null, 1864 | "id": "f2b99180-43d1-4ef9-bcf9-5e53ee0819b1", 1865 | "metadata": {}, 1866 | "outputs": [], 1867 | "source": [ 1868 | "#######################\n", 1869 | "# Change these values #\n", 1870 | "#######################\n", 1871 | "num_bins = 30\n", 1872 | "num_marginals = 4\n", 1873 | "\n", 1874 | "##############################\n", 1875 | "# Construct probability path #\n", 1876 | "##############################\n", 1877 | "path = GaussianConditionalProbabilityPath(\n", 1878 | " p_data = GaussianMixture.symmetric_2D(nmodes=5, std=PARAMS[\"target_std\"], scale=PARAMS[\"target_scale\"]).to(device), \n", 1879 | " alpha = LinearAlpha(),\n", 1880 | " beta = SquareRootBeta()\n", 1881 | ").to(device)\n", 1882 | "\n", 1883 | "#########################\n", 1884 | "# Define score networks #\n", 1885 | "#########################\n", 1886 | "learned_score_model = score_model\n", 1887 | "flow_score_model = ScoreFromVectorField(flow_model, path.alpha, path.beta)\n", 1888 | "\n", 1889 | "\n", 1890 | "###############################\n", 1891 | "# Plot score fields over time #\n", 1892 | "###############################\n", 1893 | "fig, axes = plt.subplots(2, num_marginals, figsize=(6 * num_marginals, 12))\n", 1894 | "axes = axes.reshape((2, num_marginals))\n", 1895 | "\n", 1896 | "scale = PARAMS[\"scale\"]\n", 1897 | "ts = torch.linspace(0.0, 0.9999, num_marginals).to(device)\n", 1898 | "xs = torch.linspace(-scale, scale, num_bins).to(device)\n", 1899 | "ys = torch.linspace(-scale, scale, num_bins).to(device)\n", 1900 | "xx, yy = torch.meshgrid(xs, ys)\n", 1901 | "xx = xx.reshape(-1,1)\n", 1902 | "yy = yy.reshape(-1,1)\n", 1903 | "xy = torch.cat([xx,yy], dim=-1)\n", 1904 | "\n", 1905 | "axes[0,0].set_ylabel(\"Learned with Score Matching\", fontsize=12)\n", 1906 | "axes[1,0].set_ylabel(\"Computed from $u_t^{{\\\\theta}}(x)$\", fontsize=12)\n", 1907 | "for idx in range(num_marginals):\n", 1908 | " t = ts[idx]\n", 1909 | " bs = num_bins ** 2\n", 1910 | " tt = t.view(1,1).expand(bs, 1)\n", 1911 | " \n", 1912 | " # Learned scores\n", 1913 | " learned_scores = learned_score_model(xy, tt)\n", 1914 | " learned_scores_x = learned_scores[:,0]\n", 1915 | " learned_scores_y = learned_scores[:,1]\n", 1916 | "\n", 1917 | " ax = axes[0, idx]\n", 1918 | " ax.quiver(xx.detach().cpu(), yy.detach().cpu(), learned_scores_x.detach().cpu(), learned_scores_y.detach().cpu(), scale=125, alpha=0.5)\n", 1919 | " imshow_density(density=path.p_simple, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Reds'))\n", 1920 | " imshow_density(density=path.p_data, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Blues'))\n", 1921 | " ax.set_title(f'$s_{{t}}^{{\\\\theta}}$ at t={t.item():.2f}')\n", 1922 | " ax.set_xticks([])\n", 1923 | " ax.set_yticks([])\n", 1924 | " \n", 1925 | "\n", 1926 | " # Flow score model\n", 1927 | " ax = axes\n", 1928 | " flow_scores = flow_score_model(xy,tt)\n", 1929 | " flow_scores_x = flow_scores[:,0]\n", 1930 | " flow_scores_y = flow_scores[:,1]\n", 1931 | "\n", 1932 | " ax = axes[1, idx]\n", 1933 | " ax.quiver(xx.detach().cpu(), yy.detach().cpu(), flow_scores_x.detach().cpu(), flow_scores_y.detach().cpu(), scale=125, alpha=0.5)\n", 1934 | " imshow_density(density=path.p_simple, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Reds'))\n", 1935 | " imshow_density(density=path.p_data, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Blues'))\n", 1936 | " ax.set_title(f'$\\\\tilde{{s}}_{{t}}^{{\\\\theta}}$ at t={t.item():.2f}')\n", 1937 | " ax.set_xticks([])\n", 1938 | " ax.set_yticks([])" 1939 | ] 1940 | }, 1941 | { 1942 | "cell_type": "markdown", 1943 | "id": "b58fd169-6891-45ab-b917-14cc3912bec8", 1944 | "metadata": {}, 1945 | "source": [ 1946 | "# Part 4: Flow Matching Between Arbitrary Distributions with a Linear Probability Path\n", 1947 | "In this section, we will consider an alterntive conditional probability path - the **linear conditional probability path** - which can be constructed as follows. Given a source distribution $p_{\\text{simple}}$ and a data distribution $p_{\\text{data}}$, for a fixed $z$ we may consider the *interpolant* $$X_t = (1-t) X_0 + tz$$\n", 1948 | "where $X_0 \\sim p_{\\text{simple}}$ is a random variable. We may then define $p_t(x|z)$ so that $X_t \\sim p_t(x|z)$. Then it is apparent that $p_0(x|z) = p_{\\text{simple}}(x)$ and $p_1(x| z)= \\delta_z(x)$. It is also not difficult to show that the conditional vector field is given by $u_t^{\\text{ref}}(x) = \\frac{z - x}{1-t}$ for $t \\in [0,1)$. We make two observations about the linear conditional probability path: First, unlike in the Gaussian probability path, we do not have a closed form for the conditional score $\\nabla \\log p_t(x|z)$. Second, there is no constraint that $p_{\\text{simple}}$ be a Gaussian, which we will exploit in Problem 4.3 to construct flows between arbitrary choices of $p_{\\text{simple}}$ and $p_{\\text{data}}$. First, let's examine some more complicated choices of $p_{\\text{data}}$." 1949 | ] 1950 | }, 1951 | { 1952 | "cell_type": "code", 1953 | "execution_count": null, 1954 | "id": "c24d596d-e307-4344-b683-aa7b55adf0c0", 1955 | "metadata": {}, 1956 | "outputs": [], 1957 | "source": [ 1958 | "class MoonsSampleable(Sampleable):\n", 1959 | " \"\"\"\n", 1960 | " Implementation of the Moons distribution using sklearn's make_moons\n", 1961 | " \"\"\"\n", 1962 | " def __init__(self, device: torch.device, noise: float = 0.05, scale: float = 5.0, offset: Optional[torch.Tensor] = None):\n", 1963 | " \"\"\"\n", 1964 | " Args:\n", 1965 | " noise: Standard deviation of Gaussian noise added to the data\n", 1966 | " scale: How much to scale the data\n", 1967 | " offset: How much to shift the samples from the original distribution (2,)\n", 1968 | " \"\"\"\n", 1969 | " self.noise = noise\n", 1970 | " self.scale = scale\n", 1971 | " self.device = device\n", 1972 | " if offset is None:\n", 1973 | " offset = torch.zeros(2)\n", 1974 | " self.offset = offset.to(device)\n", 1975 | "\n", 1976 | " @property\n", 1977 | " def dim(self) -> int:\n", 1978 | " return 2\n", 1979 | "\n", 1980 | " def sample(self, num_samples: int) -> torch.Tensor:\n", 1981 | " \"\"\"\n", 1982 | " Args:\n", 1983 | " num_samples: Number of samples to generate\n", 1984 | " Returns:\n", 1985 | " torch.Tensor: Generated samples with shape (num_samples, 3)\n", 1986 | " \"\"\"\n", 1987 | " samples, _ = make_moons(\n", 1988 | " n_samples=num_samples,\n", 1989 | " noise=self.noise,\n", 1990 | " random_state=None # Allow for random generation each time\n", 1991 | " )\n", 1992 | " return self.scale * torch.from_numpy(samples.astype(np.float32)).to(self.device) + self.offset\n", 1993 | "\n", 1994 | "class CirclesSampleable(Sampleable):\n", 1995 | " \"\"\"\n", 1996 | " Implementation of concentric circle distribution using sklearn's make_circles\n", 1997 | " \"\"\"\n", 1998 | " def __init__(self, device: torch.device, noise: float = 0.05, scale=5.0, offset: Optional[torch.Tensor] = None):\n", 1999 | " \"\"\"\n", 2000 | " Args:\n", 2001 | " noise: standard deviation of Gaussian noise added to the data\n", 2002 | " \"\"\"\n", 2003 | " self.noise = noise\n", 2004 | " self.scale = scale\n", 2005 | " self.device = device\n", 2006 | " if offset is None:\n", 2007 | " offset = torch.zeros(2)\n", 2008 | " self.offset = offset.to(device)\n", 2009 | "\n", 2010 | " @property\n", 2011 | " def dim(self) -> int:\n", 2012 | " return 2\n", 2013 | "\n", 2014 | " def sample(self, num_samples: int) -> torch.Tensor:\n", 2015 | " \"\"\"\n", 2016 | " Args:\n", 2017 | " num_samples: number of samples to generate\n", 2018 | " Returns:\n", 2019 | " torch.Tensor: shape (num_samples, 3)\n", 2020 | " \"\"\"\n", 2021 | " samples, _ = make_circles(\n", 2022 | " n_samples=num_samples,\n", 2023 | " noise=self.noise,\n", 2024 | " factor=0.5,\n", 2025 | " random_state=None\n", 2026 | " )\n", 2027 | " return self.scale * torch.from_numpy(samples.astype(np.float32)).to(self.device) + self.offset\n", 2028 | "\n", 2029 | "class CheckerboardSampleable(Sampleable):\n", 2030 | " \"\"\"\n", 2031 | " Checkboard-esque distribution\n", 2032 | " \"\"\"\n", 2033 | " def __init__(self, device: torch.device, grid_size: int = 3, scale=5.0):\n", 2034 | " \"\"\"\n", 2035 | " Args:\n", 2036 | " noise: standard deviation of Gaussian noise added to the data\n", 2037 | " \"\"\"\n", 2038 | " self.grid_size = grid_size\n", 2039 | " self.scale = scale\n", 2040 | " self.device = device\n", 2041 | "\n", 2042 | " @property\n", 2043 | " def dim(self) -> int:\n", 2044 | " return 2\n", 2045 | "\n", 2046 | " def sample(self, num_samples: int) -> torch.Tensor:\n", 2047 | " \"\"\"\n", 2048 | " Args:\n", 2049 | " num_samples: number of samples to generate\n", 2050 | " Returns:\n", 2051 | " torch.Tensor: shape (num_samples, 3)\n", 2052 | " \"\"\"\n", 2053 | " grid_length = 2 * self.scale / self.grid_size\n", 2054 | " samples = torch.zeros(0,2).to(device)\n", 2055 | " while samples.shape[0] < num_samples:\n", 2056 | " # Sample num_samples\n", 2057 | " new_samples = (torch.rand(num_samples,2).to(self.device) - 0.5) * 2 * self.scale\n", 2058 | " x_mask = torch.floor((new_samples[:,0] + self.scale) / grid_length) % 2 == 0 # (bs,)\n", 2059 | " y_mask = torch.floor((new_samples[:,1] + self.scale) / grid_length) % 2 == 0 # (bs,)\n", 2060 | " accept_mask = torch.logical_xor(~x_mask, y_mask)\n", 2061 | " samples = torch.cat([samples, new_samples[accept_mask]], dim=0)\n", 2062 | " return samples[:num_samples]" 2063 | ] 2064 | }, 2065 | { 2066 | "cell_type": "code", 2067 | "execution_count": null, 2068 | "id": "b3c44f89-ddbc-45c4-99fd-d1c8ea5027a3", 2069 | "metadata": {}, 2070 | "outputs": [], 2071 | "source": [ 2072 | "# Visualize alternative choices of p_data\n", 2073 | "targets = {\n", 2074 | " \"circles\": CirclesSampleable(device),\n", 2075 | " \"moons\": MoonsSampleable(device, scale=3.5),\n", 2076 | " \"checkerboard\": CheckerboardSampleable(device, grid_size=4)\n", 2077 | "}\n", 2078 | "\n", 2079 | "###################################\n", 2080 | "# Graph Various Choices of p_data #\n", 2081 | "###################################\n", 2082 | "\n", 2083 | "fig, axes = plt.subplots(1, len(targets), figsize=(6 * len(targets), 6))\n", 2084 | "\n", 2085 | "num_samples = 20000\n", 2086 | "num_bins = 100\n", 2087 | "for idx, (target_name, target) in enumerate(targets.items()):\n", 2088 | " ax = axes[idx]\n", 2089 | " hist2d_sampleable(target, num_samples, bins=num_bins, scale=7.5, ax=ax)\n", 2090 | " ax.set_aspect('equal')\n", 2091 | " ax.set_xticks([])\n", 2092 | " ax.set_yticks([])\n", 2093 | " ax.set_title(f'Histogram of {target_name}')\n", 2094 | "\n", 2095 | "plt.show()" 2096 | ] 2097 | }, 2098 | { 2099 | "cell_type": "markdown", 2100 | "id": "5e14a9a6-a5c7-40df-9acc-39af296b22e2", 2101 | "metadata": {}, 2102 | "source": [ 2103 | "### Problem 4.1: Linear Probability Paths\n", 2104 | "Below we define the `LinearConditionalProbabilityPath`. We purposely omit the implementation of `conditional_score` because, as mentioned earlier, there is no nice form for it!" 2105 | ] 2106 | }, 2107 | { 2108 | "cell_type": "code", 2109 | "execution_count": null, 2110 | "id": "2d0aa7a5-49b2-48af-a2c3-552c00c3322d", 2111 | "metadata": {}, 2112 | "outputs": [], 2113 | "source": [ 2114 | "class LinearConditionalProbabilityPath(ConditionalProbabilityPath):\n", 2115 | " def __init__(self, p_simple: Sampleable, p_data: Sampleable):\n", 2116 | " super().__init__(p_simple, p_data)\n", 2117 | "\n", 2118 | " def sample_conditioning_variable(self, num_samples: int) -> torch.Tensor:\n", 2119 | " \"\"\"\n", 2120 | " Samples the conditioning variable z ~ p_data(x)\n", 2121 | " Args:\n", 2122 | " - num_samples: the number of samples\n", 2123 | " Returns:\n", 2124 | " - z: samples from p(z), (num_samples, ...)\n", 2125 | " \"\"\"\n", 2126 | " return self.p_data.sample(num_samples)\n", 2127 | " \n", 2128 | " def sample_conditional_path(self, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", 2129 | " \"\"\"\n", 2130 | " Samples the random variable X_t = (1-t) X_0 + tz\n", 2131 | " Args:\n", 2132 | " - z: conditioning variable (num_samples, dim)\n", 2133 | " - t: time (num_samples, 1)\n", 2134 | " Returns:\n", 2135 | " - x: samples from p_t(x|z), (num_samples, dim)\n", 2136 | " \"\"\"\n", 2137 | " raise NotImplementedError(\"Fill me in for Question 4.1!\")\n", 2138 | " \n", 2139 | " def conditional_vector_field(self, x: torch.Tensor, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", 2140 | " \"\"\"\n", 2141 | " Evaluates the conditional vector field u_t(x|z) = (z - x) / (1 - t)\n", 2142 | " Note: Only defined on t in [0,1)\n", 2143 | " Args:\n", 2144 | " - x: position variable (num_samples, dim)\n", 2145 | " - z: conditioning variable (num_samples, dim)\n", 2146 | " - t: time (num_samples, 1)\n", 2147 | " Returns:\n", 2148 | " - conditional_vector_field: conditional vector field (num_samples, dim)\n", 2149 | " \"\"\" \n", 2150 | " raise NotImplementedError(\"Fill me in for Question 4.1!\")\n", 2151 | "\n", 2152 | " def conditional_score(self, x: torch.Tensor, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", 2153 | " \"\"\"\n", 2154 | " Not known for Linear Conditional Probability Paths\n", 2155 | " \"\"\" \n", 2156 | " raise Exception(\"You should not be calling this function!\")" 2157 | ] 2158 | }, 2159 | { 2160 | "cell_type": "markdown", 2161 | "id": "9fea4cf9-4365-4b45-a3c3-be267c162ea9", 2162 | "metadata": {}, 2163 | "source": [ 2164 | "**Your work**: Implement `LinearConditionalProbabilityPath.sample_conditional_path` and `LinearConditionalProbabilityPath.conditional_vector_field`." 2165 | ] 2166 | }, 2167 | { 2168 | "cell_type": "markdown", 2169 | "id": "9f357d6e-577e-47d9-96ea-6d13f3bfa795", 2170 | "metadata": {}, 2171 | "source": [ 2172 | "You can sanity check that the implementations are correct by ensuring that they are consistent with one another. The following visualization provides three sequences of graphs:\n", 2173 | "1. The first row shows the conditional probability path, as produced by your implemententation of `sample_conditional_path`.\n", 2174 | "2. The second row shows the conditional probability path, as produced by your implemententation of `conditional_vector_field`.\n", 2175 | "3. The third row shows the marginal probability path, as produced by `sample_marginal_path`." 2176 | ] 2177 | }, 2178 | { 2179 | "cell_type": "code", 2180 | "execution_count": null, 2181 | "id": "e9b3d914-8508-41c9-8b1f-1a789e919711", 2182 | "metadata": {}, 2183 | "outputs": [], 2184 | "source": [ 2185 | "##########################\n", 2186 | "# Play around with these #\n", 2187 | "##########################\n", 2188 | "num_samples = 100000\n", 2189 | "num_timesteps = 500\n", 2190 | "num_marginals = 5\n", 2191 | "assert num_timesteps % (num_marginals - 1) == 0\n", 2192 | "\n", 2193 | "##########################################\n", 2194 | "# Construct conditional probability path #\n", 2195 | "##########################################\n", 2196 | "path = LinearConditionalProbabilityPath(\n", 2197 | " p_simple = Gaussian.isotropic(dim=2, std=1.0),\n", 2198 | " p_data = CheckerboardSampleable(device, grid_size=4)\n", 2199 | ").to(device)\n", 2200 | "z = path.p_data.sample(1) # (1,2)\n", 2201 | "\n", 2202 | "##############\n", 2203 | "# Setup plots #\n", 2204 | "##############\n", 2205 | "\n", 2206 | "fig, axes = plt.subplots(3, num_marginals, figsize=(6 * num_marginals, 6 * 3))\n", 2207 | "axes = axes.reshape(3, num_marginals)\n", 2208 | "scale = 6.0\n", 2209 | "\n", 2210 | "\n", 2211 | "#####################################################################\n", 2212 | "# Graph conditional probability paths using sample_conditional_path #\n", 2213 | "#####################################################################\n", 2214 | "ts = torch.linspace(0.0, 1.0, num_marginals).to(device)\n", 2215 | "for idx, t in enumerate(ts):\n", 2216 | " zz = z.expand(num_samples, -1)\n", 2217 | " tt = t.view(1,1).expand(num_samples,1)\n", 2218 | " xts = path.sample_conditional_path(zz, tt)\n", 2219 | " percentile = min(99 + 2 * torch.sin(t).item(), 100)\n", 2220 | " hist2d_samples(samples=xts.cpu(), ax=axes[0, idx], bins=300, scale=scale, percentile=percentile, alpha=1.0)\n", 2221 | " axes[0, idx].set_xlim(-scale, scale)\n", 2222 | " axes[0, idx].set_ylim(-scale, scale)\n", 2223 | " axes[0, idx].set_xticks([])\n", 2224 | " axes[0, idx].set_yticks([])\n", 2225 | " axes[0, idx].set_title(f'$t={t.item():.2f}$', fontsize=15)\n", 2226 | "axes[0, 0].set_ylabel(\"Conditional (from Ground-Truth)\", fontsize=20)\n", 2227 | "\n", 2228 | "# Plot z\n", 2229 | "axes[0,-1].scatter(z[:,0].cpu(), z[:,1].cpu(), marker='*', color='red', s=200, label='z',zorder=20)\n", 2230 | "axes[0,-1].legend()\n", 2231 | "\n", 2232 | "######################################################################\n", 2233 | "# Graph conditional probability paths using conditional_vector_field #\n", 2234 | "######################################################################\n", 2235 | "ode = ConditionalVectorFieldODE(path, z)\n", 2236 | "simulator = EulerSimulator(ode)\n", 2237 | "ts = torch.linspace(0,1,num_timesteps).to(device)\n", 2238 | "record_every_idxs = record_every(len(ts), len(ts) // (num_marginals - 1))\n", 2239 | "x0 = path.p_simple.sample(num_samples)\n", 2240 | "xts = simulator.simulate_with_trajectory(x0, ts.view(1,-1,1).expand(num_samples,-1,1))\n", 2241 | "xts = xts[:,record_every_idxs,:]\n", 2242 | "for idx in range(xts.shape[1]):\n", 2243 | " xx = xts[:,idx,:]\n", 2244 | " tt = ts[record_every_idxs[idx]]\n", 2245 | " percentile = min(99 + 2 * torch.sin(tt).item(), 100)\n", 2246 | " hist2d_samples(samples=xx.cpu(), ax=axes[1, idx], bins=300, scale=scale, percentile=percentile, alpha=1.0)\n", 2247 | " axes[1, idx].set_xlim(-scale, scale)\n", 2248 | " axes[1, idx].set_ylim(-scale, scale)\n", 2249 | " axes[1, idx].set_xticks([])\n", 2250 | " axes[1, idx].set_yticks([])\n", 2251 | " axes[1, idx].set_title(f'$t={tt.item():.2f}$', fontsize=15)\n", 2252 | "axes[1, 0].set_ylabel(\"Conditional (from ODE)\", fontsize=20)\n", 2253 | "\n", 2254 | "# Plot z\n", 2255 | "axes[1,-1].scatter(z[:,0].cpu(), z[:,1].cpu(), marker='*', color='red', s=200, label='z',zorder=20)\n", 2256 | "axes[1,-1].legend()\n", 2257 | "\n", 2258 | "##################################################################\n", 2259 | "# Graph conditional probability paths using sample_marginal_path #\n", 2260 | "##################################################################\n", 2261 | "ts = torch.linspace(0.0, 1.0, num_marginals).to(device)\n", 2262 | "for idx, t in enumerate(ts):\n", 2263 | " zz = z.expand(num_samples, -1)\n", 2264 | " tt = t.view(1,1).expand(num_samples,1)\n", 2265 | " xts = path.sample_marginal_path(tt)\n", 2266 | " hist2d_samples(samples=xts.cpu(), ax=axes[2, idx], bins=300, scale=scale, percentile=99, alpha=1.0)\n", 2267 | " axes[2, idx].set_xlim(-scale, scale)\n", 2268 | " axes[2, idx].set_ylim(-scale, scale)\n", 2269 | " axes[2, idx].set_xticks([])\n", 2270 | " axes[2, idx].set_yticks([])\n", 2271 | " axes[2, idx].set_title(f'$t={t.item():.2f}$', fontsize=15)\n", 2272 | "axes[2, 0].set_ylabel(\"Marginal\", fontsize=20)\n", 2273 | "\n", 2274 | "plt.show()" 2275 | ] 2276 | }, 2277 | { 2278 | "cell_type": "markdown", 2279 | "id": "aaf5f81d-bd38-4222-a14e-553a31afa964", 2280 | "metadata": {}, 2281 | "source": [ 2282 | "### Part 4.2: Flow Matching with Linear Probability Paths\n", 2283 | "Now, let's train a flow matching model using the linear conditional probability path! **Remember, the loss should converge, but not necessarily to zero!**" 2284 | ] 2285 | }, 2286 | { 2287 | "cell_type": "code", 2288 | "execution_count": null, 2289 | "id": "bc9dfb02-8c65-436d-ad70-a6fee48ba035", 2290 | "metadata": {}, 2291 | "outputs": [], 2292 | "source": [ 2293 | "# Construct conditional probability path\n", 2294 | "path = LinearConditionalProbabilityPath(\n", 2295 | " p_simple = Gaussian.isotropic(dim=2, std=1.0),\n", 2296 | " p_data = CheckerboardSampleable(device, grid_size=4)\n", 2297 | ").to(device)\n", 2298 | "\n", 2299 | "# Construct learnable vector field\n", 2300 | "linear_flow_model = MLPVectorField(dim=2, hiddens=[64,64,64,64])\n", 2301 | "\n", 2302 | "# Construct trainer\n", 2303 | "trainer = ConditionalFlowMatchingTrainer(path, linear_flow_model)\n", 2304 | "losses = trainer.train(num_epochs=10000, device=device, lr=1e-3, batch_size=2000)" 2305 | ] 2306 | }, 2307 | { 2308 | "cell_type": "code", 2309 | "execution_count": null, 2310 | "id": "5ebeb708-9829-4276-afe9-db7a930e741a", 2311 | "metadata": {}, 2312 | "outputs": [], 2313 | "source": [ 2314 | "##########################\n", 2315 | "# Play around With These #\n", 2316 | "##########################\n", 2317 | "num_samples = 50000\n", 2318 | "num_marginals = 5\n", 2319 | "\n", 2320 | "##############\n", 2321 | "# Setup Plots #\n", 2322 | "##############\n", 2323 | "\n", 2324 | "fig, axes = plt.subplots(2, num_marginals, figsize=(6 * num_marginals, 6 * 2))\n", 2325 | "axes = axes.reshape(2, num_marginals)\n", 2326 | "scale = 6.0\n", 2327 | "\n", 2328 | "###########################\n", 2329 | "# Graph Ground-Truth Marginals #\n", 2330 | "###########################\n", 2331 | "ts = torch.linspace(0.0, 1.0, num_marginals).to(device)\n", 2332 | "for idx, t in enumerate(ts):\n", 2333 | " tt = t.view(1,1).expand(num_samples,1)\n", 2334 | " xts = path.sample_marginal_path(tt)\n", 2335 | " hist2d_samples(samples=xts.cpu(), ax=axes[0, idx], bins=200, scale=scale, percentile=99, alpha=1.0)\n", 2336 | " axes[0, idx].set_xlim(-scale, scale)\n", 2337 | " axes[0, idx].set_ylim(-scale, scale)\n", 2338 | " axes[0, idx].set_xticks([])\n", 2339 | " axes[0, idx].set_yticks([])\n", 2340 | " axes[0, idx].set_title(f'$t={t.item():.2f}$', fontsize=15)\n", 2341 | "axes[0, 0].set_ylabel(\"Ground Truth\", fontsize=20)\n", 2342 | "\n", 2343 | "###############################################\n", 2344 | "# Graph Marginals of Learned Vector Field #\n", 2345 | "###############################################\n", 2346 | "ode = LearnedVectorFieldODE(linear_flow_model)\n", 2347 | "simulator = EulerSimulator(ode)\n", 2348 | "ts = torch.linspace(0,1,100).to(device)\n", 2349 | "record_every_idxs = record_every(len(ts), len(ts) // (num_marginals - 1))\n", 2350 | "x0 = path.p_simple.sample(num_samples)\n", 2351 | "xts = simulator.simulate_with_trajectory(x0, ts.view(1,-1,1).expand(num_samples,-1,1))\n", 2352 | "xts = xts[:,record_every_idxs,:]\n", 2353 | "for idx in range(xts.shape[1]):\n", 2354 | " xx = xts[:,idx,:]\n", 2355 | " hist2d_samples(samples=xx.cpu(), ax=axes[1, idx], bins=200, scale=scale, percentile=99, alpha=1.0)\n", 2356 | " axes[1, idx].set_xlim(-scale, scale)\n", 2357 | " axes[1, idx].set_ylim(-scale, scale)\n", 2358 | " axes[1, idx].set_xticks([])\n", 2359 | " axes[1, idx].set_yticks([])\n", 2360 | " tt = ts[record_every_idxs[idx]]\n", 2361 | " axes[1, idx].set_title(f'$t={tt.item():.2f}$', fontsize=15)\n", 2362 | "axes[1, 0].set_ylabel(\"Learned\", fontsize=20) \n", 2363 | "\n", 2364 | "plt.show()" 2365 | ] 2366 | }, 2367 | { 2368 | "cell_type": "markdown", 2369 | "id": "98e4271f-495d-4164-8183-8619c08af3d7", 2370 | "metadata": {}, 2371 | "source": [ 2372 | "### Problem 4.3: Bridging Between Arbitrary Source and Target\n", 2373 | "Notice that in our construction of the linear probability path, there is no need for $p_{\\text{simple}}$ to be a Gaussian. Let's try setting it to another distribution!" 2374 | ] 2375 | }, 2376 | { 2377 | "cell_type": "code", 2378 | "execution_count": null, 2379 | "id": "ccec7698-4263-4c2c-a69e-e4a0b7a38724", 2380 | "metadata": {}, 2381 | "outputs": [], 2382 | "source": [ 2383 | "# Construct conditional probability path\n", 2384 | "path = LinearConditionalProbabilityPath(\n", 2385 | " p_simple = CirclesSampleable(device),\n", 2386 | " p_data = CheckerboardSampleable(device, grid_size=4)\n", 2387 | ").to(device)\n", 2388 | "\n", 2389 | "# Construct learnable vector field\n", 2390 | "bridging_flow_model = MLPVectorField(dim=2, hiddens=[100,100,100,100])\n", 2391 | "\n", 2392 | "# Construct trainer\n", 2393 | "trainer = ConditionalFlowMatchingTrainer(path, bridging_flow_model)\n", 2394 | "losses = trainer.train(num_epochs=20000, device=device, lr=1e-3, batch_size=2000)" 2395 | ] 2396 | }, 2397 | { 2398 | "cell_type": "code", 2399 | "execution_count": null, 2400 | "id": "9e7b2d3f-f647-42a8-a8a1-d7c183efb90e", 2401 | "metadata": {}, 2402 | "outputs": [], 2403 | "source": [ 2404 | "##########################\n", 2405 | "# Play around With These #\n", 2406 | "##########################\n", 2407 | "num_samples = 30000\n", 2408 | "num_marginals = 5\n", 2409 | "\n", 2410 | "\n", 2411 | "##############\n", 2412 | "# Setup Plots #\n", 2413 | "##############\n", 2414 | "\n", 2415 | "fig, axes = plt.subplots(2, num_marginals, figsize=(6 * num_marginals, 6 * 2))\n", 2416 | "axes = axes.reshape(2, num_marginals)\n", 2417 | "scale = 6.0\n", 2418 | "\n", 2419 | "\n", 2420 | "###########################\n", 2421 | "# Graph Ground-Truth Marginals #\n", 2422 | "###########################\n", 2423 | "ts = torch.linspace(0.0, 1.0, num_marginals).to(device)\n", 2424 | "for idx, t in enumerate(ts):\n", 2425 | " tt = t.view(1,1).expand(num_samples,1)\n", 2426 | " xts = path.sample_marginal_path(tt)\n", 2427 | " hist2d_samples(samples=xts.cpu(), ax=axes[0, idx], bins=200, scale=scale, percentile=99, alpha=1.0)\n", 2428 | " axes[0, idx].set_xlim(-scale, scale)\n", 2429 | " axes[0, idx].set_ylim(-scale, scale)\n", 2430 | " axes[0, idx].set_xticks([])\n", 2431 | " axes[0, idx].set_yticks([])\n", 2432 | " axes[0, idx].set_title(f'$t={t.item():.2f}$', fontsize=15)\n", 2433 | "axes[0, 0].set_ylabel(\"Ground Truth\", fontsize=20)\n", 2434 | "\n", 2435 | "###############################################\n", 2436 | "# Graph Learned Marginals #\n", 2437 | "###############################################\n", 2438 | "ode = LearnedVectorFieldODE(bridging_flow_model)\n", 2439 | "simulator = EulerSimulator(ode)\n", 2440 | "ts = torch.linspace(0,1,200).to(device)\n", 2441 | "record_every_idxs = record_every(len(ts), len(ts) // (num_marginals - 1))\n", 2442 | "x0 = path.p_simple.sample(num_samples)\n", 2443 | "xts = simulator.simulate_with_trajectory(x0, ts.view(1,-1,1).expand(num_samples,-1,1))\n", 2444 | "xts = xts[:,record_every_idxs,:]\n", 2445 | "for idx in range(xts.shape[1]):\n", 2446 | " xx = xts[:,idx,:]\n", 2447 | " hist2d_samples(samples=xx.cpu(), ax=axes[1, idx], bins=200, scale=scale, percentile=99, alpha=1.0)\n", 2448 | " axes[1, idx].set_xlim(-scale, scale)\n", 2449 | " axes[1, idx].set_ylim(-scale, scale)\n", 2450 | " axes[1, idx].set_xticks([])\n", 2451 | " axes[1, idx].set_yticks([])\n", 2452 | " tt = ts[record_every_idxs[idx]]\n", 2453 | " axes[1, idx].set_title(f'$t={tt.item():.2f}$', fontsize=15)\n", 2454 | "axes[1, 0].set_ylabel(\"Learned\", fontsize=20)\n", 2455 | "\n", 2456 | "plt.show()" 2457 | ] 2458 | }, 2459 | { 2460 | "cell_type": "markdown", 2461 | "id": "3d7bb3ca-939b-4a62-8870-ba9b2e103ad1", 2462 | "metadata": {}, 2463 | "source": [ 2464 | "**Your job**: Play around with the choice of $p_{\\text{simple}}$ and $p_{\\text{data}}$. Any observations?" 2465 | ] 2466 | } 2467 | ], 2468 | "metadata": { 2469 | "kernelspec": { 2470 | "display_name": "mtds", 2471 | "language": "python", 2472 | "name": "mtds" 2473 | }, 2474 | "language_info": { 2475 | "codemirror_mode": { 2476 | "name": "ipython", 2477 | "version": 3 2478 | }, 2479 | "file_extension": ".py", 2480 | "mimetype": "text/x-python", 2481 | "name": "python", 2482 | "nbconvert_exporter": "python", 2483 | "pygments_lexer": "ipython3", 2484 | "version": "3.9.20" 2485 | } 2486 | }, 2487 | "nbformat": 4, 2488 | "nbformat_minor": 5 2489 | } 2490 | --------------------------------------------------------------------------------