├── .gitignore ├── Bayesian Deep Learning part1.ipynb ├── Bayesian Deep Learning part2.ipynb ├── README.md ├── mlss2019bdl ├── __init__.py ├── bdl │ ├── __init__.py │ ├── base.py │ ├── bernoulli.py │ └── gaussian.py ├── dataset.py ├── flex.py └── plotting.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # General 2 | .DS_Store 3 | .AppleDouble 4 | .LSOverride 5 | 6 | # Icon must end with two \r 7 | Icon 8 | 9 | 10 | # Thumbnails 11 | ._* 12 | 13 | # Files that might appear in the root of a volume 14 | .DocumentRevisions-V100 15 | .fseventsd 16 | .Spotlight-V100 17 | .TemporaryItems 18 | .Trashes 19 | .VolumeIcon.icns 20 | .com.apple.timemachine.donotpresent 21 | 22 | # Directories potentially created on remote AFP share 23 | .AppleDB 24 | .AppleDesktop 25 | Network Trash Folder 26 | Temporary Items 27 | .apdisk 28 | 29 | # Byte-compiled / optimized / DLL files 30 | __pycache__/ 31 | *.py[cod] 32 | *$py.class 33 | 34 | # C extensions 35 | *.so 36 | 37 | # Distribution / packaging 38 | .Python 39 | build/ 40 | develop-eggs/ 41 | dist/ 42 | downloads/ 43 | eggs/ 44 | .eggs/ 45 | lib/ 46 | lib64/ 47 | parts/ 48 | sdist/ 49 | var/ 50 | wheels/ 51 | *.egg-info/ 52 | .installed.cfg 53 | *.egg 54 | MANIFEST 55 | 56 | # PyInstaller 57 | # Usually these files are written by a python script from a template 58 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 59 | *.manifest 60 | *.spec 61 | 62 | # Installer logs 63 | pip-log.txt 64 | pip-delete-this-directory.txt 65 | 66 | # Unit test / coverage reports 67 | htmlcov/ 68 | .tox/ 69 | .coverage 70 | .coverage.* 71 | .cache 72 | nosetests.xml 73 | coverage.xml 74 | *.cover 75 | .hypothesis/ 76 | .pytest_cache/ 77 | 78 | # Translations 79 | *.mo 80 | *.pot 81 | 82 | # Django stuff: 83 | *.log 84 | local_settings.py 85 | db.sqlite3 86 | 87 | # Flask stuff: 88 | instance/ 89 | .webassets-cache 90 | 91 | # Scrapy stuff: 92 | .scrapy 93 | 94 | # Sphinx documentation 95 | docs/_build/ 96 | 97 | # PyBuilder 98 | target/ 99 | 100 | # Jupyter Notebook 101 | .ipynb_checkpoints 102 | 103 | # pyenv 104 | .python-version 105 | 106 | # celery beat schedule file 107 | celerybeat-schedule 108 | 109 | # SageMath parsed files 110 | *.sage.py 111 | 112 | # Environments 113 | .env 114 | .venv 115 | env/ 116 | venv/ 117 | ENV/ 118 | env.bak/ 119 | venv.bak/ 120 | 121 | # Spyder project settings 122 | .spyderproject 123 | .spyproject 124 | 125 | # Rope project settings 126 | .ropeproject 127 | 128 | # mkdocs documentation 129 | /site 130 | 131 | # mypy 132 | .mypy_cache/ 133 | -------------------------------------------------------------------------------- /Bayesian Deep Learning part1.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# MLSS2019: Bayesian Deep Learning" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "In this tutorial we will learn what basic building blocks are needed\n", 15 | "to endow (deep) neural networks with uncertainty estimates." 16 | ] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "metadata": {}, 21 | "source": [ 22 | "The plan of the tutorial\n", 23 | "1. [Setup and imports](#Setup-and-imports)\n", 24 | "2. [Easy uncertainty in networks](#Easy-uncertainty-in-networks)\n", 25 | " 1. [Bayesification via dropout and weight decay](#Bayesification-via-dropout-and-weight-decay)\n", 26 | " 2. [Implementing function sampling with the DropoutLinear Layer](#Implementing-function-sampling-with-the-DropoutLinear-Layer)\n", 27 | " 3. [Implementing-DropoutLinear](#Implementing-DropoutLinear)\n", 28 | " 4. [Comparing sample functions to point-estimates](#Comparing-sample-functions-to-point-estimates)\n", 29 | "3. [(optional) Dropout $2$-d Convolutional layer](#(optional)-Dropout-$2$-d-Convolutional-layer)\n", 30 | "4. [(optional) A brief reminder on Bayesian and Variational Inference](#(optional)-A-brief-reminder-on-Bayesian-and-Variational-Inference)" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": {}, 36 | "source": [ 37 | "**(note)**\n", 38 | "* to view documentation on something type in `something?` (with one question mark)\n", 39 | "* to view code of something type in `something??` (with two question marks)." 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": {}, 45 | "source": [ 46 | "
" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "## Setup and imports" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "In this section we import necessary modules and functions and\n", 61 | "define the computational device." 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": {}, 67 | "source": [ 68 | "First, we install some boilerplate service code for this tutorial." 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "!pip install -q --upgrade git+https://github.com/ivannz/mlss2019-bayesian-deep-learning.git" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "metadata": {}, 83 | "source": [ 84 | "Next, numpy for computing, matplotlib for plotting and tqdm for progress bars." 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "import tqdm\n", 94 | "import numpy as np\n", 95 | "\n", 96 | "%matplotlib inline\n", 97 | "import matplotlib.pyplot as plt" 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "metadata": {}, 103 | "source": [ 104 | "For deep learning stuff will be using [pytorch](https://pytorch.org/).\n", 105 | "\n", 106 | "If you are unfamiliar with it, it is basically like `numpy` with autograd,\n", 107 | "stricter data type enforcement, native GPU support, and tools for building\n", 108 | "training and serializing models.\n", 109 | "\n", 110 | "\n", 111 | "There are good introductory tutorials on `pytorch`, like this\n", 112 | "[one](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html)." 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "import torch\n", 122 | "import torch.nn.functional as F\n", 123 | "\n", 124 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" 125 | ] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "metadata": {}, 130 | "source": [ 131 | "We will need some functionality from scikit" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "from sklearn.metrics import confusion_matrix" 141 | ] 142 | }, 143 | { 144 | "cell_type": "markdown", 145 | "metadata": {}, 146 | "source": [ 147 | "Next we import the boilerplate code.\n", 148 | "\n", 149 | "* a procedure that implements a minibatch SGD **fit** loop\n", 150 | "* a function, that **evaluates** the model on the provided dataset" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "from mlss2019bdl import fit" 160 | ] 161 | }, 162 | { 163 | "cell_type": "markdown", 164 | "metadata": {}, 165 | "source": [ 166 | "```python\n", 167 | "# pseudocode\n", 168 | "def fit(model, dataset, criterion, ...):\n", 169 | " for epoch in epochs:\n", 170 | " for batch in dataset:\n", 171 | " loss = criterion(model, batch) # forward pass\n", 172 | "\n", 173 | " grad = loss.backward() # gradient via back propagation\n", 174 | "\n", 175 | " adam_step(grad)\n", 176 | "```" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": null, 182 | "metadata": {}, 183 | "outputs": [], 184 | "source": [ 185 | "from mlss2019bdl import predict" 186 | ] 187 | }, 188 | { 189 | "cell_type": "markdown", 190 | "metadata": {}, 191 | "source": [ 192 | "```python\n", 193 | "# pseudocode\n", 194 | "def predict(model, dataset, ...):\n", 195 | " for input_batch in dataset:\n", 196 | " output.append(model(input_batch)) # forward pass\n", 197 | " \n", 198 | " return concatenate(output)\n", 199 | "```" 200 | ] 201 | }, 202 | { 203 | "cell_type": "markdown", 204 | "metadata": {}, 205 | "source": [ 206 | "
" 207 | ] 208 | }, 209 | { 210 | "cell_type": "markdown", 211 | "metadata": {}, 212 | "source": [ 213 | "## Easy uncertainty in networks" 214 | ] 215 | }, 216 | { 217 | "cell_type": "markdown", 218 | "metadata": {}, 219 | "source": [ 220 | "Generate the initial small dataset $S_0 = (x_i, y_i)_{i=1}^{m_0}$\n", 221 | "with $y_i = g(x_i)$, $x_i$ on a regular-spaced grid, and $\n", 222 | "g\n", 223 | " \\colon \\mathbb{R} \\to \\mathbb{R}\n", 224 | " \\colon x \\mapsto \\tfrac{x^2}4 + \\sin \\frac\\pi2 x\n", 225 | "$.\n", 226 | "" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": null, 236 | "metadata": {}, 237 | "outputs": [], 238 | "source": [ 239 | "from mlss2019bdl import dataset_from_numpy\n", 240 | "\n", 241 | "X_train = np.linspace(-6.0, +6.0, num=20)[:, np.newaxis]\n", 242 | "y_train = np.sin(X_train * np.pi / 2) + 0.25 * X_train**2\n", 243 | "\n", 244 | "train = dataset_from_numpy(X_train, y_train, device=device)" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": null, 250 | "metadata": {}, 251 | "outputs": [], 252 | "source": [ 253 | "X_domain = np.linspace(-10., +10., num=251)[:, np.newaxis]\n", 254 | "\n", 255 | "domain = dataset_from_numpy(X_domain, device=device)" 256 | ] 257 | }, 258 | { 259 | "cell_type": "markdown", 260 | "metadata": {}, 261 | "source": [ 262 | "Suppose we have the following model: a 3-layer fully connected\n", 263 | "network with LeakyReLU activations." 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": null, 269 | "metadata": {}, 270 | "outputs": [], 271 | "source": [ 272 | "from torch.nn import Linear, Sequential\n", 273 | "from torch.nn import LeakyReLU\n", 274 | "\n", 275 | "\n", 276 | "model = Sequential(\n", 277 | " Linear(1, 512, bias=True),\n", 278 | " LeakyReLU(),\n", 279 | "\n", 280 | " Linear(512, 512, bias=True),\n", 281 | " LeakyReLU(),\n", 282 | "\n", 283 | " Linear(512, 1, bias=True),\n", 284 | ")\n", 285 | "\n", 286 | "model.to(device)" 287 | ] 288 | }, 289 | { 290 | "cell_type": "markdown", 291 | "metadata": {}, 292 | "source": [ 293 | "
" 294 | ] 295 | }, 296 | { 297 | "cell_type": "markdown", 298 | "metadata": {}, 299 | "source": [ 300 | "We fit our model on `train` using MSE loss and $\\ell_2$ penalty on\n", 301 | "weights (`weight_decay`):\n", 302 | "$$\n", 303 | " \\tfrac1{2 m} \\|f_\\omega(x) - y\\|_2^2 + \\lambda \\|\\omega\\|_2^2\n", 304 | " \\,, $$\n", 305 | "where $\\omega$ are all the learnable parameters of the network $f_\\omega$." 306 | ] 307 | }, 308 | { 309 | "cell_type": "markdown", 310 | "metadata": {}, 311 | "source": [ 312 | "
" 313 | ] 314 | }, 315 | { 316 | "cell_type": "markdown", 317 | "metadata": {}, 318 | "source": [ 319 | "Fit, ..." 320 | ] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "execution_count": null, 325 | "metadata": { 326 | "scrolled": false 327 | }, 328 | "outputs": [], 329 | "source": [ 330 | "fit(model, train, criterion=\"mse\", n_epochs=2000, verbose=True, weight_decay=1e-3)" 331 | ] 332 | }, 333 | { 334 | "cell_type": "markdown", 335 | "metadata": {}, 336 | "source": [ 337 | "..., compute the predictions, ..." 338 | ] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "execution_count": null, 343 | "metadata": {}, 344 | "outputs": [], 345 | "source": [ 346 | "y_pred = predict(model, domain)" 347 | ] 348 | }, 349 | { 350 | "cell_type": "markdown", 351 | "metadata": {}, 352 | "source": [ 353 | "..., and plot them." 354 | ] 355 | }, 356 | { 357 | "cell_type": "code", 358 | "execution_count": null, 359 | "metadata": {}, 360 | "outputs": [], 361 | "source": [ 362 | "fig, ax = plt.subplots(1, 1, figsize=(12, 5))\n", 363 | "\n", 364 | "ax.scatter(X_train, y_train, c=\"black\", s=40, label=\"train\")\n", 365 | "\n", 366 | "ax.plot(X_domain, y_pred.numpy(), c=\"C0\", lw=2, label=\"prediction\")\n", 367 | "\n", 368 | "plt.legend();" 369 | ] 370 | }, 371 | { 372 | "cell_type": "markdown", 373 | "metadata": {}, 374 | "source": [ 375 | "This model seems to fit the train set adequately well. However, there is no\n", 376 | "way to assess how confident this model is with respect to its predictions.\n", 377 | "Indeed, the prediction $\\hat{y}_x = f_\\omega(x)$ is is a deterministic function\n", 378 | "of the input $x$ and the learnt parameters $\\omega$." 379 | ] 380 | }, 381 | { 382 | "cell_type": "markdown", 383 | "metadata": {}, 384 | "source": [ 385 | "
" 386 | ] 387 | }, 388 | { 389 | "cell_type": "markdown", 390 | "metadata": {}, 391 | "source": [ 392 | "### `Bayesification` via dropout and weight decay" 393 | ] 394 | }, 395 | { 396 | "cell_type": "markdown", 397 | "metadata": {}, 398 | "source": [ 399 | "One inexpensive way to make any network into a stochastic function of its\n", 400 | "input is to add dropout before any parameterized layer like `linear`\n", 401 | "or `convolutional`, [Hinton et al. 2012](https://arxiv.org/abs/1207.0580).\n", 402 | "Essentially, dropout applies a Bernoulli mask to the features of the input.\n", 403 | "\n", 404 | "In [Gal, Y. (2016)](http://www.cs.ox.ac.uk/people/yarin.gal/website/thesis/thesis.pdf)\n", 405 | "it has been shown that a simple, somewhat ad-hoc approach of\n", 406 | "adding uncertainty quantification to networks through dropout,\n", 407 | "coupled with $\\ell_2$ weight penalty, is a special case of Variational Inference." 408 | ] 409 | }, 410 | { 411 | "cell_type": "markdown", 412 | "metadata": {}, 413 | "source": [ 414 | "For input\n", 415 | "$\n", 416 | " x\\in \\mathbb{R}^{[\\mathrm{in}]}\n", 417 | "$ the dropout layer acts like this:\n", 418 | "\n", 419 | "$$\n", 420 | " y_j = x_j \\, m_j\n", 421 | " \\,, $$\n", 422 | "\n", 423 | "where $m\\in \\mathbb{R}^{[\\mathrm{in}]}$ with $\n", 424 | "m_j \\sim \\pi_p(m_j)\n", 425 | " = \\mathcal{Ber}\\bigl(\\bigl\\{0, \\tfrac1{1-p}\\bigr\\}, 1-p\\bigr)\n", 426 | "$,\n", 427 | "i.e. equals $\\tfrac1{1-p}$ with probability $1-p$ and $0$ otherwise." 428 | ] 429 | }, 430 | { 431 | "cell_type": "markdown", 432 | "metadata": {}, 433 | "source": [ 434 | "#### (task) Always Active Dropout" 435 | ] 436 | }, 437 | { 438 | "cell_type": "markdown", 439 | "metadata": {}, 440 | "source": [ 441 | "Useful methods:\n", 442 | "* `torch.rand(d1, ..., dn)` -- draw $d_1\\times \\ldots \\times d_n$ tensor of uniform rv-s\n", 443 | "* `torch.rand_like(other)` -- draw a tensor of uniform rv-s with the shape, data type and device as `other`\n", 444 | "\n", 445 | "\n", 446 | "* `torch.bernoulli(pi)` -- draw tensor $t$ with independent $\n", 447 | "t_\\alpha \\sim \\mathcal{Ber}\\bigl(\\{0, 1\\}, \\pi_\\alpha\\bigr)\n", 448 | "$ for each index $\\alpha$\n", 449 | "* `torch.full((d1, ..., dn), v)` -- a $d_1\\times \\ldots \\times d_n$ tensor with the same value $v$\n", 450 | "\n", 451 | "\n", 452 | "* `Tensor.to(other)` -- assume move `Tensor` to the device of the `other` and cast to its data type." 453 | ] 454 | }, 455 | { 456 | "cell_type": "code", 457 | "execution_count": null, 458 | "metadata": {}, 459 | "outputs": [], 460 | "source": [ 461 | "from torch.nn import Module\n", 462 | "\n", 463 | "class ActiveDropout(Module):\n", 464 | " # all building blocks of networks are inherited from Module!\n", 465 | "\n", 466 | " def __init__(self, p=0.5):\n", 467 | " super().__init__() # init the base class\n", 468 | "\n", 469 | " self.p = p\n", 470 | "\n", 471 | " def forward(self, input):\n", 472 | " ## Exercise: implement feature dropout on input\n", 473 | " # self.p - contains the specified dropout rate\n", 474 | " \n", 475 | " mask = torch.rand_like(input) > self.p\n", 476 | " return input * mask.to(input) / (1 - self.p)\n", 477 | "\n", 478 | " # prob = torch.full_like(input, 1 - self.p)\n", 479 | " # return input * torch.bernoulli(prob) / prob\n", 480 | "\n", 481 | " # return F.dropout(input, self.p, True)\n", 482 | "\n", 483 | " pass" 484 | ] 485 | }, 486 | { 487 | "cell_type": "markdown", 488 | "metadata": {}, 489 | "source": [ 490 | "
" 491 | ] 492 | }, 493 | { 494 | "cell_type": "markdown", 495 | "metadata": {}, 496 | "source": [ 497 | "#### (task) Rebuilding the model" 498 | ] 499 | }, 500 | { 501 | "cell_type": "markdown", 502 | "metadata": {}, 503 | "source": [ 504 | "Let's recreate the model above with this freshly minted dropout layer.\n", 505 | "Then fit and plot it's prediction uncertainty due to forward pass stochasticity." 506 | ] 507 | }, 508 | { 509 | "cell_type": "code", 510 | "execution_count": null, 511 | "metadata": {}, 512 | "outputs": [], 513 | "source": [ 514 | "def build_model(p=0.5):\n", 515 | " \"\"\"Build a model with dropout layers' rate set to `p`.\"\"\"\n", 516 | "\n", 517 | " return Sequential(\n", 518 | " ## Exercise: Use `ActiveDropout` before linear layers of our\n", 519 | " # first network. Note that dropping out inputs is not a good idea\n", 520 | "\n", 521 | " Linear(1, 512, bias=True),\n", 522 | " LeakyReLU(),\n", 523 | "\n", 524 | " ActiveDropout(p),\n", 525 | " Linear(512, 512, bias=True),\n", 526 | " LeakyReLU(),\n", 527 | "\n", 528 | " ActiveDropout(p),\n", 529 | " Linear(512, 1, bias=True),\n", 530 | "\n", 531 | " # pass\n", 532 | " )" 533 | ] 534 | }, 535 | { 536 | "cell_type": "markdown", 537 | "metadata": {}, 538 | "source": [ 539 | "
" 540 | ] 541 | }, 542 | { 543 | "cell_type": "code", 544 | "execution_count": null, 545 | "metadata": {}, 546 | "outputs": [], 547 | "source": [ 548 | "model = build_model(p=0.5)\n", 549 | "\n", 550 | "model.to(device)\n", 551 | "\n", 552 | "fit(model, train, criterion=\"mse\", n_epochs=2000, verbose=True,\n", 553 | " weight_decay=1e-3)" 554 | ] 555 | }, 556 | { 557 | "cell_type": "markdown", 558 | "metadata": {}, 559 | "source": [ 560 | "
" 561 | ] 562 | }, 563 | { 564 | "cell_type": "markdown", 565 | "metadata": {}, 566 | "source": [ 567 | "#### Sampling the random output" 568 | ] 569 | }, 570 | { 571 | "cell_type": "markdown", 572 | "metadata": {}, 573 | "source": [ 574 | "Let's take the test sample $\\tilde{S} = (\\tilde{x}_i)_{i=1}^m \\in \\mathcal{X}$\n", 575 | "and repeat the stochastic forward pass $B$ times at each $x\\in \\tilde{S}$:\n", 576 | "\n", 577 | "* for $b = 1 .. B$ do:\n", 578 | "\n", 579 | " 1. draw $y_{bi} \\sim f_\\omega(\\tilde{x}_i)$ for $i = 1 .. m$." 580 | ] 581 | }, 582 | { 583 | "cell_type": "code", 584 | "execution_count": null, 585 | "metadata": {}, 586 | "outputs": [], 587 | "source": [ 588 | "def point_estimate(model, dataset, n_samples=1, verbose=False):\n", 589 | " \"\"\"Draw pointwise samples with stochastic forward pass.\"\"\"\n", 590 | "\n", 591 | " outputs = []\n", 592 | " for sample in tqdm.tqdm(range(n_samples), disable=not verbose):\n", 593 | "\n", 594 | " outputs.append(predict(model, dataset))\n", 595 | "\n", 596 | " return torch.stack(outputs, dim=0)\n", 597 | "\n", 598 | "\n", 599 | "samples = point_estimate(model, domain, n_samples=101, verbose=True)" 600 | ] 601 | }, 602 | { 603 | "cell_type": "markdown", 604 | "metadata": {}, 605 | "source": [ 606 | "
" 607 | ] 608 | }, 609 | { 610 | "cell_type": "markdown", 611 | "metadata": {}, 612 | "source": [ 613 | "The approximate $95\\%$ confidence band of predictions is..." 614 | ] 615 | }, 616 | { 617 | "cell_type": "code", 618 | "execution_count": null, 619 | "metadata": {}, 620 | "outputs": [], 621 | "source": [ 622 | "fig, ax = plt.subplots(1, 1, figsize=(12, 5))\n", 623 | "ax.scatter(X_train, y_train, c=\"black\", s=40, label=\"train\")\n", 624 | "\n", 625 | "mean, std = samples.mean(dim=0).numpy(), samples.std(dim=0).numpy()\n", 626 | "ax.plot(X_domain, mean + 1.96 * std, c=\"k\")\n", 627 | "ax.plot(X_domain, mean - 1.96 * std, c=\"k\");" 628 | ] 629 | }, 630 | { 631 | "cell_type": "markdown", 632 | "metadata": {}, 633 | "source": [ 634 | "
" 635 | ] 636 | }, 637 | { 638 | "cell_type": "markdown", 639 | "metadata": {}, 640 | "source": [ 641 | "### Implementing function sampling with the DropoutLinear Layer" 642 | ] 643 | }, 644 | { 645 | "cell_type": "markdown", 646 | "metadata": {}, 647 | "source": [ 648 | "Let's inspect the draws $y_{bi}$ as $B$ functional samples:\n", 649 | "$(x_i, y_{bi})_{i=1}^m$ - the $b$-th sample path. Below we\n", 650 | "plot $5$ random paths." 651 | ] 652 | }, 653 | { 654 | "cell_type": "code", 655 | "execution_count": null, 656 | "metadata": {}, 657 | "outputs": [], 658 | "source": [ 659 | "samples = point_estimate(model, domain, n_samples=101, verbose=True)\n", 660 | "\n", 661 | "fig, ax = plt.subplots(1, 1, figsize=(12, 5))\n", 662 | "\n", 663 | "ax.scatter(X_train, y_train, c=\"black\", s=40, label=\"train\")\n", 664 | "ax.plot(X_domain[:, 0], samples[:5, :, 0].numpy().T, c=\"C0\", lw=1, alpha=0.25);" 665 | ] 666 | }, 667 | { 668 | "cell_type": "markdown", 669 | "metadata": {}, 670 | "source": [ 671 | "It is clear that they are very erratic!" 672 | ] 673 | }, 674 | { 675 | "cell_type": "markdown", 676 | "metadata": {}, 677 | "source": [ 678 | "Computing stochastic forward passes with a new mask each time is equivalent\n", 679 | "to drawing new **independent** prediction from for each point $x\\in \\tilde{S}$,\n", 680 | "without considering that, in fact, at adjacent points the predictions should\n", 681 | "be correlated. If we were interested in uncertainty at some particular point,\n", 682 | "this would be okay: **fast and simple**." 683 | ] 684 | }, 685 | { 686 | "cell_type": "markdown", 687 | "metadata": {}, 688 | "source": [ 689 | "However, if we are interested in the uncertainty of an integral **path-dependent**\n", 690 | "measure of the whole estimated function, or are doing **optimization** of\n", 691 | "the unknown true function taking estimation uncertainty into account, then\n", 692 | "this clearly erratic behaviour of paths is undesirable. Ex. see\n", 693 | "[blog: Gal, Y. 2016](http://www.cs.ox.ac.uk/people/yarin.gal/website/blog_2248.html)" 694 | ] 695 | }, 696 | { 697 | "cell_type": "markdown", 698 | "metadata": {}, 699 | "source": [ 700 | "
" 701 | ] 702 | }, 703 | { 704 | "cell_type": "markdown", 705 | "metadata": {}, 706 | "source": [ 707 | "We need to implement some extra functionality on top of `pytorch`,\n", 708 | "in order to draw realizations from the induced distribution over\n", 709 | "functions, defined by a network, i.e. $\n", 710 | "\\bigl\\{\n", 711 | " f_\\omega\\colon \\mathcal{X}\\to\\mathcal{Y}\n", 712 | "\\bigr\\}_{\\omega \\sim q(\\omega)}\n", 713 | "$\n", 714 | "where $q(\\omega)$ is a distribution over the parameters." 715 | ] 716 | }, 717 | { 718 | "cell_type": "markdown", 719 | "metadata": {}, 720 | "source": [ 721 | "One of the design approaches is to allow layers\n", 722 | "to cache random draws of their parameters for reuse\n", 723 | "in all subsequent forward passes, until this is no\n", 724 | "longer needed." 725 | ] 726 | }, 727 | { 728 | "cell_type": "markdown", 729 | "metadata": {}, 730 | "source": [ 731 | "#### Freeze/unfreeze interface" 732 | ] 733 | }, 734 | { 735 | "cell_type": "markdown", 736 | "metadata": {}, 737 | "source": [ 738 | "This is a base **trait-class** `FreezableWeight` that adds interface\n", 739 | "for freezing and unfreezing layer's random **weight** parameter." 740 | ] 741 | }, 742 | { 743 | "cell_type": "code", 744 | "execution_count": null, 745 | "metadata": {}, 746 | "outputs": [], 747 | "source": [ 748 | "class FreezableWeight(Module):\n", 749 | " def __init__(self):\n", 750 | " super().__init__()\n", 751 | " self.unfreeze()\n", 752 | "\n", 753 | " def unfreeze(self):\n", 754 | " self.register_buffer(\"frozen_weight\", None)\n", 755 | "\n", 756 | " def is_frozen(self):\n", 757 | " \"\"\"Check if a frozen weight is available.\"\"\"\n", 758 | " return isinstance(self.frozen_weight, torch.Tensor)\n", 759 | "\n", 760 | " def freeze(self):\n", 761 | " \"\"\"Sample from the parameter distribution and freeze.\"\"\"\n", 762 | " raise NotImplementedError()" 763 | ] 764 | }, 765 | { 766 | "cell_type": "markdown", 767 | "metadata": {}, 768 | "source": [ 769 | "Next, we declare a pair of functions:\n", 770 | "* `freeze()` instructs each compatible layer of the model to **sample and freeze** its randomness\n", 771 | "* `unfreeze()` requests the layers to **undo** this" 772 | ] 773 | }, 774 | { 775 | "cell_type": "code", 776 | "execution_count": null, 777 | "metadata": {}, 778 | "outputs": [], 779 | "source": [ 780 | "def freeze(model):\n", 781 | " for layer in model.modules():\n", 782 | " if isinstance(layer, FreezableWeight):\n", 783 | " layer.freeze()\n", 784 | "\n", 785 | " return model" 786 | ] 787 | }, 788 | { 789 | "cell_type": "code", 790 | "execution_count": null, 791 | "metadata": {}, 792 | "outputs": [], 793 | "source": [ 794 | "def unfreeze(model):\n", 795 | " for layer in model.modules():\n", 796 | " if isinstance(layer, FreezableWeight):\n", 797 | " layer.unfreeze()\n", 798 | "\n", 799 | " return model" 800 | ] 801 | }, 802 | { 803 | "cell_type": "markdown", 804 | "metadata": {}, 805 | "source": [ 806 | "
" 807 | ] 808 | }, 809 | { 810 | "cell_type": "markdown", 811 | "metadata": {}, 812 | "source": [ 813 | "#### (task) Sampling realizations" 814 | ] 815 | }, 816 | { 817 | "cell_type": "markdown", 818 | "metadata": {}, 819 | "source": [ 820 | "The algorithm to sample a random function is:\n", 821 | "* for $b = 1... B$ do:\n", 822 | "\n", 823 | " 1. draw an independent realization $f_b\\colon \\mathcal{X} \\to \\mathcal{Y}$\n", 824 | " with from the process $\\{f_\\omega\\}_{\\omega \\sim q(\\omega)}$\n", 825 | " 2. get $\\hat{y}_{bi} = f_b(\\tilde{x}_i)$ for $i=1 .. m$\n" 826 | ] 827 | }, 828 | { 829 | "cell_type": "code", 830 | "execution_count": null, 831 | "metadata": {}, 832 | "outputs": [], 833 | "source": [ 834 | "def sample_function(model, dataset, n_samples=1, verbose=False):\n", 835 | " \"\"\"Draw a realization of a random function.\"\"\"\n", 836 | "\n", 837 | " ## Exercise: code a function similar to `point_estimate()`,\n", 838 | " ## that collects the predictions from `frozen` models. Don't\n", 839 | " ## forget to unfreeze before returning.\n", 840 | "\n", 841 | " outputs = []\n", 842 | " for _ in tqdm.tqdm(range(n_samples), disable=not verbose):\n", 843 | " freeze(model)\n", 844 | "\n", 845 | " outputs.append(predict(model, dataset))\n", 846 | "\n", 847 | " unfreeze(model)\n", 848 | "\n", 849 | " return torch.stack(outputs, dim=0)\n", 850 | "\n", 851 | " pass" 852 | ] 853 | }, 854 | { 855 | "cell_type": "markdown", 856 | "metadata": {}, 857 | "source": [ 858 | "**(note)** although the internal loop in both functions looks\n", 859 | "similar they, conceptually the functions differ:\n", 860 | "\n", 861 | "```python\n", 862 | "def point_estimate(f, S):\n", 863 | " for x in S:\n", 864 | " for w from f.q: # different w for different x\n", 865 | " yield f(x, w)\n", 866 | "\n", 867 | "\n", 868 | "def sample_function(f, S):\n", 869 | " for w from f.q:\n", 870 | " for x in S: # same w for different x (thanks to freeze)\n", 871 | " yield f(x, w)\n", 872 | "```\n", 873 | "" 874 | ] 875 | }, 876 | { 877 | "cell_type": "markdown", 878 | "metadata": {}, 879 | "source": [ 880 | "
" 881 | ] 882 | }, 883 | { 884 | "cell_type": "markdown", 885 | "metadata": {}, 886 | "source": [ 887 | "### Implementing `DropoutLinear`" 888 | ] 889 | }, 890 | { 891 | "cell_type": "markdown", 892 | "metadata": {}, 893 | "source": [ 894 | "Now we will merge `ActiveDropout` and `Linear` layers into one, which\n", 895 | "\n", 896 | "1. (on forward pass) **drops out** the inputs, if necessary, and **applies** the linear (affine) transform\n", 897 | "2. (on freeze) **randomly zeros** columns in a copy of the the weight matrix $W$\n", 898 | "\n", 899 | "Preferably, we will try to preserve interface, so that the resulting\n", 900 | "object is backwards compatible with `Linear`." 901 | ] 902 | }, 903 | { 904 | "cell_type": "markdown", 905 | "metadata": {}, 906 | "source": [ 907 | "This way we would be able to draw realizations from the induced\n", 908 | "distribution over functions defined by the network $\n", 909 | "\\bigl\\{\n", 910 | " f_\\omega\\colon \\mathcal{X}\\to\\mathcal{Y}\n", 911 | "\\bigr\\}_{\\omega \\sim q(\\omega)}\n", 912 | "$\n", 913 | "where $q(\\omega)$ a distribution over the network parameters." 914 | ] 915 | }, 916 | { 917 | "cell_type": "markdown", 918 | "metadata": {}, 919 | "source": [ 920 | "
" 921 | ] 922 | }, 923 | { 924 | "cell_type": "markdown", 925 | "metadata": {}, 926 | "source": [ 927 | "#### (task) Fused dropout-linear operation" 928 | ] 929 | }, 930 | { 931 | "cell_type": "markdown", 932 | "metadata": {}, 933 | "source": [ 934 | "On the inputs into a linear layer dropout acts like this: for input\n", 935 | "$\n", 936 | " x\\in \\mathbb{R}^{[\\mathrm{in}]}\n", 937 | "$ and layer weights $\n", 938 | " W\\in \\mathbb{R}^{[\\mathrm{out}] \\times [\\mathrm{in}]}\n", 939 | "$\n", 940 | "and bias $\n", 941 | " b\\in \\mathbb{R}^{[\\mathrm{out}]}\n", 942 | "$ the resulting effect is\n", 943 | "\n", 944 | "$$\n", 945 | " \\tilde{x} = x \\odot m\n", 946 | " \\,, \\\\\n", 947 | " y = \\tilde{x} W^\\top + b\n", 948 | "% = b + \\sum_i x_i m_i W_i\n", 949 | " \\,, $$\n", 950 | "\n", 951 | "where $\\odot$ is the elementwise product and $m\\in \\mathbb{R}^{[\\mathrm{in}]}$\n", 952 | "with $m_j \\sim \\pi_p(m_j) = \\mathcal{Ber}\\bigl(\\bigl\\{0, \\tfrac1{1-p}\\bigr\\}, 1-p\\bigr)$,\n", 953 | "i.e. equals $\\tfrac1{1-p}$ with probability $1-p$ and $0$ otherwise." 954 | ] 955 | }, 956 | { 957 | "cell_type": "markdown", 958 | "metadata": {}, 959 | "source": [ 960 | "Let\n", 961 | "$\n", 962 | " x\\in \\mathbb{R}^{[\\mathrm{in}]}\n", 963 | "$, $\n", 964 | " W\\in \\mathbb{R}^{[\\mathrm{out}] \\times [\\mathrm{in}]}\n", 965 | "$\n", 966 | "and $\n", 967 | " b\\in \\mathbb{R}^{[\\mathrm{out}]}\n", 968 | "$. Let's use the following `torch`'s functions:\n", 969 | "\n", 970 | "* `F.dropout(x, p, on/off)` -- independent Bernoulli dropout $x\\mapsto x\\odot m$\n", 971 | " for $m\\sim \\mathcal{Ber}\\bigl(\\bigl\\{0, \\tfrac1{1-p}\\bigr\\}, 1-p\\bigr)$\n", 972 | "\n", 973 | "* `F.linear(x, W, b)` -- affine transformation $x \\mapsto x W^\\top + b$\n", 974 | "\n", 975 | "**(note)** the `.weight` of a linear layer in `pytorch` is an $\n", 976 | "{\n", 977 | " [\\mathrm{out}]\n", 978 | " \\times [\\mathrm{in}]\n", 979 | "}\n", 980 | "$ matrix.\n", 981 | "\n", 982 | "" 986 | ] 987 | }, 988 | { 989 | "cell_type": "code", 990 | "execution_count": null, 991 | "metadata": {}, 992 | "outputs": [], 993 | "source": [ 994 | "def DropoutLinear_forward(self, input):\n", 995 | " ## Exercise: If not frozen, then apply always active dropout,\n", 996 | " # then linear transformation. If frozen, apply the transform\n", 997 | " # using the frozen weight\n", 998 | "\n", 999 | " # linear with frozen weight\n", 1000 | " if self.is_frozen():\n", 1001 | " return F.linear(input, self.frozen_weight, self.bias)\n", 1002 | "\n", 1003 | " # stochastic pass as in `ActiveDropout` + Linear\n", 1004 | " input = F.dropout(input, self.p, True)\n", 1005 | "\n", 1006 | " return F.linear(input, self.weight, self.bias)\n", 1007 | " # return super().forward(F.dropout(input, self.p, True))\n", 1008 | "\n", 1009 | " pass" 1010 | ] 1011 | }, 1012 | { 1013 | "cell_type": "markdown", 1014 | "metadata": {}, 1015 | "source": [ 1016 | "
" 1017 | ] 1018 | }, 1019 | { 1020 | "cell_type": "markdown", 1021 | "metadata": {}, 1022 | "source": [ 1023 | "#### Parameter freezer for our custom layer" 1024 | ] 1025 | }, 1026 | { 1027 | "cell_type": "markdown", 1028 | "metadata": {}, 1029 | "source": [ 1030 | "For input\n", 1031 | "$\n", 1032 | " x\\in \\mathbb{R}^{[\\mathrm{in}]}\n", 1033 | "$ and a layer parameters $\n", 1034 | " W\\in \\mathbb{R}^{[\\mathrm{out}] \\times [\\mathrm{in}]}\n", 1035 | "$\n", 1036 | "and $\n", 1037 | " b\\in \\mathbb{R}^{[\\mathrm{out}]}\n", 1038 | "$ the effect in `DropoutLinear` is\n", 1039 | "\n", 1040 | "$$\n", 1041 | " y_j\n", 1042 | " = \\bigl[(x \\odot m) W^\\top + b\\bigr]_j\n", 1043 | " = b_j + \\sum_i x_i m_i W_{ji}\n", 1044 | " = b_j + \\sum_i x_i \\breve{W}_{ji}\n", 1045 | " \\,, $$\n", 1046 | "\n", 1047 | "where the each column of $\\breve{W}_i$ is, independently, either\n", 1048 | "$\\mathbf{0} \\in \\mathbb{R}^{[\\mathrm{out}]}$ with probability $p$ or\n", 1049 | "some (learnable) vector in $\\mathbb{R}^{[\\mathrm{out}]}$\n", 1050 | "\n", 1051 | "$$\n", 1052 | " \\breve{W}_i \\sim\n", 1053 | "\\begin{cases}\n", 1054 | " \\mathbf{0}\n", 1055 | " & \\text{ w. prob } p \\,, \\\\\n", 1056 | " \\tfrac1{1-p} M_i\n", 1057 | " & \\text{ w. prob } 1-p \\,.\n", 1058 | "\\end{cases}\n", 1059 | "$$\n", 1060 | "\n", 1061 | "Thus the multiplicative effect of the random mask $m$ on $x$ can be\n", 1062 | "equivalently seen as a random **on/off** switch effect on the\n", 1063 | "**columns** of the matrix $W$." 1064 | ] 1065 | }, 1066 | { 1067 | "cell_type": "code", 1068 | "execution_count": null, 1069 | "metadata": {}, 1070 | "outputs": [], 1071 | "source": [ 1072 | "def DropoutLinear_freeze(self):\n", 1073 | " \"\"\"Apply dropout with rate `p` to columns of `weight` and freeze it.\"\"\"\n", 1074 | " # we leverage torch's broadcasting semantics and draw a one-row\n", 1075 | " # mask binary mask, that we later multiply the weight by.\n", 1076 | "\n", 1077 | " # let's draw the new weight\n", 1078 | " with torch.no_grad():\n", 1079 | " prob = torch.full_like(self.weight[:1, :], 1 - self.p)\n", 1080 | " feature_mask = torch.bernoulli(prob) / prob\n", 1081 | "\n", 1082 | " frozen_weight = self.weight * feature_mask\n", 1083 | "\n", 1084 | " # and store it\n", 1085 | " self.register_buffer(\"frozen_weight\", frozen_weight)" 1086 | ] 1087 | }, 1088 | { 1089 | "cell_type": "markdown", 1090 | "metadata": {}, 1091 | "source": [ 1092 | "
" 1093 | ] 1094 | }, 1095 | { 1096 | "cell_type": "markdown", 1097 | "metadata": {}, 1098 | "source": [ 1099 | "Assemble the blocks into a layer" 1100 | ] 1101 | }, 1102 | { 1103 | "cell_type": "code", 1104 | "execution_count": null, 1105 | "metadata": {}, 1106 | "outputs": [], 1107 | "source": [ 1108 | "class DropoutLinear(Linear, FreezableWeight):\n", 1109 | " \"\"\"Linear layer with dropout on inputs.\"\"\"\n", 1110 | " def __init__(self, in_features, out_features, bias=True, p=0.5):\n", 1111 | " super().__init__(in_features, out_features, bias=bias)\n", 1112 | "\n", 1113 | " self.p = p\n", 1114 | "\n", 1115 | " forward = DropoutLinear_forward\n", 1116 | "\n", 1117 | " freeze = DropoutLinear_freeze" 1118 | ] 1119 | }, 1120 | { 1121 | "cell_type": "markdown", 1122 | "metadata": {}, 1123 | "source": [ 1124 | "
" 1125 | ] 1126 | }, 1127 | { 1128 | "cell_type": "markdown", 1129 | "metadata": {}, 1130 | "source": [ 1131 | "### Comparing sample functions to point-estimates " 1132 | ] 1133 | }, 1134 | { 1135 | "cell_type": "markdown", 1136 | "metadata": {}, 1137 | "source": [ 1138 | "Let's rewrite the model builder function:" 1139 | ] 1140 | }, 1141 | { 1142 | "cell_type": "code", 1143 | "execution_count": null, 1144 | "metadata": {}, 1145 | "outputs": [], 1146 | "source": [ 1147 | "def build_model(p=0.5):\n", 1148 | " \"\"\"Build a model with the custom layer and dropout rate set to `p`.\"\"\"\n", 1149 | "\n", 1150 | " return Sequential(\n", 1151 | " ## Exercise: Plug-in `DropoutLinear` layer into our second network.\n", 1152 | "\n", 1153 | " Linear(1, 512, bias=True),\n", 1154 | " LeakyReLU(),\n", 1155 | "\n", 1156 | " DropoutLinear(512, 512, bias=True , p=p),\n", 1157 | " LeakyReLU(),\n", 1158 | "\n", 1159 | " DropoutLinear(512, 1, bias=True, p=p),\n", 1160 | "\n", 1161 | " # pass\n", 1162 | " )" 1163 | ] 1164 | }, 1165 | { 1166 | "cell_type": "markdown", 1167 | "metadata": {}, 1168 | "source": [ 1169 | "Let's create a new instance and retrain the model." 1170 | ] 1171 | }, 1172 | { 1173 | "cell_type": "code", 1174 | "execution_count": null, 1175 | "metadata": {}, 1176 | "outputs": [], 1177 | "source": [ 1178 | "model = build_model(p=0.5)\n", 1179 | "model.to(device)\n", 1180 | "\n", 1181 | "fit(model, train, criterion=\"mse\", n_epochs=2000, verbose=True, weight_decay=1e-3)" 1182 | ] 1183 | }, 1184 | { 1185 | "cell_type": "markdown", 1186 | "metadata": {}, 1187 | "source": [ 1188 | "... and obtain two estimates: pointwise and functional." 1189 | ] 1190 | }, 1191 | { 1192 | "cell_type": "code", 1193 | "execution_count": null, 1194 | "metadata": {}, 1195 | "outputs": [], 1196 | "source": [ 1197 | "samples_pe = point_estimate(model, domain, n_samples=51, verbose=True)\n", 1198 | "samples_sf = sample_function(model, domain, n_samples=51, verbose=True)\n", 1199 | "\n", 1200 | "samples_pe.shape, samples_sf.shape" 1201 | ] 1202 | }, 1203 | { 1204 | "cell_type": "markdown", 1205 | "metadata": {}, 1206 | "source": [ 1207 | "```python\n", 1208 | "(torch.Size([51, 251, 1]), torch.Size([51, 251, 1]))\n", 1209 | "```" 1210 | ] 1211 | }, 1212 | { 1213 | "cell_type": "markdown", 1214 | "metadata": {}, 1215 | "source": [ 1216 | "
" 1217 | ] 1218 | }, 1219 | { 1220 | "cell_type": "markdown", 1221 | "metadata": {}, 1222 | "source": [ 1223 | "Let's compare **point estimates**\n", 1224 | "with **function sampling**." 1225 | ] 1226 | }, 1227 | { 1228 | "cell_type": "code", 1229 | "execution_count": null, 1230 | "metadata": {}, 1231 | "outputs": [], 1232 | "source": [ 1233 | "fig, ax = plt.subplots(1, 1, figsize=(12, 5))\n", 1234 | "\n", 1235 | "ax.plot(X_domain[:, 0], samples_pe[:10, :, 0].numpy().T,\n", 1236 | " c=\"C1\", lw=1, alpha=0.5)\n", 1237 | "\n", 1238 | "ax.plot(X_domain[:, 0], samples_sf[:10, :, 0].numpy().T,\n", 1239 | " c=\"C0\", lw=2, alpha=0.5)\n", 1240 | "\n", 1241 | "ax.scatter(X_train, y_train, c=\"black\", s=40,\n", 1242 | " label=\"train\", zorder=+10);" 1243 | ] 1244 | }, 1245 | { 1246 | "cell_type": "code", 1247 | "execution_count": null, 1248 | "metadata": {}, 1249 | "outputs": [], 1250 | "source": [ 1251 | "fig, ax = plt.subplots(1, 1, figsize=(12, 5))\n", 1252 | "\n", 1253 | "ax.scatter(X_train, y_train, c=\"black\", s=40, label=\"train\")\n", 1254 | "\n", 1255 | "mean, std = samples_sf.mean(dim=0).numpy(), samples_sf.std(dim=0).numpy()\n", 1256 | "ax.plot(X_domain, mean + 1.96 * std, c=\"C0\")\n", 1257 | "ax.plot(X_domain, mean - 1.96 * std, c=\"C0\");\n", 1258 | "\n", 1259 | "mean, std = samples_pe.mean(dim=0).numpy(), samples_pe.std(dim=0).numpy()\n", 1260 | "ax.plot(X_domain, mean + 1.96 * std, c=\"C1\")\n", 1261 | "ax.plot(X_domain, mean - 1.96 * std, c=\"C1\");" 1262 | ] 1263 | }, 1264 | { 1265 | "cell_type": "markdown", 1266 | "metadata": {}, 1267 | "source": [ 1268 | "Pros of `point-estimate`:\n", 1269 | "* uses stochastic forward passes -- no need to for extra code and classes\n", 1270 | "\n", 1271 | "Cons of `point-estimate`:\n", 1272 | "* samples from the predictive distribution at adjacent inputs are independent" 1273 | ] 1274 | }, 1275 | { 1276 | "cell_type": "markdown", 1277 | "metadata": {}, 1278 | "source": [ 1279 | "
" 1280 | ] 1281 | }, 1282 | { 1283 | "cell_type": "markdown", 1284 | "metadata": {}, 1285 | "source": [ 1286 | "**(note)**\n", 1287 | "The parameter distribution of the layer we've built is\n", 1288 | "\n", 1289 | "$$\n", 1290 | " q(\\omega\\mid \\theta)\n", 1291 | " = \\prod_i q(\\omega_i\\mid \\theta_i)\n", 1292 | " = \\prod_i \\bigl\\{\n", 1293 | " p \\delta_{\\mathbf{0}} (\\omega_i)\n", 1294 | " + (1 - p) \\delta_{\\tfrac1{1-p} \\theta_i}(\\omega_i)\n", 1295 | " \\bigr\\}\n", 1296 | " \\,, $$\n", 1297 | "\n", 1298 | "where $\\omega_i$ is the $i$-th column of $\\omega$, $\\delta_a$ is a\n", 1299 | "**point-mass** distribution at $a$, and $\\theta$ is the learnt\n", 1300 | "approximate posterior mean of $\\omega$." 1301 | ] 1302 | }, 1303 | { 1304 | "cell_type": "markdown", 1305 | "metadata": {}, 1306 | "source": [ 1307 | "Under benign assumptions and certain relaxations\n", 1308 | "[Gal, Y. 2016 (eq. (6.3) p.109, Prop. 4 p.149)](http://www.cs.ox.ac.uk/people/yarin.gal/website/thesis/thesis.pdf)\n", 1309 | "has shown that a deep network with dropout rate $p$\n", 1310 | "and $\\ell_2$ weight penalty (`weight_decay`) performs (doubly)\n", 1311 | "**stochastic variational inference** with the following stochastic\n", 1312 | "approximate **evidence lower bound**: for the dataset $D = (x_i, y_i)_i$\n", 1313 | "of size $N = \\lvert D \\rvert$ and random batches $B$ of size\n", 1314 | "$\\lvert B \\rvert = m$\n", 1315 | "\n", 1316 | "$$\n", 1317 | " \\frac1{N} \\Bigl( \\underbrace{\n", 1318 | " \\mathbb{E}_{\\omega\\sim q(\\omega\\mid \\theta)} \\log p(D \\mid \\omega)\n", 1319 | " - KL\\bigl(q(\\omega\\mid \\theta) \\big\\| \\pi(\\omega) \\bigr)\n", 1320 | " }_{ELBO(\\theta)} \\Bigr)\n", 1321 | " \\approx \\frac1{\\lvert B \\rvert}\n", 1322 | " \\sum_{i\\in B} \\log p(y_i \\mid x_i, \\omega^{(1)}_i, \\ldots, \\omega^{(L)}_i)\n", 1323 | " - \\sum_{l=1}^L\n", 1324 | " \\frac{1-p^{(l)}}{2 s^2 N} \\|\\theta^{(l)}\\|_2^2\n", 1325 | "% - [\\mathrm{in}_{(l)}] \\, \\mathbb{H}(\\mathcal{Ber}(p^{(l)}))\n", 1326 | "% + \\mathrm{const}\n", 1327 | "\\,, $$\n", 1328 | "where $\\omega_i^{(l)}$ are independently drawn from $q(\\omega \\mid \\theta)$\n", 1329 | "(one random draw per element in $B$) and $s^2$ is the prior variance." 1330 | ] 1331 | }, 1332 | { 1333 | "cell_type": "markdown", 1334 | "metadata": {}, 1335 | "source": [ 1336 | "Thus `weight_decay` should be decreasing with $p$ and $N$:\n", 1337 | "$$ \\lambda = \\frac{1-p}{2 s^2 N} \\,. $$" 1338 | ] 1339 | }, 1340 | { 1341 | "cell_type": "markdown", 1342 | "metadata": {}, 1343 | "source": [ 1344 | "
" 1345 | ] 1346 | }, 1347 | { 1348 | "cell_type": "markdown", 1349 | "metadata": {}, 1350 | "source": [ 1351 | "#### Question(s) (to ponder in your spare time)\n", 1352 | "\n", 1353 | "* what happens to the confidence bands, when you increase the number\n", 1354 | " of path-wise and pointwise samples?\n", 1355 | "\n", 1356 | "* what will happen if you change the dropout rate $p$ and keep `n_epochs` at 2000?\n", 1357 | "\n", 1358 | "* what happens if for $p=\\tfrac12$ we use much less `n_epochs`?\n", 1359 | "\n", 1360 | "* how does different settings of `weight_decay` affect the bands?\n", 1361 | "\n", 1362 | "Try to rebuild the model with different $p \\in (0, 1)$ using `build_model(p)`, use\n", 1363 | "`fit(..., n_epochs=...)`, and then plot the predictive bands." 1364 | ] 1365 | }, 1366 | { 1367 | "cell_type": "code", 1368 | "execution_count": null, 1369 | "metadata": {}, 1370 | "outputs": [], 1371 | "source": [ 1372 | "from mlss2019bdl.plotting import plot1d_bands\n", 1373 | "\n", 1374 | "# model = fit(build_model(p=...), train, n_epochs=..., weight_decay=..., criterion=\"mse\")\n", 1375 | "# plot1d_bands(sample_function(model, domain, n_samples=101), c=\"C0\")" 1376 | ] 1377 | }, 1378 | { 1379 | "cell_type": "markdown", 1380 | "metadata": {}, 1381 | "source": [ 1382 | "
" 1383 | ] 1384 | }, 1385 | { 1386 | "cell_type": "code", 1387 | "execution_count": null, 1388 | "metadata": {}, 1389 | "outputs": [], 1390 | "source": [ 1391 | "model_a = fit(build_model(p=0.15), train, criterion=\"mse\", n_epochs=2000, weight_decay=1e-3)\n", 1392 | "\n", 1393 | "model_z = fit(build_model(p=0.75), train, criterion=\"mse\", n_epochs=2000, weight_decay=1e-3)" 1394 | ] 1395 | }, 1396 | { 1397 | "cell_type": "code", 1398 | "execution_count": null, 1399 | "metadata": {}, 1400 | "outputs": [], 1401 | "source": [ 1402 | "fig = plt.figure(figsize=(12, 5))\n", 1403 | "\n", 1404 | "samples_a = sample_function(model_a, domain, n_samples=101)\n", 1405 | "samples_z = sample_function(model_z, domain, n_samples=101)\n", 1406 | "\n", 1407 | "plot1d_bands(X_domain, samples_a.transpose(0, 2), c=\"r\")\n", 1408 | "plot1d_bands(X_domain, samples_z.transpose(0, 2), c=\"b\")" 1409 | ] 1410 | }, 1411 | { 1412 | "cell_type": "code", 1413 | "execution_count": null, 1414 | "metadata": {}, 1415 | "outputs": [], 1416 | "source": [ 1417 | "model_a = fit(build_model(p=0.50), train, criterion=\"mse\", n_epochs=20, weight_decay=1e-3)\n", 1418 | "\n", 1419 | "model_z = fit(build_model(p=0.50), train, criterion=\"mse\", n_epochs=200, weight_decay=1e-3)" 1420 | ] 1421 | }, 1422 | { 1423 | "cell_type": "code", 1424 | "execution_count": null, 1425 | "metadata": {}, 1426 | "outputs": [], 1427 | "source": [ 1428 | "fig = plt.figure(figsize=(12, 5))\n", 1429 | "\n", 1430 | "samples_a = sample_function(model_a, domain, n_samples=101)\n", 1431 | "samples_z = sample_function(model_z, domain, n_samples=101)\n", 1432 | "\n", 1433 | "plot1d_bands(X_domain, samples_a.transpose(0, 2), c=\"r\")\n", 1434 | "plot1d_bands(X_domain, samples_z.transpose(0, 2), c=\"b\")" 1435 | ] 1436 | }, 1437 | { 1438 | "cell_type": "code", 1439 | "execution_count": null, 1440 | "metadata": {}, 1441 | "outputs": [], 1442 | "source": [ 1443 | "model_a = fit(build_model(p=0.50), train, criterion=\"mse\", n_epochs=2000, weight_decay=1e-5)\n", 1444 | "\n", 1445 | "model_z = fit(build_model(p=0.50), train, criterion=\"mse\", n_epochs=2000, weight_decay=1e-1)" 1446 | ] 1447 | }, 1448 | { 1449 | "cell_type": "code", 1450 | "execution_count": null, 1451 | "metadata": {}, 1452 | "outputs": [], 1453 | "source": [ 1454 | "fig = plt.figure(figsize=(12, 5))\n", 1455 | "\n", 1456 | "samples_a = sample_function(model_a, domain, n_samples=101)\n", 1457 | "\n", 1458 | "samples_z = sample_function(model_z, domain, n_samples=101)\n", 1459 | "\n", 1460 | "plot1d_bands(X_domain, samples_a.transpose(0, 2), c=\"r\")\n", 1461 | "plot1d_bands(X_domain, samples_z.transpose(0, 2), c=\"b\")" 1462 | ] 1463 | }, 1464 | { 1465 | "cell_type": "code", 1466 | "execution_count": null, 1467 | "metadata": {}, 1468 | "outputs": [], 1469 | "source": [ 1470 | "model_a = fit(build_model(p=0.10), train, criterion=\"mse\", n_epochs=2000, weight_decay=1e-3)\n", 1471 | "\n", 1472 | "model_z = fit(build_model(p=0.90), train, criterion=\"mse\", n_epochs=2000, weight_decay=1e-4)" 1473 | ] 1474 | }, 1475 | { 1476 | "cell_type": "code", 1477 | "execution_count": null, 1478 | "metadata": {}, 1479 | "outputs": [], 1480 | "source": [ 1481 | "fig = plt.figure(figsize=(12, 5))\n", 1482 | "\n", 1483 | "samples_a = sample_function(model_a, domain, n_samples=101)\n", 1484 | "\n", 1485 | "samples_z = sample_function(model_z, domain, n_samples=101)\n", 1486 | "\n", 1487 | "plot1d_bands(X_domain, samples_a.transpose(0, 2), c=\"r\")\n", 1488 | "plot1d_bands(X_domain, samples_z.transpose(0, 2), c=\"b\")" 1489 | ] 1490 | }, 1491 | { 1492 | "cell_type": "markdown", 1493 | "metadata": {}, 1494 | "source": [ 1495 | "
" 1496 | ] 1497 | }, 1498 | { 1499 | "cell_type": "markdown", 1500 | "metadata": {}, 1501 | "source": [ 1502 | "### (optional) Dropout $2$-d Convolutional layer" 1503 | ] 1504 | }, 1505 | { 1506 | "cell_type": "markdown", 1507 | "metadata": {}, 1508 | "source": [ 1509 | "Typically, in convolutional neural networks the dropout acts upon the feature\n", 1510 | "(channel) information and not on the spatial dimensions. Thus entire channels\n", 1511 | "are dropped out and for $\n", 1512 | " x \\in \\mathbb{R}^{\n", 1513 | " [\\mathrm{in}]\n", 1514 | " \\times h\n", 1515 | " \\times w}\n", 1516 | "$ and $\n", 1517 | " y \\in \\mathbb{R}^{\n", 1518 | " [\\mathrm{out}]\n", 1519 | " \\times h'\n", 1520 | " \\times w'}\n", 1521 | "$ the full effect of the `Dropout+Conv2d` layer is\n", 1522 | "\n", 1523 | "$$\n", 1524 | " y_{lij} = ((x \\odot m) \\ast W_l)_{ij} + b_l\n", 1525 | " = b_l + \\sum_k \\sum_{pq} x_{k i_p j_q} m_k W_{lkpq}\n", 1526 | " \\,, \\tag{conv-2d} $$\n", 1527 | " \n", 1528 | "where i.i.d $m_k \\sim \\mathcal{Ber}\\bigl(\\bigl\\{0, \\tfrac1{1-p}\\bigr\\}, 1-p\\bigr)$,\n", 1529 | "and indices $i_p$ and $j_q$ represent the spatial location in $x$ that correspond\n", 1530 | "to the $p$ and $q$ elements in the kernel $\n", 1531 | " W\\in \\mathbb{R}^{\n", 1532 | " [\\mathrm{out}]\n", 1533 | " \\times [\\mathrm{in}]\n", 1534 | " \\times h\n", 1535 | " \\times w}\n", 1536 | "$ relative to $(i, j)$ coordinates in $y$.\n", 1537 | "The exact values of $i_p$ and $j_q$ depend on the configuration of the\n", 1538 | "convolutional layer, e.g. stride, kernel size and dilation.\n", 1539 | "\n", 1540 | "**(note)** Informative illustrations on the effects of convolution\n", 1541 | "parameters can be found in [Convolution arithmetic](https://github.com/vdumoulin/conv_arithmetic) \n", 1542 | "repo." 1543 | ] 1544 | }, 1545 | { 1546 | "cell_type": "markdown", 1547 | "metadata": {}, 1548 | "source": [ 1549 | "
" 1550 | ] 1551 | }, 1552 | { 1553 | "cell_type": "markdown", 1554 | "metadata": {}, 1555 | "source": [ 1556 | "## (optional) A brief reminder on Bayesian and Variational Inference" 1557 | ] 1558 | }, 1559 | { 1560 | "cell_type": "markdown", 1561 | "metadata": {}, 1562 | "source": [ 1563 | "Bayesian Inference is a principled framework of reasoning about uncertainty.\n", 1564 | "\n", 1565 | "In Bayesian Inference (**BI**) we *assume* that the observation\n", 1566 | "data $D$ follows a *model* $m$ with data generating distribution\n", 1567 | "$p(D\\mid m, \\omega)$ *governed by unknown parameters* $\\omega$.\n", 1568 | "The goal of **BI** is to reason about the model and/or its parameters,\n", 1569 | "and new data given the observed data $D$ and our assumptions, i.e\n", 1570 | "to seek the **posterior** parameter and predictive distributions:\n", 1571 | "\n", 1572 | "$$\\begin{align}\n", 1573 | " p(d \\mid D, m)\n", 1574 | " % &= \\mathbb{E}_{\n", 1575 | " % \\omega \\sim p(\\omega \\mid D, m)\n", 1576 | " % } p(d \\mid D, \\omega, m)\n", 1577 | " &= \\int p(d \\mid D, \\omega, m) p(\\omega \\mid D, m) d\\omega\n", 1578 | " \\,, \\\\\n", 1579 | " p(\\omega \\mid D, m)\n", 1580 | " &= \\frac{p(D\\mid \\omega, m) \\, \\pi(\\omega \\mid m)}{p(D\\mid m)}\n", 1581 | " \\,.\n", 1582 | "\\end{align}\n", 1583 | "$$\n", 1584 | "\n", 1585 | "* the **prior** distribution $\\pi(\\omega \\mid m)$ reflects our belief\n", 1586 | " before having made the observations\n", 1587 | "\n", 1588 | "* the data distribution $p(D \\mid \\omega, m)$ reflects our assumptions\n", 1589 | " about the data generating process, and determines the parameter\n", 1590 | " **likelihood** (Gaussian, Categorical, Poisson)" 1591 | ] 1592 | }, 1593 | { 1594 | "cell_type": "markdown", 1595 | "metadata": {}, 1596 | "source": [ 1597 | "Unless the distributions and likelihoods are conjugate, posterior in\n", 1598 | "Bayesian inference is typically intractable and it is common to resort\n", 1599 | "to **Variational Inference** or **Monte Carlo** approximations." 1600 | ] 1601 | }, 1602 | { 1603 | "cell_type": "markdown", 1604 | "metadata": {}, 1605 | "source": [ 1606 | "This key idea of this approach is to seek an approximation $q(\\omega)$\n", 1607 | "to the intractable posterior $p(\\omega \\mid D, m)$, via a variational\n", 1608 | "optimization problem over some tractable family of distributions $\\mathcal{Q}$:\n", 1609 | "\n", 1610 | "$$\n", 1611 | " q^*(\\omega)\n", 1612 | " \\in \\arg \\min_{q\\in \\mathcal{Q}} \\mathrm{KL}(q(\\omega) \\| p(\\omega \\mid D, m))\n", 1613 | " \\,, $$\n", 1614 | "\n", 1615 | "where the Kullback-Leibler divergence between $P$ and $Q$ ($P\\ll Q$)\n", 1616 | "with densities $p$ and $q$, respectively, is given by\n", 1617 | "\n", 1618 | "$$\n", 1619 | " \\mathrm{KL}(q(\\omega) \\| p(\\omega))\n", 1620 | "% = \\mathbb{E}_{\\omega \\sim Q} \\log \\tfrac{dQ}{dP}(\\omega)\n", 1621 | " = \\mathbb{E}_{\\omega \\sim q(\\omega)}\n", 1622 | " \\log \\tfrac{q(\\omega)}{p(\\omega)}\n", 1623 | " \\,. \\tag{kl-div} $$\n", 1624 | "\n", 1625 | "\n", 1626 | "Note that the family of variational approximations $\\mathcal{Q}$ can be\n", 1627 | "structured **arbitrarily**: point-mass, products, mixture, dependent on\n", 1628 | "input, having mixed hierarchical structure, -- any valid distribution." 1629 | ] 1630 | }, 1631 | { 1632 | "cell_type": "markdown", 1633 | "metadata": {}, 1634 | "source": [ 1635 | "Although computing the divergence w.r.t. the unknown posterior\n", 1636 | "is still hard and intractable, it is possible to do away with it\n", 1637 | "through the following identity, which is based on the Bayes rule.\n", 1638 | "\n", 1639 | "For **any** $q(\\omega) \\ll p(\\omega \\mid D; \\phi)$ and any model $m$\n", 1640 | "\n", 1641 | "$$\n", 1642 | "\\begin{align}\n", 1643 | " \\overbrace{\n", 1644 | " \\log p(D \\mid m)\n", 1645 | " }^{\\text{evidence}}\n", 1646 | " &= \\overbrace{\n", 1647 | " \\mathbb{E}_{\\omega \\sim q} \\log p(D\\mid \\omega, m)\n", 1648 | " }^{\\text{expected conditional likelihood}}\n", 1649 | " - \\overbrace{\n", 1650 | " \\mathrm{KL}(q(\\omega)\\| \\pi(\\omega \\mid m))\n", 1651 | " }^{\\text{proximity to prior belief}}\n", 1652 | " \\\\\n", 1653 | " &+ \\underbrace{\n", 1654 | " \\mathrm{KL}(q(\\omega)\\| p(\\omega \\mid D, m))\n", 1655 | " }_{\\text{posterior approximation}}\n", 1656 | "\\end{align}\n", 1657 | " \\,. \\tag{master-identity}\n", 1658 | "$$" 1659 | ] 1660 | }, 1661 | { 1662 | "cell_type": "markdown", 1663 | "metadata": {}, 1664 | "source": [ 1665 | "Instead of minimizing the divergence of the approximation from the posterior,\n", 1666 | "we maximize the **Evidence Lower Bound** with respect to $q(\\omega)$:\n", 1667 | "\n", 1668 | "$$\n", 1669 | " q^* \\in\n", 1670 | " \\arg\\max_{q\\in Q}\n", 1671 | " \\mathcal{L}(q) = \n", 1672 | " \\mathbb{E}_{\\omega \\sim q} \\log p(D\\mid \\omega, m)\n", 1673 | " - \\mathrm{KL}(q(\\omega)\\| \\pi(\\omega \\mid m))\n", 1674 | " \\,. \\tag{max-ELBO} $$\n", 1675 | "\n", 1676 | "* the expected $\\log$-likelihood favours $q$ that place their mass on\n", 1677 | "parameters $\\omega$ that explain $D$ under the specified model $m$.\n", 1678 | "\n", 1679 | "* the negative KL-divergence discourages the approximation $q$\n", 1680 | "from straying too far away from to the prior belief $\\pi$ under $m$." 1681 | ] 1682 | }, 1683 | { 1684 | "cell_type": "markdown", 1685 | "metadata": {}, 1686 | "source": [ 1687 | "We usually consider the following setup (conditioning on model $m$ is omitted):\n", 1688 | "* the likelihood factorizes $\n", 1689 | "p(D \\mid \\omega)\n", 1690 | " = \\prod_i p(y_i, x_i \\mid \\omega)\n", 1691 | " \\propto \\prod_i p(y_i \\mid x_i, \\omega)\n", 1692 | "$\n", 1693 | "for $D = (x_i, y_i)_{i=1}^N$\n", 1694 | "\n", 1695 | "* the approximation is parameterized by $\\theta$: $q(\\omega\\mid \\theta)$\n", 1696 | "\n", 1697 | "* the prior on $\\omega$ itself depends on hyper-parameters $\\lambda$, that\n", 1698 | " can be fixed, or variable ($\\pi(\\omega \\mid \\lambda)$)." 1699 | ] 1700 | }, 1701 | { 1702 | "cell_type": "markdown", 1703 | "metadata": {}, 1704 | "source": [ 1705 | "In this case the variational objective (evidence lower bound)\n", 1706 | "\n", 1707 | "$$\n", 1708 | " \\log p(D\\mid \\lambda )\n", 1709 | " \\geq \\mathcal{L}(\\theta, \\lambda)\n", 1710 | " = \\mathbb{E}_{\\omega \\sim q(\\omega \\mid \\theta)}\n", 1711 | " \\sum_i \\log p_\\phi(y_i \\mid x_i, \\omega)\n", 1712 | " - KL(q(\\omega \\mid \\theta) \\| \\pi(\\omega \\mid \\lambda))\n", 1713 | " $$\n", 1714 | "\n", 1715 | "is maximized with respect to $\\theta$ (to approximate the posterior)." 1716 | ] 1717 | }, 1718 | { 1719 | "cell_type": "markdown", 1720 | "metadata": {}, 1721 | "source": [ 1722 | "Priors can be\n", 1723 | "* *subjective*, i.e. reflecting prior beliefs (but not arbitrary),\n", 1724 | "* *objective*, i.e. reflecting our lack of knowledge,\n", 1725 | "* *empirical*, i.e. learnt from data (we also optimize over hyper-parameters $\\lambda$)" 1726 | ] 1727 | }, 1728 | { 1729 | "cell_type": "markdown", 1730 | "metadata": {}, 1731 | "source": [ 1732 | "The stochastic variant of ELBO is formed by randomly batching\n", 1733 | "the dataset $D$:\n", 1734 | "\n", 1735 | "$$\n", 1736 | " \\mathcal{L}(\\theta, \\lambda)\n", 1737 | " \\approx \\mathcal{L}_\\mathrm{SGVB}(\\theta, \\lambda)\n", 1738 | " = \\lvert D \\rvert \\biggl(\n", 1739 | " \\tfrac1{\\lvert B \\rvert}\n", 1740 | " \\sum_{b \\in B} \\mathbb{E}_{\\omega \\sim q(\\omega \\mid \\theta)}\n", 1741 | " \\log p(y_b \\mid x_b, \\omega)\n", 1742 | " \\biggr)\n", 1743 | " - KL(q(\\omega \\mid \\theta) \\| \\pi(\\omega \\mid \\lambda))\n", 1744 | " \\,. $$\n", 1745 | "\n", 1746 | "* Stochastic optimization follows noisy unbiased gradient estimates, which are\n", 1747 | "usually cheap, allow escaping from local optima, and optimize the objective in\n", 1748 | "expectation." 1749 | ] 1750 | }, 1751 | { 1752 | "cell_type": "markdown", 1753 | "metadata": {}, 1754 | "source": [ 1755 | "In order to get a gradient of $\n", 1756 | " F_\\theta = \\mathbb{E}_{\\omega \\sim q(\\omega \\mid \\theta)} f(\\omega)\n", 1757 | "$ w.r.t $\\theta$ we use either:\n", 1758 | "\n", 1759 | "###### (REINFORCE)\n", 1760 | "$\n", 1761 | "\\nabla_\\theta F_\\theta\n", 1762 | " = \\mathbb{E}_{\\omega \\sim q(\\omega \\mid \\theta)}\n", 1763 | " (f(\\omega) - b_\\theta) \\nabla_\\theta \\log q(\\omega \\mid \\theta)\n", 1764 | "$\n", 1765 | "* for some $b_\\theta$ that is used to control variance\n", 1766 | "\n", 1767 | "###### (reparameterization)\n", 1768 | "$\n", 1769 | "\\nabla_\\theta F_\\theta\n", 1770 | " = \\nabla_\\theta \\mathbb{E}_{\\varepsilon \\sim q(\\varepsilon)}\n", 1771 | " f(g(\\theta; \\varepsilon))\n", 1772 | " = \\mathbb{E}_{\\varepsilon \\sim q(\\varepsilon)}\n", 1773 | " \\nabla_\\theta g(\\theta; \\varepsilon)\n", 1774 | " \\nabla_\\omega f(\\omega) \\big\\vert_{\\omega = g(\\theta; \\varepsilon)}\n", 1775 | "$\n", 1776 | "* when there are $q$ and differentiable $g$ such that sampling from\n", 1777 | "$q(\\omega \\mid \\theta)$ is equivalent to $\\omega = g(\\theta; \\varepsilon)$\n", 1778 | "with $\\varepsilon \\sim q(\\varepsilon)$." 1779 | ] 1780 | }, 1781 | { 1782 | "cell_type": "markdown", 1783 | "metadata": {}, 1784 | "source": [ 1785 | "The variational approximation might yield high dimensional integrals,\n", 1786 | "which are slow/prohibitive to compute. To make the computations faster\n", 1787 | "without foregoing much of the precision, we may use Monte Carlo methods:\n", 1788 | "\n", 1789 | "$$\n", 1790 | " \\mathbb{E}_{\\omega \\sim q(\\omega\\mid \\theta)} \\, f(\\omega)\n", 1791 | " \\overset{\\text{MC}}{\\approx}\n", 1792 | " \\frac1{\\lvert \\mathcal{W}\\rvert}\n", 1793 | " \\sum_{\\omega \\in \\mathcal{W}} f(\\omega)\n", 1794 | " \\,,\n", 1795 | "$$\n", 1796 | "\n", 1797 | "where $\\mathcal{W} = (\\omega_b)_{b=1}^B$ is a sample of independent draws\n", 1798 | "from $q(\\omega\\mid \\theta)$." 1799 | ] 1800 | }, 1801 | { 1802 | "cell_type": "markdown", 1803 | "metadata": {}, 1804 | "source": [ 1805 | "If we also approximate the expectation in the gradient of ELBO\n", 1806 | "via Monte Carlo we get **doubly stochastic variational objective**:\n", 1807 | "\n", 1808 | "$$\n", 1809 | " \\nabla_\\theta \\mathcal{L}_\\mathrm{DSVB}(\\theta, \\lambda)\n", 1810 | " \\approx\n", 1811 | " \\lvert D \\rvert \\biggl(\n", 1812 | " \\tfrac1{\\lvert B \\rvert}\n", 1813 | " \\sum_{b \\in B}\n", 1814 | " \\mathop{gradient}(x_b, y_b)\n", 1815 | " \\biggr)\n", 1816 | " - \\nabla_\\theta KL(q(\\omega \\mid \\theta) \\| \\pi(\\omega \\mid \\lambda))\n", 1817 | " \\,, $$\n", 1818 | "\n", 1819 | "where `gradient` is $\n", 1820 | " \\nabla_\\theta\n", 1821 | " \\mathbb{E}_{\\omega \\sim q(\\omega \\mid \\theta)}\n", 1822 | " \\log p(y \\mid x, \\omega)\n", 1823 | "$ using one of the approaches above, typically approximated using\n", 1824 | "one independent draw of $\\omega$ per $b\\in B$." 1825 | ] 1826 | }, 1827 | { 1828 | "cell_type": "markdown", 1829 | "metadata": {}, 1830 | "source": [ 1831 | "We can use a similar sampling approach to compute the gradient of the divergence term." 1832 | ] 1833 | }, 1834 | { 1835 | "cell_type": "markdown", 1836 | "metadata": {}, 1837 | "source": [ 1838 | "A good overview of Bayesian Inference can be found at [bdl101.ml](http://bdl101.ml/),\n", 1839 | "in [this lecture](http://mlg.eng.cam.ac.uk/zoubin/talks/lect1bayes.pdf),\n", 1840 | "[this paper](https://arxiv.org/abs/1206.7051.pdf), or\n", 1841 | "[this review](https://arxiv.org/abs/1601.00670.pdf),\n", 1842 | "among other great resources. It is also possible to consult\n", 1843 | "the references at [wiki](https://en.wikipedia.org/wiki/Bayesian_inference)." 1844 | ] 1845 | }, 1846 | { 1847 | "cell_type": "markdown", 1848 | "metadata": {}, 1849 | "source": [ 1850 | "We can estimate the divergence term in the ELBO\n", 1851 | "with Monte Carlo, or, for example, for the predictive distribution\n", 1852 | "we have\n", 1853 | "\n", 1854 | "$$\n", 1855 | "\\begin{align}\n", 1856 | " \\mathbb{E}_{y\\sim p(y\\mid x, D, m)} \\, g(y)\n", 1857 | " &\\overset{\\text{BI}}{=}\n", 1858 | " \\mathbb{E}_{\\omega\\sim p(\\omega \\mid D, m)}\n", 1859 | " \\mathbb{E}_{y\\sim p(y\\mid x, \\omega, D, m)} \\, g(y) \n", 1860 | " \\\\\n", 1861 | " &\\overset{\\text{VI}}{\\approx}\n", 1862 | " \\mathbb{E}_{\\omega\\sim q(\\omega)}\n", 1863 | " \\mathbb{E}_{y\\sim p(y\\mid x, \\omega, D, m)} \\, g(y)\n", 1864 | " \\\\\n", 1865 | " &\\overset{\\text{MC}}{\\approx}\n", 1866 | "% \\hat{\\mathbb{E}}_{\\omega \\sim \\mathcal{W}}\n", 1867 | "% \\mathbb{E}_{y\\sim p(y\\mid x, \\omega, D, m)} \\, g(y)\n", 1868 | " \\frac1{\\lvert \\mathcal{W}\\rvert} \\sum_{\\omega \\in \\mathcal{W}}\n", 1869 | " \\mathbb{E}_{y\\sim p(y\\mid x, \\omega, D, m)} \\, g(y)\n", 1870 | " \\,,\n", 1871 | "\\end{align}\n", 1872 | "$$\n", 1873 | "\n", 1874 | "where $\\mathcal{W} = (\\omega_b)_{b=1}^B \\sim q(\\omega)$\n", 1875 | "-- iid samples from the variational approximation." 1876 | ] 1877 | }, 1878 | { 1879 | "cell_type": "markdown", 1880 | "metadata": {}, 1881 | "source": [ 1882 | "
" 1883 | ] 1884 | } 1885 | ], 1886 | "metadata": { 1887 | "kernelspec": { 1888 | "display_name": "Python 3", 1889 | "language": "python", 1890 | "name": "python3" 1891 | }, 1892 | "language_info": { 1893 | "codemirror_mode": { 1894 | "name": "ipython", 1895 | "version": 3 1896 | }, 1897 | "file_extension": ".py", 1898 | "mimetype": "text/x-python", 1899 | "name": "python", 1900 | "nbconvert_exporter": "python", 1901 | "pygments_lexer": "ipython3", 1902 | "version": "3.7.4" 1903 | } 1904 | }, 1905 | "nbformat": 4, 1906 | "nbformat_minor": 2 1907 | } 1908 | -------------------------------------------------------------------------------- /Bayesian Deep Learning part2.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# MLSS2019: Bayesian Deep Learning" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "In this tutorial we will uncertainty estimation can be\n", 15 | "used in active learning or expert-in-the-loop pipelines." 16 | ] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "metadata": {}, 21 | "source": [ 22 | "The plan of the tutorial\n", 23 | "1. [Imports and definitions](#Imports-and-definitions)\n", 24 | "2. [Bayesian Active Learning with images](#Bayesian-Active-Learning-with-images)\n", 25 | " 1. [The model](#The-model)\n", 26 | " 2. [the Acquisition Function](#the-Acquisition-Function)\n", 27 | " 3. [Data and the Oracle](#Data-and-the-Oracle)\n", 28 | " 4. [the Active Learning loop](#the-Active-Learning-loop)\n", 29 | " 5. [The baseline](#The-baseline)\n", 30 | "3. [Bayesian Active Learning by Disagreement](#Bayesian-Active-Learning-by-Disagreement)\n", 31 | " 1. [Points of improvement: batch-vs-single](#Points-of-improvement:-batch-vs-single)\n", 32 | " 2. [Points of improvement: bias](#Points-of-improvement:-bias)\n" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "**(note)**\n", 40 | "* to view documentation on something type in `something?` (with one question mark)\n", 41 | "* to view code of something type in `something??` (with two question marks)." 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "metadata": {}, 47 | "source": [ 48 | "
" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": {}, 54 | "source": [ 55 | "## Imports and definitions" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "metadata": {}, 61 | "source": [ 62 | "In this section we import necessary modules and functions and\n", 63 | "define the computational device." 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": {}, 69 | "source": [ 70 | "First, we install some boilerplate service code for this tutorial." 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "!pip install -q --upgrade git+https://github.com/ivannz/mlss2019-bayesian-deep-learning.git" 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "metadata": {}, 85 | "source": [ 86 | "Next, numpy for computing, matplotlib for plotting and tqdm for progress bars." 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "import tqdm\n", 96 | "import numpy as np\n", 97 | "\n", 98 | "%matplotlib inline\n", 99 | "import matplotlib.pyplot as plt" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "metadata": {}, 105 | "source": [ 106 | "For deep learning stuff will be using [pytorch](https://pytorch.org/).\n", 107 | "\n", 108 | "If you are unfamiliar with it, it is basically like `numpy` with autograd,\n", 109 | "native GPU support, and tools for building training and serializing models.\n", 110 | "\n", 111 | "\n", 112 | "There are good introductory tutorials on `pytorch`, like this\n", 113 | "[one](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html)." 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": null, 119 | "metadata": {}, 120 | "outputs": [], 121 | "source": [ 122 | "import torch\n", 123 | "import torch.nn.functional as F\n", 124 | "\n", 125 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" 126 | ] 127 | }, 128 | { 129 | "cell_type": "markdown", 130 | "metadata": {}, 131 | "source": [ 132 | "Next we import the boilerplate code.\n", 133 | "\n", 134 | "* a procedure that implements a minibatch SGD **fit** loop\n", 135 | "* a function, that **evaluates** the model on the provided dataset" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "from mlss2019bdl import fit, predict" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "metadata": {}, 150 | "source": [ 151 | "The algorithm to sample a random function is:\n", 152 | "* for $b = 1... B$ do:\n", 153 | "\n", 154 | " 1. draw an independent realization $f_b\\colon \\mathcal{X} \\to \\mathcal{Y}$\n", 155 | " with from the process $\\{f_\\omega\\}_{\\omega \\sim q(\\omega)}$\n", 156 | " 2. get $\\hat{y}_{bi} = f_b(\\tilde{x}_i)$ for $i=1 .. m$\n" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": null, 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "from mlss2019bdl.bdl import freeze, unfreeze\n", 166 | "\n", 167 | "def sample_function(model, dataset, n_draws=1, verbose=False):\n", 168 | " \"\"\"Draw a realization of a random function.\"\"\"\n", 169 | " outputs = []\n", 170 | " for _ in tqdm.tqdm(range(n_draws), disable=not verbose):\n", 171 | " freeze(model)\n", 172 | "\n", 173 | " outputs.append(predict(model, dataset))\n", 174 | "\n", 175 | " unfreeze(model)\n", 176 | "\n", 177 | " return torch.stack(outputs, dim=0)" 178 | ] 179 | }, 180 | { 181 | "cell_type": "markdown", 182 | "metadata": {}, 183 | "source": [ 184 | "Sample the class probabilities $p(y_x = k \\mid x, \\omega, m)$\n", 185 | "with $\\omega \\sim q(\\omega)$ by a model that **outputs raw class\n", 186 | "logit scores**." 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": null, 192 | "metadata": {}, 193 | "outputs": [], 194 | "source": [ 195 | "def sample_proba(model, dataset, n_draws=1):\n", 196 | " logits = sample_function(model, dataset, n_draws=n_draws)\n", 197 | "\n", 198 | " return F.softmax(logits, dim=-1)" 199 | ] 200 | }, 201 | { 202 | "cell_type": "markdown", 203 | "metadata": {}, 204 | "source": [ 205 | "Get the predictive posterior class probabilities\n", 206 | "$$\n", 207 | "p(y_x = k \\mid x, m)\n", 208 | "% = \\mathbb{E}_{\\omega \\sim q(\\omega)}\n", 209 | "% p(y_x = k \\mid x, \\omega, m)\n", 210 | " \\approx \\frac1{\\lvert \\mathcal{W} \\rvert}\n", 211 | " \\sum_{\\omega \\in \\mathcal{W}}\n", 212 | " p(y_x = k \\mid x, \\omega, m)\n", 213 | " \\,, $$\n", 214 | "with $\\mathcal{W}$ -- iid draws from $q(\\omega)$." 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": null, 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [ 223 | "def predict_proba(model, dataset, n_draws=1):\n", 224 | " proba = sample_proba(model, dataset, n_draws=n_draws)\n", 225 | "\n", 226 | " return proba.mean(dim=0)" 227 | ] 228 | }, 229 | { 230 | "cell_type": "markdown", 231 | "metadata": {}, 232 | "source": [ 233 | "Gat the maximum a posteriori class label **(MAP)**: $\n", 234 | "\\hat{y}_x\n", 235 | " = \\arg \\max_k \\mathbb{E}_{\\omega \\sim q(\\omega)}\n", 236 | " p(y_x = k \\mid x, \\omega, m)\n", 237 | "$" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": null, 243 | "metadata": {}, 244 | "outputs": [], 245 | "source": [ 246 | "def predict_label(model, dataset, n_draws=1):\n", 247 | " proba = predict_proba(model, dataset, n_draws=n_draws)\n", 248 | "\n", 249 | " return proba.argmax(dim=-1)" 250 | ] 251 | }, 252 | { 253 | "cell_type": "markdown", 254 | "metadata": {}, 255 | "source": [ 256 | "We will need some functionality from scikit" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": null, 262 | "metadata": {}, 263 | "outputs": [], 264 | "source": [ 265 | "from sklearn.metrics import confusion_matrix\n", 266 | "\n", 267 | "def evaluate(model, dataset, n_draws=1):\n", 268 | " assert isinstance(dataset, TensorDataset)\n", 269 | "\n", 270 | " predicted = predict_label(model, dataset, n_draws=n_draws)\n", 271 | "\n", 272 | " target = dataset.tensors[1].cpu().numpy()\n", 273 | " return confusion_matrix(target, predicted.cpu().numpy())" 274 | ] 275 | }, 276 | { 277 | "cell_type": "markdown", 278 | "metadata": {}, 279 | "source": [ 280 | "A function to plot images in a small dataset. " 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": null, 286 | "metadata": {}, 287 | "outputs": [], 288 | "source": [ 289 | "from mlss2019bdl.flex import plot\n", 290 | "from torch.utils.data import TensorDataset\n", 291 | "from IPython.display import clear_output\n", 292 | "\n", 293 | "def display(images, n_col=None, title=None, figsize=None, refresh=False):\n", 294 | " if isinstance(images, TensorDataset):\n", 295 | " images, targets = images.tensors\n", 296 | " \n", 297 | " if refresh:\n", 298 | " clear_output(True)\n", 299 | "\n", 300 | " fig, ax = plt.subplots(1, 1, figsize=figsize)\n", 301 | " plot(ax, images, n_col=n_col, cmap=plt.cm.bone)\n", 302 | " if title is not None:\n", 303 | " ax.set_title(title)\n", 304 | "\n", 305 | " plt.show()\n", 306 | " plt.close()" 307 | ] 308 | }, 309 | { 310 | "cell_type": "markdown", 311 | "metadata": {}, 312 | "source": [ 313 | "
" 314 | ] 315 | }, 316 | { 317 | "cell_type": "markdown", 318 | "metadata": {}, 319 | "source": [ 320 | "## Bayesian Active Learning with images" 321 | ] 322 | }, 323 | { 324 | "cell_type": "markdown", 325 | "metadata": {}, 326 | "source": [ 327 | "* Data labelling is costly and time consuming\n", 328 | "* unlabeled instances are essentially free\n", 329 | "\n", 330 | "**Goal** Achieve high performance with fewer labels by\n", 331 | "identifying the best instances to learn from" 332 | ] 333 | }, 334 | { 335 | "cell_type": "markdown", 336 | "metadata": {}, 337 | "source": [ 338 | "Essential blocks of active learning:\n", 339 | "\n", 340 | "* a **model** $m$ capable of quantifying uncertainty (preferably a Bayesian model)\n", 341 | "* an **acquisition function** $a\\colon \\mathcal{M} \\times \\mathcal{X}^* \\to \\mathbb{R}$\n", 342 | " that for any finite set of inputs $S\\subset \\mathcal{X}$ quantifies their usefulness\n", 343 | " to the model $m\\in \\mathcal{M}$\n", 344 | "* a labelling **oracle**, e.g. a human expert" 345 | ] 346 | }, 347 | { 348 | "cell_type": "markdown", 349 | "metadata": {}, 350 | "source": [ 351 | "### The model" 352 | ] 353 | }, 354 | { 355 | "cell_type": "markdown", 356 | "metadata": {}, 357 | "source": [ 358 | "We reuse the `DropoutLinear` from the first part." 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": null, 364 | "metadata": {}, 365 | "outputs": [], 366 | "source": [ 367 | "from torch.nn import Module, Sequential\n", 368 | "from torch.nn import AvgPool2d, LeakyReLU\n", 369 | "from torch.nn import Linear, Conv2d\n", 370 | "\n", 371 | "from mlss2019bdl.bdl import DropoutLinear, DropoutConv2d\n", 372 | "\n", 373 | "class MNISTModel(Module):\n", 374 | " def __init__(self, p=0.5):\n", 375 | " super().__init__()\n", 376 | "\n", 377 | " self.head = Sequential(\n", 378 | " Conv2d(1, 32, 3, 1),\n", 379 | " LeakyReLU(),\n", 380 | " DropoutConv2d(32, 64, 3, 1, p=p),\n", 381 | " LeakyReLU(),\n", 382 | " AvgPool2d(2),\n", 383 | " )\n", 384 | "\n", 385 | " self.tail = Sequential(\n", 386 | " DropoutLinear(12 * 12 * 64, 128, p=p),\n", 387 | " LeakyReLU(),\n", 388 | " DropoutLinear(128, 10, p=p),\n", 389 | " )\n", 390 | "\n", 391 | " def forward(self, input):\n", 392 | " \"\"\"Take images and compute their class logits.\"\"\"\n", 393 | " x = self.head(input)\n", 394 | " return self.tail(x.flatten(1))" 395 | ] 396 | }, 397 | { 398 | "cell_type": "markdown", 399 | "metadata": {}, 400 | "source": [ 401 | "
" 402 | ] 403 | }, 404 | { 405 | "cell_type": "markdown", 406 | "metadata": {}, 407 | "source": [ 408 | "### the Acquisition Function" 409 | ] 410 | }, 411 | { 412 | "cell_type": "markdown", 413 | "metadata": {}, 414 | "source": [ 415 | "There are many acquisition criteria (borrowed from [Gal17a](http://proceedings.mlr.press/v70/gal17a.html)):\n", 416 | "* Classification\n", 417 | " * Posterior predictive entropy\n", 418 | " * Posterior Mutual Information\n", 419 | " * Variance ratios\n", 420 | " * BALD\n", 421 | "\n", 422 | "* Regression\n", 423 | " * predictive variance\n", 424 | "\n", 425 | "... and there is always the baseline **random acquisition**" 426 | ] 427 | }, 428 | { 429 | "cell_type": "code", 430 | "execution_count": null, 431 | "metadata": {}, 432 | "outputs": [], 433 | "source": [ 434 | "random_state = np.random.RandomState(812_760_351)\n", 435 | "\n", 436 | "def random_acquisition(dataset, model, n_request=1, n_draws=1):\n", 437 | " indices = random_state.choice(len(dataset), size=n_request)\n", 438 | "\n", 439 | " return torch.from_numpy(indices).to(device)" 440 | ] 441 | }, 442 | { 443 | "cell_type": "markdown", 444 | "metadata": {}, 445 | "source": [ 446 | "
" 447 | ] 448 | }, 449 | { 450 | "cell_type": "markdown", 451 | "metadata": {}, 452 | "source": [ 453 | "### Data and the Oracle" 454 | ] 455 | }, 456 | { 457 | "cell_type": "markdown", 458 | "metadata": {}, 459 | "source": [ 460 | "Prepare the datasets from the `train` part of\n", 461 | "[MNIST](http://yann.lecun.com/exdb/mnist/)\n", 462 | "(or [Kuzushiji-MNIST](https://github.com/rois-codh/kmnist)):\n", 463 | "* ($\\mathcal{S}_\\mathrm{train}$) initial **training**: $30$ images\n", 464 | "* ($\\mathcal{S}_\\mathrm{valid}$) our **validation**:\n", 465 | " $5000$ images, stratified\n", 466 | "* ($\\mathcal{S}_\\mathrm{pool}$) acquisition **pool**:\n", 467 | " $5000$ of the unused images, skewed to class $0$\n", 468 | "\n", 469 | "The true test sample of MNIST is in $\\mathcal{S}_\\mathrm{test}$ -- we\n", 470 | "will use it to evaluate the final performance." 471 | ] 472 | }, 473 | { 474 | "cell_type": "code", 475 | "execution_count": null, 476 | "metadata": {}, 477 | "outputs": [], 478 | "source": [ 479 | "from mlss2019bdl.dataset import get_dataset\n", 480 | "\n", 481 | "S_train, S_pool, S_valid, S_test = get_dataset(\n", 482 | " n_train=30,\n", 483 | " n_valid=5000,\n", 484 | " n_pool=5000,\n", 485 | " name=\"MNIST\", # \"KMNIST\"\n", 486 | " path=\"./data\",\n", 487 | " random_state=722_257_201)" 488 | ] 489 | }, 490 | { 491 | "cell_type": "markdown", 492 | "metadata": {}, 493 | "source": [ 494 | "* `query_oracle(ix, D)` **request** the instances in `D` at the specified\n", 495 | " indices `ix` into a dataset and **remove** from them from `D`\n", 496 | "\n", 497 | "* `merge(*datasets, [out=])` merge the datasets, creting a new one, or replacing `out`" 498 | ] 499 | }, 500 | { 501 | "cell_type": "code", 502 | "execution_count": null, 503 | "metadata": {}, 504 | "outputs": [], 505 | "source": [ 506 | "from mlss2019bdl.dataset import collect as query_oracle" 507 | ] 508 | }, 509 | { 510 | "cell_type": "markdown", 511 | "metadata": {}, 512 | "source": [ 513 | "
" 514 | ] 515 | }, 516 | { 517 | "cell_type": "markdown", 518 | "metadata": {}, 519 | "source": [ 520 | "### the Active Learning loop" 521 | ] 522 | }, 523 | { 524 | "cell_type": "markdown", 525 | "metadata": {}, 526 | "source": [ 527 | "1. fit $m$ on $\\mathcal{S}_{\\mathrm{labelled}}$\n", 528 | "\n", 529 | "\n", 530 | "2. get exact (or approximate) $$\n", 531 | " \\mathcal{S}^* \\in \\arg \\max\\limits_{S \\subseteq \\mathcal{S}_\\mathrm{unlabelled}}\n", 532 | " a(m, S)\n", 533 | "$$ satisfying **budget constraints** and **without** access to targets\n", 534 | "(constraints, like $\\lvert S \\rvert \\leq \\ell$ or other economically motivated ones).\n", 535 | "\n", 536 | "\n", 537 | "3. request the **oracle** to provide labels for each $x\\in \\mathcal{S}^*$\n", 538 | "\n", 539 | "\n", 540 | "4. update $\n", 541 | "\\mathcal{S}_{\\mathrm{labelled}}\n", 542 | " \\leftarrow \\mathcal{S}^*\n", 543 | " \\cup \\mathcal{S}_{\\mathrm{labelled}}\n", 544 | "$ and goto 1." 545 | ] 546 | }, 547 | { 548 | "cell_type": "code", 549 | "execution_count": null, 550 | "metadata": {}, 551 | "outputs": [], 552 | "source": [ 553 | "import copy\n", 554 | "from mlss2019bdl.dataset import merge\n", 555 | "\n", 556 | "def active_learn(S_train,\n", 557 | " S_pool,\n", 558 | " S_valid,\n", 559 | " acquire_fn,\n", 560 | " n_budget=150,\n", 561 | " n_max_request=3,\n", 562 | " n_draws=11,\n", 563 | " n_epochs=200,\n", 564 | " p=0.5,\n", 565 | " weight_decay=1e-2):\n", 566 | "\n", 567 | " model = MNISTModel(p=p).to(device)\n", 568 | "\n", 569 | " scores, balances = [], []\n", 570 | " S_train, S_pool = copy.deepcopy(S_train), copy.deepcopy(S_pool)\n", 571 | " while True:\n", 572 | " # 1. fit on train\n", 573 | " l2_reg = weight_decay * (1 - p) / max(len(S_train), 1)\n", 574 | "\n", 575 | " model = fit(model, S_train, batch_size=32, criterion=\"cross_entropy\",\n", 576 | " weight_decay=l2_reg, n_epochs=n_epochs)\n", 577 | "\n", 578 | "\n", 579 | " # (optional) keep track of scores and plot the train dataset\n", 580 | " scores.append(evaluate(model, S_valid, n_draws))\n", 581 | " balances.append(np.bincount(S_train.tensors[1], minlength=10))\n", 582 | "\n", 583 | " accuracy = scores[-1].diagonal().sum() / scores[-1].sum()\n", 584 | " title = f\"(n_train) {len(S_train)} (Acc.) {accuracy:.1%}\"\n", 585 | " display(S_train, n_col=30, figsize=(15, 5), title=title, refresh=True)\n", 586 | "\n", 587 | "\n", 588 | " # 2-3. request new data from pool, if within budget\n", 589 | " n_request = min(n_budget - len(S_train), n_max_request)\n", 590 | " if n_request <= 0:\n", 591 | " break\n", 592 | "\n", 593 | " indices = acquire_fn(S_pool, model, n_request=n_request, n_draws=n_draws)\n", 594 | "\n", 595 | " # 4. update the train dataset\n", 596 | " S_requested = query_oracle(indices, S_pool)\n", 597 | " S_train = merge(S_train, S_requested)\n", 598 | "\n", 599 | " return model, S_train, np.stack(scores, axis=0), np.stack(balances, axis=0)" 600 | ] 601 | }, 602 | { 603 | "cell_type": "markdown", 604 | "metadata": {}, 605 | "source": [ 606 | "* `collect(ix, D)` **collect** the instances in `D` at the specified\n", 607 | " indices `ix` into a dataset and **remove** from them from `D`\n", 608 | "\n", 609 | "* `merge(*datasets, [out=])` merge the datasets, creting a new one, or replacing `out`" 610 | ] 611 | }, 612 | { 613 | "cell_type": "markdown", 614 | "metadata": {}, 615 | "source": [ 616 | "
" 617 | ] 618 | }, 619 | { 620 | "cell_type": "markdown", 621 | "metadata": {}, 622 | "source": [ 623 | "### The baseline" 624 | ] 625 | }, 626 | { 627 | "cell_type": "markdown", 628 | "metadata": {}, 629 | "source": [ 630 | "How powerful will our model with random acquisition get under a total budget of $150$ images?" 631 | ] 632 | }, 633 | { 634 | "cell_type": "code", 635 | "execution_count": null, 636 | "metadata": { 637 | "scrolled": false 638 | }, 639 | "outputs": [], 640 | "source": [ 641 | "baseline = active_learn(\n", 642 | " S_train,\n", 643 | " S_pool,\n", 644 | " S_valid,\n", 645 | " random_acquisition,\n", 646 | " n_draws=21,\n", 647 | " n_budget=150,\n", 648 | " n_max_request=3,\n", 649 | " n_epochs=200,\n", 650 | ")" 651 | ] 652 | }, 653 | { 654 | "cell_type": "markdown", 655 | "metadata": {}, 656 | "source": [ 657 | "Let's see the dynamics of the accuracy ..." 658 | ] 659 | }, 660 | { 661 | "cell_type": "code", 662 | "execution_count": null, 663 | "metadata": {}, 664 | "outputs": [], 665 | "source": [ 666 | "def accuracy(scores):\n", 667 | " tp = scores.diagonal(axis1=-2, axis2=-1)\n", 668 | " return tp.sum(-1) / scores.sum((-2, -1))" 669 | ] 670 | }, 671 | { 672 | "cell_type": "code", 673 | "execution_count": null, 674 | "metadata": {}, 675 | "outputs": [], 676 | "source": [ 677 | "model_rand, train_rand, scores_rand, balances_rand = baseline\n", 678 | "\n", 679 | "fig, ax = plt.subplots(1, 1, figsize=(12, 7))\n", 680 | "ax.plot(accuracy(scores_rand), label='Accuracy (random)', lw=2)\n", 681 | "\n", 682 | "ax.legend()\n", 683 | "plt.show()" 684 | ] 685 | }, 686 | { 687 | "cell_type": "markdown", 688 | "metadata": {}, 689 | "source": [ 690 | "..., and the frequency of each class in $\\mathcal{S}_\\mathrm{train}$." 691 | ] 692 | }, 693 | { 694 | "cell_type": "code", 695 | "execution_count": null, 696 | "metadata": {}, 697 | "outputs": [], 698 | "source": [ 699 | "fig, ax = plt.subplots(1, 1, figsize=(12, 7))\n", 700 | "\n", 701 | "lines = ax.plot(balances_rand, lw=2)\n", 702 | "plt.legend(lines, list(range(10)), ncol=2);" 703 | ] 704 | }, 705 | { 706 | "cell_type": "markdown", 707 | "metadata": {}, 708 | "source": [ 709 | "
" 710 | ] 711 | }, 712 | { 713 | "cell_type": "markdown", 714 | "metadata": {}, 715 | "source": [ 716 | "## Bayesian Active Learning by Disagreement" 717 | ] 718 | }, 719 | { 720 | "cell_type": "markdown", 721 | "metadata": {}, 722 | "source": [ 723 | "Bayesian Active Learning by Disagreement, or **BALD** criterion, is\n", 724 | "based on the posterior mutual information between model's predictions\n", 725 | "$y_x$ at some point $x$ and its parameters $\\omega$:\n", 726 | "\n", 727 | "$$\\begin{align}\n", 728 | " a(m, S)\n", 729 | " &= \\sum_{x\\in S} a(m, \\{x\\})\n", 730 | " \\\\\n", 731 | " a(m, \\{x\\})\n", 732 | " &= \\mathbb{I}(y_x; \\omega \\mid x, m, D)\n", 733 | "\\end{align}\n", 734 | " \\,, \\tag{bald} $$\n", 735 | "\n", 736 | "with the [**Mutual Information**](https://en.wikipedia.org/wiki/Mutual_information#Relation_to_Kullback%E2%80%93Leibler_divergence)\n", 737 | "(**MI**)\n", 738 | "$$\n", 739 | " \\mathbb{I}(y_x; \\omega \\mid x, m, D)\n", 740 | " = \\mathbb{H}\\bigl(\n", 741 | " \\mathbb{E}_{\\omega \\sim q(\\omega\\mid m, D)}\n", 742 | " p(y_x \\,\\mid\\, x, \\omega, m, D)\n", 743 | " \\bigr)\n", 744 | " - \\mathbb{E}_{\\omega \\sim q(\\omega\\mid m, D)}\n", 745 | " \\mathbb{H}\\bigl(\n", 746 | " p(y_x \\,\\mid\\, x, \\omega, m, D)\n", 747 | " \\bigr)\n", 748 | " \\,, \\tag{mi} $$\n", 749 | "\n", 750 | "and the [(differential) **entropy**](https://en.wikipedia.org/wiki/Differential_entropy#Differential_entropies_for_various_distributions)\n", 751 | "(all densities and/or probability mass functions can be conditional):\n", 752 | "\n", 753 | "$$\n", 754 | " \\mathbb{H}(p(y))\n", 755 | " = - \\mathbb{E}_{y\\sim p} \\log p(y)\n", 756 | " \\,. $$" 757 | ] 758 | }, 759 | { 760 | "cell_type": "markdown", 761 | "metadata": {}, 762 | "source": [ 763 | "
" 764 | ] 765 | }, 766 | { 767 | "cell_type": "markdown", 768 | "metadata": {}, 769 | "source": [ 770 | "#### (task) Implementing the acquisition function" 771 | ] 772 | }, 773 | { 774 | "cell_type": "markdown", 775 | "metadata": {}, 776 | "source": [ 777 | "Note that $a(m, S)$ is additively separable in $S$, i.e.\n", 778 | "equals $\\sum_{x\\in S} a(m, \\{x\\})$. This implies\n", 779 | "\n", 780 | "$$\n", 781 | "\\begin{align}\n", 782 | " \\max_{S \\subseteq \\mathcal{S}_\\mathrm{unlabelled}} a(m, S)\n", 783 | " &= \\max_{z \\in \\mathcal{S}_\\mathrm{unlabelled}}\n", 784 | " \\max_{F \\in \\mathcal{S}_\\mathrm{unlabelled} \\setminus \\{z\\}}\n", 785 | " \\sum_{x\\in F \\cup \\{x\\}} a(m, \\{x\\})\n", 786 | " \\\\\n", 787 | " &= \\max_{z \\in \\mathcal{S}_\\mathrm{unlabelled}}\n", 788 | " a(m, \\{z\\})\n", 789 | " + \\max_{F \\in \\mathcal{S}_\\mathrm{unlabelled} \\setminus \\{z\\}}\n", 790 | " \\sum_{x\\in F} a(m, \\{x\\})\n", 791 | "\\end{align}\n", 792 | " \\,. $$" 793 | ] 794 | }, 795 | { 796 | "cell_type": "markdown", 797 | "metadata": {}, 798 | "source": [ 799 | "Therefore selecting the $\\ell$ `most interesting` points from\n", 800 | "$\\mathcal{S}_\\mathrm{unlabelled}$ is trivial." 801 | ] 802 | }, 803 | { 804 | "cell_type": "markdown", 805 | "metadata": {}, 806 | "source": [ 807 | "The acquisition function that we implement has interface\n", 808 | "identical to `random_acquisition` but uses BALD to choose\n", 809 | "instances." 810 | ] 811 | }, 812 | { 813 | "cell_type": "code", 814 | "execution_count": null, 815 | "metadata": {}, 816 | "outputs": [], 817 | "source": [ 818 | "def BALD_acquisition(dataset, model, n_request=1, n_draws=1):\n", 819 | " ## Exercise: implement BALD\n", 820 | "\n", 821 | " proba = sample_proba(model, dataset, n_draws=n_draws)\n", 822 | "\n", 823 | " mi = mutual_information(proba)\n", 824 | "\n", 825 | " return mi.argsort(descending=True)[:n_request]\n", 826 | "\n", 827 | " pass" 828 | ] 829 | }, 830 | { 831 | "cell_type": "markdown", 832 | "metadata": {}, 833 | "source": [ 834 | "
" 835 | ] 836 | }, 837 | { 838 | "cell_type": "markdown", 839 | "metadata": {}, 840 | "source": [ 841 | "#### (task) implementing entropy" 842 | ] 843 | }, 844 | { 845 | "cell_type": "markdown", 846 | "metadata": {}, 847 | "source": [ 848 | "For categorical (discrete) random variables $y \\sim \\mathcal{Cat}(\\mathbf{p})$,\n", 849 | "$\\mathbf{p} \\in \\{ \\mu \\in [0, 1]^d \\colon \\sum_k \\mu_k = 1\\}$, the entropy is\n", 850 | "\n", 851 | "$$\n", 852 | " \\mathbb{H}(p(y))\n", 853 | " = - \\mathbb{E}_{y\\sim p(y)} \\log p(y)\n", 854 | " = - \\sum_k p_k \\log p_k\n", 855 | " \\,. $$" 856 | ] 857 | }, 858 | { 859 | "cell_type": "markdown", 860 | "metadata": {}, 861 | "source": [ 862 | "**(note)** although in calculus $0 \\cdot \\log 0 = 0$ (because\n", 863 | "$\\lim_{p\\downarrow 0} p \\cdot \\log p = 0$), in floating point\n", 864 | "arithmetic $0 \\cdot \\log 0 = \\mathrm{NaN}$. So you need to add\n", 865 | "some **really tiny float number** to the argument of $\\log$." 866 | ] 867 | }, 868 | { 869 | "cell_type": "code", 870 | "execution_count": null, 871 | "metadata": {}, 872 | "outputs": [], 873 | "source": [ 874 | "def categorical_entropy(proba):\n", 875 | " \"\"\"Compute the entropy along the last dimension.\"\"\"\n", 876 | "\n", 877 | " ## Exercise: the probabilities sum to one along the last axis.\n", 878 | " # Please, compute their entropy.\n", 879 | "\n", 880 | " return - torch.kl_div(torch.tensor(0.).to(proba), proba).sum(dim=-1)\n", 881 | "\n", 882 | " return - torch.sum(proba * torch.log(proba + 1e-20), dim=-1)\n", 883 | "\n", 884 | " pass" 885 | ] 886 | }, 887 | { 888 | "cell_type": "markdown", 889 | "metadata": {}, 890 | "source": [ 891 | "
" 892 | ] 893 | }, 894 | { 895 | "cell_type": "markdown", 896 | "metadata": {}, 897 | "source": [ 898 | "#### (task) implementing mutual information" 899 | ] 900 | }, 901 | { 902 | "cell_type": "markdown", 903 | "metadata": {}, 904 | "source": [ 905 | "Consider a tensor $p_{bik}$ of probabilities $p(y_{x_i}=k \\mid x_i, \\omega_b, m, D)$\n", 906 | "with $\\omega_b \\sim q(\\omega \\mid m, D)$ with $\\mathcal{W} = (\\omega_b)_{b=1}^B$\n", 907 | "being iid draws from $q(\\omega \\mid m, D)$.\n", 908 | "\n", 909 | "Let's implement a procedure that computes the Monte Carlo estimate of the\n", 910 | "posterior predictive distribution, its **entropy** and **mutual information**\n", 911 | "\n", 912 | "$$\n", 913 | " \\mathbb{I}_\\mathrm{MC}(y_x; \\omega \\mid x, m, D)\n", 914 | " = \\mathbb{H}\\bigl(\n", 915 | " \\hat{p}(y_x\\mid x, m, D)\n", 916 | " \\bigr)\n", 917 | " - \\frac1{\\lvert \\mathcal{W} \\rvert} \\sum_{\\omega\\in \\mathcal{W}}\n", 918 | " \\mathbb{H}\\bigl(\n", 919 | " p(y_x \\,\\mid\\, x, \\omega, m, D)\n", 920 | " \\bigr)\n", 921 | " \\,, \\tag{mi-mc} $$\n", 922 | "where\n", 923 | "$$\n", 924 | "\\hat{p}(y_x\\mid x, m, D)\n", 925 | " = \\frac1{\\lvert \\mathcal{W} \\rvert} \\sum_{\\omega\\in \\mathcal{W}}\n", 926 | " \\,p(y_x \\mid x, \\omega, m, D)\n", 927 | " \\,. $$" 928 | ] 929 | }, 930 | { 931 | "cell_type": "code", 932 | "execution_count": null, 933 | "metadata": {}, 934 | "outputs": [], 935 | "source": [ 936 | "def mutual_information(proba):\n", 937 | " ## Exercise: compute a Monte Carlo estimator of the predictive\n", 938 | " ## distribution, its entropy and MI `H E_w p(., w) - E_w H p(., w)`\n", 939 | "\n", 940 | " entropy_expected = categorical_entropy(proba.mean(dim=0))\n", 941 | " expected_entropy = categorical_entropy(proba).mean(dim=0)\n", 942 | "\n", 943 | " return entropy_expected - expected_entropy\n", 944 | "\n", 945 | " pass" 946 | ] 947 | }, 948 | { 949 | "cell_type": "markdown", 950 | "metadata": {}, 951 | "source": [ 952 | "
" 953 | ] 954 | }, 955 | { 956 | "cell_type": "markdown", 957 | "metadata": {}, 958 | "source": [ 959 | "How powerful will our model with **BALD** acquisition, if we can afford no more than $150$ images?" 960 | ] 961 | }, 962 | { 963 | "cell_type": "code", 964 | "execution_count": null, 965 | "metadata": { 966 | "scrolled": false 967 | }, 968 | "outputs": [], 969 | "source": [ 970 | "bald_results = active_learn(\n", 971 | " S_train,\n", 972 | " S_pool,\n", 973 | " S_valid,\n", 974 | " BALD_acquisition,\n", 975 | " n_draws=21,\n", 976 | " n_budget=150,\n", 977 | " n_max_request=3,\n", 978 | " n_epochs=200,\n", 979 | ")" 980 | ] 981 | }, 982 | { 983 | "cell_type": "markdown", 984 | "metadata": {}, 985 | "source": [ 986 | "Let's see the dynamics of the accuracy ..." 987 | ] 988 | }, 989 | { 990 | "cell_type": "code", 991 | "execution_count": null, 992 | "metadata": {}, 993 | "outputs": [], 994 | "source": [ 995 | "model_bald, train_bald, scores_bald, balances_bald = bald_results\n", 996 | "\n", 997 | "fig, ax = plt.subplots(1, 1, figsize=(12, 7))\n", 998 | "\n", 999 | "ax.plot(accuracy(scores_rand), label='Accuracy (random)', lw=2)\n", 1000 | "ax.plot(accuracy(scores_bald), label='Accuracy (BALD)', lw=2)\n", 1001 | "\n", 1002 | "ax.legend()\n", 1003 | "plt.show()" 1004 | ] 1005 | }, 1006 | { 1007 | "cell_type": "markdown", 1008 | "metadata": {}, 1009 | "source": [ 1010 | "..., and the frequency of each class in $\\mathcal{S}_\\mathrm{train}$." 1011 | ] 1012 | }, 1013 | { 1014 | "cell_type": "code", 1015 | "execution_count": null, 1016 | "metadata": {}, 1017 | "outputs": [], 1018 | "source": [ 1019 | "fig, ax = plt.subplots(1, 1, figsize=(12, 7))\n", 1020 | "\n", 1021 | "lines = ax.plot(balances_bald, lw=2)\n", 1022 | "plt.legend(lines, list(range(10)), ncol=2);" 1023 | ] 1024 | }, 1025 | { 1026 | "cell_type": "markdown", 1027 | "metadata": {}, 1028 | "source": [ 1029 | "
" 1030 | ] 1031 | }, 1032 | { 1033 | "cell_type": "markdown", 1034 | "metadata": {}, 1035 | "source": [ 1036 | "#### Class performance" 1037 | ] 1038 | }, 1039 | { 1040 | "cell_type": "markdown", 1041 | "metadata": {}, 1042 | "source": [ 1043 | "The *one-versus-rest* precision / recall scores on\n", 1044 | "$\\mathcal{S}_\\mathrm{valid}$. For binary classification:\n", 1045 | "\n", 1046 | "$$ \\begin{align}\n", 1047 | "\\mathrm{Precision}\n", 1048 | " &= \\frac{\\mathrm{TP}}{\\mathrm{TP} + \\mathrm{FP}}\n", 1049 | " \\approx \\mathbb{P}(y = 1 \\mid \\hat{y} = 1)\n", 1050 | " \\,, \\\\\n", 1051 | "\\mathrm{Recall}\n", 1052 | " &= \\frac{\\mathrm{TP}}{\\mathrm{TP} + \\mathrm{FN}}\n", 1053 | " \\approx \\mathbb{P}(\\hat{y} = 1 \\mid y = 1)\n", 1054 | " \\,.\n", 1055 | "\\end{align}$$" 1056 | ] 1057 | }, 1058 | { 1059 | "cell_type": "code", 1060 | "execution_count": null, 1061 | "metadata": {}, 1062 | "outputs": [], 1063 | "source": [ 1064 | "import pandas as pd\n", 1065 | "\n", 1066 | "def pr_scores(score_matrix):\n", 1067 | " tp = score_matrix.diagonal(axis1=-2, axis2=-1)\n", 1068 | " fp, fn = score_matrix.sum(axis=-2) - tp, score_matrix.sum(axis=-1) - tp\n", 1069 | " \n", 1070 | " return pd.DataFrame({\n", 1071 | " \"precision\": {l: f\"{p:.2%}\" for l, p in enumerate(tp / (tp + fp))},\n", 1072 | " \"recall\": {l: f\"{p:.2%}\" for l, p in enumerate(tp / (tp + fn))},\n", 1073 | " })" 1074 | ] 1075 | }, 1076 | { 1077 | "cell_type": "markdown", 1078 | "metadata": {}, 1079 | "source": [ 1080 | "Let's see the performance on the test set" 1081 | ] 1082 | }, 1083 | { 1084 | "cell_type": "code", 1085 | "execution_count": null, 1086 | "metadata": {}, 1087 | "outputs": [], 1088 | "source": [ 1089 | "scores = {}\n", 1090 | "scores[\"rand\"] = evaluate(model_rand, S_test, n_draws=21)\n", 1091 | "scores[\"bald\"] = evaluate(model_bald, S_test, n_draws=21)" 1092 | ] 1093 | }, 1094 | { 1095 | "cell_type": "markdown", 1096 | "metadata": {}, 1097 | "source": [ 1098 | "
" 1099 | ] 1100 | }, 1101 | { 1102 | "cell_type": "code", 1103 | "execution_count": null, 1104 | "metadata": {}, 1105 | "outputs": [], 1106 | "source": [ 1107 | "df = pd.concat({\n", 1108 | " name: pr_scores(score)\n", 1109 | " for name, score in scores.items()\n", 1110 | "}, axis=1).T\n", 1111 | "\n", 1112 | "df.swaplevel().sort_index()" 1113 | ] 1114 | }, 1115 | { 1116 | "cell_type": "markdown", 1117 | "metadata": {}, 1118 | "source": [ 1119 | "
" 1120 | ] 1121 | }, 1122 | { 1123 | "cell_type": "markdown", 1124 | "metadata": {}, 1125 | "source": [ 1126 | "#### Question(s) (to work on in your spare time)\n", 1127 | "\n", 1128 | "* Run the experiments on the `KMNIST` dataset\n", 1129 | "\n", 1130 | "* Replicate figure 1 from [Gat et al. (2017): p. 4](http://proceedings.mlr.press/v70/gal17a.html).\n", 1131 | " You will need to re-run each experiment several times $11$, recording\n", 1132 | " the accuracy dynamics of each, then compare the mean and $25\\%$-$75\\%$\n", 1133 | " quantiles as they evolve with the size of the training sample." 1134 | ] 1135 | }, 1136 | { 1137 | "cell_type": "markdown", 1138 | "metadata": {}, 1139 | "source": [ 1140 | "
" 1141 | ] 1142 | }, 1143 | { 1144 | "cell_type": "markdown", 1145 | "metadata": {}, 1146 | "source": [ 1147 | "### (optional) Points of improvement: batch-vs-single" 1148 | ] 1149 | }, 1150 | { 1151 | "cell_type": "markdown", 1152 | "metadata": {}, 1153 | "source": [ 1154 | "A drawback of the `pointwise` top-$\\ell$ procedure above is that, although\n", 1155 | "it acquires individually informative instances, altogether they might end\n", 1156 | "up **being** `jointly poorly informative`. This can be corrected, if we\n", 1157 | "would seek the highest mutual information among finite sets $\n", 1158 | "S \\subseteq \\mathcal{S}_\\mathrm{unlabelled}\n", 1159 | "$ of size $\\ell$." 1160 | ] 1161 | }, 1162 | { 1163 | "cell_type": "markdown", 1164 | "metadata": {}, 1165 | "source": [ 1166 | "Such acquisition function is called **batch-BALD**\n", 1167 | "([Kirsch et al.; 2019](https://arxiv.org/abs/1906.08158.pdf)):\n", 1168 | "\n", 1169 | "$$\\begin{align}\n", 1170 | " a(m, S)\n", 1171 | " &= \\mathbb{I}\\bigl((y_x)_{x\\in S}; \\omega \\mid S, m \\bigr)\n", 1172 | " = \\mathbb{H} \\bigl(\n", 1173 | " \\mathbb{E}_{\\omega \\sim q(\\omega\\mid m)} p\\bigl((y_x)_{x\\in S}\\mid S, \\omega, m \\bigr)\n", 1174 | " \\bigr)\n", 1175 | " - \\mathbb{E}_{\\omega \\sim q(\\omega\\mid m)} H\\bigl(\n", 1176 | " p\\bigl((y_x)_{x\\in S}\\mid S, \\omega, m \\bigr)\n", 1177 | " \\bigr)\n", 1178 | "\\end{align}\n", 1179 | " \\,. \\tag{batch-bald} $$" 1180 | ] 1181 | }, 1182 | { 1183 | "cell_type": "markdown", 1184 | "metadata": {}, 1185 | "source": [ 1186 | "This criterion requires combinatorially growing number of computations and\n", 1187 | "memory, however there are working solutions like random sampling of subsets\n", 1188 | "$\\mathcal{S}$ from $\\mathcal{S}_\\mathrm{unlabelled}$ or greedily maximizing\n", 1189 | "of this **submodular** criterion." 1190 | ] 1191 | }, 1192 | { 1193 | "cell_type": "markdown", 1194 | "metadata": {}, 1195 | "source": [ 1196 | "
" 1197 | ] 1198 | }, 1199 | { 1200 | "cell_type": "markdown", 1201 | "metadata": {}, 1202 | "source": [ 1203 | "### (optional) Points of improvement: bias" 1204 | ] 1205 | }, 1206 | { 1207 | "cell_type": "markdown", 1208 | "metadata": {}, 1209 | "source": [ 1210 | "The first term in the **MC** estimate of the mutual information is the\n", 1211 | "so-called **plug-in** estimator of the entropy:\n", 1212 | "\n", 1213 | "$$\n", 1214 | " \\hat{H}\n", 1215 | " = \\mathbb{H}(\\hat{p}) = - \\sum_k \\hat{p}_k \\log \\hat{p}_k\n", 1216 | " \\,, $$\n", 1217 | "\n", 1218 | "where $\\hat{p}_k = \\tfrac1B \\sum_b p_{bk}$ is the full sample estimator\n", 1219 | "of the probabilities." 1220 | ] 1221 | }, 1222 | { 1223 | "cell_type": "markdown", 1224 | "metadata": {}, 1225 | "source": [ 1226 | "It is known that this plug-in estimate is biased\n", 1227 | "(see [blog: Nowozin, 2015](http://www.nowozin.net/sebastian/blog/estimating-discrete-entropy-part-1.html)\n", 1228 | "and references therein, also this [notebook](https://colab.research.google.com/drive/1z9ZDNM6NFmuFnU28d8UO0Qymbd2LiNJW)). \n", 1229 | "In order to correct for small-sample bias we can use\n", 1230 | "[jackknife resampling](https://en.wikipedia.org/wiki/Jackknife_resampling).\n", 1231 | "It derives an estimate of the finite sample bias from the leave-one-out\n", 1232 | "estimators of the entropy and is relatively computationally cheap\n", 1233 | "(see [blog: Nowozin, 2015](http://www.nowozin.net/sebastian/blog/estimating-discrete-entropy-part-2.html),\n", 1234 | "[Miller, R. G. (1974)](http://www.math.ntu.edu.tw/~hchen/teaching/LargeSample/references/Miller74jackknife.pdf) and these [notes](http://people.bu.edu/aimcinto/jackknife.pdf)).\n", 1235 | "\n", 1236 | "The jackknife correction of a plug-in estimator $\\mathbb{H}(\\cdot)$\n", 1237 | "is computed thus: given a sample $(p_b)_{b=1}^B$ with $p_b$ -- discrete distribution on $1..K$\n", 1238 | "* for each $b=1.. B$\n", 1239 | " * get the leave-one-out estimator: $\\hat{p}_k^{-b} = \\tfrac1{B-1} \\sum_{j\\neq b} p_{jk}$\n", 1240 | " * compute the plug-in entropy estimator: $\\hat{H}_{-b} = \\mathbb{H}(\\hat{p}^{-b})$\n", 1241 | "* then compute the bias-corrected entropy estimator $\n", 1242 | "\\hat{H}_J\n", 1243 | " = \\hat{H} + (B - 1) \\bigl\\{\n", 1244 | " \\hat{H} - \\tfrac1B \\sum_b \\hat{H}^{-b}\n", 1245 | " \\bigr\\}\n", 1246 | "$" 1247 | ] 1248 | }, 1249 | { 1250 | "cell_type": "markdown", 1251 | "metadata": {}, 1252 | "source": [ 1253 | "**(note)** when we knock the $i$-th data point out of the sample mean\n", 1254 | "$\\mu = \\tfrac1n \\sum_i x_i$ and recompute the mean $\\mu_{-i}$ we get\n", 1255 | "the following relation\n", 1256 | "$$ \\mu_{-i}\n", 1257 | " = \\frac1{n-1} \\sum_{j\\neq i} x_j\n", 1258 | " = \\frac{n}{n-1} \\mu - \\tfrac1{n-1} x_i\n", 1259 | " = \\mu + \\frac{\\mu - x_i}{n-1}\n", 1260 | " \\,. $$\n", 1261 | "This makes it possible to quickly compute leave-one-out estimators of\n", 1262 | "discrete probability distribution." 1263 | ] 1264 | }, 1265 | { 1266 | "cell_type": "markdown", 1267 | "metadata": {}, 1268 | "source": [ 1269 | "#### (task*) Unbiased estimator of entropy and mutual information\n", 1270 | "\n", 1271 | "Try to efficiently implement a bias-corrected acquisition\n", 1272 | "function, and see it is worth the effort." 1273 | ] 1274 | }, 1275 | { 1276 | "cell_type": "code", 1277 | "execution_count": null, 1278 | "metadata": {}, 1279 | "outputs": [], 1280 | "source": [ 1281 | "def BALD_jknf_acquisition(dataset, model, n_request=1, n_draws=1):\n", 1282 | " proba = sample_proba(model, dataset, n_draws=n_draws)\n", 1283 | "\n", 1284 | " ## Exercise: MC estimate of the predictive distribution, entropy and MI\n", 1285 | " ## mutual information `H E_w p(., w) - E_w H p(., w)` with jackknife\n", 1286 | " ## correction.\n", 1287 | "\n", 1288 | " # plug-in estimate of entropy \n", 1289 | " proba_avg = proba.mean(dim=0)\n", 1290 | " entropy_expected = categorical_entropy(proba_avg)\n", 1291 | "\n", 1292 | " # jackknife correction\n", 1293 | " proba_loo = proba_avg + (proba_avg - proba) / (len(proba) - 1)\n", 1294 | " expected_entropy_loo = categorical_entropy(proba_loo).mean(dim=0)\n", 1295 | " entropy_expected += (len(proba) - 1) * (entropy_expected - expected_entropy_loo)\n", 1296 | "\n", 1297 | " mi = entropy_expected - categorical_entropy(proba).mean(dim=0)\n", 1298 | "\n", 1299 | " return mi.argsort(descending=True)[:n_request]" 1300 | ] 1301 | }, 1302 | { 1303 | "cell_type": "markdown", 1304 | "metadata": {}, 1305 | "source": [ 1306 | "
" 1307 | ] 1308 | }, 1309 | { 1310 | "cell_type": "markdown", 1311 | "metadata": {}, 1312 | "source": [ 1313 | "Let's see ..." 1314 | ] 1315 | }, 1316 | { 1317 | "cell_type": "code", 1318 | "execution_count": null, 1319 | "metadata": { 1320 | "scrolled": false 1321 | }, 1322 | "outputs": [], 1323 | "source": [ 1324 | "jknf_results = active_learn(\n", 1325 | " S_train,\n", 1326 | " S_pool,\n", 1327 | " S_valid,\n", 1328 | " BALD_jknf_acquisition,\n", 1329 | " n_draws=21,\n", 1330 | " n_budget=150,\n", 1331 | " n_max_request=3,\n", 1332 | " n_epochs=200,\n", 1333 | ")" 1334 | ] 1335 | }, 1336 | { 1337 | "cell_type": "code", 1338 | "execution_count": null, 1339 | "metadata": {}, 1340 | "outputs": [], 1341 | "source": [ 1342 | "fig, ax = plt.subplots(1, 1, figsize=(12, 7))\n", 1343 | "\n", 1344 | "model_jknf, train_jknf, scores_jknf, balances_jknf = jknf_results\n", 1345 | "ax.plot(accuracy(scores_rand), label='Accuracy (random)', lw=2)\n", 1346 | "ax.plot(accuracy(scores_bald), label='Accuracy (BALD)', lw=2)\n", 1347 | "ax.plot(accuracy(scores_jknf), label='Accuracy (BALD-jknf)', lw=2)\n", 1348 | "\n", 1349 | "ax.legend()\n", 1350 | "plt.show()" 1351 | ] 1352 | }, 1353 | { 1354 | "cell_type": "code", 1355 | "execution_count": null, 1356 | "metadata": {}, 1357 | "outputs": [], 1358 | "source": [ 1359 | "fig, ax = plt.subplots(1, 1, figsize=(12, 7))\n", 1360 | "\n", 1361 | "lines = ax.plot(balances_jknf, lw=2)\n", 1362 | "plt.legend(lines, list(range(10)), ncol=2);" 1363 | ] 1364 | }, 1365 | { 1366 | "cell_type": "markdown", 1367 | "metadata": {}, 1368 | "source": [ 1369 | "
" 1370 | ] 1371 | } 1372 | ], 1373 | "metadata": { 1374 | "kernelspec": { 1375 | "display_name": "Python 3", 1376 | "language": "python", 1377 | "name": "python3" 1378 | }, 1379 | "language_info": { 1380 | "codemirror_mode": { 1381 | "name": "ipython", 1382 | "version": 3 1383 | }, 1384 | "file_extension": ".py", 1385 | "mimetype": "text/x-python", 1386 | "name": "python", 1387 | "nbconvert_exporter": "python", 1388 | "pygments_lexer": "ipython3", 1389 | "version": "3.7.3" 1390 | } 1391 | }, 1392 | "nbformat": 4, 1393 | "nbformat_minor": 2 1394 | } 1395 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MLSS2019: Bayesian Deep Learning 2 | 3 | ## Installation: colab 4 | 5 | In Google colab there is no need to clone the repo or preinstall anything -- 6 | all jupyter runtimes come with the basic packages like numpy, scipy, and 7 | matplotlib and deep learning libraries keras, tensorflow, and pytorch. 8 | 9 | The only step to make is to change the runtime type to GPU in 10 | **Edit > Notebook settings or Runtime>Change runtime type** by selecting 11 | **GPU as Hardware accelerator**. 12 | 13 | 14 | ## Installation: local install 15 | 16 | Please make sure that you have the following packages installed: 17 | * tqdm 18 | * numpy 19 | * torch >= 1.1 20 | 21 | The most convenient way to ensure this is use Anaconda with python 3.7. 22 | 23 | When all prerequisites have been met, please, clone this repository and 24 | install it with: 25 | 26 | ```bash 27 | git clone https://github.com/ivannz/mlss2019-bayesian-deep-learning.git 28 | 29 | cd mlss2019-bayesian-deep-learning 30 | 31 | pip install --editable . 32 | ``` 33 | 34 | This will install the necessary service python code that will make the 35 | seminar much more concise and, hopefully, your learning experience better. 36 | 37 | 38 | ## Versions 39 | 40 | The version presented at MLSS Moscow Aug 26 - Sep 5, 2019, can also be found 41 | in the [MLSS2019](https://github.com/mlss-skoltech/) repo. Here it sits under 42 | the tag `mlss2019-Aug-30`. 43 | -------------------------------------------------------------------------------- /mlss2019bdl/__init__.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | import torch 3 | 4 | import torch.nn.functional as F 5 | 6 | from torch.utils.data import TensorDataset, DataLoader 7 | 8 | 9 | def dataset_from_numpy(*ndarrays, device=None, dtype=torch.float32): 10 | """Create :class:`TensorDataset` from the passed :class:`numpy.ndarray`-s. 11 | 12 | Each returned tensor in the TensorDataset and :attr:`ndarray` share 13 | the same memory, unless a type cast or device transfer took place. 14 | Modifications to any tensor in the dataset will be reflected in respective 15 | :attr:`ndarray` and vice versa. 16 | 17 | Each returned tensor in the dataset is not resizable. 18 | 19 | See Also 20 | -------- 21 | torch.from_numpy : create a tensor from an ndarray. 22 | """ 23 | tensors = map(torch.from_numpy, ndarrays) 24 | 25 | return TensorDataset(*[t.to(device, dtype) for t in tensors]) 26 | 27 | 28 | default_criteria = { 29 | "cross_entropy": 30 | lambda model, X, y: F.cross_entropy(model(X), y, reduction="mean"), 31 | "mse": 32 | lambda model, X, y: 0.5 * F.mse_loss(model(X), y, reduction="mean"), 33 | } 34 | 35 | 36 | def fit(model, dataset, criterion="mse", batch_size=32, 37 | n_epochs=1, weight_decay=0, verbose=False): 38 | """Fit the model with SGD (Adam) on the specified dataset and criterion. 39 | 40 | This bare minimum of a fit loop creates a minibatch generator from 41 | the `dataset` with batches of size `batch_size`. On each batch it 42 | computes the backward pass through the `criterion` and the `model` 43 | and updates the `model`-s parameters with the Adam optimizer step. 44 | The loop passes through the dataset `n_epochs` times. It does not 45 | output any running debugging information, except for a progress bar. 46 | 47 | The criterion can be either "mse" for mean sqaured error, "nll" for 48 | negative loglikelihood (categorical), or a callable taking `model, X, y` 49 | as arguments. 50 | """ 51 | if len(dataset) <= 0 or batch_size <= 0: 52 | return model 53 | 54 | criterion = default_criteria.get(criterion, criterion) 55 | assert callable(criterion) 56 | 57 | # get the model's device 58 | device = next(model.parameters()).device 59 | 60 | # an optimizer for model's parameters 61 | optim = torch.optim.Adam(model.parameters(), lr=2e-3, 62 | weight_decay=weight_decay) 63 | 64 | # stochastic minibatch generator for the training loop 65 | feed = DataLoader(dataset, shuffle=True, batch_size=batch_size) 66 | for epoch in tqdm.tqdm(range(n_epochs), disable=not verbose): 67 | 68 | model.train() 69 | 70 | for X, y in feed: 71 | # forward pass through the criterion (batch-average loss) 72 | loss = criterion(model, X.to(device), y.to(device)) 73 | 74 | # get gradients with backward pass 75 | optim.zero_grad() 76 | loss.backward() 77 | 78 | # SGD update 79 | optim.step() 80 | 81 | return model 82 | 83 | 84 | def predict(model, dataset, batch_size=512): 85 | """Get model's output on the dataset. 86 | 87 | This straightforward function switches the model into `evaluation` 88 | regime, computes the forward pass on the `dataset` (in batches of 89 | size `batch_size`) and stacks the results into a tensor on the `cpu`. 90 | It temporarily disables `autograd` to gain some speed-up. 91 | """ 92 | model.eval() 93 | 94 | # get the model's device 95 | device = next(model.parameters()).device 96 | 97 | # batch generator for the evaluation loop 98 | feed = DataLoader(dataset, batch_size=batch_size, shuffle=False) 99 | 100 | # compute and collect the outputs 101 | with torch.no_grad(): 102 | return torch.cat([ 103 | model(X.to(device)).cpu() for X, *rest in feed 104 | ], dim=0) 105 | -------------------------------------------------------------------------------- /mlss2019bdl/bdl/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import freeze, unfreeze, penalties 2 | 3 | 4 | from .bernoulli import DropoutLinear, DropoutConv2d 5 | from .gaussian import GaussianLinearARD, GaussianConv2dARD -------------------------------------------------------------------------------- /mlss2019bdl/bdl/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch.nn import Module 4 | 5 | 6 | class FreezableWeight(Module): 7 | def __init__(self): 8 | super().__init__() 9 | self.unfreeze() 10 | 11 | def unfreeze(self): 12 | self.register_buffer("frozen_weight", None) 13 | 14 | def is_frozen(self): 15 | """Check if a frozen weight is available.""" 16 | return isinstance(self.frozen_weight, torch.Tensor) 17 | 18 | def freeze(self): 19 | """Sample from the distribution and freeze.""" 20 | raise NotImplementedError() 21 | 22 | 23 | def freeze(module): 24 | for mod in module.modules(): 25 | if isinstance(mod, FreezableWeight): 26 | mod.freeze() 27 | 28 | return module # return self 29 | 30 | 31 | def unfreeze(module): 32 | for mod in module.modules(): 33 | if isinstance(mod, FreezableWeight): 34 | mod.unfreeze() 35 | 36 | return module # return self 37 | 38 | 39 | class PenalizedWeight(Module): 40 | def penalty(self): 41 | raise NotImplementedError() 42 | 43 | 44 | def penalties(module): 45 | for mod in module.modules(): 46 | if isinstance(mod, PenalizedWeight): 47 | yield mod.penalty() 48 | -------------------------------------------------------------------------------- /mlss2019bdl/bdl/bernoulli.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from torch.nn import Linear, Conv2d 5 | 6 | from .base import FreezableWeight, PenalizedWeight 7 | 8 | 9 | class DropoutLinear(Linear, FreezableWeight): 10 | """Linear layer with dropout on inputs.""" 11 | def __init__(self, in_features, out_features, bias=True, p=0.5): 12 | super().__init__(in_features, out_features, bias=bias) 13 | 14 | self.p = p 15 | 16 | def forward(self, input): 17 | if self.is_frozen(): 18 | return F.linear(input, self.frozen_weight, self.bias) 19 | 20 | return super().forward(F.dropout(input, self.p, True)) 21 | 22 | def freeze(self): 23 | # let's draw the new weight 24 | with torch.no_grad(): 25 | prob = torch.full_like(self.weight[:1, :], 1 - self.p) 26 | feature_mask = torch.bernoulli(prob) / prob 27 | 28 | frozen_weight = self.weight * feature_mask 29 | 30 | # and store it 31 | self.register_buffer("frozen_weight", frozen_weight) 32 | 33 | 34 | class DropoutConv2d(Conv2d, FreezableWeight): 35 | """2d Convolutional layer with dropout on input features.""" 36 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 37 | padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', 38 | p=0.5): 39 | 40 | super().__init__(in_channels, out_channels, kernel_size, stride=stride, 41 | padding=padding, dilation=dilation, groups=groups, 42 | bias=bias, padding_mode=padding_mode) 43 | 44 | self.p = p 45 | 46 | def forward(self, input): 47 | """Apply feature dropout and then forward pass through the convolution.""" 48 | if self.is_frozen(): 49 | return F.conv2d(input, self.frozen_weight, self.bias, self.stride, 50 | self.padding, self.dilation, self.groups) 51 | 52 | return super().forward(F.dropout2d(input, self.p, True)) 53 | 54 | def freeze(self): 55 | """Sample the weight from the parameter distribution and freeze it.""" 56 | prob = torch.full_like(self.weight[:1, :, :1, :1], 1 - self.p) 57 | feature_mask = torch.bernoulli(prob) / prob 58 | 59 | with torch.no_grad(): 60 | frozen_weight = self.weight * feature_mask 61 | 62 | self.register_buffer("frozen_weight", frozen_weight) 63 | 64 | -------------------------------------------------------------------------------- /mlss2019bdl/bdl/gaussian.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from torch.nn import Linear, Conv2d 5 | 6 | from .base import FreezableWeight, PenalizedWeight 7 | 8 | 9 | class BaseGaussianLinear(Linear, FreezableWeight, PenalizedWeight): 10 | """Linear layer with Gaussian Mean Field weight distribution.""" 11 | def __init__(self, in_features, out_features, bias=True): 12 | super().__init__(in_features, out_features, bias=bias) 13 | 14 | self.log_sigma2 = torch.nn.Parameter( 15 | torch.Tensor(*self.weight.shape)) 16 | 17 | self.reset_variational_parameters() 18 | 19 | def reset_variational_parameters(self): 20 | self.log_sigma2.data.normal_(-5, 0.1) # from arxiv:1811.00596 21 | 22 | def forward(self, input): 23 | """Forward pass for the linear layer with the local reparameterization trick.""" 24 | 25 | if self.is_frozen(): 26 | return F.linear(input, self.frozen_weight, self.bias) 27 | 28 | s2 = F.linear(input * input, torch.exp(self.log_sigma2), None) 29 | 30 | mu = super().forward(input) 31 | return mu + torch.randn_like(s2) * torch.sqrt(torch.clamp(s2, 1e-8)) 32 | 33 | def freeze(self): 34 | 35 | with torch.no_grad(): 36 | stdev = torch.exp(0.5 * self.log_sigma2) 37 | weight = torch.normal(self.weight, std=stdev) 38 | 39 | self.register_buffer("frozen_weight", weight) 40 | 41 | 42 | class BaseGaussianConv2d(Conv2d, PenalizedWeight, FreezableWeight): 43 | """Convolutional layer with Gaussian Mean Field weight distribution.""" 44 | 45 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 46 | padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'): 47 | super().__init__(in_channels, out_channels, kernel_size, stride=stride, 48 | padding=padding, dilation=dilation, groups=groups, 49 | bias=bias, padding_mode=padding_mode) 50 | 51 | self.log_sigma2 = torch.nn.Parameter( 52 | torch.Tensor(*self.weight.shape)) 53 | 54 | self.reset_variational_parameters() 55 | 56 | reset_variational_parameters = BaseGaussianLinear.reset_variational_parameters 57 | 58 | def forward(self, input): 59 | """Forward pass with the local reparameterization trick.""" 60 | if self.is_frozen(): 61 | return F.conv2d(input, self.frozen_weight, self.bias, self.stride, 62 | self.padding, self.dilation, self.groups) 63 | 64 | s2 = F.conv2d(input * input, torch.exp(self.log_sigma2), None, 65 | self.stride, self.padding, self.dilation, self.groups) 66 | 67 | mu = super().forward(input) 68 | return mu + torch.randn_like(s2) * torch.sqrt(torch.clamp(s2, 1e-8)) 69 | 70 | freeze = BaseGaussianLinear.freeze 71 | 72 | 73 | class GaussianLinearARD(BaseGaussianLinear): 74 | def penalty(self): 75 | # compute \tfrac12 \log (1 + \tfrac{\mu_{ji}}{\sigma_{ji}^2}) 76 | log_weight2 = 2 * torch.log(torch.abs(self.weight) + 1e-20) 77 | 78 | # `softplus` is $x \mapsto \log(1 + e^x)$ 79 | return 0.5 * torch.sum(F.softplus(log_weight2 - self.log_sigma2)) 80 | 81 | 82 | class GaussianConv2dARD(BaseGaussianConv2d): 83 | penalty = GaussianLinearARD.penalty 84 | -------------------------------------------------------------------------------- /mlss2019bdl/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from torch.utils.data import TensorDataset 5 | from torchvision import datasets 6 | 7 | from sklearn.utils import check_random_state 8 | from sklearn.model_selection import train_test_split 9 | 10 | 11 | def get_data(name, path="./data", train=True): 12 | if name == "MNIST": 13 | dataset = datasets.MNIST(path, train=train, download=True) 14 | elif name == "KMNIST": 15 | dataset = datasets.KMNIST(path, train=train, download=True) 16 | 17 | images = dataset.data.float().unsqueeze(1) 18 | return TensorDataset(images / 255., dataset.targets) 19 | 20 | 21 | def get_dataset(n_train=20, n_valid=5000, n_pool=5000, 22 | name="MNIST", path="./data", random_state=None): 23 | random_state = check_random_state(random_state) 24 | 25 | dataset = get_data(name, path, train=True) 26 | S_test = get_data(name, path, train=False) 27 | 28 | # create an imbalanced class label distribution for the train 29 | targets = dataset.tensors[-1].cpu().numpy() 30 | 31 | # split the dataset into validaton and train 32 | ix_all = np.r_[:len(targets)] 33 | ix_train, ix_valid = train_test_split( 34 | ix_all, stratify=targets, shuffle=True, 35 | train_size=max(n_train, 1), test_size=max(n_valid, 1), 36 | random_state=random_state) 37 | 38 | # prepare the datasets: pool, train and validation 39 | if n_train < 1: 40 | ix_train = np.r_[:0] 41 | S_train = TensorDataset(*dataset[ix_train]) 42 | 43 | if n_valid < 1: 44 | ix_valid = np.r_[:0] 45 | S_valid = TensorDataset(*dataset[ix_valid]) 46 | 47 | # prepare the pool 48 | ix_pool = np.delete(ix_all, np.r_[ix_train, ix_valid]) 49 | 50 | # we want to have lots of boring/useless examples in the pool 51 | labels, share = (1, 2, 3, 4, 5, 6, 7, 8, 9), 0.95 52 | pool_targets, dropped = targets[ix_pool], [] 53 | 54 | # deplete the pool of each class 55 | for label in labels: 56 | ix_cls = np.flatnonzero(pool_targets == label) 57 | n_kept = int(share * len(ix_cls)) 58 | 59 | # pick examples at random to drop 60 | ix_cls = random_state.permutation(ix_cls) 61 | dropped.append(ix_cls[:n_kept]) 62 | 63 | ix_pool = np.delete(ix_pool, np.concatenate(dropped)) 64 | 65 | # select at most `n_pool` examples 66 | if n_pool > 0: 67 | ix_pool = random_state.permutation(ix_pool)[:n_pool] 68 | S_pool = TensorDataset(*dataset[ix_pool]) 69 | 70 | return S_train, S_pool, S_valid, S_test 71 | 72 | 73 | def collect(indices, dataset): 74 | """Collect the specified samples from the dataset and remove.""" 75 | assert len(dataset) > 0 76 | 77 | mask = torch.zeros(len(dataset), dtype=torch.bool) 78 | mask[indices.long()] = True 79 | 80 | collected = TensorDataset(*dataset[mask]) 81 | 82 | dataset.tensors = dataset[~mask] 83 | 84 | return collected 85 | 86 | 87 | def merge(*datasets, out=None): 88 | # Classes derived from Dataset support appending via 89 | # `+` (__add__), but this breaks slicing. 90 | 91 | data = [d.tensors for d in datasets if d is not None and d.tensors] 92 | assert all(len(data[0]) == len(d) for d in data) 93 | 94 | tensors = [torch.cat(tup, dim=0) for tup in zip(*data)] 95 | 96 | if isinstance(out, TensorDataset): 97 | out.tensors = tensors 98 | return out 99 | 100 | return TensorDataset(*tensors) 101 | -------------------------------------------------------------------------------- /mlss2019bdl/flex.py: -------------------------------------------------------------------------------- 1 | """Handy plotting procedures for small 2d images.""" 2 | import numpy as np 3 | 4 | from torch import Tensor 5 | from math import sqrt 6 | 7 | 8 | def get_dimensions(n_samples, height, width, 9 | n_row=None, n_col=None, aspect=(16, 9)): 10 | """Get the dimensions that aesthetically conform to the aspect ratio.""" 11 | if n_row is None and n_col is None: 12 | ratio = (width * aspect[1]) / (height * aspect[0]) 13 | n_row = int(sqrt(n_samples * ratio)) 14 | 15 | if n_row is None: 16 | n_row = (n_samples + n_col - 1) // n_col 17 | 18 | elif n_col is None: 19 | n_col = (n_samples + n_row - 1) // n_row 20 | 21 | return n_row, n_col 22 | 23 | 24 | def setup_canvas(ax, height, width, n_row, n_col): 25 | """Setup the ticks and labels for the canvas.""" 26 | # A pair of index arrays 27 | row_index, col_index = np.r_[:n_row], np.r_[:n_col] 28 | 29 | # Setup major ticks to the seams between images and disable labels 30 | ax.set_yticks((row_index[:-1] + 1) * height - 0.5, minor=False) 31 | ax.set_xticks((col_index[:-1] + 1) * width - 0.5, minor=False) 32 | 33 | ax.set_yticklabels([], minor=False) 34 | ax.set_xticklabels([], minor=False) 35 | 36 | # Set minor ticks so that they are exactly between the major ones 37 | ax.set_yticks((row_index + 0.5) * height, minor=True) 38 | ax.set_xticks((col_index + 0.5) * width, minor=True) 39 | 40 | # ... and make their labels into i-j coordinates 41 | ax.set_yticklabels([f"{i:d}" for i in row_index], minor=True) 42 | ax.set_xticklabels([f"{j:d}" for j in col_index], minor=True) 43 | 44 | # Orient tick marks outward 45 | ax.tick_params(axis="both", which="both", direction="out") 46 | return ax 47 | 48 | 49 | def arrange(n_row, n_col, data, fill_value=0): 50 | """Create a grid and populate it with images.""" 51 | n_samples, height, width, *color = data.shape 52 | grid = np.full((n_row * height, n_col * width, *color), 53 | fill_value, dtype=data.dtype) 54 | 55 | for k in range(min(n_samples, n_col * n_row)): 56 | i, j = (k // n_col) * height, (k % n_col) * width 57 | grid[i:i + height, j:j + width] = data[k] 58 | 59 | return grid 60 | 61 | 62 | def to_hwc(images, format): 63 | assert format in ("chw", "hwc"), f"Unrecognized format `{format}`." 64 | 65 | if images.ndim == 3: 66 | return images[..., np.newaxis] 67 | 68 | assert images.ndim == 4, f"Images must be Nx{'x'.join(format.upper())}." 69 | 70 | if format == "chw": 71 | return images.transpose(0, 2, 3, 1) 72 | 73 | elif format == "hwc": 74 | return images 75 | 76 | 77 | def plot(ax, images, *, n_col=None, n_row=None, format="chw", **kwargs): 78 | """Plot images in the numpy array on the specified matplotlib Axis.""" 79 | if isinstance(images, Tensor): 80 | images = images.data.cpu().numpy() 81 | 82 | images = to_hwc(images, format) 83 | 84 | n_samples, height, width, *color = images.shape 85 | if n_samples < 1: 86 | return None 87 | 88 | n_row, n_col = get_dimensions(n_samples, height, width, n_row, n_col) 89 | ax = setup_canvas(ax, height, width, n_row, n_col) 90 | 91 | image = arrange(n_row, n_col, images) 92 | return ax.imshow(image.squeeze(), **kwargs, origin="upper") 93 | -------------------------------------------------------------------------------- /mlss2019bdl/plotting.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | from torch import Tensor 4 | from numpy import asarray 5 | 6 | 7 | def darker(color, a=0.5): 8 | """Adapted from this stackoverflow question_. 9 | 10 | .. _question: https://stackoverflow.com/questions/37765197/ 11 | """ 12 | from matplotlib.colors import to_rgb 13 | from colorsys import rgb_to_hls, hls_to_rgb 14 | 15 | h, l, s = rgb_to_hls(*to_rgb(color)) 16 | return hls_to_rgb(h, max(0, min(a * l, 1)), s) 17 | 18 | 19 | def canvas1d(*, figsize=(12, 5)): 20 | """Setup canvas for 1d function plot.""" 21 | fig, ax = plt.subplots(1, 1, figsize=figsize) 22 | 23 | fig.patch.set_alpha(1.0) 24 | ax.set_xlim(-7, +7) 25 | ax.set_ylim(-7, +9) 26 | 27 | return fig, ax 28 | 29 | 30 | def to_numpy(tensor): 31 | if isinstance(tensor, Tensor): 32 | tensor = tensor.data.cpu().numpy() 33 | 34 | return asarray(tensor).squeeze() 35 | 36 | 37 | def plot1d(X, y, bands, ax=None, **kwargs): 38 | X, y = to_numpy(X), to_numpy(y) 39 | assert y.ndim == 2 and X.ndim == 1 40 | 41 | ax = plt.gca() if ax is None else ax 42 | 43 | # plot the predictive mean with the specified colour 44 | y_mean, y_std = y.mean(axis=-1), y.std(axis=-1) 45 | line, = ax.plot(X, y_mean, **kwargs) 46 | 47 | # plot paths or bands with a lighter color and slightly behind the mean 48 | color, zorder = darker(line.get_color(), 1.25), line.get_zorder() 49 | if bands is None: 50 | ax.plot(X, y, c=color, alpha=0.08, zorder=zorder - 1) 51 | 52 | else: 53 | for band in sorted(bands): 54 | ax.fill_between(X, y_mean + band * y_std, y_mean - band * y_std, 55 | color=color, zorder=zorder-1, 56 | alpha=0.4 / len(bands)) 57 | 58 | return line 59 | 60 | 61 | def plot1d_bands(X, y, ax=None, **kwargs): 62 | # return plot1d(X, y, bands=(0.5, 1.0, 1.5, 2.0), ax=ax, **kwargs) 63 | return plot1d(X, y, bands=(1.96,), ax=ax, **kwargs) 64 | 65 | 66 | def plot1d_paths(X, y, ax=None, **kwargs): 67 | return plot1d(X, y, bands=None, ax=ax, **kwargs) 68 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | 3 | setup( 4 | name="mlss2019bdl", 5 | version="0.2", 6 | description="""MLSS2019 Tutorial on Bayesian Active Learning""", 7 | license="MIT License", 8 | author="Ivan Nazarov, Yarin Gal", 9 | author_email="ivan.nazarov@skolkovotech.ru", 10 | packages=[ 11 | "mlss2019bdl", 12 | "mlss2019bdl.bdl", 13 | ], 14 | install_requires=[ 15 | "numpy", 16 | "tqdm", 17 | "matplotlib", 18 | "torch", 19 | "torchvision", 20 | ] 21 | ) 22 | --------------------------------------------------------------------------------