├── .gitattributes ├── .gitignore ├── MLSS FULL SCHEDULE - 8.22.19.pdf ├── README.md ├── bayesian_deep_learning ├── .gitignore ├── Bayesian Deep Learning part1-soln.ipynb ├── Bayesian Deep Learning part1.ipynb ├── Bayesian Deep Learning part2-soln.ipynb ├── Bayesian Deep Learning part2.ipynb ├── README.md ├── mlss2019bdl │ ├── __init__.py │ ├── bdl │ │ ├── __init__.py │ │ ├── base.py │ │ ├── bernoulli.py │ │ └── gaussian.py │ ├── dataset.py │ ├── flex.py │ └── plotting.py └── setup.py ├── causality ├── exercises-answers.pdf └── exercises-tutorial.pdf ├── geometric_techniques_in_ML ├── MANIFEST.in ├── riemannian_opt_for_ml_solution.ipynb ├── riemannian_opt_for_ml_task.ipynb ├── riemannian_opt_gmm_embeddings.ipynb ├── riemannian_opt_text_preprocessing.ipynb ├── riemannianoptimization │ ├── __init__.py │ ├── data │ │ ├── kbvt_lfpw_v1_train.csv │ │ └── tsne_result_training_part.csv │ └── tutorial_helpers.py └── setup.py ├── img ├── img0.png ├── img1.png ├── img2.png ├── img3.png ├── img4.png ├── img5.png ├── img6.png ├── img7.png └── img8.png ├── kernels ├── README.md ├── dril-heuristic.png ├── probability_testing │ ├── __init__.py │ ├── data │ │ ├── almost_simple.npz │ │ ├── blobs.npz │ │ ├── blobs2.npz │ │ ├── blobs_single.npz │ │ ├── gan-samples.npz │ │ ├── hsic.npz │ │ ├── simple.npz │ │ ├── stopwords-english.txt │ │ ├── stopwords-french.txt │ │ └── transcripts.tar.bz2 │ └── support │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── kernels.cpython-36.pyc │ │ ├── mmd.cpython-36.pyc │ │ └── utils.cpython-36.pyc │ │ ├── kernels.py │ │ ├── mmd.py │ │ └── utils.py ├── setup.py ├── solutions_testing.ipynb └── testing.ipynb └── optimal_transport_tutorial ├── MANIFEST.in ├── Opt_transport_1_Introduction_to_POT_and_S._solutions.ipynb ├── Opt_transport_1_Introduction_to_POT_and_S.ipynb ├── Opt_transport_2_Optimal_Transport_for_Mac.ipynb ├── Opt_transport_2_Optimal_Transport_for_Mac_solutions.ipynb ├── optimaltransport ├── __init__.py └── data │ ├── croissants.pickle │ ├── schiele.jpg │ ├── schiele2.jpg │ └── texts.pickle └── setup.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.DS_Store -------------------------------------------------------------------------------- /MLSS FULL SCHEDULE - 8.22.19.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/MLSS FULL SCHEDULE - 8.22.19.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MLSS 2019 Skoltech tutorials 2 | This is the official repository for Machine Learning Summer School 2019, which is taking place at Skoltech Institute of Science and Technology, Moscow, from 26.08 - 06.09. 3 | 4 | This repository will contain all of the materials needed for MLSS tutorials. 5 | 6 | ## The list of the current tutorials published (will be updated with time): 7 | * DAY-1 (26.08): François-Pierre Paty, Marco Cuturi - Optimal Transport: https://github.com/mlss-skoltech/tutorials/tree/master/optimal_transport_tutorial 8 | * DAY-2 (27.08): Alexey Artemov, Justin Solomon - Geometric Techniques in ML: https://github.com/mlss-skoltech/tutorials/tree/master/geometric_techniques_in_ML 9 | * DAY-4 (29.08): Yermek Kapushev, Arthur Gretton - Kernels: https://github.com/mlss-skoltech/tutorials/tree/master/kernels 10 | * [Updated 28.08]DAY-5 (30.08): Joris Mooij - Causality: https://github.com/mlss-skoltech/tutorials/tree/master/causality 11 | * [Updated 30.08]DAY-5 (30.08): Ivan Nazarov, Yarin Gal - Bayesian Deep Learning: https://github.com/mlss-skoltech/tutorials/tree/master/bayesian_deep_learning 12 | 13 | # Running the tutorials on Google Colaboratory: 14 | Most of the tutorials were created using Jupyter Notebooks. In order to reduce the time spent on installing various software, we have made sure that all of the tutorials are Google Colaboratory friendly. 15 | 16 | Colaboratory is a free Jupyter notebook environment that requires no setup and runs entirely in the cloud. With Colaboratory you can write and execute code, save and share your analyses, and access powerful computing resources, all for free from your browser. All of the notebooks already contain all the set-ups needed for each particular tutorial, so you will just be required to run the first several cells. 17 | 18 | Here are the instructions on how open the notebooks in Colaboratory (tested on Google Chrome, version 76.0.): 19 | * First go to https://colab.research.google.com/github/mlss-skoltech/ 20 | * In the pop-up window, sign-in into your GitHub account 21 | ![image0](/img/img0.png) 22 | * In the opened window, choose the notebook correspodning to the tutorial 23 | ![image1](/img/img1.png) 24 | * The selected notebook will open, now make sure that you are signed-in into your Google account 25 | ![image2](/img/img2.png) 26 | * Try to run the first cell, you will get the following message: 27 | ![image3](/img/img3.png) 28 | Press ```RUN ANYWAY``` 29 | * For the message ```Reset all runtimes``` press ```YES``` 30 | ![image4](/img/img4.png) 31 | 32 | In order to download all the material for the tutorial, make sure you run the cells containing the following code first (all of these cells are already added to the notebooks with the right paths): 33 | * For downloading the github subdirectory containing the tutorial: 34 | 35 | ```!pip install --upgrade git+https://github.com/mlss-skoltech/tutorials.git#subdirectory=``` 36 | 37 | * For declaring the data files' path: 38 | ``` 39 | import pkg_resources 40 | DATA_PATH = pkg_resources.resource_filename('name_of_the_installed_tutorial_package', 'data/') 41 | ``` 42 | # Using GPU with Google Colaboratory: 43 | Sometimes for computationally hard tasks you will be required to use GPU instead of default CPU, in order to do this follow these steps: 44 | * Go to ```Edit->Notebook Settings``` 45 | ![image5](/img/img5.png) 46 | * In the ```Hardware accelerator``` field choose ```GPU``` 47 | ![image6](/img/img6.png) 48 | ![image7](/img/img7.png) 49 | 50 | # Saving and downloading the notebooks 51 | You can save your notebook in your Google Drive or simply download it, for that go to ```File->Save a copy in Drive``` or ```File->Download.ipynb```. 52 | ![image8](/img/img8.png) 53 | 54 | 55 | 56 | If you would like to see more tutorials regarding Google Colaboratory have a look at this notebook: https://colab.research.google.com/notebooks/welcome.ipynb 57 | 58 | # Contact 59 | If you have any questions/suggestions regarding this githup repository or have found any bugs, please write to me at N.Mazyavkina@skoltech.ru 60 | 61 | -------------------------------------------------------------------------------- /bayesian_deep_learning/.gitignore: -------------------------------------------------------------------------------- 1 | # General 2 | .DS_Store 3 | .AppleDouble 4 | .LSOverride 5 | 6 | # Icon must end with two \r 7 | Icon 8 | 9 | 10 | # Thumbnails 11 | ._* 12 | 13 | # Files that might appear in the root of a volume 14 | .DocumentRevisions-V100 15 | .fseventsd 16 | .Spotlight-V100 17 | .TemporaryItems 18 | .Trashes 19 | .VolumeIcon.icns 20 | .com.apple.timemachine.donotpresent 21 | 22 | # Directories potentially created on remote AFP share 23 | .AppleDB 24 | .AppleDesktop 25 | Network Trash Folder 26 | Temporary Items 27 | .apdisk 28 | 29 | # Byte-compiled / optimized / DLL files 30 | __pycache__/ 31 | *.py[cod] 32 | *$py.class 33 | 34 | # C extensions 35 | *.so 36 | 37 | # Distribution / packaging 38 | .Python 39 | build/ 40 | develop-eggs/ 41 | dist/ 42 | downloads/ 43 | eggs/ 44 | .eggs/ 45 | lib/ 46 | lib64/ 47 | parts/ 48 | sdist/ 49 | var/ 50 | wheels/ 51 | *.egg-info/ 52 | .installed.cfg 53 | *.egg 54 | MANIFEST 55 | 56 | # PyInstaller 57 | # Usually these files are written by a python script from a template 58 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 59 | *.manifest 60 | *.spec 61 | 62 | # Installer logs 63 | pip-log.txt 64 | pip-delete-this-directory.txt 65 | 66 | # Unit test / coverage reports 67 | htmlcov/ 68 | .tox/ 69 | .coverage 70 | .coverage.* 71 | .cache 72 | nosetests.xml 73 | coverage.xml 74 | *.cover 75 | .hypothesis/ 76 | .pytest_cache/ 77 | 78 | # Translations 79 | *.mo 80 | *.pot 81 | 82 | # Django stuff: 83 | *.log 84 | local_settings.py 85 | db.sqlite3 86 | 87 | # Flask stuff: 88 | instance/ 89 | .webassets-cache 90 | 91 | # Scrapy stuff: 92 | .scrapy 93 | 94 | # Sphinx documentation 95 | docs/_build/ 96 | 97 | # PyBuilder 98 | target/ 99 | 100 | # Jupyter Notebook 101 | .ipynb_checkpoints 102 | 103 | # pyenv 104 | .python-version 105 | 106 | # celery beat schedule file 107 | celerybeat-schedule 108 | 109 | # SageMath parsed files 110 | *.sage.py 111 | 112 | # Environments 113 | .env 114 | .venv 115 | env/ 116 | venv/ 117 | ENV/ 118 | env.bak/ 119 | venv.bak/ 120 | 121 | # Spyder project settings 122 | .spyderproject 123 | .spyproject 124 | 125 | # Rope project settings 126 | .ropeproject 127 | 128 | # mkdocs documentation 129 | /site 130 | 131 | # mypy 132 | .mypy_cache/ 133 | -------------------------------------------------------------------------------- /bayesian_deep_learning/Bayesian Deep Learning part2.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# MLSS2019: Bayesian Deep Learning" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "In this tutorial we will uncertainty estimation can be\n", 15 | "used in active learning or expert-in-the-loop pipelines." 16 | ] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "metadata": {}, 21 | "source": [ 22 | "The plan of the tutorial\n", 23 | "1. [Imports and definitions](#Imports-and-definitions)\n", 24 | "2. [Bayesian Active Learning with images](#Bayesian-Active-Learning-with-images)\n", 25 | " 1. [The model](#The-model)\n", 26 | " 2. [the Acquisition Function](#the-Acquisition-Function)\n", 27 | " 3. [Data and the Oracle](#Data-and-the-Oracle)\n", 28 | " 4. [the Active Learning loop](#the-Active-Learning-loop)\n", 29 | " 5. [The baseline](#The-baseline)\n", 30 | "3. [Bayesian Active Learning by Disagreement](#Bayesian-Active-Learning-by-Disagreement)\n", 31 | " 1. [Points of improvement: batch-vs-single](#Points-of-improvement:-batch-vs-single)\n", 32 | " 2. [Points of improvement: bias](#Points-of-improvement:-bias)\n" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "**(note)**\n", 40 | "* to view documentation on something type in `something?` (with one question mark)\n", 41 | "* to view code of something type in `something??` (with two question marks)." 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "metadata": {}, 47 | "source": [ 48 | "
" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": {}, 54 | "source": [ 55 | "## Imports and definitions" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "metadata": {}, 61 | "source": [ 62 | "In this section we import necessary modules and functions and\n", 63 | "define the computational device." 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": {}, 69 | "source": [ 70 | "First, we install some boilerplate service code for this tutorial." 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "!pip install -q --upgrade git+https://github.com/mlss-skoltech/tutorials.git#subdirectory=bayesian_deep_learning" 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "metadata": {}, 85 | "source": [ 86 | "Next, numpy for computing, matplotlib for plotting and tqdm for progress bars." 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "import tqdm\n", 96 | "import numpy as np\n", 97 | "\n", 98 | "%matplotlib inline\n", 99 | "import matplotlib.pyplot as plt" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "metadata": {}, 105 | "source": [ 106 | "For deep learning stuff will be using [pytorch](https://pytorch.org/).\n", 107 | "\n", 108 | "If you are unfamiliar with it, it is basically like `numpy` with autograd,\n", 109 | "native GPU support, and tools for building training and serializing models.\n", 110 | "\n", 111 | "\n", 112 | "There are good introductory tutorials on `pytorch`, like this\n", 113 | "[one](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html)." 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": null, 119 | "metadata": {}, 120 | "outputs": [], 121 | "source": [ 122 | "import torch\n", 123 | "import torch.nn.functional as F\n", 124 | "\n", 125 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" 126 | ] 127 | }, 128 | { 129 | "cell_type": "markdown", 130 | "metadata": {}, 131 | "source": [ 132 | "Next we import the boilerplate code.\n", 133 | "\n", 134 | "* a procedure that implements a minibatch SGD **fit** loop\n", 135 | "* a function, that **evaluates** the model on the provided dataset" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "from mlss2019bdl import fit, predict" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "metadata": {}, 150 | "source": [ 151 | "The algorithm to sample a random function is:\n", 152 | "* for $b = 1... B$ do:\n", 153 | "\n", 154 | " 1. draw an independent realization $f_b\\colon \\mathcal{X} \\to \\mathcal{Y}$\n", 155 | " with from the process $\\{f_\\omega\\}_{\\omega \\sim q(\\omega)}$\n", 156 | " 2. get $\\hat{y}_{bi} = f_b(\\tilde{x}_i)$ for $i=1 .. m$\n" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": null, 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "from mlss2019bdl.bdl import freeze, unfreeze\n", 166 | "\n", 167 | "def sample_function(model, dataset, n_draws=1, verbose=False):\n", 168 | " \"\"\"Draw a realization of a random function.\"\"\"\n", 169 | " outputs = []\n", 170 | " for _ in tqdm.tqdm(range(n_draws), disable=not verbose):\n", 171 | " freeze(model)\n", 172 | "\n", 173 | " outputs.append(predict(model, dataset))\n", 174 | "\n", 175 | " unfreeze(model)\n", 176 | "\n", 177 | " return torch.stack(outputs, dim=0)" 178 | ] 179 | }, 180 | { 181 | "cell_type": "markdown", 182 | "metadata": {}, 183 | "source": [ 184 | "Sample the class probabilities $p(y_x = k \\mid x, \\omega, m)$\n", 185 | "with $\\omega \\sim q(\\omega)$ by a model that **outputs raw class\n", 186 | "logit scores**." 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": null, 192 | "metadata": {}, 193 | "outputs": [], 194 | "source": [ 195 | "def sample_proba(model, dataset, n_draws=1):\n", 196 | " logits = sample_function(model, dataset, n_draws=n_draws)\n", 197 | "\n", 198 | " return F.softmax(logits, dim=-1)" 199 | ] 200 | }, 201 | { 202 | "cell_type": "markdown", 203 | "metadata": {}, 204 | "source": [ 205 | "Get the predictive posterior class probabilities\n", 206 | "$$\n", 207 | "p(y_x = k \\mid x, m)\n", 208 | "% = \\mathbb{E}_{\\omega \\sim q(\\omega)}\n", 209 | "% p(y_x = k \\mid x, \\omega, m)\n", 210 | " \\approx \\frac1{\\lvert \\mathcal{W} \\rvert}\n", 211 | " \\sum_{\\omega \\in \\mathcal{W}}\n", 212 | " p(y_x = k \\mid x, \\omega, m)\n", 213 | " \\,, $$\n", 214 | "with $\\mathcal{W}$ -- iid draws from $q(\\omega)$." 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": null, 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [ 223 | "def predict_proba(model, dataset, n_draws=1):\n", 224 | " proba = sample_proba(model, dataset, n_draws=n_draws)\n", 225 | "\n", 226 | " return proba.mean(dim=0)" 227 | ] 228 | }, 229 | { 230 | "cell_type": "markdown", 231 | "metadata": {}, 232 | "source": [ 233 | "Gat the maximum a posteriori class label **(MAP)**: $\n", 234 | "\\hat{y}_x\n", 235 | " = \\arg \\max_k \\mathbb{E}_{\\omega \\sim q(\\omega)}\n", 236 | " p(y_x = k \\mid x, \\omega, m)\n", 237 | "$" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": null, 243 | "metadata": {}, 244 | "outputs": [], 245 | "source": [ 246 | "def predict_label(model, dataset, n_draws=1):\n", 247 | " proba = predict_proba(model, dataset, n_draws=n_draws)\n", 248 | "\n", 249 | " return proba.argmax(dim=-1)" 250 | ] 251 | }, 252 | { 253 | "cell_type": "markdown", 254 | "metadata": {}, 255 | "source": [ 256 | "We will need some functionality from scikit" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": null, 262 | "metadata": {}, 263 | "outputs": [], 264 | "source": [ 265 | "from sklearn.metrics import confusion_matrix\n", 266 | "\n", 267 | "def evaluate(model, dataset, n_draws=1):\n", 268 | " assert isinstance(dataset, TensorDataset)\n", 269 | "\n", 270 | " predicted = predict_label(model, dataset, n_draws=n_draws)\n", 271 | "\n", 272 | " target = dataset.tensors[1].cpu().numpy()\n", 273 | " return confusion_matrix(target, predicted.cpu().numpy())" 274 | ] 275 | }, 276 | { 277 | "cell_type": "markdown", 278 | "metadata": {}, 279 | "source": [ 280 | "A function to plot images in a small dataset. " 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": null, 286 | "metadata": {}, 287 | "outputs": [], 288 | "source": [ 289 | "from mlss2019bdl.flex import plot\n", 290 | "from torch.utils.data import TensorDataset\n", 291 | "from IPython.display import clear_output\n", 292 | "\n", 293 | "def display(images, n_col=None, title=None, figsize=None, refresh=False):\n", 294 | " if isinstance(images, TensorDataset):\n", 295 | " images, targets = images.tensors\n", 296 | " \n", 297 | " if refresh:\n", 298 | " clear_output(True)\n", 299 | "\n", 300 | " fig, ax = plt.subplots(1, 1, figsize=figsize)\n", 301 | " plot(ax, images, n_col=n_col, cmap=plt.cm.bone)\n", 302 | " if title is not None:\n", 303 | " ax.set_title(title)\n", 304 | "\n", 305 | " plt.show()\n", 306 | " plt.close()" 307 | ] 308 | }, 309 | { 310 | "cell_type": "markdown", 311 | "metadata": {}, 312 | "source": [ 313 | "
" 314 | ] 315 | }, 316 | { 317 | "cell_type": "markdown", 318 | "metadata": {}, 319 | "source": [ 320 | "## Bayesian Active Learning with images" 321 | ] 322 | }, 323 | { 324 | "cell_type": "markdown", 325 | "metadata": {}, 326 | "source": [ 327 | "* Data labelling is costly and time consuming\n", 328 | "* unlabeled instances are essentially free\n", 329 | "\n", 330 | "**Goal** Achieve high performance with fewer labels by\n", 331 | "identifying the best instances to learn from" 332 | ] 333 | }, 334 | { 335 | "cell_type": "markdown", 336 | "metadata": {}, 337 | "source": [ 338 | "Essential blocks of active learning:\n", 339 | "\n", 340 | "* a **model** $m$ capable of quantifying uncertainty (preferably a Bayesian model)\n", 341 | "* an **acquisition function** $a\\colon \\mathcal{M} \\times \\mathcal{X}^* \\to \\mathbb{R}$\n", 342 | " that for any finite set of inputs $S\\subset \\mathcal{X}$ quantifies their usefulness\n", 343 | " to the model $m\\in \\mathcal{M}$\n", 344 | "* a labelling **oracle**, e.g. a human expert" 345 | ] 346 | }, 347 | { 348 | "cell_type": "markdown", 349 | "metadata": {}, 350 | "source": [ 351 | "### The model" 352 | ] 353 | }, 354 | { 355 | "cell_type": "markdown", 356 | "metadata": {}, 357 | "source": [ 358 | "We reuse the `DropoutLinear` from the first part." 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": null, 364 | "metadata": {}, 365 | "outputs": [], 366 | "source": [ 367 | "from torch.nn import Module, Sequential\n", 368 | "from torch.nn import AvgPool2d, LeakyReLU\n", 369 | "from torch.nn import Linear, Conv2d\n", 370 | "\n", 371 | "from mlss2019bdl.bdl import DropoutLinear, DropoutConv2d\n", 372 | "\n", 373 | "class MNISTModel(Module):\n", 374 | " def __init__(self, p=0.5):\n", 375 | " super().__init__()\n", 376 | "\n", 377 | " self.head = Sequential(\n", 378 | " Conv2d(1, 32, 3, 1),\n", 379 | " LeakyReLU(),\n", 380 | " DropoutConv2d(32, 64, 3, 1, p=p),\n", 381 | " LeakyReLU(),\n", 382 | " AvgPool2d(2),\n", 383 | " )\n", 384 | "\n", 385 | " self.tail = Sequential(\n", 386 | " DropoutLinear(12 * 12 * 64, 128, p=p),\n", 387 | " LeakyReLU(),\n", 388 | " DropoutLinear(128, 10, p=p),\n", 389 | " )\n", 390 | "\n", 391 | " def forward(self, input):\n", 392 | " \"\"\"Take images and compute their class logits.\"\"\"\n", 393 | " x = self.head(input)\n", 394 | " return self.tail(x.flatten(1))" 395 | ] 396 | }, 397 | { 398 | "cell_type": "markdown", 399 | "metadata": {}, 400 | "source": [ 401 | "
" 402 | ] 403 | }, 404 | { 405 | "cell_type": "markdown", 406 | "metadata": {}, 407 | "source": [ 408 | "### the Acquisition Function" 409 | ] 410 | }, 411 | { 412 | "cell_type": "markdown", 413 | "metadata": {}, 414 | "source": [ 415 | "There are many acquisition criteria (borrowed from [Gal17a](http://proceedings.mlr.press/v70/gal17a.html)):\n", 416 | "* Classification\n", 417 | " * Posterior predictive entropy\n", 418 | " * Posterior Mutual Information\n", 419 | " * Variance ratios\n", 420 | " * BALD\n", 421 | "\n", 422 | "* Regression\n", 423 | " * predictive variance\n", 424 | "\n", 425 | "... and there is always the baseline **random acquisition**" 426 | ] 427 | }, 428 | { 429 | "cell_type": "code", 430 | "execution_count": null, 431 | "metadata": {}, 432 | "outputs": [], 433 | "source": [ 434 | "random_state = np.random.RandomState(812_760_351)\n", 435 | "\n", 436 | "def random_acquisition(dataset, model, n_request=1, n_draws=1):\n", 437 | " indices = random_state.choice(len(dataset), size=n_request)\n", 438 | "\n", 439 | " return torch.from_numpy(indices).to(device)" 440 | ] 441 | }, 442 | { 443 | "cell_type": "markdown", 444 | "metadata": {}, 445 | "source": [ 446 | "
" 447 | ] 448 | }, 449 | { 450 | "cell_type": "markdown", 451 | "metadata": {}, 452 | "source": [ 453 | "### Data and the Oracle" 454 | ] 455 | }, 456 | { 457 | "cell_type": "markdown", 458 | "metadata": {}, 459 | "source": [ 460 | "Prepare the datasets from the `train` part of\n", 461 | "[MNIST](http://yann.lecun.com/exdb/mnist/)\n", 462 | "(or [Kuzushiji-MNIST](https://github.com/rois-codh/kmnist)):\n", 463 | "* ($\\mathcal{S}_\\mathrm{train}$) initial **training**: $30$ images\n", 464 | "* ($\\mathcal{S}_\\mathrm{valid}$) our **validation**:\n", 465 | " $5000$ images, stratified\n", 466 | "* ($\\mathcal{S}_\\mathrm{pool}$) acquisition **pool**:\n", 467 | " $5000$ of the unused images, skewed to class $0$\n", 468 | "\n", 469 | "The true test sample of MNIST is in $\\mathcal{S}_\\mathrm{test}$ -- we\n", 470 | "will use it to evaluate the final performance." 471 | ] 472 | }, 473 | { 474 | "cell_type": "code", 475 | "execution_count": null, 476 | "metadata": {}, 477 | "outputs": [], 478 | "source": [ 479 | "from mlss2019bdl.dataset import get_dataset\n", 480 | "\n", 481 | "S_train, S_pool, S_valid, S_test = get_dataset(\n", 482 | " n_train=30,\n", 483 | " n_valid=5000,\n", 484 | " n_pool=5000,\n", 485 | " name=\"MNIST\", # \"KMNIST\"\n", 486 | " path=\"./data\",\n", 487 | " random_state=722_257_201)" 488 | ] 489 | }, 490 | { 491 | "cell_type": "markdown", 492 | "metadata": {}, 493 | "source": [ 494 | "* `query_oracle(ix, D)` **request** the instances in `D` at the specified\n", 495 | " indices `ix` into a dataset and **remove** from them from `D`\n", 496 | "\n", 497 | "* `merge(*datasets, [out=])` merge the datasets, creting a new one, or replacing `out`" 498 | ] 499 | }, 500 | { 501 | "cell_type": "code", 502 | "execution_count": null, 503 | "metadata": {}, 504 | "outputs": [], 505 | "source": [ 506 | "from mlss2019bdl.dataset import collect as query_oracle" 507 | ] 508 | }, 509 | { 510 | "cell_type": "markdown", 511 | "metadata": {}, 512 | "source": [ 513 | "
" 514 | ] 515 | }, 516 | { 517 | "cell_type": "markdown", 518 | "metadata": {}, 519 | "source": [ 520 | "### the Active Learning loop" 521 | ] 522 | }, 523 | { 524 | "cell_type": "markdown", 525 | "metadata": {}, 526 | "source": [ 527 | "1. fit $m$ on $\\mathcal{S}_{\\mathrm{labelled}}$\n", 528 | "\n", 529 | "\n", 530 | "2. get exact (or approximate) $$\n", 531 | " \\mathcal{S}^* \\in \\arg \\max\\limits_{S \\subseteq \\mathcal{S}_\\mathrm{unlabelled}}\n", 532 | " a(m, S)\n", 533 | "$$ satisfying **budget constraints** and **without** access to targets\n", 534 | "(constraints, like $\\lvert S \\rvert \\leq \\ell$ or other economically motivated ones).\n", 535 | "\n", 536 | "\n", 537 | "3. request the **oracle** to provide labels for each $x\\in \\mathcal{S}^*$\n", 538 | "\n", 539 | "\n", 540 | "4. update $\n", 541 | "\\mathcal{S}_{\\mathrm{labelled}}\n", 542 | " \\leftarrow \\mathcal{S}^*\n", 543 | " \\cup \\mathcal{S}_{\\mathrm{labelled}}\n", 544 | "$ and goto 1." 545 | ] 546 | }, 547 | { 548 | "cell_type": "code", 549 | "execution_count": null, 550 | "metadata": {}, 551 | "outputs": [], 552 | "source": [ 553 | "import copy\n", 554 | "from mlss2019bdl.dataset import merge\n", 555 | "\n", 556 | "def active_learn(S_train,\n", 557 | " S_pool,\n", 558 | " S_valid,\n", 559 | " acquire_fn,\n", 560 | " n_budget=150,\n", 561 | " n_max_request=3,\n", 562 | " n_draws=11,\n", 563 | " n_epochs=200,\n", 564 | " p=0.5,\n", 565 | " weight_decay=1e-2):\n", 566 | "\n", 567 | " model = MNISTModel(p=p).to(device)\n", 568 | "\n", 569 | " scores, balances = [], []\n", 570 | " S_train, S_pool = copy.deepcopy(S_train), copy.deepcopy(S_pool)\n", 571 | " while True:\n", 572 | " # 1. fit on train\n", 573 | " l2_reg = weight_decay * (1 - p) / max(len(S_train), 1)\n", 574 | "\n", 575 | " model = fit(model, S_train, batch_size=32, criterion=\"cross_entropy\",\n", 576 | " weight_decay=l2_reg, n_epochs=n_epochs)\n", 577 | "\n", 578 | "\n", 579 | " # (optional) keep track of scores and plot the train dataset\n", 580 | " scores.append(evaluate(model, S_valid, n_draws))\n", 581 | " balances.append(np.bincount(S_train.tensors[1], minlength=10))\n", 582 | "\n", 583 | " accuracy = scores[-1].diagonal().sum() / scores[-1].sum()\n", 584 | " title = f\"(n_train) {len(S_train)} (Acc.) {accuracy:.1%}\"\n", 585 | " display(S_train, n_col=30, figsize=(15, 5), title=title, refresh=True)\n", 586 | "\n", 587 | "\n", 588 | " # 2-3. request new data from pool, if within budget\n", 589 | " n_request = min(n_budget - len(S_train), n_max_request)\n", 590 | " if n_request <= 0:\n", 591 | " break\n", 592 | "\n", 593 | " indices = acquire_fn(S_pool, model, n_request=n_request, n_draws=n_draws)\n", 594 | "\n", 595 | " # 4. update the train dataset\n", 596 | " S_requested = query_oracle(indices, S_pool)\n", 597 | " S_train = merge(S_train, S_requested)\n", 598 | "\n", 599 | " return model, S_train, np.stack(scores, axis=0), np.stack(balances, axis=0)" 600 | ] 601 | }, 602 | { 603 | "cell_type": "markdown", 604 | "metadata": {}, 605 | "source": [ 606 | "* `collect(ix, D)` **collect** the instances in `D` at the specified\n", 607 | " indices `ix` into a dataset and **remove** from them from `D`\n", 608 | "\n", 609 | "* `merge(*datasets, [out=])` merge the datasets, creting a new one, or replacing `out`" 610 | ] 611 | }, 612 | { 613 | "cell_type": "markdown", 614 | "metadata": {}, 615 | "source": [ 616 | "
" 617 | ] 618 | }, 619 | { 620 | "cell_type": "markdown", 621 | "metadata": {}, 622 | "source": [ 623 | "### The baseline" 624 | ] 625 | }, 626 | { 627 | "cell_type": "markdown", 628 | "metadata": {}, 629 | "source": [ 630 | "How powerful will our model with random acquisition get under a total budget of $150$ images?" 631 | ] 632 | }, 633 | { 634 | "cell_type": "code", 635 | "execution_count": null, 636 | "metadata": { 637 | "scrolled": false 638 | }, 639 | "outputs": [], 640 | "source": [ 641 | "baseline = active_learn(\n", 642 | " S_train,\n", 643 | " S_pool,\n", 644 | " S_valid,\n", 645 | " random_acquisition,\n", 646 | " n_draws=21,\n", 647 | " n_budget=150,\n", 648 | " n_max_request=3,\n", 649 | " n_epochs=200,\n", 650 | ")" 651 | ] 652 | }, 653 | { 654 | "cell_type": "markdown", 655 | "metadata": {}, 656 | "source": [ 657 | "Let's see the dynamics of the accuracy ..." 658 | ] 659 | }, 660 | { 661 | "cell_type": "code", 662 | "execution_count": null, 663 | "metadata": {}, 664 | "outputs": [], 665 | "source": [ 666 | "def accuracy(scores):\n", 667 | " tp = scores.diagonal(axis1=-2, axis2=-1)\n", 668 | " return tp.sum(-1) / scores.sum((-2, -1))" 669 | ] 670 | }, 671 | { 672 | "cell_type": "code", 673 | "execution_count": null, 674 | "metadata": {}, 675 | "outputs": [], 676 | "source": [ 677 | "model_rand, train_rand, scores_rand, balances_rand = baseline\n", 678 | "\n", 679 | "fig, ax = plt.subplots(1, 1, figsize=(12, 7))\n", 680 | "ax.plot(accuracy(scores_rand), label='Accuracy (random)', lw=2)\n", 681 | "\n", 682 | "ax.legend()\n", 683 | "plt.show()" 684 | ] 685 | }, 686 | { 687 | "cell_type": "markdown", 688 | "metadata": {}, 689 | "source": [ 690 | "..., and the frequency of each class in $\\mathcal{S}_\\mathrm{train}$." 691 | ] 692 | }, 693 | { 694 | "cell_type": "code", 695 | "execution_count": null, 696 | "metadata": {}, 697 | "outputs": [], 698 | "source": [ 699 | "fig, ax = plt.subplots(1, 1, figsize=(12, 7))\n", 700 | "\n", 701 | "lines = ax.plot(balances_rand, lw=2)\n", 702 | "plt.legend(lines, list(range(10)), ncol=2);" 703 | ] 704 | }, 705 | { 706 | "cell_type": "markdown", 707 | "metadata": {}, 708 | "source": [ 709 | "
" 710 | ] 711 | }, 712 | { 713 | "cell_type": "markdown", 714 | "metadata": {}, 715 | "source": [ 716 | "## Bayesian Active Learning by Disagreement" 717 | ] 718 | }, 719 | { 720 | "cell_type": "markdown", 721 | "metadata": {}, 722 | "source": [ 723 | "Bayesian Active Learning by Disagreement, or **BALD** criterion, is\n", 724 | "based on the posterior mutual information between model's predictions\n", 725 | "$y_x$ at some point $x$ and its parameters $\\omega$:\n", 726 | "\n", 727 | "$$\\begin{align}\n", 728 | " a(m, S)\n", 729 | " &= \\sum_{x\\in S} a(m, \\{x\\})\n", 730 | " \\\\\n", 731 | " a(m, \\{x\\})\n", 732 | " &= \\mathbb{I}(y_x; \\omega \\mid x, m, D)\n", 733 | "\\end{align}\n", 734 | " \\,, \\tag{bald} $$\n", 735 | "\n", 736 | "with the [**Mutual Information**](https://en.wikipedia.org/wiki/Mutual_information#Relation_to_Kullback%E2%80%93Leibler_divergence)\n", 737 | "(**MI**)\n", 738 | "$$\n", 739 | " \\mathbb{I}(y_x; \\omega \\mid x, m, D)\n", 740 | " = \\mathbb{H}\\bigl(\n", 741 | " \\mathbb{E}_{\\omega \\sim q(\\omega\\mid m, D)}\n", 742 | " p(y_x \\,\\mid\\, x, \\omega, m, D)\n", 743 | " \\bigr)\n", 744 | " - \\mathbb{E}_{\\omega \\sim q(\\omega\\mid m, D)}\n", 745 | " \\mathbb{H}\\bigl(\n", 746 | " p(y_x \\,\\mid\\, x, \\omega, m, D)\n", 747 | " \\bigr)\n", 748 | " \\,, \\tag{mi} $$\n", 749 | "\n", 750 | "and the [(differential) **entropy**](https://en.wikipedia.org/wiki/Differential_entropy#Differential_entropies_for_various_distributions)\n", 751 | "(all densities and/or probability mass functions can be conditional):\n", 752 | "\n", 753 | "$$\n", 754 | " \\mathbb{H}(p(y))\n", 755 | " = - \\mathbb{E}_{y\\sim p} \\log p(y)\n", 756 | " \\,. $$" 757 | ] 758 | }, 759 | { 760 | "cell_type": "markdown", 761 | "metadata": {}, 762 | "source": [ 763 | "
" 764 | ] 765 | }, 766 | { 767 | "cell_type": "markdown", 768 | "metadata": {}, 769 | "source": [ 770 | "#### (task) Implementing the acquisition function" 771 | ] 772 | }, 773 | { 774 | "cell_type": "markdown", 775 | "metadata": {}, 776 | "source": [ 777 | "Note that $a(m, S)$ is additively separable in $S$, i.e.\n", 778 | "equals $\\sum_{x\\in S} a(m, \\{x\\})$. This implies\n", 779 | "\n", 780 | "$$\n", 781 | "\\begin{align}\n", 782 | " \\max_{S \\subseteq \\mathcal{S}_\\mathrm{unlabelled}} a(m, S)\n", 783 | " &= \\max_{z \\in \\mathcal{S}_\\mathrm{unlabelled}}\n", 784 | " \\max_{F \\in \\mathcal{S}_\\mathrm{unlabelled} \\setminus \\{z\\}}\n", 785 | " \\sum_{x\\in F \\cup \\{x\\}} a(m, \\{x\\})\n", 786 | " \\\\\n", 787 | " &= \\max_{z \\in \\mathcal{S}_\\mathrm{unlabelled}}\n", 788 | " a(m, \\{z\\})\n", 789 | " + \\max_{F \\in \\mathcal{S}_\\mathrm{unlabelled} \\setminus \\{z\\}}\n", 790 | " \\sum_{x\\in F} a(m, \\{x\\})\n", 791 | "\\end{align}\n", 792 | " \\,. $$" 793 | ] 794 | }, 795 | { 796 | "cell_type": "markdown", 797 | "metadata": {}, 798 | "source": [ 799 | "Therefore selecting the $\\ell$ `most interesting` points from\n", 800 | "$\\mathcal{S}_\\mathrm{unlabelled}$ is trivial." 801 | ] 802 | }, 803 | { 804 | "cell_type": "markdown", 805 | "metadata": {}, 806 | "source": [ 807 | "The acquisition function that we implement has interface\n", 808 | "identical to `random_acquisition` but uses BALD to choose\n", 809 | "instances." 810 | ] 811 | }, 812 | { 813 | "cell_type": "code", 814 | "execution_count": null, 815 | "metadata": {}, 816 | "outputs": [], 817 | "source": [ 818 | "def BALD_acquisition(dataset, model, n_request=1, n_draws=1):\n", 819 | " proba = sample_proba(model, dataset, n_draws=n_draws)\n", 820 | "\n", 821 | " ## Exercise: implement BALD\n", 822 | "\n", 823 | " pass" 824 | ] 825 | }, 826 | { 827 | "cell_type": "markdown", 828 | "metadata": {}, 829 | "source": [ 830 | "
" 831 | ] 832 | }, 833 | { 834 | "cell_type": "markdown", 835 | "metadata": {}, 836 | "source": [ 837 | "#### (task) implementing entropy" 838 | ] 839 | }, 840 | { 841 | "cell_type": "markdown", 842 | "metadata": {}, 843 | "source": [ 844 | "For categorical (discrete) random variables $y \\sim \\mathcal{Cat}(\\mathbf{p})$,\n", 845 | "$\\mathbf{p} \\in \\{ \\mu \\in [0, 1]^d \\colon \\sum_k \\mu_k = 1\\}$, the entropy is\n", 846 | "\n", 847 | "$$\n", 848 | " \\mathbb{H}(p(y))\n", 849 | " = - \\mathbb{E}_{y\\sim p(y)} \\log p(y)\n", 850 | " = - \\sum_k p_k \\log p_k\n", 851 | " \\,. $$" 852 | ] 853 | }, 854 | { 855 | "cell_type": "markdown", 856 | "metadata": {}, 857 | "source": [ 858 | "**(note)** although in calculus $0 \\cdot \\log 0 = 0$ (because\n", 859 | "$\\lim_{p\\downarrow 0} p \\cdot \\log p = 0$), in floating point\n", 860 | "arithmetic $0 \\cdot \\log 0 = \\mathrm{NaN}$. So you need to add\n", 861 | "some **really tiny float number** to the argument of $\\log$." 862 | ] 863 | }, 864 | { 865 | "cell_type": "code", 866 | "execution_count": null, 867 | "metadata": {}, 868 | "outputs": [], 869 | "source": [ 870 | "def categorical_entropy(proba):\n", 871 | " \"\"\"Compute the entropy along the last dimension.\"\"\"\n", 872 | "\n", 873 | " ## Exercise: the probabilities sum to one along the last axis.\n", 874 | " # Please, compute their entropy.\n", 875 | "\n", 876 | " pass" 877 | ] 878 | }, 879 | { 880 | "cell_type": "markdown", 881 | "metadata": {}, 882 | "source": [ 883 | "
" 884 | ] 885 | }, 886 | { 887 | "cell_type": "markdown", 888 | "metadata": {}, 889 | "source": [ 890 | "#### (task) implementing mutual information" 891 | ] 892 | }, 893 | { 894 | "cell_type": "markdown", 895 | "metadata": {}, 896 | "source": [ 897 | "Consider a tensor $p_{bik}$ of probabilities $p(y_{x_i}=k \\mid x_i, \\omega_b, m, D)$\n", 898 | "with $\\omega_b \\sim q(\\omega \\mid m, D)$ with $\\mathcal{W} = (\\omega_b)_{b=1}^B$\n", 899 | "being iid draws from $q(\\omega \\mid m, D)$.\n", 900 | "\n", 901 | "Let's implement a procedure that computes the Monte Carlo estimate of the\n", 902 | "posterior predictive distribution, its **entropy** and **mutual information**\n", 903 | "\n", 904 | "$$\n", 905 | " \\mathbb{I}_\\mathrm{MC}(y_x; \\omega \\mid x, m, D)\n", 906 | " = \\mathbb{H}\\bigl(\n", 907 | " \\hat{p}(y_x\\mid x, m, D)\n", 908 | " \\bigr)\n", 909 | " - \\frac1{\\lvert \\mathcal{W} \\rvert} \\sum_{\\omega\\in \\mathcal{W}}\n", 910 | " \\mathbb{H}\\bigl(\n", 911 | " p(y_x \\,\\mid\\, x, \\omega, m, D)\n", 912 | " \\bigr)\n", 913 | " \\,, \\tag{mi-mc} $$\n", 914 | "where\n", 915 | "$$\n", 916 | "\\hat{p}(y_x\\mid x, m, D)\n", 917 | " = \\frac1{\\lvert \\mathcal{W} \\rvert} \\sum_{\\omega\\in \\mathcal{W}}\n", 918 | " \\,p(y_x \\mid x, \\omega, m, D)\n", 919 | " \\,. $$" 920 | ] 921 | }, 922 | { 923 | "cell_type": "code", 924 | "execution_count": null, 925 | "metadata": {}, 926 | "outputs": [], 927 | "source": [ 928 | "def mutual_information(proba):\n", 929 | " ## Exercise: compute a Monte Carlo estimator of the predictive\n", 930 | " ## distribution, its entropy and MI `H E_w p(., w) - E_w H p(., w)`\n", 931 | "\n", 932 | " pass" 933 | ] 934 | }, 935 | { 936 | "cell_type": "markdown", 937 | "metadata": {}, 938 | "source": [ 939 | "
" 940 | ] 941 | }, 942 | { 943 | "cell_type": "markdown", 944 | "metadata": {}, 945 | "source": [ 946 | "How powerful will our model with **BALD** acquisition, if we can afford no more than $150$ images?" 947 | ] 948 | }, 949 | { 950 | "cell_type": "code", 951 | "execution_count": null, 952 | "metadata": { 953 | "scrolled": false 954 | }, 955 | "outputs": [], 956 | "source": [ 957 | "bald_results = active_learn(\n", 958 | " S_train,\n", 959 | " S_pool,\n", 960 | " S_valid,\n", 961 | " BALD_acquisition,\n", 962 | " n_draws=21,\n", 963 | " n_budget=150,\n", 964 | " n_max_request=3,\n", 965 | " n_epochs=200,\n", 966 | ")" 967 | ] 968 | }, 969 | { 970 | "cell_type": "markdown", 971 | "metadata": {}, 972 | "source": [ 973 | "Let's see the dynamics of the accuracy ..." 974 | ] 975 | }, 976 | { 977 | "cell_type": "code", 978 | "execution_count": null, 979 | "metadata": {}, 980 | "outputs": [], 981 | "source": [ 982 | "model_bald, train_bald, scores_bald, balances_bald = bald_results\n", 983 | "\n", 984 | "fig, ax = plt.subplots(1, 1, figsize=(12, 7))\n", 985 | "\n", 986 | "ax.plot(accuracy(scores_rand), label='Accuracy (random)', lw=2)\n", 987 | "ax.plot(accuracy(scores_bald), label='Accuracy (BALD)', lw=2)\n", 988 | "\n", 989 | "ax.legend()\n", 990 | "plt.show()" 991 | ] 992 | }, 993 | { 994 | "cell_type": "markdown", 995 | "metadata": {}, 996 | "source": [ 997 | "..., and the frequency of each class in $\\mathcal{S}_\\mathrm{train}$." 998 | ] 999 | }, 1000 | { 1001 | "cell_type": "code", 1002 | "execution_count": null, 1003 | "metadata": {}, 1004 | "outputs": [], 1005 | "source": [ 1006 | "fig, ax = plt.subplots(1, 1, figsize=(12, 7))\n", 1007 | "\n", 1008 | "lines = ax.plot(balances_bald, lw=2)\n", 1009 | "plt.legend(lines, list(range(10)), ncol=2);" 1010 | ] 1011 | }, 1012 | { 1013 | "cell_type": "markdown", 1014 | "metadata": {}, 1015 | "source": [ 1016 | "
" 1017 | ] 1018 | }, 1019 | { 1020 | "cell_type": "markdown", 1021 | "metadata": {}, 1022 | "source": [ 1023 | "#### Class performance" 1024 | ] 1025 | }, 1026 | { 1027 | "cell_type": "markdown", 1028 | "metadata": {}, 1029 | "source": [ 1030 | "The *one-versus-rest* precision / recall scores on\n", 1031 | "$\\mathcal{S}_\\mathrm{valid}$. For binary classification:\n", 1032 | "\n", 1033 | "$$ \\begin{align}\n", 1034 | "\\mathrm{Precision}\n", 1035 | " &= \\frac{\\mathrm{TP}}{\\mathrm{TP} + \\mathrm{FP}}\n", 1036 | " \\approx \\mathbb{P}(y = 1 \\mid \\hat{y} = 1)\n", 1037 | " \\,, \\\\\n", 1038 | "\\mathrm{Recall}\n", 1039 | " &= \\frac{\\mathrm{TP}}{\\mathrm{TP} + \\mathrm{FN}}\n", 1040 | " \\approx \\mathbb{P}(\\hat{y} = 1 \\mid y = 1)\n", 1041 | " \\,.\n", 1042 | "\\end{align}$$" 1043 | ] 1044 | }, 1045 | { 1046 | "cell_type": "code", 1047 | "execution_count": null, 1048 | "metadata": {}, 1049 | "outputs": [], 1050 | "source": [ 1051 | "import pandas as pd\n", 1052 | "\n", 1053 | "def pr_scores(score_matrix):\n", 1054 | " tp = score_matrix.diagonal(axis1=-2, axis2=-1)\n", 1055 | " fp, fn = score_matrix.sum(axis=-2) - tp, score_matrix.sum(axis=-1) - tp\n", 1056 | " \n", 1057 | " return pd.DataFrame({\n", 1058 | " \"precision\": {l: f\"{p:.2%}\" for l, p in enumerate(tp / (tp + fp))},\n", 1059 | " \"recall\": {l: f\"{p:.2%}\" for l, p in enumerate(tp / (tp + fn))},\n", 1060 | " })" 1061 | ] 1062 | }, 1063 | { 1064 | "cell_type": "markdown", 1065 | "metadata": {}, 1066 | "source": [ 1067 | "Let's see the performance on the test set" 1068 | ] 1069 | }, 1070 | { 1071 | "cell_type": "code", 1072 | "execution_count": null, 1073 | "metadata": {}, 1074 | "outputs": [], 1075 | "source": [ 1076 | "scores = {}\n", 1077 | "scores[\"rand\"] = evaluate(model_rand, S_test, n_draws=21)\n", 1078 | "scores[\"bald\"] = evaluate(model_bald, S_test, n_draws=21)" 1079 | ] 1080 | }, 1081 | { 1082 | "cell_type": "markdown", 1083 | "metadata": {}, 1084 | "source": [ 1085 | "
" 1086 | ] 1087 | }, 1088 | { 1089 | "cell_type": "code", 1090 | "execution_count": null, 1091 | "metadata": {}, 1092 | "outputs": [], 1093 | "source": [ 1094 | "df = pd.concat({\n", 1095 | " name: pr_scores(score)\n", 1096 | " for name, score in scores.items()\n", 1097 | "}, axis=1).T\n", 1098 | "\n", 1099 | "df.swaplevel().sort_index()" 1100 | ] 1101 | }, 1102 | { 1103 | "cell_type": "markdown", 1104 | "metadata": {}, 1105 | "source": [ 1106 | "
" 1107 | ] 1108 | }, 1109 | { 1110 | "cell_type": "markdown", 1111 | "metadata": {}, 1112 | "source": [ 1113 | "#### Question(s) (to work on in your spare time)\n", 1114 | "\n", 1115 | "* Run the experiments on the `KMNIST` dataset\n", 1116 | "\n", 1117 | "* Replicate figure 1 from [Gat et al. (2017): p. 4](http://proceedings.mlr.press/v70/gal17a.html).\n", 1118 | " You will need to re-run each experiment several times $11$, recording\n", 1119 | " the accuracy dynamics of each, then compare the mean and $25\\%$-$75\\%$\n", 1120 | " quantiles as they evolve with the size of the training sample." 1121 | ] 1122 | }, 1123 | { 1124 | "cell_type": "markdown", 1125 | "metadata": {}, 1126 | "source": [ 1127 | "
" 1128 | ] 1129 | }, 1130 | { 1131 | "cell_type": "markdown", 1132 | "metadata": {}, 1133 | "source": [ 1134 | "### (optional) Points of improvement: batch-vs-single" 1135 | ] 1136 | }, 1137 | { 1138 | "cell_type": "markdown", 1139 | "metadata": {}, 1140 | "source": [ 1141 | "A drawback of the `pointwise` top-$\\ell$ procedure above is that, although\n", 1142 | "it acquires individually informative instances, altogether they might end\n", 1143 | "up **being** `jointly poorly informative`. This can be corrected, if we\n", 1144 | "would seek the highest mutual information among finite sets $\n", 1145 | "S \\subseteq \\mathcal{S}_\\mathrm{unlabelled}\n", 1146 | "$ of size $\\ell$." 1147 | ] 1148 | }, 1149 | { 1150 | "cell_type": "markdown", 1151 | "metadata": {}, 1152 | "source": [ 1153 | "Such acquisition function is called **batch-BALD**\n", 1154 | "([Kirsch et al.; 2019](https://arxiv.org/abs/1906.08158.pdf)):\n", 1155 | "\n", 1156 | "$$\\begin{align}\n", 1157 | " a(m, S)\n", 1158 | " &= \\mathbb{I}\\bigl((y_x)_{x\\in S}; \\omega \\mid S, m \\bigr)\n", 1159 | " = \\mathbb{H} \\bigl(\n", 1160 | " \\mathbb{E}_{\\omega \\sim q(\\omega\\mid m)} p\\bigl((y_x)_{x\\in S}\\mid S, \\omega, m \\bigr)\n", 1161 | " \\bigr)\n", 1162 | " - \\mathbb{E}_{\\omega \\sim q(\\omega\\mid m)} H\\bigl(\n", 1163 | " p\\bigl((y_x)_{x\\in S}\\mid S, \\omega, m \\bigr)\n", 1164 | " \\bigr)\n", 1165 | "\\end{align}\n", 1166 | " \\,. \\tag{batch-bald} $$" 1167 | ] 1168 | }, 1169 | { 1170 | "cell_type": "markdown", 1171 | "metadata": {}, 1172 | "source": [ 1173 | "This criterion requires combinatorially growing number of computations and\n", 1174 | "memory, however there are working solutions like random sampling of subsets\n", 1175 | "$\\mathcal{S}$ from $\\mathcal{S}_\\mathrm{unlabelled}$ or greedily maximizing\n", 1176 | "of this **submodular** criterion." 1177 | ] 1178 | }, 1179 | { 1180 | "cell_type": "markdown", 1181 | "metadata": {}, 1182 | "source": [ 1183 | "
" 1184 | ] 1185 | }, 1186 | { 1187 | "cell_type": "markdown", 1188 | "metadata": {}, 1189 | "source": [ 1190 | "### (optional) Points of improvement: bias" 1191 | ] 1192 | }, 1193 | { 1194 | "cell_type": "markdown", 1195 | "metadata": {}, 1196 | "source": [ 1197 | "The first term in the **MC** estimate of the mutual information is the\n", 1198 | "so-called **plug-in** estimator of the entropy:\n", 1199 | "\n", 1200 | "$$\n", 1201 | " \\hat{H}\n", 1202 | " = \\mathbb{H}(\\hat{p}) = - \\sum_k \\hat{p}_k \\log \\hat{p}_k\n", 1203 | " \\,, $$\n", 1204 | "\n", 1205 | "where $\\hat{p}_k = \\tfrac1B \\sum_b p_{bk}$ is the full sample estimator\n", 1206 | "of the probabilities." 1207 | ] 1208 | }, 1209 | { 1210 | "cell_type": "markdown", 1211 | "metadata": {}, 1212 | "source": [ 1213 | "It is known that this plug-in estimate is biased\n", 1214 | "(see [blog: Nowozin, 2015](http://www.nowozin.net/sebastian/blog/estimating-discrete-entropy-part-1.html)\n", 1215 | "and references therein, also this [notebook](https://colab.research.google.com/drive/1z9ZDNM6NFmuFnU28d8UO0Qymbd2LiNJW)). \n", 1216 | "In order to correct for small-sample bias we can use\n", 1217 | "[jackknife resampling](https://en.wikipedia.org/wiki/Jackknife_resampling).\n", 1218 | "It derives an estimate of the finite sample bias from the leave-one-out\n", 1219 | "estimators of the entropy and is relatively computationally cheap\n", 1220 | "(see [blog: Nowozin, 2015](http://www.nowozin.net/sebastian/blog/estimating-discrete-entropy-part-2.html),\n", 1221 | "[Miller, R. G. (1974)](http://www.math.ntu.edu.tw/~hchen/teaching/LargeSample/references/Miller74jackknife.pdf) and these [notes](http://people.bu.edu/aimcinto/jackknife.pdf)).\n", 1222 | "\n", 1223 | "The jackknife correction of a plug-in estimator $\\mathbb{H}(\\cdot)$\n", 1224 | "is computed thus: given a sample $(p_b)_{b=1}^B$ with $p_b$ -- discrete distribution on $1..K$\n", 1225 | "* for each $b=1.. B$\n", 1226 | " * get the leave-one-out estimator: $\\hat{p}_k^{-b} = \\tfrac1{B-1} \\sum_{j\\neq b} p_{jk}$\n", 1227 | " * compute the plug-in entropy estimator: $\\hat{H}_{-b} = \\mathbb{H}(\\hat{p}^{-b})$\n", 1228 | "* then compute the bias-corrected entropy estimator $\n", 1229 | "\\hat{H}_J\n", 1230 | " = \\hat{H} + (B - 1) \\bigl\\{\n", 1231 | " \\hat{H} - \\tfrac1B \\sum_b \\hat{H}^{-b}\n", 1232 | " \\bigr\\}\n", 1233 | "$" 1234 | ] 1235 | }, 1236 | { 1237 | "cell_type": "markdown", 1238 | "metadata": {}, 1239 | "source": [ 1240 | "**(note)** when we knock the $i$-th data point out of the sample mean\n", 1241 | "$\\mu = \\tfrac1n \\sum_i x_i$ and recompute the mean $\\mu_{-i}$ we get\n", 1242 | "the following relation\n", 1243 | "$$ \\mu_{-i}\n", 1244 | " = \\frac1{n-1} \\sum_{j\\neq i} x_j\n", 1245 | " = \\frac{n}{n-1} \\mu - \\tfrac1{n-1} x_i\n", 1246 | " = \\mu + \\frac{\\mu - x_i}{n-1}\n", 1247 | " \\,. $$\n", 1248 | "This makes it possible to quickly compute leave-one-out estimators of\n", 1249 | "discrete probability distribution." 1250 | ] 1251 | }, 1252 | { 1253 | "cell_type": "markdown", 1254 | "metadata": {}, 1255 | "source": [ 1256 | "#### (task*) Unbiased estimator of entropy and mutual information\n", 1257 | "\n", 1258 | "Try to efficiently implement a bias-corrected acquisition\n", 1259 | "function, and see it is worth the effort." 1260 | ] 1261 | }, 1262 | { 1263 | "cell_type": "code", 1264 | "execution_count": null, 1265 | "metadata": {}, 1266 | "outputs": [], 1267 | "source": [ 1268 | "def BALD_jknf_acquisition(dataset, model, n_request=1, n_draws=1):\n", 1269 | " proba = sample_proba(model, dataset, n_draws=n_draws)\n", 1270 | "\n", 1271 | " ## Exercise: MC estimate of the predictive distribution, entropy and MI\n", 1272 | " ## mutual information `H E_w p(., w) - E_w H p(., w)` with jackknife\n", 1273 | " ## correction.\n", 1274 | "\n", 1275 | " pass" 1276 | ] 1277 | }, 1278 | { 1279 | "cell_type": "markdown", 1280 | "metadata": {}, 1281 | "source": [ 1282 | "
" 1283 | ] 1284 | }, 1285 | { 1286 | "cell_type": "markdown", 1287 | "metadata": {}, 1288 | "source": [ 1289 | "Let's see ..." 1290 | ] 1291 | }, 1292 | { 1293 | "cell_type": "code", 1294 | "execution_count": null, 1295 | "metadata": { 1296 | "scrolled": false 1297 | }, 1298 | "outputs": [], 1299 | "source": [ 1300 | "jknf_results = active_learn(\n", 1301 | " S_train,\n", 1302 | " S_pool,\n", 1303 | " S_valid,\n", 1304 | " BALD_jknf_acquisition,\n", 1305 | " n_draws=21,\n", 1306 | " n_budget=150,\n", 1307 | " n_max_request=3,\n", 1308 | " n_epochs=200,\n", 1309 | ")" 1310 | ] 1311 | }, 1312 | { 1313 | "cell_type": "code", 1314 | "execution_count": null, 1315 | "metadata": {}, 1316 | "outputs": [], 1317 | "source": [ 1318 | "fig, ax = plt.subplots(1, 1, figsize=(12, 7))\n", 1319 | "\n", 1320 | "model_jknf, train_jknf, scores_jknf, balances_jknf = jknf_results\n", 1321 | "ax.plot(accuracy(scores_rand), label='Accuracy (random)', lw=2)\n", 1322 | "ax.plot(accuracy(scores_bald), label='Accuracy (BALD)', lw=2)\n", 1323 | "ax.plot(accuracy(scores_jknf), label='Accuracy (BALD-jknf)', lw=2)\n", 1324 | "\n", 1325 | "ax.legend()\n", 1326 | "plt.show()" 1327 | ] 1328 | }, 1329 | { 1330 | "cell_type": "code", 1331 | "execution_count": null, 1332 | "metadata": {}, 1333 | "outputs": [], 1334 | "source": [ 1335 | "fig, ax = plt.subplots(1, 1, figsize=(12, 7))\n", 1336 | "\n", 1337 | "lines = ax.plot(balances_jknf, lw=2)\n", 1338 | "plt.legend(lines, list(range(10)), ncol=2);" 1339 | ] 1340 | }, 1341 | { 1342 | "cell_type": "markdown", 1343 | "metadata": {}, 1344 | "source": [ 1345 | "
" 1346 | ] 1347 | } 1348 | ], 1349 | "metadata": { 1350 | "kernelspec": { 1351 | "display_name": "Python 3", 1352 | "language": "python", 1353 | "name": "python3" 1354 | }, 1355 | "language_info": { 1356 | "codemirror_mode": { 1357 | "name": "ipython", 1358 | "version": 3 1359 | }, 1360 | "file_extension": ".py", 1361 | "mimetype": "text/x-python", 1362 | "name": "python", 1363 | "nbconvert_exporter": "python", 1364 | "pygments_lexer": "ipython3", 1365 | "version": "3.7.2" 1366 | } 1367 | }, 1368 | "nbformat": 4, 1369 | "nbformat_minor": 2 1370 | } 1371 | -------------------------------------------------------------------------------- /bayesian_deep_learning/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/bayesian_deep_learning/README.md -------------------------------------------------------------------------------- /bayesian_deep_learning/mlss2019bdl/__init__.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | import torch 3 | 4 | import torch.nn.functional as F 5 | 6 | from torch.utils.data import TensorDataset, DataLoader 7 | 8 | 9 | def dataset_from_numpy(*ndarrays, device=None, dtype=torch.float32): 10 | """Create :class:`TensorDataset` from the passed :class:`numpy.ndarray`-s. 11 | 12 | Each returned tensor in the TensorDataset and :attr:`ndarray` share 13 | the same memory, unless a type cast or device transfer took place. 14 | Modifications to any tensor in the dataset will be reflected in respective 15 | :attr:`ndarray` and vice versa. 16 | 17 | Each returned tensor in the dataset is not resizable. 18 | 19 | See Also 20 | -------- 21 | torch.from_numpy : create a tensor from an ndarray. 22 | """ 23 | tensors = map(torch.from_numpy, ndarrays) 24 | 25 | return TensorDataset(*[t.to(device, dtype) for t in tensors]) 26 | 27 | 28 | default_criteria = { 29 | "cross_entropy": 30 | lambda model, X, y: F.cross_entropy(model(X), y, reduction="mean"), 31 | "mse": 32 | lambda model, X, y: 0.5 * F.mse_loss(model(X), y, reduction="mean"), 33 | } 34 | 35 | 36 | def fit(model, dataset, criterion="mse", batch_size=32, 37 | n_epochs=1, weight_decay=0, verbose=False): 38 | """Fit the model with SGD (Adam) on the specified dataset and criterion. 39 | 40 | This bare minimum of a fit loop creates a minibatch generator from 41 | the `dataset` with batches of size `batch_size`. On each batch it 42 | computes the backward pass through the `criterion` and the `model` 43 | and updates the `model`-s parameters with the Adam optimizer step. 44 | The loop passes through the dataset `n_epochs` times. It does not 45 | output any running debugging information, except for a progress bar. 46 | 47 | The criterion can be either "mse" for mean sqaured error, "nll" for 48 | negative loglikelihood (categorical), or a callable taking `model, X, y` 49 | as arguments. 50 | """ 51 | if len(dataset) <= 0 or batch_size <= 0: 52 | return model 53 | 54 | criterion = default_criteria.get(criterion, criterion) 55 | assert callable(criterion) 56 | 57 | # get the model's device 58 | device = next(model.parameters()).device 59 | 60 | # an optimizer for model's parameters 61 | optim = torch.optim.Adam(model.parameters(), lr=2e-3, 62 | weight_decay=weight_decay) 63 | 64 | # stochastic minibatch generator for the training loop 65 | feed = DataLoader(dataset, shuffle=True, batch_size=batch_size) 66 | for epoch in tqdm.tqdm(range(n_epochs), disable=not verbose): 67 | 68 | model.train() 69 | 70 | for X, y in feed: 71 | # forward pass through the criterion (batch-average loss) 72 | loss = criterion(model, X.to(device), y.to(device)) 73 | 74 | # get gradients with backward pass 75 | optim.zero_grad() 76 | loss.backward() 77 | 78 | # SGD update 79 | optim.step() 80 | 81 | return model 82 | 83 | 84 | def predict(model, dataset, batch_size=512): 85 | """Get model's output on the dataset. 86 | 87 | This straightforward function switches the model into `evaluation` 88 | regime, computes the forward pass on the `dataset` (in batches of 89 | size `batch_size`) and stacks the results into a tensor on the `cpu`. 90 | It temporarily disables `autograd` to gain some speed-up. 91 | """ 92 | model.eval() 93 | 94 | # get the model's device 95 | device = next(model.parameters()).device 96 | 97 | # batch generator for the evaluation loop 98 | feed = DataLoader(dataset, batch_size=batch_size, shuffle=False) 99 | 100 | # compute and collect the outputs 101 | with torch.no_grad(): 102 | return torch.cat([ 103 | model(X.to(device)).cpu() for X, *rest in feed 104 | ], dim=0) 105 | -------------------------------------------------------------------------------- /bayesian_deep_learning/mlss2019bdl/bdl/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import freeze, unfreeze 2 | 3 | 4 | from .bernoulli import DropoutLinear, DropoutConv2d 5 | -------------------------------------------------------------------------------- /bayesian_deep_learning/mlss2019bdl/bdl/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch.nn import Module 4 | 5 | 6 | class FreezableWeight(Module): 7 | def __init__(self): 8 | super().__init__() 9 | self.unfreeze() 10 | 11 | def unfreeze(self): 12 | self.register_buffer("frozen_weight", None) 13 | 14 | def is_frozen(self): 15 | """Check if a frozen weight is available.""" 16 | return isinstance(self.frozen_weight, torch.Tensor) 17 | 18 | def freeze(self): 19 | """Sample from the distribution and freeze.""" 20 | raise NotImplementedError() 21 | 22 | 23 | def freeze(module): 24 | for mod in module.modules(): 25 | if isinstance(mod, FreezableWeight): 26 | mod.freeze() 27 | 28 | return module # return self 29 | 30 | 31 | def unfreeze(module): 32 | for mod in module.modules(): 33 | if isinstance(mod, FreezableWeight): 34 | mod.unfreeze() 35 | 36 | return module # return self 37 | 38 | 39 | class PenalizedWeight(Module): 40 | def penalty(self): 41 | raise NotImplementedError() 42 | 43 | 44 | def penalties(module): 45 | for mod in module.modules(): 46 | if isinstance(mod, PenalizedWeight): 47 | yield mod.penalty() 48 | -------------------------------------------------------------------------------- /bayesian_deep_learning/mlss2019bdl/bdl/bernoulli.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from torch.nn import Linear, Conv2d 5 | 6 | from .base import FreezableWeight, PenalizedWeight 7 | 8 | 9 | class DropoutLinear(Linear, FreezableWeight): 10 | """Linear layer with dropout on inputs.""" 11 | def __init__(self, in_features, out_features, bias=True, p=0.5): 12 | super().__init__(in_features, out_features, bias=bias) 13 | 14 | self.p = p 15 | 16 | def forward(self, input): 17 | if self.is_frozen(): 18 | return F.linear(input, self.frozen_weight, self.bias) 19 | 20 | return super().forward(F.dropout(input, self.p, True)) 21 | 22 | def freeze(self): 23 | # let's draw the new weight 24 | with torch.no_grad(): 25 | prob = torch.full_like(self.weight[:1, :], 1 - self.p) 26 | feature_mask = torch.bernoulli(prob) / prob 27 | 28 | frozen_weight = self.weight * feature_mask 29 | 30 | # and store it 31 | self.register_buffer("frozen_weight", frozen_weight) 32 | 33 | 34 | class DropoutConv2d(Conv2d, FreezableWeight): 35 | """2d Convolutional layer with dropout on input features.""" 36 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 37 | padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', 38 | p=0.5): 39 | 40 | super().__init__(in_channels, out_channels, kernel_size, stride=stride, 41 | padding=padding, dilation=dilation, groups=groups, 42 | bias=bias, padding_mode=padding_mode) 43 | 44 | self.p = p 45 | 46 | def forward(self, input): 47 | """Apply feature dropout and then forward pass through the convolution.""" 48 | if self.is_frozen(): 49 | return F.conv2d(input, self.frozen_weight, self.bias, self.stride, 50 | self.padding, self.dilation, self.groups) 51 | 52 | return super().forward(F.dropout2d(input, self.p, True)) 53 | 54 | def freeze(self): 55 | """Sample the weight from the parameter distribution and freeze it.""" 56 | prob = torch.full_like(self.weight[:1, :, :1, :1], 1 - self.p) 57 | feature_mask = torch.bernoulli(prob) / prob 58 | 59 | with torch.no_grad(): 60 | frozen_weight = self.weight * feature_mask 61 | 62 | self.register_buffer("frozen_weight", frozen_weight) 63 | 64 | -------------------------------------------------------------------------------- /bayesian_deep_learning/mlss2019bdl/bdl/gaussian.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from torch.nn import Linear, Conv2d 5 | 6 | from .base import FreezableWeight, PenalizedWeight 7 | 8 | 9 | class BaseGaussianLinear(Linear, FreezableWeight, PenalizedWeight): 10 | """Linear layer with Gaussian Mean Field weight distribution.""" 11 | def __init__(self, in_features, out_features, bias=True): 12 | super().__init__(in_features, out_features, bias=bias) 13 | 14 | self.log_sigma2 = torch.nn.Parameter( 15 | torch.Tensor(*self.weight.shape)) 16 | 17 | self.reset_variational_parameters() 18 | 19 | def reset_variational_parameters(self): 20 | self.log_sigma2.data.normal_(-5, 0.1) # from arxiv:1811.00596 21 | 22 | def forward(self, input): 23 | """Forward pass for the linear layer with the local reparameterization trick.""" 24 | 25 | if self.is_frozen(): 26 | return F.linear(input, self.frozen_weight, self.bias) 27 | 28 | s2 = F.linear(input * input, torch.exp(self.log_sigma2), None) 29 | 30 | return torch.normal(super().forward(input), torch.sqrt(s2 + 1e-20)) 31 | 32 | def freeze(self): 33 | 34 | with torch.no_grad(): 35 | stdev = torch.exp(0.5 * self.log_sigma2) 36 | weight = torch.normal(self.weight, std=stdev) 37 | 38 | self.register_buffer("frozen_weight", weight) 39 | 40 | 41 | class BaseGaussianConv2d(Conv2d, PenalizedWeight, FreezableWeight): 42 | """Convolutional layer with Gaussian Mean Field weight distribution.""" 43 | 44 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 45 | padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'): 46 | super().__init__(in_channels, out_channels, kernel_size, stride=stride, 47 | padding=padding, dilation=dilation, groups=groups, 48 | bias=bias, padding_mode=padding_mode) 49 | 50 | self.log_sigma2 = torch.nn.Parameter( 51 | torch.Tensor(*self.weight.shape)) 52 | 53 | self.reset_variational_parameters() 54 | 55 | reset_variational_parameters = BaseGaussianLinear.reset_variational_parameters 56 | 57 | def forward(self, input): 58 | """Forward pass with the local reparameterization trick.""" 59 | if self.is_frozen(): 60 | return F.conv2d(input, self.frozen_weight, self.bias, self.stride, 61 | self.padding, self.dilation, self.groups) 62 | 63 | s2 = F.conv2d(input * input, torch.exp(self.log_sigma2), None, 64 | self.stride, self.padding, self.dilation, self.groups) 65 | 66 | return torch.normal(super().forward(input), torch.sqrt(s2 + 1e-20)) 67 | 68 | freeze = BaseGaussianLinear.freeze 69 | 70 | 71 | class GaussianLinearARD(BaseGaussianLinear): 72 | def penalty(self): 73 | # compute \tfrac12 \log (1 + \tfrac{\mu_{ji}}{\sigma_{ji}^2}) 74 | log_weight2 = 2 * torch.log(torch.abs(self.weight) + 1e-20) 75 | 76 | # `softplus` is $x \mapsto \log(1 + e^x)$ 77 | return 0.5 * torch.sum(F.softplus(log_weight2 - self.log_sigma2)) 78 | 79 | 80 | class GaussianConv2dARD(BaseGaussianConv2d): 81 | penalty = GaussianLinearARD.penalty 82 | -------------------------------------------------------------------------------- /bayesian_deep_learning/mlss2019bdl/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from torch.utils.data import TensorDataset 5 | from torchvision import datasets 6 | 7 | from sklearn.utils import check_random_state 8 | from sklearn.model_selection import train_test_split 9 | 10 | 11 | def get_data(name, path="./data", train=True): 12 | if name == "MNIST": 13 | dataset = datasets.MNIST(path, train=train, download=True) 14 | elif name == "KMNIST": 15 | dataset = datasets.KMNIST(path, train=train, download=True) 16 | 17 | images = dataset.data.float().unsqueeze(1) 18 | return TensorDataset(images / 255., dataset.targets) 19 | 20 | 21 | def get_dataset(n_train=20, n_valid=5000, n_pool=5000, 22 | name="MNIST", path="./data", random_state=None): 23 | random_state = check_random_state(random_state) 24 | 25 | dataset = get_data(name, path, train=True) 26 | S_test = get_data(name, path, train=False) 27 | 28 | # create an imbalanced class label distribution for the train 29 | targets = dataset.tensors[-1].cpu().numpy() 30 | 31 | # split the dataset into validaton and train 32 | ix_all = np.r_[:len(targets)] 33 | ix_train, ix_valid = train_test_split( 34 | ix_all, stratify=targets, shuffle=True, 35 | train_size=max(n_train, 1), test_size=max(n_valid, 1), 36 | random_state=random_state) 37 | 38 | # prepare the datasets: pool, train and validation 39 | if n_train < 1: 40 | ix_train = np.r_[:0] 41 | S_train = TensorDataset(*dataset[ix_train]) 42 | 43 | if n_valid < 1: 44 | ix_valid = np.r_[:0] 45 | S_valid = TensorDataset(*dataset[ix_valid]) 46 | 47 | # prepare the pool 48 | ix_pool = np.delete(ix_all, np.r_[ix_train, ix_valid]) 49 | 50 | # we want to have lots of boring/useless examples in the pool 51 | labels, share = (1, 2, 3, 4, 5, 6, 7, 8, 9), 0.95 52 | pool_targets, dropped = targets[ix_pool], [] 53 | 54 | # deplete the pool of each class 55 | for label in labels: 56 | ix_cls = np.flatnonzero(pool_targets == label) 57 | n_kept = int(share * len(ix_cls)) 58 | 59 | # pick examples at random to drop 60 | ix_cls = random_state.permutation(ix_cls) 61 | dropped.append(ix_cls[:n_kept]) 62 | 63 | ix_pool = np.delete(ix_pool, np.concatenate(dropped)) 64 | 65 | # select at most `n_pool` examples 66 | if n_pool > 0: 67 | ix_pool = random_state.permutation(ix_pool)[:n_pool] 68 | S_pool = TensorDataset(*dataset[ix_pool]) 69 | 70 | return S_train, S_pool, S_valid, S_test 71 | 72 | 73 | def collect(indices, dataset): 74 | """Collect the specified samples from the dataset and remove.""" 75 | assert len(dataset) > 0 76 | 77 | mask = torch.zeros(len(dataset), dtype=torch.uint8) 78 | mask[indices] = True 79 | 80 | collected = TensorDataset(*dataset[mask]) 81 | 82 | dataset.tensors = dataset[~mask] 83 | 84 | return collected 85 | 86 | 87 | def merge(*datasets, out=None): 88 | # Classes derived from Dataset support appending via 89 | # `+` (__add__), but this breaks slicing. 90 | 91 | data = [d.tensors for d in datasets if d is not None and d.tensors] 92 | assert all(len(data[0]) == len(d) for d in data) 93 | 94 | tensors = [torch.cat(tup, dim=0) for tup in zip(*data)] 95 | 96 | if isinstance(out, TensorDataset): 97 | out.tensors = tensors 98 | return out 99 | 100 | return TensorDataset(*tensors) 101 | -------------------------------------------------------------------------------- /bayesian_deep_learning/mlss2019bdl/flex.py: -------------------------------------------------------------------------------- 1 | """Handy plotting procedures for small 2d images.""" 2 | import numpy as np 3 | 4 | from torch import Tensor 5 | from math import sqrt 6 | 7 | 8 | def get_dimensions(n_samples, height, width, 9 | n_row=None, n_col=None, aspect=(16, 9)): 10 | """Get the dimensions that aesthetically conform to the aspect ratio.""" 11 | if n_row is None and n_col is None: 12 | ratio = (width * aspect[1]) / (height * aspect[0]) 13 | n_row = int(sqrt(n_samples * ratio)) 14 | 15 | if n_row is None: 16 | n_row = (n_samples + n_col - 1) // n_col 17 | 18 | elif n_col is None: 19 | n_col = (n_samples + n_row - 1) // n_row 20 | 21 | return n_row, n_col 22 | 23 | 24 | def setup_canvas(ax, height, width, n_row, n_col): 25 | """Setup the ticks and labels for the canvas.""" 26 | # A pair of index arrays 27 | row_index, col_index = np.r_[:n_row], np.r_[:n_col] 28 | 29 | # Setup major ticks to the seams between images and disable labels 30 | ax.set_yticks((row_index[:-1] + 1) * height - 0.5, minor=False) 31 | ax.set_xticks((col_index[:-1] + 1) * width - 0.5, minor=False) 32 | 33 | ax.set_yticklabels([], minor=False) 34 | ax.set_xticklabels([], minor=False) 35 | 36 | # Set minor ticks so that they are exactly between the major ones 37 | ax.set_yticks((row_index + 0.5) * height, minor=True) 38 | ax.set_xticks((col_index + 0.5) * width, minor=True) 39 | 40 | # ... and make their labels into i-j coordinates 41 | ax.set_yticklabels([f"{i:d}" for i in row_index], minor=True) 42 | ax.set_xticklabels([f"{j:d}" for j in col_index], minor=True) 43 | 44 | # Orient tick marks outward 45 | ax.tick_params(axis="both", which="both", direction="out") 46 | return ax 47 | 48 | 49 | def arrange(n_row, n_col, data, fill_value=0): 50 | """Create a grid and populate it with images.""" 51 | n_samples, height, width, *color = data.shape 52 | grid = np.full((n_row * height, n_col * width, *color), 53 | fill_value, dtype=data.dtype) 54 | 55 | for k in range(min(n_samples, n_col * n_row)): 56 | i, j = (k // n_col) * height, (k % n_col) * width 57 | grid[i:i + height, j:j + width] = data[k] 58 | 59 | return grid 60 | 61 | 62 | def to_hwc(images, format): 63 | assert format in ("chw", "hwc"), f"Unrecognized format `{format}`." 64 | 65 | if images.ndim == 3: 66 | return images[..., np.newaxis] 67 | 68 | assert images.ndim == 4, f"Images must be Nx{'x'.join(format.upper())}." 69 | 70 | if format == "chw": 71 | return images.transpose(0, 2, 3, 1) 72 | 73 | elif format == "hwc": 74 | return images 75 | 76 | 77 | def plot(ax, images, *, n_col=None, n_row=None, format="chw", **kwargs): 78 | """Plot images in the numpy array on the specified matplotlib Axis.""" 79 | if isinstance(images, Tensor): 80 | images = images.data.cpu().numpy() 81 | 82 | images = to_hwc(images, format) 83 | 84 | n_samples, height, width, *color = images.shape 85 | if n_samples < 1: 86 | return None 87 | 88 | n_row, n_col = get_dimensions(n_samples, height, width, n_row, n_col) 89 | ax = setup_canvas(ax, height, width, n_row, n_col) 90 | 91 | image = arrange(n_row, n_col, images) 92 | return ax.imshow(image.squeeze(), **kwargs, origin="upper") 93 | -------------------------------------------------------------------------------- /bayesian_deep_learning/mlss2019bdl/plotting.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | from torch import Tensor 4 | from numpy import asarray 5 | 6 | 7 | def darker(color, a=0.5): 8 | """Adapted from this stackoverflow question_. 9 | 10 | .. _question: https://stackoverflow.com/questions/37765197/ 11 | """ 12 | from matplotlib.colors import to_rgb 13 | from colorsys import rgb_to_hls, hls_to_rgb 14 | 15 | h, l, s = rgb_to_hls(*to_rgb(color)) 16 | return hls_to_rgb(h, max(0, min(a * l, 1)), s) 17 | 18 | 19 | def canvas1d(*, figsize=(12, 5)): 20 | """Setup canvas for 1d function plot.""" 21 | fig, ax = plt.subplots(1, 1, figsize=figsize) 22 | 23 | fig.patch.set_alpha(1.0) 24 | ax.set_xlim(-7, +7) 25 | ax.set_ylim(-7, +9) 26 | 27 | return fig, ax 28 | 29 | 30 | def to_numpy(tensor): 31 | if isinstance(tensor, Tensor): 32 | tensor = tensor.data.cpu().numpy() 33 | 34 | return asarray(tensor).squeeze() 35 | 36 | 37 | def plot1d(X, y, bands, ax=None, **kwargs): 38 | X, y = to_numpy(X), to_numpy(y) 39 | assert y.ndim == 2 and X.ndim == 1 40 | 41 | ax = plt.gca() if ax is None else ax 42 | 43 | # plot the predictive mean with the specified colour 44 | y_mean, y_std = y.mean(axis=-1), y.std(axis=-1) 45 | line, = ax.plot(X, y_mean, **kwargs) 46 | 47 | # plot paths or bands with a lighter color and slightly behind the mean 48 | color, zorder = darker(line.get_color(), 1.25), line.get_zorder() 49 | if bands is None: 50 | ax.plot(X, y, c=color, alpha=0.08, zorder=zorder - 1) 51 | 52 | else: 53 | for band in sorted(bands): 54 | ax.fill_between(X, y_mean + band * y_std, y_mean - band * y_std, 55 | color=color, zorder=zorder-1, 56 | alpha=0.4 / len(bands)) 57 | 58 | return line 59 | 60 | 61 | def plot1d_bands(X, y, ax=None, **kwargs): 62 | # return plot1d(X, y, bands=(0.5, 1.0, 1.5, 2.0), ax=ax, **kwargs) 63 | return plot1d(X, y, bands=(1.96,), ax=ax, **kwargs) 64 | 65 | 66 | def plot1d_paths(X, y, ax=None, **kwargs): 67 | return plot1d(X, y, bands=None, ax=ax, **kwargs) 68 | -------------------------------------------------------------------------------- /bayesian_deep_learning/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | 3 | setup( 4 | name="mlss2019bdl", 5 | version="0.2", 6 | description="""Service code for MLSS2019 Tutorial on Bayesian Deep Learning""", 7 | license="MIT License", 8 | author="Ivan Nazarov, Yarin Gal", 9 | author_email="ivan.nazarov@skolkovotech.ru", 10 | packages=[ 11 | "mlss2019bdl", 12 | "mlss2019bdl.bdl", 13 | ], 14 | install_requires=[ 15 | "numpy", 16 | "tqdm", 17 | "matplotlib", 18 | "torch", 19 | "torchvision", 20 | ] 21 | ) 22 | -------------------------------------------------------------------------------- /causality/exercises-answers.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/causality/exercises-answers.pdf -------------------------------------------------------------------------------- /causality/exercises-tutorial.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/causality/exercises-tutorial.pdf -------------------------------------------------------------------------------- /geometric_techniques_in_ML/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include data/*.csv 2 | -------------------------------------------------------------------------------- /geometric_techniques_in_ML/riemannian_opt_gmm_embeddings.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "kernelspec": { 6 | "display_name": "Python 2", 7 | "language": "python", 8 | "name": "python2" 9 | }, 10 | "language_info": { 11 | "codemirror_mode": { 12 | "name": "ipython", 13 | "version": 3 14 | }, 15 | "file_extension": ".py", 16 | "mimetype": "text/x-python", 17 | "name": "python", 18 | "nbconvert_exporter": "python", 19 | "pygments_lexer": "ipython3", 20 | "version": "3.6.5" 21 | }, 22 | "colab": { 23 | "name": "riemannian_opt_gmm_embeddings.ipynb", 24 | "version": "0.3.2", 25 | "provenance": [] 26 | } 27 | }, 28 | "cells": [ 29 | { 30 | "cell_type": "markdown", 31 | "metadata": { 32 | "id": "FnIQw3vrPEhl", 33 | "colab_type": "text" 34 | }, 35 | "source": [ 36 | "This is a tutorial notebook on Riemannian optimization for machine learning, prepared for the Machine Learning Summer School 2019 (MLSS-2019, http://mlss2019.skoltech.ru) in Moscow, Russia, Skoltech (http://skoltech.ru).\n", 37 | "\n", 38 | "Copyright 2019 by Alexey Artemov and ADASE 3DDL Team. Special thanks to Alexey Zaytsev for a valuable contribution." 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": { 44 | "id": "8MD6cSJaPEhn", 45 | "colab_type": "text" 46 | }, 47 | "source": [ 48 | "## Index" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": { 54 | "id": "WNw_x4zrPEhn", 55 | "colab_type": "text" 56 | }, 57 | "source": [ 58 | "1. [Generate a toy dataset](#Generate-a-toy-dataset).\n", 59 | "2. [Use Riemannian optimization to obtain GMM estimates](#Use-Riemannian-optimization-to-obtain-GMM-estimates).\n", 60 | "3. [GMM with real-world data using Riemannian optimization](#GMM-with-real-world-data-using-Riemannian-optimization)." 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "metadata": { 66 | "id": "h69X9iE7PEho", 67 | "colab_type": "text" 68 | }, 69 | "source": [ 70 | "## Riemannian Optimisation with `pymanopt` for inference in Gaussian mixture models" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "metadata": { 76 | "id": "Fvi8wp7kPEhp", 77 | "colab_type": "text" 78 | }, 79 | "source": [ 80 | "This notebook is the second in the series of two notebooks on Riemannian optimization and is based heavily on the [official mixture of Gaussian notebook](https://github.com/pymanopt/pymanopt/blob/master/examples/MoG.ipynb) from `pymanopt` docs. \n", 81 | "\n", 82 | "For the basic introduction, see the first part `riemannian_opt_for_ml.ipynb`." 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": { 88 | "id": "kUbxow6wPEhp", 89 | "colab_type": "text" 90 | }, 91 | "source": [ 92 | "Install the necessary libraries" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "metadata": { 98 | "id": "6RqFzcLpPEhq", 99 | "colab_type": "code", 100 | "colab": {} 101 | }, 102 | "source": [ 103 | "!pip install --upgrade git+https://github.com/mlss-skoltech/tutorials.git#subdirectory=geometric_techniques_in_ML" 104 | ], 105 | "execution_count": 0, 106 | "outputs": [] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "metadata": { 111 | "id": "sOGvTsCrPEhu", 112 | "colab_type": "code", 113 | "colab": {} 114 | }, 115 | "source": [ 116 | "!pip install pymanopt autograd\n", 117 | "!pip install scipy==1.2.1 -U" 118 | ], 119 | "execution_count": 0, 120 | "outputs": [] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "metadata": { 125 | "id": "cZDQc1LfQBRi", 126 | "colab_type": "code", 127 | "colab": {} 128 | }, 129 | "source": [ 130 | "import pkg_resources\n", 131 | "\n", 132 | "DATA_PATH = pkg_resources.resource_filename('riemannianoptimization', 'data/')" 133 | ], 134 | "execution_count": 0, 135 | "outputs": [] 136 | }, 137 | { 138 | "cell_type": "markdown", 139 | "metadata": { 140 | "id": "rgkSzxeVPEhw", 141 | "colab_type": "text" 142 | }, 143 | "source": [ 144 | "### Generate a toy dataset" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "metadata": { 150 | "id": "Q9auRDUvPEhx", 151 | "colab_type": "text" 152 | }, 153 | "source": [ 154 | "The Mixture of Gaussians (MoG) model assumes that datapoints $\\mathbf{x}_i\\in\\mathbb{R}^d$ follow a distribution described by the following probability density function:\n", 155 | "$$\n", 156 | "p(\\mathbf{x}) = \\sum_{m=1}^M \\pi_m p_\\mathcal{N}(\\mathbf{x};\\mathbf{\\mu}_m,\\mathbf{\\Sigma}_m)\n", 157 | "$$ \n", 158 | "\n", 159 | "where $\\pi_m$ is the probability that the data point belongs to the $m^\\text{th}$ mixture component and $p_\\mathcal{N}(\\mathbf{x};\\mathbf{\\mu}_m,\\mathbf{\\Sigma}_m)$ is the probability density function of a [multivariate Gaussian distribution](https://en.wikipedia.org/wiki/Multivariate_normal_distribution) with mean $\\mathbf{\\mu}_m \\in \\mathbb{R}^d$ and [positive semi-definite](https://en.wikipedia.org/wiki/Definiteness_of_a_matrix) (PSD) covariance matrix $\\mathbf{\\Sigma}_m \\in \\{\\mathbf{M}\\in\\mathbb{R}^{d\\times d}: \\mathbf{M}\\succeq 0\\}$.\n", 160 | "\n", 161 | "As an example consider the mixture of three Gaussians with means\n", 162 | "$$\n", 163 | "\\mathbf{\\mu}_1 = \\begin{bmatrix} -4 \\\\ 1 \\end{bmatrix},\n", 164 | "\\quad\n", 165 | "\\mathbf{\\mu}_2 = \\begin{bmatrix} 0 \\\\ 0 \\end{bmatrix},\n", 166 | "\\quad\n", 167 | "\\mathbf{\\mu}_3 = \\begin{bmatrix} 2 \\\\ -1 \\end{bmatrix},\n", 168 | "$$\n", 169 | "covariances\n", 170 | "$$\\mathbf{\\Sigma}_1 = \\begin{bmatrix} 3 & 0 \\\\ 0 & 1 \\end{bmatrix},\n", 171 | "\\mathbf{\\Sigma}_2 = \\begin{bmatrix} 1 & 1 \\\\ 1 & 3 \\end{bmatrix},\n", 172 | "\\mathbf{\\Sigma}_3 = \\begin{bmatrix} 0.5 & 0 \\\\ 0 & 0.5 \\end{bmatrix}$$\n", 173 | "and mixture probability vector $\\pi=\\left[0.1, 0.6, 0.3\\right]$.\n", 174 | "Let's generate $N=1000$ samples of that MoG model and scatter plot the samples:" 175 | ] 176 | }, 177 | { 178 | "cell_type": "markdown", 179 | "metadata": { 180 | "id": "QuYDX2UKPEhx", 181 | "colab_type": "text" 182 | }, 183 | "source": [ 184 | "Generate a synthetic dataset of $M=3$ Gaussian distributions, w" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "metadata": { 190 | "id": "ambBqXq6PEhy", 191 | "colab_type": "code", 192 | "colab": {} 193 | }, 194 | "source": [ 195 | "import numpy as np\n", 196 | "np.set_printoptions(precision=2)\n", 197 | "\n", 198 | "toy_n_points = 1000 # Number of data\n", 199 | "toy_dim = 2 # Dimension of data\n", 200 | "toy_components = 3 # Number of clusters \n", 201 | "\n", 202 | "# mixture parameters\n", 203 | "toy_pi = [0.1, 0.6, 0.3]\n", 204 | "toy_mus = [np.array([-4, 1]),\n", 205 | " np.array([0, 0]),\n", 206 | " np.array([2, -1])]\n", 207 | "toy_sigmas = [np.array([[3, 0],[0, 1]]),\n", 208 | " np.array([[1, 1.], [1, 3]]),\n", 209 | " .5 * np.eye(2)]\n", 210 | "\n", 211 | "# select which component work in each case\n", 212 | "components = np.random.choice(toy_components, size=toy_n_points, p=toy_pi)\n", 213 | "\n", 214 | "# prepare data\n", 215 | "samples = np.zeros((toy_n_points, toy_dim))\n", 216 | "\n", 217 | "# for each component, generate all needed samples\n", 218 | "for k in range(toy_components):\n", 219 | " # indices of current component in X\n", 220 | " indices = (k == components)\n", 221 | " # number of those occurrences\n", 222 | " n_k = indices.sum()\n", 223 | " if n_k > 0:\n", 224 | " samples[indices] = np.random.multivariate_normal(toy_mus[k], toy_sigmas[k], n_k)" 225 | ], 226 | "execution_count": 0, 227 | "outputs": [] 228 | }, 229 | { 230 | "cell_type": "markdown", 231 | "metadata": { 232 | "id": "rrePOBAiPEh1", 233 | "colab_type": "text" 234 | }, 235 | "source": [ 236 | "The following is a bunch of helper functions for visualizations." 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "metadata": { 242 | "id": "IEM2wHcMPEh2", 243 | "colab_type": "code", 244 | "colab": {} 245 | }, 246 | "source": [ 247 | "import numpy as np\n", 248 | "import matplotlib.pyplot as plt\n", 249 | "from matplotlib import cm # Colormaps\n", 250 | "\n", 251 | "\n", 252 | "def multivariate_normal(x, d, mean, covariance):\n", 253 | " \"\"\"pdf of the multivariate normal distribution.\"\"\"\n", 254 | " x_m = x - mean\n", 255 | " pdf = (1. / (np.sqrt((2 * np.pi)**d * np.linalg.det(covariance))) * \n", 256 | " np.exp(-(np.linalg.solve(covariance, x_m).T.dot(x_m)) / 2))\n", 257 | " return pdf\n", 258 | "\n", 259 | "\n", 260 | "# Plot bivariate distribution\n", 261 | "def generate_surface(mean, covariance, d):\n", 262 | " \"\"\"Helper function to generate density surface.\"\"\"\n", 263 | " nb_of_x = 100 # grid size\n", 264 | " # choose limits adaptively\n", 265 | "# mu1, mu2 = mean[:, 0]\n", 266 | "# sigmasq1, sigmasq2 = covariance[0, 0], covariance[1, 1]\n", 267 | "# min_x1 = mu1 - 3. * np.sqrt(sigmasq1)\n", 268 | "# max_x1 = mu1 + 3. * np.sqrt(sigmasq1)\n", 269 | "# min_x2 = mu2 - 3. * np.sqrt(sigmasq2)\n", 270 | "# max_x2 = mu2 + 3. * np.sqrt(sigmasq2)\n", 271 | "# print(min_x1, max_x1)\n", 272 | "# print(min_x2, max_x2)\n", 273 | " min_x1, max_x1 = -4, 4\n", 274 | " min_x2, max_x2 = -4, 4\n", 275 | " x1s = np.linspace(min_x1, max_x1, num=nb_of_x)\n", 276 | " x2s = np.linspace(min_x2, max_x2, num=nb_of_x)\n", 277 | " x1, x2 = np.meshgrid(x1s, x2s) # Generate grid\n", 278 | " pdf = np.zeros((nb_of_x, nb_of_x))\n", 279 | " \n", 280 | " # Fill the cost matrix for each combination of weights\n", 281 | " for i in range(nb_of_x):\n", 282 | " for j in range(nb_of_x):\n", 283 | " pdf[i,j] = multivariate_normal(\n", 284 | " np.matrix([[x1[i,j]], [x2[i,j]]]), \n", 285 | " d, mean, covariance)\n", 286 | " return x1, x2, pdf # x1, x2, pdf(x1,x2)\n", 287 | "\n", 288 | "\n", 289 | "def plot_gaussian(mu, sigma, ax):\n", 290 | " bivariate_mean = np.matrix(mu) # Mean\n", 291 | " bivariate_covariance = np.matrix(sigma) # Covariance\n", 292 | " x1, x2, p = generate_surface(\n", 293 | " bivariate_mean, bivariate_covariance, d=2) \n", 294 | " # Plot bivariate distribution\n", 295 | " con = ax.contour(x1, x2, p, 10, cmap=cm.hot)\n", 296 | " # ax2.axis([-2.5, 2.5, -1.5, 3.5])\n", 297 | " ax.set_aspect('equal')" 298 | ], 299 | "execution_count": 0, 300 | "outputs": [] 301 | }, 302 | { 303 | "cell_type": "code", 304 | "metadata": { 305 | "id": "kIPDqDjLPEh4", 306 | "colab_type": "code", 307 | "colab": {} 308 | }, 309 | "source": [ 310 | "fig = plt.figure(figsize=(8,8))\n", 311 | "ax = fig.gca()\n", 312 | "\n", 313 | "for mu, sigma in zip(toy_mus, toy_sigmas):\n", 314 | " mu = np.matrix(mu).T # plot_gaussian requires mu to be a column vector\n", 315 | " plot_gaussian(mu, sigma, ax)\n", 316 | "\n", 317 | "colors = ['r', 'g', 'b', 'c', 'm']\n", 318 | "for i in range(toy_components):\n", 319 | " indices = (i == components)\n", 320 | " ax.scatter(samples[indices, 0], samples[indices, 1], alpha=.4, color=colors[i % toy_components])\n" 321 | ], 322 | "execution_count": 0, 323 | "outputs": [] 324 | }, 325 | { 326 | "cell_type": "markdown", 327 | "metadata": { 328 | "id": "Zblgje-JPEh9", 329 | "colab_type": "text" 330 | }, 331 | "source": [ 332 | "### Use Riemannian optimization to obtain GMM estimates" 333 | ] 334 | }, 335 | { 336 | "cell_type": "markdown", 337 | "metadata": { 338 | "id": "6l5-6FsOPEh-", 339 | "colab_type": "text" 340 | }, 341 | "source": [ 342 | "Given a data sample the de facto standard method to infer the parameters is the [expectation maximisation](https://en.wikipedia.org/wiki/Expectation-maximization_algorithm) (EM) algorithm that, in alternating so-called E and M steps, maximises the log-likelihood of the data.\n", 343 | "\n", 344 | "In [arXiv:1506.07677](http://arxiv.org/pdf/1506.07677v1.pdf) Hosseini and Sra propose Riemannian optimisation as a powerful counterpart to EM. Importantly, they introduce a reparameterisation that leaves local optima of the log-likelihood unchanged while resulting in a geodesically convex optimisation problem over a product manifold $\\prod_{m=1}^M\\mathcal{PD}^{(d+1)\\times(d+1)}$ of manifolds of $(d+1)\\times(d+1)$ positive definite matrices.\n", 345 | "The proposed method is on par with EM and shows less variability in running times.\n", 346 | "\n", 347 | "The reparameterised optimisation problem for augmented data points $\\mathbf{y}_i=[\\mathbf{x}_i\\ 1]$ can be stated as follows:\n", 348 | "\n", 349 | "$$\\min_{(S_1, ..., S_m, \\nu_1, ..., \\nu_{m-1}) \\in \\prod_{m=1}^M \\mathcal{PD}^{(d+1)\\times(d+1)}\\times\\mathbb{R}^{M-1}}\n", 350 | "-\\sum_{n=1}^N\\log\\left(\n", 351 | "\\sum_{m=1}^M \\frac{\\exp(\\nu_m)}{\\sum_{k=1}^M\\exp(\\nu_k)}\n", 352 | "q_\\mathcal{N}(\\mathbf{y}_n;\\mathbf{S}_m)\n", 353 | "\\right)$$\n", 354 | "\n", 355 | "where\n", 356 | "\n", 357 | "* $\\mathcal{PD}^{(d+1)\\times(d+1)}$ is the manifold of positive definite\n", 358 | "$(d+1)\\times(d+1)$ matrices\n", 359 | "* $\\mathcal{\\nu}_m = \\log\\left(\\frac{\\alpha_m}{\\alpha_M}\\right), \\ m=1, ..., M-1$ and $\\nu_M=0$\n", 360 | "* $q_\\mathcal{N}(\\mathbf{y}_n;\\mathbf{S}_m) =\n", 361 | "2\\pi\\exp\\left(\\frac{1}{2}\\right)\n", 362 | "|\\operatorname{det}(\\mathbf{S}_m)|^{-\\frac{1}{2}}(2\\pi)^{-\\frac{d+1}{2}}\n", 363 | "\\exp\\left(-\\frac{1}{2}\\mathbf{y}_i^\\top\\mathbf{S}_m^{-1}\\mathbf{y}_i\\right)$\n", 364 | "\n", 365 | "**Optimisation problems like this can easily be solved using Pymanopt – even without the need to differentiate the cost function manually!**\n", 366 | "\n", 367 | "So let's infer the parameters of our toy example by Riemannian optimisation using Pymanopt:" 368 | ] 369 | }, 370 | { 371 | "cell_type": "code", 372 | "metadata": { 373 | "id": "iHYOYVfMPEh_", 374 | "colab_type": "code", 375 | "colab": {} 376 | }, 377 | "source": [ 378 | "import pymanopt as opt\n", 379 | "import pymanopt.solvers as solvers\n", 380 | "import pymanopt.manifolds as manifolds" 381 | ], 382 | "execution_count": 0, 383 | "outputs": [] 384 | }, 385 | { 386 | "cell_type": "code", 387 | "metadata": { 388 | "id": "UYYpl8X4PEiB", 389 | "colab_type": "code", 390 | "colab": {} 391 | }, 392 | "source": [ 393 | "import autograd.numpy as np # import here to avoid errors\n", 394 | "from autograd.scipy.misc import logsumexp\n", 395 | "\n", 396 | "# (1) Instantiate the manifold\n", 397 | "manifold = manifolds.Product([\n", 398 | " manifolds.PositiveDefinite(toy_dim + 1, k=toy_components), \n", 399 | " manifolds.Euclidean(toy_components - 1)\n", 400 | "])\n", 401 | "\n", 402 | "# (2) Define cost function\n", 403 | "# The parameters must be contained in a list theta.\n", 404 | "def cost(theta):\n", 405 | " # Unpack parameters\n", 406 | " nu = np.concatenate([theta[1], [0]], axis=0)\n", 407 | " \n", 408 | " S = theta[0]\n", 409 | " logdetS = np.expand_dims(np.linalg.slogdet(S)[1], 1)\n", 410 | " y = np.concatenate([samples.T, np.ones((1, len(samples)))], axis=0)\n", 411 | "\n", 412 | " # Calculate log_q\n", 413 | " y = np.expand_dims(y, 0)\n", 414 | " \n", 415 | " # 'Probability' of y belonging to each cluster\n", 416 | " log_q = -0.5 * (np.sum(y * np.linalg.solve(S, y), axis=1) + logdetS)\n", 417 | "\n", 418 | " alpha = np.exp(nu)\n", 419 | " alpha = alpha / np.sum(alpha)\n", 420 | " alpha = np.expand_dims(alpha, 1)\n", 421 | " \n", 422 | " loglikvec = logsumexp(np.log(alpha) + log_q, axis=0)\n", 423 | " return -np.sum(loglikvec)\n", 424 | "\n", 425 | "\n", 426 | "problem = opt.Problem(manifold=manifold, cost=cost, verbosity=2)\n", 427 | "\n", 428 | "# (3) Instantiate a Pymanopt solver\n", 429 | "solver = solvers.SteepestDescent()\n", 430 | "\n", 431 | "# let Pymanopt do the rest\n", 432 | "Xopt = solver.solve(problem)" 433 | ], 434 | "execution_count": 0, 435 | "outputs": [] 436 | }, 437 | { 438 | "cell_type": "markdown", 439 | "metadata": { 440 | "id": "3yvE_pHbPEiE", 441 | "colab_type": "text" 442 | }, 443 | "source": [ 444 | "Once Pymanopt has finished the optimisation we can obtain the inferred parameters as follows:" 445 | ] 446 | }, 447 | { 448 | "cell_type": "code", 449 | "metadata": { 450 | "id": "KIxF4cPzPEiF", 451 | "colab_type": "code", 452 | "colab": {} 453 | }, 454 | "source": [ 455 | "def extract_gaussian_parameters(Xopt, n=1):\n", 456 | " params, probas = Xopt\n", 457 | " \n", 458 | " mus, sigmas = [], []\n", 459 | " \n", 460 | " for p in params:\n", 461 | " mu = p[0:2,2:3]\n", 462 | " sigma = p[:2, :2] - mu.dot(mu.T)\n", 463 | " mus.append(mu)\n", 464 | " sigmas.append(sigma)\n", 465 | " \n", 466 | " pis = np.exp(np.concatenate([probas, [0]], axis=0))\n", 467 | " pis = pis / np.sum(pis)\n", 468 | " \n", 469 | " return mus, sigmas, pis" 470 | ], 471 | "execution_count": 0, 472 | "outputs": [] 473 | }, 474 | { 475 | "cell_type": "code", 476 | "metadata": { 477 | "id": "g7H7_-WMPEiH", 478 | "colab_type": "code", 479 | "colab": {} 480 | }, 481 | "source": [ 482 | "toy_mus_opt, toy_sigmas_opt, toy_pis_opt = extract_gaussian_parameters(Xopt, n=3)\n", 483 | "toy_mus_opt, toy_sigmas_opt, toy_pis_opt" 484 | ], 485 | "execution_count": 0, 486 | "outputs": [] 487 | }, 488 | { 489 | "cell_type": "code", 490 | "metadata": { 491 | "id": "21rmQP1iPEiJ", 492 | "colab_type": "code", 493 | "colab": {} 494 | }, 495 | "source": [ 496 | "fig = plt.figure(figsize=(8,8))\n", 497 | "ax = fig.gca()\n", 498 | "\n", 499 | "for mu, sigma in zip(toy_mus_opt, toy_sigmas_opt):\n", 500 | " plot_gaussian(mu, sigma, ax)\n", 501 | "\n", 502 | "colors = ['r', 'g', 'b', 'c', 'm']\n", 503 | "for i in range(toy_components):\n", 504 | " indices = (i == components)\n", 505 | " ax.scatter(samples[indices, 0], samples[indices, 1], alpha=.4, color=colors[i % toy_components])\n" 506 | ], 507 | "execution_count": 0, 508 | "outputs": [] 509 | }, 510 | { 511 | "cell_type": "markdown", 512 | "metadata": { 513 | "id": "WMPEQeFgPEiL", 514 | "colab_type": "text" 515 | }, 516 | "source": [ 517 | "And convince ourselves that the inferred parameters are close to the ground truth parameters." 518 | ] 519 | }, 520 | { 521 | "cell_type": "markdown", 522 | "metadata": { 523 | "id": "WdO9uKI3PEiM", 524 | "colab_type": "text" 525 | }, 526 | "source": [ 527 | "### GMM with real-world data using Riemannian optimization" 528 | ] 529 | }, 530 | { 531 | "cell_type": "markdown", 532 | "metadata": { 533 | "id": "0o9UHzqXPEiN", 534 | "colab_type": "text" 535 | }, 536 | "source": [ 537 | "Certain real-world datasets can be sufficiently closely modelled by the GMM. One instance might be low-dimensional word embeddings. An accompanying notebook `riemannian_opt_text_preprocessing.ipynb` details how these data were obtained. " 538 | ] 539 | }, 540 | { 541 | "cell_type": "code", 542 | "metadata": { 543 | "id": "u9g3ModaPEiN", 544 | "colab_type": "code", 545 | "colab": {} 546 | }, 547 | "source": [ 548 | "import pandas as pd\n", 549 | "df = pd.read_csv(DATA_PATH + 'tsne_result_training_part.csv', index_col=0)\n", 550 | "df" 551 | ], 552 | "execution_count": 0, 553 | "outputs": [] 554 | }, 555 | { 556 | "cell_type": "code", 557 | "metadata": { 558 | "id": "qD4oYVAkPEiP", 559 | "colab_type": "code", 560 | "colab": {} 561 | }, 562 | "source": [ 563 | "samples = df[['x', 'y']].values" 564 | ], 565 | "execution_count": 0, 566 | "outputs": [] 567 | }, 568 | { 569 | "cell_type": "code", 570 | "metadata": { 571 | "id": "amEL6JivPEiR", 572 | "colab_type": "code", 573 | "colab": {} 574 | }, 575 | "source": [ 576 | "plt.scatter(samples[:, 0], samples[:, 1], alpha=0.5)" 577 | ], 578 | "execution_count": 0, 579 | "outputs": [] 580 | }, 581 | { 582 | "cell_type": "markdown", 583 | "metadata": { 584 | "id": "QIiQiVMBPEiT", 585 | "colab_type": "text" 586 | }, 587 | "source": [ 588 | "For the optimization to be a little more stable, we standardize the data." 589 | ] 590 | }, 591 | { 592 | "cell_type": "code", 593 | "metadata": { 594 | "id": "4EFQFPxXPEiU", 595 | "colab_type": "code", 596 | "colab": {} 597 | }, 598 | "source": [ 599 | "from sklearn.preprocessing import StandardScaler\n", 600 | "samples = StandardScaler().fit_transform(samples)\n", 601 | "\n", 602 | "plt.scatter(samples[:, 0], samples[:, 1], alpha=0.5)" 603 | ], 604 | "execution_count": 0, 605 | "outputs": [] 606 | }, 607 | { 608 | "cell_type": "markdown", 609 | "metadata": { 610 | "id": "V-7O_l5mPEiW", 611 | "colab_type": "text" 612 | }, 613 | "source": [ 614 | "Use pretty much the same codes as above, changing the number of components and sample size accordingly." 615 | ] 616 | }, 617 | { 618 | "cell_type": "code", 619 | "metadata": { 620 | "id": "s49QjNaPPEiX", 621 | "colab_type": "code", 622 | "colab": {} 623 | }, 624 | "source": [ 625 | "real_components = 4\n", 626 | "real_dim = 2\n", 627 | "real_points = len(samples)" 628 | ], 629 | "execution_count": 0, 630 | "outputs": [] 631 | }, 632 | { 633 | "cell_type": "code", 634 | "metadata": { 635 | "id": "9AItZaTnPEie", 636 | "colab_type": "code", 637 | "colab": {} 638 | }, 639 | "source": [ 640 | "import autograd.numpy as np # import here to avoid errors\n", 641 | "from autograd.scipy.misc import logsumexp\n", 642 | "\n", 643 | "# (1) Instantiate the manifold\n", 644 | "manifold = manifolds.Product([\n", 645 | " manifolds.PositiveDefinite(real_dim + 1, k=real_components), \n", 646 | " manifolds.Euclidean(real_components - 1)\n", 647 | "])\n", 648 | "\n", 649 | "# (2) Define cost function\n", 650 | "# The parameters must be contained in a list theta.\n", 651 | "def cost(theta):\n", 652 | " # Unpack parameters\n", 653 | " nu = np.concatenate([theta[1], [0]], axis=0)\n", 654 | " \n", 655 | " S = theta[0]\n", 656 | " logdetS = np.expand_dims(np.linalg.slogdet(S)[1], 1)\n", 657 | " y = np.concatenate([samples.T, np.ones((1, real_points))], axis=0)\n", 658 | "\n", 659 | " # Calculate log_q\n", 660 | " y = np.expand_dims(y, 0)\n", 661 | " \n", 662 | " # 'Probability' of y belonging to each cluster\n", 663 | " log_q = -0.5 * (np.sum(y * np.linalg.solve(S, y), axis=1) + logdetS)\n", 664 | "\n", 665 | " alpha = np.exp(nu)\n", 666 | " alpha = alpha / np.sum(alpha)\n", 667 | " alpha = np.expand_dims(alpha, 1)\n", 668 | " \n", 669 | " loglikvec = logsumexp(np.log(alpha) + log_q, axis=0)\n", 670 | " return -np.sum(loglikvec)\n", 671 | "\n", 672 | "\n", 673 | "problem = opt.Problem(manifold=manifold, cost=cost, verbosity=2)\n", 674 | "\n", 675 | "# (3) Instantiate a Pymanopt solver\n", 676 | "solver = solvers.SteepestDescent()\n", 677 | "\n", 678 | "# let Pymanopt do the rest\n", 679 | "Xopt = solver.solve(problem)" 680 | ], 681 | "execution_count": 0, 682 | "outputs": [] 683 | }, 684 | { 685 | "cell_type": "code", 686 | "metadata": { 687 | "id": "hwiwvxgSPEih", 688 | "colab_type": "code", 689 | "colab": {} 690 | }, 691 | "source": [ 692 | "real_mus_opt, real_sigmas_opt, real_pis_opt = extract_gaussian_parameters(Xopt, n=3)\n", 693 | "real_mus_opt, real_sigmas_opt, real_pis_opt" 694 | ], 695 | "execution_count": 0, 696 | "outputs": [] 697 | }, 698 | { 699 | "cell_type": "code", 700 | "metadata": { 701 | "scrolled": false, 702 | "id": "sJjPaye6PEij", 703 | "colab_type": "code", 704 | "colab": {} 705 | }, 706 | "source": [ 707 | "fig = plt.figure(figsize=(8,8))\n", 708 | "ax = fig.gca()\n", 709 | "\n", 710 | "for mu, sigma in zip(real_mus_opt, real_sigmas_opt):\n", 711 | " plot_gaussian(mu, sigma, ax)\n", 712 | "\n", 713 | "ax.scatter(samples[:, 0], samples[:, 1], alpha=0.5)" 714 | ], 715 | "execution_count": 0, 716 | "outputs": [] 717 | }, 718 | { 719 | "cell_type": "markdown", 720 | "metadata": { 721 | "id": "RLge4dRLPEil", 722 | "colab_type": "text" 723 | }, 724 | "source": [ 725 | "Et voilà – this was a brief demonstration of how to do inference for MoG models by performing Manifold optimisation using Pymanopt." 726 | ] 727 | }, 728 | { 729 | "cell_type": "markdown", 730 | "metadata": { 731 | "id": "VPVLD88NPEil", 732 | "colab_type": "text" 733 | }, 734 | "source": [ 735 | "**TODO HOMEWORK** add riemannian optimization in M-step to speed up EM" 736 | ] 737 | }, 738 | { 739 | "cell_type": "code", 740 | "metadata": { 741 | "id": "63OLLMvtPEim", 742 | "colab_type": "code", 743 | "colab": {} 744 | }, 745 | "source": [ 746 | "" 747 | ], 748 | "execution_count": 0, 749 | "outputs": [] 750 | } 751 | ] 752 | } -------------------------------------------------------------------------------- /geometric_techniques_in_ML/riemannianoptimization/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/geometric_techniques_in_ML/riemannianoptimization/__init__.py -------------------------------------------------------------------------------- /geometric_techniques_in_ML/riemannianoptimization/tutorial_helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | 7 | # hardcode landmark indexes 8 | left_brow = np.array([1, 5, 3, 6, 1]) - 1 9 | right_brow = np.array([4, 7, 2, 8, 4]) - 1 10 | 11 | left_eye = np.array([9, 13, 11, 14, 9]) - 1 12 | right_eye = np.array([12, 15, 10, 16, 12]) - 1 13 | 14 | nosetip = np.array([19, 21, 20, 22, 19]) - 1 15 | mouth_outerlip = np.array([23, 25, 24, 28, 23]) - 1 16 | mouth_innerlip = np.array([23, 26, 24, 27, 23]) - 1 17 | 18 | face_outer = np.array([29, 33, 31, 35, 32, 34, 30]) - 1 19 | 20 | CONTOURS = [left_brow, right_brow, left_eye, right_eye, 21 | nosetip, mouth_outerlip, mouth_innerlip, face_outer] 22 | 23 | 24 | def get_contours(image): 25 | for contour_idx in CONTOURS: 26 | x = image[contour_idx, 0] 27 | y = -image[contour_idx, 1] 28 | yield x, y 29 | 30 | 31 | 32 | def plot_landmarks(landmarks, ax=None, draw_landmark_id=False, draw_contours=True, draw_landmarks=True, 33 | alpha=1, color_landmarks='red', color_contour='orange', get_contour_handles=False): 34 | """Plots landmarks, connecting them appropriately. 35 | 36 | landmarks: ndarray of shape either [35, 2] or [70,] 37 | ax: axis (created if None) 38 | """ 39 | if None is ax: 40 | f = plt.figure(figsize=(8, 8)) 41 | ax = f.gca() 42 | 43 | ax.tick_params( 44 | axis='both', # changes apply to both axes 45 | which='both', # both major and minor ticks are affected 46 | bottom=False, # ticks along the bottom edge are off 47 | top=False, # ticks along the top edge are off 48 | left=False, 49 | right=False, 50 | labelbottom=False, 51 | labeltop=False, 52 | labelleft=False, 53 | labelright=False) 54 | 55 | if landmarks.shape == (70,): 56 | landmarks = landmarks.reshape((35, 2)) 57 | 58 | contour_handles = [] 59 | if draw_contours: 60 | 61 | def _plot_landmark(landmarks, idx, color): 62 | h, = ax.plot(landmarks[idx, 0], -landmarks[idx, 1], color=color) 63 | return h 64 | 65 | contour_handles = [ 66 | _plot_landmark(landmarks, idx, color_contour) 67 | for idx in CONTOURS 68 | ] 69 | 70 | if draw_landmarks: 71 | ax.scatter(landmarks[:, 0], -landmarks[:, 1], s=20, color=color_landmarks, alpha=alpha) 72 | 73 | if draw_landmark_id: 74 | for i in range(35): 75 | ax.text(s=str(i + 1), x=landmarks[i, 0], y=-landmarks[i, 1]) 76 | 77 | if get_contour_handles: 78 | return contour_handles 79 | 80 | 81 | 82 | def load_data(data_path): 83 | df = pd.read_csv(data_path + 'kbvt_lfpw_v1_train.csv', delimiter='\t') 84 | 85 | # We don't need all of the columns -- only the ones with landmarks 86 | columns_to_include = [col for col in df.columns.tolist() 87 | if col.endswith('_x') or col.endswith('_y')] 88 | print('Selecting the following columns from the dataset: {}'.format('\n'.join(columns_to_include))) 89 | 90 | # select only averaged predictions 91 | data = df[columns_to_include][df['worker'] == 'average'] 92 | landmarks = data.values 93 | 94 | print('\n\n The resulting dataset has shape {}'.format(landmarks.shape)) 95 | 96 | return landmarks 97 | 98 | 99 | def prepare_html_for_scatter_plot(projected_shapes): 100 | xs = '[' + ','.join(map(str, projected_shapes[:, 0])) + ']' 101 | ys = '[' + ','.join(map(str, projected_shapes[:, 1])) + ']' 102 | return f'x: {xs}, y: {ys}' 103 | 104 | 105 | def prepare_html_for_landmarks(landmarks_for_one_sample): 106 | landmarks = landmarks_for_one_sample.reshape(35, 2) 107 | xs = '[' + ',"nan",'.join(','.join(map(str, landmarks[idx, 0])) for idx in CONTOURS) + ']' 108 | ys = '[' + ',"nan",'.join(','.join(map(str, -landmarks[idx, 1])) for idx in CONTOURS) + ']' 109 | return f'[{{x: {xs}, y: {ys}, type: "scatter", mode: "lines+markers", line: {{width: 1, color: "orange"}}, marker: {{size: 3, color: "red"}} }}]' 110 | 111 | 112 | def prepare_html_for_all_landmarks(landmarks): 113 | return '[' + ','.join(map(prepare_html_for_landmarks, landmarks)) + ']' 114 | 115 | 116 | def prepare_html_for_visualization(projected_shapes, landmarks, scatterplot_size=[700, 700], annotation_size=[100, 100], floating_annotation=True): 117 | scatter_data = prepare_html_for_scatter_plot(projected_shapes) 118 | scatter_width = str(scatterplot_size[0]) 119 | scatter_height = str(scatterplot_size[1]) 120 | 121 | annotation_data = prepare_html_for_all_landmarks(landmarks) 122 | annotation_width = str(annotation_size[0]) 123 | annotation_height = str(annotation_size[1]) 124 | 125 | html = ''' 126 | 127 | 128 | 129 |
130 |
131 | 132 | 133 | ''' 192 | return html 193 | 194 | 195 | 196 | 197 | __all__ = [ 198 | "load_data", 199 | "plot_landmarks", 200 | "prepare_html_for_visualization", 201 | ] -------------------------------------------------------------------------------- /geometric_techniques_in_ML/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | 3 | 4 | setup( 5 | name="riemannianoptimization", 6 | version="0.1", 7 | include_package_data=True, 8 | packages=[ 9 | "riemannianoptimization", 10 | ] 11 | 12 | ) -------------------------------------------------------------------------------- /img/img0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/img/img0.png -------------------------------------------------------------------------------- /img/img1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/img/img1.png -------------------------------------------------------------------------------- /img/img2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/img/img2.png -------------------------------------------------------------------------------- /img/img3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/img/img3.png -------------------------------------------------------------------------------- /img/img4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/img/img4.png -------------------------------------------------------------------------------- /img/img5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/img/img5.png -------------------------------------------------------------------------------- /img/img6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/img/img6.png -------------------------------------------------------------------------------- /img/img7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/img/img7.png -------------------------------------------------------------------------------- /img/img8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/img/img8.png -------------------------------------------------------------------------------- /kernels/README.md: -------------------------------------------------------------------------------- 1 | This is the practical component of the [Machine Learning Summer School, Moscow 2019](https://mlss2019.skoltech.ru/) session on kernels, focusing on hypothesis testing with kernel statistics. 2 | 3 | The materials here are most recently by 4 | [Dougal Sutherland](http://www.gatsby.ucl.ac.uk/~dougals/) 5 | with consultation from [Arthur Gretton](http://www.gatsby.ucl.ac.uk/~gretton/), 6 | updated from [a previous course](https://github.com/dougalsutherland/ds3-kernels/), 7 | and based in large part on [earlier materials](https://github.com/karlnapf/ds3_kernel_testing) 8 | by [Heiko Strathmann](http://herrstrathmann.de/). 9 | 10 | We'll cover, in varying levels of detail, the following topics: 11 | 12 | - Two-sample testing with the kernel Maximum Mean Discrepancy (MMD). 13 | - Basic concepts of hypothesis testing, including permutation tests. 14 | - Computing kernel values. 15 | - Estimators for the MMD. 16 | - Learning an appropriate kernel function. 17 | - Independence testing with the Hilbert-Schmidt Independence Criterion. 18 | 19 | 20 | ## Dependencies 21 | 22 | ### Colab 23 | 24 | This notebook is [available on Google Colab](https://colab.research.google.com/github/dougalsutherland/mlss-testing/blob/built/testing.ipynb). You don't have to set anything up yourself and it runs on cloud resources, so this is probably the easiest option if you trust that your network connection is going to be reasonably reliable. Make a copy to your own Google Drive to save your progress, and to use a GPU, click Runtime -> Change runtime type -> Hardware accelerator -> GPU. Everything you need is already installed on Colab; use a Python 3 notebook. 25 | 26 | ### Local setup 27 | 28 | Run `check_imports.py` to see if everything you need is installed and downloaded. If that works, you're set; otherwise, read on. 29 | 30 | 31 | #### Files 32 | There are a few Python files and some data files in the repository. By far the easiest thing to do is just put them all in the same directory: 33 | 34 | ``` 35 | git clone https://github.com/dougalsutherland/mlss-testing 36 | ``` 37 | 38 | #### Python version 39 | This notebook requires Python 3.6+. Python 3.0 was released in 2008, and it's time to stop living in the past; most importart Python projects [are dropping support for Python 2 this year](https://python3statement.org/). If you've never used Python 3 before, don't worry! It's almost the same; for the purposes of this notebook, you probably only need to know that you should write `print("hi")` since it's a function call now, and you can write `A @ B` instead of `A.dot(B)`. 40 | 41 | #### Python packages 42 | 43 | The main thing we use is PyTorch and Jupyter. If you already have those set up, you should be fine; just additionally make sure you also have (with `conda install` or `pip install`) `seaborn`, `tqdm`, and `sckit-learn`. We import everything right at the start, so if that runs you shouldn't hit any surprises later on. 44 | 45 | If you don't already have a setup you're happy with, we recommend the `conda` package manager - start by installing [miniconda](https://docs.conda.io/en/latest/miniconda.html). Then you can create an environment with everything you need as: 46 | 47 | ```bash 48 | conda create --name mlss-testing --override-channels -c pytorch -c defaults --strict-channel-priority python=3 notebook ipywidgets numpy scipy scikit-learn pytorch=1.1 torchvision matplotlib seaborn tqdm 49 | conda activate mlss-testing 50 | 51 | git clone https://github.com/dougalsutherland/mlss-testing 52 | cd mlss-testing 53 | python check_imports.py 54 | jupyter notebook 55 | ``` 56 | 57 | (If you have an old conda setup, you can use `source activate` instead of `conda activate`, but it's better to [switch to the new style of activation](https://conda.io/projects/conda/en/latest/release-notes.html#recommended-change-to-enable-conda-in-your-shell). This won't matter for this tutorial, but it's general good practice.) 58 | 59 | (You can make your life easier when using jupyter notebooks with multiple kernels by installing `nb_conda_kernels`, but as long as you install and run `jupyter` from inside the env it will also be fine.) 60 | 61 | 62 | ## PyTorch 63 | 64 | We're going to use PyTorch in this tutorial, even though we're not doing a ton of "deep learning." (The CPU version will be fine, though a GPU might let you get slightly better performance in some of the "advanced" sections.) 65 | 66 | If you haven't used PyTorch before, don't worry! The API is unfortunately a little different from NumPy (and TensorFlow), but it's pretty easy to get used to; you can refer to [a cheat sheet vs NumPy](https://github.com/wkentaro/pytorch-for-numpy-users/blob/master/README.md) as well as the docs: [tensor methods](https://pytorch.org/docs/stable/tensors.html) and [the `torch` namespace](https://pytorch.org/docs/stable/torch.html#torch.eq). Feel free to ask if you have trouble figuring something out. 67 | 68 | You can convert a `torch.Tensor` to a `numpy.ndarray` with [`t.numpy()`](https://pytorch.org/docs/stable/tensors.html#torch.Tensor.numpy), and vice versa with [`torch.as_tensor()`](https://pytorch.org/docs/stable/torch.html#torch.as_tensor). (These share data when possible.) Doing this breaks PyTorch's ability to track gradients through these objects, but it's okay for things we won't need to take derivatives of. If you have a one-element tensor, you can get a regular Python number out of it with [`t.item()`](https://pytorch.org/docs/stable/tensors.html#torch.Tensor.item). 69 | -------------------------------------------------------------------------------- /kernels/dril-heuristic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/kernels/dril-heuristic.png -------------------------------------------------------------------------------- /kernels/probability_testing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/kernels/probability_testing/__init__.py -------------------------------------------------------------------------------- /kernels/probability_testing/data/almost_simple.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/kernels/probability_testing/data/almost_simple.npz -------------------------------------------------------------------------------- /kernels/probability_testing/data/blobs.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/kernels/probability_testing/data/blobs.npz -------------------------------------------------------------------------------- /kernels/probability_testing/data/blobs2.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/kernels/probability_testing/data/blobs2.npz -------------------------------------------------------------------------------- /kernels/probability_testing/data/blobs_single.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/kernels/probability_testing/data/blobs_single.npz -------------------------------------------------------------------------------- /kernels/probability_testing/data/gan-samples.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/kernels/probability_testing/data/gan-samples.npz -------------------------------------------------------------------------------- /kernels/probability_testing/data/hsic.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/kernels/probability_testing/data/hsic.npz -------------------------------------------------------------------------------- /kernels/probability_testing/data/simple.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/kernels/probability_testing/data/simple.npz -------------------------------------------------------------------------------- /kernels/probability_testing/data/stopwords-english.txt: -------------------------------------------------------------------------------- 1 | i 2 | me 3 | my 4 | myself 5 | we 6 | our 7 | ours 8 | ourselves 9 | you 10 | you're 11 | you've 12 | you'll 13 | you'd 14 | your 15 | yours 16 | yourself 17 | yourselves 18 | he 19 | him 20 | his 21 | himself 22 | she 23 | she's 24 | her 25 | hers 26 | herself 27 | it 28 | it's 29 | its 30 | itself 31 | they 32 | them 33 | their 34 | theirs 35 | themselves 36 | what 37 | which 38 | who 39 | whom 40 | this 41 | that 42 | that'll 43 | these 44 | those 45 | am 46 | is 47 | are 48 | was 49 | were 50 | be 51 | been 52 | being 53 | have 54 | has 55 | had 56 | having 57 | do 58 | does 59 | did 60 | doing 61 | a 62 | an 63 | the 64 | and 65 | but 66 | if 67 | or 68 | because 69 | as 70 | until 71 | while 72 | of 73 | at 74 | by 75 | for 76 | with 77 | about 78 | against 79 | between 80 | into 81 | through 82 | during 83 | before 84 | after 85 | above 86 | below 87 | to 88 | from 89 | up 90 | down 91 | in 92 | out 93 | on 94 | off 95 | over 96 | under 97 | again 98 | further 99 | then 100 | once 101 | here 102 | there 103 | when 104 | where 105 | why 106 | how 107 | all 108 | any 109 | both 110 | each 111 | few 112 | more 113 | most 114 | other 115 | some 116 | such 117 | no 118 | nor 119 | not 120 | only 121 | own 122 | same 123 | so 124 | than 125 | too 126 | very 127 | s 128 | t 129 | can 130 | will 131 | just 132 | don 133 | don't 134 | should 135 | should've 136 | now 137 | d 138 | ll 139 | m 140 | o 141 | re 142 | ve 143 | y 144 | ain 145 | aren 146 | aren't 147 | couldn 148 | couldn't 149 | didn 150 | didn't 151 | doesn 152 | doesn't 153 | hadn 154 | hadn't 155 | hasn 156 | hasn't 157 | haven 158 | haven't 159 | isn 160 | isn't 161 | ma 162 | mightn 163 | mightn't 164 | mustn 165 | mustn't 166 | needn 167 | needn't 168 | shan 169 | shan't 170 | shouldn 171 | shouldn't 172 | wasn 173 | wasn't 174 | weren 175 | weren't 176 | won 177 | won't 178 | wouldn 179 | wouldn't 180 | -------------------------------------------------------------------------------- /kernels/probability_testing/data/stopwords-french.txt: -------------------------------------------------------------------------------- 1 | au 2 | aux 3 | avec 4 | ce 5 | ces 6 | dans 7 | de 8 | des 9 | du 10 | elle 11 | en 12 | et 13 | eux 14 | il 15 | je 16 | la 17 | le 18 | leur 19 | lui 20 | ma 21 | mais 22 | me 23 | même 24 | mes 25 | moi 26 | mon 27 | ne 28 | nos 29 | notre 30 | nous 31 | on 32 | ou 33 | par 34 | pas 35 | pour 36 | qu 37 | que 38 | qui 39 | sa 40 | se 41 | ses 42 | son 43 | sur 44 | ta 45 | te 46 | tes 47 | toi 48 | ton 49 | tu 50 | un 51 | une 52 | vos 53 | votre 54 | vous 55 | c 56 | d 57 | j 58 | l 59 | à 60 | m 61 | n 62 | s 63 | t 64 | y 65 | été 66 | étée 67 | étées 68 | étés 69 | étant 70 | étante 71 | étants 72 | étantes 73 | suis 74 | es 75 | est 76 | sommes 77 | êtes 78 | sont 79 | serai 80 | seras 81 | sera 82 | serons 83 | serez 84 | seront 85 | serais 86 | serait 87 | serions 88 | seriez 89 | seraient 90 | étais 91 | était 92 | étions 93 | étiez 94 | étaient 95 | fus 96 | fut 97 | fûmes 98 | fûtes 99 | furent 100 | sois 101 | soit 102 | soyons 103 | soyez 104 | soient 105 | fusse 106 | fusses 107 | fût 108 | fussions 109 | fussiez 110 | fussent 111 | ayant 112 | ayante 113 | ayantes 114 | ayants 115 | eu 116 | eue 117 | eues 118 | eus 119 | ai 120 | as 121 | avons 122 | avez 123 | ont 124 | aurai 125 | auras 126 | aura 127 | aurons 128 | aurez 129 | auront 130 | aurais 131 | aurait 132 | aurions 133 | auriez 134 | auraient 135 | avais 136 | avait 137 | avions 138 | aviez 139 | avaient 140 | eut 141 | eûmes 142 | eûtes 143 | eurent 144 | aie 145 | aies 146 | ait 147 | ayons 148 | ayez 149 | aient 150 | eusse 151 | eusses 152 | eût 153 | eussions 154 | eussiez 155 | eussent 156 | -------------------------------------------------------------------------------- /kernels/probability_testing/data/transcripts.tar.bz2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/kernels/probability_testing/data/transcripts.tar.bz2 -------------------------------------------------------------------------------- /kernels/probability_testing/support/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | import sys 3 | assert sys.version_info >= (3, 6) 4 | 5 | from .kernels import LazyKernel 6 | from .mmd import mmd2_u_stat_variance 7 | from .utils import as_tensors, maybe_squeeze, pil_grid 8 | -------------------------------------------------------------------------------- /kernels/probability_testing/support/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/kernels/probability_testing/support/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /kernels/probability_testing/support/__pycache__/kernels.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/kernels/probability_testing/support/__pycache__/kernels.cpython-36.pyc -------------------------------------------------------------------------------- /kernels/probability_testing/support/__pycache__/mmd.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/kernels/probability_testing/support/__pycache__/mmd.cpython-36.pyc -------------------------------------------------------------------------------- /kernels/probability_testing/support/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/kernels/probability_testing/support/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /kernels/probability_testing/support/kernels.py: -------------------------------------------------------------------------------- 1 | """ 2 | Some probably over-engineered infrastructure for lazily computing kernel 3 | matrices, allowing for various sums / means / etc used by MMD-related estimators. 4 | """ 5 | from copy import copy 6 | from functools import wraps 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from .utils import as_tensors 12 | 13 | 14 | def _cache(f): 15 | # Only works when the function takes no or simple arguments! 16 | @wraps(f) 17 | def wrapper(self, *args): 18 | key = (f.__name__,) + tuple(args) 19 | if key in self._cache: 20 | return self._cache[key] 21 | self._cache[key] = val = f(self, *args) 22 | return val 23 | 24 | return wrapper 25 | 26 | 27 | ################################################################################ 28 | # Kernel base class 29 | 30 | _name_map = {"X": 0, "Y": 1, "Z": 2} 31 | 32 | 33 | class LazyKernel(torch.nn.Module): 34 | """ 35 | Base class that allows computing kernel matrices among a bunch of datasets, 36 | only computing the matrices when we use them. 37 | 38 | Constructor arguments: 39 | - A bunch of matrices we'll compute the kernel among. 40 | 2d tensors, with second dimension agreeing, or None; 41 | None is a special value meaning to use the first entry X. 42 | (This is more efficient than passing the same tensor again.) 43 | 44 | Access the results with: 45 | - K[0, 1] to get the Tensor between parts 0 and 1. 46 | - K.XX, K.XY, K.ZY, etc: shortcuts, with X=0, Y=1, Z=2. 47 | - K.matrix(0, 1) or K.XY_m: returns a Matrix subclass (see below). 48 | """ 49 | 50 | def __init__(self, X, *rest): 51 | super().__init__() 52 | self._cache = {} 53 | if not hasattr(self, "const_diagonal"): 54 | self.const_diagonal = False 55 | 56 | # want to use pytorch buffer for parts 57 | # but can't assign a list to those, so munge some names 58 | X, *rest = as_tensors(X, *rest) 59 | if len(X.shape) < 2: 60 | raise ValueError( 61 | "LazyKernel expects inputs to be at least 2d. " 62 | "If your data is 1d, make it [n, 1] with X[:, np.newaxis]." 63 | ) 64 | 65 | self.register_buffer("_part_0", X) 66 | self.n_parts = 1 67 | for p in rest: 68 | self.append_part(p) 69 | 70 | @property 71 | def X(self): 72 | return self._part_0 73 | 74 | def _part(self, i): 75 | return self._buffers[f"_part_{i}"] 76 | 77 | def part(self, i): 78 | p = self._part(i) 79 | return self.X if p is None else p 80 | 81 | def n(self, i): 82 | return self.part(i).shape[0] 83 | 84 | @property 85 | def ns(self): 86 | return [self.n(i) for i in range(self.n_parts)] 87 | 88 | @property 89 | def parts(self): 90 | return [self.part(i) for i in range(self.n_parts)] 91 | 92 | @property 93 | def dtype(self): 94 | return self.X.dtype 95 | 96 | @property 97 | def device(self): 98 | return self.X.device 99 | 100 | def __repr__(self): 101 | return f"<{type(self).__name__}({', '.join(str(n) for n in self.ns)})>" 102 | 103 | def _compute(self, A, B): 104 | """ 105 | Compute the kernel matrix between A and B. 106 | 107 | Might get called with A = X, B = X, or A = X, B = Y, etc. 108 | 109 | Should return a tensor of shape [A.shape[0], B.shape[0]]. 110 | 111 | This default, slow, version calls self._compute_one(a, b) in a loop. 112 | If you override this, you don't need to implement _compute_one at all. 113 | 114 | If you implement _precompute, this gets added to the signature here: 115 | self._compute(A, *self._precompute(A), B, *self._precompute(B)). 116 | The default _precompute returns an empty tuple, so it's _compute(A, B), 117 | but if you make a _precompute that returns [A_squared, A_cubed] then it's 118 | self._compute(A, A_squared, A_cubed, B, B_squared, B_cubed). 119 | """ 120 | return torch.stack( 121 | [ 122 | torch.stack([torch.as_tensor(self._compute_one(a, b)) for b in B]) 123 | for a in A 124 | ] 125 | ) 126 | 127 | def _compute_one(self, a, b): 128 | raise NotImplementedError( 129 | f"{type(self).__name__}: need to implement _compute or _compute_one" 130 | ) 131 | 132 | def _precompute(self, A): 133 | """ 134 | Compute something extra for each part A. 135 | 136 | Can be used to share computation between kernel(X, X) and kernel(X, Y). 137 | 138 | We end up calling basically (but with caching) 139 | self._compute(A, *self._precompute(A), B, *self._precompute(B)) 140 | This default _precompute returns an empty tuple, so it's 141 | self._compute(A, B) 142 | But if you return [A_squared], it'd be 143 | self._compute(A, A_squared, B, B_squared) 144 | and so on. 145 | """ 146 | return () 147 | 148 | @_cache 149 | def _precompute_i(self, i): 150 | p = self._part(i) 151 | if p is None: 152 | return self._precompute_i(0) 153 | return self._precompute(p) 154 | 155 | @_cache 156 | def __getitem__(self, k): 157 | try: 158 | i, j = k 159 | except ValueError: 160 | raise KeyError("You should index kernels with pairs") 161 | 162 | A = self._part(i) 163 | if A is None: 164 | return self[0, j] 165 | 166 | B = self._part(j) 167 | if B is None: 168 | return self[i, 0] 169 | 170 | if i > j: 171 | return self[j, i].t() 172 | 173 | A_info = self._precompute_i(i) 174 | B_info = self._precompute_i(j) 175 | return self._compute(A, *A_info, B, *B_info) 176 | 177 | @_cache 178 | def matrix(self, i, j): 179 | if self._part(i) is None: 180 | return self.matrix(0, j) 181 | 182 | if self._part(j) is None: 183 | return self.matrix(i, 0) 184 | 185 | k = self[i, j] 186 | if i == j: 187 | return as_matrix(k, const_diagonal=self.const_diagonal, symmetric=True) 188 | else: 189 | return as_matrix(k) 190 | 191 | @_cache 192 | def joint(self, *inds): 193 | if not inds: 194 | return self.joint(*range(self.n_parts)) 195 | return torch.cat([torch.cat([self[i, j] for j in inds], 1) for i in inds], 0) 196 | 197 | @_cache 198 | def joint_m(self, *inds): 199 | if not inds: 200 | return self.joint_m(*range(self.n_parts)) 201 | return as_matrix( 202 | self.joint(*inds), const_diagonal=self.const_diagonal, symmetric=True 203 | ) 204 | 205 | def __getattr__(self, name): 206 | # self.X, self.Y, self.Z 207 | if name in _name_map: 208 | i = _name_map[name] 209 | if i < self.n_parts: 210 | return self.part(i) 211 | else: 212 | raise AttributeError(f"have {self.n_parts} parts, asked for {i}") 213 | 214 | # self.XX, self.XY, self.YZ, etc; also self.XX_m 215 | ret_matrix = False 216 | if len(name) == 4 and name.endswith("_m"): 217 | ret_matrix = True 218 | name = name[:2] 219 | 220 | if len(name) == 2: 221 | i = _name_map.get(name[0], np.inf) 222 | j = _name_map.get(name[1], np.inf) 223 | if i < self.n_parts and j < self.n_parts: 224 | return self.matrix(i, j) if ret_matrix else self[i, j] 225 | else: 226 | raise AttributeError(f"have {self.n_parts} parts, asked for {i}, {j}") 227 | 228 | return super().__getattr__(name) 229 | 230 | def _invalidate_cache(self, i): 231 | for k in list(self._cache.keys()): 232 | if ( 233 | i in k[1:] 234 | or any(isinstance(arg, tuple) and i in arg for arg in k[1:]) 235 | or k in [("joint",), ("joint_m",)] 236 | ): 237 | del self._cache[k] 238 | 239 | def drop_last_part(self): 240 | assert self.n_parts >= 2 241 | i = self.n_parts - 1 242 | self._invalidate_cache(i) 243 | del self._buffers[f"_part_{i}"] 244 | self.n_parts -= 1 245 | 246 | def change_part(self, i, new): 247 | assert i < self.n_parts 248 | if new is not None and new.shape[1:] != self.X.shape[1:]: 249 | raise ValueError(f"X has shape {self.X.shape}, new entry has {new.shape}") 250 | self._invalidate_cache(i) 251 | self._buffers[f"_part_{i}"] = new 252 | 253 | def append_part(self, new): 254 | if new is not None and new.shape[1:] != self.X.shape[1:]: 255 | raise ValueError(f"X has shape {self.X.shape}, new entry has {new.shape}") 256 | self._buffers[f"_part_{self.n_parts}"] = new 257 | self.n_parts += 1 258 | 259 | def __copy__(self): 260 | """ 261 | Doesn't deep-copy the data tensors, but copies dictionaries so that 262 | change_part/etc don't affect the original. 263 | """ 264 | cls = self.__class__ 265 | result = cls.__new__(cls) 266 | to_copy = {"_cache", "_buffers", "_parameters", "_modules"} 267 | result.__dict__.update( 268 | {k: v.copy() if k in to_copy else v for k, v in self.__dict__.items()} 269 | ) 270 | return result 271 | 272 | def _apply(self, fn): # used in to(), cuda(), etc 273 | super()._apply(fn) 274 | for key, val in self._cache.items(): 275 | if val is not None: 276 | self._cache[key] = fn(val) 277 | return self 278 | 279 | def as_tensors(self, *args, **kwargs): 280 | "Helper that makes everything a tensor with self.X's type." 281 | kwargs.setdefault("device", self.X.device) 282 | kwargs.setdefault("dtype", self.X.dtype) 283 | return tuple(None if r is None else torch.as_tensor(r, **kwargs) for r in args) 284 | 285 | 286 | ################################################################################ 287 | # Matrix wrappers that cache sums / etc. Including various subclasses; see 288 | # as_matrix() to pick between them appropriately. 289 | 290 | # TODO: could support a matrix transpose that shares the cache appropriately 291 | 292 | 293 | class Matrix: 294 | def __init__(self, M, const_diagonal=False): 295 | self.mat = M = torch.as_tensor(M) 296 | self.m, self.n = self.shape = M.shape 297 | self._cache = {} 298 | 299 | @_cache 300 | def row_sums(self): 301 | return self.mat.sum(0) 302 | 303 | @_cache 304 | def col_sums(self): 305 | return self.mat.sum(1) 306 | 307 | @_cache 308 | def row_sums_sq_sum(self): 309 | sums = self.row_sums() 310 | return sums @ sums 311 | 312 | @_cache 313 | def col_sums_sq_sum(self): 314 | sums = self.col_sums() 315 | return sums @ sums 316 | 317 | @_cache 318 | def sum(self): 319 | if "row_sums" in self._cache: 320 | return self.row_sums().sum() 321 | elif "col_sums" in self._cache: 322 | return self.col_sums().sum() 323 | else: 324 | return self.mat.sum() 325 | 326 | def mean(self): 327 | return self.sum() / (self.m * self.n) 328 | 329 | @_cache 330 | def sq_sum(self): 331 | flat = self.mat.view(-1) 332 | return flat @ flat 333 | 334 | def __repr__(self): 335 | return f"<{type(self).__name__}, {self.m} by {self.n}>" 336 | 337 | 338 | class SquareMatrix(Matrix): 339 | def __init__(self, M): 340 | super().__init__(M) 341 | assert self.m == self.n 342 | 343 | @_cache 344 | def diagonal(self): 345 | return self.mat.diagonal() 346 | 347 | @_cache 348 | def trace(self): 349 | return self.mat.trace() 350 | 351 | @_cache 352 | def sq_trace(self): 353 | diag = self.diagonal() 354 | return diag @ diag 355 | 356 | @_cache 357 | def offdiag_row_sums(self): 358 | return self.row_sums() - self.diagonal() 359 | 360 | @_cache 361 | def offdiag_col_sums(self): 362 | return self.col_sums() - self.diagonal() 363 | 364 | @_cache 365 | def offdiag_row_sums_sq_sum(self): 366 | sums = self.offdiag_row_sums() 367 | return sums @ sums 368 | 369 | @_cache 370 | def offdiag_col_sums_sq_sum(self): 371 | sums = self.offdiag_col_sums() 372 | return sums @ sums 373 | 374 | @_cache 375 | def offdiag_sum(self): 376 | return self.offdiag_row_sums().sum() 377 | 378 | def offdiag_mean(self): 379 | return self.offdiag_sum() / (self.n * (self.n - 1)) 380 | 381 | @_cache 382 | def offdiag_sq_sum(self): 383 | return self.sq_sum() - self.sq_trace() 384 | 385 | 386 | class SymmetricMatrix(SquareMatrix): 387 | def col_sums(self): 388 | return self.row_sums() 389 | 390 | def sums(self): 391 | return self.row_sums() 392 | 393 | def offdiag_col_sums(self): 394 | return self.offdiag_row_sums() 395 | 396 | def offdiag_sums(self): 397 | return self.offdiag_row_sums() 398 | 399 | def col_sums_sq_sum(self): 400 | return self.row_sums_sq_sum() 401 | 402 | def sums_sq_sum(self): 403 | return self.row_sums_sq_sum() 404 | 405 | def offdiag_col_sums_sq_sum(self): 406 | return self.offdiag_row_sums_sq_sum() 407 | 408 | def offdiag_sums_sq_sum(self): 409 | return self.offdiag_row_sums_sq_sum() 410 | 411 | 412 | class ConstDiagMatrix(SquareMatrix): 413 | def __init__(self, M, diag_value): 414 | super().__init__(M) 415 | self.diag_value = diag_value 416 | 417 | @_cache 418 | def diagonal(self): 419 | return self.mat.new_full((1,), self.diag_value) 420 | 421 | def trace(self): 422 | return self.n * self.diag_value 423 | 424 | def sq_trace(self): 425 | return self.n * (self.diag_value ** 2) 426 | 427 | 428 | class SymmetricConstDiagMatrix(ConstDiagMatrix, SymmetricMatrix): 429 | pass 430 | 431 | 432 | def as_matrix(M, const_diagonal=False, symmetric=False): 433 | if symmetric: 434 | if const_diagonal is not False: 435 | return SymmetricConstDiagMatrix(M, diag_value=const_diagonal) 436 | else: 437 | return SymmetricMatrix(M) 438 | elif const_diagonal is not False: 439 | return ConstDiagMatrix(M, diag_value=const_diagonal) 440 | elif M.shape[0] == M.shape[1]: 441 | return SquareMatrix(M) 442 | else: 443 | return Matrix(M) 444 | -------------------------------------------------------------------------------- /kernels/probability_testing/support/mmd.py: -------------------------------------------------------------------------------- 1 | def mmd2_u_stat_variance(K, inds=(0, 1)): 2 | """ 3 | Estimate MMD variance with estimator from https://arxiv.org/abs/1906.02104. 4 | 5 | K should be a LazyKernel; we'll compare the parts in inds, 6 | default (0, 1) to use K.XX, K.XY, K.YY. 7 | """ 8 | i, j = inds 9 | 10 | m = K.n(i) 11 | assert K.n(j) == m 12 | 13 | XX = K.matrix(i, i) 14 | XY = K.matrix(i, j) 15 | YY = K.matrix(j, j) 16 | 17 | mm = m * m 18 | mmm = mm * m 19 | m1 = m - 1 20 | m1_m1 = m1 * m1 21 | m1_m1_m1 = m1_m1 * m1 22 | m2 = m - 2 23 | mdown2 = m * m1 24 | mdown3 = mdown2 * m2 25 | mdown4 = mdown3 * (m - 3) 26 | twom3 = 2 * m - 3 27 | 28 | return ( 29 | (4 / mdown4) * (XX.offdiag_sums_sq_sum() + YY.offdiag_sums_sq_sum()) 30 | + (4 * (mm - m - 1) / (mmm * m1_m1)) 31 | * (XY.row_sums_sq_sum() + XY.col_sums_sq_sum()) 32 | - (8 / (mm * (mm - 3 * m + 2))) 33 | * (XX.offdiag_sums() @ XY.col_sums() + YY.offdiag_sums() @ XY.row_sums()) 34 | + 8 / (mm * mdown3) * ((XX.offdiag_sum() + YY.offdiag_sum()) * XY.sum()) 35 | - (2 * twom3 / (mdown2 * mdown4)) * (XX.offdiag_sum() + YY.offdiag_sum()) 36 | - (4 * twom3 / (mmm * m1_m1_m1)) * XY.sum() ** 2 37 | - (2 / (m * (mmm - 6 * mm + 11 * m - 6))) 38 | * (XX.offdiag_sq_sum() + YY.offdiag_sq_sum()) 39 | + (4 * m2 / (mm * m1_m1_m1)) * XY.sq_sum() 40 | ) 41 | -------------------------------------------------------------------------------- /kernels/probability_testing/support/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | 4 | 5 | def as_tensors(X, *rest): 6 | "Calls as_tensor on a bunch of args, all of the first's device and dtype." 7 | X = torch.as_tensor(X) 8 | return [X] + [ 9 | None if r is None else torch.as_tensor(r, device=X.device, dtype=X.dtype) 10 | for r in rest 11 | ] 12 | 13 | 14 | def pil_grid(X, **kwargs): 15 | return torchvision.transforms.ToPILImage()(torchvision.utils.make_grid(X, **kwargs)) 16 | 17 | 18 | def maybe_squeeze(X, dim): 19 | "Like torch.squeeze, but don't crash if dim already doesn't exist." 20 | return torch.squeeze(X, dim) if dim < len(X.shape) else X 21 | -------------------------------------------------------------------------------- /kernels/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | 3 | 4 | setup( 5 | name="probability_testing", 6 | version="0.1", 7 | include_package_data=True, 8 | packages=[ 9 | "probability_testing", "probability_testing.support", 10 | ] 11 | 12 | ) 13 | -------------------------------------------------------------------------------- /optimal_transport_tutorial/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include data/*.pickle 2 | include data/*.jpg -------------------------------------------------------------------------------- /optimal_transport_tutorial/Opt_transport_1_Introduction_to_POT_and_S.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "colab_type": "text", 7 | "id": "KV6ChN5Nt4fj" 8 | }, 9 | "source": [ 10 | "# MLSS 2019: Optimal Transport and Wasserstein Distances" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": { 16 | "colab_type": "text", 17 | "id": "Zb-hmuept4fl" 18 | }, 19 | "source": [ 20 | "The goal of this first practical session is to introduce computational optimal transport (OT) in Python. You will familiarize yourself with OT by:\n", 21 | "1. using the Python library POT (Python Optimal Transport),\n", 22 | "2. coding Sinkhorn's algorithm.\n", 23 | "\n", 24 | "In the second practical session, we will use optimal transport as a nice geometrical tool in machine learning." 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": { 30 | "colab_type": "text", 31 | "id": "sbBHs3xPgCrm" 32 | }, 33 | "source": [ 34 | "We are going to use Google Collab to run this notebook. In order to install all the necessary files run the following cells:" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": { 41 | "colab": {}, 42 | "colab_type": "code", 43 | "id": "zdbamnHGu7Tw" 44 | }, 45 | "outputs": [], 46 | "source": [ 47 | "import os\n", 48 | "!pip install --upgrade git+https://github.com/mlss-skoltech/tutorials.git#subdirectory=optimal_transport_tutorial" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": { 55 | "colab": {}, 56 | "colab_type": "code", 57 | "id": "wGfIhryst4fm" 58 | }, 59 | "outputs": [], 60 | "source": [ 61 | "# Check your installation by importing POT\n", 62 | "!pip install pot\n", 63 | "import ot" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": { 69 | "colab_type": "text", 70 | "id": "Ahr5RgE4gtgS" 71 | }, 72 | "source": [ 73 | "Declare ```DATA_PATH``` as a path to the data from the tutorial package" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": { 80 | "colab": {}, 81 | "colab_type": "code", 82 | "id": "SPIus7WThTpU" 83 | }, 84 | "outputs": [], 85 | "source": [ 86 | "import pkg_resources\n", 87 | "\n", 88 | "DATA_PATH = pkg_resources.resource_filename('optimaltransport', 'data/')" 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "metadata": { 94 | "colab_type": "text", 95 | "id": "rRk-69jhgCrt" 96 | }, 97 | "source": [ 98 | "If you are running this notebook locally, make sure to clone the tutorial repository:\n", 99 | "\n", 100 | "```!pip install --upgrade git+https://github.com/mlss-skoltech/tutorials.git#subdirectory=optimal_transport_tutorial```\n", 101 | "\n", 102 | "\n", 103 | "\n", 104 | "And install the following package:\n", 105 | "\n", 106 | "* Install with pip: ```bash pip install pot```\n", 107 | "* Install with conda: ```bash conda install -c conda-forge pot ```" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "metadata": { 114 | "colab": {}, 115 | "colab_type": "code", 116 | "id": "T1BG78Dtt4fp" 117 | }, 118 | "outputs": [], 119 | "source": [ 120 | "import numpy as np\n", 121 | "import matplotlib.pyplot as plt" 122 | ] 123 | }, 124 | { 125 | "cell_type": "markdown", 126 | "metadata": { 127 | "colab_type": "text", 128 | "id": "_8RMdsSot4fr" 129 | }, 130 | "source": [ 131 | "## 1. Solving Exact OT: Linear Programming" 132 | ] 133 | }, 134 | { 135 | "cell_type": "markdown", 136 | "metadata": { 137 | "colab_type": "text", 138 | "id": "tzPSVXint4fr" 139 | }, 140 | "source": [ 141 | "### Reminders on Optimal Transport" 142 | ] 143 | }, 144 | { 145 | "cell_type": "markdown", 146 | "metadata": { 147 | "colab_type": "text", 148 | "id": "I67_BnZOt4fs" 149 | }, 150 | "source": [ 151 | "Optimal Transport is a theory that allows us to compare two (weighted) points clouds $(X, a)$ and $(Y, b)$, where $X \\in \\mathbb{R}^{n \\times d}$ and $Y \\in \\mathbb{R}^{m \\times d}$ are the locations of the $n$ (resp. $m$) points in dimension $d$, and $a \\in \\mathbb{R}^n$, $b \\in \\mathbb{R}^m$ are the weights. We ask that the total weights sum to one, i.e. $\\sum_{i=1}^n a_i = \\sum_{j=1}^m b_j = 1$." 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "metadata": { 157 | "colab_type": "text", 158 | "id": "03yipBjGt4fs" 159 | }, 160 | "source": [ 161 | "The basic idea of Optimal Transport is to \"transport\" the mass located at points $X$ to the mass located at points $Y$.\n", 162 | "\n", 163 | "Let us denote by $\\mathcal{U}(a,b) = \\left\\{ P \\in \\mathbb{R}^{n \\times m} \\,|\\, P \\geq 0, \\sum_{j=1}^m P_{ij} = a_i, \\sum_{i=1}^n P_{ij} = b_j\\right\\}$ the set of admissible transport plans.\n", 164 | "\n", 165 | "If $P \\in \\mathcal{U}(a,b)$, the quantity $P_{ij} \\geq 0$ should be regarded as the mass transported from point $X_i$ to point $Y_j$. For this reason, it is called a *transport plan*.\n", 166 | "\n", 167 | "We will also consider a *cost matrix* $C \\in \\mathbb{R}^{n \\times m}$. The quantity $C_{ij}$ should be regarded as the cost paid for transporting one unit of mass from $X_i$ to $Y_j$. This cost is usually computed using the positions $X_i$ and $Y_j$, for example $C_{ij} = \\|X_i - Y_j\\|$ or $C_{ij} = \\|X_i - Y_j\\|^2$.\n", 168 | "\n", 169 | "Then transporting mass according to $P \\in \\mathcal{U}(a,b)$ has a total cost of $\\sum_{ij} P_{ij} C_{ij}$." 170 | ] 171 | }, 172 | { 173 | "cell_type": "markdown", 174 | "metadata": { 175 | "colab_type": "text", 176 | "id": "icSHTV5Ut4ft" 177 | }, 178 | "source": [ 179 | "In \"Optimal Transport\", there is the word _Optimal_. Indeed, we want to find a transport plan $P \\in \\mathcal{U}(a,b)$ that will minimize its total cost. In other words, we want to solve\n", 180 | "$$\n", 181 | " \\min_{P \\in \\mathcal{U}(a,b)} \\sum_{ij} C_{ij }P_{ij}.\n", 182 | "$$" 183 | ] 184 | }, 185 | { 186 | "cell_type": "markdown", 187 | "metadata": { 188 | "colab_type": "text", 189 | "id": "pSsukYhWt4fu" 190 | }, 191 | "source": [ 192 | "This problem is a Linear Program: the objective function is linear in the variable $P$, and the constraints are linear in $P$. We can thus solve this problem using classical Linear Programming algorithms, such as the simplex algorithm." 193 | ] 194 | }, 195 | { 196 | "cell_type": "markdown", 197 | "metadata": { 198 | "colab_type": "text", 199 | "id": "ckmQk8t9t4fv" 200 | }, 201 | "source": [ 202 | "If $P^*$ is a solution to the Optimal Transport problem, we will say that $P^*$ is an optimal transport plan between $(X, a)$ and $(Y, b)$, and that $\\sum_{ij} P^*_{ij} C_{ij}$ is the optimal transport distance between $(X, a)$ and $(Y, b)$: it is the minimal amount of \"energy\" that is necessary to transport the initial mass located at points $X$ to the target mass located at points $Y$." 203 | ] 204 | }, 205 | { 206 | "cell_type": "markdown", 207 | "metadata": { 208 | "colab_type": "text", 209 | "id": "fRt9uBnWt4fw" 210 | }, 211 | "source": [ 212 | "### Computing Optimal \"Croissant\" Transport using POT" 213 | ] 214 | }, 215 | { 216 | "cell_type": "markdown", 217 | "metadata": { 218 | "colab_type": "text", 219 | "id": "27EET8bAt4fw" 220 | }, 221 | "source": [ 222 | "We will solve the Bakery/Cafés problem of transporting croissants from a number of Bakeries to Cafés in Moscow.\n", 223 | "\n", 224 | "We use fictional positions, production and sale numbers (that both sum to the same value).\n", 225 | "\n", 226 | "We have acess to the position of Bakeries $X \\in \\mathbb{R}^{8 \\times 2}$ and their respective production $a \\in \\mathbb{R}^8$ which describe the source point cloud. The Cafés where the croissants are sold are defined by their position $Y \\in \\mathbb{R}^{5 \\times 2}$ and $b \\in \\mathbb{R}^{5}$." 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": null, 232 | "metadata": { 233 | "colab": {}, 234 | "colab_type": "code", 235 | "id": "ACOcSFXqt4fx" 236 | }, 237 | "outputs": [], 238 | "source": [ 239 | "# Load the data\n", 240 | "import pickle\n", 241 | "\n", 242 | "with open(DATA_PATH + 'croissants.pickle', 'rb') as file:\n", 243 | " croissants = pickle.load(file)\n", 244 | "\n", 245 | "X = croissants['bakery_pos']\n", 246 | "a = croissants['bakery_prod']\n", 247 | "Y = croissants['cafe_pos']\n", 248 | "b = croissants['cafe_prod']\n", 249 | "\n", 250 | "print('Bakery productions =', a)\n", 251 | "print('Café sales =', b)\n", 252 | "print('Total number of croissants =', a.sum())" 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": null, 258 | "metadata": { 259 | "colab": {}, 260 | "colab_type": "code", 261 | "id": "W2hrrI_Xt4fz" 262 | }, 263 | "outputs": [], 264 | "source": [ 265 | "plt.figure(figsize=(8,8))\n", 266 | "plt.scatter(X[:,0], X[:,1], s=10*a, c='r', edgecolors='k', label='Bakeries')\n", 267 | "plt.scatter(Y[:,0], Y[:,1], s=10*b, c='b', edgecolors='k', label='Cafés')\n", 268 | "plt.legend(fontsize=20)\n", 269 | "plt.axis('off')\n", 270 | "plt.title('Moscow Bakeries and Cafés', fontsize=25)\n", 271 | "plt.show()" 272 | ] 273 | }, 274 | { 275 | "cell_type": "markdown", 276 | "metadata": { 277 | "colab_type": "text", 278 | "id": "yAIYAif9t4f1" 279 | }, 280 | "source": [ 281 | "Let us now compute the cost matrix $C \\in \\mathbb{R}^{n \\times m}$. Here, we will use two different costs: $\\ell_1$ and $\\ell_2$ costs." 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": null, 287 | "metadata": { 288 | "colab": {}, 289 | "colab_type": "code", 290 | "id": "77sLokEtt4f1" 291 | }, 292 | "outputs": [], 293 | "source": [ 294 | "C_1 = np.zeros((8,5)) # TODO: contains the l1 distances\n", 295 | "C_2 = np.zeros((8,5)) # TODO: contains the l2 distances" 296 | ] 297 | }, 298 | { 299 | "cell_type": "markdown", 300 | "metadata": { 301 | "colab_type": "text", 302 | "id": "Cp968Q4ht4f3" 303 | }, 304 | "source": [ 305 | "We can now compute the Optimal Transport plan to transport the croissants from the bakeries to the cafés, for the two different costs." 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "execution_count": null, 311 | "metadata": { 312 | "colab": {}, 313 | "colab_type": "code", 314 | "id": "maFZKc0-t4f5" 315 | }, 316 | "outputs": [], 317 | "source": [ 318 | "optimal_plan_1 = ot.emd() # TODO: compute the exact OT plan using function ot.emd\n", 319 | "print(optimal_plan_1)\n", 320 | "optimal_cost_1 = # TODO: compute the OT cost for the l1 ground cost\n", 321 | "print('1-Wasserstein distance =', optimal_cost_1)\n", 322 | "print('')\n", 323 | "\n", 324 | "optimal_plan_2 = ot.emd() # TODO: compute the exact OT plan using function ot.emd\n", 325 | "print(optimal_plan_2)\n", 326 | "optimal_cost_2 = # TODO: compute the OT cost for the l2 ground cost\n", 327 | "print('2-Wasserstein distance =', np.sqrt(optimal_cost_2))" 328 | ] 329 | }, 330 | { 331 | "cell_type": "code", 332 | "execution_count": null, 333 | "metadata": { 334 | "colab": {}, 335 | "colab_type": "code", 336 | "id": "IMYthtNvt4f6" 337 | }, 338 | "outputs": [], 339 | "source": [ 340 | "fig = plt.figure(figsize=(17,8))\n", 341 | "\n", 342 | "ax = fig.add_subplot(1, 2, 1)\n", 343 | "ax.scatter(X[:,0], X[:,1], s=10*a, c='r', edgecolors='k', label='Bakeries')\n", 344 | "ax.scatter(Y[:,0], Y[:,1], s=10*b, c='b', edgecolors='k', label='Cafés')\n", 345 | "# TODO: plot a line between Bakery i and Café j whenever some croissants are transported between i and j\n", 346 | "ax.axis('off')\n", 347 | "ax.set_title('$\\ell_1$ cost', fontsize=30)\n", 348 | "\n", 349 | "ax = fig.add_subplot(1, 2, 2)\n", 350 | "ax.scatter(X[:,0], X[:,1], s=10*a, c='r', edgecolors='k', label='Bakeries')\n", 351 | "ax.scatter(Y[:,0], Y[:,1], s=10*b, c='b', edgecolors='k', label='Cafés')\n", 352 | "# TODO: plot a line between Bakery i and Café j whenever some croissants are transported between i and j\n", 353 | "ax.axis('off')\n", 354 | "ax.set_title('$\\ell_2$ cost', fontsize=30)\n", 355 | "\n", 356 | "plt.legend(fontsize=20)\n", 357 | "plt.show()" 358 | ] 359 | }, 360 | { 361 | "cell_type": "markdown", 362 | "metadata": { 363 | "colab_type": "text", 364 | "id": "eWnTSM2-t4f9" 365 | }, 366 | "source": [ 367 | "## 2. Sinkhorn Algorithm for Entropy Regularized Optimal Transport" 368 | ] 369 | }, 370 | { 371 | "cell_type": "markdown", 372 | "metadata": { 373 | "colab_type": "text", 374 | "id": "OS6IMu8Bt4f-" 375 | }, 376 | "source": [ 377 | "### Reminders on Sinkhorn Algorithm" 378 | ] 379 | }, 380 | { 381 | "cell_type": "markdown", 382 | "metadata": { 383 | "colab_type": "text", 384 | "id": "eQncY7VGt4f-" 385 | }, 386 | "source": [ 387 | "In real applications, and especially in Machine Learning, we often have to deal with huge numbers of points. In this case, the linear programming algorithms which have cubic complexity will take too much time to run.\n", 388 | "\n", 389 | "That's why in practise, among other reasons, people minimize another criterion given by\n", 390 | "$$\n", 391 | " \\min_{P \\in \\mathcal{U}(a,b)} \\langle C, P \\rangle + \\epsilon \\sum_{ij} P_{ij} [ \\log(P_{ij}) - 1].\n", 392 | "$$\n", 393 | "When $\\epsilon$ is sufficiently small, we can consider that a solution to the above problem (often refered to as \"Entropy-regularized Optimal Transport\") is a good approximation of a real optimal transport plan." 394 | ] 395 | }, 396 | { 397 | "cell_type": "markdown", 398 | "metadata": { 399 | "colab_type": "text", 400 | "id": "7L05M4snt4f_" 401 | }, 402 | "source": [ 403 | "In order to solve this problem, one can remark that the optimality conditions imply that a solution $P_\\epsilon^*$ necessarily is of the form $P_\\epsilon^* = \\text{diag}(u) \\, K \\, \\text{diag}(v)$, where $K = \\exp(-C/\\epsilon)$ and $u,v$ are two non-negative vectors." 404 | ] 405 | }, 406 | { 407 | "cell_type": "markdown", 408 | "metadata": { 409 | "colab_type": "text", 410 | "id": "4NxnPr1-t4gA" 411 | }, 412 | "source": [ 413 | "$P_\\epsilon^*$ should verify the constraints, i.e. $P_\\epsilon^* \\in \\mathcal{U}(a,b)$, so that\n", 414 | "$$\n", 415 | " P_\\epsilon^* 1_m = a \\text{ and } (P_\\epsilon^*)^T 1_n = b\n", 416 | "$$\n", 417 | "which can be rewritten as\n", 418 | "$$\n", 419 | " u \\odot (Kv) = a \\text{ and } v \\odot (K^T u) = b\n", 420 | "$$\n", 421 | "\n", 422 | "Then Sinkhorn's algorithm alternate between the resolution of these two equations, and reads\n", 423 | "$$\n", 424 | " u \\leftarrow \\frac{a}{Kv} \\text{ and } v \\leftarrow \\frac{b}{K^T u}\n", 425 | "$$" 426 | ] 427 | }, 428 | { 429 | "cell_type": "code", 430 | "execution_count": null, 431 | "metadata": { 432 | "colab": {}, 433 | "colab_type": "code", 434 | "id": "a4A4wbg3t4gA" 435 | }, 436 | "outputs": [], 437 | "source": [ 438 | "def sinkhorn(a, b, C, epsilon=0.1, max_iters=100):\n", 439 | " \"\"\"Run Sinnkhorn's algorithm\"\"\"\n", 440 | " \n", 441 | " # TODO: Compute the kernel matrix K\n", 442 | " \n", 443 | " # TODO: Alternate projections\n", 444 | " \n", 445 | " return # TODO" 446 | ] 447 | }, 448 | { 449 | "cell_type": "code", 450 | "execution_count": null, 451 | "metadata": { 452 | "colab": {}, 453 | "colab_type": "code", 454 | "id": "8M5FuBFnt4gC" 455 | }, 456 | "outputs": [], 457 | "source": [ 458 | "np.round(sinkhorn(a, b, C_2/C_2.max(), epsilon=0.01), 2)" 459 | ] 460 | }, 461 | { 462 | "cell_type": "code", 463 | "execution_count": null, 464 | "metadata": { 465 | "colab": {}, 466 | "colab_type": "code", 467 | "id": "inUdN4eIt4gD" 468 | }, 469 | "outputs": [], 470 | "source": [ 471 | "optimal_plan_2" 472 | ] 473 | }, 474 | { 475 | "cell_type": "markdown", 476 | "metadata": { 477 | "colab_type": "text", 478 | "id": "mhQOXGjEt4gF" 479 | }, 480 | "source": [ 481 | "We first show that this algorithm is consistent with classical optimal transport, using the \"croissant\" transport example." 482 | ] 483 | }, 484 | { 485 | "cell_type": "code", 486 | "execution_count": null, 487 | "metadata": { 488 | "colab": {}, 489 | "colab_type": "code", 490 | "id": "mWJnrT_Ot4gG" 491 | }, 492 | "outputs": [], 493 | "source": [ 494 | "plan_diff = []\n", 495 | "distance_diff = []\n", 496 | "for epsilon in np.linspace(0.01, 1, 100):\n", 497 | " optimal_plan_sinkhorn = # TODO: compute OT plan using Sinkhorn, with regularization strength epsilon\n", 498 | " optimal_cost_sinkhorn = # TODO: compute OT distance using Sinkhorn\n", 499 | " plan_diff.append() # TODO: compute the Frobenius distance between the exact OT plan and the Sinkhorn OT plan\n", 500 | " distance_diff.append() # TODO: compute the error between exact OT and Sinkhorn values (in %)" 501 | ] 502 | }, 503 | { 504 | "cell_type": "code", 505 | "execution_count": null, 506 | "metadata": { 507 | "colab": {}, 508 | "colab_type": "code", 509 | "id": "baCU3-hMt4gH" 510 | }, 511 | "outputs": [], 512 | "source": [ 513 | "plt.figure(figsize=(16,5))\n", 514 | "plt.loglog(np.linspace(0.01, 1, 100), plan_diff, lw=4)\n", 515 | "plt.xlabel('Regularization Strength $\\epsilon$', fontsize=25)\n", 516 | "plt.ylabel('$||P^* - P_\\epsilon^*||_F$', fontsize=25)\n", 517 | "plt.xticks(fontsize=20)\n", 518 | "plt.yticks(fontsize=20)\n", 519 | "plt.grid(ls='--')\n", 520 | "plt.show()" 521 | ] 522 | }, 523 | { 524 | "cell_type": "code", 525 | "execution_count": null, 526 | "metadata": { 527 | "colab": {}, 528 | "colab_type": "code", 529 | "id": "H7pedasSt4gJ" 530 | }, 531 | "outputs": [], 532 | "source": [ 533 | "plt.figure(figsize=(16,5))\n", 534 | "plt.loglog(np.linspace(0.01, 1, 100), distance_diff, lw=4)\n", 535 | "plt.xlabel('Regularization Strength $\\epsilon$', fontsize=25)\n", 536 | "plt.ylabel('Error in %', fontsize=25)\n", 537 | "plt.xticks(fontsize=20)\n", 538 | "plt.yticks(fontsize=20)\n", 539 | "plt.grid(ls='--')\n", 540 | "plt.show()" 541 | ] 542 | }, 543 | { 544 | "cell_type": "markdown", 545 | "metadata": { 546 | "colab_type": "text", 547 | "id": "2fwnkgjpt4gL" 548 | }, 549 | "source": [ 550 | "Let us now compare the running time for sinkhorn and classical optimal transport algorithm on more data." 551 | ] 552 | }, 553 | { 554 | "cell_type": "code", 555 | "execution_count": null, 556 | "metadata": { 557 | "colab": {}, 558 | "colab_type": "code", 559 | "id": "hOfRY9Wut4gM" 560 | }, 561 | "outputs": [], 562 | "source": [ 563 | "n = 1000\n", 564 | "m = 1000\n", 565 | "d = 2\n", 566 | "\n", 567 | "X = np.random.randn(n,d)\n", 568 | "Y = np.random.randn(m,d)\n", 569 | "\n", 570 | "a = np.ones(n)\n", 571 | "b = np.ones(m)\n", 572 | "\n", 573 | "C = np.zeros((n,m))\n", 574 | "# TODO: compute the cost matrix (using l2 ground distance)" 575 | ] 576 | }, 577 | { 578 | "cell_type": "markdown", 579 | "metadata": { 580 | "colab_type": "text", 581 | "id": "7jgwaUChtPAf" 582 | }, 583 | "source": [ 584 | "Because of Google Colab set up the time measuring can be unreliable, in order to get more certain results try running the code locally" 585 | ] 586 | }, 587 | { 588 | "cell_type": "code", 589 | "execution_count": null, 590 | "metadata": { 591 | "colab": { 592 | "base_uri": "https://localhost:8080/", 593 | "height": 170 594 | }, 595 | "colab_type": "code", 596 | "id": "SiRL17XYt4gP", 597 | "outputId": "e16c4fc2-1e3b-4c63-ca57-cfc1afac0613" 598 | }, 599 | "outputs": [], 600 | "source": [ 601 | "%time ot.emd(a,b,C)" 602 | ] 603 | }, 604 | { 605 | "cell_type": "code", 606 | "execution_count": null, 607 | "metadata": { 608 | "colab": { 609 | "base_uri": "https://localhost:8080/", 610 | "height": 272 611 | }, 612 | "colab_type": "code", 613 | "id": "uyt7TZBzpEDC", 614 | "outputId": "03a76a69-2668-400a-aab4-741f03ae0a0b" 615 | }, 616 | "outputs": [], 617 | "source": [ 618 | "%time sinkhorn(a,b,C)" 619 | ] 620 | }, 621 | { 622 | "cell_type": "markdown", 623 | "metadata": { 624 | "colab_type": "text", 625 | "id": "OTbGsBvPt4gU" 626 | }, 627 | "source": [ 628 | "We see that sinkhorn is faster. What is even more interesting is that sinkhorn can be parallelerized on GPUs, giving further acceleration. Of course, Sinkhorn algorithm is not computing the exact optimal transport plan any more." 629 | ] 630 | }, 631 | { 632 | "cell_type": "markdown", 633 | "metadata": {}, 634 | "source": [ 635 | "## 3. Optimal Transport in Dimension 1\n", 636 | "\n", 637 | "In dimension $d=1$, computing OT boils down to sorting the points. You will check this fact, and discuss the influence of the regularization strength $\\epsilon$." 638 | ] 639 | }, 640 | { 641 | "cell_type": "code", 642 | "execution_count": null, 643 | "metadata": {}, 644 | "outputs": [], 645 | "source": [ 646 | "n = 4\n", 647 | "m = 4\n", 648 | "\n", 649 | "X = np.random.uniform(size=n)\n", 650 | "Y = np.random.uniform(size=m)\n", 651 | "\n", 652 | "a = np.ones(n)\n", 653 | "b = np.ones(m)\n", 654 | "\n", 655 | "plt.figure(figsize=(17,4))\n", 656 | "plt.scatter(X, np.zeros(n), s=200*a, c='r')\n", 657 | "plt.scatter(Y, np.zeros(m), s=200*b, c='b')\n", 658 | "for i in range(n):\n", 659 | " plt.gca().annotate(str(i+1), xy=(X[i],0.005), size=30, color='r', ha='center')\n", 660 | "for j in range(m):\n", 661 | " plt.gca().annotate(str(j+1), xy=(Y[j],0.005), size=30, color='b', ha='center')\n", 662 | "plt.axis('off')\n", 663 | "plt.show()" 664 | ] 665 | }, 666 | { 667 | "cell_type": "code", 668 | "execution_count": null, 669 | "metadata": {}, 670 | "outputs": [], 671 | "source": [ 672 | "# TODO: Compute the OT plan using sorting, POT, and Sinkhorn. Discuss the results and the running times." 673 | ] 674 | } 675 | ], 676 | "metadata": { 677 | "accelerator": "GPU", 678 | "colab": { 679 | "name": "MLSS 1 Introduction to POT and Sinkhorn Algorithm (student version).ipynb", 680 | "provenance": [], 681 | "version": "0.3.2" 682 | }, 683 | "kernelspec": { 684 | "display_name": "Python 3", 685 | "language": "python", 686 | "name": "python3" 687 | }, 688 | "language_info": { 689 | "codemirror_mode": { 690 | "name": "ipython", 691 | "version": 3 692 | }, 693 | "file_extension": ".py", 694 | "mimetype": "text/x-python", 695 | "name": "python", 696 | "nbconvert_exporter": "python", 697 | "pygments_lexer": "ipython3", 698 | "version": "3.7.1" 699 | } 700 | }, 701 | "nbformat": 4, 702 | "nbformat_minor": 1 703 | } 704 | -------------------------------------------------------------------------------- /optimal_transport_tutorial/Opt_transport_2_Optimal_Transport_for_Mac.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "colab_type": "text", 7 | "id": "57KL8eDKwAkr" 8 | }, 9 | "source": [ 10 | "# MLSS 2019: Optimal Transport for Machine Learning" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": { 16 | "colab_type": "text", 17 | "id": "axBioaovwAkv" 18 | }, 19 | "source": [ 20 | "In this second practical session, we will apply OT in two different Machine Learning applications:\n", 21 | "1. Color Transfer\n", 22 | "2. Document Clustering\n", 23 | "\n", 24 | "In Color Transfer, we will mainly be interested in the optimal transport plan itself, while in Document Clustering, we will be interested in the value of the Optimal Transport / Wasserstein distance." 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": { 30 | "colab_type": "text", 31 | "id": "oCUSMwp1wMZQ" 32 | }, 33 | "source": [ 34 | "We are going to use Google Collab to run this notebook. In order to install all the necessary files run the following cells:" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": { 41 | "colab": {}, 42 | "colab_type": "code", 43 | "id": "33pD5ciywNmm" 44 | }, 45 | "outputs": [], 46 | "source": [ 47 | "import os\n", 48 | "!pip install --upgrade git+https://github.com/mlss-skoltech/tutorials.git#subdirectory=optimal_transport_tutorial" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": { 55 | "colab": {}, 56 | "colab_type": "code", 57 | "id": "u7bJKDk6wyam" 58 | }, 59 | "outputs": [], 60 | "source": [ 61 | "# Check your installation by importing POT\n", 62 | "!pip install pot\n", 63 | "import ot" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": { 69 | "colab_type": "text", 70 | "id": "jE5dwh5jw4mK" 71 | }, 72 | "source": [ 73 | "Declare ```DATA_PATH``` as a path to the data from the tutorial package" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": { 80 | "colab": {}, 81 | "colab_type": "code", 82 | "id": "h_urSV_Nw7gz" 83 | }, 84 | "outputs": [], 85 | "source": [ 86 | "import pkg_resources\n", 87 | "\n", 88 | "DATA_PATH = pkg_resources.resource_filename('optimaltransport', 'data/')" 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "metadata": { 94 | "colab_type": "text", 95 | "id": "xdjSuC2JxA0g" 96 | }, 97 | "source": [ 98 | "If you are running this notebook locally, make sure to clone the tutorial repository:\n", 99 | "\n", 100 | "```!pip install --upgrade git+https://github.com/mlss-skoltech/tutorials.git#subdirectory=optimal_transport_tutorial```\n", 101 | "\n", 102 | "\n", 103 | "\n", 104 | "And install the following package:\n", 105 | "\n", 106 | "* Install with pip: ```bash pip install pot```\n", 107 | "* Install with conda: ```bash conda install -c conda-forge pot ```" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "metadata": { 114 | "colab": {}, 115 | "colab_type": "code", 116 | "id": "y9eqVdkrwAkw" 117 | }, 118 | "outputs": [], 119 | "source": [ 120 | "import numpy as np\n", 121 | "import ot" 122 | ] 123 | }, 124 | { 125 | "cell_type": "markdown", 126 | "metadata": { 127 | "colab_type": "text", 128 | "id": "d4TDmWaawAk0" 129 | }, 130 | "source": [ 131 | "## 1. Color Transfer" 132 | ] 133 | }, 134 | { 135 | "cell_type": "markdown", 136 | "metadata": { 137 | "colab_type": "text", 138 | "id": "KsTBGYN3wAk0" 139 | }, 140 | "source": [ 141 | "Given a source and a target image, the goal of color transfer is to transform the colors of the source image so that it looks similar to the target image color palette. In the end, we want to find a \"color mapping\", giving for each color of the source image a new color. This can be done by computing the optimal transport plan between the two images, seen as point clouds in the RGB space." 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": null, 147 | "metadata": { 148 | "colab": {}, 149 | "colab_type": "code", 150 | "id": "VzfbrFDFwAk1" 151 | }, 152 | "outputs": [], 153 | "source": [ 154 | "# For plotting\n", 155 | "import matplotlib.pyplot as plt\n", 156 | "from matplotlib.pyplot import imread\n", 157 | "from mpl_toolkits.mplot3d import Axes3D" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": null, 163 | "metadata": { 164 | "colab": {}, 165 | "colab_type": "code", 166 | "id": "13BAZ2KpwAk8" 167 | }, 168 | "outputs": [], 169 | "source": [ 170 | "# Load the images\n", 171 | "I1 = imread(DATA_PATH + 'schiele.jpg').astype(np.float64) / 256\n", 172 | "I2 = imread(DATA_PATH + 'schiele2.jpg').astype(np.float64) / 256\n", 173 | "\n", 174 | "fig = plt.figure(figsize=(17, 30))\n", 175 | "\n", 176 | "ax = fig.add_subplot(1, 2, 1)\n", 177 | "ax.imshow(I1)\n", 178 | "ax.set_title('Landscape', fontsize=25)\n", 179 | "ax.axis('off')\n", 180 | "\n", 181 | "ax = fig.add_subplot(1, 2, 2)\n", 182 | "ax.imshow(I2)\n", 183 | "ax.set_title('Portrait', fontsize=25)\n", 184 | "ax.axis('off')\n", 185 | "\n", 186 | "plt.show()" 187 | ] 188 | }, 189 | { 190 | "cell_type": "markdown", 191 | "metadata": { 192 | "colab_type": "text", 193 | "id": "pSPpFriLwAlA" 194 | }, 195 | "source": [ 196 | "We will need to work with \"matrices\" instead of images. Since there are 3 colors, images have shape `(Width, Height, 3)`, and the corresponding matrices will have shape `(Width*Height, 3)`." 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": null, 202 | "metadata": { 203 | "colab": {}, 204 | "colab_type": "code", 205 | "id": "emal4z2EwAlB" 206 | }, 207 | "outputs": [], 208 | "source": [ 209 | "def im2mat(I):\n", 210 | " '''Convert image I to matrix.'''\n", 211 | " return # TODO: reshape\n", 212 | "\n", 213 | "def mat2im(X, shape):\n", 214 | " '''Convert matrix X to image with shape 'shape'.'''\n", 215 | " return # TODO: reshape\n", 216 | "\n", 217 | "X1 = im2mat(I1)\n", 218 | "X2 = im2mat(I2)" 219 | ] 220 | }, 221 | { 222 | "cell_type": "markdown", 223 | "metadata": { 224 | "colab_type": "text", 225 | "id": "KprLOWg7wAlE" 226 | }, 227 | "source": [ 228 | "Real images have way too many different colors, so we will need to subsample them. In order to do this, we use K-means over all the colors, and keep only the computed centroids. Note that using Mini Batch K-Means will speed the computations up." 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": null, 234 | "metadata": { 235 | "colab": {}, 236 | "colab_type": "code", 237 | "id": "ypyN5TMjwAlF" 238 | }, 239 | "outputs": [], 240 | "source": [ 241 | "from sklearn.cluster import MiniBatchKMeans\n", 242 | "\n", 243 | "# Size of the subsampled point clouds\n", 244 | "nbsamples = 1000\n", 245 | "\n", 246 | "kmeans1 = # TODO: Mini Batch K-Means for X1\n", 247 | "X1_sampled = # TODO: get the centroids\n", 248 | "\n", 249 | "kmeans2 = # TODO: Mini Batch K-Means for X2\n", 250 | "X2_sampled = # TODO: get the centroids" 251 | ] 252 | }, 253 | { 254 | "cell_type": "markdown", 255 | "metadata": { 256 | "colab_type": "text", 257 | "id": "xYguNfDhwAlL" 258 | }, 259 | "source": [ 260 | "Each image is represented by its \"matrix\", i.e. is seen as a point cloud $X \\in \\mathbb{R}^{N\\times3}$ in the RGB color space, identified with $\\mathbb{R}^3$. " 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": null, 266 | "metadata": { 267 | "colab": {}, 268 | "colab_type": "code", 269 | "id": "jHszq7CPwAlM" 270 | }, 271 | "outputs": [], 272 | "source": [ 273 | "def showImageAsPointCloud(X, Y):\n", 274 | " '''Show the color palette associated with images X and Y.'''\n", 275 | " fig = plt.figure(figsize=(17,8))\n", 276 | " ax = fig.add_subplot(121, projection='3d')\n", 277 | " ax.set_xlim(0,1)\n", 278 | " ax.scatter(X[:,0], X[:,1], X[:,2], c=X, s=10, marker='o', alpha=0.6)\n", 279 | " ax.set_xlabel('R',fontsize=22)\n", 280 | " ax.set_xticklabels([])\n", 281 | " ax.set_ylim(0,1)\n", 282 | " ax.set_ylabel('G',fontsize=22)\n", 283 | " ax.set_yticklabels([])\n", 284 | " ax.set_zlim(0,1)\n", 285 | " ax.set_zlabel('B',fontsize=22)\n", 286 | " ax.set_zticklabels([])\n", 287 | " ax.set_title('Landscape Color Palette', fontsize=20)\n", 288 | " ax.grid('off')\n", 289 | " \n", 290 | " ax = fig.add_subplot(122, projection='3d')\n", 291 | " ax.set_xlim(0,1)\n", 292 | " ax.scatter(Y[:,0], Y[:,1], Y[:,2], c=Y, s=10, marker='o', alpha=0.6)\n", 293 | " ax.set_xlabel('R',fontsize=22)\n", 294 | " ax.set_xticklabels([])\n", 295 | " ax.set_ylim(0,1)\n", 296 | " ax.set_ylabel('G',fontsize=22)\n", 297 | " ax.set_yticklabels([])\n", 298 | " ax.set_zlim(0,1)\n", 299 | " ax.set_zlabel('B',fontsize=22)\n", 300 | " ax.set_zticklabels([])\n", 301 | " ax.set_title('Portrait Color Palette', fontsize=20)\n", 302 | " ax.grid('off')\n", 303 | " \n", 304 | " plt.show()" 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": null, 310 | "metadata": { 311 | "colab": {}, 312 | "colab_type": "code", 313 | "id": "d4r_n9j5wAlO" 314 | }, 315 | "outputs": [], 316 | "source": [ 317 | "showImageAsPointCloud(X1_sampled, X2_sampled)" 318 | ] 319 | }, 320 | { 321 | "cell_type": "markdown", 322 | "metadata": { 323 | "colab_type": "text", 324 | "id": "bWtBge5KwAlR" 325 | }, 326 | "source": [ 327 | "In order to compute the optimal transport plans between the two point clouds, we have to compute the corresponding cost matrix. In the following, we will always consider the squared distance, _i.e._ $C_{ij} = \\|X_i - Y_j\\|^2$." 328 | ] 329 | }, 330 | { 331 | "cell_type": "code", 332 | "execution_count": null, 333 | "metadata": { 334 | "colab": {}, 335 | "colab_type": "code", 336 | "id": "AV-S1C6-wAlR" 337 | }, 338 | "outputs": [], 339 | "source": [ 340 | "C = # TODO: compute the cost matrix using l2 ground distance" 341 | ] 342 | }, 343 | { 344 | "cell_type": "markdown", 345 | "metadata": { 346 | "colab_type": "text", 347 | "id": "hg6CSfzBwAlV" 348 | }, 349 | "source": [ 350 | "### Landscape with Portrait colors" 351 | ] 352 | }, 353 | { 354 | "cell_type": "markdown", 355 | "metadata": { 356 | "colab_type": "text", 357 | "id": "wexKcWriwAlV" 358 | }, 359 | "source": [ 360 | "Here, the goal is to transfer the colors of the portrait to the landscape. We will compute the exact Optimal Transport Plan, as well as the Entropy Regularized Optimal Transport plans." 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": null, 366 | "metadata": { 367 | "colab": {}, 368 | "colab_type": "code", 369 | "id": "LrILq63TwAlW" 370 | }, 371 | "outputs": [], 372 | "source": [ 373 | "regs = [0.01, 0.1, 0.5]\n", 374 | "OT_plans = [] # Contains the OT plans for regularization strengths : 0, 0.01, 0.1, 0.5\n", 375 | "OT_plans.append(ot.emd()) # TODO: OT plan for exact OT\n", 376 | "for reg in regs:\n", 377 | " OT_plans.append() # TODO: OT plan for regularization strength = reg" 378 | ] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "execution_count": null, 383 | "metadata": { 384 | "colab": {}, 385 | "colab_type": "code", 386 | "id": "VSLLPM4AwAld" 387 | }, 388 | "outputs": [], 389 | "source": [ 390 | "def colorTransfer(OT_plan, kmeans1, kmeans2, shape):\n", 391 | " '''Return the color-transfered image of shape \"shape\".'''\n", 392 | " return # TODO" 393 | ] 394 | }, 395 | { 396 | "cell_type": "code", 397 | "execution_count": null, 398 | "metadata": { 399 | "colab": {}, 400 | "colab_type": "code", 401 | "id": "IUJHdGJKwAlh" 402 | }, 403 | "outputs": [], 404 | "source": [ 405 | "fig = plt.figure(figsize=(17, 20))\n", 406 | "\n", 407 | "ax = fig.add_subplot(1, 2, 1)\n", 408 | "ax.imshow(I1)\n", 409 | "ax.set_title('Source Image', fontsize=20)\n", 410 | "ax.axis('off')\n", 411 | "\n", 412 | "ax = fig.add_subplot(1, 2, 2)\n", 413 | "I = colorTransfer(OT_plans[0], kmeans1, kmeans2, I1.shape)\n", 414 | "ax.imshow(I)\n", 415 | "ax.set_title('Reg = 0', fontsize=20)\n", 416 | "ax.axis('off')\n", 417 | "\n", 418 | "plt.show()\n", 419 | "\n", 420 | "fig = plt.figure(figsize=(17, 20))\n", 421 | "for i in range(3):\n", 422 | " ax = fig.add_subplot(2, 3, i+1)\n", 423 | " I = colorTransfer(OT_plans[i+1], kmeans1, kmeans2, I1.shape)\n", 424 | " ax.imshow(I)\n", 425 | " ax.set_title('Reg = '+str(regs[i]), fontsize=20)\n", 426 | " ax.axis('off')\n", 427 | "\n", 428 | "plt.show()" 429 | ] 430 | }, 431 | { 432 | "cell_type": "markdown", 433 | "metadata": { 434 | "colab_type": "text", 435 | "id": "zR5JREbywAlk" 436 | }, 437 | "source": [ 438 | "### Portait with Landscape colors" 439 | ] 440 | }, 441 | { 442 | "cell_type": "markdown", 443 | "metadata": { 444 | "colab_type": "text", 445 | "id": "vEuyDVKswAll" 446 | }, 447 | "source": [ 448 | "We now transfer the colors of the landscape to the portrait." 449 | ] 450 | }, 451 | { 452 | "cell_type": "code", 453 | "execution_count": null, 454 | "metadata": { 455 | "colab": {}, 456 | "colab_type": "code", 457 | "id": "4y4w6yoYwAlm" 458 | }, 459 | "outputs": [], 460 | "source": [ 461 | "C = # TODO" 462 | ] 463 | }, 464 | { 465 | "cell_type": "code", 466 | "execution_count": null, 467 | "metadata": { 468 | "colab": {}, 469 | "colab_type": "code", 470 | "id": "hkKbTE74wAlr" 471 | }, 472 | "outputs": [], 473 | "source": [ 474 | "regs = [0.01, 0.03, 0.1]\n", 475 | "OT_plans = []\n", 476 | "OT_plans.append(ot.emd()) # TODO\n", 477 | "for reg in regs:\n", 478 | " OT_plans.append() # TODO" 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": null, 484 | "metadata": { 485 | "colab": {}, 486 | "colab_type": "code", 487 | "id": "wrM8E7_4wAlv" 488 | }, 489 | "outputs": [], 490 | "source": [ 491 | "def colorTransfer(OT_plan, kmeans1, kmeans2, shape):\n", 492 | " return # TODO" 493 | ] 494 | }, 495 | { 496 | "cell_type": "code", 497 | "execution_count": null, 498 | "metadata": { 499 | "colab": {}, 500 | "colab_type": "code", 501 | "id": "vMIVlsDGwAly" 502 | }, 503 | "outputs": [], 504 | "source": [ 505 | "fig = plt.figure(figsize=(17, 20))\n", 506 | "\n", 507 | "ax = fig.add_subplot(1, 2, 1)\n", 508 | "ax.imshow(I2)\n", 509 | "ax.set_title('Source Image', fontsize=20)\n", 510 | "ax.axis('off')\n", 511 | "\n", 512 | "ax = fig.add_subplot(1, 2, 2)\n", 513 | "I = colorTransfer(OT_plans[0], kmeans1, kmeans2, I2.shape)\n", 514 | "ax.imshow(I)\n", 515 | "ax.set_title('Reg = 0', fontsize=20)\n", 516 | "ax.axis('off')\n", 517 | "\n", 518 | "plt.show()\n", 519 | "\n", 520 | "fig = plt.figure(figsize=(17, 20))\n", 521 | "for i in range(3):\n", 522 | " ax = fig.add_subplot(2, 3, i+1)\n", 523 | " I = colorTransfer(OT_plans[i+1], kmeans1, kmeans2, I2.shape)\n", 524 | " ax.imshow(I)\n", 525 | " ax.set_title('Reg = '+str(regs[i]), fontsize=20)\n", 526 | " ax.axis('off')\n", 527 | "\n", 528 | "plt.show()" 529 | ] 530 | }, 531 | { 532 | "cell_type": "markdown", 533 | "metadata": { 534 | "colab_type": "text", 535 | "id": "eNFlM3RywAl1" 536 | }, 537 | "source": [ 538 | "## 2. Document Clustering" 539 | ] 540 | }, 541 | { 542 | "cell_type": "markdown", 543 | "metadata": { 544 | "colab_type": "text", 545 | "id": "cq9danJpwAl2" 546 | }, 547 | "source": [ 548 | "We would likle to classify several text documents. In order to do this, we will:\n", 549 | "1. Transform each text into a point cloud\n", 550 | "2. Compute the Optimal Transport distances between each pair of point clouds\n", 551 | "3. Use MDS to plot the different clusters in 2 dimensions" 552 | ] 553 | }, 554 | { 555 | "cell_type": "markdown", 556 | "metadata": { 557 | "colab_type": "text", 558 | "id": "iLjCFnUHwAl2" 559 | }, 560 | "source": [ 561 | "### Load the Data and Preprocessing\n", 562 | "We consider seven movie scenarios. We transformed each of them into a point cloud using the following steps:\n", 563 | "1. Keep only the words among the $2.000 - 20.000$ most common words\n", 564 | "2. Each remaining word is transformed into a $300$-dimensional vector using word2vec\n", 565 | "3. Each word is given a weight proportional to its frequency\n", 566 | "\n", 567 | "The variable `texts` is a list of tuples. Each tuple represents a movie, and contains two parts:\n", 568 | "1. A matrix $X \\in \\mathbb{R}^{n \\times 300}$ where $n$ is the number of different words, containing the position of the points\n", 569 | "2. A vector $a \\in \\mathbb{R}^n$ containing the corresponding weights" 570 | ] 571 | }, 572 | { 573 | "cell_type": "code", 574 | "execution_count": null, 575 | "metadata": { 576 | "colab": {}, 577 | "colab_type": "code", 578 | "id": "SwMrZHa3wAl5" 579 | }, 580 | "outputs": [], 581 | "source": [ 582 | "import pickle\n", 583 | "\n", 584 | "with open(DATA_PATH + 'texts.pickle', 'rb') as file:\n", 585 | " texts = pickle.load(file)\n", 586 | "\n", 587 | "movies = ['DUNKIRK', 'GRAVITY', 'INTERSTELLAR', 'KILL BILL VOL.1', 'KILL BILL VOL.2', 'THE MARTIAN', 'TITANIC']" 588 | ] 589 | }, 590 | { 591 | "cell_type": "markdown", 592 | "metadata": { 593 | "colab_type": "text", 594 | "id": "Q9kWV4ltwAl8" 595 | }, 596 | "source": [ 597 | "### Compute the OT distances" 598 | ] 599 | }, 600 | { 601 | "cell_type": "code", 602 | "execution_count": null, 603 | "metadata": { 604 | "colab": {}, 605 | "colab_type": "code", 606 | "id": "hCLVUY9xwAl9" 607 | }, 608 | "outputs": [], 609 | "source": [ 610 | "# Set regularization strength\n", 611 | "reg = 0.1" 612 | ] 613 | }, 614 | { 615 | "cell_type": "code", 616 | "execution_count": null, 617 | "metadata": { 618 | "colab": {}, 619 | "colab_type": "code", 620 | "id": "yv4xFXtHwAmB" 621 | }, 622 | "outputs": [], 623 | "source": [ 624 | "def costMatrix(i,j):\n", 625 | " '''Return the cost matrix C between movies number i and j.'''\n", 626 | " X = texts[i][0]\n", 627 | " Y = texts[j][0]\n", 628 | " \n", 629 | " return # TODO" 630 | ] 631 | }, 632 | { 633 | "cell_type": "code", 634 | "execution_count": null, 635 | "metadata": { 636 | "colab": {}, 637 | "colab_type": "code", 638 | "id": "T9jSqIywwAmD" 639 | }, 640 | "outputs": [], 641 | "source": [ 642 | "#this cell will take approximately 1 minute to compute in Google Colaboratory after you complete it\n", 643 | "OT_distances = np.zeros((7,7))\n", 644 | "# TODO: compute the OT distance (using Sinkhorn algorithm ot.sinkhorn) between all the pairs of scenarios" 645 | ] 646 | }, 647 | { 648 | "cell_type": "code", 649 | "execution_count": null, 650 | "metadata": {}, 651 | "outputs": [], 652 | "source": [ 653 | "for film in movies:\n", 654 | " print('The film most similar to', film, 'is', # TODO)" 655 | ] 656 | }, 657 | { 658 | "cell_type": "markdown", 659 | "metadata": { 660 | "colab_type": "text", 661 | "id": "QPJaNctTwAmF" 662 | }, 663 | "source": [ 664 | "### Plot the MDS projection" 665 | ] 666 | }, 667 | { 668 | "cell_type": "code", 669 | "execution_count": null, 670 | "metadata": { 671 | "colab": {}, 672 | "colab_type": "code", 673 | "id": "GFeOuUdbwAmF" 674 | }, 675 | "outputs": [], 676 | "source": [ 677 | "from sklearn.manifold import MDS\n", 678 | "embedding = MDS(n_components=2, dissimilarity='precomputed')\n", 679 | "dis = OT_distances - OT_distances[OT_distances>0].min()\n", 680 | "np.fill_diagonal(dis, 0.)\n", 681 | "embedding = embedding.fit(dis)\n", 682 | "X = embedding.embedding_\n", 683 | "\n", 684 | "import matplotlib.pyplot as plt\n", 685 | "plt.figure(figsize=(17,6))\n", 686 | "plt.scatter(X[:,0], X[:,1], alpha=0.)\n", 687 | "plt.axis('equal')\n", 688 | "plt.axis('off')\n", 689 | "c = {'KILL BILL VOL.1':'red', 'KILL BILL VOL.2':'red', 'TITANIC':'blue', 'DUNKIRK':'blue', 'GRAVITY':'black', 'INTERSTELLAR':'black', 'THE MARTIAN':'black'}\n", 690 | "for film in movies:\n", 691 | " i = movies.index(film)\n", 692 | " plt.gca().annotate(film, X[i], size=30, ha='center', color=c[film], weight=\"bold\", alpha=0.7)\n", 693 | "plt.show()" 694 | ] 695 | }, 696 | { 697 | "cell_type": "code", 698 | "execution_count": null, 699 | "metadata": { 700 | "colab": {}, 701 | "colab_type": "code", 702 | "id": "EVruQxj1zSfW" 703 | }, 704 | "outputs": [], 705 | "source": [] 706 | } 707 | ], 708 | "metadata": { 709 | "colab": { 710 | "name": "MLSS 2 Optimal Transport for Machine Learning (student version).ipynb", 711 | "provenance": [], 712 | "version": "0.3.2" 713 | }, 714 | "kernelspec": { 715 | "display_name": "Python 3", 716 | "language": "python", 717 | "name": "python3" 718 | }, 719 | "language_info": { 720 | "codemirror_mode": { 721 | "name": "ipython", 722 | "version": 3 723 | }, 724 | "file_extension": ".py", 725 | "mimetype": "text/x-python", 726 | "name": "python", 727 | "nbconvert_exporter": "python", 728 | "pygments_lexer": "ipython3", 729 | "version": "3.7.1" 730 | } 731 | }, 732 | "nbformat": 4, 733 | "nbformat_minor": 1 734 | } 735 | -------------------------------------------------------------------------------- /optimal_transport_tutorial/optimaltransport/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/optimal_transport_tutorial/optimaltransport/__init__.py -------------------------------------------------------------------------------- /optimal_transport_tutorial/optimaltransport/data/croissants.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/optimal_transport_tutorial/optimaltransport/data/croissants.pickle -------------------------------------------------------------------------------- /optimal_transport_tutorial/optimaltransport/data/schiele.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/optimal_transport_tutorial/optimaltransport/data/schiele.jpg -------------------------------------------------------------------------------- /optimal_transport_tutorial/optimaltransport/data/schiele2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/optimal_transport_tutorial/optimaltransport/data/schiele2.jpg -------------------------------------------------------------------------------- /optimal_transport_tutorial/optimaltransport/data/texts.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlss-skoltech/tutorials/e23a317eb748102712b8c16452b696d37a1ac065/optimal_transport_tutorial/optimaltransport/data/texts.pickle -------------------------------------------------------------------------------- /optimal_transport_tutorial/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | 3 | 4 | setup( 5 | name="optimaltransport", 6 | version="0.1", 7 | include_package_data=True, 8 | packages=[ 9 | "optimaltransport", 10 | ] 11 | 12 | ) --------------------------------------------------------------------------------