├── 1d_ode.gif ├── README.md ├── cnf_density.gif ├── linear_system_of_odes.gif ├── neural_ode_solvers.ipynb └── neural_ode_solvers_presentation.pdf /1d_ode.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MStypulkowski/neural_ode_solvers_workshop/5a7e52537f25feb213f1c6f8bc75d095ca4d1fa7/1d_ode.gif -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural ODE solvers 2 | Presentation and notebook with four assignments on neural ODE solvers. 3 | 4 | The materials were designed for AI Tech spring school held in May 2022 at Politechnika Gdańska, Poland. 5 | 6 | ## Presentation 7 | Slides from the presentation are available in the `neural_ode_solvers_presentation.pdf` file. They are divided into sections: 8 | 1. Introduction. 9 | 2. Numerical ODE solvers. 10 | 3. Neural network as an ODE (covering the adjoint method). 11 | 4. Continuous Normalizing Flow (CNF). 12 | 13 | ## Notebook 14 | The `neural_ode_solvers.ipynb` file contains solved exercises with descriptions, instructions, and unit tests. You can find four models on the following topics: 15 | 16 | 1. Linear system of two ODEs - stationary dynamic function. 17 | 18 | ![Linear system](linear_system_of_odes.gif) 19 | 20 | 2. 1D ODE - non-stationary dynamic function. 21 | 22 | ![1D ODE](1d_ode.gif) 23 | 24 | 3. MNIST classification. 25 | Classifier achieves ~97% accuracy. 26 | 27 | 4. Continuous Normalizing Flow (CNF). 28 | 29 | ![CNF density](cnf_density.gif) 30 | -------------------------------------------------------------------------------- /cnf_density.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MStypulkowski/neural_ode_solvers_workshop/5a7e52537f25feb213f1c6f8bc75d095ca4d1fa7/cnf_density.gif -------------------------------------------------------------------------------- /linear_system_of_odes.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MStypulkowski/neural_ode_solvers_workshop/5a7e52537f25feb213f1c6f8bc75d095ca4d1fa7/linear_system_of_odes.gif -------------------------------------------------------------------------------- /neural_ode_solvers.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "LHd4UeFSFITW" 7 | }, 8 | "source": [ 9 | "# Neural ODE solvers\n", 10 | "In this notebook, we will take a look at four examples of using neural ODE solvers. We will get familiar with `torchdiffeq` library and use it to build models for:\n", 11 | "\n", 12 | "1. Linear system of ODEs with time-independent dynamic function\n", 13 | "2. Non-homogeneous ODE with time-dependent dynamic function\n", 14 | "3. MNIST classification\n", 15 | "4. Continuous Normalizing Flow (CNF)\n", 16 | "\n", 17 | "Have fun!\n", 18 | "\n" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": { 25 | "id": "ec4LZKyZVOcK" 26 | }, 27 | "outputs": [], 28 | "source": [ 29 | "import os\n", 30 | "from tqdm.notebook import tqdm\n", 31 | "from collections import OrderedDict\n", 32 | "\n", 33 | "import numpy as np\n", 34 | "import torch\n", 35 | "import torch.nn as nn\n", 36 | "import torch.nn.functional as F\n", 37 | "from torch.utils.data import DataLoader\n", 38 | "from torchvision.datasets import MNIST\n", 39 | "from torchvision.transforms import Compose, ToTensor, Normalize\n", 40 | "\n", 41 | "import matplotlib.pyplot as plt\n", 42 | "from matplotlib.cm import get_cmap\n", 43 | "from matplotlib.animation import FuncAnimation, PillowWriter\n", 44 | "from IPython.display import clear_output\n", 45 | "\n", 46 | "!pip install torchdiffeq\n", 47 | "from torchdiffeq import odeint_adjoint\n", 48 | "from torchdiffeq import odeint as standard_odeint\n", 49 | "\n", 50 | "device = 'cpu'" 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "metadata": { 56 | "id": "zjcwf1_YC0V3" 57 | }, 58 | "source": [ 59 | "# Utils\n", 60 | "Some utils for plotting and loging. Don't bother, just run." 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": { 67 | "id": "DmSDg_W90j_g" 68 | }, 69 | "outputs": [], 70 | "source": [ 71 | "plt.rcParams.update({'font.size': 15})\n", 72 | "\n", 73 | "log_dir = './logs'\n", 74 | "if not os.path.exists(log_dir):\n", 75 | " os.mkdir(log_dir)" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": { 82 | "id": "6nAhm4D7C1re" 83 | }, 84 | "outputs": [], 85 | "source": [ 86 | "def make_linear_plot(xs, ys, title, xlabel, ylabel, figsize=(8, 8), markers='-', markersize='10', legend_labels=None):\n", 87 | " plt.figure(figsize=figsize)\n", 88 | " plt.title(title)\n", 89 | " plt.xlabel(xlabel)\n", 90 | " plt.ylabel(ylabel)\n", 91 | "\n", 92 | " if isinstance(xs, list):\n", 93 | " assert len(xs) == len(ys)\n", 94 | " for x, y, marker, label in zip(xs, ys, markers, legend_labels):\n", 95 | " plt.plot(x, y, marker, markersize=markersize, label=label)\n", 96 | " plt.legend() \n", 97 | " else:\n", 98 | " plt.plot(xs, ys, markers, markersize=markersize)\n", 99 | " plt.show()" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "metadata": { 106 | "id": "pwVTqnkI2I6Y" 107 | }, 108 | "outputs": [], 109 | "source": [ 110 | "def make_gif(gt, predicted, title, xlim, ylim, ts=None):\n", 111 | " # predicted in the shape of (n_frames, n_points, 2)\n", 112 | " fig, ax = plt.subplots(figsize=(8, 8))\n", 113 | "\n", 114 | " def animate(i):\n", 115 | " ax.clear()\n", 116 | " ax.set_xlim(xlim)\n", 117 | " ax.set_ylim(ylim)\n", 118 | " if ts is None:\n", 119 | " line1, = ax.plot(gt[:, 0, 0], gt[:, 0, 1], '-', markersize='10', label='GT')\n", 120 | " line2, = ax.plot(predicted[i, :, 0], predicted[i, :, 1], '--', markersize='10', label='Predicted')\n", 121 | " else:\n", 122 | " line1, = ax.plot(ts, gt, '-', markersize='10', label='GT')\n", 123 | " line2, = ax.plot(ts, predicted[i, :], '--', markersize='10', label='Predicted')\n", 124 | " ax.legend()\n", 125 | " return line1, line2 \n", 126 | " \n", 127 | " ani = FuncAnimation(fig, animate, frames=len(predicted)) \n", 128 | " ani.save(os.path.join(log_dir, '_'.join(title.lower().split(' ')) + '.gif'), dpi=300, writer=PillowWriter(fps=10))" 129 | ] 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "metadata": { 134 | "id": "hkD22tf3a37X" 135 | }, 136 | "source": [ 137 | "# Simple ODEs" 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "metadata": { 143 | "id": "-uDXUI_1t4DF" 144 | }, 145 | "source": [ 146 | "## Linear system of ODEs - time-independent dynamic function\n", 147 | "First, we will play with some simple linear system of ODEs to get familiar with the `torchdiffeq` library.\n", 148 | "\n", 149 | "Consider a system of ODEs:\n", 150 | "\n", 151 | "\\begin{align}\n", 152 | "&x_1'(t) = a x_1(t) + b x_2(t) \\\\\n", 153 | "&x_2'(t) = c x_1(t) + d x_2(t) \n", 154 | "\\end{align}\n", 155 | "\n", 156 | "It can be written in a vectorized (horizontal) version:\n", 157 | "$$ x'(t) = x(t)A$$\n", 158 | "where $x(t) = [x_1(t), x_2(t)]$.\n", 159 | "\n", 160 | "Now, let's say that we have points from trajectory of $x(t)$. Our job is to create a neural network that will aproximate our dynamic function. We need to make 3 steps:\n", 161 | "1. Create dataset. You are given ground-truth gradients, so you need to integrate over them to get the position:\n", 162 | "$$ x(T) = x(0) + \\int_0^T x'(t) dt $$\n", 163 | "Hint: sum instead of integral is sufficient.\n", 164 | "2. Define neural network that only uses $x$ as an argument (no time dependency).\n", 165 | "3. Optimize it.\n", 166 | "\n" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "metadata": { 173 | "id": "O_xZEckKueAf" 174 | }, 175 | "outputs": [], 176 | "source": [ 177 | "x0 = torch.tensor([[2., 0.]]).to(device)\n", 178 | "ts = torch.linspace(0., 10., 1000).to(device)\n", 179 | "A_gt = torch.tensor([[-0.1, 2.0], [-2.0, -0.5]]).to(device)" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": null, 185 | "metadata": { 186 | "id": "WVJ3E_JaiKaj" 187 | }, 188 | "outputs": [], 189 | "source": [ 190 | "def generate_2d_data(x0, ts, A_gt):\n", 191 | " \"\"\"\n", 192 | " Generate data for every time step in ts.\n", 193 | " Use ground-truth matrix A_gt.\n", 194 | " \"\"\"\n", 195 | " # TODO\n", 196 | " dt = ts[1] - ts[0]\n", 197 | " xs = [x0]\n", 198 | " for t in ts[1:]:\n", 199 | " xs.append(xs[-1] + torch.mm(xs[-1], A_gt) * dt)\n", 200 | " return torch.stack(xs)" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": null, 206 | "metadata": { 207 | "id": "_0YyEGfgZyUG" 208 | }, 209 | "outputs": [], 210 | "source": [ 211 | "x_gt = generate_2d_data(x0, ts, A_gt)\n", 212 | "make_linear_plot(x_gt[:, 0, 0].cpu(), x_gt[:, 0, 1].cpu(), '2D ODE', 'x1', 'x2')" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": null, 218 | "metadata": { 219 | "id": "r-mPC30GsufL" 220 | }, 221 | "outputs": [], 222 | "source": [ 223 | "class ODEFunc(nn.Module):\n", 224 | " def __init__(self, hid_dim=64):\n", 225 | " super(ODEFunc, self).__init__()\n", 226 | " \"\"\"\n", 227 | " Define neural network:\n", 228 | " Linear(2, hid_dim) -> SiLU -> Linear(hid_dim, 2)\n", 229 | " \"\"\"\n", 230 | " # TODO\n", 231 | " self.fc1 = nn.Linear(2, hid_dim)\n", 232 | " self.fc2 = nn.Linear(hid_dim, 2)\n", 233 | " self.activation = nn.SiLU()\n", 234 | "\n", 235 | " def forward(self, t, x):\n", 236 | " \"\"\"\n", 237 | " Make the forward pass, don't use t at all.\n", 238 | " \"\"\"\n", 239 | " # TODO\n", 240 | " x = self.activation(self.fc1(x))\n", 241 | " x = self.fc2(x)\n", 242 | " return x" 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": null, 248 | "metadata": { 249 | "id": "zlzuPUAyvYPe" 250 | }, 251 | "outputs": [], 252 | "source": [ 253 | "torch.manual_seed(420)\n", 254 | "odefunc_test = ODEFunc()\n", 255 | "x = torch.randn(5, 2)\n", 256 | "with torch.no_grad():\n", 257 | " out = odefunc_test(None, x)\n", 258 | "print(out)\n", 259 | "\n", 260 | "# expected output:\n", 261 | "# tensor([[-0.4450, 0.0515],\n", 262 | "# [ 0.0216, 0.2244],\n", 263 | "# [-0.1507, 0.0914],\n", 264 | "# [ 0.0363, 0.0892],\n", 265 | "# [ 0.0484, 0.0241]])" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": null, 271 | "metadata": { 272 | "id": "TGC6cb7WtZFX" 273 | }, 274 | "outputs": [], 275 | "source": [ 276 | "torch.manual_seed(420)\n", 277 | "odefunc = ODEFunc().to(device)\n", 278 | "optimizer = torch.optim.Adam(odefunc.parameters(), lr=1e-2)\n", 279 | "mse_loss = torch.nn.MSELoss()\n", 280 | "\n", 281 | "history = []\n", 282 | "for epoch in range(301):\n", 283 | " \"\"\"\n", 284 | " Calculate predicted values x_pred and compare them to\n", 285 | " ground-truths using MSE loss function.\n", 286 | " Use standard_odeint method. \n", 287 | " It takes a nn.Module dynamic function, initial point, and integration times as its inputs.\n", 288 | " \"\"\"\n", 289 | " # TODO\n", 290 | " x_pred = standard_odeint(odefunc, x0, ts, rtol=1e-6, atol=1e-6)\n", 291 | " loss = mse_loss(x_pred, x_gt)\n", 292 | "\n", 293 | " optimizer.zero_grad()\n", 294 | " loss.backward()\n", 295 | " optimizer.step()\n", 296 | " with torch.no_grad():\n", 297 | " history.append(x_pred[:, 0, :].detach().cpu())\n", 298 | " if epoch % 10 == 0: \n", 299 | " make_linear_plot([x_gt[:, 0, 0].cpu(), x_pred[:, 0, 0].cpu()], \n", 300 | " [x_gt[:, 0, 1].cpu(), x_pred[:, 0, 1].cpu()], \n", 301 | " f'Epoch {epoch} MSE {loss.item():.4f}', 'x1', 'x2', \n", 302 | " markers=['-', '--'], legend_labels=['GT', 'Predicted'])\n", 303 | " clear_output(wait=True)" 304 | ] 305 | }, 306 | { 307 | "cell_type": "code", 308 | "execution_count": null, 309 | "metadata": { 310 | "id": "QQhRv-GN1xG-" 311 | }, 312 | "outputs": [], 313 | "source": [ 314 | "make_gif(x_gt.cpu(), torch.stack(history).cpu(), 'Linear system of ODEs', [-1.5, 2], [-1.5, 2])" 315 | ] 316 | }, 317 | { 318 | "cell_type": "markdown", 319 | "metadata": { 320 | "id": "sGrzdky5tzaf" 321 | }, 322 | "source": [ 323 | "## Non-homogeneous ODE - time-dependent dynamic function\n", 324 | "Now, let's bring another factor to the table - time. Let's consider an initial value problem of the form:\n", 325 | "$$ x'(t) = e^{-t} (\\sin(2t) + \\cos(2t)), $$\n", 326 | "$$ x(0) = 1. $$\n", 327 | "\n", 328 | "We know (using some justified guessing or Wolfram) that there is an analytical solution to this problem:\n", 329 | "$$ x(t) = \\frac{1}{5} e^{-t} (8 e^t + \\sin(2t) - 3 \\cos(2t)). $$\n", 330 | "\n", 331 | "Let's find out if we can create a model that learns the solution! Here are some steps that we need to follow:\n", 332 | "\n", 333 | "1. Generate ground truth data. This time we will do it using analytical solution.\n", 334 | "2. Create a neural network that uses both $x$ and $t$ as its arguments. We will do it using `ConcatLinear` layer.\n", 335 | "3. Optimize the model.\n", 336 | "\n" 337 | ] 338 | }, 339 | { 340 | "cell_type": "code", 341 | "execution_count": null, 342 | "metadata": { 343 | "id": "hFhgCwAQt61r" 344 | }, 345 | "outputs": [], 346 | "source": [ 347 | "def generate_1d_data(t):\n", 348 | " return 1/5 * torch.exp(-t) * (8 * torch.exp(t) + torch.sin(2 * t) - 3 * torch.cos(2 * t))" 349 | ] 350 | }, 351 | { 352 | "cell_type": "code", 353 | "execution_count": null, 354 | "metadata": { 355 | "id": "rsaY4M5GzfiG" 356 | }, 357 | "outputs": [], 358 | "source": [ 359 | "x0 = torch.tensor([[1.0]]).to(device)\n", 360 | "ts = torch.linspace(0., 5., 1000).to(device)\n", 361 | "\n", 362 | "x_gt = generate_1d_data(ts).unsqueeze(-1)\n", 363 | "make_linear_plot(ts.cpu(), x_gt[:, 0].cpu(), '1D ODE', 't', 'x')" 364 | ] 365 | }, 366 | { 367 | "cell_type": "code", 368 | "execution_count": null, 369 | "metadata": { 370 | "id": "QVBKuJSkG9jt" 371 | }, 372 | "outputs": [], 373 | "source": [ 374 | "class ConcatLinear(nn.Module):\n", 375 | " def __init__(self, in_dim, out_dim):\n", 376 | " super(ConcatLinear, self).__init__()\n", 377 | " \"\"\"\n", 378 | " Define a Linear layer that concatenates t and x at its input. \n", 379 | " Don't use nonlinearities.\n", 380 | " \"\"\"\n", 381 | " # TODO\n", 382 | " self.fc = nn.Linear(in_dim + 1, out_dim)\n", 383 | "\n", 384 | " def forward(self, t, x):\n", 385 | " \"\"\"\n", 386 | " 1. Expand dimensions of scalar t, so that it matches 2nd dimension of x, e.g.:\n", 387 | " for x of shape (5, 2), t should be (5, 1) to enable concatenation.\n", 388 | " 2. Concatenate x and t.\n", 389 | " 3. Pass it through the Linear layer. \n", 390 | " \"\"\"\n", 391 | " # TODO\n", 392 | " t = torch.ones_like(x[:, 0]).unsqueeze(-1) * t\n", 393 | " tx = torch.cat([t, x], dim=1)\n", 394 | " tx = self.fc(tx)\n", 395 | " return tx" 396 | ] 397 | }, 398 | { 399 | "cell_type": "code", 400 | "execution_count": null, 401 | "metadata": { 402 | "id": "DbyEt85n_Au4" 403 | }, 404 | "outputs": [], 405 | "source": [ 406 | "torch.manual_seed(420)\n", 407 | "concat_linear_test = ConcatLinear(3, 2)\n", 408 | "x = torch.randn(5, 3)\n", 409 | "t = torch.tensor(1)\n", 410 | "with torch.no_grad():\n", 411 | " out = concat_linear_test(t, x)\n", 412 | "print(out)\n", 413 | "\n", 414 | "# expected output:\n", 415 | "# tensor([[-0.4596, -1.2096],\n", 416 | "# [ 0.1936, -0.2271],\n", 417 | "# [ 0.0486, 0.4682],\n", 418 | "# [-0.7598, -0.8102],\n", 419 | "# [ 0.5184, -0.8245]])" 420 | ] 421 | }, 422 | { 423 | "cell_type": "code", 424 | "execution_count": null, 425 | "metadata": { 426 | "id": "KdW3l2NIFPcg" 427 | }, 428 | "outputs": [], 429 | "source": [ 430 | "class ODEFunc(nn.Module):\n", 431 | " def __init__(self, hid_dim=32):\n", 432 | " super(ODEFunc, self).__init__()\n", 433 | " \"\"\"\n", 434 | " Define neural network using 4 * ConcatLinear layers with SiLU activations. \n", 435 | " Input and output are 1D.\n", 436 | " Don't use the activation after last layer.\n", 437 | " \"\"\"\n", 438 | " # TODO\n", 439 | " self.fc1 = ConcatLinear(1, hid_dim)\n", 440 | " self.fc2 = ConcatLinear(hid_dim, hid_dim)\n", 441 | " self.fc3 = ConcatLinear(hid_dim, hid_dim)\n", 442 | " self.fc4 = ConcatLinear(hid_dim, 1)\n", 443 | " self.activation = nn.SiLU()\n", 444 | "\n", 445 | " def forward(self, t, x):\n", 446 | " \"\"\"\n", 447 | " Remember to use both t and x.\n", 448 | " \"\"\"\n", 449 | " # TODO\n", 450 | " x = self.activation(self.fc1(t, x))\n", 451 | " x = self.activation(self.fc2(t, x))\n", 452 | " x = self.activation(self.fc3(t, x))\n", 453 | " x = self.fc4(t, x)\n", 454 | " return x" 455 | ] 456 | }, 457 | { 458 | "cell_type": "code", 459 | "execution_count": null, 460 | "metadata": { 461 | "id": "yGf-udWRAiWY" 462 | }, 463 | "outputs": [], 464 | "source": [ 465 | "torch.manual_seed(420)\n", 466 | "odefunc_test = ODEFunc()\n", 467 | "x = torch.randn(5, 1)\n", 468 | "t = torch.tensor(1)\n", 469 | "with torch.no_grad():\n", 470 | " out = odefunc_test(t, x)\n", 471 | "print(out)\n", 472 | "\n", 473 | "# expected output:\n", 474 | "# tensor([[0.2415],\n", 475 | "# [0.2396],\n", 476 | "# [0.2399],\n", 477 | "# [0.2401],\n", 478 | "# [0.2431]])" 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": null, 484 | "metadata": { 485 | "id": "ksz7we1yHEU2" 486 | }, 487 | "outputs": [], 488 | "source": [ 489 | "torch.manual_seed(420)\n", 490 | "odefunc = ODEFunc().to(device)\n", 491 | "optimizer = torch.optim.Adam(odefunc.parameters(), lr=1e-2)\n", 492 | "mse_loss = torch.nn.MSELoss()\n", 493 | "\n", 494 | "history = []\n", 495 | "for i in range(351):\n", 496 | " \"\"\"\n", 497 | " Calculate x_pred and MSE loss.\n", 498 | " \"\"\"\n", 499 | " # TODO\n", 500 | " x_pred = standard_odeint(odefunc, x0, ts, rtol=1e-6, atol=1e-6).squeeze(-1)\n", 501 | " loss = mse_loss(x_pred, x_gt)\n", 502 | "\n", 503 | " optimizer.zero_grad()\n", 504 | " loss.backward()\n", 505 | " optimizer.step()\n", 506 | "\n", 507 | " \n", 508 | " with torch.no_grad():\n", 509 | " history.append(x_pred[:, 0].detach().cpu())\n", 510 | " if i % 10 == 0:\n", 511 | " make_linear_plot([ts.cpu(), ts.cpu()], [x_gt[:, 0].cpu(), x_pred[:, 0].cpu()], \n", 512 | " f'Epoch {i} MSE {loss.item():.4f}', 't', 'x', \n", 513 | " markers=['-', '--'], legend_labels=['GT', 'Predicted'])\n", 514 | " clear_output(wait=True)" 515 | ] 516 | }, 517 | { 518 | "cell_type": "code", 519 | "execution_count": null, 520 | "metadata": { 521 | "id": "eGAWR_M_CD09" 522 | }, 523 | "outputs": [], 524 | "source": [ 525 | "make_gif(x_gt.cpu(), torch.stack(history).cpu(), '1D ODE', [0, 5], [1, 1.9], ts=ts.cpu())" 526 | ] 527 | }, 528 | { 529 | "cell_type": "markdown", 530 | "metadata": { 531 | "id": "ou1ebEQwwWDf" 532 | }, 533 | "source": [ 534 | "# MNIST classification\n", 535 | "Now, let's classify MNIST digits! A classic problem tackled from a different angle.\n", 536 | "\n", 537 | "We want to predict label $y=x(T)$ for image $x_0 = x(0)$. Let's define change of dynamics between input and output:\n", 538 | "$$ \\frac{dx(t)}{dt} = f(x(t), t, \\theta)$$\n", 539 | "\n", 540 | "Note, that dynamic function preserves the dimensionality of the input. To make the classification, we need to add one linear layer on top of it!" 541 | ] 542 | }, 543 | { 544 | "cell_type": "code", 545 | "execution_count": null, 546 | "metadata": { 547 | "id": "mo01u5cDuMsz" 548 | }, 549 | "outputs": [], 550 | "source": [ 551 | "class ConcatLinear(nn.Module):\n", 552 | " def __init__(self, in_dim, out_dim):\n", 553 | " super(ConcatLinear, self).__init__()\n", 554 | " \"\"\"\n", 555 | " copy-paste from previous section\n", 556 | " \"\"\"\n", 557 | " self.fc = nn.Linear(in_dim + 1, out_dim)\n", 558 | "\n", 559 | " def forward(self, t, x):\n", 560 | " t = torch.ones_like(x[:, 0]).unsqueeze(-1) * t\n", 561 | " tx = torch.cat([t, x], dim=1)\n", 562 | " tx = self.fc(tx)\n", 563 | " return tx" 564 | ] 565 | }, 566 | { 567 | "cell_type": "code", 568 | "execution_count": null, 569 | "metadata": { 570 | "id": "tQ5oklOkGsBz" 571 | }, 572 | "outputs": [], 573 | "source": [ 574 | "class ResNetBlock(nn.Module):\n", 575 | " def __init__(self, dim):\n", 576 | " super(ResNetBlock, self).__init__()\n", 577 | " \"\"\"\n", 578 | " Implement block of Residual Network.\n", 579 | " Use 2 * ConcatLinear with and SiLU activation.\n", 580 | " x - - - - -\n", 581 | " | |\n", 582 | " SiLU(CL) |\n", 583 | " | |\n", 584 | " CL |\n", 585 | " | |\n", 586 | " + <- - - - \n", 587 | " | \n", 588 | " SiLU \n", 589 | " \"\"\"\n", 590 | " # TODO\n", 591 | " self.fc1 = ConcatLinear(dim, dim)\n", 592 | " self.fc2 = ConcatLinear(dim, dim)\n", 593 | " self.activation = nn.SiLU()\n", 594 | "\n", 595 | " def forward(self, t, x):\n", 596 | " \"\"\"\n", 597 | " Implement the forward pass.\n", 598 | " Don't forget to use both t and x, and add skip connection.\n", 599 | " Use nonlineraity at the end.\n", 600 | " \"\"\"\n", 601 | " # TODO\n", 602 | " _x = x\n", 603 | " _x = self.activation(self.fc1(t, _x))\n", 604 | " _x = self.fc2(t, _x)\n", 605 | " _x = self.activation(_x + x)\n", 606 | " return _x" 607 | ] 608 | }, 609 | { 610 | "cell_type": "code", 611 | "execution_count": null, 612 | "metadata": { 613 | "id": "ehzmDjhVZk4J" 614 | }, 615 | "outputs": [], 616 | "source": [ 617 | "torch.manual_seed(420)\n", 618 | "resnet_test = ResNetBlock(2)\n", 619 | "x = torch.randn(4, 2)\n", 620 | "t = torch.tensor(1)\n", 621 | "with torch.no_grad():\n", 622 | " out = resnet_test(t, x)\n", 623 | "print(out)\n", 624 | "\n", 625 | "# expected output:\n", 626 | "# tensor([[-0.2231, -0.2664],\n", 627 | "# [ 0.0701, -0.2782],\n", 628 | "# [-0.2701, -0.2697],\n", 629 | "# [ 0.0243, -0.2676]])" 630 | ] 631 | }, 632 | { 633 | "cell_type": "code", 634 | "execution_count": null, 635 | "metadata": { 636 | "id": "bNaOpPFSHbbA" 637 | }, 638 | "outputs": [], 639 | "source": [ 640 | "class ODEFunc(nn.Module):\n", 641 | " def __init__(self, in_dim, hid_dim=64, n_resnet_blocks=2):\n", 642 | " super(ODEFunc, self).__init__()\n", 643 | " \"\"\"\n", 644 | " Let's define our dynamic function but this time in more scalable way.\n", 645 | " Implement NN as:\n", 646 | " ConcatLinear -> n_resnet_blocks * ResNetBlock -> ConcatLinear\n", 647 | " Use SiLU ativation everywhere but after last layer.\n", 648 | " Note that out_dim == in_dim.\n", 649 | " \"\"\"\n", 650 | " # TODO\n", 651 | " self.fc_in = ConcatLinear(in_dim, hid_dim)\n", 652 | "\n", 653 | " self.resnet_blocks = []\n", 654 | " for _ in range(n_resnet_blocks):\n", 655 | " self.resnet_blocks.append(ResNetBlock(hid_dim))\n", 656 | " self.resnet_blocks = nn.ModuleList(self.resnet_blocks)\n", 657 | "\n", 658 | " self.fc_out = ConcatLinear(hid_dim, in_dim)\n", 659 | "\n", 660 | " self.activation = nn.SiLU()\n", 661 | "\n", 662 | " def forward(self, t, x):\n", 663 | " \"\"\"\n", 664 | " Implement forward pass.\n", 665 | " \"\"\"\n", 666 | " # TODO\n", 667 | " x = self.activation(self.fc_in(t, x))\n", 668 | " for resnet_block in self.resnet_blocks:\n", 669 | " x = resnet_block(t, x)\n", 670 | " x = self.fc_out(t, x)\n", 671 | " return x" 672 | ] 673 | }, 674 | { 675 | "cell_type": "code", 676 | "execution_count": null, 677 | "metadata": { 678 | "id": "s1QKw1sgg0Gr" 679 | }, 680 | "outputs": [], 681 | "source": [ 682 | "torch.manual_seed(420)\n", 683 | "odefunc_test = ODEFunc(2, hid_dim=10, n_resnet_blocks=3)\n", 684 | "x = torch.randn(3, 2)\n", 685 | "t = torch.tensor(1)\n", 686 | "with torch.no_grad():\n", 687 | " out = odefunc_test(t, x)\n", 688 | "print(out)\n", 689 | "\n", 690 | "# expected output:\n", 691 | "# tensor([[0.2312, 0.4042],\n", 692 | "# [0.4999, 0.2859],\n", 693 | "# [0.3722, 0.3528]])" 694 | ] 695 | }, 696 | { 697 | "cell_type": "code", 698 | "execution_count": null, 699 | "metadata": { 700 | "id": "zHiw8TuzHbpX" 701 | }, 702 | "outputs": [], 703 | "source": [ 704 | "class ODEBlock(nn.Module):\n", 705 | " def __init__(self, odefunc, odeint=odeint_adjoint, rtol=1e-3, atol=1e-3):\n", 706 | " super(ODEBlock, self).__init__()\n", 707 | " \"\"\"\n", 708 | " We will use ODEBlock to wrap everything related to ODE solver.\n", 709 | " \"\"\"\n", 710 | " self.odefunc = odefunc\n", 711 | " self.odeint = odeint\n", 712 | " self.rtol = rtol\n", 713 | " self.atol = atol\n", 714 | " self.integration_times = torch.tensor([0., 1.])\n", 715 | "\n", 716 | " def forward(self, x):\n", 717 | " \"\"\"\n", 718 | " Calculate output from self.odeint (adjoint method).\n", 719 | " Return only the output at time t=1.\n", 720 | " \"\"\"\n", 721 | " # TODO\n", 722 | " integration_times = self.integration_times.to(x)\n", 723 | " out = self.odeint(self.odefunc, x, integration_times, rtol=self.rtol, atol=self.atol)\n", 724 | " return out[-1]" 725 | ] 726 | }, 727 | { 728 | "cell_type": "code", 729 | "execution_count": null, 730 | "metadata": { 731 | "id": "da9tID4bjpWB" 732 | }, 733 | "outputs": [], 734 | "source": [ 735 | "torch.manual_seed(420)\n", 736 | "odefunc_test = ODEFunc(2, hid_dim=10, n_resnet_blocks=3)\n", 737 | "odeblock_test = ODEBlock(odefunc_test, odeint=standard_odeint, rtol=1e-4, atol=1e-4)\n", 738 | "x = torch.randn(3, 2)\n", 739 | "with torch.no_grad():\n", 740 | " out = odeblock_test(x)\n", 741 | "print(out)\n", 742 | "\n", 743 | "# expected output:\n", 744 | "# tensor([[ 1.3183, -0.3704],\n", 745 | "# [-1.8987, 0.3295],\n", 746 | "# [-0.0468, 0.5624]])" 747 | ] 748 | }, 749 | { 750 | "cell_type": "code", 751 | "execution_count": null, 752 | "metadata": { 753 | "id": "ZthWtKCRt52T" 754 | }, 755 | "outputs": [], 756 | "source": [ 757 | "class ODEClassifier(nn.Module):\n", 758 | " def __init__(self, in_dim, out_dim, hid_dim=64, n_resnet_blocks=2, odeint_method=odeint_adjoint, rtol=1e-3, atol=1e-3):\n", 759 | " super(ODEClassifier, self).__init__()\n", 760 | " \"\"\"\n", 761 | " Define classifier. We need three things:\n", 762 | " 1. ODEFunc\n", 763 | " 2. ODEBlock\n", 764 | " 3. Linear layer to get right dimensions for classification.\n", 765 | " \"\"\"\n", 766 | " self.odefunc = ODEFunc(in_dim, hid_dim=hid_dim, n_resnet_blocks=n_resnet_blocks)\n", 767 | " self.ode_block = ODEBlock(self.odefunc, odeint=odeint_method, rtol=rtol, atol=atol)\n", 768 | " self.fc_out = nn.Linear(in_dim, out_dim)\n", 769 | "\n", 770 | " def forward(self, x):\n", 771 | " \"\"\"\n", 772 | " Implement forward pass through ODEBlock and Linear layer.\n", 773 | " \"\"\"\n", 774 | " # TODO\n", 775 | " x = self.ode_block(x)\n", 776 | " x = self.fc_out(x)\n", 777 | " return x" 778 | ] 779 | }, 780 | { 781 | "cell_type": "code", 782 | "execution_count": null, 783 | "metadata": { 784 | "id": "kK7qjj5gmf4t" 785 | }, 786 | "outputs": [], 787 | "source": [ 788 | "bsz = 2048\n", 789 | "\n", 790 | "in_dim = 28 * 28\n", 791 | "hid_dim = 64\n", 792 | "out_dim = 10\n", 793 | "n_resnet_blocks = 3\n", 794 | "\n", 795 | "odeint_method = odeint_adjoint\n", 796 | "rtol = 1e-3\n", 797 | "atol = 1e-3\n", 798 | "\n", 799 | "lr = 1e-3\n", 800 | "\n", 801 | "n_epochs = 11" 802 | ] 803 | }, 804 | { 805 | "cell_type": "code", 806 | "execution_count": null, 807 | "metadata": { 808 | "id": "YGiGG9UAgk8x" 809 | }, 810 | "outputs": [], 811 | "source": [ 812 | "data_train = MNIST('./data', train=True, download=True, transform=Compose([ToTensor(), Normalize((0.1307,), (0.3081,))]))\n", 813 | "data_test = MNIST('./data', train=False, download=True, transform=Compose([ToTensor(), Normalize((0.1307,), (0.3081,))]))\n", 814 | "\n", 815 | "dataloader_train = DataLoader(data_train, batch_size=bsz, shuffle=True)\n", 816 | "dataloader_test = DataLoader(data_test, batch_size=bsz, shuffle=True)\n", 817 | "\n", 818 | "print(f'Loaded MNIST train split with {len(data_train)} samples')\n", 819 | "print(f'Loaded MNIST test split with {len(data_test)} samples')" 820 | ] 821 | }, 822 | { 823 | "cell_type": "code", 824 | "execution_count": null, 825 | "metadata": { 826 | "id": "-kvFL5C6m1wL" 827 | }, 828 | "outputs": [], 829 | "source": [ 830 | "for i in range(3):\n", 831 | " x, y = data_train[i]\n", 832 | " plt.imshow(x[0], cmap='gray')\n", 833 | " plt.title(str(y))\n", 834 | " plt.axis('off')\n", 835 | " plt.show()" 836 | ] 837 | }, 838 | { 839 | "cell_type": "code", 840 | "execution_count": null, 841 | "metadata": { 842 | "id": "sirAjXT1YZdX" 843 | }, 844 | "outputs": [], 845 | "source": [ 846 | "torch.manual_seed(420)\n", 847 | "model = ODEClassifier(in_dim, out_dim, hid_dim=hid_dim, n_resnet_blocks=n_resnet_blocks, odeint_method=odeint_adjoint, rtol=rtol, atol=atol).to(device)\n", 848 | "optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n", 849 | "ce_loss = nn.CrossEntropyLoss()" 850 | ] 851 | }, 852 | { 853 | "cell_type": "code", 854 | "execution_count": null, 855 | "metadata": { 856 | "id": "X1-B-z0UnaQ4" 857 | }, 858 | "outputs": [], 859 | "source": [ 860 | "for epoch in range(n_epochs):\n", 861 | " pbar = tqdm(dataloader_train, desc=f'Epoch {epoch}')\n", 862 | " for x, y in pbar:\n", 863 | " \"\"\"\n", 864 | " Make the classification.\n", 865 | " Hint: use F.one_hot().\n", 866 | " \"\"\"\n", 867 | " # TODO\n", 868 | " x = x.reshape(-1, in_dim).to(device)\n", 869 | " y = y.to(device)\n", 870 | " one_hots = F.one_hot(y, out_dim).float()\n", 871 | "\n", 872 | " logits = model(x)\n", 873 | " loss = ce_loss(logits, one_hots)\n", 874 | "\n", 875 | " optimizer.zero_grad()\n", 876 | " loss.backward()\n", 877 | " optimizer.step()\n", 878 | "\n", 879 | " pbar.set_postfix({'CE Loss': f'{loss.item():.4f}'})\n", 880 | " \n", 881 | " if epoch % 1 == 0:\n", 882 | " with torch.no_grad():\n", 883 | " accuracy_sum = 0\n", 884 | " for x, y in tqdm(dataloader_test, desc='Test'):\n", 885 | " x = x.reshape(-1, in_dim).to(device)\n", 886 | " y = y.to(device)\n", 887 | " logits = model(x)\n", 888 | " accuracy_sum += (logits.argmax(1) == y).sum()\n", 889 | " print(f'Test accuracy after {epoch} epochs {accuracy_sum / len(data_test) * 100 :.2f}%')" 890 | ] 891 | }, 892 | { 893 | "cell_type": "markdown", 894 | "metadata": { 895 | "id": "aR7SRDsQ9W3X" 896 | }, 897 | "source": [ 898 | "# Continuous Normalizing Flow (CNF)\n", 899 | "We will work on synthetic 2D moons dataset. Our goal is to create CNF that will learn the distribution of the data.\n", 900 | "\n", 901 | "Let's recall the three main components of CNF:\n", 902 | "1. Training:\n", 903 | "$$ z_0 = z_1 + \\int_1^0 f(z(t), t) dt $$\n", 904 | "2. Loss function:\n", 905 | "$$ \\log p(z_1) = \\log p(z_0) - \\int_0^1 \\text{tr} \\bigg( \\frac{f(z(t), t)}{dz(t)} \\bigg) dt $$\n", 906 | "3. Sampling:\n", 907 | "$$ z_1 = z_0 + \\int_0^1 f(z(t), t) dt $$" 908 | ] 909 | }, 910 | { 911 | "cell_type": "code", 912 | "execution_count": null, 913 | "metadata": { 914 | "id": "YNM-JiQYxU6x" 915 | }, 916 | "outputs": [], 917 | "source": [ 918 | "def generate_moons(width=1.0):\n", 919 | " moon1 = [\n", 920 | " [r * np.cos(a) - 2.5, r * np.sin(a) - 1.0]\n", 921 | " for r in np.arange(5 - width, 5 + width, 0.1 * width)\n", 922 | " for a in np.arange(0, np.pi, 0.01)\n", 923 | " ]\n", 924 | " moon2 = [\n", 925 | " [r * np.cos(a) + 2.5, r * np.sin(a) + 1.0]\n", 926 | " for r in np.arange(5 - width, 5 + width, 0.1 * width)\n", 927 | " for a in np.arange(np.pi, 2 * np.pi, 0.01)\n", 928 | " ]\n", 929 | " points = torch.tensor(moon1 + moon2)\n", 930 | " points += torch.rand(points.shape) * width\n", 931 | " return points.float()" 932 | ] 933 | }, 934 | { 935 | "cell_type": "code", 936 | "execution_count": null, 937 | "metadata": { 938 | "id": "fMDdqHIMNq9k" 939 | }, 940 | "outputs": [], 941 | "source": [ 942 | "data = generate_moons(0.5)\n", 943 | "data = (data - data.mean(0)) / data.std(0)\n", 944 | "print(data.shape)\n", 945 | "plt.scatter(data[:, 0], data[:, 1], s=1)" 946 | ] 947 | }, 948 | { 949 | "cell_type": "code", 950 | "execution_count": null, 951 | "metadata": { 952 | "id": "mdoh9bl9OhaE" 953 | }, 954 | "outputs": [], 955 | "source": [ 956 | "def normal_logprob(z):\n", 957 | "\t\"\"\"\n", 958 | "\tLog-probability of standard Gaussian distribution.\n", 959 | "\t\"\"\"\n", 960 | "\treturn (-np.log(2 * np.pi) - 0.5 * z**2).sum(1, keepdim=True)" 961 | ] 962 | }, 963 | { 964 | "cell_type": "code", 965 | "execution_count": null, 966 | "metadata": { 967 | "id": "lnzU667URTWn" 968 | }, 969 | "outputs": [], 970 | "source": [ 971 | "class ConcatLinear(nn.Module):\n", 972 | " def __init__(self, in_dim, out_dim):\n", 973 | " super(ConcatLinear, self).__init__()\n", 974 | " \"\"\"\n", 975 | " copy-paste from previous section\n", 976 | " \"\"\"\n", 977 | " # TODO\n", 978 | " self.fc = nn.Linear(in_dim + 1, out_dim)\n", 979 | "\n", 980 | " def forward(self, t, x):\n", 981 | " # TODO\n", 982 | " t = torch.ones_like(x[:, 0]).unsqueeze(-1) * t\n", 983 | " tx = torch.cat([t, x], dim=1)\n", 984 | " tx = self.fc(tx)\n", 985 | " return tx" 986 | ] 987 | }, 988 | { 989 | "cell_type": "code", 990 | "execution_count": null, 991 | "metadata": { 992 | "id": "leSAUEoURaJJ" 993 | }, 994 | "outputs": [], 995 | "source": [ 996 | "class ResNetBlock(nn.Module):\n", 997 | " def __init__(self, dim):\n", 998 | " super(ResNetBlock, self).__init__()\n", 999 | " \"\"\"\n", 1000 | " copy-paste from previous section\n", 1001 | " \"\"\"\n", 1002 | " # TODO\n", 1003 | " self.fc1 = ConcatLinear(dim, dim)\n", 1004 | " self.fc2 = ConcatLinear(dim, dim)\n", 1005 | " self.activation = nn.SiLU()\n", 1006 | "\n", 1007 | " def forward(self, t, x):\n", 1008 | " # TODO\n", 1009 | " _x = x\n", 1010 | " _x = self.activation(self.fc1(t, _x))\n", 1011 | " _x = self.fc2(t, _x)\n", 1012 | " _x = self.activation(_x + x)\n", 1013 | " return _x" 1014 | ] 1015 | }, 1016 | { 1017 | "cell_type": "code", 1018 | "execution_count": null, 1019 | "metadata": { 1020 | "id": "4DZKObG3RdiZ" 1021 | }, 1022 | "outputs": [], 1023 | "source": [ 1024 | "class ODEFunc(nn.Module):\n", 1025 | " def __init__(self, in_dim, hid_dim=64, n_resnet_blocks=2):\n", 1026 | " super(ODEFunc, self).__init__()\n", 1027 | " \"\"\"\n", 1028 | " copy-paste from previous section\n", 1029 | " \"\"\"\n", 1030 | " # TODO\n", 1031 | " self.fc_in = ConcatLinear(in_dim, hid_dim)\n", 1032 | "\n", 1033 | " self.resnet_blocks = []\n", 1034 | " for _ in range(n_resnet_blocks):\n", 1035 | " self.resnet_blocks.append(ResNetBlock(hid_dim))\n", 1036 | " self.resnet_blocks = nn.ModuleList(self.resnet_blocks)\n", 1037 | "\n", 1038 | " self.fc_out = ConcatLinear(hid_dim, in_dim)\n", 1039 | "\n", 1040 | " self.activation = nn.SiLU()\n", 1041 | "\n", 1042 | " def forward(self, t, states):\n", 1043 | " \"\"\"\n", 1044 | " Implement forward pass of CNF.\n", 1045 | " states is a tuple (z, divergence). \n", 1046 | " We don't need the divergence as an input\n", 1047 | " but we need to calculate and return it for the purpose of integration.\n", 1048 | " We have two steps:\n", 1049 | " 1. Calculate dz by passing through all of the layers.\n", 1050 | " 2. Calculate -trace(df/dz) using torch.autograd.grad().\n", 1051 | " \"\"\"\n", 1052 | " # TODO\n", 1053 | " z = states[0]\n", 1054 | "\n", 1055 | " with torch.set_grad_enabled(True):\n", 1056 | " z.requires_grad_(True)\n", 1057 | "\n", 1058 | " dz = self.activation(self.fc_in(t, z))\n", 1059 | " for resnet_block in self.resnet_blocks:\n", 1060 | " dz = resnet_block(t, dz)\n", 1061 | " dz = self.fc_out(t, dz)\n", 1062 | "\n", 1063 | " divergence = 0.\n", 1064 | " for i in range(z.shape[1]):\n", 1065 | " divergence += torch.autograd.grad(dz[:, i].sum(), z, create_graph=True)[0][:, i]\n", 1066 | "\n", 1067 | " return dz, -divergence.reshape(-1, 1)" 1068 | ] 1069 | }, 1070 | { 1071 | "cell_type": "code", 1072 | "execution_count": null, 1073 | "metadata": { 1074 | "id": "6mjZzKBmTJmC" 1075 | }, 1076 | "outputs": [], 1077 | "source": [ 1078 | "torch.manual_seed(420)\n", 1079 | "odefunc_test = ODEFunc(2, hid_dim=10, n_resnet_blocks=3)\n", 1080 | "states = (torch.randn(3, 2), torch.randn(3, 1))\n", 1081 | "t = torch.tensor(1.)\n", 1082 | "out = odefunc_test(t, states)\n", 1083 | "print(out[0])\n", 1084 | "print(out[1])\n", 1085 | "\n", 1086 | "# expected output:\n", 1087 | "# tensor([[0.2312, 0.4042],\n", 1088 | "# [0.4999, 0.2859],\n", 1089 | "# [0.3722, 0.3528]], grad_fn=)\n", 1090 | "# tensor([[0.0713],\n", 1091 | "# [0.1680],\n", 1092 | "# [0.1117]], grad_fn=))" 1093 | ] 1094 | }, 1095 | { 1096 | "cell_type": "code", 1097 | "execution_count": null, 1098 | "metadata": { 1099 | "id": "BwCk9xt8Rz7Y" 1100 | }, 1101 | "outputs": [], 1102 | "source": [ 1103 | "class CNF(nn.Module):\n", 1104 | " def __init__(self, in_dim, hid_dim=64, n_resnet_blocks=3, odeint=odeint_adjoint, rtol=1e-3, atol=1e-3):\n", 1105 | " super(CNF, self).__init__()\n", 1106 | " \"\"\"\n", 1107 | " Now, let's wrap everything into CNF class.\n", 1108 | " \"\"\"\n", 1109 | " self.odefunc = ODEFunc(in_dim, hid_dim=hid_dim, n_resnet_blocks=n_resnet_blocks)\n", 1110 | " self.odeint = odeint\n", 1111 | " self.rtol = rtol\n", 1112 | " self.atol = atol\n", 1113 | "\n", 1114 | " def forward(self, z, dlogpz=None, integration_times=None, reverse=False):\n", 1115 | " \"\"\"\n", 1116 | " Implement forward pass for CNF\n", 1117 | " \"\"\"\n", 1118 | " # TODO\n", 1119 | " if dlogpz is None:\n", 1120 | " dlogpz = torch.zeros(z.shape[0], 1).to(z)\n", 1121 | " if integration_times is None:\n", 1122 | " integration_times = torch.tensor([0., 1.]).to(z)\n", 1123 | " if reverse:\n", 1124 | " integration_times = integration_times.flip(-1)\n", 1125 | " \n", 1126 | " states = self.odeint(self.odefunc, (z, dlogpz), integration_times, rtol=self.rtol, atol=self.atol)\n", 1127 | "\n", 1128 | " if len(integration_times) == 2:\n", 1129 | " states = tuple(s[1] for s in states)\n", 1130 | " z, dlogpz = states\n", 1131 | "\n", 1132 | " return states" 1133 | ] 1134 | }, 1135 | { 1136 | "cell_type": "code", 1137 | "execution_count": null, 1138 | "metadata": { 1139 | "id": "6wbSWcfhXMsN" 1140 | }, 1141 | "outputs": [], 1142 | "source": [ 1143 | "torch.manual_seed(420)\n", 1144 | "cnf_test = CNF(2)\n", 1145 | "z = torch.randn(3, 2)\n", 1146 | "dlogpz = torch.zeros(3, 1)\n", 1147 | "\n", 1148 | "with torch.no_grad():\n", 1149 | " z, dlogpz = cnf_test(z, dlogpz=dlogpz)\n", 1150 | "print(z)\n", 1151 | "print(dlogpz)\n", 1152 | "\n", 1153 | "# expected output:\n", 1154 | "# tensor([[ 0.3526, -0.1600],\n", 1155 | "# [-0.2969, 0.5003],\n", 1156 | "# [ 2.0203, 0.4865]])\n", 1157 | "# tensor([[ 0.0166],\n", 1158 | "# [ 0.0777],\n", 1159 | "# [-0.0174]])" 1160 | ] 1161 | }, 1162 | { 1163 | "cell_type": "code", 1164 | "execution_count": null, 1165 | "metadata": { 1166 | "id": "x5KZcKaRJjmK" 1167 | }, 1168 | "outputs": [], 1169 | "source": [ 1170 | "in_dim = 2\n", 1171 | "hid_dim = 64\n", 1172 | "n_resnet_blocks = 3\n", 1173 | "\n", 1174 | "odeint_method = odeint_adjoint\n", 1175 | "rtol = 1e-3\n", 1176 | "atol = 1e-3\n", 1177 | "\n", 1178 | "lr = 1e-2\n", 1179 | "\n", 1180 | "n_epochs = 251\n", 1181 | "n_test_samples = 10000" 1182 | ] 1183 | }, 1184 | { 1185 | "cell_type": "code", 1186 | "execution_count": null, 1187 | "metadata": { 1188 | "id": "icDEUgwGLkkl" 1189 | }, 1190 | "outputs": [], 1191 | "source": [ 1192 | "torch.manual_seed(420)\n", 1193 | "model = CNF(in_dim, \n", 1194 | " hid_dim=hid_dim, \n", 1195 | " n_resnet_blocks=n_resnet_blocks, \n", 1196 | " odeint=odeint_method, \n", 1197 | " rtol=rtol, \n", 1198 | " atol=atol).to(device)\n", 1199 | "optimizer = torch.optim.Adam(model.parameters(), lr=lr)" 1200 | ] 1201 | }, 1202 | { 1203 | "cell_type": "code", 1204 | "execution_count": null, 1205 | "metadata": { 1206 | "id": "FbS7L4PTOomh" 1207 | }, 1208 | "outputs": [], 1209 | "source": [ 1210 | "data = torch.tensor(data).to(device)\n", 1211 | "for epoch in range(n_epochs):\n", 1212 | " \"\"\"\n", 1213 | " Fill the blanks in the training loop.\n", 1214 | " Calculate z0 and dlogpz using CNF's forward pass.\n", 1215 | " Calculate logpz0 using normal_logprob().\n", 1216 | " Calculate loss function.\n", 1217 | " \"\"\"\n", 1218 | " z1 = data + 1e-3 * torch.randn_like(data)\n", 1219 | "\n", 1220 | " # TODO\n", 1221 | " z0, dlogpz = model(z1)\n", 1222 | "\n", 1223 | " logpz0 = normal_logprob(z0)\n", 1224 | " logpz1 = logpz0 - dlogpz\n", 1225 | " loss = -torch.mean(logpz1)\n", 1226 | "\n", 1227 | " optimizer.zero_grad()\n", 1228 | " loss.backward()\n", 1229 | " optimizer.step()\n", 1230 | "\n", 1231 | " print(f'Epoch {epoch} Loss {loss.item():.4f}')\n", 1232 | " if epoch % 10 == 0:\n", 1233 | " with torch.no_grad():\n", 1234 | " z0 = torch.randn(n_test_samples, in_dim).to(device)\n", 1235 | " samples = model(z0, reverse=True)[0]\n", 1236 | " plt.figure(figsize=(5, 5))\n", 1237 | " plt.scatter(samples[:, 0].cpu(), samples[:, 1].cpu(), alpha=0.2)\n", 1238 | " plt.axis('off')\n", 1239 | " plt.show()" 1240 | ] 1241 | }, 1242 | { 1243 | "cell_type": "code", 1244 | "execution_count": null, 1245 | "metadata": { 1246 | "id": "gg7ybvLmNgJo" 1247 | }, 1248 | "outputs": [], 1249 | "source": [ 1250 | "with torch.no_grad():\n", 1251 | " integration_times = torch.arange(0., 1.01, 0.01).to(device)\n", 1252 | "\n", 1253 | " x = np.arange(-3, 3, 0.02)\n", 1254 | " y = np.arange(-3, 3, 0.02)\n", 1255 | " X, Y = np.meshgrid(x, y)\n", 1256 | " X, Y = torch.tensor(X).float().to(device), torch.tensor(Y).float().to(device)\n", 1257 | " z_grid = torch.cat([X.reshape(-1, 1), Y.reshape(-1, 1)], 1)\n", 1258 | "\n", 1259 | " dlogpz_grid = normal_logprob(z_grid)\n", 1260 | "\n", 1261 | " z1s, dlogpz1s = model(z_grid, dlogpz=dlogpz_grid, integration_times=integration_times, reverse=True)\n", 1262 | "\n", 1263 | "fig, ax = plt.subplots(figsize=(8, 8))\n", 1264 | "\n", 1265 | "def animate(i):\n", 1266 | " t = integration_times[i]\n", 1267 | " z1 = z1s[i].reshape(X.shape[0], X.shape[1], 2).cpu()\n", 1268 | " dlogpz1 = dlogpz1s[i].reshape(X.shape).cpu()\n", 1269 | "\n", 1270 | " ax.clear()\n", 1271 | " ax.set_xlim([-3, 3])\n", 1272 | " ax.set_ylim([-3, 3])\n", 1273 | " \n", 1274 | " line = ax.pcolormesh(z1[:, :, 0], z1[:, :, 1], dlogpz1.exp())\n", 1275 | " cmap = get_cmap(None)\n", 1276 | " ax.set_facecolor(cmap(0.))\n", 1277 | " ax.get_xaxis().set_ticks([])\n", 1278 | " ax.get_yaxis().set_ticks([])\n", 1279 | "\n", 1280 | " return line\n", 1281 | " \n", 1282 | "ani = FuncAnimation(fig, animate, frames=len(integration_times)) \n", 1283 | "ani.save(os.path.join(log_dir, 'cnf_density.gif'), dpi=300, writer=PillowWriter(fps=10))" 1284 | ] 1285 | }, 1286 | { 1287 | "cell_type": "code", 1288 | "execution_count": null, 1289 | "metadata": { 1290 | "id": "LZW-J5jawDCv" 1291 | }, 1292 | "outputs": [], 1293 | "source": [ 1294 | "" 1295 | ] 1296 | } 1297 | ], 1298 | "metadata": { 1299 | "accelerator": "GPU", 1300 | "colab": { 1301 | "collapsed_sections": [ 1302 | "zjcwf1_YC0V3" 1303 | ], 1304 | "name": "neural_ode_solvers.ipynb", 1305 | "provenance": [] 1306 | }, 1307 | "kernelspec": { 1308 | "display_name": "Python 3", 1309 | "name": "python3" 1310 | }, 1311 | "language_info": { 1312 | "name": "python" 1313 | } 1314 | }, 1315 | "nbformat": 4, 1316 | "nbformat_minor": 0 1317 | } -------------------------------------------------------------------------------- /neural_ode_solvers_presentation.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MStypulkowski/neural_ode_solvers_workshop/5a7e52537f25feb213f1c6f8bc75d095ca4d1fa7/neural_ode_solvers_presentation.pdf --------------------------------------------------------------------------------