├── .gitignore
├── Bayesian Deep Learning part1.ipynb
├── Bayesian Deep Learning part2.ipynb
├── README.md
├── mlss2019bdl
├── __init__.py
├── bdl
│ ├── __init__.py
│ ├── base.py
│ ├── bernoulli.py
│ └── gaussian.py
├── dataset.py
├── flex.py
└── plotting.py
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # General
2 | .DS_Store
3 | .AppleDouble
4 | .LSOverride
5 |
6 | # Icon must end with two \r
7 | Icon
8 |
9 |
10 | # Thumbnails
11 | ._*
12 |
13 | # Files that might appear in the root of a volume
14 | .DocumentRevisions-V100
15 | .fseventsd
16 | .Spotlight-V100
17 | .TemporaryItems
18 | .Trashes
19 | .VolumeIcon.icns
20 | .com.apple.timemachine.donotpresent
21 |
22 | # Directories potentially created on remote AFP share
23 | .AppleDB
24 | .AppleDesktop
25 | Network Trash Folder
26 | Temporary Items
27 | .apdisk
28 |
29 | # Byte-compiled / optimized / DLL files
30 | __pycache__/
31 | *.py[cod]
32 | *$py.class
33 |
34 | # C extensions
35 | *.so
36 |
37 | # Distribution / packaging
38 | .Python
39 | build/
40 | develop-eggs/
41 | dist/
42 | downloads/
43 | eggs/
44 | .eggs/
45 | lib/
46 | lib64/
47 | parts/
48 | sdist/
49 | var/
50 | wheels/
51 | *.egg-info/
52 | .installed.cfg
53 | *.egg
54 | MANIFEST
55 |
56 | # PyInstaller
57 | # Usually these files are written by a python script from a template
58 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
59 | *.manifest
60 | *.spec
61 |
62 | # Installer logs
63 | pip-log.txt
64 | pip-delete-this-directory.txt
65 |
66 | # Unit test / coverage reports
67 | htmlcov/
68 | .tox/
69 | .coverage
70 | .coverage.*
71 | .cache
72 | nosetests.xml
73 | coverage.xml
74 | *.cover
75 | .hypothesis/
76 | .pytest_cache/
77 |
78 | # Translations
79 | *.mo
80 | *.pot
81 |
82 | # Django stuff:
83 | *.log
84 | local_settings.py
85 | db.sqlite3
86 |
87 | # Flask stuff:
88 | instance/
89 | .webassets-cache
90 |
91 | # Scrapy stuff:
92 | .scrapy
93 |
94 | # Sphinx documentation
95 | docs/_build/
96 |
97 | # PyBuilder
98 | target/
99 |
100 | # Jupyter Notebook
101 | .ipynb_checkpoints
102 |
103 | # pyenv
104 | .python-version
105 |
106 | # celery beat schedule file
107 | celerybeat-schedule
108 |
109 | # SageMath parsed files
110 | *.sage.py
111 |
112 | # Environments
113 | .env
114 | .venv
115 | env/
116 | venv/
117 | ENV/
118 | env.bak/
119 | venv.bak/
120 |
121 | # Spyder project settings
122 | .spyderproject
123 | .spyproject
124 |
125 | # Rope project settings
126 | .ropeproject
127 |
128 | # mkdocs documentation
129 | /site
130 |
131 | # mypy
132 | .mypy_cache/
133 |
--------------------------------------------------------------------------------
/Bayesian Deep Learning part1.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# MLSS2019: Bayesian Deep Learning"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "In this tutorial we will learn what basic building blocks are needed\n",
15 | "to endow (deep) neural networks with uncertainty estimates."
16 | ]
17 | },
18 | {
19 | "cell_type": "markdown",
20 | "metadata": {},
21 | "source": [
22 | "The plan of the tutorial\n",
23 | "1. [Setup and imports](#Setup-and-imports)\n",
24 | "2. [Easy uncertainty in networks](#Easy-uncertainty-in-networks)\n",
25 | " 1. [Bayesification via dropout and weight decay](#Bayesification-via-dropout-and-weight-decay)\n",
26 | " 2. [Implementing function sampling with the DropoutLinear Layer](#Implementing-function-sampling-with-the-DropoutLinear-Layer)\n",
27 | " 3. [Implementing-DropoutLinear](#Implementing-DropoutLinear)\n",
28 | " 4. [Comparing sample functions to point-estimates](#Comparing-sample-functions-to-point-estimates)\n",
29 | "3. [(optional) Dropout $2$-d Convolutional layer](#(optional)-Dropout-$2$-d-Convolutional-layer)\n",
30 | "4. [(optional) A brief reminder on Bayesian and Variational Inference](#(optional)-A-brief-reminder-on-Bayesian-and-Variational-Inference)"
31 | ]
32 | },
33 | {
34 | "cell_type": "markdown",
35 | "metadata": {},
36 | "source": [
37 | "**(note)**\n",
38 | "* to view documentation on something type in `something?` (with one question mark)\n",
39 | "* to view code of something type in `something??` (with two question marks)."
40 | ]
41 | },
42 | {
43 | "cell_type": "markdown",
44 | "metadata": {},
45 | "source": [
46 | "
"
47 | ]
48 | },
49 | {
50 | "cell_type": "markdown",
51 | "metadata": {},
52 | "source": [
53 | "## Setup and imports"
54 | ]
55 | },
56 | {
57 | "cell_type": "markdown",
58 | "metadata": {},
59 | "source": [
60 | "In this section we import necessary modules and functions and\n",
61 | "define the computational device."
62 | ]
63 | },
64 | {
65 | "cell_type": "markdown",
66 | "metadata": {},
67 | "source": [
68 | "First, we install some boilerplate service code for this tutorial."
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "execution_count": null,
74 | "metadata": {},
75 | "outputs": [],
76 | "source": [
77 | "!pip install -q --upgrade git+https://github.com/ivannz/mlss2019-bayesian-deep-learning.git"
78 | ]
79 | },
80 | {
81 | "cell_type": "markdown",
82 | "metadata": {},
83 | "source": [
84 | "Next, numpy for computing, matplotlib for plotting and tqdm for progress bars."
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": null,
90 | "metadata": {},
91 | "outputs": [],
92 | "source": [
93 | "import tqdm\n",
94 | "import numpy as np\n",
95 | "\n",
96 | "%matplotlib inline\n",
97 | "import matplotlib.pyplot as plt"
98 | ]
99 | },
100 | {
101 | "cell_type": "markdown",
102 | "metadata": {},
103 | "source": [
104 | "For deep learning stuff will be using [pytorch](https://pytorch.org/).\n",
105 | "\n",
106 | "If you are unfamiliar with it, it is basically like `numpy` with autograd,\n",
107 | "stricter data type enforcement, native GPU support, and tools for building\n",
108 | "training and serializing models.\n",
109 | "\n",
110 | "\n",
111 | "There are good introductory tutorials on `pytorch`, like this\n",
112 | "[one](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html)."
113 | ]
114 | },
115 | {
116 | "cell_type": "code",
117 | "execution_count": null,
118 | "metadata": {},
119 | "outputs": [],
120 | "source": [
121 | "import torch\n",
122 | "import torch.nn.functional as F\n",
123 | "\n",
124 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
125 | ]
126 | },
127 | {
128 | "cell_type": "markdown",
129 | "metadata": {},
130 | "source": [
131 | "We will need some functionality from scikit"
132 | ]
133 | },
134 | {
135 | "cell_type": "code",
136 | "execution_count": null,
137 | "metadata": {},
138 | "outputs": [],
139 | "source": [
140 | "from sklearn.metrics import confusion_matrix"
141 | ]
142 | },
143 | {
144 | "cell_type": "markdown",
145 | "metadata": {},
146 | "source": [
147 | "Next we import the boilerplate code.\n",
148 | "\n",
149 | "* a procedure that implements a minibatch SGD **fit** loop\n",
150 | "* a function, that **evaluates** the model on the provided dataset"
151 | ]
152 | },
153 | {
154 | "cell_type": "code",
155 | "execution_count": null,
156 | "metadata": {},
157 | "outputs": [],
158 | "source": [
159 | "from mlss2019bdl import fit"
160 | ]
161 | },
162 | {
163 | "cell_type": "markdown",
164 | "metadata": {},
165 | "source": [
166 | "```python\n",
167 | "# pseudocode\n",
168 | "def fit(model, dataset, criterion, ...):\n",
169 | " for epoch in epochs:\n",
170 | " for batch in dataset:\n",
171 | " loss = criterion(model, batch) # forward pass\n",
172 | "\n",
173 | " grad = loss.backward() # gradient via back propagation\n",
174 | "\n",
175 | " adam_step(grad)\n",
176 | "```"
177 | ]
178 | },
179 | {
180 | "cell_type": "code",
181 | "execution_count": null,
182 | "metadata": {},
183 | "outputs": [],
184 | "source": [
185 | "from mlss2019bdl import predict"
186 | ]
187 | },
188 | {
189 | "cell_type": "markdown",
190 | "metadata": {},
191 | "source": [
192 | "```python\n",
193 | "# pseudocode\n",
194 | "def predict(model, dataset, ...):\n",
195 | " for input_batch in dataset:\n",
196 | " output.append(model(input_batch)) # forward pass\n",
197 | " \n",
198 | " return concatenate(output)\n",
199 | "```"
200 | ]
201 | },
202 | {
203 | "cell_type": "markdown",
204 | "metadata": {},
205 | "source": [
206 | "
"
207 | ]
208 | },
209 | {
210 | "cell_type": "markdown",
211 | "metadata": {},
212 | "source": [
213 | "## Easy uncertainty in networks"
214 | ]
215 | },
216 | {
217 | "cell_type": "markdown",
218 | "metadata": {},
219 | "source": [
220 | "Generate the initial small dataset $S_0 = (x_i, y_i)_{i=1}^{m_0}$\n",
221 | "with $y_i = g(x_i)$, $x_i$ on a regular-spaced grid, and $\n",
222 | "g\n",
223 | " \\colon \\mathbb{R} \\to \\mathbb{R}\n",
224 | " \\colon x \\mapsto \\tfrac{x^2}4 + \\sin \\frac\\pi2 x\n",
225 | "$.\n",
226 | ""
231 | ]
232 | },
233 | {
234 | "cell_type": "code",
235 | "execution_count": null,
236 | "metadata": {},
237 | "outputs": [],
238 | "source": [
239 | "from mlss2019bdl import dataset_from_numpy\n",
240 | "\n",
241 | "X_train = np.linspace(-6.0, +6.0, num=20)[:, np.newaxis]\n",
242 | "y_train = np.sin(X_train * np.pi / 2) + 0.25 * X_train**2\n",
243 | "\n",
244 | "train = dataset_from_numpy(X_train, y_train, device=device)"
245 | ]
246 | },
247 | {
248 | "cell_type": "code",
249 | "execution_count": null,
250 | "metadata": {},
251 | "outputs": [],
252 | "source": [
253 | "X_domain = np.linspace(-10., +10., num=251)[:, np.newaxis]\n",
254 | "\n",
255 | "domain = dataset_from_numpy(X_domain, device=device)"
256 | ]
257 | },
258 | {
259 | "cell_type": "markdown",
260 | "metadata": {},
261 | "source": [
262 | "Suppose we have the following model: a 3-layer fully connected\n",
263 | "network with LeakyReLU activations."
264 | ]
265 | },
266 | {
267 | "cell_type": "code",
268 | "execution_count": null,
269 | "metadata": {},
270 | "outputs": [],
271 | "source": [
272 | "from torch.nn import Linear, Sequential\n",
273 | "from torch.nn import LeakyReLU\n",
274 | "\n",
275 | "\n",
276 | "model = Sequential(\n",
277 | " Linear(1, 512, bias=True),\n",
278 | " LeakyReLU(),\n",
279 | "\n",
280 | " Linear(512, 512, bias=True),\n",
281 | " LeakyReLU(),\n",
282 | "\n",
283 | " Linear(512, 1, bias=True),\n",
284 | ")\n",
285 | "\n",
286 | "model.to(device)"
287 | ]
288 | },
289 | {
290 | "cell_type": "markdown",
291 | "metadata": {},
292 | "source": [
293 | "
"
294 | ]
295 | },
296 | {
297 | "cell_type": "markdown",
298 | "metadata": {},
299 | "source": [
300 | "We fit our model on `train` using MSE loss and $\\ell_2$ penalty on\n",
301 | "weights (`weight_decay`):\n",
302 | "$$\n",
303 | " \\tfrac1{2 m} \\|f_\\omega(x) - y\\|_2^2 + \\lambda \\|\\omega\\|_2^2\n",
304 | " \\,, $$\n",
305 | "where $\\omega$ are all the learnable parameters of the network $f_\\omega$."
306 | ]
307 | },
308 | {
309 | "cell_type": "markdown",
310 | "metadata": {},
311 | "source": [
312 | "
"
313 | ]
314 | },
315 | {
316 | "cell_type": "markdown",
317 | "metadata": {},
318 | "source": [
319 | "Fit, ..."
320 | ]
321 | },
322 | {
323 | "cell_type": "code",
324 | "execution_count": null,
325 | "metadata": {
326 | "scrolled": false
327 | },
328 | "outputs": [],
329 | "source": [
330 | "fit(model, train, criterion=\"mse\", n_epochs=2000, verbose=True, weight_decay=1e-3)"
331 | ]
332 | },
333 | {
334 | "cell_type": "markdown",
335 | "metadata": {},
336 | "source": [
337 | "..., compute the predictions, ..."
338 | ]
339 | },
340 | {
341 | "cell_type": "code",
342 | "execution_count": null,
343 | "metadata": {},
344 | "outputs": [],
345 | "source": [
346 | "y_pred = predict(model, domain)"
347 | ]
348 | },
349 | {
350 | "cell_type": "markdown",
351 | "metadata": {},
352 | "source": [
353 | "..., and plot them."
354 | ]
355 | },
356 | {
357 | "cell_type": "code",
358 | "execution_count": null,
359 | "metadata": {},
360 | "outputs": [],
361 | "source": [
362 | "fig, ax = plt.subplots(1, 1, figsize=(12, 5))\n",
363 | "\n",
364 | "ax.scatter(X_train, y_train, c=\"black\", s=40, label=\"train\")\n",
365 | "\n",
366 | "ax.plot(X_domain, y_pred.numpy(), c=\"C0\", lw=2, label=\"prediction\")\n",
367 | "\n",
368 | "plt.legend();"
369 | ]
370 | },
371 | {
372 | "cell_type": "markdown",
373 | "metadata": {},
374 | "source": [
375 | "This model seems to fit the train set adequately well. However, there is no\n",
376 | "way to assess how confident this model is with respect to its predictions.\n",
377 | "Indeed, the prediction $\\hat{y}_x = f_\\omega(x)$ is is a deterministic function\n",
378 | "of the input $x$ and the learnt parameters $\\omega$."
379 | ]
380 | },
381 | {
382 | "cell_type": "markdown",
383 | "metadata": {},
384 | "source": [
385 | "
"
386 | ]
387 | },
388 | {
389 | "cell_type": "markdown",
390 | "metadata": {},
391 | "source": [
392 | "### `Bayesification` via dropout and weight decay"
393 | ]
394 | },
395 | {
396 | "cell_type": "markdown",
397 | "metadata": {},
398 | "source": [
399 | "One inexpensive way to make any network into a stochastic function of its\n",
400 | "input is to add dropout before any parameterized layer like `linear`\n",
401 | "or `convolutional`, [Hinton et al. 2012](https://arxiv.org/abs/1207.0580).\n",
402 | "Essentially, dropout applies a Bernoulli mask to the features of the input.\n",
403 | "\n",
404 | "In [Gal, Y. (2016)](http://www.cs.ox.ac.uk/people/yarin.gal/website/thesis/thesis.pdf)\n",
405 | "it has been shown that a simple, somewhat ad-hoc approach of\n",
406 | "adding uncertainty quantification to networks through dropout,\n",
407 | "coupled with $\\ell_2$ weight penalty, is a special case of Variational Inference."
408 | ]
409 | },
410 | {
411 | "cell_type": "markdown",
412 | "metadata": {},
413 | "source": [
414 | "For input\n",
415 | "$\n",
416 | " x\\in \\mathbb{R}^{[\\mathrm{in}]}\n",
417 | "$ the dropout layer acts like this:\n",
418 | "\n",
419 | "$$\n",
420 | " y_j = x_j \\, m_j\n",
421 | " \\,, $$\n",
422 | "\n",
423 | "where $m\\in \\mathbb{R}^{[\\mathrm{in}]}$ with $\n",
424 | "m_j \\sim \\pi_p(m_j)\n",
425 | " = \\mathcal{Ber}\\bigl(\\bigl\\{0, \\tfrac1{1-p}\\bigr\\}, 1-p\\bigr)\n",
426 | "$,\n",
427 | "i.e. equals $\\tfrac1{1-p}$ with probability $1-p$ and $0$ otherwise."
428 | ]
429 | },
430 | {
431 | "cell_type": "markdown",
432 | "metadata": {},
433 | "source": [
434 | "#### (task) Always Active Dropout"
435 | ]
436 | },
437 | {
438 | "cell_type": "markdown",
439 | "metadata": {},
440 | "source": [
441 | "Useful methods:\n",
442 | "* `torch.rand(d1, ..., dn)` -- draw $d_1\\times \\ldots \\times d_n$ tensor of uniform rv-s\n",
443 | "* `torch.rand_like(other)` -- draw a tensor of uniform rv-s with the shape, data type and device as `other`\n",
444 | "\n",
445 | "\n",
446 | "* `torch.bernoulli(pi)` -- draw tensor $t$ with independent $\n",
447 | "t_\\alpha \\sim \\mathcal{Ber}\\bigl(\\{0, 1\\}, \\pi_\\alpha\\bigr)\n",
448 | "$ for each index $\\alpha$\n",
449 | "* `torch.full((d1, ..., dn), v)` -- a $d_1\\times \\ldots \\times d_n$ tensor with the same value $v$\n",
450 | "\n",
451 | "\n",
452 | "* `Tensor.to(other)` -- assume move `Tensor` to the device of the `other` and cast to its data type."
453 | ]
454 | },
455 | {
456 | "cell_type": "code",
457 | "execution_count": null,
458 | "metadata": {},
459 | "outputs": [],
460 | "source": [
461 | "from torch.nn import Module\n",
462 | "\n",
463 | "class ActiveDropout(Module):\n",
464 | " # all building blocks of networks are inherited from Module!\n",
465 | "\n",
466 | " def __init__(self, p=0.5):\n",
467 | " super().__init__() # init the base class\n",
468 | "\n",
469 | " self.p = p\n",
470 | "\n",
471 | " def forward(self, input):\n",
472 | " ## Exercise: implement feature dropout on input\n",
473 | " # self.p - contains the specified dropout rate\n",
474 | " \n",
475 | " mask = torch.rand_like(input) > self.p\n",
476 | " return input * mask.to(input) / (1 - self.p)\n",
477 | "\n",
478 | " # prob = torch.full_like(input, 1 - self.p)\n",
479 | " # return input * torch.bernoulli(prob) / prob\n",
480 | "\n",
481 | " # return F.dropout(input, self.p, True)\n",
482 | "\n",
483 | " pass"
484 | ]
485 | },
486 | {
487 | "cell_type": "markdown",
488 | "metadata": {},
489 | "source": [
490 | "
"
491 | ]
492 | },
493 | {
494 | "cell_type": "markdown",
495 | "metadata": {},
496 | "source": [
497 | "#### (task) Rebuilding the model"
498 | ]
499 | },
500 | {
501 | "cell_type": "markdown",
502 | "metadata": {},
503 | "source": [
504 | "Let's recreate the model above with this freshly minted dropout layer.\n",
505 | "Then fit and plot it's prediction uncertainty due to forward pass stochasticity."
506 | ]
507 | },
508 | {
509 | "cell_type": "code",
510 | "execution_count": null,
511 | "metadata": {},
512 | "outputs": [],
513 | "source": [
514 | "def build_model(p=0.5):\n",
515 | " \"\"\"Build a model with dropout layers' rate set to `p`.\"\"\"\n",
516 | "\n",
517 | " return Sequential(\n",
518 | " ## Exercise: Use `ActiveDropout` before linear layers of our\n",
519 | " # first network. Note that dropping out inputs is not a good idea\n",
520 | "\n",
521 | " Linear(1, 512, bias=True),\n",
522 | " LeakyReLU(),\n",
523 | "\n",
524 | " ActiveDropout(p),\n",
525 | " Linear(512, 512, bias=True),\n",
526 | " LeakyReLU(),\n",
527 | "\n",
528 | " ActiveDropout(p),\n",
529 | " Linear(512, 1, bias=True),\n",
530 | "\n",
531 | " # pass\n",
532 | " )"
533 | ]
534 | },
535 | {
536 | "cell_type": "markdown",
537 | "metadata": {},
538 | "source": [
539 | "
"
540 | ]
541 | },
542 | {
543 | "cell_type": "code",
544 | "execution_count": null,
545 | "metadata": {},
546 | "outputs": [],
547 | "source": [
548 | "model = build_model(p=0.5)\n",
549 | "\n",
550 | "model.to(device)\n",
551 | "\n",
552 | "fit(model, train, criterion=\"mse\", n_epochs=2000, verbose=True,\n",
553 | " weight_decay=1e-3)"
554 | ]
555 | },
556 | {
557 | "cell_type": "markdown",
558 | "metadata": {},
559 | "source": [
560 | "
"
561 | ]
562 | },
563 | {
564 | "cell_type": "markdown",
565 | "metadata": {},
566 | "source": [
567 | "#### Sampling the random output"
568 | ]
569 | },
570 | {
571 | "cell_type": "markdown",
572 | "metadata": {},
573 | "source": [
574 | "Let's take the test sample $\\tilde{S} = (\\tilde{x}_i)_{i=1}^m \\in \\mathcal{X}$\n",
575 | "and repeat the stochastic forward pass $B$ times at each $x\\in \\tilde{S}$:\n",
576 | "\n",
577 | "* for $b = 1 .. B$ do:\n",
578 | "\n",
579 | " 1. draw $y_{bi} \\sim f_\\omega(\\tilde{x}_i)$ for $i = 1 .. m$."
580 | ]
581 | },
582 | {
583 | "cell_type": "code",
584 | "execution_count": null,
585 | "metadata": {},
586 | "outputs": [],
587 | "source": [
588 | "def point_estimate(model, dataset, n_samples=1, verbose=False):\n",
589 | " \"\"\"Draw pointwise samples with stochastic forward pass.\"\"\"\n",
590 | "\n",
591 | " outputs = []\n",
592 | " for sample in tqdm.tqdm(range(n_samples), disable=not verbose):\n",
593 | "\n",
594 | " outputs.append(predict(model, dataset))\n",
595 | "\n",
596 | " return torch.stack(outputs, dim=0)\n",
597 | "\n",
598 | "\n",
599 | "samples = point_estimate(model, domain, n_samples=101, verbose=True)"
600 | ]
601 | },
602 | {
603 | "cell_type": "markdown",
604 | "metadata": {},
605 | "source": [
606 | "
"
607 | ]
608 | },
609 | {
610 | "cell_type": "markdown",
611 | "metadata": {},
612 | "source": [
613 | "The approximate $95\\%$ confidence band of predictions is..."
614 | ]
615 | },
616 | {
617 | "cell_type": "code",
618 | "execution_count": null,
619 | "metadata": {},
620 | "outputs": [],
621 | "source": [
622 | "fig, ax = plt.subplots(1, 1, figsize=(12, 5))\n",
623 | "ax.scatter(X_train, y_train, c=\"black\", s=40, label=\"train\")\n",
624 | "\n",
625 | "mean, std = samples.mean(dim=0).numpy(), samples.std(dim=0).numpy()\n",
626 | "ax.plot(X_domain, mean + 1.96 * std, c=\"k\")\n",
627 | "ax.plot(X_domain, mean - 1.96 * std, c=\"k\");"
628 | ]
629 | },
630 | {
631 | "cell_type": "markdown",
632 | "metadata": {},
633 | "source": [
634 | "
"
635 | ]
636 | },
637 | {
638 | "cell_type": "markdown",
639 | "metadata": {},
640 | "source": [
641 | "### Implementing function sampling with the DropoutLinear Layer"
642 | ]
643 | },
644 | {
645 | "cell_type": "markdown",
646 | "metadata": {},
647 | "source": [
648 | "Let's inspect the draws $y_{bi}$ as $B$ functional samples:\n",
649 | "$(x_i, y_{bi})_{i=1}^m$ - the $b$-th sample path. Below we\n",
650 | "plot $5$ random paths."
651 | ]
652 | },
653 | {
654 | "cell_type": "code",
655 | "execution_count": null,
656 | "metadata": {},
657 | "outputs": [],
658 | "source": [
659 | "samples = point_estimate(model, domain, n_samples=101, verbose=True)\n",
660 | "\n",
661 | "fig, ax = plt.subplots(1, 1, figsize=(12, 5))\n",
662 | "\n",
663 | "ax.scatter(X_train, y_train, c=\"black\", s=40, label=\"train\")\n",
664 | "ax.plot(X_domain[:, 0], samples[:5, :, 0].numpy().T, c=\"C0\", lw=1, alpha=0.25);"
665 | ]
666 | },
667 | {
668 | "cell_type": "markdown",
669 | "metadata": {},
670 | "source": [
671 | "It is clear that they are very erratic!"
672 | ]
673 | },
674 | {
675 | "cell_type": "markdown",
676 | "metadata": {},
677 | "source": [
678 | "Computing stochastic forward passes with a new mask each time is equivalent\n",
679 | "to drawing new **independent** prediction from for each point $x\\in \\tilde{S}$,\n",
680 | "without considering that, in fact, at adjacent points the predictions should\n",
681 | "be correlated. If we were interested in uncertainty at some particular point,\n",
682 | "this would be okay: **fast and simple**."
683 | ]
684 | },
685 | {
686 | "cell_type": "markdown",
687 | "metadata": {},
688 | "source": [
689 | "However, if we are interested in the uncertainty of an integral **path-dependent**\n",
690 | "measure of the whole estimated function, or are doing **optimization** of\n",
691 | "the unknown true function taking estimation uncertainty into account, then\n",
692 | "this clearly erratic behaviour of paths is undesirable. Ex. see\n",
693 | "[blog: Gal, Y. 2016](http://www.cs.ox.ac.uk/people/yarin.gal/website/blog_2248.html)"
694 | ]
695 | },
696 | {
697 | "cell_type": "markdown",
698 | "metadata": {},
699 | "source": [
700 | "
"
701 | ]
702 | },
703 | {
704 | "cell_type": "markdown",
705 | "metadata": {},
706 | "source": [
707 | "We need to implement some extra functionality on top of `pytorch`,\n",
708 | "in order to draw realizations from the induced distribution over\n",
709 | "functions, defined by a network, i.e. $\n",
710 | "\\bigl\\{\n",
711 | " f_\\omega\\colon \\mathcal{X}\\to\\mathcal{Y}\n",
712 | "\\bigr\\}_{\\omega \\sim q(\\omega)}\n",
713 | "$\n",
714 | "where $q(\\omega)$ is a distribution over the parameters."
715 | ]
716 | },
717 | {
718 | "cell_type": "markdown",
719 | "metadata": {},
720 | "source": [
721 | "One of the design approaches is to allow layers\n",
722 | "to cache random draws of their parameters for reuse\n",
723 | "in all subsequent forward passes, until this is no\n",
724 | "longer needed."
725 | ]
726 | },
727 | {
728 | "cell_type": "markdown",
729 | "metadata": {},
730 | "source": [
731 | "#### Freeze/unfreeze interface"
732 | ]
733 | },
734 | {
735 | "cell_type": "markdown",
736 | "metadata": {},
737 | "source": [
738 | "This is a base **trait-class** `FreezableWeight` that adds interface\n",
739 | "for freezing and unfreezing layer's random **weight** parameter."
740 | ]
741 | },
742 | {
743 | "cell_type": "code",
744 | "execution_count": null,
745 | "metadata": {},
746 | "outputs": [],
747 | "source": [
748 | "class FreezableWeight(Module):\n",
749 | " def __init__(self):\n",
750 | " super().__init__()\n",
751 | " self.unfreeze()\n",
752 | "\n",
753 | " def unfreeze(self):\n",
754 | " self.register_buffer(\"frozen_weight\", None)\n",
755 | "\n",
756 | " def is_frozen(self):\n",
757 | " \"\"\"Check if a frozen weight is available.\"\"\"\n",
758 | " return isinstance(self.frozen_weight, torch.Tensor)\n",
759 | "\n",
760 | " def freeze(self):\n",
761 | " \"\"\"Sample from the parameter distribution and freeze.\"\"\"\n",
762 | " raise NotImplementedError()"
763 | ]
764 | },
765 | {
766 | "cell_type": "markdown",
767 | "metadata": {},
768 | "source": [
769 | "Next, we declare a pair of functions:\n",
770 | "* `freeze()` instructs each compatible layer of the model to **sample and freeze** its randomness\n",
771 | "* `unfreeze()` requests the layers to **undo** this"
772 | ]
773 | },
774 | {
775 | "cell_type": "code",
776 | "execution_count": null,
777 | "metadata": {},
778 | "outputs": [],
779 | "source": [
780 | "def freeze(model):\n",
781 | " for layer in model.modules():\n",
782 | " if isinstance(layer, FreezableWeight):\n",
783 | " layer.freeze()\n",
784 | "\n",
785 | " return model"
786 | ]
787 | },
788 | {
789 | "cell_type": "code",
790 | "execution_count": null,
791 | "metadata": {},
792 | "outputs": [],
793 | "source": [
794 | "def unfreeze(model):\n",
795 | " for layer in model.modules():\n",
796 | " if isinstance(layer, FreezableWeight):\n",
797 | " layer.unfreeze()\n",
798 | "\n",
799 | " return model"
800 | ]
801 | },
802 | {
803 | "cell_type": "markdown",
804 | "metadata": {},
805 | "source": [
806 | "
"
807 | ]
808 | },
809 | {
810 | "cell_type": "markdown",
811 | "metadata": {},
812 | "source": [
813 | "#### (task) Sampling realizations"
814 | ]
815 | },
816 | {
817 | "cell_type": "markdown",
818 | "metadata": {},
819 | "source": [
820 | "The algorithm to sample a random function is:\n",
821 | "* for $b = 1... B$ do:\n",
822 | "\n",
823 | " 1. draw an independent realization $f_b\\colon \\mathcal{X} \\to \\mathcal{Y}$\n",
824 | " with from the process $\\{f_\\omega\\}_{\\omega \\sim q(\\omega)}$\n",
825 | " 2. get $\\hat{y}_{bi} = f_b(\\tilde{x}_i)$ for $i=1 .. m$\n"
826 | ]
827 | },
828 | {
829 | "cell_type": "code",
830 | "execution_count": null,
831 | "metadata": {},
832 | "outputs": [],
833 | "source": [
834 | "def sample_function(model, dataset, n_samples=1, verbose=False):\n",
835 | " \"\"\"Draw a realization of a random function.\"\"\"\n",
836 | "\n",
837 | " ## Exercise: code a function similar to `point_estimate()`,\n",
838 | " ## that collects the predictions from `frozen` models. Don't\n",
839 | " ## forget to unfreeze before returning.\n",
840 | "\n",
841 | " outputs = []\n",
842 | " for _ in tqdm.tqdm(range(n_samples), disable=not verbose):\n",
843 | " freeze(model)\n",
844 | "\n",
845 | " outputs.append(predict(model, dataset))\n",
846 | "\n",
847 | " unfreeze(model)\n",
848 | "\n",
849 | " return torch.stack(outputs, dim=0)\n",
850 | "\n",
851 | " pass"
852 | ]
853 | },
854 | {
855 | "cell_type": "markdown",
856 | "metadata": {},
857 | "source": [
858 | "**(note)** although the internal loop in both functions looks\n",
859 | "similar they, conceptually the functions differ:\n",
860 | "\n",
861 | "```python\n",
862 | "def point_estimate(f, S):\n",
863 | " for x in S:\n",
864 | " for w from f.q: # different w for different x\n",
865 | " yield f(x, w)\n",
866 | "\n",
867 | "\n",
868 | "def sample_function(f, S):\n",
869 | " for w from f.q:\n",
870 | " for x in S: # same w for different x (thanks to freeze)\n",
871 | " yield f(x, w)\n",
872 | "```\n",
873 | ""
874 | ]
875 | },
876 | {
877 | "cell_type": "markdown",
878 | "metadata": {},
879 | "source": [
880 | "
"
881 | ]
882 | },
883 | {
884 | "cell_type": "markdown",
885 | "metadata": {},
886 | "source": [
887 | "### Implementing `DropoutLinear`"
888 | ]
889 | },
890 | {
891 | "cell_type": "markdown",
892 | "metadata": {},
893 | "source": [
894 | "Now we will merge `ActiveDropout` and `Linear` layers into one, which\n",
895 | "\n",
896 | "1. (on forward pass) **drops out** the inputs, if necessary, and **applies** the linear (affine) transform\n",
897 | "2. (on freeze) **randomly zeros** columns in a copy of the the weight matrix $W$\n",
898 | "\n",
899 | "Preferably, we will try to preserve interface, so that the resulting\n",
900 | "object is backwards compatible with `Linear`."
901 | ]
902 | },
903 | {
904 | "cell_type": "markdown",
905 | "metadata": {},
906 | "source": [
907 | "This way we would be able to draw realizations from the induced\n",
908 | "distribution over functions defined by the network $\n",
909 | "\\bigl\\{\n",
910 | " f_\\omega\\colon \\mathcal{X}\\to\\mathcal{Y}\n",
911 | "\\bigr\\}_{\\omega \\sim q(\\omega)}\n",
912 | "$\n",
913 | "where $q(\\omega)$ a distribution over the network parameters."
914 | ]
915 | },
916 | {
917 | "cell_type": "markdown",
918 | "metadata": {},
919 | "source": [
920 | "
"
921 | ]
922 | },
923 | {
924 | "cell_type": "markdown",
925 | "metadata": {},
926 | "source": [
927 | "#### (task) Fused dropout-linear operation"
928 | ]
929 | },
930 | {
931 | "cell_type": "markdown",
932 | "metadata": {},
933 | "source": [
934 | "On the inputs into a linear layer dropout acts like this: for input\n",
935 | "$\n",
936 | " x\\in \\mathbb{R}^{[\\mathrm{in}]}\n",
937 | "$ and layer weights $\n",
938 | " W\\in \\mathbb{R}^{[\\mathrm{out}] \\times [\\mathrm{in}]}\n",
939 | "$\n",
940 | "and bias $\n",
941 | " b\\in \\mathbb{R}^{[\\mathrm{out}]}\n",
942 | "$ the resulting effect is\n",
943 | "\n",
944 | "$$\n",
945 | " \\tilde{x} = x \\odot m\n",
946 | " \\,, \\\\\n",
947 | " y = \\tilde{x} W^\\top + b\n",
948 | "% = b + \\sum_i x_i m_i W_i\n",
949 | " \\,, $$\n",
950 | "\n",
951 | "where $\\odot$ is the elementwise product and $m\\in \\mathbb{R}^{[\\mathrm{in}]}$\n",
952 | "with $m_j \\sim \\pi_p(m_j) = \\mathcal{Ber}\\bigl(\\bigl\\{0, \\tfrac1{1-p}\\bigr\\}, 1-p\\bigr)$,\n",
953 | "i.e. equals $\\tfrac1{1-p}$ with probability $1-p$ and $0$ otherwise."
954 | ]
955 | },
956 | {
957 | "cell_type": "markdown",
958 | "metadata": {},
959 | "source": [
960 | "Let\n",
961 | "$\n",
962 | " x\\in \\mathbb{R}^{[\\mathrm{in}]}\n",
963 | "$, $\n",
964 | " W\\in \\mathbb{R}^{[\\mathrm{out}] \\times [\\mathrm{in}]}\n",
965 | "$\n",
966 | "and $\n",
967 | " b\\in \\mathbb{R}^{[\\mathrm{out}]}\n",
968 | "$. Let's use the following `torch`'s functions:\n",
969 | "\n",
970 | "* `F.dropout(x, p, on/off)` -- independent Bernoulli dropout $x\\mapsto x\\odot m$\n",
971 | " for $m\\sim \\mathcal{Ber}\\bigl(\\bigl\\{0, \\tfrac1{1-p}\\bigr\\}, 1-p\\bigr)$\n",
972 | "\n",
973 | "* `F.linear(x, W, b)` -- affine transformation $x \\mapsto x W^\\top + b$\n",
974 | "\n",
975 | "**(note)** the `.weight` of a linear layer in `pytorch` is an $\n",
976 | "{\n",
977 | " [\\mathrm{out}]\n",
978 | " \\times [\\mathrm{in}]\n",
979 | "}\n",
980 | "$ matrix.\n",
981 | "\n",
982 | ""
986 | ]
987 | },
988 | {
989 | "cell_type": "code",
990 | "execution_count": null,
991 | "metadata": {},
992 | "outputs": [],
993 | "source": [
994 | "def DropoutLinear_forward(self, input):\n",
995 | " ## Exercise: If not frozen, then apply always active dropout,\n",
996 | " # then linear transformation. If frozen, apply the transform\n",
997 | " # using the frozen weight\n",
998 | "\n",
999 | " # linear with frozen weight\n",
1000 | " if self.is_frozen():\n",
1001 | " return F.linear(input, self.frozen_weight, self.bias)\n",
1002 | "\n",
1003 | " # stochastic pass as in `ActiveDropout` + Linear\n",
1004 | " input = F.dropout(input, self.p, True)\n",
1005 | "\n",
1006 | " return F.linear(input, self.weight, self.bias)\n",
1007 | " # return super().forward(F.dropout(input, self.p, True))\n",
1008 | "\n",
1009 | " pass"
1010 | ]
1011 | },
1012 | {
1013 | "cell_type": "markdown",
1014 | "metadata": {},
1015 | "source": [
1016 | "
"
1017 | ]
1018 | },
1019 | {
1020 | "cell_type": "markdown",
1021 | "metadata": {},
1022 | "source": [
1023 | "#### Parameter freezer for our custom layer"
1024 | ]
1025 | },
1026 | {
1027 | "cell_type": "markdown",
1028 | "metadata": {},
1029 | "source": [
1030 | "For input\n",
1031 | "$\n",
1032 | " x\\in \\mathbb{R}^{[\\mathrm{in}]}\n",
1033 | "$ and a layer parameters $\n",
1034 | " W\\in \\mathbb{R}^{[\\mathrm{out}] \\times [\\mathrm{in}]}\n",
1035 | "$\n",
1036 | "and $\n",
1037 | " b\\in \\mathbb{R}^{[\\mathrm{out}]}\n",
1038 | "$ the effect in `DropoutLinear` is\n",
1039 | "\n",
1040 | "$$\n",
1041 | " y_j\n",
1042 | " = \\bigl[(x \\odot m) W^\\top + b\\bigr]_j\n",
1043 | " = b_j + \\sum_i x_i m_i W_{ji}\n",
1044 | " = b_j + \\sum_i x_i \\breve{W}_{ji}\n",
1045 | " \\,, $$\n",
1046 | "\n",
1047 | "where the each column of $\\breve{W}_i$ is, independently, either\n",
1048 | "$\\mathbf{0} \\in \\mathbb{R}^{[\\mathrm{out}]}$ with probability $p$ or\n",
1049 | "some (learnable) vector in $\\mathbb{R}^{[\\mathrm{out}]}$\n",
1050 | "\n",
1051 | "$$\n",
1052 | " \\breve{W}_i \\sim\n",
1053 | "\\begin{cases}\n",
1054 | " \\mathbf{0}\n",
1055 | " & \\text{ w. prob } p \\,, \\\\\n",
1056 | " \\tfrac1{1-p} M_i\n",
1057 | " & \\text{ w. prob } 1-p \\,.\n",
1058 | "\\end{cases}\n",
1059 | "$$\n",
1060 | "\n",
1061 | "Thus the multiplicative effect of the random mask $m$ on $x$ can be\n",
1062 | "equivalently seen as a random **on/off** switch effect on the\n",
1063 | "**columns** of the matrix $W$."
1064 | ]
1065 | },
1066 | {
1067 | "cell_type": "code",
1068 | "execution_count": null,
1069 | "metadata": {},
1070 | "outputs": [],
1071 | "source": [
1072 | "def DropoutLinear_freeze(self):\n",
1073 | " \"\"\"Apply dropout with rate `p` to columns of `weight` and freeze it.\"\"\"\n",
1074 | " # we leverage torch's broadcasting semantics and draw a one-row\n",
1075 | " # mask binary mask, that we later multiply the weight by.\n",
1076 | "\n",
1077 | " # let's draw the new weight\n",
1078 | " with torch.no_grad():\n",
1079 | " prob = torch.full_like(self.weight[:1, :], 1 - self.p)\n",
1080 | " feature_mask = torch.bernoulli(prob) / prob\n",
1081 | "\n",
1082 | " frozen_weight = self.weight * feature_mask\n",
1083 | "\n",
1084 | " # and store it\n",
1085 | " self.register_buffer(\"frozen_weight\", frozen_weight)"
1086 | ]
1087 | },
1088 | {
1089 | "cell_type": "markdown",
1090 | "metadata": {},
1091 | "source": [
1092 | "
"
1093 | ]
1094 | },
1095 | {
1096 | "cell_type": "markdown",
1097 | "metadata": {},
1098 | "source": [
1099 | "Assemble the blocks into a layer"
1100 | ]
1101 | },
1102 | {
1103 | "cell_type": "code",
1104 | "execution_count": null,
1105 | "metadata": {},
1106 | "outputs": [],
1107 | "source": [
1108 | "class DropoutLinear(Linear, FreezableWeight):\n",
1109 | " \"\"\"Linear layer with dropout on inputs.\"\"\"\n",
1110 | " def __init__(self, in_features, out_features, bias=True, p=0.5):\n",
1111 | " super().__init__(in_features, out_features, bias=bias)\n",
1112 | "\n",
1113 | " self.p = p\n",
1114 | "\n",
1115 | " forward = DropoutLinear_forward\n",
1116 | "\n",
1117 | " freeze = DropoutLinear_freeze"
1118 | ]
1119 | },
1120 | {
1121 | "cell_type": "markdown",
1122 | "metadata": {},
1123 | "source": [
1124 | "
"
1125 | ]
1126 | },
1127 | {
1128 | "cell_type": "markdown",
1129 | "metadata": {},
1130 | "source": [
1131 | "### Comparing sample functions to point-estimates "
1132 | ]
1133 | },
1134 | {
1135 | "cell_type": "markdown",
1136 | "metadata": {},
1137 | "source": [
1138 | "Let's rewrite the model builder function:"
1139 | ]
1140 | },
1141 | {
1142 | "cell_type": "code",
1143 | "execution_count": null,
1144 | "metadata": {},
1145 | "outputs": [],
1146 | "source": [
1147 | "def build_model(p=0.5):\n",
1148 | " \"\"\"Build a model with the custom layer and dropout rate set to `p`.\"\"\"\n",
1149 | "\n",
1150 | " return Sequential(\n",
1151 | " ## Exercise: Plug-in `DropoutLinear` layer into our second network.\n",
1152 | "\n",
1153 | " Linear(1, 512, bias=True),\n",
1154 | " LeakyReLU(),\n",
1155 | "\n",
1156 | " DropoutLinear(512, 512, bias=True , p=p),\n",
1157 | " LeakyReLU(),\n",
1158 | "\n",
1159 | " DropoutLinear(512, 1, bias=True, p=p),\n",
1160 | "\n",
1161 | " # pass\n",
1162 | " )"
1163 | ]
1164 | },
1165 | {
1166 | "cell_type": "markdown",
1167 | "metadata": {},
1168 | "source": [
1169 | "Let's create a new instance and retrain the model."
1170 | ]
1171 | },
1172 | {
1173 | "cell_type": "code",
1174 | "execution_count": null,
1175 | "metadata": {},
1176 | "outputs": [],
1177 | "source": [
1178 | "model = build_model(p=0.5)\n",
1179 | "model.to(device)\n",
1180 | "\n",
1181 | "fit(model, train, criterion=\"mse\", n_epochs=2000, verbose=True, weight_decay=1e-3)"
1182 | ]
1183 | },
1184 | {
1185 | "cell_type": "markdown",
1186 | "metadata": {},
1187 | "source": [
1188 | "... and obtain two estimates: pointwise and functional."
1189 | ]
1190 | },
1191 | {
1192 | "cell_type": "code",
1193 | "execution_count": null,
1194 | "metadata": {},
1195 | "outputs": [],
1196 | "source": [
1197 | "samples_pe = point_estimate(model, domain, n_samples=51, verbose=True)\n",
1198 | "samples_sf = sample_function(model, domain, n_samples=51, verbose=True)\n",
1199 | "\n",
1200 | "samples_pe.shape, samples_sf.shape"
1201 | ]
1202 | },
1203 | {
1204 | "cell_type": "markdown",
1205 | "metadata": {},
1206 | "source": [
1207 | "```python\n",
1208 | "(torch.Size([51, 251, 1]), torch.Size([51, 251, 1]))\n",
1209 | "```"
1210 | ]
1211 | },
1212 | {
1213 | "cell_type": "markdown",
1214 | "metadata": {},
1215 | "source": [
1216 | "
"
1217 | ]
1218 | },
1219 | {
1220 | "cell_type": "markdown",
1221 | "metadata": {},
1222 | "source": [
1223 | "Let's compare **point estimates**\n",
1224 | "with **function sampling**."
1225 | ]
1226 | },
1227 | {
1228 | "cell_type": "code",
1229 | "execution_count": null,
1230 | "metadata": {},
1231 | "outputs": [],
1232 | "source": [
1233 | "fig, ax = plt.subplots(1, 1, figsize=(12, 5))\n",
1234 | "\n",
1235 | "ax.plot(X_domain[:, 0], samples_pe[:10, :, 0].numpy().T,\n",
1236 | " c=\"C1\", lw=1, alpha=0.5)\n",
1237 | "\n",
1238 | "ax.plot(X_domain[:, 0], samples_sf[:10, :, 0].numpy().T,\n",
1239 | " c=\"C0\", lw=2, alpha=0.5)\n",
1240 | "\n",
1241 | "ax.scatter(X_train, y_train, c=\"black\", s=40,\n",
1242 | " label=\"train\", zorder=+10);"
1243 | ]
1244 | },
1245 | {
1246 | "cell_type": "code",
1247 | "execution_count": null,
1248 | "metadata": {},
1249 | "outputs": [],
1250 | "source": [
1251 | "fig, ax = plt.subplots(1, 1, figsize=(12, 5))\n",
1252 | "\n",
1253 | "ax.scatter(X_train, y_train, c=\"black\", s=40, label=\"train\")\n",
1254 | "\n",
1255 | "mean, std = samples_sf.mean(dim=0).numpy(), samples_sf.std(dim=0).numpy()\n",
1256 | "ax.plot(X_domain, mean + 1.96 * std, c=\"C0\")\n",
1257 | "ax.plot(X_domain, mean - 1.96 * std, c=\"C0\");\n",
1258 | "\n",
1259 | "mean, std = samples_pe.mean(dim=0).numpy(), samples_pe.std(dim=0).numpy()\n",
1260 | "ax.plot(X_domain, mean + 1.96 * std, c=\"C1\")\n",
1261 | "ax.plot(X_domain, mean - 1.96 * std, c=\"C1\");"
1262 | ]
1263 | },
1264 | {
1265 | "cell_type": "markdown",
1266 | "metadata": {},
1267 | "source": [
1268 | "Pros of `point-estimate`:\n",
1269 | "* uses stochastic forward passes -- no need to for extra code and classes\n",
1270 | "\n",
1271 | "Cons of `point-estimate`:\n",
1272 | "* samples from the predictive distribution at adjacent inputs are independent"
1273 | ]
1274 | },
1275 | {
1276 | "cell_type": "markdown",
1277 | "metadata": {},
1278 | "source": [
1279 | "
"
1280 | ]
1281 | },
1282 | {
1283 | "cell_type": "markdown",
1284 | "metadata": {},
1285 | "source": [
1286 | "**(note)**\n",
1287 | "The parameter distribution of the layer we've built is\n",
1288 | "\n",
1289 | "$$\n",
1290 | " q(\\omega\\mid \\theta)\n",
1291 | " = \\prod_i q(\\omega_i\\mid \\theta_i)\n",
1292 | " = \\prod_i \\bigl\\{\n",
1293 | " p \\delta_{\\mathbf{0}} (\\omega_i)\n",
1294 | " + (1 - p) \\delta_{\\tfrac1{1-p} \\theta_i}(\\omega_i)\n",
1295 | " \\bigr\\}\n",
1296 | " \\,, $$\n",
1297 | "\n",
1298 | "where $\\omega_i$ is the $i$-th column of $\\omega$, $\\delta_a$ is a\n",
1299 | "**point-mass** distribution at $a$, and $\\theta$ is the learnt\n",
1300 | "approximate posterior mean of $\\omega$."
1301 | ]
1302 | },
1303 | {
1304 | "cell_type": "markdown",
1305 | "metadata": {},
1306 | "source": [
1307 | "Under benign assumptions and certain relaxations\n",
1308 | "[Gal, Y. 2016 (eq. (6.3) p.109, Prop. 4 p.149)](http://www.cs.ox.ac.uk/people/yarin.gal/website/thesis/thesis.pdf)\n",
1309 | "has shown that a deep network with dropout rate $p$\n",
1310 | "and $\\ell_2$ weight penalty (`weight_decay`) performs (doubly)\n",
1311 | "**stochastic variational inference** with the following stochastic\n",
1312 | "approximate **evidence lower bound**: for the dataset $D = (x_i, y_i)_i$\n",
1313 | "of size $N = \\lvert D \\rvert$ and random batches $B$ of size\n",
1314 | "$\\lvert B \\rvert = m$\n",
1315 | "\n",
1316 | "$$\n",
1317 | " \\frac1{N} \\Bigl( \\underbrace{\n",
1318 | " \\mathbb{E}_{\\omega\\sim q(\\omega\\mid \\theta)} \\log p(D \\mid \\omega)\n",
1319 | " - KL\\bigl(q(\\omega\\mid \\theta) \\big\\| \\pi(\\omega) \\bigr)\n",
1320 | " }_{ELBO(\\theta)} \\Bigr)\n",
1321 | " \\approx \\frac1{\\lvert B \\rvert}\n",
1322 | " \\sum_{i\\in B} \\log p(y_i \\mid x_i, \\omega^{(1)}_i, \\ldots, \\omega^{(L)}_i)\n",
1323 | " - \\sum_{l=1}^L\n",
1324 | " \\frac{1-p^{(l)}}{2 s^2 N} \\|\\theta^{(l)}\\|_2^2\n",
1325 | "% - [\\mathrm{in}_{(l)}] \\, \\mathbb{H}(\\mathcal{Ber}(p^{(l)}))\n",
1326 | "% + \\mathrm{const}\n",
1327 | "\\,, $$\n",
1328 | "where $\\omega_i^{(l)}$ are independently drawn from $q(\\omega \\mid \\theta)$\n",
1329 | "(one random draw per element in $B$) and $s^2$ is the prior variance."
1330 | ]
1331 | },
1332 | {
1333 | "cell_type": "markdown",
1334 | "metadata": {},
1335 | "source": [
1336 | "Thus `weight_decay` should be decreasing with $p$ and $N$:\n",
1337 | "$$ \\lambda = \\frac{1-p}{2 s^2 N} \\,. $$"
1338 | ]
1339 | },
1340 | {
1341 | "cell_type": "markdown",
1342 | "metadata": {},
1343 | "source": [
1344 | "
"
1345 | ]
1346 | },
1347 | {
1348 | "cell_type": "markdown",
1349 | "metadata": {},
1350 | "source": [
1351 | "#### Question(s) (to ponder in your spare time)\n",
1352 | "\n",
1353 | "* what happens to the confidence bands, when you increase the number\n",
1354 | " of path-wise and pointwise samples?\n",
1355 | "\n",
1356 | "* what will happen if you change the dropout rate $p$ and keep `n_epochs` at 2000?\n",
1357 | "\n",
1358 | "* what happens if for $p=\\tfrac12$ we use much less `n_epochs`?\n",
1359 | "\n",
1360 | "* how does different settings of `weight_decay` affect the bands?\n",
1361 | "\n",
1362 | "Try to rebuild the model with different $p \\in (0, 1)$ using `build_model(p)`, use\n",
1363 | "`fit(..., n_epochs=...)`, and then plot the predictive bands."
1364 | ]
1365 | },
1366 | {
1367 | "cell_type": "code",
1368 | "execution_count": null,
1369 | "metadata": {},
1370 | "outputs": [],
1371 | "source": [
1372 | "from mlss2019bdl.plotting import plot1d_bands\n",
1373 | "\n",
1374 | "# model = fit(build_model(p=...), train, n_epochs=..., weight_decay=..., criterion=\"mse\")\n",
1375 | "# plot1d_bands(sample_function(model, domain, n_samples=101), c=\"C0\")"
1376 | ]
1377 | },
1378 | {
1379 | "cell_type": "markdown",
1380 | "metadata": {},
1381 | "source": [
1382 | "
"
1383 | ]
1384 | },
1385 | {
1386 | "cell_type": "code",
1387 | "execution_count": null,
1388 | "metadata": {},
1389 | "outputs": [],
1390 | "source": [
1391 | "model_a = fit(build_model(p=0.15), train, criterion=\"mse\", n_epochs=2000, weight_decay=1e-3)\n",
1392 | "\n",
1393 | "model_z = fit(build_model(p=0.75), train, criterion=\"mse\", n_epochs=2000, weight_decay=1e-3)"
1394 | ]
1395 | },
1396 | {
1397 | "cell_type": "code",
1398 | "execution_count": null,
1399 | "metadata": {},
1400 | "outputs": [],
1401 | "source": [
1402 | "fig = plt.figure(figsize=(12, 5))\n",
1403 | "\n",
1404 | "samples_a = sample_function(model_a, domain, n_samples=101)\n",
1405 | "samples_z = sample_function(model_z, domain, n_samples=101)\n",
1406 | "\n",
1407 | "plot1d_bands(X_domain, samples_a.transpose(0, 2), c=\"r\")\n",
1408 | "plot1d_bands(X_domain, samples_z.transpose(0, 2), c=\"b\")"
1409 | ]
1410 | },
1411 | {
1412 | "cell_type": "code",
1413 | "execution_count": null,
1414 | "metadata": {},
1415 | "outputs": [],
1416 | "source": [
1417 | "model_a = fit(build_model(p=0.50), train, criterion=\"mse\", n_epochs=20, weight_decay=1e-3)\n",
1418 | "\n",
1419 | "model_z = fit(build_model(p=0.50), train, criterion=\"mse\", n_epochs=200, weight_decay=1e-3)"
1420 | ]
1421 | },
1422 | {
1423 | "cell_type": "code",
1424 | "execution_count": null,
1425 | "metadata": {},
1426 | "outputs": [],
1427 | "source": [
1428 | "fig = plt.figure(figsize=(12, 5))\n",
1429 | "\n",
1430 | "samples_a = sample_function(model_a, domain, n_samples=101)\n",
1431 | "samples_z = sample_function(model_z, domain, n_samples=101)\n",
1432 | "\n",
1433 | "plot1d_bands(X_domain, samples_a.transpose(0, 2), c=\"r\")\n",
1434 | "plot1d_bands(X_domain, samples_z.transpose(0, 2), c=\"b\")"
1435 | ]
1436 | },
1437 | {
1438 | "cell_type": "code",
1439 | "execution_count": null,
1440 | "metadata": {},
1441 | "outputs": [],
1442 | "source": [
1443 | "model_a = fit(build_model(p=0.50), train, criterion=\"mse\", n_epochs=2000, weight_decay=1e-5)\n",
1444 | "\n",
1445 | "model_z = fit(build_model(p=0.50), train, criterion=\"mse\", n_epochs=2000, weight_decay=1e-1)"
1446 | ]
1447 | },
1448 | {
1449 | "cell_type": "code",
1450 | "execution_count": null,
1451 | "metadata": {},
1452 | "outputs": [],
1453 | "source": [
1454 | "fig = plt.figure(figsize=(12, 5))\n",
1455 | "\n",
1456 | "samples_a = sample_function(model_a, domain, n_samples=101)\n",
1457 | "\n",
1458 | "samples_z = sample_function(model_z, domain, n_samples=101)\n",
1459 | "\n",
1460 | "plot1d_bands(X_domain, samples_a.transpose(0, 2), c=\"r\")\n",
1461 | "plot1d_bands(X_domain, samples_z.transpose(0, 2), c=\"b\")"
1462 | ]
1463 | },
1464 | {
1465 | "cell_type": "code",
1466 | "execution_count": null,
1467 | "metadata": {},
1468 | "outputs": [],
1469 | "source": [
1470 | "model_a = fit(build_model(p=0.10), train, criterion=\"mse\", n_epochs=2000, weight_decay=1e-3)\n",
1471 | "\n",
1472 | "model_z = fit(build_model(p=0.90), train, criterion=\"mse\", n_epochs=2000, weight_decay=1e-4)"
1473 | ]
1474 | },
1475 | {
1476 | "cell_type": "code",
1477 | "execution_count": null,
1478 | "metadata": {},
1479 | "outputs": [],
1480 | "source": [
1481 | "fig = plt.figure(figsize=(12, 5))\n",
1482 | "\n",
1483 | "samples_a = sample_function(model_a, domain, n_samples=101)\n",
1484 | "\n",
1485 | "samples_z = sample_function(model_z, domain, n_samples=101)\n",
1486 | "\n",
1487 | "plot1d_bands(X_domain, samples_a.transpose(0, 2), c=\"r\")\n",
1488 | "plot1d_bands(X_domain, samples_z.transpose(0, 2), c=\"b\")"
1489 | ]
1490 | },
1491 | {
1492 | "cell_type": "markdown",
1493 | "metadata": {},
1494 | "source": [
1495 | "
"
1496 | ]
1497 | },
1498 | {
1499 | "cell_type": "markdown",
1500 | "metadata": {},
1501 | "source": [
1502 | "### (optional) Dropout $2$-d Convolutional layer"
1503 | ]
1504 | },
1505 | {
1506 | "cell_type": "markdown",
1507 | "metadata": {},
1508 | "source": [
1509 | "Typically, in convolutional neural networks the dropout acts upon the feature\n",
1510 | "(channel) information and not on the spatial dimensions. Thus entire channels\n",
1511 | "are dropped out and for $\n",
1512 | " x \\in \\mathbb{R}^{\n",
1513 | " [\\mathrm{in}]\n",
1514 | " \\times h\n",
1515 | " \\times w}\n",
1516 | "$ and $\n",
1517 | " y \\in \\mathbb{R}^{\n",
1518 | " [\\mathrm{out}]\n",
1519 | " \\times h'\n",
1520 | " \\times w'}\n",
1521 | "$ the full effect of the `Dropout+Conv2d` layer is\n",
1522 | "\n",
1523 | "$$\n",
1524 | " y_{lij} = ((x \\odot m) \\ast W_l)_{ij} + b_l\n",
1525 | " = b_l + \\sum_k \\sum_{pq} x_{k i_p j_q} m_k W_{lkpq}\n",
1526 | " \\,, \\tag{conv-2d} $$\n",
1527 | " \n",
1528 | "where i.i.d $m_k \\sim \\mathcal{Ber}\\bigl(\\bigl\\{0, \\tfrac1{1-p}\\bigr\\}, 1-p\\bigr)$,\n",
1529 | "and indices $i_p$ and $j_q$ represent the spatial location in $x$ that correspond\n",
1530 | "to the $p$ and $q$ elements in the kernel $\n",
1531 | " W\\in \\mathbb{R}^{\n",
1532 | " [\\mathrm{out}]\n",
1533 | " \\times [\\mathrm{in}]\n",
1534 | " \\times h\n",
1535 | " \\times w}\n",
1536 | "$ relative to $(i, j)$ coordinates in $y$.\n",
1537 | "The exact values of $i_p$ and $j_q$ depend on the configuration of the\n",
1538 | "convolutional layer, e.g. stride, kernel size and dilation.\n",
1539 | "\n",
1540 | "**(note)** Informative illustrations on the effects of convolution\n",
1541 | "parameters can be found in [Convolution arithmetic](https://github.com/vdumoulin/conv_arithmetic) \n",
1542 | "repo."
1543 | ]
1544 | },
1545 | {
1546 | "cell_type": "markdown",
1547 | "metadata": {},
1548 | "source": [
1549 | "
"
1550 | ]
1551 | },
1552 | {
1553 | "cell_type": "markdown",
1554 | "metadata": {},
1555 | "source": [
1556 | "## (optional) A brief reminder on Bayesian and Variational Inference"
1557 | ]
1558 | },
1559 | {
1560 | "cell_type": "markdown",
1561 | "metadata": {},
1562 | "source": [
1563 | "Bayesian Inference is a principled framework of reasoning about uncertainty.\n",
1564 | "\n",
1565 | "In Bayesian Inference (**BI**) we *assume* that the observation\n",
1566 | "data $D$ follows a *model* $m$ with data generating distribution\n",
1567 | "$p(D\\mid m, \\omega)$ *governed by unknown parameters* $\\omega$.\n",
1568 | "The goal of **BI** is to reason about the model and/or its parameters,\n",
1569 | "and new data given the observed data $D$ and our assumptions, i.e\n",
1570 | "to seek the **posterior** parameter and predictive distributions:\n",
1571 | "\n",
1572 | "$$\\begin{align}\n",
1573 | " p(d \\mid D, m)\n",
1574 | " % &= \\mathbb{E}_{\n",
1575 | " % \\omega \\sim p(\\omega \\mid D, m)\n",
1576 | " % } p(d \\mid D, \\omega, m)\n",
1577 | " &= \\int p(d \\mid D, \\omega, m) p(\\omega \\mid D, m) d\\omega\n",
1578 | " \\,, \\\\\n",
1579 | " p(\\omega \\mid D, m)\n",
1580 | " &= \\frac{p(D\\mid \\omega, m) \\, \\pi(\\omega \\mid m)}{p(D\\mid m)}\n",
1581 | " \\,.\n",
1582 | "\\end{align}\n",
1583 | "$$\n",
1584 | "\n",
1585 | "* the **prior** distribution $\\pi(\\omega \\mid m)$ reflects our belief\n",
1586 | " before having made the observations\n",
1587 | "\n",
1588 | "* the data distribution $p(D \\mid \\omega, m)$ reflects our assumptions\n",
1589 | " about the data generating process, and determines the parameter\n",
1590 | " **likelihood** (Gaussian, Categorical, Poisson)"
1591 | ]
1592 | },
1593 | {
1594 | "cell_type": "markdown",
1595 | "metadata": {},
1596 | "source": [
1597 | "Unless the distributions and likelihoods are conjugate, posterior in\n",
1598 | "Bayesian inference is typically intractable and it is common to resort\n",
1599 | "to **Variational Inference** or **Monte Carlo** approximations."
1600 | ]
1601 | },
1602 | {
1603 | "cell_type": "markdown",
1604 | "metadata": {},
1605 | "source": [
1606 | "This key idea of this approach is to seek an approximation $q(\\omega)$\n",
1607 | "to the intractable posterior $p(\\omega \\mid D, m)$, via a variational\n",
1608 | "optimization problem over some tractable family of distributions $\\mathcal{Q}$:\n",
1609 | "\n",
1610 | "$$\n",
1611 | " q^*(\\omega)\n",
1612 | " \\in \\arg \\min_{q\\in \\mathcal{Q}} \\mathrm{KL}(q(\\omega) \\| p(\\omega \\mid D, m))\n",
1613 | " \\,, $$\n",
1614 | "\n",
1615 | "where the Kullback-Leibler divergence between $P$ and $Q$ ($P\\ll Q$)\n",
1616 | "with densities $p$ and $q$, respectively, is given by\n",
1617 | "\n",
1618 | "$$\n",
1619 | " \\mathrm{KL}(q(\\omega) \\| p(\\omega))\n",
1620 | "% = \\mathbb{E}_{\\omega \\sim Q} \\log \\tfrac{dQ}{dP}(\\omega)\n",
1621 | " = \\mathbb{E}_{\\omega \\sim q(\\omega)}\n",
1622 | " \\log \\tfrac{q(\\omega)}{p(\\omega)}\n",
1623 | " \\,. \\tag{kl-div} $$\n",
1624 | "\n",
1625 | "\n",
1626 | "Note that the family of variational approximations $\\mathcal{Q}$ can be\n",
1627 | "structured **arbitrarily**: point-mass, products, mixture, dependent on\n",
1628 | "input, having mixed hierarchical structure, -- any valid distribution."
1629 | ]
1630 | },
1631 | {
1632 | "cell_type": "markdown",
1633 | "metadata": {},
1634 | "source": [
1635 | "Although computing the divergence w.r.t. the unknown posterior\n",
1636 | "is still hard and intractable, it is possible to do away with it\n",
1637 | "through the following identity, which is based on the Bayes rule.\n",
1638 | "\n",
1639 | "For **any** $q(\\omega) \\ll p(\\omega \\mid D; \\phi)$ and any model $m$\n",
1640 | "\n",
1641 | "$$\n",
1642 | "\\begin{align}\n",
1643 | " \\overbrace{\n",
1644 | " \\log p(D \\mid m)\n",
1645 | " }^{\\text{evidence}}\n",
1646 | " &= \\overbrace{\n",
1647 | " \\mathbb{E}_{\\omega \\sim q} \\log p(D\\mid \\omega, m)\n",
1648 | " }^{\\text{expected conditional likelihood}}\n",
1649 | " - \\overbrace{\n",
1650 | " \\mathrm{KL}(q(\\omega)\\| \\pi(\\omega \\mid m))\n",
1651 | " }^{\\text{proximity to prior belief}}\n",
1652 | " \\\\\n",
1653 | " &+ \\underbrace{\n",
1654 | " \\mathrm{KL}(q(\\omega)\\| p(\\omega \\mid D, m))\n",
1655 | " }_{\\text{posterior approximation}}\n",
1656 | "\\end{align}\n",
1657 | " \\,. \\tag{master-identity}\n",
1658 | "$$"
1659 | ]
1660 | },
1661 | {
1662 | "cell_type": "markdown",
1663 | "metadata": {},
1664 | "source": [
1665 | "Instead of minimizing the divergence of the approximation from the posterior,\n",
1666 | "we maximize the **Evidence Lower Bound** with respect to $q(\\omega)$:\n",
1667 | "\n",
1668 | "$$\n",
1669 | " q^* \\in\n",
1670 | " \\arg\\max_{q\\in Q}\n",
1671 | " \\mathcal{L}(q) = \n",
1672 | " \\mathbb{E}_{\\omega \\sim q} \\log p(D\\mid \\omega, m)\n",
1673 | " - \\mathrm{KL}(q(\\omega)\\| \\pi(\\omega \\mid m))\n",
1674 | " \\,. \\tag{max-ELBO} $$\n",
1675 | "\n",
1676 | "* the expected $\\log$-likelihood favours $q$ that place their mass on\n",
1677 | "parameters $\\omega$ that explain $D$ under the specified model $m$.\n",
1678 | "\n",
1679 | "* the negative KL-divergence discourages the approximation $q$\n",
1680 | "from straying too far away from to the prior belief $\\pi$ under $m$."
1681 | ]
1682 | },
1683 | {
1684 | "cell_type": "markdown",
1685 | "metadata": {},
1686 | "source": [
1687 | "We usually consider the following setup (conditioning on model $m$ is omitted):\n",
1688 | "* the likelihood factorizes $\n",
1689 | "p(D \\mid \\omega)\n",
1690 | " = \\prod_i p(y_i, x_i \\mid \\omega)\n",
1691 | " \\propto \\prod_i p(y_i \\mid x_i, \\omega)\n",
1692 | "$\n",
1693 | "for $D = (x_i, y_i)_{i=1}^N$\n",
1694 | "\n",
1695 | "* the approximation is parameterized by $\\theta$: $q(\\omega\\mid \\theta)$\n",
1696 | "\n",
1697 | "* the prior on $\\omega$ itself depends on hyper-parameters $\\lambda$, that\n",
1698 | " can be fixed, or variable ($\\pi(\\omega \\mid \\lambda)$)."
1699 | ]
1700 | },
1701 | {
1702 | "cell_type": "markdown",
1703 | "metadata": {},
1704 | "source": [
1705 | "In this case the variational objective (evidence lower bound)\n",
1706 | "\n",
1707 | "$$\n",
1708 | " \\log p(D\\mid \\lambda )\n",
1709 | " \\geq \\mathcal{L}(\\theta, \\lambda)\n",
1710 | " = \\mathbb{E}_{\\omega \\sim q(\\omega \\mid \\theta)}\n",
1711 | " \\sum_i \\log p_\\phi(y_i \\mid x_i, \\omega)\n",
1712 | " - KL(q(\\omega \\mid \\theta) \\| \\pi(\\omega \\mid \\lambda))\n",
1713 | " $$\n",
1714 | "\n",
1715 | "is maximized with respect to $\\theta$ (to approximate the posterior)."
1716 | ]
1717 | },
1718 | {
1719 | "cell_type": "markdown",
1720 | "metadata": {},
1721 | "source": [
1722 | "Priors can be\n",
1723 | "* *subjective*, i.e. reflecting prior beliefs (but not arbitrary),\n",
1724 | "* *objective*, i.e. reflecting our lack of knowledge,\n",
1725 | "* *empirical*, i.e. learnt from data (we also optimize over hyper-parameters $\\lambda$)"
1726 | ]
1727 | },
1728 | {
1729 | "cell_type": "markdown",
1730 | "metadata": {},
1731 | "source": [
1732 | "The stochastic variant of ELBO is formed by randomly batching\n",
1733 | "the dataset $D$:\n",
1734 | "\n",
1735 | "$$\n",
1736 | " \\mathcal{L}(\\theta, \\lambda)\n",
1737 | " \\approx \\mathcal{L}_\\mathrm{SGVB}(\\theta, \\lambda)\n",
1738 | " = \\lvert D \\rvert \\biggl(\n",
1739 | " \\tfrac1{\\lvert B \\rvert}\n",
1740 | " \\sum_{b \\in B} \\mathbb{E}_{\\omega \\sim q(\\omega \\mid \\theta)}\n",
1741 | " \\log p(y_b \\mid x_b, \\omega)\n",
1742 | " \\biggr)\n",
1743 | " - KL(q(\\omega \\mid \\theta) \\| \\pi(\\omega \\mid \\lambda))\n",
1744 | " \\,. $$\n",
1745 | "\n",
1746 | "* Stochastic optimization follows noisy unbiased gradient estimates, which are\n",
1747 | "usually cheap, allow escaping from local optima, and optimize the objective in\n",
1748 | "expectation."
1749 | ]
1750 | },
1751 | {
1752 | "cell_type": "markdown",
1753 | "metadata": {},
1754 | "source": [
1755 | "In order to get a gradient of $\n",
1756 | " F_\\theta = \\mathbb{E}_{\\omega \\sim q(\\omega \\mid \\theta)} f(\\omega)\n",
1757 | "$ w.r.t $\\theta$ we use either:\n",
1758 | "\n",
1759 | "###### (REINFORCE)\n",
1760 | "$\n",
1761 | "\\nabla_\\theta F_\\theta\n",
1762 | " = \\mathbb{E}_{\\omega \\sim q(\\omega \\mid \\theta)}\n",
1763 | " (f(\\omega) - b_\\theta) \\nabla_\\theta \\log q(\\omega \\mid \\theta)\n",
1764 | "$\n",
1765 | "* for some $b_\\theta$ that is used to control variance\n",
1766 | "\n",
1767 | "###### (reparameterization)\n",
1768 | "$\n",
1769 | "\\nabla_\\theta F_\\theta\n",
1770 | " = \\nabla_\\theta \\mathbb{E}_{\\varepsilon \\sim q(\\varepsilon)}\n",
1771 | " f(g(\\theta; \\varepsilon))\n",
1772 | " = \\mathbb{E}_{\\varepsilon \\sim q(\\varepsilon)}\n",
1773 | " \\nabla_\\theta g(\\theta; \\varepsilon)\n",
1774 | " \\nabla_\\omega f(\\omega) \\big\\vert_{\\omega = g(\\theta; \\varepsilon)}\n",
1775 | "$\n",
1776 | "* when there are $q$ and differentiable $g$ such that sampling from\n",
1777 | "$q(\\omega \\mid \\theta)$ is equivalent to $\\omega = g(\\theta; \\varepsilon)$\n",
1778 | "with $\\varepsilon \\sim q(\\varepsilon)$."
1779 | ]
1780 | },
1781 | {
1782 | "cell_type": "markdown",
1783 | "metadata": {},
1784 | "source": [
1785 | "The variational approximation might yield high dimensional integrals,\n",
1786 | "which are slow/prohibitive to compute. To make the computations faster\n",
1787 | "without foregoing much of the precision, we may use Monte Carlo methods:\n",
1788 | "\n",
1789 | "$$\n",
1790 | " \\mathbb{E}_{\\omega \\sim q(\\omega\\mid \\theta)} \\, f(\\omega)\n",
1791 | " \\overset{\\text{MC}}{\\approx}\n",
1792 | " \\frac1{\\lvert \\mathcal{W}\\rvert}\n",
1793 | " \\sum_{\\omega \\in \\mathcal{W}} f(\\omega)\n",
1794 | " \\,,\n",
1795 | "$$\n",
1796 | "\n",
1797 | "where $\\mathcal{W} = (\\omega_b)_{b=1}^B$ is a sample of independent draws\n",
1798 | "from $q(\\omega\\mid \\theta)$."
1799 | ]
1800 | },
1801 | {
1802 | "cell_type": "markdown",
1803 | "metadata": {},
1804 | "source": [
1805 | "If we also approximate the expectation in the gradient of ELBO\n",
1806 | "via Monte Carlo we get **doubly stochastic variational objective**:\n",
1807 | "\n",
1808 | "$$\n",
1809 | " \\nabla_\\theta \\mathcal{L}_\\mathrm{DSVB}(\\theta, \\lambda)\n",
1810 | " \\approx\n",
1811 | " \\lvert D \\rvert \\biggl(\n",
1812 | " \\tfrac1{\\lvert B \\rvert}\n",
1813 | " \\sum_{b \\in B}\n",
1814 | " \\mathop{gradient}(x_b, y_b)\n",
1815 | " \\biggr)\n",
1816 | " - \\nabla_\\theta KL(q(\\omega \\mid \\theta) \\| \\pi(\\omega \\mid \\lambda))\n",
1817 | " \\,, $$\n",
1818 | "\n",
1819 | "where `gradient` is $\n",
1820 | " \\nabla_\\theta\n",
1821 | " \\mathbb{E}_{\\omega \\sim q(\\omega \\mid \\theta)}\n",
1822 | " \\log p(y \\mid x, \\omega)\n",
1823 | "$ using one of the approaches above, typically approximated using\n",
1824 | "one independent draw of $\\omega$ per $b\\in B$."
1825 | ]
1826 | },
1827 | {
1828 | "cell_type": "markdown",
1829 | "metadata": {},
1830 | "source": [
1831 | "We can use a similar sampling approach to compute the gradient of the divergence term."
1832 | ]
1833 | },
1834 | {
1835 | "cell_type": "markdown",
1836 | "metadata": {},
1837 | "source": [
1838 | "A good overview of Bayesian Inference can be found at [bdl101.ml](http://bdl101.ml/),\n",
1839 | "in [this lecture](http://mlg.eng.cam.ac.uk/zoubin/talks/lect1bayes.pdf),\n",
1840 | "[this paper](https://arxiv.org/abs/1206.7051.pdf), or\n",
1841 | "[this review](https://arxiv.org/abs/1601.00670.pdf),\n",
1842 | "among other great resources. It is also possible to consult\n",
1843 | "the references at [wiki](https://en.wikipedia.org/wiki/Bayesian_inference)."
1844 | ]
1845 | },
1846 | {
1847 | "cell_type": "markdown",
1848 | "metadata": {},
1849 | "source": [
1850 | "We can estimate the divergence term in the ELBO\n",
1851 | "with Monte Carlo, or, for example, for the predictive distribution\n",
1852 | "we have\n",
1853 | "\n",
1854 | "$$\n",
1855 | "\\begin{align}\n",
1856 | " \\mathbb{E}_{y\\sim p(y\\mid x, D, m)} \\, g(y)\n",
1857 | " &\\overset{\\text{BI}}{=}\n",
1858 | " \\mathbb{E}_{\\omega\\sim p(\\omega \\mid D, m)}\n",
1859 | " \\mathbb{E}_{y\\sim p(y\\mid x, \\omega, D, m)} \\, g(y) \n",
1860 | " \\\\\n",
1861 | " &\\overset{\\text{VI}}{\\approx}\n",
1862 | " \\mathbb{E}_{\\omega\\sim q(\\omega)}\n",
1863 | " \\mathbb{E}_{y\\sim p(y\\mid x, \\omega, D, m)} \\, g(y)\n",
1864 | " \\\\\n",
1865 | " &\\overset{\\text{MC}}{\\approx}\n",
1866 | "% \\hat{\\mathbb{E}}_{\\omega \\sim \\mathcal{W}}\n",
1867 | "% \\mathbb{E}_{y\\sim p(y\\mid x, \\omega, D, m)} \\, g(y)\n",
1868 | " \\frac1{\\lvert \\mathcal{W}\\rvert} \\sum_{\\omega \\in \\mathcal{W}}\n",
1869 | " \\mathbb{E}_{y\\sim p(y\\mid x, \\omega, D, m)} \\, g(y)\n",
1870 | " \\,,\n",
1871 | "\\end{align}\n",
1872 | "$$\n",
1873 | "\n",
1874 | "where $\\mathcal{W} = (\\omega_b)_{b=1}^B \\sim q(\\omega)$\n",
1875 | "-- iid samples from the variational approximation."
1876 | ]
1877 | },
1878 | {
1879 | "cell_type": "markdown",
1880 | "metadata": {},
1881 | "source": [
1882 | "
"
1883 | ]
1884 | }
1885 | ],
1886 | "metadata": {
1887 | "kernelspec": {
1888 | "display_name": "Python 3",
1889 | "language": "python",
1890 | "name": "python3"
1891 | },
1892 | "language_info": {
1893 | "codemirror_mode": {
1894 | "name": "ipython",
1895 | "version": 3
1896 | },
1897 | "file_extension": ".py",
1898 | "mimetype": "text/x-python",
1899 | "name": "python",
1900 | "nbconvert_exporter": "python",
1901 | "pygments_lexer": "ipython3",
1902 | "version": "3.7.4"
1903 | }
1904 | },
1905 | "nbformat": 4,
1906 | "nbformat_minor": 2
1907 | }
1908 |
--------------------------------------------------------------------------------
/Bayesian Deep Learning part2.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# MLSS2019: Bayesian Deep Learning"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "In this tutorial we will uncertainty estimation can be\n",
15 | "used in active learning or expert-in-the-loop pipelines."
16 | ]
17 | },
18 | {
19 | "cell_type": "markdown",
20 | "metadata": {},
21 | "source": [
22 | "The plan of the tutorial\n",
23 | "1. [Imports and definitions](#Imports-and-definitions)\n",
24 | "2. [Bayesian Active Learning with images](#Bayesian-Active-Learning-with-images)\n",
25 | " 1. [The model](#The-model)\n",
26 | " 2. [the Acquisition Function](#the-Acquisition-Function)\n",
27 | " 3. [Data and the Oracle](#Data-and-the-Oracle)\n",
28 | " 4. [the Active Learning loop](#the-Active-Learning-loop)\n",
29 | " 5. [The baseline](#The-baseline)\n",
30 | "3. [Bayesian Active Learning by Disagreement](#Bayesian-Active-Learning-by-Disagreement)\n",
31 | " 1. [Points of improvement: batch-vs-single](#Points-of-improvement:-batch-vs-single)\n",
32 | " 2. [Points of improvement: bias](#Points-of-improvement:-bias)\n"
33 | ]
34 | },
35 | {
36 | "cell_type": "markdown",
37 | "metadata": {},
38 | "source": [
39 | "**(note)**\n",
40 | "* to view documentation on something type in `something?` (with one question mark)\n",
41 | "* to view code of something type in `something??` (with two question marks)."
42 | ]
43 | },
44 | {
45 | "cell_type": "markdown",
46 | "metadata": {},
47 | "source": [
48 | "
"
49 | ]
50 | },
51 | {
52 | "cell_type": "markdown",
53 | "metadata": {},
54 | "source": [
55 | "## Imports and definitions"
56 | ]
57 | },
58 | {
59 | "cell_type": "markdown",
60 | "metadata": {},
61 | "source": [
62 | "In this section we import necessary modules and functions and\n",
63 | "define the computational device."
64 | ]
65 | },
66 | {
67 | "cell_type": "markdown",
68 | "metadata": {},
69 | "source": [
70 | "First, we install some boilerplate service code for this tutorial."
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "execution_count": null,
76 | "metadata": {},
77 | "outputs": [],
78 | "source": [
79 | "!pip install -q --upgrade git+https://github.com/ivannz/mlss2019-bayesian-deep-learning.git"
80 | ]
81 | },
82 | {
83 | "cell_type": "markdown",
84 | "metadata": {},
85 | "source": [
86 | "Next, numpy for computing, matplotlib for plotting and tqdm for progress bars."
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": null,
92 | "metadata": {},
93 | "outputs": [],
94 | "source": [
95 | "import tqdm\n",
96 | "import numpy as np\n",
97 | "\n",
98 | "%matplotlib inline\n",
99 | "import matplotlib.pyplot as plt"
100 | ]
101 | },
102 | {
103 | "cell_type": "markdown",
104 | "metadata": {},
105 | "source": [
106 | "For deep learning stuff will be using [pytorch](https://pytorch.org/).\n",
107 | "\n",
108 | "If you are unfamiliar with it, it is basically like `numpy` with autograd,\n",
109 | "native GPU support, and tools for building training and serializing models.\n",
110 | "\n",
111 | "\n",
112 | "There are good introductory tutorials on `pytorch`, like this\n",
113 | "[one](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html)."
114 | ]
115 | },
116 | {
117 | "cell_type": "code",
118 | "execution_count": null,
119 | "metadata": {},
120 | "outputs": [],
121 | "source": [
122 | "import torch\n",
123 | "import torch.nn.functional as F\n",
124 | "\n",
125 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
126 | ]
127 | },
128 | {
129 | "cell_type": "markdown",
130 | "metadata": {},
131 | "source": [
132 | "Next we import the boilerplate code.\n",
133 | "\n",
134 | "* a procedure that implements a minibatch SGD **fit** loop\n",
135 | "* a function, that **evaluates** the model on the provided dataset"
136 | ]
137 | },
138 | {
139 | "cell_type": "code",
140 | "execution_count": null,
141 | "metadata": {},
142 | "outputs": [],
143 | "source": [
144 | "from mlss2019bdl import fit, predict"
145 | ]
146 | },
147 | {
148 | "cell_type": "markdown",
149 | "metadata": {},
150 | "source": [
151 | "The algorithm to sample a random function is:\n",
152 | "* for $b = 1... B$ do:\n",
153 | "\n",
154 | " 1. draw an independent realization $f_b\\colon \\mathcal{X} \\to \\mathcal{Y}$\n",
155 | " with from the process $\\{f_\\omega\\}_{\\omega \\sim q(\\omega)}$\n",
156 | " 2. get $\\hat{y}_{bi} = f_b(\\tilde{x}_i)$ for $i=1 .. m$\n"
157 | ]
158 | },
159 | {
160 | "cell_type": "code",
161 | "execution_count": null,
162 | "metadata": {},
163 | "outputs": [],
164 | "source": [
165 | "from mlss2019bdl.bdl import freeze, unfreeze\n",
166 | "\n",
167 | "def sample_function(model, dataset, n_draws=1, verbose=False):\n",
168 | " \"\"\"Draw a realization of a random function.\"\"\"\n",
169 | " outputs = []\n",
170 | " for _ in tqdm.tqdm(range(n_draws), disable=not verbose):\n",
171 | " freeze(model)\n",
172 | "\n",
173 | " outputs.append(predict(model, dataset))\n",
174 | "\n",
175 | " unfreeze(model)\n",
176 | "\n",
177 | " return torch.stack(outputs, dim=0)"
178 | ]
179 | },
180 | {
181 | "cell_type": "markdown",
182 | "metadata": {},
183 | "source": [
184 | "Sample the class probabilities $p(y_x = k \\mid x, \\omega, m)$\n",
185 | "with $\\omega \\sim q(\\omega)$ by a model that **outputs raw class\n",
186 | "logit scores**."
187 | ]
188 | },
189 | {
190 | "cell_type": "code",
191 | "execution_count": null,
192 | "metadata": {},
193 | "outputs": [],
194 | "source": [
195 | "def sample_proba(model, dataset, n_draws=1):\n",
196 | " logits = sample_function(model, dataset, n_draws=n_draws)\n",
197 | "\n",
198 | " return F.softmax(logits, dim=-1)"
199 | ]
200 | },
201 | {
202 | "cell_type": "markdown",
203 | "metadata": {},
204 | "source": [
205 | "Get the predictive posterior class probabilities\n",
206 | "$$\n",
207 | "p(y_x = k \\mid x, m)\n",
208 | "% = \\mathbb{E}_{\\omega \\sim q(\\omega)}\n",
209 | "% p(y_x = k \\mid x, \\omega, m)\n",
210 | " \\approx \\frac1{\\lvert \\mathcal{W} \\rvert}\n",
211 | " \\sum_{\\omega \\in \\mathcal{W}}\n",
212 | " p(y_x = k \\mid x, \\omega, m)\n",
213 | " \\,, $$\n",
214 | "with $\\mathcal{W}$ -- iid draws from $q(\\omega)$."
215 | ]
216 | },
217 | {
218 | "cell_type": "code",
219 | "execution_count": null,
220 | "metadata": {},
221 | "outputs": [],
222 | "source": [
223 | "def predict_proba(model, dataset, n_draws=1):\n",
224 | " proba = sample_proba(model, dataset, n_draws=n_draws)\n",
225 | "\n",
226 | " return proba.mean(dim=0)"
227 | ]
228 | },
229 | {
230 | "cell_type": "markdown",
231 | "metadata": {},
232 | "source": [
233 | "Gat the maximum a posteriori class label **(MAP)**: $\n",
234 | "\\hat{y}_x\n",
235 | " = \\arg \\max_k \\mathbb{E}_{\\omega \\sim q(\\omega)}\n",
236 | " p(y_x = k \\mid x, \\omega, m)\n",
237 | "$"
238 | ]
239 | },
240 | {
241 | "cell_type": "code",
242 | "execution_count": null,
243 | "metadata": {},
244 | "outputs": [],
245 | "source": [
246 | "def predict_label(model, dataset, n_draws=1):\n",
247 | " proba = predict_proba(model, dataset, n_draws=n_draws)\n",
248 | "\n",
249 | " return proba.argmax(dim=-1)"
250 | ]
251 | },
252 | {
253 | "cell_type": "markdown",
254 | "metadata": {},
255 | "source": [
256 | "We will need some functionality from scikit"
257 | ]
258 | },
259 | {
260 | "cell_type": "code",
261 | "execution_count": null,
262 | "metadata": {},
263 | "outputs": [],
264 | "source": [
265 | "from sklearn.metrics import confusion_matrix\n",
266 | "\n",
267 | "def evaluate(model, dataset, n_draws=1):\n",
268 | " assert isinstance(dataset, TensorDataset)\n",
269 | "\n",
270 | " predicted = predict_label(model, dataset, n_draws=n_draws)\n",
271 | "\n",
272 | " target = dataset.tensors[1].cpu().numpy()\n",
273 | " return confusion_matrix(target, predicted.cpu().numpy())"
274 | ]
275 | },
276 | {
277 | "cell_type": "markdown",
278 | "metadata": {},
279 | "source": [
280 | "A function to plot images in a small dataset. "
281 | ]
282 | },
283 | {
284 | "cell_type": "code",
285 | "execution_count": null,
286 | "metadata": {},
287 | "outputs": [],
288 | "source": [
289 | "from mlss2019bdl.flex import plot\n",
290 | "from torch.utils.data import TensorDataset\n",
291 | "from IPython.display import clear_output\n",
292 | "\n",
293 | "def display(images, n_col=None, title=None, figsize=None, refresh=False):\n",
294 | " if isinstance(images, TensorDataset):\n",
295 | " images, targets = images.tensors\n",
296 | " \n",
297 | " if refresh:\n",
298 | " clear_output(True)\n",
299 | "\n",
300 | " fig, ax = plt.subplots(1, 1, figsize=figsize)\n",
301 | " plot(ax, images, n_col=n_col, cmap=plt.cm.bone)\n",
302 | " if title is not None:\n",
303 | " ax.set_title(title)\n",
304 | "\n",
305 | " plt.show()\n",
306 | " plt.close()"
307 | ]
308 | },
309 | {
310 | "cell_type": "markdown",
311 | "metadata": {},
312 | "source": [
313 | "
"
314 | ]
315 | },
316 | {
317 | "cell_type": "markdown",
318 | "metadata": {},
319 | "source": [
320 | "## Bayesian Active Learning with images"
321 | ]
322 | },
323 | {
324 | "cell_type": "markdown",
325 | "metadata": {},
326 | "source": [
327 | "* Data labelling is costly and time consuming\n",
328 | "* unlabeled instances are essentially free\n",
329 | "\n",
330 | "**Goal** Achieve high performance with fewer labels by\n",
331 | "identifying the best instances to learn from"
332 | ]
333 | },
334 | {
335 | "cell_type": "markdown",
336 | "metadata": {},
337 | "source": [
338 | "Essential blocks of active learning:\n",
339 | "\n",
340 | "* a **model** $m$ capable of quantifying uncertainty (preferably a Bayesian model)\n",
341 | "* an **acquisition function** $a\\colon \\mathcal{M} \\times \\mathcal{X}^* \\to \\mathbb{R}$\n",
342 | " that for any finite set of inputs $S\\subset \\mathcal{X}$ quantifies their usefulness\n",
343 | " to the model $m\\in \\mathcal{M}$\n",
344 | "* a labelling **oracle**, e.g. a human expert"
345 | ]
346 | },
347 | {
348 | "cell_type": "markdown",
349 | "metadata": {},
350 | "source": [
351 | "### The model"
352 | ]
353 | },
354 | {
355 | "cell_type": "markdown",
356 | "metadata": {},
357 | "source": [
358 | "We reuse the `DropoutLinear` from the first part."
359 | ]
360 | },
361 | {
362 | "cell_type": "code",
363 | "execution_count": null,
364 | "metadata": {},
365 | "outputs": [],
366 | "source": [
367 | "from torch.nn import Module, Sequential\n",
368 | "from torch.nn import AvgPool2d, LeakyReLU\n",
369 | "from torch.nn import Linear, Conv2d\n",
370 | "\n",
371 | "from mlss2019bdl.bdl import DropoutLinear, DropoutConv2d\n",
372 | "\n",
373 | "class MNISTModel(Module):\n",
374 | " def __init__(self, p=0.5):\n",
375 | " super().__init__()\n",
376 | "\n",
377 | " self.head = Sequential(\n",
378 | " Conv2d(1, 32, 3, 1),\n",
379 | " LeakyReLU(),\n",
380 | " DropoutConv2d(32, 64, 3, 1, p=p),\n",
381 | " LeakyReLU(),\n",
382 | " AvgPool2d(2),\n",
383 | " )\n",
384 | "\n",
385 | " self.tail = Sequential(\n",
386 | " DropoutLinear(12 * 12 * 64, 128, p=p),\n",
387 | " LeakyReLU(),\n",
388 | " DropoutLinear(128, 10, p=p),\n",
389 | " )\n",
390 | "\n",
391 | " def forward(self, input):\n",
392 | " \"\"\"Take images and compute their class logits.\"\"\"\n",
393 | " x = self.head(input)\n",
394 | " return self.tail(x.flatten(1))"
395 | ]
396 | },
397 | {
398 | "cell_type": "markdown",
399 | "metadata": {},
400 | "source": [
401 | "
"
402 | ]
403 | },
404 | {
405 | "cell_type": "markdown",
406 | "metadata": {},
407 | "source": [
408 | "### the Acquisition Function"
409 | ]
410 | },
411 | {
412 | "cell_type": "markdown",
413 | "metadata": {},
414 | "source": [
415 | "There are many acquisition criteria (borrowed from [Gal17a](http://proceedings.mlr.press/v70/gal17a.html)):\n",
416 | "* Classification\n",
417 | " * Posterior predictive entropy\n",
418 | " * Posterior Mutual Information\n",
419 | " * Variance ratios\n",
420 | " * BALD\n",
421 | "\n",
422 | "* Regression\n",
423 | " * predictive variance\n",
424 | "\n",
425 | "... and there is always the baseline **random acquisition**"
426 | ]
427 | },
428 | {
429 | "cell_type": "code",
430 | "execution_count": null,
431 | "metadata": {},
432 | "outputs": [],
433 | "source": [
434 | "random_state = np.random.RandomState(812_760_351)\n",
435 | "\n",
436 | "def random_acquisition(dataset, model, n_request=1, n_draws=1):\n",
437 | " indices = random_state.choice(len(dataset), size=n_request)\n",
438 | "\n",
439 | " return torch.from_numpy(indices).to(device)"
440 | ]
441 | },
442 | {
443 | "cell_type": "markdown",
444 | "metadata": {},
445 | "source": [
446 | "
"
447 | ]
448 | },
449 | {
450 | "cell_type": "markdown",
451 | "metadata": {},
452 | "source": [
453 | "### Data and the Oracle"
454 | ]
455 | },
456 | {
457 | "cell_type": "markdown",
458 | "metadata": {},
459 | "source": [
460 | "Prepare the datasets from the `train` part of\n",
461 | "[MNIST](http://yann.lecun.com/exdb/mnist/)\n",
462 | "(or [Kuzushiji-MNIST](https://github.com/rois-codh/kmnist)):\n",
463 | "* ($\\mathcal{S}_\\mathrm{train}$) initial **training**: $30$ images\n",
464 | "* ($\\mathcal{S}_\\mathrm{valid}$) our **validation**:\n",
465 | " $5000$ images, stratified\n",
466 | "* ($\\mathcal{S}_\\mathrm{pool}$) acquisition **pool**:\n",
467 | " $5000$ of the unused images, skewed to class $0$\n",
468 | "\n",
469 | "The true test sample of MNIST is in $\\mathcal{S}_\\mathrm{test}$ -- we\n",
470 | "will use it to evaluate the final performance."
471 | ]
472 | },
473 | {
474 | "cell_type": "code",
475 | "execution_count": null,
476 | "metadata": {},
477 | "outputs": [],
478 | "source": [
479 | "from mlss2019bdl.dataset import get_dataset\n",
480 | "\n",
481 | "S_train, S_pool, S_valid, S_test = get_dataset(\n",
482 | " n_train=30,\n",
483 | " n_valid=5000,\n",
484 | " n_pool=5000,\n",
485 | " name=\"MNIST\", # \"KMNIST\"\n",
486 | " path=\"./data\",\n",
487 | " random_state=722_257_201)"
488 | ]
489 | },
490 | {
491 | "cell_type": "markdown",
492 | "metadata": {},
493 | "source": [
494 | "* `query_oracle(ix, D)` **request** the instances in `D` at the specified\n",
495 | " indices `ix` into a dataset and **remove** from them from `D`\n",
496 | "\n",
497 | "* `merge(*datasets, [out=])` merge the datasets, creting a new one, or replacing `out`"
498 | ]
499 | },
500 | {
501 | "cell_type": "code",
502 | "execution_count": null,
503 | "metadata": {},
504 | "outputs": [],
505 | "source": [
506 | "from mlss2019bdl.dataset import collect as query_oracle"
507 | ]
508 | },
509 | {
510 | "cell_type": "markdown",
511 | "metadata": {},
512 | "source": [
513 | "
"
514 | ]
515 | },
516 | {
517 | "cell_type": "markdown",
518 | "metadata": {},
519 | "source": [
520 | "### the Active Learning loop"
521 | ]
522 | },
523 | {
524 | "cell_type": "markdown",
525 | "metadata": {},
526 | "source": [
527 | "1. fit $m$ on $\\mathcal{S}_{\\mathrm{labelled}}$\n",
528 | "\n",
529 | "\n",
530 | "2. get exact (or approximate) $$\n",
531 | " \\mathcal{S}^* \\in \\arg \\max\\limits_{S \\subseteq \\mathcal{S}_\\mathrm{unlabelled}}\n",
532 | " a(m, S)\n",
533 | "$$ satisfying **budget constraints** and **without** access to targets\n",
534 | "(constraints, like $\\lvert S \\rvert \\leq \\ell$ or other economically motivated ones).\n",
535 | "\n",
536 | "\n",
537 | "3. request the **oracle** to provide labels for each $x\\in \\mathcal{S}^*$\n",
538 | "\n",
539 | "\n",
540 | "4. update $\n",
541 | "\\mathcal{S}_{\\mathrm{labelled}}\n",
542 | " \\leftarrow \\mathcal{S}^*\n",
543 | " \\cup \\mathcal{S}_{\\mathrm{labelled}}\n",
544 | "$ and goto 1."
545 | ]
546 | },
547 | {
548 | "cell_type": "code",
549 | "execution_count": null,
550 | "metadata": {},
551 | "outputs": [],
552 | "source": [
553 | "import copy\n",
554 | "from mlss2019bdl.dataset import merge\n",
555 | "\n",
556 | "def active_learn(S_train,\n",
557 | " S_pool,\n",
558 | " S_valid,\n",
559 | " acquire_fn,\n",
560 | " n_budget=150,\n",
561 | " n_max_request=3,\n",
562 | " n_draws=11,\n",
563 | " n_epochs=200,\n",
564 | " p=0.5,\n",
565 | " weight_decay=1e-2):\n",
566 | "\n",
567 | " model = MNISTModel(p=p).to(device)\n",
568 | "\n",
569 | " scores, balances = [], []\n",
570 | " S_train, S_pool = copy.deepcopy(S_train), copy.deepcopy(S_pool)\n",
571 | " while True:\n",
572 | " # 1. fit on train\n",
573 | " l2_reg = weight_decay * (1 - p) / max(len(S_train), 1)\n",
574 | "\n",
575 | " model = fit(model, S_train, batch_size=32, criterion=\"cross_entropy\",\n",
576 | " weight_decay=l2_reg, n_epochs=n_epochs)\n",
577 | "\n",
578 | "\n",
579 | " # (optional) keep track of scores and plot the train dataset\n",
580 | " scores.append(evaluate(model, S_valid, n_draws))\n",
581 | " balances.append(np.bincount(S_train.tensors[1], minlength=10))\n",
582 | "\n",
583 | " accuracy = scores[-1].diagonal().sum() / scores[-1].sum()\n",
584 | " title = f\"(n_train) {len(S_train)} (Acc.) {accuracy:.1%}\"\n",
585 | " display(S_train, n_col=30, figsize=(15, 5), title=title, refresh=True)\n",
586 | "\n",
587 | "\n",
588 | " # 2-3. request new data from pool, if within budget\n",
589 | " n_request = min(n_budget - len(S_train), n_max_request)\n",
590 | " if n_request <= 0:\n",
591 | " break\n",
592 | "\n",
593 | " indices = acquire_fn(S_pool, model, n_request=n_request, n_draws=n_draws)\n",
594 | "\n",
595 | " # 4. update the train dataset\n",
596 | " S_requested = query_oracle(indices, S_pool)\n",
597 | " S_train = merge(S_train, S_requested)\n",
598 | "\n",
599 | " return model, S_train, np.stack(scores, axis=0), np.stack(balances, axis=0)"
600 | ]
601 | },
602 | {
603 | "cell_type": "markdown",
604 | "metadata": {},
605 | "source": [
606 | "* `collect(ix, D)` **collect** the instances in `D` at the specified\n",
607 | " indices `ix` into a dataset and **remove** from them from `D`\n",
608 | "\n",
609 | "* `merge(*datasets, [out=])` merge the datasets, creting a new one, or replacing `out`"
610 | ]
611 | },
612 | {
613 | "cell_type": "markdown",
614 | "metadata": {},
615 | "source": [
616 | "
"
617 | ]
618 | },
619 | {
620 | "cell_type": "markdown",
621 | "metadata": {},
622 | "source": [
623 | "### The baseline"
624 | ]
625 | },
626 | {
627 | "cell_type": "markdown",
628 | "metadata": {},
629 | "source": [
630 | "How powerful will our model with random acquisition get under a total budget of $150$ images?"
631 | ]
632 | },
633 | {
634 | "cell_type": "code",
635 | "execution_count": null,
636 | "metadata": {
637 | "scrolled": false
638 | },
639 | "outputs": [],
640 | "source": [
641 | "baseline = active_learn(\n",
642 | " S_train,\n",
643 | " S_pool,\n",
644 | " S_valid,\n",
645 | " random_acquisition,\n",
646 | " n_draws=21,\n",
647 | " n_budget=150,\n",
648 | " n_max_request=3,\n",
649 | " n_epochs=200,\n",
650 | ")"
651 | ]
652 | },
653 | {
654 | "cell_type": "markdown",
655 | "metadata": {},
656 | "source": [
657 | "Let's see the dynamics of the accuracy ..."
658 | ]
659 | },
660 | {
661 | "cell_type": "code",
662 | "execution_count": null,
663 | "metadata": {},
664 | "outputs": [],
665 | "source": [
666 | "def accuracy(scores):\n",
667 | " tp = scores.diagonal(axis1=-2, axis2=-1)\n",
668 | " return tp.sum(-1) / scores.sum((-2, -1))"
669 | ]
670 | },
671 | {
672 | "cell_type": "code",
673 | "execution_count": null,
674 | "metadata": {},
675 | "outputs": [],
676 | "source": [
677 | "model_rand, train_rand, scores_rand, balances_rand = baseline\n",
678 | "\n",
679 | "fig, ax = plt.subplots(1, 1, figsize=(12, 7))\n",
680 | "ax.plot(accuracy(scores_rand), label='Accuracy (random)', lw=2)\n",
681 | "\n",
682 | "ax.legend()\n",
683 | "plt.show()"
684 | ]
685 | },
686 | {
687 | "cell_type": "markdown",
688 | "metadata": {},
689 | "source": [
690 | "..., and the frequency of each class in $\\mathcal{S}_\\mathrm{train}$."
691 | ]
692 | },
693 | {
694 | "cell_type": "code",
695 | "execution_count": null,
696 | "metadata": {},
697 | "outputs": [],
698 | "source": [
699 | "fig, ax = plt.subplots(1, 1, figsize=(12, 7))\n",
700 | "\n",
701 | "lines = ax.plot(balances_rand, lw=2)\n",
702 | "plt.legend(lines, list(range(10)), ncol=2);"
703 | ]
704 | },
705 | {
706 | "cell_type": "markdown",
707 | "metadata": {},
708 | "source": [
709 | "
"
710 | ]
711 | },
712 | {
713 | "cell_type": "markdown",
714 | "metadata": {},
715 | "source": [
716 | "## Bayesian Active Learning by Disagreement"
717 | ]
718 | },
719 | {
720 | "cell_type": "markdown",
721 | "metadata": {},
722 | "source": [
723 | "Bayesian Active Learning by Disagreement, or **BALD** criterion, is\n",
724 | "based on the posterior mutual information between model's predictions\n",
725 | "$y_x$ at some point $x$ and its parameters $\\omega$:\n",
726 | "\n",
727 | "$$\\begin{align}\n",
728 | " a(m, S)\n",
729 | " &= \\sum_{x\\in S} a(m, \\{x\\})\n",
730 | " \\\\\n",
731 | " a(m, \\{x\\})\n",
732 | " &= \\mathbb{I}(y_x; \\omega \\mid x, m, D)\n",
733 | "\\end{align}\n",
734 | " \\,, \\tag{bald} $$\n",
735 | "\n",
736 | "with the [**Mutual Information**](https://en.wikipedia.org/wiki/Mutual_information#Relation_to_Kullback%E2%80%93Leibler_divergence)\n",
737 | "(**MI**)\n",
738 | "$$\n",
739 | " \\mathbb{I}(y_x; \\omega \\mid x, m, D)\n",
740 | " = \\mathbb{H}\\bigl(\n",
741 | " \\mathbb{E}_{\\omega \\sim q(\\omega\\mid m, D)}\n",
742 | " p(y_x \\,\\mid\\, x, \\omega, m, D)\n",
743 | " \\bigr)\n",
744 | " - \\mathbb{E}_{\\omega \\sim q(\\omega\\mid m, D)}\n",
745 | " \\mathbb{H}\\bigl(\n",
746 | " p(y_x \\,\\mid\\, x, \\omega, m, D)\n",
747 | " \\bigr)\n",
748 | " \\,, \\tag{mi} $$\n",
749 | "\n",
750 | "and the [(differential) **entropy**](https://en.wikipedia.org/wiki/Differential_entropy#Differential_entropies_for_various_distributions)\n",
751 | "(all densities and/or probability mass functions can be conditional):\n",
752 | "\n",
753 | "$$\n",
754 | " \\mathbb{H}(p(y))\n",
755 | " = - \\mathbb{E}_{y\\sim p} \\log p(y)\n",
756 | " \\,. $$"
757 | ]
758 | },
759 | {
760 | "cell_type": "markdown",
761 | "metadata": {},
762 | "source": [
763 | "
"
764 | ]
765 | },
766 | {
767 | "cell_type": "markdown",
768 | "metadata": {},
769 | "source": [
770 | "#### (task) Implementing the acquisition function"
771 | ]
772 | },
773 | {
774 | "cell_type": "markdown",
775 | "metadata": {},
776 | "source": [
777 | "Note that $a(m, S)$ is additively separable in $S$, i.e.\n",
778 | "equals $\\sum_{x\\in S} a(m, \\{x\\})$. This implies\n",
779 | "\n",
780 | "$$\n",
781 | "\\begin{align}\n",
782 | " \\max_{S \\subseteq \\mathcal{S}_\\mathrm{unlabelled}} a(m, S)\n",
783 | " &= \\max_{z \\in \\mathcal{S}_\\mathrm{unlabelled}}\n",
784 | " \\max_{F \\in \\mathcal{S}_\\mathrm{unlabelled} \\setminus \\{z\\}}\n",
785 | " \\sum_{x\\in F \\cup \\{x\\}} a(m, \\{x\\})\n",
786 | " \\\\\n",
787 | " &= \\max_{z \\in \\mathcal{S}_\\mathrm{unlabelled}}\n",
788 | " a(m, \\{z\\})\n",
789 | " + \\max_{F \\in \\mathcal{S}_\\mathrm{unlabelled} \\setminus \\{z\\}}\n",
790 | " \\sum_{x\\in F} a(m, \\{x\\})\n",
791 | "\\end{align}\n",
792 | " \\,. $$"
793 | ]
794 | },
795 | {
796 | "cell_type": "markdown",
797 | "metadata": {},
798 | "source": [
799 | "Therefore selecting the $\\ell$ `most interesting` points from\n",
800 | "$\\mathcal{S}_\\mathrm{unlabelled}$ is trivial."
801 | ]
802 | },
803 | {
804 | "cell_type": "markdown",
805 | "metadata": {},
806 | "source": [
807 | "The acquisition function that we implement has interface\n",
808 | "identical to `random_acquisition` but uses BALD to choose\n",
809 | "instances."
810 | ]
811 | },
812 | {
813 | "cell_type": "code",
814 | "execution_count": null,
815 | "metadata": {},
816 | "outputs": [],
817 | "source": [
818 | "def BALD_acquisition(dataset, model, n_request=1, n_draws=1):\n",
819 | " ## Exercise: implement BALD\n",
820 | "\n",
821 | " proba = sample_proba(model, dataset, n_draws=n_draws)\n",
822 | "\n",
823 | " mi = mutual_information(proba)\n",
824 | "\n",
825 | " return mi.argsort(descending=True)[:n_request]\n",
826 | "\n",
827 | " pass"
828 | ]
829 | },
830 | {
831 | "cell_type": "markdown",
832 | "metadata": {},
833 | "source": [
834 | "
"
835 | ]
836 | },
837 | {
838 | "cell_type": "markdown",
839 | "metadata": {},
840 | "source": [
841 | "#### (task) implementing entropy"
842 | ]
843 | },
844 | {
845 | "cell_type": "markdown",
846 | "metadata": {},
847 | "source": [
848 | "For categorical (discrete) random variables $y \\sim \\mathcal{Cat}(\\mathbf{p})$,\n",
849 | "$\\mathbf{p} \\in \\{ \\mu \\in [0, 1]^d \\colon \\sum_k \\mu_k = 1\\}$, the entropy is\n",
850 | "\n",
851 | "$$\n",
852 | " \\mathbb{H}(p(y))\n",
853 | " = - \\mathbb{E}_{y\\sim p(y)} \\log p(y)\n",
854 | " = - \\sum_k p_k \\log p_k\n",
855 | " \\,. $$"
856 | ]
857 | },
858 | {
859 | "cell_type": "markdown",
860 | "metadata": {},
861 | "source": [
862 | "**(note)** although in calculus $0 \\cdot \\log 0 = 0$ (because\n",
863 | "$\\lim_{p\\downarrow 0} p \\cdot \\log p = 0$), in floating point\n",
864 | "arithmetic $0 \\cdot \\log 0 = \\mathrm{NaN}$. So you need to add\n",
865 | "some **really tiny float number** to the argument of $\\log$."
866 | ]
867 | },
868 | {
869 | "cell_type": "code",
870 | "execution_count": null,
871 | "metadata": {},
872 | "outputs": [],
873 | "source": [
874 | "def categorical_entropy(proba):\n",
875 | " \"\"\"Compute the entropy along the last dimension.\"\"\"\n",
876 | "\n",
877 | " ## Exercise: the probabilities sum to one along the last axis.\n",
878 | " # Please, compute their entropy.\n",
879 | "\n",
880 | " return - torch.kl_div(torch.tensor(0.).to(proba), proba).sum(dim=-1)\n",
881 | "\n",
882 | " return - torch.sum(proba * torch.log(proba + 1e-20), dim=-1)\n",
883 | "\n",
884 | " pass"
885 | ]
886 | },
887 | {
888 | "cell_type": "markdown",
889 | "metadata": {},
890 | "source": [
891 | "
"
892 | ]
893 | },
894 | {
895 | "cell_type": "markdown",
896 | "metadata": {},
897 | "source": [
898 | "#### (task) implementing mutual information"
899 | ]
900 | },
901 | {
902 | "cell_type": "markdown",
903 | "metadata": {},
904 | "source": [
905 | "Consider a tensor $p_{bik}$ of probabilities $p(y_{x_i}=k \\mid x_i, \\omega_b, m, D)$\n",
906 | "with $\\omega_b \\sim q(\\omega \\mid m, D)$ with $\\mathcal{W} = (\\omega_b)_{b=1}^B$\n",
907 | "being iid draws from $q(\\omega \\mid m, D)$.\n",
908 | "\n",
909 | "Let's implement a procedure that computes the Monte Carlo estimate of the\n",
910 | "posterior predictive distribution, its **entropy** and **mutual information**\n",
911 | "\n",
912 | "$$\n",
913 | " \\mathbb{I}_\\mathrm{MC}(y_x; \\omega \\mid x, m, D)\n",
914 | " = \\mathbb{H}\\bigl(\n",
915 | " \\hat{p}(y_x\\mid x, m, D)\n",
916 | " \\bigr)\n",
917 | " - \\frac1{\\lvert \\mathcal{W} \\rvert} \\sum_{\\omega\\in \\mathcal{W}}\n",
918 | " \\mathbb{H}\\bigl(\n",
919 | " p(y_x \\,\\mid\\, x, \\omega, m, D)\n",
920 | " \\bigr)\n",
921 | " \\,, \\tag{mi-mc} $$\n",
922 | "where\n",
923 | "$$\n",
924 | "\\hat{p}(y_x\\mid x, m, D)\n",
925 | " = \\frac1{\\lvert \\mathcal{W} \\rvert} \\sum_{\\omega\\in \\mathcal{W}}\n",
926 | " \\,p(y_x \\mid x, \\omega, m, D)\n",
927 | " \\,. $$"
928 | ]
929 | },
930 | {
931 | "cell_type": "code",
932 | "execution_count": null,
933 | "metadata": {},
934 | "outputs": [],
935 | "source": [
936 | "def mutual_information(proba):\n",
937 | " ## Exercise: compute a Monte Carlo estimator of the predictive\n",
938 | " ## distribution, its entropy and MI `H E_w p(., w) - E_w H p(., w)`\n",
939 | "\n",
940 | " entropy_expected = categorical_entropy(proba.mean(dim=0))\n",
941 | " expected_entropy = categorical_entropy(proba).mean(dim=0)\n",
942 | "\n",
943 | " return entropy_expected - expected_entropy\n",
944 | "\n",
945 | " pass"
946 | ]
947 | },
948 | {
949 | "cell_type": "markdown",
950 | "metadata": {},
951 | "source": [
952 | "
"
953 | ]
954 | },
955 | {
956 | "cell_type": "markdown",
957 | "metadata": {},
958 | "source": [
959 | "How powerful will our model with **BALD** acquisition, if we can afford no more than $150$ images?"
960 | ]
961 | },
962 | {
963 | "cell_type": "code",
964 | "execution_count": null,
965 | "metadata": {
966 | "scrolled": false
967 | },
968 | "outputs": [],
969 | "source": [
970 | "bald_results = active_learn(\n",
971 | " S_train,\n",
972 | " S_pool,\n",
973 | " S_valid,\n",
974 | " BALD_acquisition,\n",
975 | " n_draws=21,\n",
976 | " n_budget=150,\n",
977 | " n_max_request=3,\n",
978 | " n_epochs=200,\n",
979 | ")"
980 | ]
981 | },
982 | {
983 | "cell_type": "markdown",
984 | "metadata": {},
985 | "source": [
986 | "Let's see the dynamics of the accuracy ..."
987 | ]
988 | },
989 | {
990 | "cell_type": "code",
991 | "execution_count": null,
992 | "metadata": {},
993 | "outputs": [],
994 | "source": [
995 | "model_bald, train_bald, scores_bald, balances_bald = bald_results\n",
996 | "\n",
997 | "fig, ax = plt.subplots(1, 1, figsize=(12, 7))\n",
998 | "\n",
999 | "ax.plot(accuracy(scores_rand), label='Accuracy (random)', lw=2)\n",
1000 | "ax.plot(accuracy(scores_bald), label='Accuracy (BALD)', lw=2)\n",
1001 | "\n",
1002 | "ax.legend()\n",
1003 | "plt.show()"
1004 | ]
1005 | },
1006 | {
1007 | "cell_type": "markdown",
1008 | "metadata": {},
1009 | "source": [
1010 | "..., and the frequency of each class in $\\mathcal{S}_\\mathrm{train}$."
1011 | ]
1012 | },
1013 | {
1014 | "cell_type": "code",
1015 | "execution_count": null,
1016 | "metadata": {},
1017 | "outputs": [],
1018 | "source": [
1019 | "fig, ax = plt.subplots(1, 1, figsize=(12, 7))\n",
1020 | "\n",
1021 | "lines = ax.plot(balances_bald, lw=2)\n",
1022 | "plt.legend(lines, list(range(10)), ncol=2);"
1023 | ]
1024 | },
1025 | {
1026 | "cell_type": "markdown",
1027 | "metadata": {},
1028 | "source": [
1029 | "
"
1030 | ]
1031 | },
1032 | {
1033 | "cell_type": "markdown",
1034 | "metadata": {},
1035 | "source": [
1036 | "#### Class performance"
1037 | ]
1038 | },
1039 | {
1040 | "cell_type": "markdown",
1041 | "metadata": {},
1042 | "source": [
1043 | "The *one-versus-rest* precision / recall scores on\n",
1044 | "$\\mathcal{S}_\\mathrm{valid}$. For binary classification:\n",
1045 | "\n",
1046 | "$$ \\begin{align}\n",
1047 | "\\mathrm{Precision}\n",
1048 | " &= \\frac{\\mathrm{TP}}{\\mathrm{TP} + \\mathrm{FP}}\n",
1049 | " \\approx \\mathbb{P}(y = 1 \\mid \\hat{y} = 1)\n",
1050 | " \\,, \\\\\n",
1051 | "\\mathrm{Recall}\n",
1052 | " &= \\frac{\\mathrm{TP}}{\\mathrm{TP} + \\mathrm{FN}}\n",
1053 | " \\approx \\mathbb{P}(\\hat{y} = 1 \\mid y = 1)\n",
1054 | " \\,.\n",
1055 | "\\end{align}$$"
1056 | ]
1057 | },
1058 | {
1059 | "cell_type": "code",
1060 | "execution_count": null,
1061 | "metadata": {},
1062 | "outputs": [],
1063 | "source": [
1064 | "import pandas as pd\n",
1065 | "\n",
1066 | "def pr_scores(score_matrix):\n",
1067 | " tp = score_matrix.diagonal(axis1=-2, axis2=-1)\n",
1068 | " fp, fn = score_matrix.sum(axis=-2) - tp, score_matrix.sum(axis=-1) - tp\n",
1069 | " \n",
1070 | " return pd.DataFrame({\n",
1071 | " \"precision\": {l: f\"{p:.2%}\" for l, p in enumerate(tp / (tp + fp))},\n",
1072 | " \"recall\": {l: f\"{p:.2%}\" for l, p in enumerate(tp / (tp + fn))},\n",
1073 | " })"
1074 | ]
1075 | },
1076 | {
1077 | "cell_type": "markdown",
1078 | "metadata": {},
1079 | "source": [
1080 | "Let's see the performance on the test set"
1081 | ]
1082 | },
1083 | {
1084 | "cell_type": "code",
1085 | "execution_count": null,
1086 | "metadata": {},
1087 | "outputs": [],
1088 | "source": [
1089 | "scores = {}\n",
1090 | "scores[\"rand\"] = evaluate(model_rand, S_test, n_draws=21)\n",
1091 | "scores[\"bald\"] = evaluate(model_bald, S_test, n_draws=21)"
1092 | ]
1093 | },
1094 | {
1095 | "cell_type": "markdown",
1096 | "metadata": {},
1097 | "source": [
1098 | "
"
1099 | ]
1100 | },
1101 | {
1102 | "cell_type": "code",
1103 | "execution_count": null,
1104 | "metadata": {},
1105 | "outputs": [],
1106 | "source": [
1107 | "df = pd.concat({\n",
1108 | " name: pr_scores(score)\n",
1109 | " for name, score in scores.items()\n",
1110 | "}, axis=1).T\n",
1111 | "\n",
1112 | "df.swaplevel().sort_index()"
1113 | ]
1114 | },
1115 | {
1116 | "cell_type": "markdown",
1117 | "metadata": {},
1118 | "source": [
1119 | "
"
1120 | ]
1121 | },
1122 | {
1123 | "cell_type": "markdown",
1124 | "metadata": {},
1125 | "source": [
1126 | "#### Question(s) (to work on in your spare time)\n",
1127 | "\n",
1128 | "* Run the experiments on the `KMNIST` dataset\n",
1129 | "\n",
1130 | "* Replicate figure 1 from [Gat et al. (2017): p. 4](http://proceedings.mlr.press/v70/gal17a.html).\n",
1131 | " You will need to re-run each experiment several times $11$, recording\n",
1132 | " the accuracy dynamics of each, then compare the mean and $25\\%$-$75\\%$\n",
1133 | " quantiles as they evolve with the size of the training sample."
1134 | ]
1135 | },
1136 | {
1137 | "cell_type": "markdown",
1138 | "metadata": {},
1139 | "source": [
1140 | "
"
1141 | ]
1142 | },
1143 | {
1144 | "cell_type": "markdown",
1145 | "metadata": {},
1146 | "source": [
1147 | "### (optional) Points of improvement: batch-vs-single"
1148 | ]
1149 | },
1150 | {
1151 | "cell_type": "markdown",
1152 | "metadata": {},
1153 | "source": [
1154 | "A drawback of the `pointwise` top-$\\ell$ procedure above is that, although\n",
1155 | "it acquires individually informative instances, altogether they might end\n",
1156 | "up **being** `jointly poorly informative`. This can be corrected, if we\n",
1157 | "would seek the highest mutual information among finite sets $\n",
1158 | "S \\subseteq \\mathcal{S}_\\mathrm{unlabelled}\n",
1159 | "$ of size $\\ell$."
1160 | ]
1161 | },
1162 | {
1163 | "cell_type": "markdown",
1164 | "metadata": {},
1165 | "source": [
1166 | "Such acquisition function is called **batch-BALD**\n",
1167 | "([Kirsch et al.; 2019](https://arxiv.org/abs/1906.08158.pdf)):\n",
1168 | "\n",
1169 | "$$\\begin{align}\n",
1170 | " a(m, S)\n",
1171 | " &= \\mathbb{I}\\bigl((y_x)_{x\\in S}; \\omega \\mid S, m \\bigr)\n",
1172 | " = \\mathbb{H} \\bigl(\n",
1173 | " \\mathbb{E}_{\\omega \\sim q(\\omega\\mid m)} p\\bigl((y_x)_{x\\in S}\\mid S, \\omega, m \\bigr)\n",
1174 | " \\bigr)\n",
1175 | " - \\mathbb{E}_{\\omega \\sim q(\\omega\\mid m)} H\\bigl(\n",
1176 | " p\\bigl((y_x)_{x\\in S}\\mid S, \\omega, m \\bigr)\n",
1177 | " \\bigr)\n",
1178 | "\\end{align}\n",
1179 | " \\,. \\tag{batch-bald} $$"
1180 | ]
1181 | },
1182 | {
1183 | "cell_type": "markdown",
1184 | "metadata": {},
1185 | "source": [
1186 | "This criterion requires combinatorially growing number of computations and\n",
1187 | "memory, however there are working solutions like random sampling of subsets\n",
1188 | "$\\mathcal{S}$ from $\\mathcal{S}_\\mathrm{unlabelled}$ or greedily maximizing\n",
1189 | "of this **submodular** criterion."
1190 | ]
1191 | },
1192 | {
1193 | "cell_type": "markdown",
1194 | "metadata": {},
1195 | "source": [
1196 | "
"
1197 | ]
1198 | },
1199 | {
1200 | "cell_type": "markdown",
1201 | "metadata": {},
1202 | "source": [
1203 | "### (optional) Points of improvement: bias"
1204 | ]
1205 | },
1206 | {
1207 | "cell_type": "markdown",
1208 | "metadata": {},
1209 | "source": [
1210 | "The first term in the **MC** estimate of the mutual information is the\n",
1211 | "so-called **plug-in** estimator of the entropy:\n",
1212 | "\n",
1213 | "$$\n",
1214 | " \\hat{H}\n",
1215 | " = \\mathbb{H}(\\hat{p}) = - \\sum_k \\hat{p}_k \\log \\hat{p}_k\n",
1216 | " \\,, $$\n",
1217 | "\n",
1218 | "where $\\hat{p}_k = \\tfrac1B \\sum_b p_{bk}$ is the full sample estimator\n",
1219 | "of the probabilities."
1220 | ]
1221 | },
1222 | {
1223 | "cell_type": "markdown",
1224 | "metadata": {},
1225 | "source": [
1226 | "It is known that this plug-in estimate is biased\n",
1227 | "(see [blog: Nowozin, 2015](http://www.nowozin.net/sebastian/blog/estimating-discrete-entropy-part-1.html)\n",
1228 | "and references therein, also this [notebook](https://colab.research.google.com/drive/1z9ZDNM6NFmuFnU28d8UO0Qymbd2LiNJW)). \n",
1229 | "In order to correct for small-sample bias we can use\n",
1230 | "[jackknife resampling](https://en.wikipedia.org/wiki/Jackknife_resampling).\n",
1231 | "It derives an estimate of the finite sample bias from the leave-one-out\n",
1232 | "estimators of the entropy and is relatively computationally cheap\n",
1233 | "(see [blog: Nowozin, 2015](http://www.nowozin.net/sebastian/blog/estimating-discrete-entropy-part-2.html),\n",
1234 | "[Miller, R. G. (1974)](http://www.math.ntu.edu.tw/~hchen/teaching/LargeSample/references/Miller74jackknife.pdf) and these [notes](http://people.bu.edu/aimcinto/jackknife.pdf)).\n",
1235 | "\n",
1236 | "The jackknife correction of a plug-in estimator $\\mathbb{H}(\\cdot)$\n",
1237 | "is computed thus: given a sample $(p_b)_{b=1}^B$ with $p_b$ -- discrete distribution on $1..K$\n",
1238 | "* for each $b=1.. B$\n",
1239 | " * get the leave-one-out estimator: $\\hat{p}_k^{-b} = \\tfrac1{B-1} \\sum_{j\\neq b} p_{jk}$\n",
1240 | " * compute the plug-in entropy estimator: $\\hat{H}_{-b} = \\mathbb{H}(\\hat{p}^{-b})$\n",
1241 | "* then compute the bias-corrected entropy estimator $\n",
1242 | "\\hat{H}_J\n",
1243 | " = \\hat{H} + (B - 1) \\bigl\\{\n",
1244 | " \\hat{H} - \\tfrac1B \\sum_b \\hat{H}^{-b}\n",
1245 | " \\bigr\\}\n",
1246 | "$"
1247 | ]
1248 | },
1249 | {
1250 | "cell_type": "markdown",
1251 | "metadata": {},
1252 | "source": [
1253 | "**(note)** when we knock the $i$-th data point out of the sample mean\n",
1254 | "$\\mu = \\tfrac1n \\sum_i x_i$ and recompute the mean $\\mu_{-i}$ we get\n",
1255 | "the following relation\n",
1256 | "$$ \\mu_{-i}\n",
1257 | " = \\frac1{n-1} \\sum_{j\\neq i} x_j\n",
1258 | " = \\frac{n}{n-1} \\mu - \\tfrac1{n-1} x_i\n",
1259 | " = \\mu + \\frac{\\mu - x_i}{n-1}\n",
1260 | " \\,. $$\n",
1261 | "This makes it possible to quickly compute leave-one-out estimators of\n",
1262 | "discrete probability distribution."
1263 | ]
1264 | },
1265 | {
1266 | "cell_type": "markdown",
1267 | "metadata": {},
1268 | "source": [
1269 | "#### (task*) Unbiased estimator of entropy and mutual information\n",
1270 | "\n",
1271 | "Try to efficiently implement a bias-corrected acquisition\n",
1272 | "function, and see it is worth the effort."
1273 | ]
1274 | },
1275 | {
1276 | "cell_type": "code",
1277 | "execution_count": null,
1278 | "metadata": {},
1279 | "outputs": [],
1280 | "source": [
1281 | "def BALD_jknf_acquisition(dataset, model, n_request=1, n_draws=1):\n",
1282 | " proba = sample_proba(model, dataset, n_draws=n_draws)\n",
1283 | "\n",
1284 | " ## Exercise: MC estimate of the predictive distribution, entropy and MI\n",
1285 | " ## mutual information `H E_w p(., w) - E_w H p(., w)` with jackknife\n",
1286 | " ## correction.\n",
1287 | "\n",
1288 | " # plug-in estimate of entropy \n",
1289 | " proba_avg = proba.mean(dim=0)\n",
1290 | " entropy_expected = categorical_entropy(proba_avg)\n",
1291 | "\n",
1292 | " # jackknife correction\n",
1293 | " proba_loo = proba_avg + (proba_avg - proba) / (len(proba) - 1)\n",
1294 | " expected_entropy_loo = categorical_entropy(proba_loo).mean(dim=0)\n",
1295 | " entropy_expected += (len(proba) - 1) * (entropy_expected - expected_entropy_loo)\n",
1296 | "\n",
1297 | " mi = entropy_expected - categorical_entropy(proba).mean(dim=0)\n",
1298 | "\n",
1299 | " return mi.argsort(descending=True)[:n_request]"
1300 | ]
1301 | },
1302 | {
1303 | "cell_type": "markdown",
1304 | "metadata": {},
1305 | "source": [
1306 | "
"
1307 | ]
1308 | },
1309 | {
1310 | "cell_type": "markdown",
1311 | "metadata": {},
1312 | "source": [
1313 | "Let's see ..."
1314 | ]
1315 | },
1316 | {
1317 | "cell_type": "code",
1318 | "execution_count": null,
1319 | "metadata": {
1320 | "scrolled": false
1321 | },
1322 | "outputs": [],
1323 | "source": [
1324 | "jknf_results = active_learn(\n",
1325 | " S_train,\n",
1326 | " S_pool,\n",
1327 | " S_valid,\n",
1328 | " BALD_jknf_acquisition,\n",
1329 | " n_draws=21,\n",
1330 | " n_budget=150,\n",
1331 | " n_max_request=3,\n",
1332 | " n_epochs=200,\n",
1333 | ")"
1334 | ]
1335 | },
1336 | {
1337 | "cell_type": "code",
1338 | "execution_count": null,
1339 | "metadata": {},
1340 | "outputs": [],
1341 | "source": [
1342 | "fig, ax = plt.subplots(1, 1, figsize=(12, 7))\n",
1343 | "\n",
1344 | "model_jknf, train_jknf, scores_jknf, balances_jknf = jknf_results\n",
1345 | "ax.plot(accuracy(scores_rand), label='Accuracy (random)', lw=2)\n",
1346 | "ax.plot(accuracy(scores_bald), label='Accuracy (BALD)', lw=2)\n",
1347 | "ax.plot(accuracy(scores_jknf), label='Accuracy (BALD-jknf)', lw=2)\n",
1348 | "\n",
1349 | "ax.legend()\n",
1350 | "plt.show()"
1351 | ]
1352 | },
1353 | {
1354 | "cell_type": "code",
1355 | "execution_count": null,
1356 | "metadata": {},
1357 | "outputs": [],
1358 | "source": [
1359 | "fig, ax = plt.subplots(1, 1, figsize=(12, 7))\n",
1360 | "\n",
1361 | "lines = ax.plot(balances_jknf, lw=2)\n",
1362 | "plt.legend(lines, list(range(10)), ncol=2);"
1363 | ]
1364 | },
1365 | {
1366 | "cell_type": "markdown",
1367 | "metadata": {},
1368 | "source": [
1369 | "
"
1370 | ]
1371 | }
1372 | ],
1373 | "metadata": {
1374 | "kernelspec": {
1375 | "display_name": "Python 3",
1376 | "language": "python",
1377 | "name": "python3"
1378 | },
1379 | "language_info": {
1380 | "codemirror_mode": {
1381 | "name": "ipython",
1382 | "version": 3
1383 | },
1384 | "file_extension": ".py",
1385 | "mimetype": "text/x-python",
1386 | "name": "python",
1387 | "nbconvert_exporter": "python",
1388 | "pygments_lexer": "ipython3",
1389 | "version": "3.7.3"
1390 | }
1391 | },
1392 | "nbformat": 4,
1393 | "nbformat_minor": 2
1394 | }
1395 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MLSS2019: Bayesian Deep Learning
2 |
3 | ## Installation: colab
4 |
5 | In Google colab there is no need to clone the repo or preinstall anything --
6 | all jupyter runtimes come with the basic packages like numpy, scipy, and
7 | matplotlib and deep learning libraries keras, tensorflow, and pytorch.
8 |
9 | The only step to make is to change the runtime type to GPU in
10 | **Edit > Notebook settings or Runtime>Change runtime type** by selecting
11 | **GPU as Hardware accelerator**.
12 |
13 |
14 | ## Installation: local install
15 |
16 | Please make sure that you have the following packages installed:
17 | * tqdm
18 | * numpy
19 | * torch >= 1.1
20 |
21 | The most convenient way to ensure this is use Anaconda with python 3.7.
22 |
23 | When all prerequisites have been met, please, clone this repository and
24 | install it with:
25 |
26 | ```bash
27 | git clone https://github.com/ivannz/mlss2019-bayesian-deep-learning.git
28 |
29 | cd mlss2019-bayesian-deep-learning
30 |
31 | pip install --editable .
32 | ```
33 |
34 | This will install the necessary service python code that will make the
35 | seminar much more concise and, hopefully, your learning experience better.
36 |
37 |
38 | ## Versions
39 |
40 | The version presented at MLSS Moscow Aug 26 - Sep 5, 2019, can also be found
41 | in the [MLSS2019](https://github.com/mlss-skoltech/) repo. Here it sits under
42 | the tag `mlss2019-Aug-30`.
43 |
--------------------------------------------------------------------------------
/mlss2019bdl/__init__.py:
--------------------------------------------------------------------------------
1 | import tqdm
2 | import torch
3 |
4 | import torch.nn.functional as F
5 |
6 | from torch.utils.data import TensorDataset, DataLoader
7 |
8 |
9 | def dataset_from_numpy(*ndarrays, device=None, dtype=torch.float32):
10 | """Create :class:`TensorDataset` from the passed :class:`numpy.ndarray`-s.
11 |
12 | Each returned tensor in the TensorDataset and :attr:`ndarray` share
13 | the same memory, unless a type cast or device transfer took place.
14 | Modifications to any tensor in the dataset will be reflected in respective
15 | :attr:`ndarray` and vice versa.
16 |
17 | Each returned tensor in the dataset is not resizable.
18 |
19 | See Also
20 | --------
21 | torch.from_numpy : create a tensor from an ndarray.
22 | """
23 | tensors = map(torch.from_numpy, ndarrays)
24 |
25 | return TensorDataset(*[t.to(device, dtype) for t in tensors])
26 |
27 |
28 | default_criteria = {
29 | "cross_entropy":
30 | lambda model, X, y: F.cross_entropy(model(X), y, reduction="mean"),
31 | "mse":
32 | lambda model, X, y: 0.5 * F.mse_loss(model(X), y, reduction="mean"),
33 | }
34 |
35 |
36 | def fit(model, dataset, criterion="mse", batch_size=32,
37 | n_epochs=1, weight_decay=0, verbose=False):
38 | """Fit the model with SGD (Adam) on the specified dataset and criterion.
39 |
40 | This bare minimum of a fit loop creates a minibatch generator from
41 | the `dataset` with batches of size `batch_size`. On each batch it
42 | computes the backward pass through the `criterion` and the `model`
43 | and updates the `model`-s parameters with the Adam optimizer step.
44 | The loop passes through the dataset `n_epochs` times. It does not
45 | output any running debugging information, except for a progress bar.
46 |
47 | The criterion can be either "mse" for mean sqaured error, "nll" for
48 | negative loglikelihood (categorical), or a callable taking `model, X, y`
49 | as arguments.
50 | """
51 | if len(dataset) <= 0 or batch_size <= 0:
52 | return model
53 |
54 | criterion = default_criteria.get(criterion, criterion)
55 | assert callable(criterion)
56 |
57 | # get the model's device
58 | device = next(model.parameters()).device
59 |
60 | # an optimizer for model's parameters
61 | optim = torch.optim.Adam(model.parameters(), lr=2e-3,
62 | weight_decay=weight_decay)
63 |
64 | # stochastic minibatch generator for the training loop
65 | feed = DataLoader(dataset, shuffle=True, batch_size=batch_size)
66 | for epoch in tqdm.tqdm(range(n_epochs), disable=not verbose):
67 |
68 | model.train()
69 |
70 | for X, y in feed:
71 | # forward pass through the criterion (batch-average loss)
72 | loss = criterion(model, X.to(device), y.to(device))
73 |
74 | # get gradients with backward pass
75 | optim.zero_grad()
76 | loss.backward()
77 |
78 | # SGD update
79 | optim.step()
80 |
81 | return model
82 |
83 |
84 | def predict(model, dataset, batch_size=512):
85 | """Get model's output on the dataset.
86 |
87 | This straightforward function switches the model into `evaluation`
88 | regime, computes the forward pass on the `dataset` (in batches of
89 | size `batch_size`) and stacks the results into a tensor on the `cpu`.
90 | It temporarily disables `autograd` to gain some speed-up.
91 | """
92 | model.eval()
93 |
94 | # get the model's device
95 | device = next(model.parameters()).device
96 |
97 | # batch generator for the evaluation loop
98 | feed = DataLoader(dataset, batch_size=batch_size, shuffle=False)
99 |
100 | # compute and collect the outputs
101 | with torch.no_grad():
102 | return torch.cat([
103 | model(X.to(device)).cpu() for X, *rest in feed
104 | ], dim=0)
105 |
--------------------------------------------------------------------------------
/mlss2019bdl/bdl/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import freeze, unfreeze, penalties
2 |
3 |
4 | from .bernoulli import DropoutLinear, DropoutConv2d
5 | from .gaussian import GaussianLinearARD, GaussianConv2dARD
--------------------------------------------------------------------------------
/mlss2019bdl/bdl/base.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch.nn import Module
4 |
5 |
6 | class FreezableWeight(Module):
7 | def __init__(self):
8 | super().__init__()
9 | self.unfreeze()
10 |
11 | def unfreeze(self):
12 | self.register_buffer("frozen_weight", None)
13 |
14 | def is_frozen(self):
15 | """Check if a frozen weight is available."""
16 | return isinstance(self.frozen_weight, torch.Tensor)
17 |
18 | def freeze(self):
19 | """Sample from the distribution and freeze."""
20 | raise NotImplementedError()
21 |
22 |
23 | def freeze(module):
24 | for mod in module.modules():
25 | if isinstance(mod, FreezableWeight):
26 | mod.freeze()
27 |
28 | return module # return self
29 |
30 |
31 | def unfreeze(module):
32 | for mod in module.modules():
33 | if isinstance(mod, FreezableWeight):
34 | mod.unfreeze()
35 |
36 | return module # return self
37 |
38 |
39 | class PenalizedWeight(Module):
40 | def penalty(self):
41 | raise NotImplementedError()
42 |
43 |
44 | def penalties(module):
45 | for mod in module.modules():
46 | if isinstance(mod, PenalizedWeight):
47 | yield mod.penalty()
48 |
--------------------------------------------------------------------------------
/mlss2019bdl/bdl/bernoulli.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from torch.nn import Linear, Conv2d
5 |
6 | from .base import FreezableWeight, PenalizedWeight
7 |
8 |
9 | class DropoutLinear(Linear, FreezableWeight):
10 | """Linear layer with dropout on inputs."""
11 | def __init__(self, in_features, out_features, bias=True, p=0.5):
12 | super().__init__(in_features, out_features, bias=bias)
13 |
14 | self.p = p
15 |
16 | def forward(self, input):
17 | if self.is_frozen():
18 | return F.linear(input, self.frozen_weight, self.bias)
19 |
20 | return super().forward(F.dropout(input, self.p, True))
21 |
22 | def freeze(self):
23 | # let's draw the new weight
24 | with torch.no_grad():
25 | prob = torch.full_like(self.weight[:1, :], 1 - self.p)
26 | feature_mask = torch.bernoulli(prob) / prob
27 |
28 | frozen_weight = self.weight * feature_mask
29 |
30 | # and store it
31 | self.register_buffer("frozen_weight", frozen_weight)
32 |
33 |
34 | class DropoutConv2d(Conv2d, FreezableWeight):
35 | """2d Convolutional layer with dropout on input features."""
36 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
37 | padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros',
38 | p=0.5):
39 |
40 | super().__init__(in_channels, out_channels, kernel_size, stride=stride,
41 | padding=padding, dilation=dilation, groups=groups,
42 | bias=bias, padding_mode=padding_mode)
43 |
44 | self.p = p
45 |
46 | def forward(self, input):
47 | """Apply feature dropout and then forward pass through the convolution."""
48 | if self.is_frozen():
49 | return F.conv2d(input, self.frozen_weight, self.bias, self.stride,
50 | self.padding, self.dilation, self.groups)
51 |
52 | return super().forward(F.dropout2d(input, self.p, True))
53 |
54 | def freeze(self):
55 | """Sample the weight from the parameter distribution and freeze it."""
56 | prob = torch.full_like(self.weight[:1, :, :1, :1], 1 - self.p)
57 | feature_mask = torch.bernoulli(prob) / prob
58 |
59 | with torch.no_grad():
60 | frozen_weight = self.weight * feature_mask
61 |
62 | self.register_buffer("frozen_weight", frozen_weight)
63 |
64 |
--------------------------------------------------------------------------------
/mlss2019bdl/bdl/gaussian.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from torch.nn import Linear, Conv2d
5 |
6 | from .base import FreezableWeight, PenalizedWeight
7 |
8 |
9 | class BaseGaussianLinear(Linear, FreezableWeight, PenalizedWeight):
10 | """Linear layer with Gaussian Mean Field weight distribution."""
11 | def __init__(self, in_features, out_features, bias=True):
12 | super().__init__(in_features, out_features, bias=bias)
13 |
14 | self.log_sigma2 = torch.nn.Parameter(
15 | torch.Tensor(*self.weight.shape))
16 |
17 | self.reset_variational_parameters()
18 |
19 | def reset_variational_parameters(self):
20 | self.log_sigma2.data.normal_(-5, 0.1) # from arxiv:1811.00596
21 |
22 | def forward(self, input):
23 | """Forward pass for the linear layer with the local reparameterization trick."""
24 |
25 | if self.is_frozen():
26 | return F.linear(input, self.frozen_weight, self.bias)
27 |
28 | s2 = F.linear(input * input, torch.exp(self.log_sigma2), None)
29 |
30 | mu = super().forward(input)
31 | return mu + torch.randn_like(s2) * torch.sqrt(torch.clamp(s2, 1e-8))
32 |
33 | def freeze(self):
34 |
35 | with torch.no_grad():
36 | stdev = torch.exp(0.5 * self.log_sigma2)
37 | weight = torch.normal(self.weight, std=stdev)
38 |
39 | self.register_buffer("frozen_weight", weight)
40 |
41 |
42 | class BaseGaussianConv2d(Conv2d, PenalizedWeight, FreezableWeight):
43 | """Convolutional layer with Gaussian Mean Field weight distribution."""
44 |
45 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
46 | padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'):
47 | super().__init__(in_channels, out_channels, kernel_size, stride=stride,
48 | padding=padding, dilation=dilation, groups=groups,
49 | bias=bias, padding_mode=padding_mode)
50 |
51 | self.log_sigma2 = torch.nn.Parameter(
52 | torch.Tensor(*self.weight.shape))
53 |
54 | self.reset_variational_parameters()
55 |
56 | reset_variational_parameters = BaseGaussianLinear.reset_variational_parameters
57 |
58 | def forward(self, input):
59 | """Forward pass with the local reparameterization trick."""
60 | if self.is_frozen():
61 | return F.conv2d(input, self.frozen_weight, self.bias, self.stride,
62 | self.padding, self.dilation, self.groups)
63 |
64 | s2 = F.conv2d(input * input, torch.exp(self.log_sigma2), None,
65 | self.stride, self.padding, self.dilation, self.groups)
66 |
67 | mu = super().forward(input)
68 | return mu + torch.randn_like(s2) * torch.sqrt(torch.clamp(s2, 1e-8))
69 |
70 | freeze = BaseGaussianLinear.freeze
71 |
72 |
73 | class GaussianLinearARD(BaseGaussianLinear):
74 | def penalty(self):
75 | # compute \tfrac12 \log (1 + \tfrac{\mu_{ji}}{\sigma_{ji}^2})
76 | log_weight2 = 2 * torch.log(torch.abs(self.weight) + 1e-20)
77 |
78 | # `softplus` is $x \mapsto \log(1 + e^x)$
79 | return 0.5 * torch.sum(F.softplus(log_weight2 - self.log_sigma2))
80 |
81 |
82 | class GaussianConv2dARD(BaseGaussianConv2d):
83 | penalty = GaussianLinearARD.penalty
84 |
--------------------------------------------------------------------------------
/mlss2019bdl/dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 | from torch.utils.data import TensorDataset
5 | from torchvision import datasets
6 |
7 | from sklearn.utils import check_random_state
8 | from sklearn.model_selection import train_test_split
9 |
10 |
11 | def get_data(name, path="./data", train=True):
12 | if name == "MNIST":
13 | dataset = datasets.MNIST(path, train=train, download=True)
14 | elif name == "KMNIST":
15 | dataset = datasets.KMNIST(path, train=train, download=True)
16 |
17 | images = dataset.data.float().unsqueeze(1)
18 | return TensorDataset(images / 255., dataset.targets)
19 |
20 |
21 | def get_dataset(n_train=20, n_valid=5000, n_pool=5000,
22 | name="MNIST", path="./data", random_state=None):
23 | random_state = check_random_state(random_state)
24 |
25 | dataset = get_data(name, path, train=True)
26 | S_test = get_data(name, path, train=False)
27 |
28 | # create an imbalanced class label distribution for the train
29 | targets = dataset.tensors[-1].cpu().numpy()
30 |
31 | # split the dataset into validaton and train
32 | ix_all = np.r_[:len(targets)]
33 | ix_train, ix_valid = train_test_split(
34 | ix_all, stratify=targets, shuffle=True,
35 | train_size=max(n_train, 1), test_size=max(n_valid, 1),
36 | random_state=random_state)
37 |
38 | # prepare the datasets: pool, train and validation
39 | if n_train < 1:
40 | ix_train = np.r_[:0]
41 | S_train = TensorDataset(*dataset[ix_train])
42 |
43 | if n_valid < 1:
44 | ix_valid = np.r_[:0]
45 | S_valid = TensorDataset(*dataset[ix_valid])
46 |
47 | # prepare the pool
48 | ix_pool = np.delete(ix_all, np.r_[ix_train, ix_valid])
49 |
50 | # we want to have lots of boring/useless examples in the pool
51 | labels, share = (1, 2, 3, 4, 5, 6, 7, 8, 9), 0.95
52 | pool_targets, dropped = targets[ix_pool], []
53 |
54 | # deplete the pool of each class
55 | for label in labels:
56 | ix_cls = np.flatnonzero(pool_targets == label)
57 | n_kept = int(share * len(ix_cls))
58 |
59 | # pick examples at random to drop
60 | ix_cls = random_state.permutation(ix_cls)
61 | dropped.append(ix_cls[:n_kept])
62 |
63 | ix_pool = np.delete(ix_pool, np.concatenate(dropped))
64 |
65 | # select at most `n_pool` examples
66 | if n_pool > 0:
67 | ix_pool = random_state.permutation(ix_pool)[:n_pool]
68 | S_pool = TensorDataset(*dataset[ix_pool])
69 |
70 | return S_train, S_pool, S_valid, S_test
71 |
72 |
73 | def collect(indices, dataset):
74 | """Collect the specified samples from the dataset and remove."""
75 | assert len(dataset) > 0
76 |
77 | mask = torch.zeros(len(dataset), dtype=torch.bool)
78 | mask[indices.long()] = True
79 |
80 | collected = TensorDataset(*dataset[mask])
81 |
82 | dataset.tensors = dataset[~mask]
83 |
84 | return collected
85 |
86 |
87 | def merge(*datasets, out=None):
88 | # Classes derived from Dataset support appending via
89 | # `+` (__add__), but this breaks slicing.
90 |
91 | data = [d.tensors for d in datasets if d is not None and d.tensors]
92 | assert all(len(data[0]) == len(d) for d in data)
93 |
94 | tensors = [torch.cat(tup, dim=0) for tup in zip(*data)]
95 |
96 | if isinstance(out, TensorDataset):
97 | out.tensors = tensors
98 | return out
99 |
100 | return TensorDataset(*tensors)
101 |
--------------------------------------------------------------------------------
/mlss2019bdl/flex.py:
--------------------------------------------------------------------------------
1 | """Handy plotting procedures for small 2d images."""
2 | import numpy as np
3 |
4 | from torch import Tensor
5 | from math import sqrt
6 |
7 |
8 | def get_dimensions(n_samples, height, width,
9 | n_row=None, n_col=None, aspect=(16, 9)):
10 | """Get the dimensions that aesthetically conform to the aspect ratio."""
11 | if n_row is None and n_col is None:
12 | ratio = (width * aspect[1]) / (height * aspect[0])
13 | n_row = int(sqrt(n_samples * ratio))
14 |
15 | if n_row is None:
16 | n_row = (n_samples + n_col - 1) // n_col
17 |
18 | elif n_col is None:
19 | n_col = (n_samples + n_row - 1) // n_row
20 |
21 | return n_row, n_col
22 |
23 |
24 | def setup_canvas(ax, height, width, n_row, n_col):
25 | """Setup the ticks and labels for the canvas."""
26 | # A pair of index arrays
27 | row_index, col_index = np.r_[:n_row], np.r_[:n_col]
28 |
29 | # Setup major ticks to the seams between images and disable labels
30 | ax.set_yticks((row_index[:-1] + 1) * height - 0.5, minor=False)
31 | ax.set_xticks((col_index[:-1] + 1) * width - 0.5, minor=False)
32 |
33 | ax.set_yticklabels([], minor=False)
34 | ax.set_xticklabels([], minor=False)
35 |
36 | # Set minor ticks so that they are exactly between the major ones
37 | ax.set_yticks((row_index + 0.5) * height, minor=True)
38 | ax.set_xticks((col_index + 0.5) * width, minor=True)
39 |
40 | # ... and make their labels into i-j coordinates
41 | ax.set_yticklabels([f"{i:d}" for i in row_index], minor=True)
42 | ax.set_xticklabels([f"{j:d}" for j in col_index], minor=True)
43 |
44 | # Orient tick marks outward
45 | ax.tick_params(axis="both", which="both", direction="out")
46 | return ax
47 |
48 |
49 | def arrange(n_row, n_col, data, fill_value=0):
50 | """Create a grid and populate it with images."""
51 | n_samples, height, width, *color = data.shape
52 | grid = np.full((n_row * height, n_col * width, *color),
53 | fill_value, dtype=data.dtype)
54 |
55 | for k in range(min(n_samples, n_col * n_row)):
56 | i, j = (k // n_col) * height, (k % n_col) * width
57 | grid[i:i + height, j:j + width] = data[k]
58 |
59 | return grid
60 |
61 |
62 | def to_hwc(images, format):
63 | assert format in ("chw", "hwc"), f"Unrecognized format `{format}`."
64 |
65 | if images.ndim == 3:
66 | return images[..., np.newaxis]
67 |
68 | assert images.ndim == 4, f"Images must be Nx{'x'.join(format.upper())}."
69 |
70 | if format == "chw":
71 | return images.transpose(0, 2, 3, 1)
72 |
73 | elif format == "hwc":
74 | return images
75 |
76 |
77 | def plot(ax, images, *, n_col=None, n_row=None, format="chw", **kwargs):
78 | """Plot images in the numpy array on the specified matplotlib Axis."""
79 | if isinstance(images, Tensor):
80 | images = images.data.cpu().numpy()
81 |
82 | images = to_hwc(images, format)
83 |
84 | n_samples, height, width, *color = images.shape
85 | if n_samples < 1:
86 | return None
87 |
88 | n_row, n_col = get_dimensions(n_samples, height, width, n_row, n_col)
89 | ax = setup_canvas(ax, height, width, n_row, n_col)
90 |
91 | image = arrange(n_row, n_col, images)
92 | return ax.imshow(image.squeeze(), **kwargs, origin="upper")
93 |
--------------------------------------------------------------------------------
/mlss2019bdl/plotting.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 |
3 | from torch import Tensor
4 | from numpy import asarray
5 |
6 |
7 | def darker(color, a=0.5):
8 | """Adapted from this stackoverflow question_.
9 |
10 | .. _question: https://stackoverflow.com/questions/37765197/
11 | """
12 | from matplotlib.colors import to_rgb
13 | from colorsys import rgb_to_hls, hls_to_rgb
14 |
15 | h, l, s = rgb_to_hls(*to_rgb(color))
16 | return hls_to_rgb(h, max(0, min(a * l, 1)), s)
17 |
18 |
19 | def canvas1d(*, figsize=(12, 5)):
20 | """Setup canvas for 1d function plot."""
21 | fig, ax = plt.subplots(1, 1, figsize=figsize)
22 |
23 | fig.patch.set_alpha(1.0)
24 | ax.set_xlim(-7, +7)
25 | ax.set_ylim(-7, +9)
26 |
27 | return fig, ax
28 |
29 |
30 | def to_numpy(tensor):
31 | if isinstance(tensor, Tensor):
32 | tensor = tensor.data.cpu().numpy()
33 |
34 | return asarray(tensor).squeeze()
35 |
36 |
37 | def plot1d(X, y, bands, ax=None, **kwargs):
38 | X, y = to_numpy(X), to_numpy(y)
39 | assert y.ndim == 2 and X.ndim == 1
40 |
41 | ax = plt.gca() if ax is None else ax
42 |
43 | # plot the predictive mean with the specified colour
44 | y_mean, y_std = y.mean(axis=-1), y.std(axis=-1)
45 | line, = ax.plot(X, y_mean, **kwargs)
46 |
47 | # plot paths or bands with a lighter color and slightly behind the mean
48 | color, zorder = darker(line.get_color(), 1.25), line.get_zorder()
49 | if bands is None:
50 | ax.plot(X, y, c=color, alpha=0.08, zorder=zorder - 1)
51 |
52 | else:
53 | for band in sorted(bands):
54 | ax.fill_between(X, y_mean + band * y_std, y_mean - band * y_std,
55 | color=color, zorder=zorder-1,
56 | alpha=0.4 / len(bands))
57 |
58 | return line
59 |
60 |
61 | def plot1d_bands(X, y, ax=None, **kwargs):
62 | # return plot1d(X, y, bands=(0.5, 1.0, 1.5, 2.0), ax=ax, **kwargs)
63 | return plot1d(X, y, bands=(1.96,), ax=ax, **kwargs)
64 |
65 |
66 | def plot1d_paths(X, y, ax=None, **kwargs):
67 | return plot1d(X, y, bands=None, ax=ax, **kwargs)
68 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from distutils.core import setup
2 |
3 | setup(
4 | name="mlss2019bdl",
5 | version="0.2",
6 | description="""MLSS2019 Tutorial on Bayesian Active Learning""",
7 | license="MIT License",
8 | author="Ivan Nazarov, Yarin Gal",
9 | author_email="ivan.nazarov@skolkovotech.ru",
10 | packages=[
11 | "mlss2019bdl",
12 | "mlss2019bdl.bdl",
13 | ],
14 | install_requires=[
15 | "numpy",
16 | "tqdm",
17 | "matplotlib",
18 | "torch",
19 | "torchvision",
20 | ]
21 | )
22 |
--------------------------------------------------------------------------------