├── .gitattributes
├── .gitignore
├── MLSS FULL SCHEDULE - 8.22.19.pdf
├── README.md
├── bayesian_deep_learning
├── .gitignore
├── Bayesian Deep Learning part1-soln.ipynb
├── Bayesian Deep Learning part1.ipynb
├── Bayesian Deep Learning part2-soln.ipynb
├── Bayesian Deep Learning part2.ipynb
├── README.md
├── mlss2019bdl
│ ├── __init__.py
│ ├── bdl
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── bernoulli.py
│ │ └── gaussian.py
│ ├── dataset.py
│ ├── flex.py
│ └── plotting.py
└── setup.py
├── causality
├── exercises-answers.pdf
└── exercises-tutorial.pdf
├── geometric_techniques_in_ML
├── MANIFEST.in
├── riemannian_opt_for_ml_solution.ipynb
├── riemannian_opt_for_ml_task.ipynb
├── riemannian_opt_gmm_embeddings.ipynb
├── riemannian_opt_text_preprocessing.ipynb
├── riemannianoptimization
│ ├── __init__.py
│ ├── data
│ │ ├── kbvt_lfpw_v1_train.csv
│ │ └── tsne_result_training_part.csv
│ └── tutorial_helpers.py
└── setup.py
├── img
├── img0.png
├── img1.png
├── img2.png
├── img3.png
├── img4.png
├── img5.png
├── img6.png
├── img7.png
└── img8.png
├── kernels
├── README.md
├── dril-heuristic.png
├── probability_testing
│ ├── __init__.py
│ ├── data
│ │ ├── almost_simple.npz
│ │ ├── blobs.npz
│ │ ├── blobs2.npz
│ │ ├── blobs_single.npz
│ │ ├── gan-samples.npz
│ │ ├── hsic.npz
│ │ ├── simple.npz
│ │ ├── stopwords-english.txt
│ │ ├── stopwords-french.txt
│ │ └── transcripts.tar.bz2
│ └── support
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── kernels.cpython-36.pyc
│ │ ├── mmd.cpython-36.pyc
│ │ └── utils.cpython-36.pyc
│ │ ├── kernels.py
│ │ ├── mmd.py
│ │ └── utils.py
├── setup.py
├── solutions_testing.ipynb
└── testing.ipynb
└── optimal_transport_tutorial
├── MANIFEST.in
├── Opt_transport_1_Introduction_to_POT_and_S._solutions.ipynb
├── Opt_transport_1_Introduction_to_POT_and_S.ipynb
├── Opt_transport_2_Optimal_Transport_for_Mac.ipynb
├── Opt_transport_2_Optimal_Transport_for_Mac_solutions.ipynb
├── optimaltransport
├── __init__.py
└── data
│ ├── croissants.pickle
│ ├── schiele.jpg
│ ├── schiele2.jpg
│ └── texts.pickle
└── setup.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.DS_Store
--------------------------------------------------------------------------------
/MLSS FULL SCHEDULE - 8.22.19.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/MLSS FULL SCHEDULE - 8.22.19.pdf
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MLSS 2019 Skoltech tutorials
2 | This is the official repository for Machine Learning Summer School 2019, which is taking place at Skoltech Institute of Science and Technology, Moscow, from 26.08 - 06.09.
3 |
4 | This repository will contain all of the materials needed for MLSS tutorials.
5 |
6 | ## The list of the current tutorials published (will be updated with time):
7 | * DAY-1 (26.08): François-Pierre Paty, Marco Cuturi - Optimal Transport: https://github.com/mlss-skoltech/tutorials/tree/master/optimal_transport_tutorial
8 | * DAY-2 (27.08): Alexey Artemov, Justin Solomon - Geometric Techniques in ML: https://github.com/mlss-skoltech/tutorials/tree/master/geometric_techniques_in_ML
9 | * DAY-4 (29.08): Yermek Kapushev, Arthur Gretton - Kernels: https://github.com/mlss-skoltech/tutorials/tree/master/kernels
10 | * [Updated 28.08]DAY-5 (30.08): Joris Mooij - Causality: https://github.com/mlss-skoltech/tutorials/tree/master/causality
11 | * [Updated 30.08]DAY-5 (30.08): Ivan Nazarov, Yarin Gal - Bayesian Deep Learning: https://github.com/mlss-skoltech/tutorials/tree/master/bayesian_deep_learning
12 |
13 | # Running the tutorials on Google Colaboratory:
14 | Most of the tutorials were created using Jupyter Notebooks. In order to reduce the time spent on installing various software, we have made sure that all of the tutorials are Google Colaboratory friendly.
15 |
16 | Colaboratory is a free Jupyter notebook environment that requires no setup and runs entirely in the cloud. With Colaboratory you can write and execute code, save and share your analyses, and access powerful computing resources, all for free from your browser. All of the notebooks already contain all the set-ups needed for each particular tutorial, so you will just be required to run the first several cells.
17 |
18 | Here are the instructions on how open the notebooks in Colaboratory (tested on Google Chrome, version 76.0.):
19 | * First go to https://colab.research.google.com/github/mlss-skoltech/
20 | * In the pop-up window, sign-in into your GitHub account
21 | 
22 | * In the opened window, choose the notebook correspodning to the tutorial
23 | 
24 | * The selected notebook will open, now make sure that you are signed-in into your Google account
25 | 
26 | * Try to run the first cell, you will get the following message:
27 | 
28 | Press ```RUN ANYWAY```
29 | * For the message ```Reset all runtimes``` press ```YES```
30 | 
31 |
32 | In order to download all the material for the tutorial, make sure you run the cells containing the following code first (all of these cells are already added to the notebooks with the right paths):
33 | * For downloading the github subdirectory containing the tutorial:
34 |
35 | ```!pip install --upgrade git+https://github.com/mlss-skoltech/tutorials.git#subdirectory=```
36 |
37 | * For declaring the data files' path:
38 | ```
39 | import pkg_resources
40 | DATA_PATH = pkg_resources.resource_filename('name_of_the_installed_tutorial_package', 'data/')
41 | ```
42 | # Using GPU with Google Colaboratory:
43 | Sometimes for computationally hard tasks you will be required to use GPU instead of default CPU, in order to do this follow these steps:
44 | * Go to ```Edit->Notebook Settings```
45 | 
46 | * In the ```Hardware accelerator``` field choose ```GPU```
47 | 
48 | 
49 |
50 | # Saving and downloading the notebooks
51 | You can save your notebook in your Google Drive or simply download it, for that go to ```File->Save a copy in Drive``` or ```File->Download.ipynb```.
52 | 
53 |
54 |
55 |
56 | If you would like to see more tutorials regarding Google Colaboratory have a look at this notebook: https://colab.research.google.com/notebooks/welcome.ipynb
57 |
58 | # Contact
59 | If you have any questions/suggestions regarding this githup repository or have found any bugs, please write to me at N.Mazyavkina@skoltech.ru
60 |
61 |
--------------------------------------------------------------------------------
/bayesian_deep_learning/.gitignore:
--------------------------------------------------------------------------------
1 | # General
2 | .DS_Store
3 | .AppleDouble
4 | .LSOverride
5 |
6 | # Icon must end with two \r
7 | Icon
8 |
9 |
10 | # Thumbnails
11 | ._*
12 |
13 | # Files that might appear in the root of a volume
14 | .DocumentRevisions-V100
15 | .fseventsd
16 | .Spotlight-V100
17 | .TemporaryItems
18 | .Trashes
19 | .VolumeIcon.icns
20 | .com.apple.timemachine.donotpresent
21 |
22 | # Directories potentially created on remote AFP share
23 | .AppleDB
24 | .AppleDesktop
25 | Network Trash Folder
26 | Temporary Items
27 | .apdisk
28 |
29 | # Byte-compiled / optimized / DLL files
30 | __pycache__/
31 | *.py[cod]
32 | *$py.class
33 |
34 | # C extensions
35 | *.so
36 |
37 | # Distribution / packaging
38 | .Python
39 | build/
40 | develop-eggs/
41 | dist/
42 | downloads/
43 | eggs/
44 | .eggs/
45 | lib/
46 | lib64/
47 | parts/
48 | sdist/
49 | var/
50 | wheels/
51 | *.egg-info/
52 | .installed.cfg
53 | *.egg
54 | MANIFEST
55 |
56 | # PyInstaller
57 | # Usually these files are written by a python script from a template
58 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
59 | *.manifest
60 | *.spec
61 |
62 | # Installer logs
63 | pip-log.txt
64 | pip-delete-this-directory.txt
65 |
66 | # Unit test / coverage reports
67 | htmlcov/
68 | .tox/
69 | .coverage
70 | .coverage.*
71 | .cache
72 | nosetests.xml
73 | coverage.xml
74 | *.cover
75 | .hypothesis/
76 | .pytest_cache/
77 |
78 | # Translations
79 | *.mo
80 | *.pot
81 |
82 | # Django stuff:
83 | *.log
84 | local_settings.py
85 | db.sqlite3
86 |
87 | # Flask stuff:
88 | instance/
89 | .webassets-cache
90 |
91 | # Scrapy stuff:
92 | .scrapy
93 |
94 | # Sphinx documentation
95 | docs/_build/
96 |
97 | # PyBuilder
98 | target/
99 |
100 | # Jupyter Notebook
101 | .ipynb_checkpoints
102 |
103 | # pyenv
104 | .python-version
105 |
106 | # celery beat schedule file
107 | celerybeat-schedule
108 |
109 | # SageMath parsed files
110 | *.sage.py
111 |
112 | # Environments
113 | .env
114 | .venv
115 | env/
116 | venv/
117 | ENV/
118 | env.bak/
119 | venv.bak/
120 |
121 | # Spyder project settings
122 | .spyderproject
123 | .spyproject
124 |
125 | # Rope project settings
126 | .ropeproject
127 |
128 | # mkdocs documentation
129 | /site
130 |
131 | # mypy
132 | .mypy_cache/
133 |
--------------------------------------------------------------------------------
/bayesian_deep_learning/Bayesian Deep Learning part2.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# MLSS2019: Bayesian Deep Learning"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "In this tutorial we will uncertainty estimation can be\n",
15 | "used in active learning or expert-in-the-loop pipelines."
16 | ]
17 | },
18 | {
19 | "cell_type": "markdown",
20 | "metadata": {},
21 | "source": [
22 | "The plan of the tutorial\n",
23 | "1. [Imports and definitions](#Imports-and-definitions)\n",
24 | "2. [Bayesian Active Learning with images](#Bayesian-Active-Learning-with-images)\n",
25 | " 1. [The model](#The-model)\n",
26 | " 2. [the Acquisition Function](#the-Acquisition-Function)\n",
27 | " 3. [Data and the Oracle](#Data-and-the-Oracle)\n",
28 | " 4. [the Active Learning loop](#the-Active-Learning-loop)\n",
29 | " 5. [The baseline](#The-baseline)\n",
30 | "3. [Bayesian Active Learning by Disagreement](#Bayesian-Active-Learning-by-Disagreement)\n",
31 | " 1. [Points of improvement: batch-vs-single](#Points-of-improvement:-batch-vs-single)\n",
32 | " 2. [Points of improvement: bias](#Points-of-improvement:-bias)\n"
33 | ]
34 | },
35 | {
36 | "cell_type": "markdown",
37 | "metadata": {},
38 | "source": [
39 | "**(note)**\n",
40 | "* to view documentation on something type in `something?` (with one question mark)\n",
41 | "* to view code of something type in `something??` (with two question marks)."
42 | ]
43 | },
44 | {
45 | "cell_type": "markdown",
46 | "metadata": {},
47 | "source": [
48 | "
"
49 | ]
50 | },
51 | {
52 | "cell_type": "markdown",
53 | "metadata": {},
54 | "source": [
55 | "## Imports and definitions"
56 | ]
57 | },
58 | {
59 | "cell_type": "markdown",
60 | "metadata": {},
61 | "source": [
62 | "In this section we import necessary modules and functions and\n",
63 | "define the computational device."
64 | ]
65 | },
66 | {
67 | "cell_type": "markdown",
68 | "metadata": {},
69 | "source": [
70 | "First, we install some boilerplate service code for this tutorial."
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "execution_count": null,
76 | "metadata": {},
77 | "outputs": [],
78 | "source": [
79 | "!pip install -q --upgrade git+https://github.com/mlss-skoltech/tutorials.git#subdirectory=bayesian_deep_learning"
80 | ]
81 | },
82 | {
83 | "cell_type": "markdown",
84 | "metadata": {},
85 | "source": [
86 | "Next, numpy for computing, matplotlib for plotting and tqdm for progress bars."
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": null,
92 | "metadata": {},
93 | "outputs": [],
94 | "source": [
95 | "import tqdm\n",
96 | "import numpy as np\n",
97 | "\n",
98 | "%matplotlib inline\n",
99 | "import matplotlib.pyplot as plt"
100 | ]
101 | },
102 | {
103 | "cell_type": "markdown",
104 | "metadata": {},
105 | "source": [
106 | "For deep learning stuff will be using [pytorch](https://pytorch.org/).\n",
107 | "\n",
108 | "If you are unfamiliar with it, it is basically like `numpy` with autograd,\n",
109 | "native GPU support, and tools for building training and serializing models.\n",
110 | "\n",
111 | "\n",
112 | "There are good introductory tutorials on `pytorch`, like this\n",
113 | "[one](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html)."
114 | ]
115 | },
116 | {
117 | "cell_type": "code",
118 | "execution_count": null,
119 | "metadata": {},
120 | "outputs": [],
121 | "source": [
122 | "import torch\n",
123 | "import torch.nn.functional as F\n",
124 | "\n",
125 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
126 | ]
127 | },
128 | {
129 | "cell_type": "markdown",
130 | "metadata": {},
131 | "source": [
132 | "Next we import the boilerplate code.\n",
133 | "\n",
134 | "* a procedure that implements a minibatch SGD **fit** loop\n",
135 | "* a function, that **evaluates** the model on the provided dataset"
136 | ]
137 | },
138 | {
139 | "cell_type": "code",
140 | "execution_count": null,
141 | "metadata": {},
142 | "outputs": [],
143 | "source": [
144 | "from mlss2019bdl import fit, predict"
145 | ]
146 | },
147 | {
148 | "cell_type": "markdown",
149 | "metadata": {},
150 | "source": [
151 | "The algorithm to sample a random function is:\n",
152 | "* for $b = 1... B$ do:\n",
153 | "\n",
154 | " 1. draw an independent realization $f_b\\colon \\mathcal{X} \\to \\mathcal{Y}$\n",
155 | " with from the process $\\{f_\\omega\\}_{\\omega \\sim q(\\omega)}$\n",
156 | " 2. get $\\hat{y}_{bi} = f_b(\\tilde{x}_i)$ for $i=1 .. m$\n"
157 | ]
158 | },
159 | {
160 | "cell_type": "code",
161 | "execution_count": null,
162 | "metadata": {},
163 | "outputs": [],
164 | "source": [
165 | "from mlss2019bdl.bdl import freeze, unfreeze\n",
166 | "\n",
167 | "def sample_function(model, dataset, n_draws=1, verbose=False):\n",
168 | " \"\"\"Draw a realization of a random function.\"\"\"\n",
169 | " outputs = []\n",
170 | " for _ in tqdm.tqdm(range(n_draws), disable=not verbose):\n",
171 | " freeze(model)\n",
172 | "\n",
173 | " outputs.append(predict(model, dataset))\n",
174 | "\n",
175 | " unfreeze(model)\n",
176 | "\n",
177 | " return torch.stack(outputs, dim=0)"
178 | ]
179 | },
180 | {
181 | "cell_type": "markdown",
182 | "metadata": {},
183 | "source": [
184 | "Sample the class probabilities $p(y_x = k \\mid x, \\omega, m)$\n",
185 | "with $\\omega \\sim q(\\omega)$ by a model that **outputs raw class\n",
186 | "logit scores**."
187 | ]
188 | },
189 | {
190 | "cell_type": "code",
191 | "execution_count": null,
192 | "metadata": {},
193 | "outputs": [],
194 | "source": [
195 | "def sample_proba(model, dataset, n_draws=1):\n",
196 | " logits = sample_function(model, dataset, n_draws=n_draws)\n",
197 | "\n",
198 | " return F.softmax(logits, dim=-1)"
199 | ]
200 | },
201 | {
202 | "cell_type": "markdown",
203 | "metadata": {},
204 | "source": [
205 | "Get the predictive posterior class probabilities\n",
206 | "$$\n",
207 | "p(y_x = k \\mid x, m)\n",
208 | "% = \\mathbb{E}_{\\omega \\sim q(\\omega)}\n",
209 | "% p(y_x = k \\mid x, \\omega, m)\n",
210 | " \\approx \\frac1{\\lvert \\mathcal{W} \\rvert}\n",
211 | " \\sum_{\\omega \\in \\mathcal{W}}\n",
212 | " p(y_x = k \\mid x, \\omega, m)\n",
213 | " \\,, $$\n",
214 | "with $\\mathcal{W}$ -- iid draws from $q(\\omega)$."
215 | ]
216 | },
217 | {
218 | "cell_type": "code",
219 | "execution_count": null,
220 | "metadata": {},
221 | "outputs": [],
222 | "source": [
223 | "def predict_proba(model, dataset, n_draws=1):\n",
224 | " proba = sample_proba(model, dataset, n_draws=n_draws)\n",
225 | "\n",
226 | " return proba.mean(dim=0)"
227 | ]
228 | },
229 | {
230 | "cell_type": "markdown",
231 | "metadata": {},
232 | "source": [
233 | "Gat the maximum a posteriori class label **(MAP)**: $\n",
234 | "\\hat{y}_x\n",
235 | " = \\arg \\max_k \\mathbb{E}_{\\omega \\sim q(\\omega)}\n",
236 | " p(y_x = k \\mid x, \\omega, m)\n",
237 | "$"
238 | ]
239 | },
240 | {
241 | "cell_type": "code",
242 | "execution_count": null,
243 | "metadata": {},
244 | "outputs": [],
245 | "source": [
246 | "def predict_label(model, dataset, n_draws=1):\n",
247 | " proba = predict_proba(model, dataset, n_draws=n_draws)\n",
248 | "\n",
249 | " return proba.argmax(dim=-1)"
250 | ]
251 | },
252 | {
253 | "cell_type": "markdown",
254 | "metadata": {},
255 | "source": [
256 | "We will need some functionality from scikit"
257 | ]
258 | },
259 | {
260 | "cell_type": "code",
261 | "execution_count": null,
262 | "metadata": {},
263 | "outputs": [],
264 | "source": [
265 | "from sklearn.metrics import confusion_matrix\n",
266 | "\n",
267 | "def evaluate(model, dataset, n_draws=1):\n",
268 | " assert isinstance(dataset, TensorDataset)\n",
269 | "\n",
270 | " predicted = predict_label(model, dataset, n_draws=n_draws)\n",
271 | "\n",
272 | " target = dataset.tensors[1].cpu().numpy()\n",
273 | " return confusion_matrix(target, predicted.cpu().numpy())"
274 | ]
275 | },
276 | {
277 | "cell_type": "markdown",
278 | "metadata": {},
279 | "source": [
280 | "A function to plot images in a small dataset. "
281 | ]
282 | },
283 | {
284 | "cell_type": "code",
285 | "execution_count": null,
286 | "metadata": {},
287 | "outputs": [],
288 | "source": [
289 | "from mlss2019bdl.flex import plot\n",
290 | "from torch.utils.data import TensorDataset\n",
291 | "from IPython.display import clear_output\n",
292 | "\n",
293 | "def display(images, n_col=None, title=None, figsize=None, refresh=False):\n",
294 | " if isinstance(images, TensorDataset):\n",
295 | " images, targets = images.tensors\n",
296 | " \n",
297 | " if refresh:\n",
298 | " clear_output(True)\n",
299 | "\n",
300 | " fig, ax = plt.subplots(1, 1, figsize=figsize)\n",
301 | " plot(ax, images, n_col=n_col, cmap=plt.cm.bone)\n",
302 | " if title is not None:\n",
303 | " ax.set_title(title)\n",
304 | "\n",
305 | " plt.show()\n",
306 | " plt.close()"
307 | ]
308 | },
309 | {
310 | "cell_type": "markdown",
311 | "metadata": {},
312 | "source": [
313 | "
"
314 | ]
315 | },
316 | {
317 | "cell_type": "markdown",
318 | "metadata": {},
319 | "source": [
320 | "## Bayesian Active Learning with images"
321 | ]
322 | },
323 | {
324 | "cell_type": "markdown",
325 | "metadata": {},
326 | "source": [
327 | "* Data labelling is costly and time consuming\n",
328 | "* unlabeled instances are essentially free\n",
329 | "\n",
330 | "**Goal** Achieve high performance with fewer labels by\n",
331 | "identifying the best instances to learn from"
332 | ]
333 | },
334 | {
335 | "cell_type": "markdown",
336 | "metadata": {},
337 | "source": [
338 | "Essential blocks of active learning:\n",
339 | "\n",
340 | "* a **model** $m$ capable of quantifying uncertainty (preferably a Bayesian model)\n",
341 | "* an **acquisition function** $a\\colon \\mathcal{M} \\times \\mathcal{X}^* \\to \\mathbb{R}$\n",
342 | " that for any finite set of inputs $S\\subset \\mathcal{X}$ quantifies their usefulness\n",
343 | " to the model $m\\in \\mathcal{M}$\n",
344 | "* a labelling **oracle**, e.g. a human expert"
345 | ]
346 | },
347 | {
348 | "cell_type": "markdown",
349 | "metadata": {},
350 | "source": [
351 | "### The model"
352 | ]
353 | },
354 | {
355 | "cell_type": "markdown",
356 | "metadata": {},
357 | "source": [
358 | "We reuse the `DropoutLinear` from the first part."
359 | ]
360 | },
361 | {
362 | "cell_type": "code",
363 | "execution_count": null,
364 | "metadata": {},
365 | "outputs": [],
366 | "source": [
367 | "from torch.nn import Module, Sequential\n",
368 | "from torch.nn import AvgPool2d, LeakyReLU\n",
369 | "from torch.nn import Linear, Conv2d\n",
370 | "\n",
371 | "from mlss2019bdl.bdl import DropoutLinear, DropoutConv2d\n",
372 | "\n",
373 | "class MNISTModel(Module):\n",
374 | " def __init__(self, p=0.5):\n",
375 | " super().__init__()\n",
376 | "\n",
377 | " self.head = Sequential(\n",
378 | " Conv2d(1, 32, 3, 1),\n",
379 | " LeakyReLU(),\n",
380 | " DropoutConv2d(32, 64, 3, 1, p=p),\n",
381 | " LeakyReLU(),\n",
382 | " AvgPool2d(2),\n",
383 | " )\n",
384 | "\n",
385 | " self.tail = Sequential(\n",
386 | " DropoutLinear(12 * 12 * 64, 128, p=p),\n",
387 | " LeakyReLU(),\n",
388 | " DropoutLinear(128, 10, p=p),\n",
389 | " )\n",
390 | "\n",
391 | " def forward(self, input):\n",
392 | " \"\"\"Take images and compute their class logits.\"\"\"\n",
393 | " x = self.head(input)\n",
394 | " return self.tail(x.flatten(1))"
395 | ]
396 | },
397 | {
398 | "cell_type": "markdown",
399 | "metadata": {},
400 | "source": [
401 | "
"
402 | ]
403 | },
404 | {
405 | "cell_type": "markdown",
406 | "metadata": {},
407 | "source": [
408 | "### the Acquisition Function"
409 | ]
410 | },
411 | {
412 | "cell_type": "markdown",
413 | "metadata": {},
414 | "source": [
415 | "There are many acquisition criteria (borrowed from [Gal17a](http://proceedings.mlr.press/v70/gal17a.html)):\n",
416 | "* Classification\n",
417 | " * Posterior predictive entropy\n",
418 | " * Posterior Mutual Information\n",
419 | " * Variance ratios\n",
420 | " * BALD\n",
421 | "\n",
422 | "* Regression\n",
423 | " * predictive variance\n",
424 | "\n",
425 | "... and there is always the baseline **random acquisition**"
426 | ]
427 | },
428 | {
429 | "cell_type": "code",
430 | "execution_count": null,
431 | "metadata": {},
432 | "outputs": [],
433 | "source": [
434 | "random_state = np.random.RandomState(812_760_351)\n",
435 | "\n",
436 | "def random_acquisition(dataset, model, n_request=1, n_draws=1):\n",
437 | " indices = random_state.choice(len(dataset), size=n_request)\n",
438 | "\n",
439 | " return torch.from_numpy(indices).to(device)"
440 | ]
441 | },
442 | {
443 | "cell_type": "markdown",
444 | "metadata": {},
445 | "source": [
446 | "
"
447 | ]
448 | },
449 | {
450 | "cell_type": "markdown",
451 | "metadata": {},
452 | "source": [
453 | "### Data and the Oracle"
454 | ]
455 | },
456 | {
457 | "cell_type": "markdown",
458 | "metadata": {},
459 | "source": [
460 | "Prepare the datasets from the `train` part of\n",
461 | "[MNIST](http://yann.lecun.com/exdb/mnist/)\n",
462 | "(or [Kuzushiji-MNIST](https://github.com/rois-codh/kmnist)):\n",
463 | "* ($\\mathcal{S}_\\mathrm{train}$) initial **training**: $30$ images\n",
464 | "* ($\\mathcal{S}_\\mathrm{valid}$) our **validation**:\n",
465 | " $5000$ images, stratified\n",
466 | "* ($\\mathcal{S}_\\mathrm{pool}$) acquisition **pool**:\n",
467 | " $5000$ of the unused images, skewed to class $0$\n",
468 | "\n",
469 | "The true test sample of MNIST is in $\\mathcal{S}_\\mathrm{test}$ -- we\n",
470 | "will use it to evaluate the final performance."
471 | ]
472 | },
473 | {
474 | "cell_type": "code",
475 | "execution_count": null,
476 | "metadata": {},
477 | "outputs": [],
478 | "source": [
479 | "from mlss2019bdl.dataset import get_dataset\n",
480 | "\n",
481 | "S_train, S_pool, S_valid, S_test = get_dataset(\n",
482 | " n_train=30,\n",
483 | " n_valid=5000,\n",
484 | " n_pool=5000,\n",
485 | " name=\"MNIST\", # \"KMNIST\"\n",
486 | " path=\"./data\",\n",
487 | " random_state=722_257_201)"
488 | ]
489 | },
490 | {
491 | "cell_type": "markdown",
492 | "metadata": {},
493 | "source": [
494 | "* `query_oracle(ix, D)` **request** the instances in `D` at the specified\n",
495 | " indices `ix` into a dataset and **remove** from them from `D`\n",
496 | "\n",
497 | "* `merge(*datasets, [out=])` merge the datasets, creting a new one, or replacing `out`"
498 | ]
499 | },
500 | {
501 | "cell_type": "code",
502 | "execution_count": null,
503 | "metadata": {},
504 | "outputs": [],
505 | "source": [
506 | "from mlss2019bdl.dataset import collect as query_oracle"
507 | ]
508 | },
509 | {
510 | "cell_type": "markdown",
511 | "metadata": {},
512 | "source": [
513 | "
"
514 | ]
515 | },
516 | {
517 | "cell_type": "markdown",
518 | "metadata": {},
519 | "source": [
520 | "### the Active Learning loop"
521 | ]
522 | },
523 | {
524 | "cell_type": "markdown",
525 | "metadata": {},
526 | "source": [
527 | "1. fit $m$ on $\\mathcal{S}_{\\mathrm{labelled}}$\n",
528 | "\n",
529 | "\n",
530 | "2. get exact (or approximate) $$\n",
531 | " \\mathcal{S}^* \\in \\arg \\max\\limits_{S \\subseteq \\mathcal{S}_\\mathrm{unlabelled}}\n",
532 | " a(m, S)\n",
533 | "$$ satisfying **budget constraints** and **without** access to targets\n",
534 | "(constraints, like $\\lvert S \\rvert \\leq \\ell$ or other economically motivated ones).\n",
535 | "\n",
536 | "\n",
537 | "3. request the **oracle** to provide labels for each $x\\in \\mathcal{S}^*$\n",
538 | "\n",
539 | "\n",
540 | "4. update $\n",
541 | "\\mathcal{S}_{\\mathrm{labelled}}\n",
542 | " \\leftarrow \\mathcal{S}^*\n",
543 | " \\cup \\mathcal{S}_{\\mathrm{labelled}}\n",
544 | "$ and goto 1."
545 | ]
546 | },
547 | {
548 | "cell_type": "code",
549 | "execution_count": null,
550 | "metadata": {},
551 | "outputs": [],
552 | "source": [
553 | "import copy\n",
554 | "from mlss2019bdl.dataset import merge\n",
555 | "\n",
556 | "def active_learn(S_train,\n",
557 | " S_pool,\n",
558 | " S_valid,\n",
559 | " acquire_fn,\n",
560 | " n_budget=150,\n",
561 | " n_max_request=3,\n",
562 | " n_draws=11,\n",
563 | " n_epochs=200,\n",
564 | " p=0.5,\n",
565 | " weight_decay=1e-2):\n",
566 | "\n",
567 | " model = MNISTModel(p=p).to(device)\n",
568 | "\n",
569 | " scores, balances = [], []\n",
570 | " S_train, S_pool = copy.deepcopy(S_train), copy.deepcopy(S_pool)\n",
571 | " while True:\n",
572 | " # 1. fit on train\n",
573 | " l2_reg = weight_decay * (1 - p) / max(len(S_train), 1)\n",
574 | "\n",
575 | " model = fit(model, S_train, batch_size=32, criterion=\"cross_entropy\",\n",
576 | " weight_decay=l2_reg, n_epochs=n_epochs)\n",
577 | "\n",
578 | "\n",
579 | " # (optional) keep track of scores and plot the train dataset\n",
580 | " scores.append(evaluate(model, S_valid, n_draws))\n",
581 | " balances.append(np.bincount(S_train.tensors[1], minlength=10))\n",
582 | "\n",
583 | " accuracy = scores[-1].diagonal().sum() / scores[-1].sum()\n",
584 | " title = f\"(n_train) {len(S_train)} (Acc.) {accuracy:.1%}\"\n",
585 | " display(S_train, n_col=30, figsize=(15, 5), title=title, refresh=True)\n",
586 | "\n",
587 | "\n",
588 | " # 2-3. request new data from pool, if within budget\n",
589 | " n_request = min(n_budget - len(S_train), n_max_request)\n",
590 | " if n_request <= 0:\n",
591 | " break\n",
592 | "\n",
593 | " indices = acquire_fn(S_pool, model, n_request=n_request, n_draws=n_draws)\n",
594 | "\n",
595 | " # 4. update the train dataset\n",
596 | " S_requested = query_oracle(indices, S_pool)\n",
597 | " S_train = merge(S_train, S_requested)\n",
598 | "\n",
599 | " return model, S_train, np.stack(scores, axis=0), np.stack(balances, axis=0)"
600 | ]
601 | },
602 | {
603 | "cell_type": "markdown",
604 | "metadata": {},
605 | "source": [
606 | "* `collect(ix, D)` **collect** the instances in `D` at the specified\n",
607 | " indices `ix` into a dataset and **remove** from them from `D`\n",
608 | "\n",
609 | "* `merge(*datasets, [out=])` merge the datasets, creting a new one, or replacing `out`"
610 | ]
611 | },
612 | {
613 | "cell_type": "markdown",
614 | "metadata": {},
615 | "source": [
616 | "
"
617 | ]
618 | },
619 | {
620 | "cell_type": "markdown",
621 | "metadata": {},
622 | "source": [
623 | "### The baseline"
624 | ]
625 | },
626 | {
627 | "cell_type": "markdown",
628 | "metadata": {},
629 | "source": [
630 | "How powerful will our model with random acquisition get under a total budget of $150$ images?"
631 | ]
632 | },
633 | {
634 | "cell_type": "code",
635 | "execution_count": null,
636 | "metadata": {
637 | "scrolled": false
638 | },
639 | "outputs": [],
640 | "source": [
641 | "baseline = active_learn(\n",
642 | " S_train,\n",
643 | " S_pool,\n",
644 | " S_valid,\n",
645 | " random_acquisition,\n",
646 | " n_draws=21,\n",
647 | " n_budget=150,\n",
648 | " n_max_request=3,\n",
649 | " n_epochs=200,\n",
650 | ")"
651 | ]
652 | },
653 | {
654 | "cell_type": "markdown",
655 | "metadata": {},
656 | "source": [
657 | "Let's see the dynamics of the accuracy ..."
658 | ]
659 | },
660 | {
661 | "cell_type": "code",
662 | "execution_count": null,
663 | "metadata": {},
664 | "outputs": [],
665 | "source": [
666 | "def accuracy(scores):\n",
667 | " tp = scores.diagonal(axis1=-2, axis2=-1)\n",
668 | " return tp.sum(-1) / scores.sum((-2, -1))"
669 | ]
670 | },
671 | {
672 | "cell_type": "code",
673 | "execution_count": null,
674 | "metadata": {},
675 | "outputs": [],
676 | "source": [
677 | "model_rand, train_rand, scores_rand, balances_rand = baseline\n",
678 | "\n",
679 | "fig, ax = plt.subplots(1, 1, figsize=(12, 7))\n",
680 | "ax.plot(accuracy(scores_rand), label='Accuracy (random)', lw=2)\n",
681 | "\n",
682 | "ax.legend()\n",
683 | "plt.show()"
684 | ]
685 | },
686 | {
687 | "cell_type": "markdown",
688 | "metadata": {},
689 | "source": [
690 | "..., and the frequency of each class in $\\mathcal{S}_\\mathrm{train}$."
691 | ]
692 | },
693 | {
694 | "cell_type": "code",
695 | "execution_count": null,
696 | "metadata": {},
697 | "outputs": [],
698 | "source": [
699 | "fig, ax = plt.subplots(1, 1, figsize=(12, 7))\n",
700 | "\n",
701 | "lines = ax.plot(balances_rand, lw=2)\n",
702 | "plt.legend(lines, list(range(10)), ncol=2);"
703 | ]
704 | },
705 | {
706 | "cell_type": "markdown",
707 | "metadata": {},
708 | "source": [
709 | "
"
710 | ]
711 | },
712 | {
713 | "cell_type": "markdown",
714 | "metadata": {},
715 | "source": [
716 | "## Bayesian Active Learning by Disagreement"
717 | ]
718 | },
719 | {
720 | "cell_type": "markdown",
721 | "metadata": {},
722 | "source": [
723 | "Bayesian Active Learning by Disagreement, or **BALD** criterion, is\n",
724 | "based on the posterior mutual information between model's predictions\n",
725 | "$y_x$ at some point $x$ and its parameters $\\omega$:\n",
726 | "\n",
727 | "$$\\begin{align}\n",
728 | " a(m, S)\n",
729 | " &= \\sum_{x\\in S} a(m, \\{x\\})\n",
730 | " \\\\\n",
731 | " a(m, \\{x\\})\n",
732 | " &= \\mathbb{I}(y_x; \\omega \\mid x, m, D)\n",
733 | "\\end{align}\n",
734 | " \\,, \\tag{bald} $$\n",
735 | "\n",
736 | "with the [**Mutual Information**](https://en.wikipedia.org/wiki/Mutual_information#Relation_to_Kullback%E2%80%93Leibler_divergence)\n",
737 | "(**MI**)\n",
738 | "$$\n",
739 | " \\mathbb{I}(y_x; \\omega \\mid x, m, D)\n",
740 | " = \\mathbb{H}\\bigl(\n",
741 | " \\mathbb{E}_{\\omega \\sim q(\\omega\\mid m, D)}\n",
742 | " p(y_x \\,\\mid\\, x, \\omega, m, D)\n",
743 | " \\bigr)\n",
744 | " - \\mathbb{E}_{\\omega \\sim q(\\omega\\mid m, D)}\n",
745 | " \\mathbb{H}\\bigl(\n",
746 | " p(y_x \\,\\mid\\, x, \\omega, m, D)\n",
747 | " \\bigr)\n",
748 | " \\,, \\tag{mi} $$\n",
749 | "\n",
750 | "and the [(differential) **entropy**](https://en.wikipedia.org/wiki/Differential_entropy#Differential_entropies_for_various_distributions)\n",
751 | "(all densities and/or probability mass functions can be conditional):\n",
752 | "\n",
753 | "$$\n",
754 | " \\mathbb{H}(p(y))\n",
755 | " = - \\mathbb{E}_{y\\sim p} \\log p(y)\n",
756 | " \\,. $$"
757 | ]
758 | },
759 | {
760 | "cell_type": "markdown",
761 | "metadata": {},
762 | "source": [
763 | "
"
764 | ]
765 | },
766 | {
767 | "cell_type": "markdown",
768 | "metadata": {},
769 | "source": [
770 | "#### (task) Implementing the acquisition function"
771 | ]
772 | },
773 | {
774 | "cell_type": "markdown",
775 | "metadata": {},
776 | "source": [
777 | "Note that $a(m, S)$ is additively separable in $S$, i.e.\n",
778 | "equals $\\sum_{x\\in S} a(m, \\{x\\})$. This implies\n",
779 | "\n",
780 | "$$\n",
781 | "\\begin{align}\n",
782 | " \\max_{S \\subseteq \\mathcal{S}_\\mathrm{unlabelled}} a(m, S)\n",
783 | " &= \\max_{z \\in \\mathcal{S}_\\mathrm{unlabelled}}\n",
784 | " \\max_{F \\in \\mathcal{S}_\\mathrm{unlabelled} \\setminus \\{z\\}}\n",
785 | " \\sum_{x\\in F \\cup \\{x\\}} a(m, \\{x\\})\n",
786 | " \\\\\n",
787 | " &= \\max_{z \\in \\mathcal{S}_\\mathrm{unlabelled}}\n",
788 | " a(m, \\{z\\})\n",
789 | " + \\max_{F \\in \\mathcal{S}_\\mathrm{unlabelled} \\setminus \\{z\\}}\n",
790 | " \\sum_{x\\in F} a(m, \\{x\\})\n",
791 | "\\end{align}\n",
792 | " \\,. $$"
793 | ]
794 | },
795 | {
796 | "cell_type": "markdown",
797 | "metadata": {},
798 | "source": [
799 | "Therefore selecting the $\\ell$ `most interesting` points from\n",
800 | "$\\mathcal{S}_\\mathrm{unlabelled}$ is trivial."
801 | ]
802 | },
803 | {
804 | "cell_type": "markdown",
805 | "metadata": {},
806 | "source": [
807 | "The acquisition function that we implement has interface\n",
808 | "identical to `random_acquisition` but uses BALD to choose\n",
809 | "instances."
810 | ]
811 | },
812 | {
813 | "cell_type": "code",
814 | "execution_count": null,
815 | "metadata": {},
816 | "outputs": [],
817 | "source": [
818 | "def BALD_acquisition(dataset, model, n_request=1, n_draws=1):\n",
819 | " proba = sample_proba(model, dataset, n_draws=n_draws)\n",
820 | "\n",
821 | " ## Exercise: implement BALD\n",
822 | "\n",
823 | " pass"
824 | ]
825 | },
826 | {
827 | "cell_type": "markdown",
828 | "metadata": {},
829 | "source": [
830 | "
"
831 | ]
832 | },
833 | {
834 | "cell_type": "markdown",
835 | "metadata": {},
836 | "source": [
837 | "#### (task) implementing entropy"
838 | ]
839 | },
840 | {
841 | "cell_type": "markdown",
842 | "metadata": {},
843 | "source": [
844 | "For categorical (discrete) random variables $y \\sim \\mathcal{Cat}(\\mathbf{p})$,\n",
845 | "$\\mathbf{p} \\in \\{ \\mu \\in [0, 1]^d \\colon \\sum_k \\mu_k = 1\\}$, the entropy is\n",
846 | "\n",
847 | "$$\n",
848 | " \\mathbb{H}(p(y))\n",
849 | " = - \\mathbb{E}_{y\\sim p(y)} \\log p(y)\n",
850 | " = - \\sum_k p_k \\log p_k\n",
851 | " \\,. $$"
852 | ]
853 | },
854 | {
855 | "cell_type": "markdown",
856 | "metadata": {},
857 | "source": [
858 | "**(note)** although in calculus $0 \\cdot \\log 0 = 0$ (because\n",
859 | "$\\lim_{p\\downarrow 0} p \\cdot \\log p = 0$), in floating point\n",
860 | "arithmetic $0 \\cdot \\log 0 = \\mathrm{NaN}$. So you need to add\n",
861 | "some **really tiny float number** to the argument of $\\log$."
862 | ]
863 | },
864 | {
865 | "cell_type": "code",
866 | "execution_count": null,
867 | "metadata": {},
868 | "outputs": [],
869 | "source": [
870 | "def categorical_entropy(proba):\n",
871 | " \"\"\"Compute the entropy along the last dimension.\"\"\"\n",
872 | "\n",
873 | " ## Exercise: the probabilities sum to one along the last axis.\n",
874 | " # Please, compute their entropy.\n",
875 | "\n",
876 | " pass"
877 | ]
878 | },
879 | {
880 | "cell_type": "markdown",
881 | "metadata": {},
882 | "source": [
883 | "
"
884 | ]
885 | },
886 | {
887 | "cell_type": "markdown",
888 | "metadata": {},
889 | "source": [
890 | "#### (task) implementing mutual information"
891 | ]
892 | },
893 | {
894 | "cell_type": "markdown",
895 | "metadata": {},
896 | "source": [
897 | "Consider a tensor $p_{bik}$ of probabilities $p(y_{x_i}=k \\mid x_i, \\omega_b, m, D)$\n",
898 | "with $\\omega_b \\sim q(\\omega \\mid m, D)$ with $\\mathcal{W} = (\\omega_b)_{b=1}^B$\n",
899 | "being iid draws from $q(\\omega \\mid m, D)$.\n",
900 | "\n",
901 | "Let's implement a procedure that computes the Monte Carlo estimate of the\n",
902 | "posterior predictive distribution, its **entropy** and **mutual information**\n",
903 | "\n",
904 | "$$\n",
905 | " \\mathbb{I}_\\mathrm{MC}(y_x; \\omega \\mid x, m, D)\n",
906 | " = \\mathbb{H}\\bigl(\n",
907 | " \\hat{p}(y_x\\mid x, m, D)\n",
908 | " \\bigr)\n",
909 | " - \\frac1{\\lvert \\mathcal{W} \\rvert} \\sum_{\\omega\\in \\mathcal{W}}\n",
910 | " \\mathbb{H}\\bigl(\n",
911 | " p(y_x \\,\\mid\\, x, \\omega, m, D)\n",
912 | " \\bigr)\n",
913 | " \\,, \\tag{mi-mc} $$\n",
914 | "where\n",
915 | "$$\n",
916 | "\\hat{p}(y_x\\mid x, m, D)\n",
917 | " = \\frac1{\\lvert \\mathcal{W} \\rvert} \\sum_{\\omega\\in \\mathcal{W}}\n",
918 | " \\,p(y_x \\mid x, \\omega, m, D)\n",
919 | " \\,. $$"
920 | ]
921 | },
922 | {
923 | "cell_type": "code",
924 | "execution_count": null,
925 | "metadata": {},
926 | "outputs": [],
927 | "source": [
928 | "def mutual_information(proba):\n",
929 | " ## Exercise: compute a Monte Carlo estimator of the predictive\n",
930 | " ## distribution, its entropy and MI `H E_w p(., w) - E_w H p(., w)`\n",
931 | "\n",
932 | " pass"
933 | ]
934 | },
935 | {
936 | "cell_type": "markdown",
937 | "metadata": {},
938 | "source": [
939 | "
"
940 | ]
941 | },
942 | {
943 | "cell_type": "markdown",
944 | "metadata": {},
945 | "source": [
946 | "How powerful will our model with **BALD** acquisition, if we can afford no more than $150$ images?"
947 | ]
948 | },
949 | {
950 | "cell_type": "code",
951 | "execution_count": null,
952 | "metadata": {
953 | "scrolled": false
954 | },
955 | "outputs": [],
956 | "source": [
957 | "bald_results = active_learn(\n",
958 | " S_train,\n",
959 | " S_pool,\n",
960 | " S_valid,\n",
961 | " BALD_acquisition,\n",
962 | " n_draws=21,\n",
963 | " n_budget=150,\n",
964 | " n_max_request=3,\n",
965 | " n_epochs=200,\n",
966 | ")"
967 | ]
968 | },
969 | {
970 | "cell_type": "markdown",
971 | "metadata": {},
972 | "source": [
973 | "Let's see the dynamics of the accuracy ..."
974 | ]
975 | },
976 | {
977 | "cell_type": "code",
978 | "execution_count": null,
979 | "metadata": {},
980 | "outputs": [],
981 | "source": [
982 | "model_bald, train_bald, scores_bald, balances_bald = bald_results\n",
983 | "\n",
984 | "fig, ax = plt.subplots(1, 1, figsize=(12, 7))\n",
985 | "\n",
986 | "ax.plot(accuracy(scores_rand), label='Accuracy (random)', lw=2)\n",
987 | "ax.plot(accuracy(scores_bald), label='Accuracy (BALD)', lw=2)\n",
988 | "\n",
989 | "ax.legend()\n",
990 | "plt.show()"
991 | ]
992 | },
993 | {
994 | "cell_type": "markdown",
995 | "metadata": {},
996 | "source": [
997 | "..., and the frequency of each class in $\\mathcal{S}_\\mathrm{train}$."
998 | ]
999 | },
1000 | {
1001 | "cell_type": "code",
1002 | "execution_count": null,
1003 | "metadata": {},
1004 | "outputs": [],
1005 | "source": [
1006 | "fig, ax = plt.subplots(1, 1, figsize=(12, 7))\n",
1007 | "\n",
1008 | "lines = ax.plot(balances_bald, lw=2)\n",
1009 | "plt.legend(lines, list(range(10)), ncol=2);"
1010 | ]
1011 | },
1012 | {
1013 | "cell_type": "markdown",
1014 | "metadata": {},
1015 | "source": [
1016 | "
"
1017 | ]
1018 | },
1019 | {
1020 | "cell_type": "markdown",
1021 | "metadata": {},
1022 | "source": [
1023 | "#### Class performance"
1024 | ]
1025 | },
1026 | {
1027 | "cell_type": "markdown",
1028 | "metadata": {},
1029 | "source": [
1030 | "The *one-versus-rest* precision / recall scores on\n",
1031 | "$\\mathcal{S}_\\mathrm{valid}$. For binary classification:\n",
1032 | "\n",
1033 | "$$ \\begin{align}\n",
1034 | "\\mathrm{Precision}\n",
1035 | " &= \\frac{\\mathrm{TP}}{\\mathrm{TP} + \\mathrm{FP}}\n",
1036 | " \\approx \\mathbb{P}(y = 1 \\mid \\hat{y} = 1)\n",
1037 | " \\,, \\\\\n",
1038 | "\\mathrm{Recall}\n",
1039 | " &= \\frac{\\mathrm{TP}}{\\mathrm{TP} + \\mathrm{FN}}\n",
1040 | " \\approx \\mathbb{P}(\\hat{y} = 1 \\mid y = 1)\n",
1041 | " \\,.\n",
1042 | "\\end{align}$$"
1043 | ]
1044 | },
1045 | {
1046 | "cell_type": "code",
1047 | "execution_count": null,
1048 | "metadata": {},
1049 | "outputs": [],
1050 | "source": [
1051 | "import pandas as pd\n",
1052 | "\n",
1053 | "def pr_scores(score_matrix):\n",
1054 | " tp = score_matrix.diagonal(axis1=-2, axis2=-1)\n",
1055 | " fp, fn = score_matrix.sum(axis=-2) - tp, score_matrix.sum(axis=-1) - tp\n",
1056 | " \n",
1057 | " return pd.DataFrame({\n",
1058 | " \"precision\": {l: f\"{p:.2%}\" for l, p in enumerate(tp / (tp + fp))},\n",
1059 | " \"recall\": {l: f\"{p:.2%}\" for l, p in enumerate(tp / (tp + fn))},\n",
1060 | " })"
1061 | ]
1062 | },
1063 | {
1064 | "cell_type": "markdown",
1065 | "metadata": {},
1066 | "source": [
1067 | "Let's see the performance on the test set"
1068 | ]
1069 | },
1070 | {
1071 | "cell_type": "code",
1072 | "execution_count": null,
1073 | "metadata": {},
1074 | "outputs": [],
1075 | "source": [
1076 | "scores = {}\n",
1077 | "scores[\"rand\"] = evaluate(model_rand, S_test, n_draws=21)\n",
1078 | "scores[\"bald\"] = evaluate(model_bald, S_test, n_draws=21)"
1079 | ]
1080 | },
1081 | {
1082 | "cell_type": "markdown",
1083 | "metadata": {},
1084 | "source": [
1085 | "
"
1086 | ]
1087 | },
1088 | {
1089 | "cell_type": "code",
1090 | "execution_count": null,
1091 | "metadata": {},
1092 | "outputs": [],
1093 | "source": [
1094 | "df = pd.concat({\n",
1095 | " name: pr_scores(score)\n",
1096 | " for name, score in scores.items()\n",
1097 | "}, axis=1).T\n",
1098 | "\n",
1099 | "df.swaplevel().sort_index()"
1100 | ]
1101 | },
1102 | {
1103 | "cell_type": "markdown",
1104 | "metadata": {},
1105 | "source": [
1106 | "
"
1107 | ]
1108 | },
1109 | {
1110 | "cell_type": "markdown",
1111 | "metadata": {},
1112 | "source": [
1113 | "#### Question(s) (to work on in your spare time)\n",
1114 | "\n",
1115 | "* Run the experiments on the `KMNIST` dataset\n",
1116 | "\n",
1117 | "* Replicate figure 1 from [Gat et al. (2017): p. 4](http://proceedings.mlr.press/v70/gal17a.html).\n",
1118 | " You will need to re-run each experiment several times $11$, recording\n",
1119 | " the accuracy dynamics of each, then compare the mean and $25\\%$-$75\\%$\n",
1120 | " quantiles as they evolve with the size of the training sample."
1121 | ]
1122 | },
1123 | {
1124 | "cell_type": "markdown",
1125 | "metadata": {},
1126 | "source": [
1127 | "
"
1128 | ]
1129 | },
1130 | {
1131 | "cell_type": "markdown",
1132 | "metadata": {},
1133 | "source": [
1134 | "### (optional) Points of improvement: batch-vs-single"
1135 | ]
1136 | },
1137 | {
1138 | "cell_type": "markdown",
1139 | "metadata": {},
1140 | "source": [
1141 | "A drawback of the `pointwise` top-$\\ell$ procedure above is that, although\n",
1142 | "it acquires individually informative instances, altogether they might end\n",
1143 | "up **being** `jointly poorly informative`. This can be corrected, if we\n",
1144 | "would seek the highest mutual information among finite sets $\n",
1145 | "S \\subseteq \\mathcal{S}_\\mathrm{unlabelled}\n",
1146 | "$ of size $\\ell$."
1147 | ]
1148 | },
1149 | {
1150 | "cell_type": "markdown",
1151 | "metadata": {},
1152 | "source": [
1153 | "Such acquisition function is called **batch-BALD**\n",
1154 | "([Kirsch et al.; 2019](https://arxiv.org/abs/1906.08158.pdf)):\n",
1155 | "\n",
1156 | "$$\\begin{align}\n",
1157 | " a(m, S)\n",
1158 | " &= \\mathbb{I}\\bigl((y_x)_{x\\in S}; \\omega \\mid S, m \\bigr)\n",
1159 | " = \\mathbb{H} \\bigl(\n",
1160 | " \\mathbb{E}_{\\omega \\sim q(\\omega\\mid m)} p\\bigl((y_x)_{x\\in S}\\mid S, \\omega, m \\bigr)\n",
1161 | " \\bigr)\n",
1162 | " - \\mathbb{E}_{\\omega \\sim q(\\omega\\mid m)} H\\bigl(\n",
1163 | " p\\bigl((y_x)_{x\\in S}\\mid S, \\omega, m \\bigr)\n",
1164 | " \\bigr)\n",
1165 | "\\end{align}\n",
1166 | " \\,. \\tag{batch-bald} $$"
1167 | ]
1168 | },
1169 | {
1170 | "cell_type": "markdown",
1171 | "metadata": {},
1172 | "source": [
1173 | "This criterion requires combinatorially growing number of computations and\n",
1174 | "memory, however there are working solutions like random sampling of subsets\n",
1175 | "$\\mathcal{S}$ from $\\mathcal{S}_\\mathrm{unlabelled}$ or greedily maximizing\n",
1176 | "of this **submodular** criterion."
1177 | ]
1178 | },
1179 | {
1180 | "cell_type": "markdown",
1181 | "metadata": {},
1182 | "source": [
1183 | "
"
1184 | ]
1185 | },
1186 | {
1187 | "cell_type": "markdown",
1188 | "metadata": {},
1189 | "source": [
1190 | "### (optional) Points of improvement: bias"
1191 | ]
1192 | },
1193 | {
1194 | "cell_type": "markdown",
1195 | "metadata": {},
1196 | "source": [
1197 | "The first term in the **MC** estimate of the mutual information is the\n",
1198 | "so-called **plug-in** estimator of the entropy:\n",
1199 | "\n",
1200 | "$$\n",
1201 | " \\hat{H}\n",
1202 | " = \\mathbb{H}(\\hat{p}) = - \\sum_k \\hat{p}_k \\log \\hat{p}_k\n",
1203 | " \\,, $$\n",
1204 | "\n",
1205 | "where $\\hat{p}_k = \\tfrac1B \\sum_b p_{bk}$ is the full sample estimator\n",
1206 | "of the probabilities."
1207 | ]
1208 | },
1209 | {
1210 | "cell_type": "markdown",
1211 | "metadata": {},
1212 | "source": [
1213 | "It is known that this plug-in estimate is biased\n",
1214 | "(see [blog: Nowozin, 2015](http://www.nowozin.net/sebastian/blog/estimating-discrete-entropy-part-1.html)\n",
1215 | "and references therein, also this [notebook](https://colab.research.google.com/drive/1z9ZDNM6NFmuFnU28d8UO0Qymbd2LiNJW)). \n",
1216 | "In order to correct for small-sample bias we can use\n",
1217 | "[jackknife resampling](https://en.wikipedia.org/wiki/Jackknife_resampling).\n",
1218 | "It derives an estimate of the finite sample bias from the leave-one-out\n",
1219 | "estimators of the entropy and is relatively computationally cheap\n",
1220 | "(see [blog: Nowozin, 2015](http://www.nowozin.net/sebastian/blog/estimating-discrete-entropy-part-2.html),\n",
1221 | "[Miller, R. G. (1974)](http://www.math.ntu.edu.tw/~hchen/teaching/LargeSample/references/Miller74jackknife.pdf) and these [notes](http://people.bu.edu/aimcinto/jackknife.pdf)).\n",
1222 | "\n",
1223 | "The jackknife correction of a plug-in estimator $\\mathbb{H}(\\cdot)$\n",
1224 | "is computed thus: given a sample $(p_b)_{b=1}^B$ with $p_b$ -- discrete distribution on $1..K$\n",
1225 | "* for each $b=1.. B$\n",
1226 | " * get the leave-one-out estimator: $\\hat{p}_k^{-b} = \\tfrac1{B-1} \\sum_{j\\neq b} p_{jk}$\n",
1227 | " * compute the plug-in entropy estimator: $\\hat{H}_{-b} = \\mathbb{H}(\\hat{p}^{-b})$\n",
1228 | "* then compute the bias-corrected entropy estimator $\n",
1229 | "\\hat{H}_J\n",
1230 | " = \\hat{H} + (B - 1) \\bigl\\{\n",
1231 | " \\hat{H} - \\tfrac1B \\sum_b \\hat{H}^{-b}\n",
1232 | " \\bigr\\}\n",
1233 | "$"
1234 | ]
1235 | },
1236 | {
1237 | "cell_type": "markdown",
1238 | "metadata": {},
1239 | "source": [
1240 | "**(note)** when we knock the $i$-th data point out of the sample mean\n",
1241 | "$\\mu = \\tfrac1n \\sum_i x_i$ and recompute the mean $\\mu_{-i}$ we get\n",
1242 | "the following relation\n",
1243 | "$$ \\mu_{-i}\n",
1244 | " = \\frac1{n-1} \\sum_{j\\neq i} x_j\n",
1245 | " = \\frac{n}{n-1} \\mu - \\tfrac1{n-1} x_i\n",
1246 | " = \\mu + \\frac{\\mu - x_i}{n-1}\n",
1247 | " \\,. $$\n",
1248 | "This makes it possible to quickly compute leave-one-out estimators of\n",
1249 | "discrete probability distribution."
1250 | ]
1251 | },
1252 | {
1253 | "cell_type": "markdown",
1254 | "metadata": {},
1255 | "source": [
1256 | "#### (task*) Unbiased estimator of entropy and mutual information\n",
1257 | "\n",
1258 | "Try to efficiently implement a bias-corrected acquisition\n",
1259 | "function, and see it is worth the effort."
1260 | ]
1261 | },
1262 | {
1263 | "cell_type": "code",
1264 | "execution_count": null,
1265 | "metadata": {},
1266 | "outputs": [],
1267 | "source": [
1268 | "def BALD_jknf_acquisition(dataset, model, n_request=1, n_draws=1):\n",
1269 | " proba = sample_proba(model, dataset, n_draws=n_draws)\n",
1270 | "\n",
1271 | " ## Exercise: MC estimate of the predictive distribution, entropy and MI\n",
1272 | " ## mutual information `H E_w p(., w) - E_w H p(., w)` with jackknife\n",
1273 | " ## correction.\n",
1274 | "\n",
1275 | " pass"
1276 | ]
1277 | },
1278 | {
1279 | "cell_type": "markdown",
1280 | "metadata": {},
1281 | "source": [
1282 | "
"
1283 | ]
1284 | },
1285 | {
1286 | "cell_type": "markdown",
1287 | "metadata": {},
1288 | "source": [
1289 | "Let's see ..."
1290 | ]
1291 | },
1292 | {
1293 | "cell_type": "code",
1294 | "execution_count": null,
1295 | "metadata": {
1296 | "scrolled": false
1297 | },
1298 | "outputs": [],
1299 | "source": [
1300 | "jknf_results = active_learn(\n",
1301 | " S_train,\n",
1302 | " S_pool,\n",
1303 | " S_valid,\n",
1304 | " BALD_jknf_acquisition,\n",
1305 | " n_draws=21,\n",
1306 | " n_budget=150,\n",
1307 | " n_max_request=3,\n",
1308 | " n_epochs=200,\n",
1309 | ")"
1310 | ]
1311 | },
1312 | {
1313 | "cell_type": "code",
1314 | "execution_count": null,
1315 | "metadata": {},
1316 | "outputs": [],
1317 | "source": [
1318 | "fig, ax = plt.subplots(1, 1, figsize=(12, 7))\n",
1319 | "\n",
1320 | "model_jknf, train_jknf, scores_jknf, balances_jknf = jknf_results\n",
1321 | "ax.plot(accuracy(scores_rand), label='Accuracy (random)', lw=2)\n",
1322 | "ax.plot(accuracy(scores_bald), label='Accuracy (BALD)', lw=2)\n",
1323 | "ax.plot(accuracy(scores_jknf), label='Accuracy (BALD-jknf)', lw=2)\n",
1324 | "\n",
1325 | "ax.legend()\n",
1326 | "plt.show()"
1327 | ]
1328 | },
1329 | {
1330 | "cell_type": "code",
1331 | "execution_count": null,
1332 | "metadata": {},
1333 | "outputs": [],
1334 | "source": [
1335 | "fig, ax = plt.subplots(1, 1, figsize=(12, 7))\n",
1336 | "\n",
1337 | "lines = ax.plot(balances_jknf, lw=2)\n",
1338 | "plt.legend(lines, list(range(10)), ncol=2);"
1339 | ]
1340 | },
1341 | {
1342 | "cell_type": "markdown",
1343 | "metadata": {},
1344 | "source": [
1345 | "
"
1346 | ]
1347 | }
1348 | ],
1349 | "metadata": {
1350 | "kernelspec": {
1351 | "display_name": "Python 3",
1352 | "language": "python",
1353 | "name": "python3"
1354 | },
1355 | "language_info": {
1356 | "codemirror_mode": {
1357 | "name": "ipython",
1358 | "version": 3
1359 | },
1360 | "file_extension": ".py",
1361 | "mimetype": "text/x-python",
1362 | "name": "python",
1363 | "nbconvert_exporter": "python",
1364 | "pygments_lexer": "ipython3",
1365 | "version": "3.7.2"
1366 | }
1367 | },
1368 | "nbformat": 4,
1369 | "nbformat_minor": 2
1370 | }
1371 |
--------------------------------------------------------------------------------
/bayesian_deep_learning/README.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/bayesian_deep_learning/README.md
--------------------------------------------------------------------------------
/bayesian_deep_learning/mlss2019bdl/__init__.py:
--------------------------------------------------------------------------------
1 | import tqdm
2 | import torch
3 |
4 | import torch.nn.functional as F
5 |
6 | from torch.utils.data import TensorDataset, DataLoader
7 |
8 |
9 | def dataset_from_numpy(*ndarrays, device=None, dtype=torch.float32):
10 | """Create :class:`TensorDataset` from the passed :class:`numpy.ndarray`-s.
11 |
12 | Each returned tensor in the TensorDataset and :attr:`ndarray` share
13 | the same memory, unless a type cast or device transfer took place.
14 | Modifications to any tensor in the dataset will be reflected in respective
15 | :attr:`ndarray` and vice versa.
16 |
17 | Each returned tensor in the dataset is not resizable.
18 |
19 | See Also
20 | --------
21 | torch.from_numpy : create a tensor from an ndarray.
22 | """
23 | tensors = map(torch.from_numpy, ndarrays)
24 |
25 | return TensorDataset(*[t.to(device, dtype) for t in tensors])
26 |
27 |
28 | default_criteria = {
29 | "cross_entropy":
30 | lambda model, X, y: F.cross_entropy(model(X), y, reduction="mean"),
31 | "mse":
32 | lambda model, X, y: 0.5 * F.mse_loss(model(X), y, reduction="mean"),
33 | }
34 |
35 |
36 | def fit(model, dataset, criterion="mse", batch_size=32,
37 | n_epochs=1, weight_decay=0, verbose=False):
38 | """Fit the model with SGD (Adam) on the specified dataset and criterion.
39 |
40 | This bare minimum of a fit loop creates a minibatch generator from
41 | the `dataset` with batches of size `batch_size`. On each batch it
42 | computes the backward pass through the `criterion` and the `model`
43 | and updates the `model`-s parameters with the Adam optimizer step.
44 | The loop passes through the dataset `n_epochs` times. It does not
45 | output any running debugging information, except for a progress bar.
46 |
47 | The criterion can be either "mse" for mean sqaured error, "nll" for
48 | negative loglikelihood (categorical), or a callable taking `model, X, y`
49 | as arguments.
50 | """
51 | if len(dataset) <= 0 or batch_size <= 0:
52 | return model
53 |
54 | criterion = default_criteria.get(criterion, criterion)
55 | assert callable(criterion)
56 |
57 | # get the model's device
58 | device = next(model.parameters()).device
59 |
60 | # an optimizer for model's parameters
61 | optim = torch.optim.Adam(model.parameters(), lr=2e-3,
62 | weight_decay=weight_decay)
63 |
64 | # stochastic minibatch generator for the training loop
65 | feed = DataLoader(dataset, shuffle=True, batch_size=batch_size)
66 | for epoch in tqdm.tqdm(range(n_epochs), disable=not verbose):
67 |
68 | model.train()
69 |
70 | for X, y in feed:
71 | # forward pass through the criterion (batch-average loss)
72 | loss = criterion(model, X.to(device), y.to(device))
73 |
74 | # get gradients with backward pass
75 | optim.zero_grad()
76 | loss.backward()
77 |
78 | # SGD update
79 | optim.step()
80 |
81 | return model
82 |
83 |
84 | def predict(model, dataset, batch_size=512):
85 | """Get model's output on the dataset.
86 |
87 | This straightforward function switches the model into `evaluation`
88 | regime, computes the forward pass on the `dataset` (in batches of
89 | size `batch_size`) and stacks the results into a tensor on the `cpu`.
90 | It temporarily disables `autograd` to gain some speed-up.
91 | """
92 | model.eval()
93 |
94 | # get the model's device
95 | device = next(model.parameters()).device
96 |
97 | # batch generator for the evaluation loop
98 | feed = DataLoader(dataset, batch_size=batch_size, shuffle=False)
99 |
100 | # compute and collect the outputs
101 | with torch.no_grad():
102 | return torch.cat([
103 | model(X.to(device)).cpu() for X, *rest in feed
104 | ], dim=0)
105 |
--------------------------------------------------------------------------------
/bayesian_deep_learning/mlss2019bdl/bdl/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import freeze, unfreeze
2 |
3 |
4 | from .bernoulli import DropoutLinear, DropoutConv2d
5 |
--------------------------------------------------------------------------------
/bayesian_deep_learning/mlss2019bdl/bdl/base.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch.nn import Module
4 |
5 |
6 | class FreezableWeight(Module):
7 | def __init__(self):
8 | super().__init__()
9 | self.unfreeze()
10 |
11 | def unfreeze(self):
12 | self.register_buffer("frozen_weight", None)
13 |
14 | def is_frozen(self):
15 | """Check if a frozen weight is available."""
16 | return isinstance(self.frozen_weight, torch.Tensor)
17 |
18 | def freeze(self):
19 | """Sample from the distribution and freeze."""
20 | raise NotImplementedError()
21 |
22 |
23 | def freeze(module):
24 | for mod in module.modules():
25 | if isinstance(mod, FreezableWeight):
26 | mod.freeze()
27 |
28 | return module # return self
29 |
30 |
31 | def unfreeze(module):
32 | for mod in module.modules():
33 | if isinstance(mod, FreezableWeight):
34 | mod.unfreeze()
35 |
36 | return module # return self
37 |
38 |
39 | class PenalizedWeight(Module):
40 | def penalty(self):
41 | raise NotImplementedError()
42 |
43 |
44 | def penalties(module):
45 | for mod in module.modules():
46 | if isinstance(mod, PenalizedWeight):
47 | yield mod.penalty()
48 |
--------------------------------------------------------------------------------
/bayesian_deep_learning/mlss2019bdl/bdl/bernoulli.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from torch.nn import Linear, Conv2d
5 |
6 | from .base import FreezableWeight, PenalizedWeight
7 |
8 |
9 | class DropoutLinear(Linear, FreezableWeight):
10 | """Linear layer with dropout on inputs."""
11 | def __init__(self, in_features, out_features, bias=True, p=0.5):
12 | super().__init__(in_features, out_features, bias=bias)
13 |
14 | self.p = p
15 |
16 | def forward(self, input):
17 | if self.is_frozen():
18 | return F.linear(input, self.frozen_weight, self.bias)
19 |
20 | return super().forward(F.dropout(input, self.p, True))
21 |
22 | def freeze(self):
23 | # let's draw the new weight
24 | with torch.no_grad():
25 | prob = torch.full_like(self.weight[:1, :], 1 - self.p)
26 | feature_mask = torch.bernoulli(prob) / prob
27 |
28 | frozen_weight = self.weight * feature_mask
29 |
30 | # and store it
31 | self.register_buffer("frozen_weight", frozen_weight)
32 |
33 |
34 | class DropoutConv2d(Conv2d, FreezableWeight):
35 | """2d Convolutional layer with dropout on input features."""
36 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
37 | padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros',
38 | p=0.5):
39 |
40 | super().__init__(in_channels, out_channels, kernel_size, stride=stride,
41 | padding=padding, dilation=dilation, groups=groups,
42 | bias=bias, padding_mode=padding_mode)
43 |
44 | self.p = p
45 |
46 | def forward(self, input):
47 | """Apply feature dropout and then forward pass through the convolution."""
48 | if self.is_frozen():
49 | return F.conv2d(input, self.frozen_weight, self.bias, self.stride,
50 | self.padding, self.dilation, self.groups)
51 |
52 | return super().forward(F.dropout2d(input, self.p, True))
53 |
54 | def freeze(self):
55 | """Sample the weight from the parameter distribution and freeze it."""
56 | prob = torch.full_like(self.weight[:1, :, :1, :1], 1 - self.p)
57 | feature_mask = torch.bernoulli(prob) / prob
58 |
59 | with torch.no_grad():
60 | frozen_weight = self.weight * feature_mask
61 |
62 | self.register_buffer("frozen_weight", frozen_weight)
63 |
64 |
--------------------------------------------------------------------------------
/bayesian_deep_learning/mlss2019bdl/bdl/gaussian.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from torch.nn import Linear, Conv2d
5 |
6 | from .base import FreezableWeight, PenalizedWeight
7 |
8 |
9 | class BaseGaussianLinear(Linear, FreezableWeight, PenalizedWeight):
10 | """Linear layer with Gaussian Mean Field weight distribution."""
11 | def __init__(self, in_features, out_features, bias=True):
12 | super().__init__(in_features, out_features, bias=bias)
13 |
14 | self.log_sigma2 = torch.nn.Parameter(
15 | torch.Tensor(*self.weight.shape))
16 |
17 | self.reset_variational_parameters()
18 |
19 | def reset_variational_parameters(self):
20 | self.log_sigma2.data.normal_(-5, 0.1) # from arxiv:1811.00596
21 |
22 | def forward(self, input):
23 | """Forward pass for the linear layer with the local reparameterization trick."""
24 |
25 | if self.is_frozen():
26 | return F.linear(input, self.frozen_weight, self.bias)
27 |
28 | s2 = F.linear(input * input, torch.exp(self.log_sigma2), None)
29 |
30 | return torch.normal(super().forward(input), torch.sqrt(s2 + 1e-20))
31 |
32 | def freeze(self):
33 |
34 | with torch.no_grad():
35 | stdev = torch.exp(0.5 * self.log_sigma2)
36 | weight = torch.normal(self.weight, std=stdev)
37 |
38 | self.register_buffer("frozen_weight", weight)
39 |
40 |
41 | class BaseGaussianConv2d(Conv2d, PenalizedWeight, FreezableWeight):
42 | """Convolutional layer with Gaussian Mean Field weight distribution."""
43 |
44 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
45 | padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'):
46 | super().__init__(in_channels, out_channels, kernel_size, stride=stride,
47 | padding=padding, dilation=dilation, groups=groups,
48 | bias=bias, padding_mode=padding_mode)
49 |
50 | self.log_sigma2 = torch.nn.Parameter(
51 | torch.Tensor(*self.weight.shape))
52 |
53 | self.reset_variational_parameters()
54 |
55 | reset_variational_parameters = BaseGaussianLinear.reset_variational_parameters
56 |
57 | def forward(self, input):
58 | """Forward pass with the local reparameterization trick."""
59 | if self.is_frozen():
60 | return F.conv2d(input, self.frozen_weight, self.bias, self.stride,
61 | self.padding, self.dilation, self.groups)
62 |
63 | s2 = F.conv2d(input * input, torch.exp(self.log_sigma2), None,
64 | self.stride, self.padding, self.dilation, self.groups)
65 |
66 | return torch.normal(super().forward(input), torch.sqrt(s2 + 1e-20))
67 |
68 | freeze = BaseGaussianLinear.freeze
69 |
70 |
71 | class GaussianLinearARD(BaseGaussianLinear):
72 | def penalty(self):
73 | # compute \tfrac12 \log (1 + \tfrac{\mu_{ji}}{\sigma_{ji}^2})
74 | log_weight2 = 2 * torch.log(torch.abs(self.weight) + 1e-20)
75 |
76 | # `softplus` is $x \mapsto \log(1 + e^x)$
77 | return 0.5 * torch.sum(F.softplus(log_weight2 - self.log_sigma2))
78 |
79 |
80 | class GaussianConv2dARD(BaseGaussianConv2d):
81 | penalty = GaussianLinearARD.penalty
82 |
--------------------------------------------------------------------------------
/bayesian_deep_learning/mlss2019bdl/dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 | from torch.utils.data import TensorDataset
5 | from torchvision import datasets
6 |
7 | from sklearn.utils import check_random_state
8 | from sklearn.model_selection import train_test_split
9 |
10 |
11 | def get_data(name, path="./data", train=True):
12 | if name == "MNIST":
13 | dataset = datasets.MNIST(path, train=train, download=True)
14 | elif name == "KMNIST":
15 | dataset = datasets.KMNIST(path, train=train, download=True)
16 |
17 | images = dataset.data.float().unsqueeze(1)
18 | return TensorDataset(images / 255., dataset.targets)
19 |
20 |
21 | def get_dataset(n_train=20, n_valid=5000, n_pool=5000,
22 | name="MNIST", path="./data", random_state=None):
23 | random_state = check_random_state(random_state)
24 |
25 | dataset = get_data(name, path, train=True)
26 | S_test = get_data(name, path, train=False)
27 |
28 | # create an imbalanced class label distribution for the train
29 | targets = dataset.tensors[-1].cpu().numpy()
30 |
31 | # split the dataset into validaton and train
32 | ix_all = np.r_[:len(targets)]
33 | ix_train, ix_valid = train_test_split(
34 | ix_all, stratify=targets, shuffle=True,
35 | train_size=max(n_train, 1), test_size=max(n_valid, 1),
36 | random_state=random_state)
37 |
38 | # prepare the datasets: pool, train and validation
39 | if n_train < 1:
40 | ix_train = np.r_[:0]
41 | S_train = TensorDataset(*dataset[ix_train])
42 |
43 | if n_valid < 1:
44 | ix_valid = np.r_[:0]
45 | S_valid = TensorDataset(*dataset[ix_valid])
46 |
47 | # prepare the pool
48 | ix_pool = np.delete(ix_all, np.r_[ix_train, ix_valid])
49 |
50 | # we want to have lots of boring/useless examples in the pool
51 | labels, share = (1, 2, 3, 4, 5, 6, 7, 8, 9), 0.95
52 | pool_targets, dropped = targets[ix_pool], []
53 |
54 | # deplete the pool of each class
55 | for label in labels:
56 | ix_cls = np.flatnonzero(pool_targets == label)
57 | n_kept = int(share * len(ix_cls))
58 |
59 | # pick examples at random to drop
60 | ix_cls = random_state.permutation(ix_cls)
61 | dropped.append(ix_cls[:n_kept])
62 |
63 | ix_pool = np.delete(ix_pool, np.concatenate(dropped))
64 |
65 | # select at most `n_pool` examples
66 | if n_pool > 0:
67 | ix_pool = random_state.permutation(ix_pool)[:n_pool]
68 | S_pool = TensorDataset(*dataset[ix_pool])
69 |
70 | return S_train, S_pool, S_valid, S_test
71 |
72 |
73 | def collect(indices, dataset):
74 | """Collect the specified samples from the dataset and remove."""
75 | assert len(dataset) > 0
76 |
77 | mask = torch.zeros(len(dataset), dtype=torch.uint8)
78 | mask[indices] = True
79 |
80 | collected = TensorDataset(*dataset[mask])
81 |
82 | dataset.tensors = dataset[~mask]
83 |
84 | return collected
85 |
86 |
87 | def merge(*datasets, out=None):
88 | # Classes derived from Dataset support appending via
89 | # `+` (__add__), but this breaks slicing.
90 |
91 | data = [d.tensors for d in datasets if d is not None and d.tensors]
92 | assert all(len(data[0]) == len(d) for d in data)
93 |
94 | tensors = [torch.cat(tup, dim=0) for tup in zip(*data)]
95 |
96 | if isinstance(out, TensorDataset):
97 | out.tensors = tensors
98 | return out
99 |
100 | return TensorDataset(*tensors)
101 |
--------------------------------------------------------------------------------
/bayesian_deep_learning/mlss2019bdl/flex.py:
--------------------------------------------------------------------------------
1 | """Handy plotting procedures for small 2d images."""
2 | import numpy as np
3 |
4 | from torch import Tensor
5 | from math import sqrt
6 |
7 |
8 | def get_dimensions(n_samples, height, width,
9 | n_row=None, n_col=None, aspect=(16, 9)):
10 | """Get the dimensions that aesthetically conform to the aspect ratio."""
11 | if n_row is None and n_col is None:
12 | ratio = (width * aspect[1]) / (height * aspect[0])
13 | n_row = int(sqrt(n_samples * ratio))
14 |
15 | if n_row is None:
16 | n_row = (n_samples + n_col - 1) // n_col
17 |
18 | elif n_col is None:
19 | n_col = (n_samples + n_row - 1) // n_row
20 |
21 | return n_row, n_col
22 |
23 |
24 | def setup_canvas(ax, height, width, n_row, n_col):
25 | """Setup the ticks and labels for the canvas."""
26 | # A pair of index arrays
27 | row_index, col_index = np.r_[:n_row], np.r_[:n_col]
28 |
29 | # Setup major ticks to the seams between images and disable labels
30 | ax.set_yticks((row_index[:-1] + 1) * height - 0.5, minor=False)
31 | ax.set_xticks((col_index[:-1] + 1) * width - 0.5, minor=False)
32 |
33 | ax.set_yticklabels([], minor=False)
34 | ax.set_xticklabels([], minor=False)
35 |
36 | # Set minor ticks so that they are exactly between the major ones
37 | ax.set_yticks((row_index + 0.5) * height, minor=True)
38 | ax.set_xticks((col_index + 0.5) * width, minor=True)
39 |
40 | # ... and make their labels into i-j coordinates
41 | ax.set_yticklabels([f"{i:d}" for i in row_index], minor=True)
42 | ax.set_xticklabels([f"{j:d}" for j in col_index], minor=True)
43 |
44 | # Orient tick marks outward
45 | ax.tick_params(axis="both", which="both", direction="out")
46 | return ax
47 |
48 |
49 | def arrange(n_row, n_col, data, fill_value=0):
50 | """Create a grid and populate it with images."""
51 | n_samples, height, width, *color = data.shape
52 | grid = np.full((n_row * height, n_col * width, *color),
53 | fill_value, dtype=data.dtype)
54 |
55 | for k in range(min(n_samples, n_col * n_row)):
56 | i, j = (k // n_col) * height, (k % n_col) * width
57 | grid[i:i + height, j:j + width] = data[k]
58 |
59 | return grid
60 |
61 |
62 | def to_hwc(images, format):
63 | assert format in ("chw", "hwc"), f"Unrecognized format `{format}`."
64 |
65 | if images.ndim == 3:
66 | return images[..., np.newaxis]
67 |
68 | assert images.ndim == 4, f"Images must be Nx{'x'.join(format.upper())}."
69 |
70 | if format == "chw":
71 | return images.transpose(0, 2, 3, 1)
72 |
73 | elif format == "hwc":
74 | return images
75 |
76 |
77 | def plot(ax, images, *, n_col=None, n_row=None, format="chw", **kwargs):
78 | """Plot images in the numpy array on the specified matplotlib Axis."""
79 | if isinstance(images, Tensor):
80 | images = images.data.cpu().numpy()
81 |
82 | images = to_hwc(images, format)
83 |
84 | n_samples, height, width, *color = images.shape
85 | if n_samples < 1:
86 | return None
87 |
88 | n_row, n_col = get_dimensions(n_samples, height, width, n_row, n_col)
89 | ax = setup_canvas(ax, height, width, n_row, n_col)
90 |
91 | image = arrange(n_row, n_col, images)
92 | return ax.imshow(image.squeeze(), **kwargs, origin="upper")
93 |
--------------------------------------------------------------------------------
/bayesian_deep_learning/mlss2019bdl/plotting.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 |
3 | from torch import Tensor
4 | from numpy import asarray
5 |
6 |
7 | def darker(color, a=0.5):
8 | """Adapted from this stackoverflow question_.
9 |
10 | .. _question: https://stackoverflow.com/questions/37765197/
11 | """
12 | from matplotlib.colors import to_rgb
13 | from colorsys import rgb_to_hls, hls_to_rgb
14 |
15 | h, l, s = rgb_to_hls(*to_rgb(color))
16 | return hls_to_rgb(h, max(0, min(a * l, 1)), s)
17 |
18 |
19 | def canvas1d(*, figsize=(12, 5)):
20 | """Setup canvas for 1d function plot."""
21 | fig, ax = plt.subplots(1, 1, figsize=figsize)
22 |
23 | fig.patch.set_alpha(1.0)
24 | ax.set_xlim(-7, +7)
25 | ax.set_ylim(-7, +9)
26 |
27 | return fig, ax
28 |
29 |
30 | def to_numpy(tensor):
31 | if isinstance(tensor, Tensor):
32 | tensor = tensor.data.cpu().numpy()
33 |
34 | return asarray(tensor).squeeze()
35 |
36 |
37 | def plot1d(X, y, bands, ax=None, **kwargs):
38 | X, y = to_numpy(X), to_numpy(y)
39 | assert y.ndim == 2 and X.ndim == 1
40 |
41 | ax = plt.gca() if ax is None else ax
42 |
43 | # plot the predictive mean with the specified colour
44 | y_mean, y_std = y.mean(axis=-1), y.std(axis=-1)
45 | line, = ax.plot(X, y_mean, **kwargs)
46 |
47 | # plot paths or bands with a lighter color and slightly behind the mean
48 | color, zorder = darker(line.get_color(), 1.25), line.get_zorder()
49 | if bands is None:
50 | ax.plot(X, y, c=color, alpha=0.08, zorder=zorder - 1)
51 |
52 | else:
53 | for band in sorted(bands):
54 | ax.fill_between(X, y_mean + band * y_std, y_mean - band * y_std,
55 | color=color, zorder=zorder-1,
56 | alpha=0.4 / len(bands))
57 |
58 | return line
59 |
60 |
61 | def plot1d_bands(X, y, ax=None, **kwargs):
62 | # return plot1d(X, y, bands=(0.5, 1.0, 1.5, 2.0), ax=ax, **kwargs)
63 | return plot1d(X, y, bands=(1.96,), ax=ax, **kwargs)
64 |
65 |
66 | def plot1d_paths(X, y, ax=None, **kwargs):
67 | return plot1d(X, y, bands=None, ax=ax, **kwargs)
68 |
--------------------------------------------------------------------------------
/bayesian_deep_learning/setup.py:
--------------------------------------------------------------------------------
1 | from distutils.core import setup
2 |
3 | setup(
4 | name="mlss2019bdl",
5 | version="0.2",
6 | description="""Service code for MLSS2019 Tutorial on Bayesian Deep Learning""",
7 | license="MIT License",
8 | author="Ivan Nazarov, Yarin Gal",
9 | author_email="ivan.nazarov@skolkovotech.ru",
10 | packages=[
11 | "mlss2019bdl",
12 | "mlss2019bdl.bdl",
13 | ],
14 | install_requires=[
15 | "numpy",
16 | "tqdm",
17 | "matplotlib",
18 | "torch",
19 | "torchvision",
20 | ]
21 | )
22 |
--------------------------------------------------------------------------------
/causality/exercises-answers.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/causality/exercises-answers.pdf
--------------------------------------------------------------------------------
/causality/exercises-tutorial.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/causality/exercises-tutorial.pdf
--------------------------------------------------------------------------------
/geometric_techniques_in_ML/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include data/*.csv
2 |
--------------------------------------------------------------------------------
/geometric_techniques_in_ML/riemannian_opt_gmm_embeddings.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "kernelspec": {
6 | "display_name": "Python 2",
7 | "language": "python",
8 | "name": "python2"
9 | },
10 | "language_info": {
11 | "codemirror_mode": {
12 | "name": "ipython",
13 | "version": 3
14 | },
15 | "file_extension": ".py",
16 | "mimetype": "text/x-python",
17 | "name": "python",
18 | "nbconvert_exporter": "python",
19 | "pygments_lexer": "ipython3",
20 | "version": "3.6.5"
21 | },
22 | "colab": {
23 | "name": "riemannian_opt_gmm_embeddings.ipynb",
24 | "version": "0.3.2",
25 | "provenance": []
26 | }
27 | },
28 | "cells": [
29 | {
30 | "cell_type": "markdown",
31 | "metadata": {
32 | "id": "FnIQw3vrPEhl",
33 | "colab_type": "text"
34 | },
35 | "source": [
36 | "This is a tutorial notebook on Riemannian optimization for machine learning, prepared for the Machine Learning Summer School 2019 (MLSS-2019, http://mlss2019.skoltech.ru) in Moscow, Russia, Skoltech (http://skoltech.ru).\n",
37 | "\n",
38 | "Copyright 2019 by Alexey Artemov and ADASE 3DDL Team. Special thanks to Alexey Zaytsev for a valuable contribution."
39 | ]
40 | },
41 | {
42 | "cell_type": "markdown",
43 | "metadata": {
44 | "id": "8MD6cSJaPEhn",
45 | "colab_type": "text"
46 | },
47 | "source": [
48 | "## Index"
49 | ]
50 | },
51 | {
52 | "cell_type": "markdown",
53 | "metadata": {
54 | "id": "WNw_x4zrPEhn",
55 | "colab_type": "text"
56 | },
57 | "source": [
58 | "1. [Generate a toy dataset](#Generate-a-toy-dataset).\n",
59 | "2. [Use Riemannian optimization to obtain GMM estimates](#Use-Riemannian-optimization-to-obtain-GMM-estimates).\n",
60 | "3. [GMM with real-world data using Riemannian optimization](#GMM-with-real-world-data-using-Riemannian-optimization)."
61 | ]
62 | },
63 | {
64 | "cell_type": "markdown",
65 | "metadata": {
66 | "id": "h69X9iE7PEho",
67 | "colab_type": "text"
68 | },
69 | "source": [
70 | "## Riemannian Optimisation with `pymanopt` for inference in Gaussian mixture models"
71 | ]
72 | },
73 | {
74 | "cell_type": "markdown",
75 | "metadata": {
76 | "id": "Fvi8wp7kPEhp",
77 | "colab_type": "text"
78 | },
79 | "source": [
80 | "This notebook is the second in the series of two notebooks on Riemannian optimization and is based heavily on the [official mixture of Gaussian notebook](https://github.com/pymanopt/pymanopt/blob/master/examples/MoG.ipynb) from `pymanopt` docs. \n",
81 | "\n",
82 | "For the basic introduction, see the first part `riemannian_opt_for_ml.ipynb`."
83 | ]
84 | },
85 | {
86 | "cell_type": "markdown",
87 | "metadata": {
88 | "id": "kUbxow6wPEhp",
89 | "colab_type": "text"
90 | },
91 | "source": [
92 | "Install the necessary libraries"
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "metadata": {
98 | "id": "6RqFzcLpPEhq",
99 | "colab_type": "code",
100 | "colab": {}
101 | },
102 | "source": [
103 | "!pip install --upgrade git+https://github.com/mlss-skoltech/tutorials.git#subdirectory=geometric_techniques_in_ML"
104 | ],
105 | "execution_count": 0,
106 | "outputs": []
107 | },
108 | {
109 | "cell_type": "code",
110 | "metadata": {
111 | "id": "sOGvTsCrPEhu",
112 | "colab_type": "code",
113 | "colab": {}
114 | },
115 | "source": [
116 | "!pip install pymanopt autograd\n",
117 | "!pip install scipy==1.2.1 -U"
118 | ],
119 | "execution_count": 0,
120 | "outputs": []
121 | },
122 | {
123 | "cell_type": "code",
124 | "metadata": {
125 | "id": "cZDQc1LfQBRi",
126 | "colab_type": "code",
127 | "colab": {}
128 | },
129 | "source": [
130 | "import pkg_resources\n",
131 | "\n",
132 | "DATA_PATH = pkg_resources.resource_filename('riemannianoptimization', 'data/')"
133 | ],
134 | "execution_count": 0,
135 | "outputs": []
136 | },
137 | {
138 | "cell_type": "markdown",
139 | "metadata": {
140 | "id": "rgkSzxeVPEhw",
141 | "colab_type": "text"
142 | },
143 | "source": [
144 | "### Generate a toy dataset"
145 | ]
146 | },
147 | {
148 | "cell_type": "markdown",
149 | "metadata": {
150 | "id": "Q9auRDUvPEhx",
151 | "colab_type": "text"
152 | },
153 | "source": [
154 | "The Mixture of Gaussians (MoG) model assumes that datapoints $\\mathbf{x}_i\\in\\mathbb{R}^d$ follow a distribution described by the following probability density function:\n",
155 | "$$\n",
156 | "p(\\mathbf{x}) = \\sum_{m=1}^M \\pi_m p_\\mathcal{N}(\\mathbf{x};\\mathbf{\\mu}_m,\\mathbf{\\Sigma}_m)\n",
157 | "$$ \n",
158 | "\n",
159 | "where $\\pi_m$ is the probability that the data point belongs to the $m^\\text{th}$ mixture component and $p_\\mathcal{N}(\\mathbf{x};\\mathbf{\\mu}_m,\\mathbf{\\Sigma}_m)$ is the probability density function of a [multivariate Gaussian distribution](https://en.wikipedia.org/wiki/Multivariate_normal_distribution) with mean $\\mathbf{\\mu}_m \\in \\mathbb{R}^d$ and [positive semi-definite](https://en.wikipedia.org/wiki/Definiteness_of_a_matrix) (PSD) covariance matrix $\\mathbf{\\Sigma}_m \\in \\{\\mathbf{M}\\in\\mathbb{R}^{d\\times d}: \\mathbf{M}\\succeq 0\\}$.\n",
160 | "\n",
161 | "As an example consider the mixture of three Gaussians with means\n",
162 | "$$\n",
163 | "\\mathbf{\\mu}_1 = \\begin{bmatrix} -4 \\\\ 1 \\end{bmatrix},\n",
164 | "\\quad\n",
165 | "\\mathbf{\\mu}_2 = \\begin{bmatrix} 0 \\\\ 0 \\end{bmatrix},\n",
166 | "\\quad\n",
167 | "\\mathbf{\\mu}_3 = \\begin{bmatrix} 2 \\\\ -1 \\end{bmatrix},\n",
168 | "$$\n",
169 | "covariances\n",
170 | "$$\\mathbf{\\Sigma}_1 = \\begin{bmatrix} 3 & 0 \\\\ 0 & 1 \\end{bmatrix},\n",
171 | "\\mathbf{\\Sigma}_2 = \\begin{bmatrix} 1 & 1 \\\\ 1 & 3 \\end{bmatrix},\n",
172 | "\\mathbf{\\Sigma}_3 = \\begin{bmatrix} 0.5 & 0 \\\\ 0 & 0.5 \\end{bmatrix}$$\n",
173 | "and mixture probability vector $\\pi=\\left[0.1, 0.6, 0.3\\right]$.\n",
174 | "Let's generate $N=1000$ samples of that MoG model and scatter plot the samples:"
175 | ]
176 | },
177 | {
178 | "cell_type": "markdown",
179 | "metadata": {
180 | "id": "QuYDX2UKPEhx",
181 | "colab_type": "text"
182 | },
183 | "source": [
184 | "Generate a synthetic dataset of $M=3$ Gaussian distributions, w"
185 | ]
186 | },
187 | {
188 | "cell_type": "code",
189 | "metadata": {
190 | "id": "ambBqXq6PEhy",
191 | "colab_type": "code",
192 | "colab": {}
193 | },
194 | "source": [
195 | "import numpy as np\n",
196 | "np.set_printoptions(precision=2)\n",
197 | "\n",
198 | "toy_n_points = 1000 # Number of data\n",
199 | "toy_dim = 2 # Dimension of data\n",
200 | "toy_components = 3 # Number of clusters \n",
201 | "\n",
202 | "# mixture parameters\n",
203 | "toy_pi = [0.1, 0.6, 0.3]\n",
204 | "toy_mus = [np.array([-4, 1]),\n",
205 | " np.array([0, 0]),\n",
206 | " np.array([2, -1])]\n",
207 | "toy_sigmas = [np.array([[3, 0],[0, 1]]),\n",
208 | " np.array([[1, 1.], [1, 3]]),\n",
209 | " .5 * np.eye(2)]\n",
210 | "\n",
211 | "# select which component work in each case\n",
212 | "components = np.random.choice(toy_components, size=toy_n_points, p=toy_pi)\n",
213 | "\n",
214 | "# prepare data\n",
215 | "samples = np.zeros((toy_n_points, toy_dim))\n",
216 | "\n",
217 | "# for each component, generate all needed samples\n",
218 | "for k in range(toy_components):\n",
219 | " # indices of current component in X\n",
220 | " indices = (k == components)\n",
221 | " # number of those occurrences\n",
222 | " n_k = indices.sum()\n",
223 | " if n_k > 0:\n",
224 | " samples[indices] = np.random.multivariate_normal(toy_mus[k], toy_sigmas[k], n_k)"
225 | ],
226 | "execution_count": 0,
227 | "outputs": []
228 | },
229 | {
230 | "cell_type": "markdown",
231 | "metadata": {
232 | "id": "rrePOBAiPEh1",
233 | "colab_type": "text"
234 | },
235 | "source": [
236 | "The following is a bunch of helper functions for visualizations."
237 | ]
238 | },
239 | {
240 | "cell_type": "code",
241 | "metadata": {
242 | "id": "IEM2wHcMPEh2",
243 | "colab_type": "code",
244 | "colab": {}
245 | },
246 | "source": [
247 | "import numpy as np\n",
248 | "import matplotlib.pyplot as plt\n",
249 | "from matplotlib import cm # Colormaps\n",
250 | "\n",
251 | "\n",
252 | "def multivariate_normal(x, d, mean, covariance):\n",
253 | " \"\"\"pdf of the multivariate normal distribution.\"\"\"\n",
254 | " x_m = x - mean\n",
255 | " pdf = (1. / (np.sqrt((2 * np.pi)**d * np.linalg.det(covariance))) * \n",
256 | " np.exp(-(np.linalg.solve(covariance, x_m).T.dot(x_m)) / 2))\n",
257 | " return pdf\n",
258 | "\n",
259 | "\n",
260 | "# Plot bivariate distribution\n",
261 | "def generate_surface(mean, covariance, d):\n",
262 | " \"\"\"Helper function to generate density surface.\"\"\"\n",
263 | " nb_of_x = 100 # grid size\n",
264 | " # choose limits adaptively\n",
265 | "# mu1, mu2 = mean[:, 0]\n",
266 | "# sigmasq1, sigmasq2 = covariance[0, 0], covariance[1, 1]\n",
267 | "# min_x1 = mu1 - 3. * np.sqrt(sigmasq1)\n",
268 | "# max_x1 = mu1 + 3. * np.sqrt(sigmasq1)\n",
269 | "# min_x2 = mu2 - 3. * np.sqrt(sigmasq2)\n",
270 | "# max_x2 = mu2 + 3. * np.sqrt(sigmasq2)\n",
271 | "# print(min_x1, max_x1)\n",
272 | "# print(min_x2, max_x2)\n",
273 | " min_x1, max_x1 = -4, 4\n",
274 | " min_x2, max_x2 = -4, 4\n",
275 | " x1s = np.linspace(min_x1, max_x1, num=nb_of_x)\n",
276 | " x2s = np.linspace(min_x2, max_x2, num=nb_of_x)\n",
277 | " x1, x2 = np.meshgrid(x1s, x2s) # Generate grid\n",
278 | " pdf = np.zeros((nb_of_x, nb_of_x))\n",
279 | " \n",
280 | " # Fill the cost matrix for each combination of weights\n",
281 | " for i in range(nb_of_x):\n",
282 | " for j in range(nb_of_x):\n",
283 | " pdf[i,j] = multivariate_normal(\n",
284 | " np.matrix([[x1[i,j]], [x2[i,j]]]), \n",
285 | " d, mean, covariance)\n",
286 | " return x1, x2, pdf # x1, x2, pdf(x1,x2)\n",
287 | "\n",
288 | "\n",
289 | "def plot_gaussian(mu, sigma, ax):\n",
290 | " bivariate_mean = np.matrix(mu) # Mean\n",
291 | " bivariate_covariance = np.matrix(sigma) # Covariance\n",
292 | " x1, x2, p = generate_surface(\n",
293 | " bivariate_mean, bivariate_covariance, d=2) \n",
294 | " # Plot bivariate distribution\n",
295 | " con = ax.contour(x1, x2, p, 10, cmap=cm.hot)\n",
296 | " # ax2.axis([-2.5, 2.5, -1.5, 3.5])\n",
297 | " ax.set_aspect('equal')"
298 | ],
299 | "execution_count": 0,
300 | "outputs": []
301 | },
302 | {
303 | "cell_type": "code",
304 | "metadata": {
305 | "id": "kIPDqDjLPEh4",
306 | "colab_type": "code",
307 | "colab": {}
308 | },
309 | "source": [
310 | "fig = plt.figure(figsize=(8,8))\n",
311 | "ax = fig.gca()\n",
312 | "\n",
313 | "for mu, sigma in zip(toy_mus, toy_sigmas):\n",
314 | " mu = np.matrix(mu).T # plot_gaussian requires mu to be a column vector\n",
315 | " plot_gaussian(mu, sigma, ax)\n",
316 | "\n",
317 | "colors = ['r', 'g', 'b', 'c', 'm']\n",
318 | "for i in range(toy_components):\n",
319 | " indices = (i == components)\n",
320 | " ax.scatter(samples[indices, 0], samples[indices, 1], alpha=.4, color=colors[i % toy_components])\n"
321 | ],
322 | "execution_count": 0,
323 | "outputs": []
324 | },
325 | {
326 | "cell_type": "markdown",
327 | "metadata": {
328 | "id": "Zblgje-JPEh9",
329 | "colab_type": "text"
330 | },
331 | "source": [
332 | "### Use Riemannian optimization to obtain GMM estimates"
333 | ]
334 | },
335 | {
336 | "cell_type": "markdown",
337 | "metadata": {
338 | "id": "6l5-6FsOPEh-",
339 | "colab_type": "text"
340 | },
341 | "source": [
342 | "Given a data sample the de facto standard method to infer the parameters is the [expectation maximisation](https://en.wikipedia.org/wiki/Expectation-maximization_algorithm) (EM) algorithm that, in alternating so-called E and M steps, maximises the log-likelihood of the data.\n",
343 | "\n",
344 | "In [arXiv:1506.07677](http://arxiv.org/pdf/1506.07677v1.pdf) Hosseini and Sra propose Riemannian optimisation as a powerful counterpart to EM. Importantly, they introduce a reparameterisation that leaves local optima of the log-likelihood unchanged while resulting in a geodesically convex optimisation problem over a product manifold $\\prod_{m=1}^M\\mathcal{PD}^{(d+1)\\times(d+1)}$ of manifolds of $(d+1)\\times(d+1)$ positive definite matrices.\n",
345 | "The proposed method is on par with EM and shows less variability in running times.\n",
346 | "\n",
347 | "The reparameterised optimisation problem for augmented data points $\\mathbf{y}_i=[\\mathbf{x}_i\\ 1]$ can be stated as follows:\n",
348 | "\n",
349 | "$$\\min_{(S_1, ..., S_m, \\nu_1, ..., \\nu_{m-1}) \\in \\prod_{m=1}^M \\mathcal{PD}^{(d+1)\\times(d+1)}\\times\\mathbb{R}^{M-1}}\n",
350 | "-\\sum_{n=1}^N\\log\\left(\n",
351 | "\\sum_{m=1}^M \\frac{\\exp(\\nu_m)}{\\sum_{k=1}^M\\exp(\\nu_k)}\n",
352 | "q_\\mathcal{N}(\\mathbf{y}_n;\\mathbf{S}_m)\n",
353 | "\\right)$$\n",
354 | "\n",
355 | "where\n",
356 | "\n",
357 | "* $\\mathcal{PD}^{(d+1)\\times(d+1)}$ is the manifold of positive definite\n",
358 | "$(d+1)\\times(d+1)$ matrices\n",
359 | "* $\\mathcal{\\nu}_m = \\log\\left(\\frac{\\alpha_m}{\\alpha_M}\\right), \\ m=1, ..., M-1$ and $\\nu_M=0$\n",
360 | "* $q_\\mathcal{N}(\\mathbf{y}_n;\\mathbf{S}_m) =\n",
361 | "2\\pi\\exp\\left(\\frac{1}{2}\\right)\n",
362 | "|\\operatorname{det}(\\mathbf{S}_m)|^{-\\frac{1}{2}}(2\\pi)^{-\\frac{d+1}{2}}\n",
363 | "\\exp\\left(-\\frac{1}{2}\\mathbf{y}_i^\\top\\mathbf{S}_m^{-1}\\mathbf{y}_i\\right)$\n",
364 | "\n",
365 | "**Optimisation problems like this can easily be solved using Pymanopt – even without the need to differentiate the cost function manually!**\n",
366 | "\n",
367 | "So let's infer the parameters of our toy example by Riemannian optimisation using Pymanopt:"
368 | ]
369 | },
370 | {
371 | "cell_type": "code",
372 | "metadata": {
373 | "id": "iHYOYVfMPEh_",
374 | "colab_type": "code",
375 | "colab": {}
376 | },
377 | "source": [
378 | "import pymanopt as opt\n",
379 | "import pymanopt.solvers as solvers\n",
380 | "import pymanopt.manifolds as manifolds"
381 | ],
382 | "execution_count": 0,
383 | "outputs": []
384 | },
385 | {
386 | "cell_type": "code",
387 | "metadata": {
388 | "id": "UYYpl8X4PEiB",
389 | "colab_type": "code",
390 | "colab": {}
391 | },
392 | "source": [
393 | "import autograd.numpy as np # import here to avoid errors\n",
394 | "from autograd.scipy.misc import logsumexp\n",
395 | "\n",
396 | "# (1) Instantiate the manifold\n",
397 | "manifold = manifolds.Product([\n",
398 | " manifolds.PositiveDefinite(toy_dim + 1, k=toy_components), \n",
399 | " manifolds.Euclidean(toy_components - 1)\n",
400 | "])\n",
401 | "\n",
402 | "# (2) Define cost function\n",
403 | "# The parameters must be contained in a list theta.\n",
404 | "def cost(theta):\n",
405 | " # Unpack parameters\n",
406 | " nu = np.concatenate([theta[1], [0]], axis=0)\n",
407 | " \n",
408 | " S = theta[0]\n",
409 | " logdetS = np.expand_dims(np.linalg.slogdet(S)[1], 1)\n",
410 | " y = np.concatenate([samples.T, np.ones((1, len(samples)))], axis=0)\n",
411 | "\n",
412 | " # Calculate log_q\n",
413 | " y = np.expand_dims(y, 0)\n",
414 | " \n",
415 | " # 'Probability' of y belonging to each cluster\n",
416 | " log_q = -0.5 * (np.sum(y * np.linalg.solve(S, y), axis=1) + logdetS)\n",
417 | "\n",
418 | " alpha = np.exp(nu)\n",
419 | " alpha = alpha / np.sum(alpha)\n",
420 | " alpha = np.expand_dims(alpha, 1)\n",
421 | " \n",
422 | " loglikvec = logsumexp(np.log(alpha) + log_q, axis=0)\n",
423 | " return -np.sum(loglikvec)\n",
424 | "\n",
425 | "\n",
426 | "problem = opt.Problem(manifold=manifold, cost=cost, verbosity=2)\n",
427 | "\n",
428 | "# (3) Instantiate a Pymanopt solver\n",
429 | "solver = solvers.SteepestDescent()\n",
430 | "\n",
431 | "# let Pymanopt do the rest\n",
432 | "Xopt = solver.solve(problem)"
433 | ],
434 | "execution_count": 0,
435 | "outputs": []
436 | },
437 | {
438 | "cell_type": "markdown",
439 | "metadata": {
440 | "id": "3yvE_pHbPEiE",
441 | "colab_type": "text"
442 | },
443 | "source": [
444 | "Once Pymanopt has finished the optimisation we can obtain the inferred parameters as follows:"
445 | ]
446 | },
447 | {
448 | "cell_type": "code",
449 | "metadata": {
450 | "id": "KIxF4cPzPEiF",
451 | "colab_type": "code",
452 | "colab": {}
453 | },
454 | "source": [
455 | "def extract_gaussian_parameters(Xopt, n=1):\n",
456 | " params, probas = Xopt\n",
457 | " \n",
458 | " mus, sigmas = [], []\n",
459 | " \n",
460 | " for p in params:\n",
461 | " mu = p[0:2,2:3]\n",
462 | " sigma = p[:2, :2] - mu.dot(mu.T)\n",
463 | " mus.append(mu)\n",
464 | " sigmas.append(sigma)\n",
465 | " \n",
466 | " pis = np.exp(np.concatenate([probas, [0]], axis=0))\n",
467 | " pis = pis / np.sum(pis)\n",
468 | " \n",
469 | " return mus, sigmas, pis"
470 | ],
471 | "execution_count": 0,
472 | "outputs": []
473 | },
474 | {
475 | "cell_type": "code",
476 | "metadata": {
477 | "id": "g7H7_-WMPEiH",
478 | "colab_type": "code",
479 | "colab": {}
480 | },
481 | "source": [
482 | "toy_mus_opt, toy_sigmas_opt, toy_pis_opt = extract_gaussian_parameters(Xopt, n=3)\n",
483 | "toy_mus_opt, toy_sigmas_opt, toy_pis_opt"
484 | ],
485 | "execution_count": 0,
486 | "outputs": []
487 | },
488 | {
489 | "cell_type": "code",
490 | "metadata": {
491 | "id": "21rmQP1iPEiJ",
492 | "colab_type": "code",
493 | "colab": {}
494 | },
495 | "source": [
496 | "fig = plt.figure(figsize=(8,8))\n",
497 | "ax = fig.gca()\n",
498 | "\n",
499 | "for mu, sigma in zip(toy_mus_opt, toy_sigmas_opt):\n",
500 | " plot_gaussian(mu, sigma, ax)\n",
501 | "\n",
502 | "colors = ['r', 'g', 'b', 'c', 'm']\n",
503 | "for i in range(toy_components):\n",
504 | " indices = (i == components)\n",
505 | " ax.scatter(samples[indices, 0], samples[indices, 1], alpha=.4, color=colors[i % toy_components])\n"
506 | ],
507 | "execution_count": 0,
508 | "outputs": []
509 | },
510 | {
511 | "cell_type": "markdown",
512 | "metadata": {
513 | "id": "WMPEQeFgPEiL",
514 | "colab_type": "text"
515 | },
516 | "source": [
517 | "And convince ourselves that the inferred parameters are close to the ground truth parameters."
518 | ]
519 | },
520 | {
521 | "cell_type": "markdown",
522 | "metadata": {
523 | "id": "WdO9uKI3PEiM",
524 | "colab_type": "text"
525 | },
526 | "source": [
527 | "### GMM with real-world data using Riemannian optimization"
528 | ]
529 | },
530 | {
531 | "cell_type": "markdown",
532 | "metadata": {
533 | "id": "0o9UHzqXPEiN",
534 | "colab_type": "text"
535 | },
536 | "source": [
537 | "Certain real-world datasets can be sufficiently closely modelled by the GMM. One instance might be low-dimensional word embeddings. An accompanying notebook `riemannian_opt_text_preprocessing.ipynb` details how these data were obtained. "
538 | ]
539 | },
540 | {
541 | "cell_type": "code",
542 | "metadata": {
543 | "id": "u9g3ModaPEiN",
544 | "colab_type": "code",
545 | "colab": {}
546 | },
547 | "source": [
548 | "import pandas as pd\n",
549 | "df = pd.read_csv(DATA_PATH + 'tsne_result_training_part.csv', index_col=0)\n",
550 | "df"
551 | ],
552 | "execution_count": 0,
553 | "outputs": []
554 | },
555 | {
556 | "cell_type": "code",
557 | "metadata": {
558 | "id": "qD4oYVAkPEiP",
559 | "colab_type": "code",
560 | "colab": {}
561 | },
562 | "source": [
563 | "samples = df[['x', 'y']].values"
564 | ],
565 | "execution_count": 0,
566 | "outputs": []
567 | },
568 | {
569 | "cell_type": "code",
570 | "metadata": {
571 | "id": "amEL6JivPEiR",
572 | "colab_type": "code",
573 | "colab": {}
574 | },
575 | "source": [
576 | "plt.scatter(samples[:, 0], samples[:, 1], alpha=0.5)"
577 | ],
578 | "execution_count": 0,
579 | "outputs": []
580 | },
581 | {
582 | "cell_type": "markdown",
583 | "metadata": {
584 | "id": "QIiQiVMBPEiT",
585 | "colab_type": "text"
586 | },
587 | "source": [
588 | "For the optimization to be a little more stable, we standardize the data."
589 | ]
590 | },
591 | {
592 | "cell_type": "code",
593 | "metadata": {
594 | "id": "4EFQFPxXPEiU",
595 | "colab_type": "code",
596 | "colab": {}
597 | },
598 | "source": [
599 | "from sklearn.preprocessing import StandardScaler\n",
600 | "samples = StandardScaler().fit_transform(samples)\n",
601 | "\n",
602 | "plt.scatter(samples[:, 0], samples[:, 1], alpha=0.5)"
603 | ],
604 | "execution_count": 0,
605 | "outputs": []
606 | },
607 | {
608 | "cell_type": "markdown",
609 | "metadata": {
610 | "id": "V-7O_l5mPEiW",
611 | "colab_type": "text"
612 | },
613 | "source": [
614 | "Use pretty much the same codes as above, changing the number of components and sample size accordingly."
615 | ]
616 | },
617 | {
618 | "cell_type": "code",
619 | "metadata": {
620 | "id": "s49QjNaPPEiX",
621 | "colab_type": "code",
622 | "colab": {}
623 | },
624 | "source": [
625 | "real_components = 4\n",
626 | "real_dim = 2\n",
627 | "real_points = len(samples)"
628 | ],
629 | "execution_count": 0,
630 | "outputs": []
631 | },
632 | {
633 | "cell_type": "code",
634 | "metadata": {
635 | "id": "9AItZaTnPEie",
636 | "colab_type": "code",
637 | "colab": {}
638 | },
639 | "source": [
640 | "import autograd.numpy as np # import here to avoid errors\n",
641 | "from autograd.scipy.misc import logsumexp\n",
642 | "\n",
643 | "# (1) Instantiate the manifold\n",
644 | "manifold = manifolds.Product([\n",
645 | " manifolds.PositiveDefinite(real_dim + 1, k=real_components), \n",
646 | " manifolds.Euclidean(real_components - 1)\n",
647 | "])\n",
648 | "\n",
649 | "# (2) Define cost function\n",
650 | "# The parameters must be contained in a list theta.\n",
651 | "def cost(theta):\n",
652 | " # Unpack parameters\n",
653 | " nu = np.concatenate([theta[1], [0]], axis=0)\n",
654 | " \n",
655 | " S = theta[0]\n",
656 | " logdetS = np.expand_dims(np.linalg.slogdet(S)[1], 1)\n",
657 | " y = np.concatenate([samples.T, np.ones((1, real_points))], axis=0)\n",
658 | "\n",
659 | " # Calculate log_q\n",
660 | " y = np.expand_dims(y, 0)\n",
661 | " \n",
662 | " # 'Probability' of y belonging to each cluster\n",
663 | " log_q = -0.5 * (np.sum(y * np.linalg.solve(S, y), axis=1) + logdetS)\n",
664 | "\n",
665 | " alpha = np.exp(nu)\n",
666 | " alpha = alpha / np.sum(alpha)\n",
667 | " alpha = np.expand_dims(alpha, 1)\n",
668 | " \n",
669 | " loglikvec = logsumexp(np.log(alpha) + log_q, axis=0)\n",
670 | " return -np.sum(loglikvec)\n",
671 | "\n",
672 | "\n",
673 | "problem = opt.Problem(manifold=manifold, cost=cost, verbosity=2)\n",
674 | "\n",
675 | "# (3) Instantiate a Pymanopt solver\n",
676 | "solver = solvers.SteepestDescent()\n",
677 | "\n",
678 | "# let Pymanopt do the rest\n",
679 | "Xopt = solver.solve(problem)"
680 | ],
681 | "execution_count": 0,
682 | "outputs": []
683 | },
684 | {
685 | "cell_type": "code",
686 | "metadata": {
687 | "id": "hwiwvxgSPEih",
688 | "colab_type": "code",
689 | "colab": {}
690 | },
691 | "source": [
692 | "real_mus_opt, real_sigmas_opt, real_pis_opt = extract_gaussian_parameters(Xopt, n=3)\n",
693 | "real_mus_opt, real_sigmas_opt, real_pis_opt"
694 | ],
695 | "execution_count": 0,
696 | "outputs": []
697 | },
698 | {
699 | "cell_type": "code",
700 | "metadata": {
701 | "scrolled": false,
702 | "id": "sJjPaye6PEij",
703 | "colab_type": "code",
704 | "colab": {}
705 | },
706 | "source": [
707 | "fig = plt.figure(figsize=(8,8))\n",
708 | "ax = fig.gca()\n",
709 | "\n",
710 | "for mu, sigma in zip(real_mus_opt, real_sigmas_opt):\n",
711 | " plot_gaussian(mu, sigma, ax)\n",
712 | "\n",
713 | "ax.scatter(samples[:, 0], samples[:, 1], alpha=0.5)"
714 | ],
715 | "execution_count": 0,
716 | "outputs": []
717 | },
718 | {
719 | "cell_type": "markdown",
720 | "metadata": {
721 | "id": "RLge4dRLPEil",
722 | "colab_type": "text"
723 | },
724 | "source": [
725 | "Et voilà – this was a brief demonstration of how to do inference for MoG models by performing Manifold optimisation using Pymanopt."
726 | ]
727 | },
728 | {
729 | "cell_type": "markdown",
730 | "metadata": {
731 | "id": "VPVLD88NPEil",
732 | "colab_type": "text"
733 | },
734 | "source": [
735 | "**TODO HOMEWORK** add riemannian optimization in M-step to speed up EM"
736 | ]
737 | },
738 | {
739 | "cell_type": "code",
740 | "metadata": {
741 | "id": "63OLLMvtPEim",
742 | "colab_type": "code",
743 | "colab": {}
744 | },
745 | "source": [
746 | ""
747 | ],
748 | "execution_count": 0,
749 | "outputs": []
750 | }
751 | ]
752 | }
--------------------------------------------------------------------------------
/geometric_techniques_in_ML/riemannianoptimization/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/geometric_techniques_in_ML/riemannianoptimization/__init__.py
--------------------------------------------------------------------------------
/geometric_techniques_in_ML/riemannianoptimization/tutorial_helpers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import matplotlib.pyplot as plt
4 |
5 |
6 |
7 | # hardcode landmark indexes
8 | left_brow = np.array([1, 5, 3, 6, 1]) - 1
9 | right_brow = np.array([4, 7, 2, 8, 4]) - 1
10 |
11 | left_eye = np.array([9, 13, 11, 14, 9]) - 1
12 | right_eye = np.array([12, 15, 10, 16, 12]) - 1
13 |
14 | nosetip = np.array([19, 21, 20, 22, 19]) - 1
15 | mouth_outerlip = np.array([23, 25, 24, 28, 23]) - 1
16 | mouth_innerlip = np.array([23, 26, 24, 27, 23]) - 1
17 |
18 | face_outer = np.array([29, 33, 31, 35, 32, 34, 30]) - 1
19 |
20 | CONTOURS = [left_brow, right_brow, left_eye, right_eye,
21 | nosetip, mouth_outerlip, mouth_innerlip, face_outer]
22 |
23 |
24 | def get_contours(image):
25 | for contour_idx in CONTOURS:
26 | x = image[contour_idx, 0]
27 | y = -image[contour_idx, 1]
28 | yield x, y
29 |
30 |
31 |
32 | def plot_landmarks(landmarks, ax=None, draw_landmark_id=False, draw_contours=True, draw_landmarks=True,
33 | alpha=1, color_landmarks='red', color_contour='orange', get_contour_handles=False):
34 | """Plots landmarks, connecting them appropriately.
35 |
36 | landmarks: ndarray of shape either [35, 2] or [70,]
37 | ax: axis (created if None)
38 | """
39 | if None is ax:
40 | f = plt.figure(figsize=(8, 8))
41 | ax = f.gca()
42 |
43 | ax.tick_params(
44 | axis='both', # changes apply to both axes
45 | which='both', # both major and minor ticks are affected
46 | bottom=False, # ticks along the bottom edge are off
47 | top=False, # ticks along the top edge are off
48 | left=False,
49 | right=False,
50 | labelbottom=False,
51 | labeltop=False,
52 | labelleft=False,
53 | labelright=False)
54 |
55 | if landmarks.shape == (70,):
56 | landmarks = landmarks.reshape((35, 2))
57 |
58 | contour_handles = []
59 | if draw_contours:
60 |
61 | def _plot_landmark(landmarks, idx, color):
62 | h, = ax.plot(landmarks[idx, 0], -landmarks[idx, 1], color=color)
63 | return h
64 |
65 | contour_handles = [
66 | _plot_landmark(landmarks, idx, color_contour)
67 | for idx in CONTOURS
68 | ]
69 |
70 | if draw_landmarks:
71 | ax.scatter(landmarks[:, 0], -landmarks[:, 1], s=20, color=color_landmarks, alpha=alpha)
72 |
73 | if draw_landmark_id:
74 | for i in range(35):
75 | ax.text(s=str(i + 1), x=landmarks[i, 0], y=-landmarks[i, 1])
76 |
77 | if get_contour_handles:
78 | return contour_handles
79 |
80 |
81 |
82 | def load_data(data_path):
83 | df = pd.read_csv(data_path + 'kbvt_lfpw_v1_train.csv', delimiter='\t')
84 |
85 | # We don't need all of the columns -- only the ones with landmarks
86 | columns_to_include = [col for col in df.columns.tolist()
87 | if col.endswith('_x') or col.endswith('_y')]
88 | print('Selecting the following columns from the dataset: {}'.format('\n'.join(columns_to_include)))
89 |
90 | # select only averaged predictions
91 | data = df[columns_to_include][df['worker'] == 'average']
92 | landmarks = data.values
93 |
94 | print('\n\n The resulting dataset has shape {}'.format(landmarks.shape))
95 |
96 | return landmarks
97 |
98 |
99 | def prepare_html_for_scatter_plot(projected_shapes):
100 | xs = '[' + ','.join(map(str, projected_shapes[:, 0])) + ']'
101 | ys = '[' + ','.join(map(str, projected_shapes[:, 1])) + ']'
102 | return f'x: {xs}, y: {ys}'
103 |
104 |
105 | def prepare_html_for_landmarks(landmarks_for_one_sample):
106 | landmarks = landmarks_for_one_sample.reshape(35, 2)
107 | xs = '[' + ',"nan",'.join(','.join(map(str, landmarks[idx, 0])) for idx in CONTOURS) + ']'
108 | ys = '[' + ',"nan",'.join(','.join(map(str, -landmarks[idx, 1])) for idx in CONTOURS) + ']'
109 | return f'[{{x: {xs}, y: {ys}, type: "scatter", mode: "lines+markers", line: {{width: 1, color: "orange"}}, marker: {{size: 3, color: "red"}} }}]'
110 |
111 |
112 | def prepare_html_for_all_landmarks(landmarks):
113 | return '[' + ','.join(map(prepare_html_for_landmarks, landmarks)) + ']'
114 |
115 |
116 | def prepare_html_for_visualization(projected_shapes, landmarks, scatterplot_size=[700, 700], annotation_size=[100, 100], floating_annotation=True):
117 | scatter_data = prepare_html_for_scatter_plot(projected_shapes)
118 | scatter_width = str(scatterplot_size[0])
119 | scatter_height = str(scatterplot_size[1])
120 |
121 | annotation_data = prepare_html_for_all_landmarks(landmarks)
122 | annotation_width = str(annotation_size[0])
123 | annotation_height = str(annotation_size[1])
124 |
125 | html = '''
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 | '''
192 | return html
193 |
194 |
195 |
196 |
197 | __all__ = [
198 | "load_data",
199 | "plot_landmarks",
200 | "prepare_html_for_visualization",
201 | ]
--------------------------------------------------------------------------------
/geometric_techniques_in_ML/setup.py:
--------------------------------------------------------------------------------
1 | from distutils.core import setup
2 |
3 |
4 | setup(
5 | name="riemannianoptimization",
6 | version="0.1",
7 | include_package_data=True,
8 | packages=[
9 | "riemannianoptimization",
10 | ]
11 |
12 | )
--------------------------------------------------------------------------------
/img/img0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/img/img0.png
--------------------------------------------------------------------------------
/img/img1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/img/img1.png
--------------------------------------------------------------------------------
/img/img2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/img/img2.png
--------------------------------------------------------------------------------
/img/img3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/img/img3.png
--------------------------------------------------------------------------------
/img/img4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/img/img4.png
--------------------------------------------------------------------------------
/img/img5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/img/img5.png
--------------------------------------------------------------------------------
/img/img6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/img/img6.png
--------------------------------------------------------------------------------
/img/img7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/img/img7.png
--------------------------------------------------------------------------------
/img/img8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/img/img8.png
--------------------------------------------------------------------------------
/kernels/README.md:
--------------------------------------------------------------------------------
1 | This is the practical component of the [Machine Learning Summer School, Moscow 2019](https://mlss2019.skoltech.ru/) session on kernels, focusing on hypothesis testing with kernel statistics.
2 |
3 | The materials here are most recently by
4 | [Dougal Sutherland](http://www.gatsby.ucl.ac.uk/~dougals/)
5 | with consultation from [Arthur Gretton](http://www.gatsby.ucl.ac.uk/~gretton/),
6 | updated from [a previous course](https://github.com/dougalsutherland/ds3-kernels/),
7 | and based in large part on [earlier materials](https://github.com/karlnapf/ds3_kernel_testing)
8 | by [Heiko Strathmann](http://herrstrathmann.de/).
9 |
10 | We'll cover, in varying levels of detail, the following topics:
11 |
12 | - Two-sample testing with the kernel Maximum Mean Discrepancy (MMD).
13 | - Basic concepts of hypothesis testing, including permutation tests.
14 | - Computing kernel values.
15 | - Estimators for the MMD.
16 | - Learning an appropriate kernel function.
17 | - Independence testing with the Hilbert-Schmidt Independence Criterion.
18 |
19 |
20 | ## Dependencies
21 |
22 | ### Colab
23 |
24 | This notebook is [available on Google Colab](https://colab.research.google.com/github/dougalsutherland/mlss-testing/blob/built/testing.ipynb). You don't have to set anything up yourself and it runs on cloud resources, so this is probably the easiest option if you trust that your network connection is going to be reasonably reliable. Make a copy to your own Google Drive to save your progress, and to use a GPU, click Runtime -> Change runtime type -> Hardware accelerator -> GPU. Everything you need is already installed on Colab; use a Python 3 notebook.
25 |
26 | ### Local setup
27 |
28 | Run `check_imports.py` to see if everything you need is installed and downloaded. If that works, you're set; otherwise, read on.
29 |
30 |
31 | #### Files
32 | There are a few Python files and some data files in the repository. By far the easiest thing to do is just put them all in the same directory:
33 |
34 | ```
35 | git clone https://github.com/dougalsutherland/mlss-testing
36 | ```
37 |
38 | #### Python version
39 | This notebook requires Python 3.6+. Python 3.0 was released in 2008, and it's time to stop living in the past; most importart Python projects [are dropping support for Python 2 this year](https://python3statement.org/). If you've never used Python 3 before, don't worry! It's almost the same; for the purposes of this notebook, you probably only need to know that you should write `print("hi")` since it's a function call now, and you can write `A @ B` instead of `A.dot(B)`.
40 |
41 | #### Python packages
42 |
43 | The main thing we use is PyTorch and Jupyter. If you already have those set up, you should be fine; just additionally make sure you also have (with `conda install` or `pip install`) `seaborn`, `tqdm`, and `sckit-learn`. We import everything right at the start, so if that runs you shouldn't hit any surprises later on.
44 |
45 | If you don't already have a setup you're happy with, we recommend the `conda` package manager - start by installing [miniconda](https://docs.conda.io/en/latest/miniconda.html). Then you can create an environment with everything you need as:
46 |
47 | ```bash
48 | conda create --name mlss-testing --override-channels -c pytorch -c defaults --strict-channel-priority python=3 notebook ipywidgets numpy scipy scikit-learn pytorch=1.1 torchvision matplotlib seaborn tqdm
49 | conda activate mlss-testing
50 |
51 | git clone https://github.com/dougalsutherland/mlss-testing
52 | cd mlss-testing
53 | python check_imports.py
54 | jupyter notebook
55 | ```
56 |
57 | (If you have an old conda setup, you can use `source activate` instead of `conda activate`, but it's better to [switch to the new style of activation](https://conda.io/projects/conda/en/latest/release-notes.html#recommended-change-to-enable-conda-in-your-shell). This won't matter for this tutorial, but it's general good practice.)
58 |
59 | (You can make your life easier when using jupyter notebooks with multiple kernels by installing `nb_conda_kernels`, but as long as you install and run `jupyter` from inside the env it will also be fine.)
60 |
61 |
62 | ## PyTorch
63 |
64 | We're going to use PyTorch in this tutorial, even though we're not doing a ton of "deep learning." (The CPU version will be fine, though a GPU might let you get slightly better performance in some of the "advanced" sections.)
65 |
66 | If you haven't used PyTorch before, don't worry! The API is unfortunately a little different from NumPy (and TensorFlow), but it's pretty easy to get used to; you can refer to [a cheat sheet vs NumPy](https://github.com/wkentaro/pytorch-for-numpy-users/blob/master/README.md) as well as the docs: [tensor methods](https://pytorch.org/docs/stable/tensors.html) and [the `torch` namespace](https://pytorch.org/docs/stable/torch.html#torch.eq). Feel free to ask if you have trouble figuring something out.
67 |
68 | You can convert a `torch.Tensor` to a `numpy.ndarray` with [`t.numpy()`](https://pytorch.org/docs/stable/tensors.html#torch.Tensor.numpy), and vice versa with [`torch.as_tensor()`](https://pytorch.org/docs/stable/torch.html#torch.as_tensor). (These share data when possible.) Doing this breaks PyTorch's ability to track gradients through these objects, but it's okay for things we won't need to take derivatives of. If you have a one-element tensor, you can get a regular Python number out of it with [`t.item()`](https://pytorch.org/docs/stable/tensors.html#torch.Tensor.item).
69 |
--------------------------------------------------------------------------------
/kernels/dril-heuristic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/kernels/dril-heuristic.png
--------------------------------------------------------------------------------
/kernels/probability_testing/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/kernels/probability_testing/__init__.py
--------------------------------------------------------------------------------
/kernels/probability_testing/data/almost_simple.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/kernels/probability_testing/data/almost_simple.npz
--------------------------------------------------------------------------------
/kernels/probability_testing/data/blobs.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/kernels/probability_testing/data/blobs.npz
--------------------------------------------------------------------------------
/kernels/probability_testing/data/blobs2.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/kernels/probability_testing/data/blobs2.npz
--------------------------------------------------------------------------------
/kernels/probability_testing/data/blobs_single.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/kernels/probability_testing/data/blobs_single.npz
--------------------------------------------------------------------------------
/kernels/probability_testing/data/gan-samples.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/kernels/probability_testing/data/gan-samples.npz
--------------------------------------------------------------------------------
/kernels/probability_testing/data/hsic.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/kernels/probability_testing/data/hsic.npz
--------------------------------------------------------------------------------
/kernels/probability_testing/data/simple.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/kernels/probability_testing/data/simple.npz
--------------------------------------------------------------------------------
/kernels/probability_testing/data/stopwords-english.txt:
--------------------------------------------------------------------------------
1 | i
2 | me
3 | my
4 | myself
5 | we
6 | our
7 | ours
8 | ourselves
9 | you
10 | you're
11 | you've
12 | you'll
13 | you'd
14 | your
15 | yours
16 | yourself
17 | yourselves
18 | he
19 | him
20 | his
21 | himself
22 | she
23 | she's
24 | her
25 | hers
26 | herself
27 | it
28 | it's
29 | its
30 | itself
31 | they
32 | them
33 | their
34 | theirs
35 | themselves
36 | what
37 | which
38 | who
39 | whom
40 | this
41 | that
42 | that'll
43 | these
44 | those
45 | am
46 | is
47 | are
48 | was
49 | were
50 | be
51 | been
52 | being
53 | have
54 | has
55 | had
56 | having
57 | do
58 | does
59 | did
60 | doing
61 | a
62 | an
63 | the
64 | and
65 | but
66 | if
67 | or
68 | because
69 | as
70 | until
71 | while
72 | of
73 | at
74 | by
75 | for
76 | with
77 | about
78 | against
79 | between
80 | into
81 | through
82 | during
83 | before
84 | after
85 | above
86 | below
87 | to
88 | from
89 | up
90 | down
91 | in
92 | out
93 | on
94 | off
95 | over
96 | under
97 | again
98 | further
99 | then
100 | once
101 | here
102 | there
103 | when
104 | where
105 | why
106 | how
107 | all
108 | any
109 | both
110 | each
111 | few
112 | more
113 | most
114 | other
115 | some
116 | such
117 | no
118 | nor
119 | not
120 | only
121 | own
122 | same
123 | so
124 | than
125 | too
126 | very
127 | s
128 | t
129 | can
130 | will
131 | just
132 | don
133 | don't
134 | should
135 | should've
136 | now
137 | d
138 | ll
139 | m
140 | o
141 | re
142 | ve
143 | y
144 | ain
145 | aren
146 | aren't
147 | couldn
148 | couldn't
149 | didn
150 | didn't
151 | doesn
152 | doesn't
153 | hadn
154 | hadn't
155 | hasn
156 | hasn't
157 | haven
158 | haven't
159 | isn
160 | isn't
161 | ma
162 | mightn
163 | mightn't
164 | mustn
165 | mustn't
166 | needn
167 | needn't
168 | shan
169 | shan't
170 | shouldn
171 | shouldn't
172 | wasn
173 | wasn't
174 | weren
175 | weren't
176 | won
177 | won't
178 | wouldn
179 | wouldn't
180 |
--------------------------------------------------------------------------------
/kernels/probability_testing/data/stopwords-french.txt:
--------------------------------------------------------------------------------
1 | au
2 | aux
3 | avec
4 | ce
5 | ces
6 | dans
7 | de
8 | des
9 | du
10 | elle
11 | en
12 | et
13 | eux
14 | il
15 | je
16 | la
17 | le
18 | leur
19 | lui
20 | ma
21 | mais
22 | me
23 | même
24 | mes
25 | moi
26 | mon
27 | ne
28 | nos
29 | notre
30 | nous
31 | on
32 | ou
33 | par
34 | pas
35 | pour
36 | qu
37 | que
38 | qui
39 | sa
40 | se
41 | ses
42 | son
43 | sur
44 | ta
45 | te
46 | tes
47 | toi
48 | ton
49 | tu
50 | un
51 | une
52 | vos
53 | votre
54 | vous
55 | c
56 | d
57 | j
58 | l
59 | à
60 | m
61 | n
62 | s
63 | t
64 | y
65 | été
66 | étée
67 | étées
68 | étés
69 | étant
70 | étante
71 | étants
72 | étantes
73 | suis
74 | es
75 | est
76 | sommes
77 | êtes
78 | sont
79 | serai
80 | seras
81 | sera
82 | serons
83 | serez
84 | seront
85 | serais
86 | serait
87 | serions
88 | seriez
89 | seraient
90 | étais
91 | était
92 | étions
93 | étiez
94 | étaient
95 | fus
96 | fut
97 | fûmes
98 | fûtes
99 | furent
100 | sois
101 | soit
102 | soyons
103 | soyez
104 | soient
105 | fusse
106 | fusses
107 | fût
108 | fussions
109 | fussiez
110 | fussent
111 | ayant
112 | ayante
113 | ayantes
114 | ayants
115 | eu
116 | eue
117 | eues
118 | eus
119 | ai
120 | as
121 | avons
122 | avez
123 | ont
124 | aurai
125 | auras
126 | aura
127 | aurons
128 | aurez
129 | auront
130 | aurais
131 | aurait
132 | aurions
133 | auriez
134 | auraient
135 | avais
136 | avait
137 | avions
138 | aviez
139 | avaient
140 | eut
141 | eûmes
142 | eûtes
143 | eurent
144 | aie
145 | aies
146 | ait
147 | ayons
148 | ayez
149 | aient
150 | eusse
151 | eusses
152 | eût
153 | eussions
154 | eussiez
155 | eussent
156 |
--------------------------------------------------------------------------------
/kernels/probability_testing/data/transcripts.tar.bz2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/kernels/probability_testing/data/transcripts.tar.bz2
--------------------------------------------------------------------------------
/kernels/probability_testing/support/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | import sys
3 | assert sys.version_info >= (3, 6)
4 |
5 | from .kernels import LazyKernel
6 | from .mmd import mmd2_u_stat_variance
7 | from .utils import as_tensors, maybe_squeeze, pil_grid
8 |
--------------------------------------------------------------------------------
/kernels/probability_testing/support/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/kernels/probability_testing/support/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/kernels/probability_testing/support/__pycache__/kernels.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/kernels/probability_testing/support/__pycache__/kernels.cpython-36.pyc
--------------------------------------------------------------------------------
/kernels/probability_testing/support/__pycache__/mmd.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/kernels/probability_testing/support/__pycache__/mmd.cpython-36.pyc
--------------------------------------------------------------------------------
/kernels/probability_testing/support/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/kernels/probability_testing/support/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/kernels/probability_testing/support/kernels.py:
--------------------------------------------------------------------------------
1 | """
2 | Some probably over-engineered infrastructure for lazily computing kernel
3 | matrices, allowing for various sums / means / etc used by MMD-related estimators.
4 | """
5 | from copy import copy
6 | from functools import wraps
7 |
8 | import numpy as np
9 | import torch
10 |
11 | from .utils import as_tensors
12 |
13 |
14 | def _cache(f):
15 | # Only works when the function takes no or simple arguments!
16 | @wraps(f)
17 | def wrapper(self, *args):
18 | key = (f.__name__,) + tuple(args)
19 | if key in self._cache:
20 | return self._cache[key]
21 | self._cache[key] = val = f(self, *args)
22 | return val
23 |
24 | return wrapper
25 |
26 |
27 | ################################################################################
28 | # Kernel base class
29 |
30 | _name_map = {"X": 0, "Y": 1, "Z": 2}
31 |
32 |
33 | class LazyKernel(torch.nn.Module):
34 | """
35 | Base class that allows computing kernel matrices among a bunch of datasets,
36 | only computing the matrices when we use them.
37 |
38 | Constructor arguments:
39 | - A bunch of matrices we'll compute the kernel among.
40 | 2d tensors, with second dimension agreeing, or None;
41 | None is a special value meaning to use the first entry X.
42 | (This is more efficient than passing the same tensor again.)
43 |
44 | Access the results with:
45 | - K[0, 1] to get the Tensor between parts 0 and 1.
46 | - K.XX, K.XY, K.ZY, etc: shortcuts, with X=0, Y=1, Z=2.
47 | - K.matrix(0, 1) or K.XY_m: returns a Matrix subclass (see below).
48 | """
49 |
50 | def __init__(self, X, *rest):
51 | super().__init__()
52 | self._cache = {}
53 | if not hasattr(self, "const_diagonal"):
54 | self.const_diagonal = False
55 |
56 | # want to use pytorch buffer for parts
57 | # but can't assign a list to those, so munge some names
58 | X, *rest = as_tensors(X, *rest)
59 | if len(X.shape) < 2:
60 | raise ValueError(
61 | "LazyKernel expects inputs to be at least 2d. "
62 | "If your data is 1d, make it [n, 1] with X[:, np.newaxis]."
63 | )
64 |
65 | self.register_buffer("_part_0", X)
66 | self.n_parts = 1
67 | for p in rest:
68 | self.append_part(p)
69 |
70 | @property
71 | def X(self):
72 | return self._part_0
73 |
74 | def _part(self, i):
75 | return self._buffers[f"_part_{i}"]
76 |
77 | def part(self, i):
78 | p = self._part(i)
79 | return self.X if p is None else p
80 |
81 | def n(self, i):
82 | return self.part(i).shape[0]
83 |
84 | @property
85 | def ns(self):
86 | return [self.n(i) for i in range(self.n_parts)]
87 |
88 | @property
89 | def parts(self):
90 | return [self.part(i) for i in range(self.n_parts)]
91 |
92 | @property
93 | def dtype(self):
94 | return self.X.dtype
95 |
96 | @property
97 | def device(self):
98 | return self.X.device
99 |
100 | def __repr__(self):
101 | return f"<{type(self).__name__}({', '.join(str(n) for n in self.ns)})>"
102 |
103 | def _compute(self, A, B):
104 | """
105 | Compute the kernel matrix between A and B.
106 |
107 | Might get called with A = X, B = X, or A = X, B = Y, etc.
108 |
109 | Should return a tensor of shape [A.shape[0], B.shape[0]].
110 |
111 | This default, slow, version calls self._compute_one(a, b) in a loop.
112 | If you override this, you don't need to implement _compute_one at all.
113 |
114 | If you implement _precompute, this gets added to the signature here:
115 | self._compute(A, *self._precompute(A), B, *self._precompute(B)).
116 | The default _precompute returns an empty tuple, so it's _compute(A, B),
117 | but if you make a _precompute that returns [A_squared, A_cubed] then it's
118 | self._compute(A, A_squared, A_cubed, B, B_squared, B_cubed).
119 | """
120 | return torch.stack(
121 | [
122 | torch.stack([torch.as_tensor(self._compute_one(a, b)) for b in B])
123 | for a in A
124 | ]
125 | )
126 |
127 | def _compute_one(self, a, b):
128 | raise NotImplementedError(
129 | f"{type(self).__name__}: need to implement _compute or _compute_one"
130 | )
131 |
132 | def _precompute(self, A):
133 | """
134 | Compute something extra for each part A.
135 |
136 | Can be used to share computation between kernel(X, X) and kernel(X, Y).
137 |
138 | We end up calling basically (but with caching)
139 | self._compute(A, *self._precompute(A), B, *self._precompute(B))
140 | This default _precompute returns an empty tuple, so it's
141 | self._compute(A, B)
142 | But if you return [A_squared], it'd be
143 | self._compute(A, A_squared, B, B_squared)
144 | and so on.
145 | """
146 | return ()
147 |
148 | @_cache
149 | def _precompute_i(self, i):
150 | p = self._part(i)
151 | if p is None:
152 | return self._precompute_i(0)
153 | return self._precompute(p)
154 |
155 | @_cache
156 | def __getitem__(self, k):
157 | try:
158 | i, j = k
159 | except ValueError:
160 | raise KeyError("You should index kernels with pairs")
161 |
162 | A = self._part(i)
163 | if A is None:
164 | return self[0, j]
165 |
166 | B = self._part(j)
167 | if B is None:
168 | return self[i, 0]
169 |
170 | if i > j:
171 | return self[j, i].t()
172 |
173 | A_info = self._precompute_i(i)
174 | B_info = self._precompute_i(j)
175 | return self._compute(A, *A_info, B, *B_info)
176 |
177 | @_cache
178 | def matrix(self, i, j):
179 | if self._part(i) is None:
180 | return self.matrix(0, j)
181 |
182 | if self._part(j) is None:
183 | return self.matrix(i, 0)
184 |
185 | k = self[i, j]
186 | if i == j:
187 | return as_matrix(k, const_diagonal=self.const_diagonal, symmetric=True)
188 | else:
189 | return as_matrix(k)
190 |
191 | @_cache
192 | def joint(self, *inds):
193 | if not inds:
194 | return self.joint(*range(self.n_parts))
195 | return torch.cat([torch.cat([self[i, j] for j in inds], 1) for i in inds], 0)
196 |
197 | @_cache
198 | def joint_m(self, *inds):
199 | if not inds:
200 | return self.joint_m(*range(self.n_parts))
201 | return as_matrix(
202 | self.joint(*inds), const_diagonal=self.const_diagonal, symmetric=True
203 | )
204 |
205 | def __getattr__(self, name):
206 | # self.X, self.Y, self.Z
207 | if name in _name_map:
208 | i = _name_map[name]
209 | if i < self.n_parts:
210 | return self.part(i)
211 | else:
212 | raise AttributeError(f"have {self.n_parts} parts, asked for {i}")
213 |
214 | # self.XX, self.XY, self.YZ, etc; also self.XX_m
215 | ret_matrix = False
216 | if len(name) == 4 and name.endswith("_m"):
217 | ret_matrix = True
218 | name = name[:2]
219 |
220 | if len(name) == 2:
221 | i = _name_map.get(name[0], np.inf)
222 | j = _name_map.get(name[1], np.inf)
223 | if i < self.n_parts and j < self.n_parts:
224 | return self.matrix(i, j) if ret_matrix else self[i, j]
225 | else:
226 | raise AttributeError(f"have {self.n_parts} parts, asked for {i}, {j}")
227 |
228 | return super().__getattr__(name)
229 |
230 | def _invalidate_cache(self, i):
231 | for k in list(self._cache.keys()):
232 | if (
233 | i in k[1:]
234 | or any(isinstance(arg, tuple) and i in arg for arg in k[1:])
235 | or k in [("joint",), ("joint_m",)]
236 | ):
237 | del self._cache[k]
238 |
239 | def drop_last_part(self):
240 | assert self.n_parts >= 2
241 | i = self.n_parts - 1
242 | self._invalidate_cache(i)
243 | del self._buffers[f"_part_{i}"]
244 | self.n_parts -= 1
245 |
246 | def change_part(self, i, new):
247 | assert i < self.n_parts
248 | if new is not None and new.shape[1:] != self.X.shape[1:]:
249 | raise ValueError(f"X has shape {self.X.shape}, new entry has {new.shape}")
250 | self._invalidate_cache(i)
251 | self._buffers[f"_part_{i}"] = new
252 |
253 | def append_part(self, new):
254 | if new is not None and new.shape[1:] != self.X.shape[1:]:
255 | raise ValueError(f"X has shape {self.X.shape}, new entry has {new.shape}")
256 | self._buffers[f"_part_{self.n_parts}"] = new
257 | self.n_parts += 1
258 |
259 | def __copy__(self):
260 | """
261 | Doesn't deep-copy the data tensors, but copies dictionaries so that
262 | change_part/etc don't affect the original.
263 | """
264 | cls = self.__class__
265 | result = cls.__new__(cls)
266 | to_copy = {"_cache", "_buffers", "_parameters", "_modules"}
267 | result.__dict__.update(
268 | {k: v.copy() if k in to_copy else v for k, v in self.__dict__.items()}
269 | )
270 | return result
271 |
272 | def _apply(self, fn): # used in to(), cuda(), etc
273 | super()._apply(fn)
274 | for key, val in self._cache.items():
275 | if val is not None:
276 | self._cache[key] = fn(val)
277 | return self
278 |
279 | def as_tensors(self, *args, **kwargs):
280 | "Helper that makes everything a tensor with self.X's type."
281 | kwargs.setdefault("device", self.X.device)
282 | kwargs.setdefault("dtype", self.X.dtype)
283 | return tuple(None if r is None else torch.as_tensor(r, **kwargs) for r in args)
284 |
285 |
286 | ################################################################################
287 | # Matrix wrappers that cache sums / etc. Including various subclasses; see
288 | # as_matrix() to pick between them appropriately.
289 |
290 | # TODO: could support a matrix transpose that shares the cache appropriately
291 |
292 |
293 | class Matrix:
294 | def __init__(self, M, const_diagonal=False):
295 | self.mat = M = torch.as_tensor(M)
296 | self.m, self.n = self.shape = M.shape
297 | self._cache = {}
298 |
299 | @_cache
300 | def row_sums(self):
301 | return self.mat.sum(0)
302 |
303 | @_cache
304 | def col_sums(self):
305 | return self.mat.sum(1)
306 |
307 | @_cache
308 | def row_sums_sq_sum(self):
309 | sums = self.row_sums()
310 | return sums @ sums
311 |
312 | @_cache
313 | def col_sums_sq_sum(self):
314 | sums = self.col_sums()
315 | return sums @ sums
316 |
317 | @_cache
318 | def sum(self):
319 | if "row_sums" in self._cache:
320 | return self.row_sums().sum()
321 | elif "col_sums" in self._cache:
322 | return self.col_sums().sum()
323 | else:
324 | return self.mat.sum()
325 |
326 | def mean(self):
327 | return self.sum() / (self.m * self.n)
328 |
329 | @_cache
330 | def sq_sum(self):
331 | flat = self.mat.view(-1)
332 | return flat @ flat
333 |
334 | def __repr__(self):
335 | return f"<{type(self).__name__}, {self.m} by {self.n}>"
336 |
337 |
338 | class SquareMatrix(Matrix):
339 | def __init__(self, M):
340 | super().__init__(M)
341 | assert self.m == self.n
342 |
343 | @_cache
344 | def diagonal(self):
345 | return self.mat.diagonal()
346 |
347 | @_cache
348 | def trace(self):
349 | return self.mat.trace()
350 |
351 | @_cache
352 | def sq_trace(self):
353 | diag = self.diagonal()
354 | return diag @ diag
355 |
356 | @_cache
357 | def offdiag_row_sums(self):
358 | return self.row_sums() - self.diagonal()
359 |
360 | @_cache
361 | def offdiag_col_sums(self):
362 | return self.col_sums() - self.diagonal()
363 |
364 | @_cache
365 | def offdiag_row_sums_sq_sum(self):
366 | sums = self.offdiag_row_sums()
367 | return sums @ sums
368 |
369 | @_cache
370 | def offdiag_col_sums_sq_sum(self):
371 | sums = self.offdiag_col_sums()
372 | return sums @ sums
373 |
374 | @_cache
375 | def offdiag_sum(self):
376 | return self.offdiag_row_sums().sum()
377 |
378 | def offdiag_mean(self):
379 | return self.offdiag_sum() / (self.n * (self.n - 1))
380 |
381 | @_cache
382 | def offdiag_sq_sum(self):
383 | return self.sq_sum() - self.sq_trace()
384 |
385 |
386 | class SymmetricMatrix(SquareMatrix):
387 | def col_sums(self):
388 | return self.row_sums()
389 |
390 | def sums(self):
391 | return self.row_sums()
392 |
393 | def offdiag_col_sums(self):
394 | return self.offdiag_row_sums()
395 |
396 | def offdiag_sums(self):
397 | return self.offdiag_row_sums()
398 |
399 | def col_sums_sq_sum(self):
400 | return self.row_sums_sq_sum()
401 |
402 | def sums_sq_sum(self):
403 | return self.row_sums_sq_sum()
404 |
405 | def offdiag_col_sums_sq_sum(self):
406 | return self.offdiag_row_sums_sq_sum()
407 |
408 | def offdiag_sums_sq_sum(self):
409 | return self.offdiag_row_sums_sq_sum()
410 |
411 |
412 | class ConstDiagMatrix(SquareMatrix):
413 | def __init__(self, M, diag_value):
414 | super().__init__(M)
415 | self.diag_value = diag_value
416 |
417 | @_cache
418 | def diagonal(self):
419 | return self.mat.new_full((1,), self.diag_value)
420 |
421 | def trace(self):
422 | return self.n * self.diag_value
423 |
424 | def sq_trace(self):
425 | return self.n * (self.diag_value ** 2)
426 |
427 |
428 | class SymmetricConstDiagMatrix(ConstDiagMatrix, SymmetricMatrix):
429 | pass
430 |
431 |
432 | def as_matrix(M, const_diagonal=False, symmetric=False):
433 | if symmetric:
434 | if const_diagonal is not False:
435 | return SymmetricConstDiagMatrix(M, diag_value=const_diagonal)
436 | else:
437 | return SymmetricMatrix(M)
438 | elif const_diagonal is not False:
439 | return ConstDiagMatrix(M, diag_value=const_diagonal)
440 | elif M.shape[0] == M.shape[1]:
441 | return SquareMatrix(M)
442 | else:
443 | return Matrix(M)
444 |
--------------------------------------------------------------------------------
/kernels/probability_testing/support/mmd.py:
--------------------------------------------------------------------------------
1 | def mmd2_u_stat_variance(K, inds=(0, 1)):
2 | """
3 | Estimate MMD variance with estimator from https://arxiv.org/abs/1906.02104.
4 |
5 | K should be a LazyKernel; we'll compare the parts in inds,
6 | default (0, 1) to use K.XX, K.XY, K.YY.
7 | """
8 | i, j = inds
9 |
10 | m = K.n(i)
11 | assert K.n(j) == m
12 |
13 | XX = K.matrix(i, i)
14 | XY = K.matrix(i, j)
15 | YY = K.matrix(j, j)
16 |
17 | mm = m * m
18 | mmm = mm * m
19 | m1 = m - 1
20 | m1_m1 = m1 * m1
21 | m1_m1_m1 = m1_m1 * m1
22 | m2 = m - 2
23 | mdown2 = m * m1
24 | mdown3 = mdown2 * m2
25 | mdown4 = mdown3 * (m - 3)
26 | twom3 = 2 * m - 3
27 |
28 | return (
29 | (4 / mdown4) * (XX.offdiag_sums_sq_sum() + YY.offdiag_sums_sq_sum())
30 | + (4 * (mm - m - 1) / (mmm * m1_m1))
31 | * (XY.row_sums_sq_sum() + XY.col_sums_sq_sum())
32 | - (8 / (mm * (mm - 3 * m + 2)))
33 | * (XX.offdiag_sums() @ XY.col_sums() + YY.offdiag_sums() @ XY.row_sums())
34 | + 8 / (mm * mdown3) * ((XX.offdiag_sum() + YY.offdiag_sum()) * XY.sum())
35 | - (2 * twom3 / (mdown2 * mdown4)) * (XX.offdiag_sum() + YY.offdiag_sum())
36 | - (4 * twom3 / (mmm * m1_m1_m1)) * XY.sum() ** 2
37 | - (2 / (m * (mmm - 6 * mm + 11 * m - 6)))
38 | * (XX.offdiag_sq_sum() + YY.offdiag_sq_sum())
39 | + (4 * m2 / (mm * m1_m1_m1)) * XY.sq_sum()
40 | )
41 |
--------------------------------------------------------------------------------
/kernels/probability_testing/support/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 |
4 |
5 | def as_tensors(X, *rest):
6 | "Calls as_tensor on a bunch of args, all of the first's device and dtype."
7 | X = torch.as_tensor(X)
8 | return [X] + [
9 | None if r is None else torch.as_tensor(r, device=X.device, dtype=X.dtype)
10 | for r in rest
11 | ]
12 |
13 |
14 | def pil_grid(X, **kwargs):
15 | return torchvision.transforms.ToPILImage()(torchvision.utils.make_grid(X, **kwargs))
16 |
17 |
18 | def maybe_squeeze(X, dim):
19 | "Like torch.squeeze, but don't crash if dim already doesn't exist."
20 | return torch.squeeze(X, dim) if dim < len(X.shape) else X
21 |
--------------------------------------------------------------------------------
/kernels/setup.py:
--------------------------------------------------------------------------------
1 | from distutils.core import setup
2 |
3 |
4 | setup(
5 | name="probability_testing",
6 | version="0.1",
7 | include_package_data=True,
8 | packages=[
9 | "probability_testing", "probability_testing.support",
10 | ]
11 |
12 | )
13 |
--------------------------------------------------------------------------------
/optimal_transport_tutorial/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include data/*.pickle
2 | include data/*.jpg
--------------------------------------------------------------------------------
/optimal_transport_tutorial/Opt_transport_1_Introduction_to_POT_and_S.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "colab_type": "text",
7 | "id": "KV6ChN5Nt4fj"
8 | },
9 | "source": [
10 | "# MLSS 2019: Optimal Transport and Wasserstein Distances"
11 | ]
12 | },
13 | {
14 | "cell_type": "markdown",
15 | "metadata": {
16 | "colab_type": "text",
17 | "id": "Zb-hmuept4fl"
18 | },
19 | "source": [
20 | "The goal of this first practical session is to introduce computational optimal transport (OT) in Python. You will familiarize yourself with OT by:\n",
21 | "1. using the Python library POT (Python Optimal Transport),\n",
22 | "2. coding Sinkhorn's algorithm.\n",
23 | "\n",
24 | "In the second practical session, we will use optimal transport as a nice geometrical tool in machine learning."
25 | ]
26 | },
27 | {
28 | "cell_type": "markdown",
29 | "metadata": {
30 | "colab_type": "text",
31 | "id": "sbBHs3xPgCrm"
32 | },
33 | "source": [
34 | "We are going to use Google Collab to run this notebook. In order to install all the necessary files run the following cells:"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": null,
40 | "metadata": {
41 | "colab": {},
42 | "colab_type": "code",
43 | "id": "zdbamnHGu7Tw"
44 | },
45 | "outputs": [],
46 | "source": [
47 | "import os\n",
48 | "!pip install --upgrade git+https://github.com/mlss-skoltech/tutorials.git#subdirectory=optimal_transport_tutorial"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": null,
54 | "metadata": {
55 | "colab": {},
56 | "colab_type": "code",
57 | "id": "wGfIhryst4fm"
58 | },
59 | "outputs": [],
60 | "source": [
61 | "# Check your installation by importing POT\n",
62 | "!pip install pot\n",
63 | "import ot"
64 | ]
65 | },
66 | {
67 | "cell_type": "markdown",
68 | "metadata": {
69 | "colab_type": "text",
70 | "id": "Ahr5RgE4gtgS"
71 | },
72 | "source": [
73 | "Declare ```DATA_PATH``` as a path to the data from the tutorial package"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": null,
79 | "metadata": {
80 | "colab": {},
81 | "colab_type": "code",
82 | "id": "SPIus7WThTpU"
83 | },
84 | "outputs": [],
85 | "source": [
86 | "import pkg_resources\n",
87 | "\n",
88 | "DATA_PATH = pkg_resources.resource_filename('optimaltransport', 'data/')"
89 | ]
90 | },
91 | {
92 | "cell_type": "markdown",
93 | "metadata": {
94 | "colab_type": "text",
95 | "id": "rRk-69jhgCrt"
96 | },
97 | "source": [
98 | "If you are running this notebook locally, make sure to clone the tutorial repository:\n",
99 | "\n",
100 | "```!pip install --upgrade git+https://github.com/mlss-skoltech/tutorials.git#subdirectory=optimal_transport_tutorial```\n",
101 | "\n",
102 | "\n",
103 | "\n",
104 | "And install the following package:\n",
105 | "\n",
106 | "* Install with pip: ```bash pip install pot```\n",
107 | "* Install with conda: ```bash conda install -c conda-forge pot ```"
108 | ]
109 | },
110 | {
111 | "cell_type": "code",
112 | "execution_count": null,
113 | "metadata": {
114 | "colab": {},
115 | "colab_type": "code",
116 | "id": "T1BG78Dtt4fp"
117 | },
118 | "outputs": [],
119 | "source": [
120 | "import numpy as np\n",
121 | "import matplotlib.pyplot as plt"
122 | ]
123 | },
124 | {
125 | "cell_type": "markdown",
126 | "metadata": {
127 | "colab_type": "text",
128 | "id": "_8RMdsSot4fr"
129 | },
130 | "source": [
131 | "## 1. Solving Exact OT: Linear Programming"
132 | ]
133 | },
134 | {
135 | "cell_type": "markdown",
136 | "metadata": {
137 | "colab_type": "text",
138 | "id": "tzPSVXint4fr"
139 | },
140 | "source": [
141 | "### Reminders on Optimal Transport"
142 | ]
143 | },
144 | {
145 | "cell_type": "markdown",
146 | "metadata": {
147 | "colab_type": "text",
148 | "id": "I67_BnZOt4fs"
149 | },
150 | "source": [
151 | "Optimal Transport is a theory that allows us to compare two (weighted) points clouds $(X, a)$ and $(Y, b)$, where $X \\in \\mathbb{R}^{n \\times d}$ and $Y \\in \\mathbb{R}^{m \\times d}$ are the locations of the $n$ (resp. $m$) points in dimension $d$, and $a \\in \\mathbb{R}^n$, $b \\in \\mathbb{R}^m$ are the weights. We ask that the total weights sum to one, i.e. $\\sum_{i=1}^n a_i = \\sum_{j=1}^m b_j = 1$."
152 | ]
153 | },
154 | {
155 | "cell_type": "markdown",
156 | "metadata": {
157 | "colab_type": "text",
158 | "id": "03yipBjGt4fs"
159 | },
160 | "source": [
161 | "The basic idea of Optimal Transport is to \"transport\" the mass located at points $X$ to the mass located at points $Y$.\n",
162 | "\n",
163 | "Let us denote by $\\mathcal{U}(a,b) = \\left\\{ P \\in \\mathbb{R}^{n \\times m} \\,|\\, P \\geq 0, \\sum_{j=1}^m P_{ij} = a_i, \\sum_{i=1}^n P_{ij} = b_j\\right\\}$ the set of admissible transport plans.\n",
164 | "\n",
165 | "If $P \\in \\mathcal{U}(a,b)$, the quantity $P_{ij} \\geq 0$ should be regarded as the mass transported from point $X_i$ to point $Y_j$. For this reason, it is called a *transport plan*.\n",
166 | "\n",
167 | "We will also consider a *cost matrix* $C \\in \\mathbb{R}^{n \\times m}$. The quantity $C_{ij}$ should be regarded as the cost paid for transporting one unit of mass from $X_i$ to $Y_j$. This cost is usually computed using the positions $X_i$ and $Y_j$, for example $C_{ij} = \\|X_i - Y_j\\|$ or $C_{ij} = \\|X_i - Y_j\\|^2$.\n",
168 | "\n",
169 | "Then transporting mass according to $P \\in \\mathcal{U}(a,b)$ has a total cost of $\\sum_{ij} P_{ij} C_{ij}$."
170 | ]
171 | },
172 | {
173 | "cell_type": "markdown",
174 | "metadata": {
175 | "colab_type": "text",
176 | "id": "icSHTV5Ut4ft"
177 | },
178 | "source": [
179 | "In \"Optimal Transport\", there is the word _Optimal_. Indeed, we want to find a transport plan $P \\in \\mathcal{U}(a,b)$ that will minimize its total cost. In other words, we want to solve\n",
180 | "$$\n",
181 | " \\min_{P \\in \\mathcal{U}(a,b)} \\sum_{ij} C_{ij }P_{ij}.\n",
182 | "$$"
183 | ]
184 | },
185 | {
186 | "cell_type": "markdown",
187 | "metadata": {
188 | "colab_type": "text",
189 | "id": "pSsukYhWt4fu"
190 | },
191 | "source": [
192 | "This problem is a Linear Program: the objective function is linear in the variable $P$, and the constraints are linear in $P$. We can thus solve this problem using classical Linear Programming algorithms, such as the simplex algorithm."
193 | ]
194 | },
195 | {
196 | "cell_type": "markdown",
197 | "metadata": {
198 | "colab_type": "text",
199 | "id": "ckmQk8t9t4fv"
200 | },
201 | "source": [
202 | "If $P^*$ is a solution to the Optimal Transport problem, we will say that $P^*$ is an optimal transport plan between $(X, a)$ and $(Y, b)$, and that $\\sum_{ij} P^*_{ij} C_{ij}$ is the optimal transport distance between $(X, a)$ and $(Y, b)$: it is the minimal amount of \"energy\" that is necessary to transport the initial mass located at points $X$ to the target mass located at points $Y$."
203 | ]
204 | },
205 | {
206 | "cell_type": "markdown",
207 | "metadata": {
208 | "colab_type": "text",
209 | "id": "fRt9uBnWt4fw"
210 | },
211 | "source": [
212 | "### Computing Optimal \"Croissant\" Transport using POT"
213 | ]
214 | },
215 | {
216 | "cell_type": "markdown",
217 | "metadata": {
218 | "colab_type": "text",
219 | "id": "27EET8bAt4fw"
220 | },
221 | "source": [
222 | "We will solve the Bakery/Cafés problem of transporting croissants from a number of Bakeries to Cafés in Moscow.\n",
223 | "\n",
224 | "We use fictional positions, production and sale numbers (that both sum to the same value).\n",
225 | "\n",
226 | "We have acess to the position of Bakeries $X \\in \\mathbb{R}^{8 \\times 2}$ and their respective production $a \\in \\mathbb{R}^8$ which describe the source point cloud. The Cafés where the croissants are sold are defined by their position $Y \\in \\mathbb{R}^{5 \\times 2}$ and $b \\in \\mathbb{R}^{5}$."
227 | ]
228 | },
229 | {
230 | "cell_type": "code",
231 | "execution_count": null,
232 | "metadata": {
233 | "colab": {},
234 | "colab_type": "code",
235 | "id": "ACOcSFXqt4fx"
236 | },
237 | "outputs": [],
238 | "source": [
239 | "# Load the data\n",
240 | "import pickle\n",
241 | "\n",
242 | "with open(DATA_PATH + 'croissants.pickle', 'rb') as file:\n",
243 | " croissants = pickle.load(file)\n",
244 | "\n",
245 | "X = croissants['bakery_pos']\n",
246 | "a = croissants['bakery_prod']\n",
247 | "Y = croissants['cafe_pos']\n",
248 | "b = croissants['cafe_prod']\n",
249 | "\n",
250 | "print('Bakery productions =', a)\n",
251 | "print('Café sales =', b)\n",
252 | "print('Total number of croissants =', a.sum())"
253 | ]
254 | },
255 | {
256 | "cell_type": "code",
257 | "execution_count": null,
258 | "metadata": {
259 | "colab": {},
260 | "colab_type": "code",
261 | "id": "W2hrrI_Xt4fz"
262 | },
263 | "outputs": [],
264 | "source": [
265 | "plt.figure(figsize=(8,8))\n",
266 | "plt.scatter(X[:,0], X[:,1], s=10*a, c='r', edgecolors='k', label='Bakeries')\n",
267 | "plt.scatter(Y[:,0], Y[:,1], s=10*b, c='b', edgecolors='k', label='Cafés')\n",
268 | "plt.legend(fontsize=20)\n",
269 | "plt.axis('off')\n",
270 | "plt.title('Moscow Bakeries and Cafés', fontsize=25)\n",
271 | "plt.show()"
272 | ]
273 | },
274 | {
275 | "cell_type": "markdown",
276 | "metadata": {
277 | "colab_type": "text",
278 | "id": "yAIYAif9t4f1"
279 | },
280 | "source": [
281 | "Let us now compute the cost matrix $C \\in \\mathbb{R}^{n \\times m}$. Here, we will use two different costs: $\\ell_1$ and $\\ell_2$ costs."
282 | ]
283 | },
284 | {
285 | "cell_type": "code",
286 | "execution_count": null,
287 | "metadata": {
288 | "colab": {},
289 | "colab_type": "code",
290 | "id": "77sLokEtt4f1"
291 | },
292 | "outputs": [],
293 | "source": [
294 | "C_1 = np.zeros((8,5)) # TODO: contains the l1 distances\n",
295 | "C_2 = np.zeros((8,5)) # TODO: contains the l2 distances"
296 | ]
297 | },
298 | {
299 | "cell_type": "markdown",
300 | "metadata": {
301 | "colab_type": "text",
302 | "id": "Cp968Q4ht4f3"
303 | },
304 | "source": [
305 | "We can now compute the Optimal Transport plan to transport the croissants from the bakeries to the cafés, for the two different costs."
306 | ]
307 | },
308 | {
309 | "cell_type": "code",
310 | "execution_count": null,
311 | "metadata": {
312 | "colab": {},
313 | "colab_type": "code",
314 | "id": "maFZKc0-t4f5"
315 | },
316 | "outputs": [],
317 | "source": [
318 | "optimal_plan_1 = ot.emd() # TODO: compute the exact OT plan using function ot.emd\n",
319 | "print(optimal_plan_1)\n",
320 | "optimal_cost_1 = # TODO: compute the OT cost for the l1 ground cost\n",
321 | "print('1-Wasserstein distance =', optimal_cost_1)\n",
322 | "print('')\n",
323 | "\n",
324 | "optimal_plan_2 = ot.emd() # TODO: compute the exact OT plan using function ot.emd\n",
325 | "print(optimal_plan_2)\n",
326 | "optimal_cost_2 = # TODO: compute the OT cost for the l2 ground cost\n",
327 | "print('2-Wasserstein distance =', np.sqrt(optimal_cost_2))"
328 | ]
329 | },
330 | {
331 | "cell_type": "code",
332 | "execution_count": null,
333 | "metadata": {
334 | "colab": {},
335 | "colab_type": "code",
336 | "id": "IMYthtNvt4f6"
337 | },
338 | "outputs": [],
339 | "source": [
340 | "fig = plt.figure(figsize=(17,8))\n",
341 | "\n",
342 | "ax = fig.add_subplot(1, 2, 1)\n",
343 | "ax.scatter(X[:,0], X[:,1], s=10*a, c='r', edgecolors='k', label='Bakeries')\n",
344 | "ax.scatter(Y[:,0], Y[:,1], s=10*b, c='b', edgecolors='k', label='Cafés')\n",
345 | "# TODO: plot a line between Bakery i and Café j whenever some croissants are transported between i and j\n",
346 | "ax.axis('off')\n",
347 | "ax.set_title('$\\ell_1$ cost', fontsize=30)\n",
348 | "\n",
349 | "ax = fig.add_subplot(1, 2, 2)\n",
350 | "ax.scatter(X[:,0], X[:,1], s=10*a, c='r', edgecolors='k', label='Bakeries')\n",
351 | "ax.scatter(Y[:,0], Y[:,1], s=10*b, c='b', edgecolors='k', label='Cafés')\n",
352 | "# TODO: plot a line between Bakery i and Café j whenever some croissants are transported between i and j\n",
353 | "ax.axis('off')\n",
354 | "ax.set_title('$\\ell_2$ cost', fontsize=30)\n",
355 | "\n",
356 | "plt.legend(fontsize=20)\n",
357 | "plt.show()"
358 | ]
359 | },
360 | {
361 | "cell_type": "markdown",
362 | "metadata": {
363 | "colab_type": "text",
364 | "id": "eWnTSM2-t4f9"
365 | },
366 | "source": [
367 | "## 2. Sinkhorn Algorithm for Entropy Regularized Optimal Transport"
368 | ]
369 | },
370 | {
371 | "cell_type": "markdown",
372 | "metadata": {
373 | "colab_type": "text",
374 | "id": "OS6IMu8Bt4f-"
375 | },
376 | "source": [
377 | "### Reminders on Sinkhorn Algorithm"
378 | ]
379 | },
380 | {
381 | "cell_type": "markdown",
382 | "metadata": {
383 | "colab_type": "text",
384 | "id": "eQncY7VGt4f-"
385 | },
386 | "source": [
387 | "In real applications, and especially in Machine Learning, we often have to deal with huge numbers of points. In this case, the linear programming algorithms which have cubic complexity will take too much time to run.\n",
388 | "\n",
389 | "That's why in practise, among other reasons, people minimize another criterion given by\n",
390 | "$$\n",
391 | " \\min_{P \\in \\mathcal{U}(a,b)} \\langle C, P \\rangle + \\epsilon \\sum_{ij} P_{ij} [ \\log(P_{ij}) - 1].\n",
392 | "$$\n",
393 | "When $\\epsilon$ is sufficiently small, we can consider that a solution to the above problem (often refered to as \"Entropy-regularized Optimal Transport\") is a good approximation of a real optimal transport plan."
394 | ]
395 | },
396 | {
397 | "cell_type": "markdown",
398 | "metadata": {
399 | "colab_type": "text",
400 | "id": "7L05M4snt4f_"
401 | },
402 | "source": [
403 | "In order to solve this problem, one can remark that the optimality conditions imply that a solution $P_\\epsilon^*$ necessarily is of the form $P_\\epsilon^* = \\text{diag}(u) \\, K \\, \\text{diag}(v)$, where $K = \\exp(-C/\\epsilon)$ and $u,v$ are two non-negative vectors."
404 | ]
405 | },
406 | {
407 | "cell_type": "markdown",
408 | "metadata": {
409 | "colab_type": "text",
410 | "id": "4NxnPr1-t4gA"
411 | },
412 | "source": [
413 | "$P_\\epsilon^*$ should verify the constraints, i.e. $P_\\epsilon^* \\in \\mathcal{U}(a,b)$, so that\n",
414 | "$$\n",
415 | " P_\\epsilon^* 1_m = a \\text{ and } (P_\\epsilon^*)^T 1_n = b\n",
416 | "$$\n",
417 | "which can be rewritten as\n",
418 | "$$\n",
419 | " u \\odot (Kv) = a \\text{ and } v \\odot (K^T u) = b\n",
420 | "$$\n",
421 | "\n",
422 | "Then Sinkhorn's algorithm alternate between the resolution of these two equations, and reads\n",
423 | "$$\n",
424 | " u \\leftarrow \\frac{a}{Kv} \\text{ and } v \\leftarrow \\frac{b}{K^T u}\n",
425 | "$$"
426 | ]
427 | },
428 | {
429 | "cell_type": "code",
430 | "execution_count": null,
431 | "metadata": {
432 | "colab": {},
433 | "colab_type": "code",
434 | "id": "a4A4wbg3t4gA"
435 | },
436 | "outputs": [],
437 | "source": [
438 | "def sinkhorn(a, b, C, epsilon=0.1, max_iters=100):\n",
439 | " \"\"\"Run Sinnkhorn's algorithm\"\"\"\n",
440 | " \n",
441 | " # TODO: Compute the kernel matrix K\n",
442 | " \n",
443 | " # TODO: Alternate projections\n",
444 | " \n",
445 | " return # TODO"
446 | ]
447 | },
448 | {
449 | "cell_type": "code",
450 | "execution_count": null,
451 | "metadata": {
452 | "colab": {},
453 | "colab_type": "code",
454 | "id": "8M5FuBFnt4gC"
455 | },
456 | "outputs": [],
457 | "source": [
458 | "np.round(sinkhorn(a, b, C_2/C_2.max(), epsilon=0.01), 2)"
459 | ]
460 | },
461 | {
462 | "cell_type": "code",
463 | "execution_count": null,
464 | "metadata": {
465 | "colab": {},
466 | "colab_type": "code",
467 | "id": "inUdN4eIt4gD"
468 | },
469 | "outputs": [],
470 | "source": [
471 | "optimal_plan_2"
472 | ]
473 | },
474 | {
475 | "cell_type": "markdown",
476 | "metadata": {
477 | "colab_type": "text",
478 | "id": "mhQOXGjEt4gF"
479 | },
480 | "source": [
481 | "We first show that this algorithm is consistent with classical optimal transport, using the \"croissant\" transport example."
482 | ]
483 | },
484 | {
485 | "cell_type": "code",
486 | "execution_count": null,
487 | "metadata": {
488 | "colab": {},
489 | "colab_type": "code",
490 | "id": "mWJnrT_Ot4gG"
491 | },
492 | "outputs": [],
493 | "source": [
494 | "plan_diff = []\n",
495 | "distance_diff = []\n",
496 | "for epsilon in np.linspace(0.01, 1, 100):\n",
497 | " optimal_plan_sinkhorn = # TODO: compute OT plan using Sinkhorn, with regularization strength epsilon\n",
498 | " optimal_cost_sinkhorn = # TODO: compute OT distance using Sinkhorn\n",
499 | " plan_diff.append() # TODO: compute the Frobenius distance between the exact OT plan and the Sinkhorn OT plan\n",
500 | " distance_diff.append() # TODO: compute the error between exact OT and Sinkhorn values (in %)"
501 | ]
502 | },
503 | {
504 | "cell_type": "code",
505 | "execution_count": null,
506 | "metadata": {
507 | "colab": {},
508 | "colab_type": "code",
509 | "id": "baCU3-hMt4gH"
510 | },
511 | "outputs": [],
512 | "source": [
513 | "plt.figure(figsize=(16,5))\n",
514 | "plt.loglog(np.linspace(0.01, 1, 100), plan_diff, lw=4)\n",
515 | "plt.xlabel('Regularization Strength $\\epsilon$', fontsize=25)\n",
516 | "plt.ylabel('$||P^* - P_\\epsilon^*||_F$', fontsize=25)\n",
517 | "plt.xticks(fontsize=20)\n",
518 | "plt.yticks(fontsize=20)\n",
519 | "plt.grid(ls='--')\n",
520 | "plt.show()"
521 | ]
522 | },
523 | {
524 | "cell_type": "code",
525 | "execution_count": null,
526 | "metadata": {
527 | "colab": {},
528 | "colab_type": "code",
529 | "id": "H7pedasSt4gJ"
530 | },
531 | "outputs": [],
532 | "source": [
533 | "plt.figure(figsize=(16,5))\n",
534 | "plt.loglog(np.linspace(0.01, 1, 100), distance_diff, lw=4)\n",
535 | "plt.xlabel('Regularization Strength $\\epsilon$', fontsize=25)\n",
536 | "plt.ylabel('Error in %', fontsize=25)\n",
537 | "plt.xticks(fontsize=20)\n",
538 | "plt.yticks(fontsize=20)\n",
539 | "plt.grid(ls='--')\n",
540 | "plt.show()"
541 | ]
542 | },
543 | {
544 | "cell_type": "markdown",
545 | "metadata": {
546 | "colab_type": "text",
547 | "id": "2fwnkgjpt4gL"
548 | },
549 | "source": [
550 | "Let us now compare the running time for sinkhorn and classical optimal transport algorithm on more data."
551 | ]
552 | },
553 | {
554 | "cell_type": "code",
555 | "execution_count": null,
556 | "metadata": {
557 | "colab": {},
558 | "colab_type": "code",
559 | "id": "hOfRY9Wut4gM"
560 | },
561 | "outputs": [],
562 | "source": [
563 | "n = 1000\n",
564 | "m = 1000\n",
565 | "d = 2\n",
566 | "\n",
567 | "X = np.random.randn(n,d)\n",
568 | "Y = np.random.randn(m,d)\n",
569 | "\n",
570 | "a = np.ones(n)\n",
571 | "b = np.ones(m)\n",
572 | "\n",
573 | "C = np.zeros((n,m))\n",
574 | "# TODO: compute the cost matrix (using l2 ground distance)"
575 | ]
576 | },
577 | {
578 | "cell_type": "markdown",
579 | "metadata": {
580 | "colab_type": "text",
581 | "id": "7jgwaUChtPAf"
582 | },
583 | "source": [
584 | "Because of Google Colab set up the time measuring can be unreliable, in order to get more certain results try running the code locally"
585 | ]
586 | },
587 | {
588 | "cell_type": "code",
589 | "execution_count": null,
590 | "metadata": {
591 | "colab": {
592 | "base_uri": "https://localhost:8080/",
593 | "height": 170
594 | },
595 | "colab_type": "code",
596 | "id": "SiRL17XYt4gP",
597 | "outputId": "e16c4fc2-1e3b-4c63-ca57-cfc1afac0613"
598 | },
599 | "outputs": [],
600 | "source": [
601 | "%time ot.emd(a,b,C)"
602 | ]
603 | },
604 | {
605 | "cell_type": "code",
606 | "execution_count": null,
607 | "metadata": {
608 | "colab": {
609 | "base_uri": "https://localhost:8080/",
610 | "height": 272
611 | },
612 | "colab_type": "code",
613 | "id": "uyt7TZBzpEDC",
614 | "outputId": "03a76a69-2668-400a-aab4-741f03ae0a0b"
615 | },
616 | "outputs": [],
617 | "source": [
618 | "%time sinkhorn(a,b,C)"
619 | ]
620 | },
621 | {
622 | "cell_type": "markdown",
623 | "metadata": {
624 | "colab_type": "text",
625 | "id": "OTbGsBvPt4gU"
626 | },
627 | "source": [
628 | "We see that sinkhorn is faster. What is even more interesting is that sinkhorn can be parallelerized on GPUs, giving further acceleration. Of course, Sinkhorn algorithm is not computing the exact optimal transport plan any more."
629 | ]
630 | },
631 | {
632 | "cell_type": "markdown",
633 | "metadata": {},
634 | "source": [
635 | "## 3. Optimal Transport in Dimension 1\n",
636 | "\n",
637 | "In dimension $d=1$, computing OT boils down to sorting the points. You will check this fact, and discuss the influence of the regularization strength $\\epsilon$."
638 | ]
639 | },
640 | {
641 | "cell_type": "code",
642 | "execution_count": null,
643 | "metadata": {},
644 | "outputs": [],
645 | "source": [
646 | "n = 4\n",
647 | "m = 4\n",
648 | "\n",
649 | "X = np.random.uniform(size=n)\n",
650 | "Y = np.random.uniform(size=m)\n",
651 | "\n",
652 | "a = np.ones(n)\n",
653 | "b = np.ones(m)\n",
654 | "\n",
655 | "plt.figure(figsize=(17,4))\n",
656 | "plt.scatter(X, np.zeros(n), s=200*a, c='r')\n",
657 | "plt.scatter(Y, np.zeros(m), s=200*b, c='b')\n",
658 | "for i in range(n):\n",
659 | " plt.gca().annotate(str(i+1), xy=(X[i],0.005), size=30, color='r', ha='center')\n",
660 | "for j in range(m):\n",
661 | " plt.gca().annotate(str(j+1), xy=(Y[j],0.005), size=30, color='b', ha='center')\n",
662 | "plt.axis('off')\n",
663 | "plt.show()"
664 | ]
665 | },
666 | {
667 | "cell_type": "code",
668 | "execution_count": null,
669 | "metadata": {},
670 | "outputs": [],
671 | "source": [
672 | "# TODO: Compute the OT plan using sorting, POT, and Sinkhorn. Discuss the results and the running times."
673 | ]
674 | }
675 | ],
676 | "metadata": {
677 | "accelerator": "GPU",
678 | "colab": {
679 | "name": "MLSS 1 Introduction to POT and Sinkhorn Algorithm (student version).ipynb",
680 | "provenance": [],
681 | "version": "0.3.2"
682 | },
683 | "kernelspec": {
684 | "display_name": "Python 3",
685 | "language": "python",
686 | "name": "python3"
687 | },
688 | "language_info": {
689 | "codemirror_mode": {
690 | "name": "ipython",
691 | "version": 3
692 | },
693 | "file_extension": ".py",
694 | "mimetype": "text/x-python",
695 | "name": "python",
696 | "nbconvert_exporter": "python",
697 | "pygments_lexer": "ipython3",
698 | "version": "3.7.1"
699 | }
700 | },
701 | "nbformat": 4,
702 | "nbformat_minor": 1
703 | }
704 |
--------------------------------------------------------------------------------
/optimal_transport_tutorial/Opt_transport_2_Optimal_Transport_for_Mac.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "colab_type": "text",
7 | "id": "57KL8eDKwAkr"
8 | },
9 | "source": [
10 | "# MLSS 2019: Optimal Transport for Machine Learning"
11 | ]
12 | },
13 | {
14 | "cell_type": "markdown",
15 | "metadata": {
16 | "colab_type": "text",
17 | "id": "axBioaovwAkv"
18 | },
19 | "source": [
20 | "In this second practical session, we will apply OT in two different Machine Learning applications:\n",
21 | "1. Color Transfer\n",
22 | "2. Document Clustering\n",
23 | "\n",
24 | "In Color Transfer, we will mainly be interested in the optimal transport plan itself, while in Document Clustering, we will be interested in the value of the Optimal Transport / Wasserstein distance."
25 | ]
26 | },
27 | {
28 | "cell_type": "markdown",
29 | "metadata": {
30 | "colab_type": "text",
31 | "id": "oCUSMwp1wMZQ"
32 | },
33 | "source": [
34 | "We are going to use Google Collab to run this notebook. In order to install all the necessary files run the following cells:"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": null,
40 | "metadata": {
41 | "colab": {},
42 | "colab_type": "code",
43 | "id": "33pD5ciywNmm"
44 | },
45 | "outputs": [],
46 | "source": [
47 | "import os\n",
48 | "!pip install --upgrade git+https://github.com/mlss-skoltech/tutorials.git#subdirectory=optimal_transport_tutorial"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": null,
54 | "metadata": {
55 | "colab": {},
56 | "colab_type": "code",
57 | "id": "u7bJKDk6wyam"
58 | },
59 | "outputs": [],
60 | "source": [
61 | "# Check your installation by importing POT\n",
62 | "!pip install pot\n",
63 | "import ot"
64 | ]
65 | },
66 | {
67 | "cell_type": "markdown",
68 | "metadata": {
69 | "colab_type": "text",
70 | "id": "jE5dwh5jw4mK"
71 | },
72 | "source": [
73 | "Declare ```DATA_PATH``` as a path to the data from the tutorial package"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": null,
79 | "metadata": {
80 | "colab": {},
81 | "colab_type": "code",
82 | "id": "h_urSV_Nw7gz"
83 | },
84 | "outputs": [],
85 | "source": [
86 | "import pkg_resources\n",
87 | "\n",
88 | "DATA_PATH = pkg_resources.resource_filename('optimaltransport', 'data/')"
89 | ]
90 | },
91 | {
92 | "cell_type": "markdown",
93 | "metadata": {
94 | "colab_type": "text",
95 | "id": "xdjSuC2JxA0g"
96 | },
97 | "source": [
98 | "If you are running this notebook locally, make sure to clone the tutorial repository:\n",
99 | "\n",
100 | "```!pip install --upgrade git+https://github.com/mlss-skoltech/tutorials.git#subdirectory=optimal_transport_tutorial```\n",
101 | "\n",
102 | "\n",
103 | "\n",
104 | "And install the following package:\n",
105 | "\n",
106 | "* Install with pip: ```bash pip install pot```\n",
107 | "* Install with conda: ```bash conda install -c conda-forge pot ```"
108 | ]
109 | },
110 | {
111 | "cell_type": "code",
112 | "execution_count": null,
113 | "metadata": {
114 | "colab": {},
115 | "colab_type": "code",
116 | "id": "y9eqVdkrwAkw"
117 | },
118 | "outputs": [],
119 | "source": [
120 | "import numpy as np\n",
121 | "import ot"
122 | ]
123 | },
124 | {
125 | "cell_type": "markdown",
126 | "metadata": {
127 | "colab_type": "text",
128 | "id": "d4TDmWaawAk0"
129 | },
130 | "source": [
131 | "## 1. Color Transfer"
132 | ]
133 | },
134 | {
135 | "cell_type": "markdown",
136 | "metadata": {
137 | "colab_type": "text",
138 | "id": "KsTBGYN3wAk0"
139 | },
140 | "source": [
141 | "Given a source and a target image, the goal of color transfer is to transform the colors of the source image so that it looks similar to the target image color palette. In the end, we want to find a \"color mapping\", giving for each color of the source image a new color. This can be done by computing the optimal transport plan between the two images, seen as point clouds in the RGB space."
142 | ]
143 | },
144 | {
145 | "cell_type": "code",
146 | "execution_count": null,
147 | "metadata": {
148 | "colab": {},
149 | "colab_type": "code",
150 | "id": "VzfbrFDFwAk1"
151 | },
152 | "outputs": [],
153 | "source": [
154 | "# For plotting\n",
155 | "import matplotlib.pyplot as plt\n",
156 | "from matplotlib.pyplot import imread\n",
157 | "from mpl_toolkits.mplot3d import Axes3D"
158 | ]
159 | },
160 | {
161 | "cell_type": "code",
162 | "execution_count": null,
163 | "metadata": {
164 | "colab": {},
165 | "colab_type": "code",
166 | "id": "13BAZ2KpwAk8"
167 | },
168 | "outputs": [],
169 | "source": [
170 | "# Load the images\n",
171 | "I1 = imread(DATA_PATH + 'schiele.jpg').astype(np.float64) / 256\n",
172 | "I2 = imread(DATA_PATH + 'schiele2.jpg').astype(np.float64) / 256\n",
173 | "\n",
174 | "fig = plt.figure(figsize=(17, 30))\n",
175 | "\n",
176 | "ax = fig.add_subplot(1, 2, 1)\n",
177 | "ax.imshow(I1)\n",
178 | "ax.set_title('Landscape', fontsize=25)\n",
179 | "ax.axis('off')\n",
180 | "\n",
181 | "ax = fig.add_subplot(1, 2, 2)\n",
182 | "ax.imshow(I2)\n",
183 | "ax.set_title('Portrait', fontsize=25)\n",
184 | "ax.axis('off')\n",
185 | "\n",
186 | "plt.show()"
187 | ]
188 | },
189 | {
190 | "cell_type": "markdown",
191 | "metadata": {
192 | "colab_type": "text",
193 | "id": "pSPpFriLwAlA"
194 | },
195 | "source": [
196 | "We will need to work with \"matrices\" instead of images. Since there are 3 colors, images have shape `(Width, Height, 3)`, and the corresponding matrices will have shape `(Width*Height, 3)`."
197 | ]
198 | },
199 | {
200 | "cell_type": "code",
201 | "execution_count": null,
202 | "metadata": {
203 | "colab": {},
204 | "colab_type": "code",
205 | "id": "emal4z2EwAlB"
206 | },
207 | "outputs": [],
208 | "source": [
209 | "def im2mat(I):\n",
210 | " '''Convert image I to matrix.'''\n",
211 | " return # TODO: reshape\n",
212 | "\n",
213 | "def mat2im(X, shape):\n",
214 | " '''Convert matrix X to image with shape 'shape'.'''\n",
215 | " return # TODO: reshape\n",
216 | "\n",
217 | "X1 = im2mat(I1)\n",
218 | "X2 = im2mat(I2)"
219 | ]
220 | },
221 | {
222 | "cell_type": "markdown",
223 | "metadata": {
224 | "colab_type": "text",
225 | "id": "KprLOWg7wAlE"
226 | },
227 | "source": [
228 | "Real images have way too many different colors, so we will need to subsample them. In order to do this, we use K-means over all the colors, and keep only the computed centroids. Note that using Mini Batch K-Means will speed the computations up."
229 | ]
230 | },
231 | {
232 | "cell_type": "code",
233 | "execution_count": null,
234 | "metadata": {
235 | "colab": {},
236 | "colab_type": "code",
237 | "id": "ypyN5TMjwAlF"
238 | },
239 | "outputs": [],
240 | "source": [
241 | "from sklearn.cluster import MiniBatchKMeans\n",
242 | "\n",
243 | "# Size of the subsampled point clouds\n",
244 | "nbsamples = 1000\n",
245 | "\n",
246 | "kmeans1 = # TODO: Mini Batch K-Means for X1\n",
247 | "X1_sampled = # TODO: get the centroids\n",
248 | "\n",
249 | "kmeans2 = # TODO: Mini Batch K-Means for X2\n",
250 | "X2_sampled = # TODO: get the centroids"
251 | ]
252 | },
253 | {
254 | "cell_type": "markdown",
255 | "metadata": {
256 | "colab_type": "text",
257 | "id": "xYguNfDhwAlL"
258 | },
259 | "source": [
260 | "Each image is represented by its \"matrix\", i.e. is seen as a point cloud $X \\in \\mathbb{R}^{N\\times3}$ in the RGB color space, identified with $\\mathbb{R}^3$. "
261 | ]
262 | },
263 | {
264 | "cell_type": "code",
265 | "execution_count": null,
266 | "metadata": {
267 | "colab": {},
268 | "colab_type": "code",
269 | "id": "jHszq7CPwAlM"
270 | },
271 | "outputs": [],
272 | "source": [
273 | "def showImageAsPointCloud(X, Y):\n",
274 | " '''Show the color palette associated with images X and Y.'''\n",
275 | " fig = plt.figure(figsize=(17,8))\n",
276 | " ax = fig.add_subplot(121, projection='3d')\n",
277 | " ax.set_xlim(0,1)\n",
278 | " ax.scatter(X[:,0], X[:,1], X[:,2], c=X, s=10, marker='o', alpha=0.6)\n",
279 | " ax.set_xlabel('R',fontsize=22)\n",
280 | " ax.set_xticklabels([])\n",
281 | " ax.set_ylim(0,1)\n",
282 | " ax.set_ylabel('G',fontsize=22)\n",
283 | " ax.set_yticklabels([])\n",
284 | " ax.set_zlim(0,1)\n",
285 | " ax.set_zlabel('B',fontsize=22)\n",
286 | " ax.set_zticklabels([])\n",
287 | " ax.set_title('Landscape Color Palette', fontsize=20)\n",
288 | " ax.grid('off')\n",
289 | " \n",
290 | " ax = fig.add_subplot(122, projection='3d')\n",
291 | " ax.set_xlim(0,1)\n",
292 | " ax.scatter(Y[:,0], Y[:,1], Y[:,2], c=Y, s=10, marker='o', alpha=0.6)\n",
293 | " ax.set_xlabel('R',fontsize=22)\n",
294 | " ax.set_xticklabels([])\n",
295 | " ax.set_ylim(0,1)\n",
296 | " ax.set_ylabel('G',fontsize=22)\n",
297 | " ax.set_yticklabels([])\n",
298 | " ax.set_zlim(0,1)\n",
299 | " ax.set_zlabel('B',fontsize=22)\n",
300 | " ax.set_zticklabels([])\n",
301 | " ax.set_title('Portrait Color Palette', fontsize=20)\n",
302 | " ax.grid('off')\n",
303 | " \n",
304 | " plt.show()"
305 | ]
306 | },
307 | {
308 | "cell_type": "code",
309 | "execution_count": null,
310 | "metadata": {
311 | "colab": {},
312 | "colab_type": "code",
313 | "id": "d4r_n9j5wAlO"
314 | },
315 | "outputs": [],
316 | "source": [
317 | "showImageAsPointCloud(X1_sampled, X2_sampled)"
318 | ]
319 | },
320 | {
321 | "cell_type": "markdown",
322 | "metadata": {
323 | "colab_type": "text",
324 | "id": "bWtBge5KwAlR"
325 | },
326 | "source": [
327 | "In order to compute the optimal transport plans between the two point clouds, we have to compute the corresponding cost matrix. In the following, we will always consider the squared distance, _i.e._ $C_{ij} = \\|X_i - Y_j\\|^2$."
328 | ]
329 | },
330 | {
331 | "cell_type": "code",
332 | "execution_count": null,
333 | "metadata": {
334 | "colab": {},
335 | "colab_type": "code",
336 | "id": "AV-S1C6-wAlR"
337 | },
338 | "outputs": [],
339 | "source": [
340 | "C = # TODO: compute the cost matrix using l2 ground distance"
341 | ]
342 | },
343 | {
344 | "cell_type": "markdown",
345 | "metadata": {
346 | "colab_type": "text",
347 | "id": "hg6CSfzBwAlV"
348 | },
349 | "source": [
350 | "### Landscape with Portrait colors"
351 | ]
352 | },
353 | {
354 | "cell_type": "markdown",
355 | "metadata": {
356 | "colab_type": "text",
357 | "id": "wexKcWriwAlV"
358 | },
359 | "source": [
360 | "Here, the goal is to transfer the colors of the portrait to the landscape. We will compute the exact Optimal Transport Plan, as well as the Entropy Regularized Optimal Transport plans."
361 | ]
362 | },
363 | {
364 | "cell_type": "code",
365 | "execution_count": null,
366 | "metadata": {
367 | "colab": {},
368 | "colab_type": "code",
369 | "id": "LrILq63TwAlW"
370 | },
371 | "outputs": [],
372 | "source": [
373 | "regs = [0.01, 0.1, 0.5]\n",
374 | "OT_plans = [] # Contains the OT plans for regularization strengths : 0, 0.01, 0.1, 0.5\n",
375 | "OT_plans.append(ot.emd()) # TODO: OT plan for exact OT\n",
376 | "for reg in regs:\n",
377 | " OT_plans.append() # TODO: OT plan for regularization strength = reg"
378 | ]
379 | },
380 | {
381 | "cell_type": "code",
382 | "execution_count": null,
383 | "metadata": {
384 | "colab": {},
385 | "colab_type": "code",
386 | "id": "VSLLPM4AwAld"
387 | },
388 | "outputs": [],
389 | "source": [
390 | "def colorTransfer(OT_plan, kmeans1, kmeans2, shape):\n",
391 | " '''Return the color-transfered image of shape \"shape\".'''\n",
392 | " return # TODO"
393 | ]
394 | },
395 | {
396 | "cell_type": "code",
397 | "execution_count": null,
398 | "metadata": {
399 | "colab": {},
400 | "colab_type": "code",
401 | "id": "IUJHdGJKwAlh"
402 | },
403 | "outputs": [],
404 | "source": [
405 | "fig = plt.figure(figsize=(17, 20))\n",
406 | "\n",
407 | "ax = fig.add_subplot(1, 2, 1)\n",
408 | "ax.imshow(I1)\n",
409 | "ax.set_title('Source Image', fontsize=20)\n",
410 | "ax.axis('off')\n",
411 | "\n",
412 | "ax = fig.add_subplot(1, 2, 2)\n",
413 | "I = colorTransfer(OT_plans[0], kmeans1, kmeans2, I1.shape)\n",
414 | "ax.imshow(I)\n",
415 | "ax.set_title('Reg = 0', fontsize=20)\n",
416 | "ax.axis('off')\n",
417 | "\n",
418 | "plt.show()\n",
419 | "\n",
420 | "fig = plt.figure(figsize=(17, 20))\n",
421 | "for i in range(3):\n",
422 | " ax = fig.add_subplot(2, 3, i+1)\n",
423 | " I = colorTransfer(OT_plans[i+1], kmeans1, kmeans2, I1.shape)\n",
424 | " ax.imshow(I)\n",
425 | " ax.set_title('Reg = '+str(regs[i]), fontsize=20)\n",
426 | " ax.axis('off')\n",
427 | "\n",
428 | "plt.show()"
429 | ]
430 | },
431 | {
432 | "cell_type": "markdown",
433 | "metadata": {
434 | "colab_type": "text",
435 | "id": "zR5JREbywAlk"
436 | },
437 | "source": [
438 | "### Portait with Landscape colors"
439 | ]
440 | },
441 | {
442 | "cell_type": "markdown",
443 | "metadata": {
444 | "colab_type": "text",
445 | "id": "vEuyDVKswAll"
446 | },
447 | "source": [
448 | "We now transfer the colors of the landscape to the portrait."
449 | ]
450 | },
451 | {
452 | "cell_type": "code",
453 | "execution_count": null,
454 | "metadata": {
455 | "colab": {},
456 | "colab_type": "code",
457 | "id": "4y4w6yoYwAlm"
458 | },
459 | "outputs": [],
460 | "source": [
461 | "C = # TODO"
462 | ]
463 | },
464 | {
465 | "cell_type": "code",
466 | "execution_count": null,
467 | "metadata": {
468 | "colab": {},
469 | "colab_type": "code",
470 | "id": "hkKbTE74wAlr"
471 | },
472 | "outputs": [],
473 | "source": [
474 | "regs = [0.01, 0.03, 0.1]\n",
475 | "OT_plans = []\n",
476 | "OT_plans.append(ot.emd()) # TODO\n",
477 | "for reg in regs:\n",
478 | " OT_plans.append() # TODO"
479 | ]
480 | },
481 | {
482 | "cell_type": "code",
483 | "execution_count": null,
484 | "metadata": {
485 | "colab": {},
486 | "colab_type": "code",
487 | "id": "wrM8E7_4wAlv"
488 | },
489 | "outputs": [],
490 | "source": [
491 | "def colorTransfer(OT_plan, kmeans1, kmeans2, shape):\n",
492 | " return # TODO"
493 | ]
494 | },
495 | {
496 | "cell_type": "code",
497 | "execution_count": null,
498 | "metadata": {
499 | "colab": {},
500 | "colab_type": "code",
501 | "id": "vMIVlsDGwAly"
502 | },
503 | "outputs": [],
504 | "source": [
505 | "fig = plt.figure(figsize=(17, 20))\n",
506 | "\n",
507 | "ax = fig.add_subplot(1, 2, 1)\n",
508 | "ax.imshow(I2)\n",
509 | "ax.set_title('Source Image', fontsize=20)\n",
510 | "ax.axis('off')\n",
511 | "\n",
512 | "ax = fig.add_subplot(1, 2, 2)\n",
513 | "I = colorTransfer(OT_plans[0], kmeans1, kmeans2, I2.shape)\n",
514 | "ax.imshow(I)\n",
515 | "ax.set_title('Reg = 0', fontsize=20)\n",
516 | "ax.axis('off')\n",
517 | "\n",
518 | "plt.show()\n",
519 | "\n",
520 | "fig = plt.figure(figsize=(17, 20))\n",
521 | "for i in range(3):\n",
522 | " ax = fig.add_subplot(2, 3, i+1)\n",
523 | " I = colorTransfer(OT_plans[i+1], kmeans1, kmeans2, I2.shape)\n",
524 | " ax.imshow(I)\n",
525 | " ax.set_title('Reg = '+str(regs[i]), fontsize=20)\n",
526 | " ax.axis('off')\n",
527 | "\n",
528 | "plt.show()"
529 | ]
530 | },
531 | {
532 | "cell_type": "markdown",
533 | "metadata": {
534 | "colab_type": "text",
535 | "id": "eNFlM3RywAl1"
536 | },
537 | "source": [
538 | "## 2. Document Clustering"
539 | ]
540 | },
541 | {
542 | "cell_type": "markdown",
543 | "metadata": {
544 | "colab_type": "text",
545 | "id": "cq9danJpwAl2"
546 | },
547 | "source": [
548 | "We would likle to classify several text documents. In order to do this, we will:\n",
549 | "1. Transform each text into a point cloud\n",
550 | "2. Compute the Optimal Transport distances between each pair of point clouds\n",
551 | "3. Use MDS to plot the different clusters in 2 dimensions"
552 | ]
553 | },
554 | {
555 | "cell_type": "markdown",
556 | "metadata": {
557 | "colab_type": "text",
558 | "id": "iLjCFnUHwAl2"
559 | },
560 | "source": [
561 | "### Load the Data and Preprocessing\n",
562 | "We consider seven movie scenarios. We transformed each of them into a point cloud using the following steps:\n",
563 | "1. Keep only the words among the $2.000 - 20.000$ most common words\n",
564 | "2. Each remaining word is transformed into a $300$-dimensional vector using word2vec\n",
565 | "3. Each word is given a weight proportional to its frequency\n",
566 | "\n",
567 | "The variable `texts` is a list of tuples. Each tuple represents a movie, and contains two parts:\n",
568 | "1. A matrix $X \\in \\mathbb{R}^{n \\times 300}$ where $n$ is the number of different words, containing the position of the points\n",
569 | "2. A vector $a \\in \\mathbb{R}^n$ containing the corresponding weights"
570 | ]
571 | },
572 | {
573 | "cell_type": "code",
574 | "execution_count": null,
575 | "metadata": {
576 | "colab": {},
577 | "colab_type": "code",
578 | "id": "SwMrZHa3wAl5"
579 | },
580 | "outputs": [],
581 | "source": [
582 | "import pickle\n",
583 | "\n",
584 | "with open(DATA_PATH + 'texts.pickle', 'rb') as file:\n",
585 | " texts = pickle.load(file)\n",
586 | "\n",
587 | "movies = ['DUNKIRK', 'GRAVITY', 'INTERSTELLAR', 'KILL BILL VOL.1', 'KILL BILL VOL.2', 'THE MARTIAN', 'TITANIC']"
588 | ]
589 | },
590 | {
591 | "cell_type": "markdown",
592 | "metadata": {
593 | "colab_type": "text",
594 | "id": "Q9kWV4ltwAl8"
595 | },
596 | "source": [
597 | "### Compute the OT distances"
598 | ]
599 | },
600 | {
601 | "cell_type": "code",
602 | "execution_count": null,
603 | "metadata": {
604 | "colab": {},
605 | "colab_type": "code",
606 | "id": "hCLVUY9xwAl9"
607 | },
608 | "outputs": [],
609 | "source": [
610 | "# Set regularization strength\n",
611 | "reg = 0.1"
612 | ]
613 | },
614 | {
615 | "cell_type": "code",
616 | "execution_count": null,
617 | "metadata": {
618 | "colab": {},
619 | "colab_type": "code",
620 | "id": "yv4xFXtHwAmB"
621 | },
622 | "outputs": [],
623 | "source": [
624 | "def costMatrix(i,j):\n",
625 | " '''Return the cost matrix C between movies number i and j.'''\n",
626 | " X = texts[i][0]\n",
627 | " Y = texts[j][0]\n",
628 | " \n",
629 | " return # TODO"
630 | ]
631 | },
632 | {
633 | "cell_type": "code",
634 | "execution_count": null,
635 | "metadata": {
636 | "colab": {},
637 | "colab_type": "code",
638 | "id": "T9jSqIywwAmD"
639 | },
640 | "outputs": [],
641 | "source": [
642 | "#this cell will take approximately 1 minute to compute in Google Colaboratory after you complete it\n",
643 | "OT_distances = np.zeros((7,7))\n",
644 | "# TODO: compute the OT distance (using Sinkhorn algorithm ot.sinkhorn) between all the pairs of scenarios"
645 | ]
646 | },
647 | {
648 | "cell_type": "code",
649 | "execution_count": null,
650 | "metadata": {},
651 | "outputs": [],
652 | "source": [
653 | "for film in movies:\n",
654 | " print('The film most similar to', film, 'is', # TODO)"
655 | ]
656 | },
657 | {
658 | "cell_type": "markdown",
659 | "metadata": {
660 | "colab_type": "text",
661 | "id": "QPJaNctTwAmF"
662 | },
663 | "source": [
664 | "### Plot the MDS projection"
665 | ]
666 | },
667 | {
668 | "cell_type": "code",
669 | "execution_count": null,
670 | "metadata": {
671 | "colab": {},
672 | "colab_type": "code",
673 | "id": "GFeOuUdbwAmF"
674 | },
675 | "outputs": [],
676 | "source": [
677 | "from sklearn.manifold import MDS\n",
678 | "embedding = MDS(n_components=2, dissimilarity='precomputed')\n",
679 | "dis = OT_distances - OT_distances[OT_distances>0].min()\n",
680 | "np.fill_diagonal(dis, 0.)\n",
681 | "embedding = embedding.fit(dis)\n",
682 | "X = embedding.embedding_\n",
683 | "\n",
684 | "import matplotlib.pyplot as plt\n",
685 | "plt.figure(figsize=(17,6))\n",
686 | "plt.scatter(X[:,0], X[:,1], alpha=0.)\n",
687 | "plt.axis('equal')\n",
688 | "plt.axis('off')\n",
689 | "c = {'KILL BILL VOL.1':'red', 'KILL BILL VOL.2':'red', 'TITANIC':'blue', 'DUNKIRK':'blue', 'GRAVITY':'black', 'INTERSTELLAR':'black', 'THE MARTIAN':'black'}\n",
690 | "for film in movies:\n",
691 | " i = movies.index(film)\n",
692 | " plt.gca().annotate(film, X[i], size=30, ha='center', color=c[film], weight=\"bold\", alpha=0.7)\n",
693 | "plt.show()"
694 | ]
695 | },
696 | {
697 | "cell_type": "code",
698 | "execution_count": null,
699 | "metadata": {
700 | "colab": {},
701 | "colab_type": "code",
702 | "id": "EVruQxj1zSfW"
703 | },
704 | "outputs": [],
705 | "source": []
706 | }
707 | ],
708 | "metadata": {
709 | "colab": {
710 | "name": "MLSS 2 Optimal Transport for Machine Learning (student version).ipynb",
711 | "provenance": [],
712 | "version": "0.3.2"
713 | },
714 | "kernelspec": {
715 | "display_name": "Python 3",
716 | "language": "python",
717 | "name": "python3"
718 | },
719 | "language_info": {
720 | "codemirror_mode": {
721 | "name": "ipython",
722 | "version": 3
723 | },
724 | "file_extension": ".py",
725 | "mimetype": "text/x-python",
726 | "name": "python",
727 | "nbconvert_exporter": "python",
728 | "pygments_lexer": "ipython3",
729 | "version": "3.7.1"
730 | }
731 | },
732 | "nbformat": 4,
733 | "nbformat_minor": 1
734 | }
735 |
--------------------------------------------------------------------------------
/optimal_transport_tutorial/optimaltransport/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/optimal_transport_tutorial/optimaltransport/__init__.py
--------------------------------------------------------------------------------
/optimal_transport_tutorial/optimaltransport/data/croissants.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/optimal_transport_tutorial/optimaltransport/data/croissants.pickle
--------------------------------------------------------------------------------
/optimal_transport_tutorial/optimaltransport/data/schiele.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/optimal_transport_tutorial/optimaltransport/data/schiele.jpg
--------------------------------------------------------------------------------
/optimal_transport_tutorial/optimaltransport/data/schiele2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/optimal_transport_tutorial/optimaltransport/data/schiele2.jpg
--------------------------------------------------------------------------------
/optimal_transport_tutorial/optimaltransport/data/texts.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/optimal_transport_tutorial/optimaltransport/data/texts.pickle
--------------------------------------------------------------------------------
/optimal_transport_tutorial/setup.py:
--------------------------------------------------------------------------------
1 | from distutils.core import setup
2 |
3 |
4 | setup(
5 | name="optimaltransport",
6 | version="0.1",
7 | include_package_data=True,
8 | packages=[
9 | "optimaltransport",
10 | ]
11 |
12 | )
--------------------------------------------------------------------------------