├── Day1 ├── .ipynb_checkpoints │ ├── students_Bayesian_regression-checkpoint.ipynb │ └── students_PPLs_Intro-checkpoint.ipynb ├── slides-L1.pdf ├── slides │ └── Figures │ │ ├── Ice-cream_shop_-_Florida.jpg │ │ ├── PGM-Tem-Sensor.png │ │ ├── PGM-Tem-Sensor2.png │ │ ├── PGM-Tem-Sensor3.png │ │ ├── africa-III.png │ │ ├── africa-IX.png │ │ ├── icecream-model-temporal.png │ │ ├── tempmodel-II-graph.png │ │ ├── tempmodel-V-b.png │ │ ├── tempmodel-VI-b.png │ │ └── tempmodel-temporal-III.png ├── solutions_PPLs_Intro.ipynb ├── solutions_bayesian_regression.ipynb ├── students_Bayesian_regression.ipynb └── students_PPLs_Intro.ipynb ├── Day2 ├── slides-L2.pdf ├── solution_lin_reg.ipynb ├── solution_simple_model.ipynb ├── students_lin_reg.ipynb └── students_simple_model.ipynb ├── Day3 ├── .ipynb_checkpoints │ ├── Bayesian_linear_regression-checkpoint.ipynb │ ├── FA-checkpoint.ipynb │ ├── VAE-checkpoint.ipynb │ ├── solution_BBVI-checkpoint.ipynb │ └── solution_simple_model-checkpoint.ipynb ├── BBVI-gradient-variance.eps ├── BBVI_exercise.png ├── Bayesian_linear_regression.ipynb ├── FA.ipynb ├── FA_model.png ├── VAE.ipynb ├── elbo_evolution.pdf ├── elbo_evolution_with_1_samples.pdf ├── reg_model.png ├── simple_pyro_exercise.png ├── slides-L3.pdf ├── solution_BBVI.ipynb ├── solution_simple_model.ipynb ├── student_BBVI.ipynb └── student_simple_model.ipynb ├── Readme.md └── environment.yml /Day1/slides-L1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PGM-Lab/probabilisticAI_tutorials/40b644dccee2d26e469bb5fd32ae01e430e66200/Day1/slides-L1.pdf -------------------------------------------------------------------------------- /Day1/slides/Figures/Ice-cream_shop_-_Florida.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PGM-Lab/probabilisticAI_tutorials/40b644dccee2d26e469bb5fd32ae01e430e66200/Day1/slides/Figures/Ice-cream_shop_-_Florida.jpg -------------------------------------------------------------------------------- /Day1/slides/Figures/PGM-Tem-Sensor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PGM-Lab/probabilisticAI_tutorials/40b644dccee2d26e469bb5fd32ae01e430e66200/Day1/slides/Figures/PGM-Tem-Sensor.png -------------------------------------------------------------------------------- /Day1/slides/Figures/PGM-Tem-Sensor2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PGM-Lab/probabilisticAI_tutorials/40b644dccee2d26e469bb5fd32ae01e430e66200/Day1/slides/Figures/PGM-Tem-Sensor2.png -------------------------------------------------------------------------------- /Day1/slides/Figures/PGM-Tem-Sensor3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PGM-Lab/probabilisticAI_tutorials/40b644dccee2d26e469bb5fd32ae01e430e66200/Day1/slides/Figures/PGM-Tem-Sensor3.png -------------------------------------------------------------------------------- /Day1/slides/Figures/africa-III.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PGM-Lab/probabilisticAI_tutorials/40b644dccee2d26e469bb5fd32ae01e430e66200/Day1/slides/Figures/africa-III.png -------------------------------------------------------------------------------- /Day1/slides/Figures/africa-IX.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PGM-Lab/probabilisticAI_tutorials/40b644dccee2d26e469bb5fd32ae01e430e66200/Day1/slides/Figures/africa-IX.png -------------------------------------------------------------------------------- /Day1/slides/Figures/icecream-model-temporal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PGM-Lab/probabilisticAI_tutorials/40b644dccee2d26e469bb5fd32ae01e430e66200/Day1/slides/Figures/icecream-model-temporal.png -------------------------------------------------------------------------------- /Day1/slides/Figures/tempmodel-II-graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PGM-Lab/probabilisticAI_tutorials/40b644dccee2d26e469bb5fd32ae01e430e66200/Day1/slides/Figures/tempmodel-II-graph.png -------------------------------------------------------------------------------- /Day1/slides/Figures/tempmodel-V-b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PGM-Lab/probabilisticAI_tutorials/40b644dccee2d26e469bb5fd32ae01e430e66200/Day1/slides/Figures/tempmodel-V-b.png -------------------------------------------------------------------------------- /Day1/slides/Figures/tempmodel-VI-b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PGM-Lab/probabilisticAI_tutorials/40b644dccee2d26e469bb5fd32ae01e430e66200/Day1/slides/Figures/tempmodel-VI-b.png -------------------------------------------------------------------------------- /Day1/slides/Figures/tempmodel-temporal-III.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PGM-Lab/probabilisticAI_tutorials/40b644dccee2d26e469bb5fd32ae01e430e66200/Day1/slides/Figures/tempmodel-temporal-III.png -------------------------------------------------------------------------------- /Day2/slides-L2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PGM-Lab/probabilisticAI_tutorials/40b644dccee2d26e469bb5fd32ae01e430e66200/Day2/slides-L2.pdf -------------------------------------------------------------------------------- /Day2/students_lin_reg.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import pandas as pd\n", 11 | "import seaborn as sns\n", 12 | "import matplotlib.pyplot as plt\n", 13 | "%matplotlib inline" 14 | ] 15 | }, 16 | { 17 | "cell_type": "markdown", 18 | "metadata": {}, 19 | "source": [ 20 | "In this code task you should implement your updating for the variational distribution of the intercept 'b', and use your implementation to learn a Bayesian linear regression model for the 'ruggedness' data, which we also considered yesterday. " 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "metadata": {}, 26 | "source": [ 27 | "# Dataset \n", 28 | "\n", 29 | "The following example is adapted from \\[1\\]. We would like to explore the relationship between topographic heterogeneity of a nation as measured by the Terrain Ruggedness Index (variable *rugged* in the dataset) and its GDP per capita. In particular, it was noted by the authors in \\[1\\] that terrain ruggedness or bad geography is related to poorer economic performance outside of Africa, but rugged terrains have had a reverse effect on income for African nations. Let us look at the data \\[2\\] and investigate this relationship. We will be focusing on three features from the dataset:\n", 30 | " - `rugged`: quantifies the Terrain Ruggedness Index\n", 31 | " - `cont_africa`: whether the given nation is in Africa\n", 32 | " - `rgdppc_2000`: Real GDP per capita for the year 2000\n", 33 | " \n", 34 | "We will take the logarithm for the response variable GDP as it tends to vary exponentially." 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 2, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "DATA_URL = \"https://d2fefpcigoriu7.cloudfront.net/datasets/rugged_data.csv\"\n", 44 | "data = pd.read_csv(DATA_URL, encoding=\"ISO-8859-1\")\n", 45 | "df = data[[\"cont_africa\", \"rugged\", \"rgdppc_2000\"]]\n", 46 | "df = df[np.isfinite(df.rgdppc_2000)]\n", 47 | "df[\"rgdppc_2000\"] = np.log(df[\"rgdppc_2000\"])\n", 48 | "df[\"african_rugged\"] = data[\"cont_africa\"] * data[\"rugged\"]\n", 49 | "df = df[[\"cont_africa\", \"rugged\", \"african_rugged\", \"rgdppc_2000\"]]\n", 50 | "\n", 51 | "# Divide the data into poredictors and response and store the data in numpy arrays\n", 52 | "data = np.array(df)\n", 53 | "x_data = data[:, :-1]\n", 54 | "y_data = data[:, -1]" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 3, 60 | "metadata": { 61 | "scrolled": true 62 | }, 63 | "outputs": [ 64 | { 65 | "data": { 66 | "text/html": [ 67 | "
\n", 68 | "\n", 81 | "\n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | "
cont_africaruggedafrican_ruggedrgdppc_2000
210.8580.8587.492609
403.4270.0008.216929
700.7690.0009.933263
800.7750.0009.407032
902.6880.0007.792343
1100.0060.0009.212541
1200.1430.00010.143191
1303.5130.00010.274632
1401.6720.0007.852028
1511.7801.7806.432380
\n", 164 | "
" 165 | ], 166 | "text/plain": [ 167 | " cont_africa rugged african_rugged rgdppc_2000\n", 168 | "2 1 0.858 0.858 7.492609\n", 169 | "4 0 3.427 0.000 8.216929\n", 170 | "7 0 0.769 0.000 9.933263\n", 171 | "8 0 0.775 0.000 9.407032\n", 172 | "9 0 2.688 0.000 7.792343\n", 173 | "11 0 0.006 0.000 9.212541\n", 174 | "12 0 0.143 0.000 10.143191\n", 175 | "13 0 3.513 0.000 10.274632\n", 176 | "14 0 1.672 0.000 7.852028\n", 177 | "15 1 1.780 1.780 6.432380" 178 | ] 179 | }, 180 | "metadata": {}, 181 | "output_type": "display_data" 182 | } 183 | ], 184 | "source": [ 185 | "# Display first 10 entries \n", 186 | "display(df[0:10])" 187 | ] 188 | }, 189 | { 190 | "cell_type": "markdown", 191 | "metadata": {}, 192 | "source": [ 193 | "# The model\n", 194 | "\n", 195 | "Following the approach from Day 1 we will model the data using a Bayesian linear regression model:\n", 196 | "\n", 197 | "The quantitative part of the model is specified as: \n", 198 | "- Number of data dim: $M$\n", 199 | "- Number of data inst: $N$\n", 200 | "- $Y_{i}|\\{{\\bf w}, {\\bf x}_i, b, \\theta \\} \\sim \\mathcal{N}({\\bf w}^T{\\bf x}_i +b, 1/\\theta)$ \n", 201 | "- ${\\bf W} \\sim {\\mathcal N}({\\bf 0}, \\gamma_w^{-1}{\\bf I}_{M\\times M})$\n", 202 | "- $b\\sim {\\mathcal N}(0,\\gamma_b^{-1})$" 203 | ] 204 | }, 205 | { 206 | "cell_type": "markdown", 207 | "metadata": {}, 208 | "source": [ 209 | "## Helper-routine: Calculate ELBO" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 4, 215 | "metadata": {}, 216 | "outputs": [], 217 | "source": [ 218 | "def calculate_ELBO(x_data, y_data, gamma_w, gamma_b, theta, q_w_mean, q_w_prec, q_b_mean, q_b_prec):\n", 219 | " \"\"\"\n", 220 | " Helper routine: Calculate ELBO. Data is the sampled x and y values, gamma_w and gamma_b are the prior precisions \n", 221 | " over the weights and intercdpt, respectively, and theta is the prior precision associated with y. Everything \n", 222 | " prefixed with a 'q' relates to the variational posterior.\n", 223 | " \n", 224 | " Note: This function obviously only works for this particular model and is not a general solution.\n", 225 | "\n", 226 | " :param x_data: The predictors\n", 227 | " :param y_data: The response variable\n", 228 | " :param gamma_w: prior precision for the weights\n", 229 | " :param gamma_b: prior precision for the intercept\n", 230 | " :param theta: prior precision for y\n", 231 | " :param q_w_mean: VB posterior mean for the distribution of the weights w \n", 232 | " :param q_w_prec: VB posterior precision (diagonal matrix) for the distribution of the weights w \n", 233 | " :param q_b_mean: VB posterior mean for the intercept b\n", 234 | " :param q_b_prec: VB posterior precision for the intercept b\n", 235 | " :return: the ELBO\n", 236 | " \"\"\"\n", 237 | " \n", 238 | " # We calculate the ELBO as E_q log p(y,x,w,b) - E_q log q(w,b), where\n", 239 | " # log p(y,x,w) = sum_i log p(y|x,w,b) + log p(w) + log p(b)\n", 240 | " # log q(w,b) = log q(w) + log q(b)\n", 241 | "\n", 242 | " M = x_data.shape[1]\n", 243 | "\n", 244 | " # E_q log p(w)\n", 245 | " E_log_p = -0.5 * M * np.log(2 * np.pi) + 0.5 * M * gamma_w - 0.5 * gamma_w * np.sum(np.diagonal(np.linalg.inv(q_w_prec))\n", 246 | " + (q_w_mean*q_w_mean).flatten())\n", 247 | " # E_q log p(b)\n", 248 | " E_log_p += -0.5 * np.log(2 * np.pi) + 0.5 * np.log(gamma_b) - 0.5 * gamma_b * (1/q_b_prec + q_b_mean**2)\n", 249 | "\n", 250 | " # sum_i E_q log p(y|x,w,b)\n", 251 | " E_w_w = np.linalg.inv(q_w_prec) + q_w_mean @ q_w_mean.transpose()\n", 252 | " E_b_b = 1/q_b_prec + q_b_mean**2\n", 253 | " for i in range(x_data.shape[0]):\n", 254 | " E_x_ww_x = np.matmul(x_data[i, :].transpose(), np.matmul(E_w_w, x_data[i, :]))\n", 255 | " E_log_p += -0.5 * np.log(2 * np.pi) + 0.5 * np.log(theta) \\\n", 256 | " - 0.5 * theta * (y_data[i]**2 + E_x_ww_x + E_b_b\n", 257 | " + 2 * q_b_mean * np.matmul(q_w_mean.transpose(), x_data[i, :])\n", 258 | " - 2 * y_data[i] * np.matmul(q_w_mean.transpose(), x_data[i,:])\n", 259 | " - 2 * y_data[i] * q_b_mean)\n", 260 | "\n", 261 | " # Entropy of q_b\n", 262 | " ent = 0.5 * np.log(1 * np.pi * np.exp(1) / q_b_prec)\n", 263 | " ent += 0.5 * np.log(np.linalg.det(2 * np.pi * np.exp(1) * np.linalg.inv(q_w_prec)))\n", 264 | "\n", 265 | " return E_log_p - ent" 266 | ] 267 | }, 268 | { 269 | "cell_type": "markdown", 270 | "metadata": {}, 271 | "source": [ 272 | "# Full mean field\n", 273 | "First we consider a full mean filed approach, where the variational approximation factorizes as\n", 274 | "$$\n", 275 | "q({\\bf w}, b) = q(b)\\prod _{i=1}^Mq(w_i)\n", 276 | "$$\n", 277 | "\n", 278 | "The following method codes the variational updating equation for the linear regression weights, $\\textbf{W}$, derived in the slide number 11 (page 39). " 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": null, 284 | "metadata": {}, 285 | "outputs": [], 286 | "source": [ 287 | "# The variational updating rule for weight component 'comp'. Observe that this is a direct implementaiton of the \n", 288 | "# updating rule from the slide.\n", 289 | "def update_w_comp(x_data, y_data, gamma_w, theta, q_w_mean, q_w_prec, q_b_mean, comp):\n", 290 | "\n", 291 | " # Lenght of weight vector\n", 292 | " M = x_data.shape[1]\n", 293 | " # The precision (a scalar)\n", 294 | " tau = gamma_w\n", 295 | " # The mean (a scalar)\n", 296 | " mu = 0.0\n", 297 | " for i in range(x_data.shape[0]):\n", 298 | " tau += theta * x_data[i, comp]**2\n", 299 | " mu += (y_data[i] - q_b_mean - (np.sum(x_data[i, :] @ q_w_mean) - x_data[i, comp]*q_w_mean[comp])) \\\n", 300 | " * x_data[i, comp]\n", 301 | " mu = theta * 1/tau * mu\n", 302 | "\n", 303 | " # Update the appropriate entries in the mean vector and precision matrix\n", 304 | " q_w_prec[comp, comp] = tau\n", 305 | " q_w_mean[comp] = mu.item()\n", 306 | "\n", 307 | " return q_w_prec, q_w_mean\n", 308 | "\n" 309 | ] 310 | }, 311 | { 312 | "cell_type": "markdown", 313 | "metadata": {}, 314 | "source": [ 315 | "Now you have to code the variational updating rule for the intercetp $B$. This updating rule only depends on $\\textbf{x}$, $\\textbf{y}$, $\\gamma_b$, $\\theta$ and the mean of the variational posterior distribution over the weights $\\textbf{W}$." 316 | ] 317 | }, 318 | { 319 | "cell_type": "code", 320 | "execution_count": null, 321 | "metadata": {}, 322 | "outputs": [], 323 | "source": [ 324 | "# The variational updating rule for the intercept\n", 325 | "def update_b(x_data, y_data, gamma_b, theta, q_w_mean):\n", 326 | "\n", 327 | " # The precision (a scalar)\n", 328 | " tau = ???????\n", 329 | " # The mean (a scalar)\n", 330 | " mu = ???????\n", 331 | "\n", 332 | " return tau, mu" 333 | ] 334 | }, 335 | { 336 | "cell_type": "markdown", 337 | "metadata": {}, 338 | "source": [ 339 | "Once coded, you can test if it works by running the code below and looking that the ELBO monotonically increases. " 340 | ] 341 | }, 342 | { 343 | "cell_type": "markdown", 344 | "metadata": {}, 345 | "source": [ 346 | "## Do the VB (full mean field)" 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": null, 352 | "metadata": {}, 353 | "outputs": [], 354 | "source": [ 355 | "# Initialize the variational distributions\n", 356 | "M = x_data.shape[1]\n", 357 | "gamma_w = 1\n", 358 | "gamma_b = 1\n", 359 | "theta = 1\n", 360 | "q_w_mean = np.random.normal(0, 1, (3, 1))\n", 361 | "q_w_prec = np.diag((1, 1, 1)) # We store the precisions for the weights in sa diagonal matrix\n", 362 | "q_b_mean = np.random.normal(0, 1)\n", 363 | "q_b_prec = 1\n", 364 | "\n", 365 | "# Keep track of the ELBO values\n", 366 | "elbos = []\n", 367 | "\n", 368 | "# Calculate ELBO\n", 369 | "this_lb = calculate_ELBO(x_data, y_data, gamma_w, gamma_b, theta, q_w_mean, q_w_prec, q_b_mean, q_b_prec)\n", 370 | "elbos.append(this_lb)\n", 371 | "\n", 372 | "# Start iterating\n", 373 | "previous_lb = -np.inf\n", 374 | "print(\"\\n\" + 100 * \"=\" + \"\\n VB iterations:\\n\" + 100 * \"=\")\n", 375 | "for iteration in range(100):\n", 376 | "\n", 377 | " # Update the variational distributions; one update for each component in the weight vectoe\n", 378 | " for i in range(M):\n", 379 | " q_w_prec, q_w_mean = update_w_comp(x_data, y_data, gamma_w, theta, q_w_mean, q_w_prec, q_b_mean, i)\n", 380 | " q_b_prec, q_b_mean = update_b(x_data, y_data, gamma_b, theta, q_w_mean)\n", 381 | "\n", 382 | " # Calculate the ELBO\n", 383 | " this_lb = calculate_ELBO(x_data, y_data, gamma_w, gamma_b, theta, q_w_mean, q_w_prec, q_b_mean, q_b_prec)\n", 384 | " elbos.append(this_lb)\n", 385 | " print(f\"Iteration {iteration:2d}. ELBO: {this_lb.item():13.7f}\")\n", 386 | " if this_lb < previous_lb:\n", 387 | " raise ValueError(\"ELBO is decreasing. Something is wrong! Goodbye...\")\n", 388 | " \n", 389 | " if iteration > 0 and np.abs((this_lb - previous_lb) / previous_lb) < 1E-8:\n", 390 | " # Very little improvement. We are done.\n", 391 | " break\n", 392 | " \n", 393 | " # If we didn't break we need to run again. Update the value for \"previous\"\n", 394 | " previous_lb = this_lb\n", 395 | "print(\"\\n\" + 100 * \"=\" + \"\\n\")\n", 396 | "\n", 397 | "# Store the results\n", 398 | "w_mean_mf = q_w_mean\n", 399 | "w_prec_mf = q_w_prec\n", 400 | "b_mean_mf = q_b_mean\n", 401 | "b_prec_mf = q_b_prec\n", 402 | "\n", 403 | "plt.plot(range(len(elbos)), elbos)\n", 404 | "plt.xlabel('NUmber of iterations')\n", 405 | "plt.ylabel('ELBO')" 406 | ] 407 | }, 408 | { 409 | "cell_type": "markdown", 410 | "metadata": {}, 411 | "source": [ 412 | "## Model evaluation\n", 413 | "\n", 414 | "To get a sense of the robustness of the model we draw samples from the posterior variational distributions over the weights and intercept; each sample correspond to a regression line" 415 | ] 416 | }, 417 | { 418 | "cell_type": "code", 419 | "execution_count": null, 420 | "metadata": {}, 421 | "outputs": [], 422 | "source": [ 423 | "fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 6), sharey=True)\n", 424 | "fig.suptitle(\"Uncertainty in Regression line \", fontsize=16)\n", 425 | "num_samples = 20\n", 426 | "\n", 427 | "ax[0].scatter(x_data[x_data[:,0]==0,1], y_data[x_data[:,0]==0])\n", 428 | "for _ in range(num_samples):\n", 429 | " b_sample = np.random.normal(loc=q_b_mean, scale=1/np.sqrt(q_b_prec))\n", 430 | " w_sample = np.random.multivariate_normal(mean=q_w_mean.flatten(), cov=np.linalg.inv(q_w_prec))\n", 431 | " ax[0].plot(x_data[x_data[:,0]==0,1], (x_data[x_data[:,0]==0,:] @ w_sample)+b_sample, 'r-')\n", 432 | "ax[0].set(xlabel=\"Terrain Ruggedness Index\",\n", 433 | " ylabel=\"log GDP (2000)\",\n", 434 | " title=\"Non African Nations\")\n", 435 | "\n", 436 | "ax[1].scatter(x_data[x_data[:,0]==1,1], y_data[x_data[:,0]==1])\n", 437 | "for _ in range(num_samples):\n", 438 | " b_sample = np.random.normal(loc=q_b_mean, scale=1/np.sqrt(q_b_prec))\n", 439 | " w_sample = np.random.multivariate_normal(mean=q_w_mean.flatten(), cov=np.linalg.inv(q_w_prec))\n", 440 | " ax[1].plot(x_data[x_data[:,0]==1,1], (x_data[x_data[:,0]==1,:] @ w_sample)+b_sample, 'r-')\n", 441 | "ax[1].set(xlabel=\"Terrain Ruggedness Index\",\n", 442 | " ylabel=\"log GDP (2000)\",\n", 443 | " title=\"African Nations\")\n", 444 | "\n", 445 | "plt.show()" 446 | ] 447 | } 448 | ], 449 | "metadata": { 450 | "kernelspec": { 451 | "display_name": "probabilistic.ai", 452 | "language": "python", 453 | "name": "probabilistic.ai" 454 | }, 455 | "language_info": { 456 | "codemirror_mode": { 457 | "name": "ipython", 458 | "version": 3 459 | }, 460 | "file_extension": ".py", 461 | "mimetype": "text/x-python", 462 | "name": "python", 463 | "nbconvert_exporter": "python", 464 | "pygments_lexer": "ipython3", 465 | "version": "3.7.0" 466 | } 467 | }, 468 | "nbformat": 4, 469 | "nbformat_minor": 2 470 | } 471 | -------------------------------------------------------------------------------- /Day2/students_simple_model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Simple model description\n", 8 | "In this code-task we work with a fairly simple model, where we have observations $x_i$, $i=1,\\ldots N$, that we assume follow a Gaussian distribution. The mean and precision (inverse variance) are unknown, so we model them in Bayesian way: The mean denoted by the random variable $\\mu$ is a Gaussian with a priori mean $0$ and precision $\\tau$. The precision of the data generating process is modelled using the random variable $\\gamma$. $\\gamma$ is a priori Gamma distributed with parameters $\\alpha$ (shape) and $\\beta$ (rate).\n", 9 | "\n", 10 | "$$\n", 11 | "\\mu \\sim Normal(0,\\tau^{-1})\\\\\n", 12 | "\\gamma \\sim Gamma(\\alpha,\\beta)\\\\\n", 13 | "x_i \\sim Normal(\\mu, \\gamma)\n", 14 | "$$\n", 15 | "\n", 16 | "In total, the model is thus like this: $\\mu \\rightarrow X_i \\leftarrow \\gamma$ (hyper-parameters not shown)." 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "metadata": {}, 22 | "source": [ 23 | "### Imports" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 1, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "import numpy as np\n", 33 | "from scipy import special, stats\n", 34 | "import matplotlib.pyplot as plt\n", 35 | "%matplotlib notebook" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "metadata": {}, 41 | "source": [ 42 | "### Startup: Define priors, and sample data" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 2, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "# Define priors\n", 52 | "alpha_prior, beta_prior = 1E-2, 1E-2 # Parameters for the prior for the precision of x\n", 53 | "tau_prior = 1E-6 # A priori precision for the precision of mu\n", 54 | "\n", 55 | "# Sample data\n", 56 | "np.random.seed(123)\n", 57 | "N = 4\n", 58 | "correct_mean = 5\n", 59 | "correct_precision = 1\n", 60 | "x = np.random.normal(loc=correct_mean, scale=1./np.sqrt(correct_precision), size=N)" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "metadata": {}, 66 | "source": [ 67 | "## Helper-routine: Make plot of posterior" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 3, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "def plot_posterior(posterior_mean_mu, posterior_prec_mu,\n", 77 | " posterior_alpha_gamma, posterior_beta_gamma,\n", 78 | " correct_mean, correct_precision):\n", 79 | " mu_range = np.linspace(posterior_mean_mu - 5./np.sqrt(posterior_prec_mu),\n", 80 | " posterior_mean_mu + 5. / np.sqrt(posterior_prec_mu), 500).astype(np.float32)\n", 81 | " precision_range = np.linspace(1E-2, 3, 500).astype(np.float32)\n", 82 | " mu_mesh, precision_mesh = np.meshgrid(mu_range, precision_range)\n", 83 | " variational_log_pdf = \\\n", 84 | " stats.norm.logpdf(mu_mesh, loc=posterior_mean_mu, scale=1. / np.sqrt(posterior_prec_mu)) + \\\n", 85 | " stats.gamma.logpdf(x=precision_mesh,\n", 86 | " a=posterior_alpha_gamma,\n", 87 | " scale=1. / posterior_beta_gamma)\n", 88 | " plt.figure()\n", 89 | " plt.contour(mu_mesh, precision_mesh, variational_log_pdf, 25)\n", 90 | " plt.plot(correct_mean, correct_precision, \"bo\")\n", 91 | " plt.title('Posterior over $(\\mu, \\\\tau)$. Blue dot: True parameters')\n", 92 | " plt.xlabel(\"Mean $\\mu$\")\n", 93 | " plt.ylabel(\"Precision $\\\\tau$\")" 94 | ] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "metadata": {}, 99 | "source": [ 100 | "## Helper-routine: Calculate ELBO" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 4, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "def calculate_lower_bound(data, tau, alpha, beta, nu_star, tau_star, alpha_star, beta_star):\n", 110 | " \"\"\"\n", 111 | " Helper routine: Calculate ELBO. Data is the sampled x-values, anything without a star relates to the prior,\n", 112 | " everything _with_ a star relates to the variational posterior.\n", 113 | " Note that we have no nu without a star; I am simplifying by forcing this to be zero a priori\n", 114 | "\n", 115 | " Note: This function obviously only works when the model is as in this code challenge,\n", 116 | " and is not a general solution.\n", 117 | "\n", 118 | " :param data: The sampled data\n", 119 | " :param tau: prior precision for mu, the mean for the data generation\n", 120 | " :param alpha: prior shape of dist for gamma, the precision of the data generation\n", 121 | " :param beta: prior rate of dist for gamma, the precision of the data generation\n", 122 | " :param nu_star: VB posterior mean for the distribution of mu - the mean of the data generation\n", 123 | " :param tau_star: VB posterior precision for the distribution of mu - the mean of the data generation\n", 124 | " :param alpha_star: VB posterior shape of dist for gamma, the precision of the data generation\n", 125 | " :param beta_star: VB posterior shape of dist for gamma, the precision of the data generation\n", 126 | " :return: the ELBO\n", 127 | " \"\"\"\n", 128 | "\n", 129 | " # We calculate ELBO as E_q log p(x,z) - E_q log q(z)\n", 130 | " # log p(x,z) here is log p(mu) + log p(gamma) + \\sum_i log p(x_i | mu, gamma)\n", 131 | "\n", 132 | " # E_q log p(mu)\n", 133 | " log_p = -.5 * np.log(2 * np.pi) + .5 * np.log(tau) - .5 * tau * (1 / tau_star + nu_star * nu_star)\n", 134 | "\n", 135 | " # E_q log p(gamma)\n", 136 | " log_p = log_p + alpha * np.log(beta) + \\\n", 137 | " (alpha - 1) * (special.digamma(alpha_star) - np.log(beta_star)) - beta * alpha_star / beta_star\n", 138 | "\n", 139 | " # E_q log p(x_i|mu, gamma)\n", 140 | " for xi in data:\n", 141 | " log_p += -.5 * np.log(2 * np.pi) \\\n", 142 | " + .5 * (special.digamma(alpha_star) - np.log(beta_star)) \\\n", 143 | " - .5 * alpha_star / beta_star * (xi * xi - 2 * xi * nu_star + 1 / tau_star + nu_star * nu_star)\n", 144 | "\n", 145 | " # Entropy of mu (Gaussian)\n", 146 | " entropy = .5 * np.log(2 * np.pi * np.exp(1) / tau_star)\n", 147 | " entropy += alpha_star - np.log(beta_star) + special.gammaln(alpha_star) \\\n", 148 | " + (1 - alpha_star) * special.digamma(alpha_star)\n", 149 | "\n", 150 | " return log_p + entropy\n" 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "metadata": {}, 156 | "source": [ 157 | "## Do the VB" 158 | ] 159 | }, 160 | { 161 | "cell_type": "markdown", 162 | "metadata": {}, 163 | "source": [ 164 | "The task is to compute a variational approxmation of the posterior over the unknown paramters $\\mu$ and $\\gamma$, \n", 165 | "\n", 166 | "$$\n", 167 | "p(\\mu,\\gamma|x_1,\\ldots,x_n) \\approx q(\\mu)q(\\gamma)\n", 168 | "$$\n", 169 | "\n", 170 | "\n", 171 | "We are looking for VB posteriors over $\\mu$ and $\\gamma$. It turns out after some pencil pushing that the posteriors are in the same distributional families as the priors were, so $\\mu$ remains Gaussian, $\\gamma$ remains Gamma distributed. What we need is the updated parameters for these two distributions. We have two parameters to update $q(\\mu)$, which are denoted as `q_mu` and `q_tau`, and another two parameters to update $q(\\gamma)$, which are denoted as `q_alpha` and `q_beta`.\n", 172 | "The parameters of the (prior) distribution $p(\\cdot)$ are called something ending with `_prior`, like `alpha_prior` for $\\alpha$." 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "# Initialization\n", 182 | "q_alpha = alpha_prior\n", 183 | "q_beta = beta_prior\n", 184 | "q_mu = 0\n", 185 | "q_tau = tau_prior\n", 186 | "previous_lb = -np.inf\n", 187 | "\n", 188 | "# Start iterating\n", 189 | "print(\"\\n\" + 100 * \"=\" + \"\\n VB iterations:\\n\" + 100 * \"=\")\n", 190 | "for iteration in range(1000):\n", 191 | " # Update gamma distribution\n", 192 | " q_alpha = ?????\n", 193 | " q_beta = ?????\n", 194 | " expected_gamma = ?????\n", 195 | " \n", 196 | " # Update Gaussian distribution\n", 197 | " q_tau = ?????\n", 198 | " q_mu = ?????\n", 199 | " \n", 200 | " # Calculate Lower-bound\n", 201 | " this_lb = calculate_lower_bound(data=x, tau=tau_prior, alpha=alpha_prior, beta=beta_prior,\n", 202 | " nu_star=q_mu, tau_star=q_tau, alpha_star=q_alpha, beta_star=q_beta)\n", 203 | "\n", 204 | " print(\"{:2d}. alpha: {:6.3f}, beta: {:12.3f}, nu: {:6.3f}, tau: {:6.3f}, ELBO: {:12.7f}\".format(\n", 205 | " iteration + 1, q_alpha, q_beta, q_mu, q_tau, this_lb))\n", 206 | " \n", 207 | " if this_lb < previous_lb:\n", 208 | " raise ValueError(\"ELBO is decreasing. Something is wrong! Goodbye...\")\n", 209 | " \n", 210 | " if iteration > 0 and np.abs((this_lb - previous_lb) / previous_lb) < 1E-8:\n", 211 | " # Very little improvement. We are done.\n", 212 | " break\n", 213 | " \n", 214 | " # If we didn't break we need to run again. Update the value for \"previous\"\n", 215 | " previous_lb = this_lb\n", 216 | " \n", 217 | "\n", 218 | "print(\"\\n\" + 100 * \"=\" + \"\\n Result:\\n\" + 100 * \"=\")\n", 219 | "print(\"E[mu] = {:5.3f} with data average {:5.3f} and prior mean {:5.3f}.\".format(q_mu, np.mean(x), 0.))\n", 220 | "print(\"E[gamma] = {:5.3f} with inverse of data covariance {:5.3f} and prior {:5.3f}.\".format(\n", 221 | " q_alpha / q_beta, 1. / np.cov(x), alpha_prior / beta_prior))" 222 | ] 223 | }, 224 | { 225 | "cell_type": "markdown", 226 | "metadata": {}, 227 | "source": [ 228 | "### Make plot of Variational Bayes posterior" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": null, 234 | "metadata": {}, 235 | "outputs": [], 236 | "source": [ 237 | "plot_posterior(q_mu, q_tau, q_alpha, q_beta, correct_mean, correct_precision)\n", 238 | "plt.show()" 239 | ] 240 | } 241 | ], 242 | "metadata": { 243 | "kernelspec": { 244 | "display_name": "probabilistic.ai", 245 | "language": "python", 246 | "name": "probabilistic.ai" 247 | }, 248 | "language_info": { 249 | "codemirror_mode": { 250 | "name": "ipython", 251 | "version": 3 252 | }, 253 | "file_extension": ".py", 254 | "mimetype": "text/x-python", 255 | "name": "python", 256 | "nbconvert_exporter": "python", 257 | "pygments_lexer": "ipython3", 258 | "version": "3.7.0" 259 | } 260 | }, 261 | "nbformat": 4, 262 | "nbformat_minor": 2 263 | } 264 | -------------------------------------------------------------------------------- /Day3/.ipynb_checkpoints/VAE-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Adapt the variational auto encoder\n", 8 | "\n", 9 | "Below you will find an implementation of a VAE for the MNIST data. To allow for faster learning time, we only consider the digits 0,1, and 2 and only the first 100 samples of those digits.\n", 10 | "\n", 11 | "In this exercise, you should familiarize yourself with the implementation below and experiment with the structure of the VAE specification in order to emphasize digit separation in the latent space and the generation of images when sampling from the latent space.\n", 12 | "\n", 13 | "Part of the implementation is based on code from the official Pyro examples." 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 1, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "import numpy as np\n", 23 | "import torch\n", 24 | "import torchvision.datasets as datasets\n", 25 | "import torch.nn as nn\n", 26 | "import torchvision.transforms as transforms\n", 27 | "import pyro\n", 28 | "import pyro.distributions as dist\n", 29 | "from pyro.infer import SVI, Trace_ELBO\n", 30 | "from pyro.optim import Adam\n", 31 | "import datetime\n", 32 | "import os\n", 33 | "import matplotlib.gridspec as gridspec\n", 34 | "from matplotlib import pyplot\n", 35 | "import matplotlib.pyplot as plt\n", 36 | "from scipy.stats import norm" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "### Get the MNIST data\n", 44 | "\n", 45 | "We will wrap the MNIST data set in a Pyro data loader. " 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 2, 51 | "metadata": { 52 | "scrolled": true 53 | }, 54 | "outputs": [], 55 | "source": [ 56 | "def setup_data_loader(batch_size=64):\n", 57 | " #data = datasets.MNIST('./data', train=True, download=True,\n", 58 | " # transform=transforms.Compose([\n", 59 | " # transforms.ToTensor(),\n", 60 | " # transforms.Normalize((0.1307,), (0.3081,))\n", 61 | " # ]))\n", 62 | " data = datasets.MNIST('./data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor()]))\n", 63 | " \n", 64 | " # We only select the digits 0, 1, and 2 and only the first 100 of each of these\n", 65 | " # digits\n", 66 | " selector = np.array([], dtype=int)\n", 67 | " for i in [5, 6, 7]:\n", 68 | " selector = np.concatenate((selector, np.where(data.targets == i)[0][:100]))\n", 69 | " data.data = data.data[selector, :, :]\n", 70 | " data.targets = data.targets[selector]\n", 71 | " \n", 72 | " # Binarize the data\n", 73 | " data.data[data.data<128] = 0\n", 74 | " data.data[data.data>=128] = 1\n", 75 | "\n", 76 | " data.data = data.data.type(torch.float)\n", 77 | " \n", 78 | " # Put the data within a data loader \n", 79 | " train_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)\n", 80 | " return train_loader\n", 81 | "\n", 82 | "\n", 83 | "train_loader = setup_data_loader(batch_size=300)" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 3, 89 | "metadata": {}, 90 | "outputs": [ 91 | { 92 | "data": { 93 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAA5NJREFUeJzt3VFuwjAUAMG64v5Xdi9QoGowCd6ZbwQBafU+XhzGnPML6Pk++wKAc4gfosQPUeKHKPFDlPghSvwQJX6IEj9E3d75YWMMtxPCYnPO8ZfXmfwQJX6IEj9EiR+ixA9R4oco8UOU+CFK/BAlfogSP0SJH6LED1HihyjxQ5T4IUr8ECV+iBI/RIkfosQPUeKHKPFDlPghSvwQJX6IEj9EiR+ixA9R4oco8UOU+CFK/BAlfogSP0SJH6LED1Hih6jb2RfAtc05l733GGPZe/OcyQ9R4oco8UOU+CFK/BAlfogSP0TZ88et3OMf/Wz3Aaxl8kOU+CFK/BAlfogSP0SJH6Ks+jZ3dJW3ct125poRkx+yxA9R4oco8UOU+CFK/BAlfoiy59/AkX35lY/NOvK7lskPUeKHKPFDlPghSvwQJX6IEj9E2fNvzi6ce0x+iBI/RIkfosQPUeKHKPFDlPghyp7/A3i+PSuY/BAlfogSP0SJH6LED1HihyjxQ5T4IUr8ECV+iBI/RIkfosQPUeKHKPFDlPghSvwQJX6IEj9EiR+ixA9R4ocoj+7+ANW/2a5+73cx+SFK/BAlfogSP0SJH6LED1Hihyh7fpby9+LXZfJDlPghSvwQJX6IEj9EiR+ixA9R9vw8ZE+/L5MfosQPUeKHKPFDlPghSvwQJX6Isue/ALv03z37XTzX/xiTH6LED1HihyjxQ5T4IUr8EGXV9wJWdf9jVXcukx+ixA9R4oco8UOU+CFK/BAlfoiy53+BnffVR+9h2Pm3+XQmP0SJH6LED1HihyjxQ5T4IUr8EGXPH+dZBF0mP0SJH6LED1HihyjxQ5T4IUr8EGXPzyHO638ukx+ixA9R4oco8UOU+CFK/BBl1bc5R3a5x+SHKPFDlPghSvwQJX6IEj9EiR+i7Pl5yJHdfZn8ECV+iBI/RIkfosQPUeKHKPFDlD1/nD1+l8kPUeKHKPFDlPghSvwQJX6IEj9EiR+ixA9R4oco8UOU+CFK/BAlfogSP0Q5z7855/W5x+SHKPFDlPghSvwQJX6IEj9EiR+ixA9R4oco8UOU+CFK/BAlfogSP0SJH6LED1HihyjxQ5T4IUr8ECV+iBI/RI0559nXAJzA5Ico8UOU+CFK/BAlfogSP0SJH6LED1HihyjxQ5T4IUr8ECV+iBI/RIkfosQPUeKHKPFDlPghSvwQJX6IEj9EiR+ifgCJ5jwXeHW/1QAAAABJRU5ErkJggg==\n", 94 | "text/plain": [ 95 | "
" 96 | ] 97 | }, 98 | "metadata": {}, 99 | "output_type": "display_data" 100 | } 101 | ], 102 | "source": [ 103 | "def display_image(x):\n", 104 | " plt.axis('off')\n", 105 | " pyplot.imshow(x.reshape((28, 28)), cmap=\"gray\")\n", 106 | " \n", 107 | "toy_image = train_loader.dataset.data[215,:,:]\n", 108 | "display_image(toy_image)" 109 | ] 110 | }, 111 | { 112 | "cell_type": "markdown", 113 | "metadata": {}, 114 | "source": [ 115 | "### Setup the decoder network" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 4, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "class Decoder(nn.Module):\n", 125 | " def __init__(self, z_dim, hidden_dim):\n", 126 | " super(Decoder, self).__init__()\n", 127 | " # setup the two linear transformations used\n", 128 | " self.fc1 = nn.Linear(z_dim, hidden_dim)\n", 129 | " self.fc21 = nn.Linear(hidden_dim, 784)\n", 130 | " # setup the non-linearities\n", 131 | " self.softplus = nn.Softplus()\n", 132 | " self.sigmoid = nn.Sigmoid()\n", 133 | "\n", 134 | " def forward(self, z):\n", 135 | " # define the forward computation on the latent z\n", 136 | " # first compute the hidden units\n", 137 | " hidden = self.softplus(self.fc1(z))\n", 138 | " # return the parameter for the output Bernoulli\n", 139 | " # each is of size batch_size x 784\n", 140 | " #loc_img = self.sigmoid(self.fc21(hidden))\n", 141 | " loc_img = self.fc21(hidden)\n", 142 | " return loc_img" 143 | ] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "metadata": {}, 148 | "source": [ 149 | "### Setup the encoder network" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 5, 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "class Encoder(nn.Module):\n", 159 | " def __init__(self, z_dim, hidden_dim):\n", 160 | " super(Encoder, self).__init__()\n", 161 | " # setup the three linear transformations used\n", 162 | " self.fc1 = nn.Linear(784, hidden_dim)\n", 163 | " self.fc21 = nn.Linear(hidden_dim, z_dim)\n", 164 | " self.fc22 = nn.Linear(hidden_dim, z_dim)\n", 165 | " # setup the non-linearities\n", 166 | " self.softplus = nn.Softplus()\n", 167 | "\n", 168 | " def forward(self, x):\n", 169 | " # define the forward computation on the image x\n", 170 | " # first shape the mini-batch to have pixels in the rightmost dimension\n", 171 | " x = x.reshape(-1, 784)\n", 172 | " # then compute the hidden units\n", 173 | " hidden = self.softplus(self.fc1(x))\n", 174 | " # then return a mean vector and a (positive) square root covariance\n", 175 | " # each of size batch_size x z_dim\n", 176 | " z_loc = self.fc21(hidden)\n", 177 | " z_scale = torch.exp(self.fc22(hidden))\n", 178 | " return z_loc, z_scale" 179 | ] 180 | }, 181 | { 182 | "cell_type": "markdown", 183 | "metadata": {}, 184 | "source": [ 185 | "### Packaging it all together" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": 6, 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [ 194 | "class VAE(nn.Module):\n", 195 | " # by default our latent space is 50-dimensional\n", 196 | " # and we use 400 hidden units\n", 197 | " def __init__(self, z_dim=2, hidden_dim=400, use_cuda=False):\n", 198 | " super(VAE, self).__init__()\n", 199 | " # create the encoder and decoder networks\n", 200 | " self.encoder = Encoder(z_dim, hidden_dim)\n", 201 | " self.decoder = Decoder(z_dim, hidden_dim)\n", 202 | " self.z_dim = z_dim\n", 203 | "\n", 204 | " # define the model p(x|z)p(z)\n", 205 | " def model(self, x):\n", 206 | " # register PyTorch module `decoder` with Pyro\n", 207 | " pyro.module(\"decoder\", self.decoder)\n", 208 | " with pyro.plate(\"data\", x.shape[0]):\n", 209 | " # setup hyperparameters for prior p(z)\n", 210 | " z_loc = x.new_zeros(torch.Size((x.shape[0], self.z_dim)))\n", 211 | " z_scale = x.new_ones(torch.Size((x.shape[0], self.z_dim)))\n", 212 | " # sample from prior (value will be sampled by guide when computing the ELBO)\n", 213 | " z = pyro.sample(\"latent\", dist.Normal(z_loc, z_scale).to_event(1))\n", 214 | " # decode the latent code z\n", 215 | " loc_img = self.decoder.forward(z)\n", 216 | " # score against actual images\n", 217 | " pyro.sample(\"obs\", dist.Bernoulli(logits=loc_img).to_event(1), obs=x.reshape(-1, 784))\n", 218 | " #pyro.sample(\"obs\", dist.Bernoulli(loc_img).to_event(1), obs=x.reshape(-1, 784))\n", 219 | "\n", 220 | " # define the guide (i.e. variational distribution) q(z|x)\n", 221 | " def guide(self, x):\n", 222 | " # register PyTorch module `encoder` with Pyro\n", 223 | " pyro.module(\"encoder\", self.encoder)\n", 224 | " with pyro.plate(\"data\", x.shape[0]):\n", 225 | " # use the encoder to get the parameters used to define q(z|x)\n", 226 | " z_loc, z_scale = self.encoder.forward(x)\n", 227 | " # sample the latent code z\n", 228 | " pyro.sample(\"latent\", dist.Normal(z_loc, z_scale).to_event(1))\n", 229 | "\n", 230 | " # define a helper function for reconstructing images\n", 231 | " def reconstruct_img(self, x):\n", 232 | " # encode image x\n", 233 | " z_loc, z_scale = self.encoder(x)\n", 234 | " # sample in latent space\n", 235 | " z = dist.Normal(z_loc, z_scale).sample()\n", 236 | " # decode the image (note we don't sample in image space)\n", 237 | " loc_img = self.decoder(z)\n", 238 | " return loc_img\n", 239 | "\n", 240 | " def sample_images(self, dim=10):\n", 241 | "\n", 242 | " plt.figure(figsize=(dim, dim))\n", 243 | " gs1 = gridspec.GridSpec(dim, dim)\n", 244 | " gs1.update(wspace=0.025, hspace=0.05) # set the spacing between axes.\n", 245 | "\n", 246 | " z_1 = norm.ppf(np.linspace(0.00001, 0.99999, dim), loc=0, scale=1)\n", 247 | " z_2 = norm.ppf(np.linspace(0.00001, 0.99999, dim), loc=0, scale=1)\n", 248 | " for j in range(dim):\n", 249 | " for i in range(dim):\n", 250 | " x_val = self.decoder.forward(torch.tensor([z_1[i], z_2[j]], dtype=torch.float32))\n", 251 | " plt.subplot(gs1[i*dim+j])\n", 252 | " plt.axis('off')\n", 253 | " plt.imshow(x_val.detach().numpy().reshape((28, 28)), cmap=\"gray_r\")\n", 254 | " plt.show()" 255 | ] 256 | }, 257 | { 258 | "cell_type": "markdown", 259 | "metadata": {}, 260 | "source": [ 261 | "### Setup training (single epoch)" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": 7, 267 | "metadata": {}, 268 | "outputs": [], 269 | "source": [ 270 | "def train(svi, train_loader):\n", 271 | " # initialize loss accumulator\n", 272 | " epoch_loss = 0.\n", 273 | " # do a training epoch over each mini-batch x returned\n", 274 | " # by the data loader\n", 275 | " for x, _ in train_loader:\n", 276 | " # do ELBO gradient and accumulate loss\n", 277 | " epoch_loss += svi.step(x)\n", 278 | "\n", 279 | " # return epoch loss\n", 280 | " normalizer_train = len(train_loader.dataset)\n", 281 | " total_epoch_loss_train = epoch_loss / normalizer_train\n", 282 | " return total_epoch_loss_train" 283 | ] 284 | }, 285 | { 286 | "cell_type": "markdown", 287 | "metadata": {}, 288 | "source": [ 289 | "### Perform learning" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": 8, 295 | "metadata": {}, 296 | "outputs": [ 297 | { 298 | "name": "stdout", 299 | "output_type": "stream", 300 | "text": [ 301 | "[epoch 000] average training loss: 561.9964\n", 302 | "[epoch 100] average training loss: 20.7505\n", 303 | "[epoch 200] average training loss: 19.0819\n", 304 | "[epoch 300] average training loss: 18.7468\n", 305 | "[epoch 400] average training loss: 18.1215\n", 306 | "[epoch 500] average training loss: 17.7618\n", 307 | "[epoch 600] average training loss: 17.4973\n", 308 | "[epoch 700] average training loss: 17.4621\n", 309 | "[epoch 800] average training loss: 17.3345\n", 310 | "[epoch 900] average training loss: 17.1521\n" 311 | ] 312 | }, 313 | { 314 | "data": { 315 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZAAAAEKCAYAAAA8QgPpAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAG+dJREFUeJzt3X20HHWd5/H3p/s+hCSYBEEICTFBgoqKyNwB3Z3dReUhsLNmncEZWM+CDzuMu6DjHOe4MOwZXBk8KrqOOsgxMug6w4isD0OWzRgBgZ2zZ5UEB0NAkMuDkgyuQRTEwM293d/9o359U7fTVZ1bSafvvfm8zmno+lV11a+6bvpTv/rVgyICMzOz6ar1uwJmZjY7OUDMzKwSB4iZmVXiADEzs0ocIGZmVokDxMzMKnGAmJlZJQ4QMzOrxAFiZmaVDPS7AtMlaQ3waaAOXBcRHy2a9vDDD4+VK1ceqKqZmc0J99xzz1MRcUS36WZVgEiqA9cAZwDbgE2S1kfEA52mX7lyJZs3bz6QVTQzm/Uk/Xhvpptth7BOAUYj4tGI2AXcCKztc53MzA5Ksy1AlgFP5Ia3pTIzMzvAZluAdCXpIkmbJW3esWNHv6tjZjZnzbYA2Q4ckxtensomRcS6iBiJiJEjjujaB2RmZhXNtgDZBKyWtErSEHAesL7PdTIzOyjNqrOwImJC0iXARrLTeK+PiPv7XC0zs4PSrAoQgIjYAGzodz3MzA52sy5AzPKazeyRzM0IovX/9JTm1vvWuAiIorLc9NE+31ZZ633us/llkC+bUpegGVOnD6Lj5yWh3PopNzDRCMYmmtSUTVcT1CSU/r9rokm9nn2+GYGkrOLA8+ONyfqNTTQ4dN4gEUEzgkYTGhGT381AXUw0AgkazamPvG7VJwLGG00kMVSvUauJF8YbNJtBvSZ2NZoM1JTqmdW1GTBYT/WsabLurXXevS6w6JAhdjWaPPfCBL98fhfzh+oM1Gr8emyC8UaztVqT27q1TfJlkM2rlurRbMbkdzVQExPNYCKtQ6fvG9i9LSTqEhPNJmL399KMbBn1mhhvBI20/o1mMDRQm/L3mP8m2x8lvnua6FC253oND9SoCcYbwc5dDYYHa5PrK7Lv8cgXzePfvPZoeskBMgM0m8HO8Qa/HpvgubEJfrlznB2/GmNsosGzz4/zwniTZmQ/Qs/vmmCs0WRsvMmuRnPKD2Kz7Qep/f+taVo/gtP6DLnPNKf+EObnMdFsAtk/0vZ/FM1mNo9G+tFqzW+gXmO80aTRiMkfPonJf6jjzdjjh6ylqNzsYHfSMYsdILPFrokmjz71HD995gWeem4XP33meZ4ba/D8rgl27mqwc7zBzrEJfj3W4Ne7JlJYNNiZxu8tCQbrNYbrNYYH65N7bDVpyt5o+/9b00AarrXGK+21TJ1HvSYG017i5GfU9pncchHUJeq1bPpmBILJvTul/9Rbe6Q1Ua9le0sTzWCwLgZqtck9wFYgDdREvZ7tMU7dN4cXxhsMDdQYHtj9PShX38ky2vd02R1SbXu/rfe01hWmfJdqL4MpgZf/zulQ1voeRPZdteq4e193zz3QllpNDA/U0g5Ae0snC95GM3hhvMEhg3WAyb3hhfMGJr+/8UaT59M0tbTN6rXdfx+NZrZT0GgG84fqU1odeYMDNSIi7Xk3GR6oM95oMlCrMTxYY6KZ1au1w1GTGJvItllRq64ZwUQjeOq5MRYODzA0UOP5XQ0WzR+k2YShAbHokKHJ75u2v7FWXYUmW3mNtKNSlyaX0Whm9Rms7z6PaI9WQdt2GG80J6dv/Z3XJRqRtWQG6jXqEi9MNBisZ9/N7r+ryYpNUlvRlH8rufGt9WmVNSMYG89aYYN1MW+wzthEk7o05W+irTHVEw6Qiv7pl8/zvcd+zl0P7eCHT/6KR3Y8x0Tb3vBgXcwfGmD+UJ1DhuosHM7eL100j/lDAywYHmDBUJ0FwwPZuOFsmhfNG+SIQ4eZN1hj0SFDDA/WJn9AhweywwVmNjMtYrDny5g/NHU4H4QHkgNkGhrN4K//7+Ns2fYMG7Y+yQvjTZbMH+R1K5bw5le+hJcfdSjLlxzCixcMc9SiecxLe4FmZnORA2QavvH9bXzof2b3bTzuJQu5+twTeeXSFzkozOyg5ADZC3f9aAdbnvglt2x5kuOPXMhf/ruTWbpoHofO631T1cxspnKA7IU/v+UBHv7ZcwB87u0nc/yRh/a5RmZm/ecA6eLqjQ9Ohsff/odT+WfHHd7nGpmZzQyz7V5YB9S1dz7CNXc8AsB733Scw8PMLMcBUuJj33pw8v3hC4f7WBMzs5nHAbKXzjjhyH5XwcxsRnGAFMhflfr+01dz9OJD+lgbM7OZxwFSoHXzOYA/+BfH9rEmZmYzkwOkwHNjEwBcufZVLBj2yWpmZu0cIAV2jmUtEIeHmVlnDpACp33iTgDmDzlAzMw6cYB0sdAtEDOzjhwgXSwY9o0Szcw6cYB04RaImVlnDpAu5jtAzMw6coB0sWDIh7DMzDpxgHThQ1hmZp3517HAicsXcdiCIQb69KxhM7OZzr+OJdTvCpiZzWAOEDMzq8QBUiB3M14zM+vAAVJC8kEsM7MifQkQSW+TdL+kpqSRtnGXSRqV9JCks3Lla1LZqKRLe13HwE0QM7My/WqBbAV+B/jf+UJJJwDnAa8C1gCfk1SXVAeuAc4GTgDOT9P2lNsfZmbF+nIab0T8EDoeIloL3BgRY8BjkkaBU9K40Yh4NH3uxjTtA72rY6/mbGY2N8y0PpBlwBO54W2prKh8D5IukrRZ0uYdO3bsU2XcBWJmVqxnLRBJtwFHdRh1eUTc3KvlRsQ6YB3AyMiI2xFmZj3SswCJiNMrfGw7cExueHkqo6S8J3wIy8ys3Ew7hLUeOE/SsKRVwGrgbmATsFrSKklDZB3t63tfHR/DMjMr0pdOdElvBT4LHAH8L0n3RsRZEXG/pJvIOscngIsjopE+cwmwEagD10fE/b2soxsgZmbl+nUW1jeBbxaMuwq4qkP5BmBDj6s2hTvRzcyKzbRDWDNGuBPEzKyUA6SEGyBmZsUcIGZmVokDxMzMKnGAlHAnuplZMQdIAfehm5mVc4CUkLvRzcwKOUAK+HkgZmblHCAl3AdiZlbMAWJmZpU4QAq4E93MrJwDpIQPYZmZFXOAFHADxMysnAOkhE/jNTMr5gAp4LvxmpmVc4CUcQPEzKyQA8TMzCpxgBTwASwzs3IOkBI+gmVmVswBUsRNEDOzUg6QEvKVhGZmhRwgBdwAMTMr5wAp4faHmVkxB4iZmVXiACngK9HNzMo5QEq4D93MrJgDpIDbH2Zm5foSIJKulvSgpC2SvilpcW7cZZJGJT0k6axc+ZpUNirp0gNSzwOxEDOzWapfLZBbgVdHxInAj4DLACSdAJwHvApYA3xOUl1SHbgGOBs4ATg/Tdsz7gIxMyvXlwCJiG9HxEQa/C6wPL1fC9wYEWMR8RgwCpySXqMR8WhE7AJuTNP2lC8kNDMrNhP6QN4F/H16vwx4IjduWyorKjczsz4Z6NWMJd0GHNVh1OURcXOa5nJgArhhPy73IuAigBUrVlSeT7gb3cysVM8CJCJOLxsv6R3AbwNvjt0XXWwHjslNtjyVUVLevtx1wDqAkZGRfUoBH8AyMyvWr7Ow1gAfBN4SETtzo9YD50kalrQKWA3cDWwCVktaJWmIrKN9fS/r6E50M7NyPWuBdPGXwDBwa+qo/m5EvCci7pd0E/AA2aGtiyOiASDpEmAjUAeuj4j7e15LN0HMzAr1JUAi4riScVcBV3Uo3wBs6GW9pi7vQC3JzGx2mglnYc1YchPEzKyQA8TMzCpxgJiZWSUOkBK+EN3MrJgDpICfB2JmVs4BUsINEDOzYg4QMzOrxAFSwAewzMzKOUBKuBPdzKyYA6SA+9DNzMo5QEr4SnQzs2IOkAJ+HoiZWTkHSAn3gZiZFXOAmJlZJQ6QAu5ENzMr5wAp4UNYZmbFHCAF3AAxMyvXNUAkvVrSlyVtTq//LunEA1G5/nMTxMysSGmASFoLfBO4E3hXet0FfD2Nm7PcB2JmVq7bM9E/DJwREY/nyrZI+g5wc3rNWe4DMTMr1u0Q1kBbeACQygZ7USEzM5sdugXIhKQV7YWSXgpM9KZKM4WPYZmZlel2COsK4DZJHwHuSWUjwKXAf+5lxWYCH8EyMytWGiAR8XeSHgM+ALw3FT8A/F5E/KDXlesnd6KbmZXr1gIhBcUFB6AuM4470c3MinU7jfdwSVdIep+khZKulbRV0s2SjjtQlewHN0DMzMp160T/W2AYWA3cDTwGnAvcAlzX26r1n58HYmZWrFuAHBkRfwq8D1gYER+PiAcj4gvA4qoLlXSlpC2S7pX0bUlHp3JJ+oyk0TT+5NxnLpT0cHpdWHXZZma2f3QLkAZARATwVNu45j4s9+qIODEiTiJrzfxZKj+brLWzGrgIuBZA0mFkZ4SdCpwCXCFpyT4sv6twL7qZWalunejHSlpPdkZr6z1peFXVhUbEs7nBBezuclgLfDkF1nclLZa0FDgNuDUingaQdCuwBvhK1TrsDXeim5kV6xYg+ftdfaJtXPvwtEi6iuzsrmeAN6biZcATucm2pbKi8p5x+8PMrFy360DuKhon6atkN1YsGn8bcFSHUZdHxM0RcTlwuaTLgEvIDlHtM0kXkR3+YsWKPS6in9689keFzMzmqK7XgZR4Q9nIiDh9L+dzA7CBLEC2A8fkxi1PZdvJDmPly+8sWO46YB3AyMhI5YaEu0DMzMr15YFSklbnBtcCD6b364EL0tlYrweeiYgngY3AmZKWpM7zM1NZr+vZ60WYmc1apS2Q/Gm07aPYt7vxflTSy8nO5Pox8J5UvgE4BxgFdgLvBIiIpyVdCWxK03241aFuZmb90e0Q1idLxj1YMq5URPxuQXkAFxeMux64vuoyp8un8ZqZlevWif7GsvFmZnbw6nYvrA/m3r+tbdxHelWpmcDtDzOzct060c/Lvb+sbdya/VyXGcd96GZmxboFiAredxqeW9wEMTMr1S1AouB9p+E5x3fjNTMr1u0srNdKepastXFIek8antfTmpmZ2YzW7Sys+oGqyEwz55tXZmb7qC9Xos8W7kQ3MyvmACngCwnNzMo5QEq4AWJmVswBUsDtDzOzcg6QEu4DMTMr5gAxM7NKHCAF3IduZlbOAVLCD5QyMyvmACkQ7kY3MyvlACnh9oeZWTEHSAH3gZiZlXOAlHETxMyskAPEzMwqcYAU8BEsM7NyDpASfqCUmVkxB0gRN0HMzEo5QEr4OkIzs2IOkAK+kNDMrJwDpIQbIGZmxRwgZmZWSV8DRNIHJIWkw9OwJH1G0qikLZJOzk17oaSH0+vCXtfNV6KbmZUb6NeCJR0DnAn8JFd8NrA6vU4FrgVOlXQYcAUwQnZ+1D2S1kfEL3pbx17O3cxsdutnC+RTwAeZesLsWuDLkfkusFjSUuAs4NaIeDqFxq3Aml5Wzg0QM7NyfQkQSWuB7RHxg7ZRy4AncsPbUllReU/5QkIzs2I9O4Ql6TbgqA6jLgf+lOzwVS+WexFwEcCKFSt6sQgzM6OHARIRp3cql/QaYBXwg/TEv+XA9yWdAmwHjslNvjyVbQdOayu/s2C564B1ACMjI5WPRIV70c3MSh3wQ1gRcV9EvCQiVkbESrLDUSdHxE+B9cAF6Wys1wPPRMSTwEbgTElLJC0ha71s7HVd3YluZlasb2dhFdgAnAOMAjuBdwJExNOSrgQ2pek+HBFP97Iibn+YmZXre4CkVkjrfQAXF0x3PXD9AaoW4CvRzczK+Er0Au4CMTMr5wAp404QM7NCDhAzM6vEAWJmZpU4QEr4AJaZWTEHSAe+iNDMrDsHSAn3oZuZFXOAdOAGiJlZdw6QEr4br5lZMQeImZlV4gDpwEewzMy6c4CUcCe6mVkxB0gHPo3XzKw7B0gJN0DMzIo5QDpw+8PMrDsHSAn3gZiZFXOAmJlZJQ6QDtyHbmbWnQOkhHwMy8yskAOkg3A3uplZVw4QMzOrxAHSgftAzMy6c4CUcBeImVkxB4iZmVXiADEzs0ocICX8QCkzs2IOkA7ciW5m1p0DpIQ70c3MivUlQCR9SNJ2Sfem1zm5cZdJGpX0kKSzcuVrUtmopEt7WT9fSGhm1t1AH5f9qYj4RL5A0gnAecCrgKOB2yQdn0ZfA5wBbAM2SVofEQ/0soJugJiZFetngHSyFrgxIsaAxySNAqekcaMR8SiApBvTtD0NEDMzK9bPPpBLJG2RdL2kJalsGfBEbpptqayofA+SLpK0WdLmHTt2VKqYO9HNzLrrWYBIuk3S1g6vtcC1wMuAk4AngU/ur+VGxLqIGImIkSOOOGKf5uVOdDOzYj07hBURp+/NdJK+ANySBrcDx+RGL09llJTvd26AmJl116+zsJbmBt8KbE3v1wPnSRqWtApYDdwNbAJWS1olaYiso319z+vpbnQzs0L96kT/uKSTyHb2Hwf+ECAi7pd0E1nn+ARwcUQ0ACRdAmwE6sD1EXF/ryoX7gQxM+uqLwESEf++ZNxVwFUdyjcAG3pZr3buAzEzK+Yr0c3MrBIHSAc+gGVm1p0DxMzMKnGAdOA+dDOz7hwgJeRedDOzQg6QTtwCMTPrygFSwu0PM7NiDhAzM6vEAdKBHyhlZtadA6SE+9DNzIo5QDrwabxmZt05QEq4AWJmVswB0oEbIGZm3TlASvhCQjOzYg4QMzOrxAHSgR8oZWbWnQOkhI9gmZkVc4B0MDRQ41+/ZikrDpvf76qYmc1Y/Xom+ox26LxBrnn7yf2uhpnZjOYWiJmZVeIAMTOzShwgZmZWiQPEzMwqcYCYmVklDhAzM6vEAWJmZpU4QMzMrBLN5fs+SdoB/HgfZnE48NR+qs5s4XWe+w629QWv83S9NCKO6DbRnA6QfSVpc0SM9LseB5LXee472NYXvM694kNYZmZWiQPEzMwqcYCUW9fvCvSB13nuO9jWF7zOPeE+EDMzq8QtEDMzq8QB0oGkNZIekjQq6dJ+12d/kXSMpDskPSDpfkl/lMoPk3SrpIfT/5ekckn6TPoetkiatQ9JkVSX9I+SbknDqyR9L63bVyUNpfLhNDyaxq/sZ72rkrRY0tckPSjph5LeMNe3s6Q/Tn/XWyV9RdK8ubadJV0v6WeStubKpr1dJV2Ypn9Y0oVV6+MAaSOpDlwDnA2cAJwv6YT+1mq/mQA+EBEnAK8HLk7rdilwe0SsBm5Pw5B9B6vT6yLg2gNf5f3mj4Af5oY/BnwqIo4DfgG8O5W/G/hFKv9Umm42+jTwrYh4BfBasnWfs9tZ0jLgfcBIRLwaqAPnMfe285eANW1l09qukg4DrgBOBU4BrmiFzrRFhF+5F/AGYGNu+DLgsn7Xq0frejNwBvAQsDSVLQUeSu8/D5yfm35yutn0Apanf1hvAm4BRHaB1UD7Ngc2Am9I7wfSdOr3OkxzfRcBj7XXey5vZ2AZ8ARwWNputwBnzcXtDKwEtlbdrsD5wOdz5VOmm87LLZA9tf4QW7alsjklNdlfB3wPODIinkyjfgocmd7Ple/iL4APAs00/GLglxExkYbz6zW5zmn8M2n62WQVsAP4Yjpsd52kBczh7RwR24FPAD8BniTbbvcwt7dzy3S3637b3g6Qg5CkhcDXgfdHxLP5cZHtksyZU/Mk/Tbws4i4p991OYAGgJOBayPidcCv2X1YA5iT23kJsJYsPI8GFrDnoZ4570BvVwfInrYDx+SGl6eyOUHSIFl43BAR30jF/0/S0jR+KfCzVD4Xvot/DrxF0uPAjWSHsT4NLJY0kKbJr9fkOqfxi4CfH8gK7wfbgG0R8b00/DWyQJnL2/l04LGI2BER48A3yLb9XN7OLdPdrvtteztA9rQJWJ3O3hgi64hb3+c67ReSBPwV8MOI+G+5UeuB1pkYF5L1jbTKL0hnc7weeCbXVJ4VIuKyiFgeESvJtuV3IuLtwB3AuWmy9nVufRfnpuln1Z56RPwUeELSy1PRm4EHmMPbmezQ1eslzU9/5611nrPbOWe623UjcKakJanldmYqm75+dwjNxBdwDvAj4BHg8n7XZz+u12+RNW+3APem1zlkx35vBx4GbgMOS9OL7Iy0R4D7yM5w6ft67MP6nwbckt4fC9wNjAL/AxhO5fPS8Ggaf2y/611xXU8CNqdt/XfAkrm+nYH/CjwIbAX+Ghiea9sZ+ApZH884WUvz3VW2K/CutO6jwDur1sdXopuZWSU+hGVmZpU4QMzMrBIHiJmZVeIAMTOzShwgZmZWiQPEZhVJIemTueE/kfSh/TTvL0k6t/uU+7yct6U75N7RVn60pK+l9ydJOmc/LnOxpP/UaVlmVTlAbLYZA35H0uH9rkhe7mrnvfFu4A8i4o35woj4p4hoBdhJZNfo7K86LAYmA6RtWWaVOEBstpkge1TnH7ePaG9BSHou/f80SXdJulnSo5I+Kuntku6WdJ+kl+Vmc7qkzZJ+lO6j1XqWyNWSNqXnKvxhbr7/IGk92VXP7fU5P81/q6SPpbI/I7ug868kXd02/co07RDwYeD3Jd0r6fclLUjPgrg73SBxbfrMOyStl/Qd4HZJCyXdLun7adlr0+w/Crwsze/q1rLSPOZJ+mKa/h8lvTE3729I+pay50Z8PPd9fCnV9T5Je2wLOzhMZ6/JbKa4BtjS+kHbS68FXgk8DTwKXBcRpyh7qNZ7gfen6VaSPSPhZcAdko4DLiC7DcRvShoG/o+kb6fpTwZeHRGP5Rcm6WiyZ0z8BtlzKL4t6d9GxIclvQn4k4jY3KmiEbErBc1IRFyS5vcRstttvEvSYuBuSbfl6nBiRDydWiFvjYhnUyvtuyngLk31PCnNb2VukRdni43XSHpFquvxadxJZHdtHgMekvRZ4CXAssieu0Gqjx2E3AKxWSeyOwh/mewBQntrU0Q8GRFjZLd2aAXAfWSh0XJTRDQj4mGyoHkF2b2CLpB0L9nt719M9pAegLvbwyP5TeDOyG7uNwHcAPzLadS33ZnApakOd5LdimNFGndrRDyd3gv4iKQtZLe1WMbu23sX+S3gbwAi4kHgx0ArQG6PiGci4gWyVtZLyb6XYyV9VtIa4NkO87SDgFsgNlv9BfB94Iu5sgnSTpGkGjCUGzeWe9/MDTeZ+u+g/d4+Qfaj/N6ImHLDOUmnkd0q/UAQ8LsR8VBbHU5tq8PbgSOA34iIcWV3IZ63D8vNf28Nsocz/ULSa8ke2PQe4PfI7q1kBxm3QGxWSnvcN7H7EaUAj5MdMgJ4CzBYYdZvk1RL/SLHkj3FbSPwH5XdCh9Jxyt7QFOZu4F/JelwZY9JPh+4axr1+BVwaG54I/BeSUp1eF3B5xaRPf9kPPVlvLRgfnn/QBY8pENXK8jWu6N0aKwWEV8H/gvZITQ7CDlAbDb7JJA/G+sLZD/aPyB7fGmV1sFPyH78/x54Tzp0cx3Z4Zvvp47nz9Ol9R7ZbbMvJbud+A+AeyLi5rLPtLkDOKHViQ5cSRaIWyTdn4Y7uQEYkXQfWd/Ng6k+Pyfru9na3nkPfA6opc98FXhHOtRXZBlwZzqc9jdkj322g5DvxmtmZpW4BWJmZpU4QMzMrBIHiJmZVeIAMTOzShwgZmZWiQPEzMwqcYCYmVklDhAzM6vk/wMAnrJKV3jAxQAAAABJRU5ErkJggg==\n", 316 | "text/plain": [ 317 | "
" 318 | ] 319 | }, 320 | "metadata": {}, 321 | "output_type": "display_data" 322 | } 323 | ], 324 | "source": [ 325 | "vae = VAE(z_dim=2, hidden_dim=400)\n", 326 | "\n", 327 | "# Run options\n", 328 | "LEARNING_RATE = 1.0e-2\n", 329 | "\n", 330 | "# Run only for a single iteration for testing\n", 331 | "NUM_EPOCHS = 1000\n", 332 | "\n", 333 | "#train_loader = setup_data_loader(batch_size=300)\n", 334 | "\n", 335 | "# clear param store\n", 336 | "pyro.clear_param_store()\n", 337 | "\n", 338 | "# setup the optimizer\n", 339 | "adam_args = {\"lr\": LEARNING_RATE}\n", 340 | "optimizer = Adam(adam_args)\n", 341 | "\n", 342 | "# setup the inference algorithm\n", 343 | "svi = SVI(vae.model, vae.guide, optimizer, loss=Trace_ELBO())\n", 344 | "train_elbo = []\n", 345 | "# training loop\n", 346 | "for epoch in range(NUM_EPOCHS):\n", 347 | " total_epoch_loss_train = train(svi, train_loader)\n", 348 | " train_elbo.append(-total_epoch_loss_train)\n", 349 | " if (epoch % 100) == 0:\n", 350 | " print(\"[epoch %03d] average training loss: %.4f\" % (epoch, total_epoch_loss_train))\n", 351 | "\n", 352 | "plt.plot(range(len(train_elbo)), train_elbo)\n", 353 | "plt.xlabel(\"Number of iterations\")\n", 354 | "plt.ylabel(\"ELBO\")\n", 355 | "plt.show()" 356 | ] 357 | }, 358 | { 359 | "cell_type": "markdown", 360 | "metadata": {}, 361 | "source": [ 362 | "### Plot the data in the embedding space" 363 | ] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "execution_count": 9, 368 | "metadata": { 369 | "scrolled": true 370 | }, 371 | "outputs": [ 372 | { 373 | "data": { 374 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAD8CAYAAABjAo9vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJztnX90VdW1778rh0ACUuIPfMEEBarlKRKITVuftqMKGvTFH8hVrLwxfG1v67D3dpRr76PFp0OC145StbX4et94A2+9tqNqi0ijbeogVuzzV2kNgiClWG+qmAiP8CNiIYH8WO+Pk5Occ7J/77X3Wnuf72cMB2Znn73XPifnu+aac645hZQShBBC0kOZ7gEQQghRC4WdEEJSBoWdEEJSBoWdEEJSBoWdEEJSBoWdEEJSBoWdEEJSBoWdEEJSBoWdEEJSxjgdNz3jjDPkjBkzdNyaEEISy9atWw9KKae6nadF2GfMmIH29nYdtyaEkMQihHjPy3l0xRBCSMqgsBNCSMqgsBNCSMqgsBNCSMqgsBNCSMqgsBNCSMqgsBNCSMqgsBNCSBh2rAceuhBorsr+u2O97hHp2aBECCGpYMd64FffAPp7sz9/+H72ZwCoW6ptWLTYCSEkKC/cOyrqOfp7s8c1QoudkCLm/PtcCDH6s5TAri/t1DcgYi4fdvo7HhO02AnJIyfqxf/N+fe5uodGTGRKrb/jMUFhJySPnJC7HSMEALDwHqC8svBYeWX2uEYo7IQQEpS6pcC1DwNTpgMQ2X+vfVhr4BSgj50QQsJRt1S7kBdDi52QPKTM/ud2jBCTobATkseuL+0cEfL8/5gVQ5IEXTGEFEERJ0mHFjshhKQMCjshhKQMCjshhKQMCjshhKQMCjshhKQMCjshhKQMCjshhKQMCjshhKQMJcIuhHhUCHFACPGWiusRQggJjiqL/TEAVym6FiGEuGNgr1FTUFJSQEr5khBihoprEUKIK4b2GjWF2HzsQojbhBDtQoj27u7uuG5LiBJaO1rRuKERdT+pQ+OGRrR2tOoeUmljaK/RMWhaVcQm7FLKdVLKBillw9SpU+O6LUkJOoW1taMVza81Y9+xfZCQ2HdsH5pfa6a468TQXqMF5FYVH74PQI6uKmIQd2bFEOPRLaxr31iLvsG+gmN9g31Y+8baWO5PLDC012gBGlcVFHZiPLqFdf+x/b6OkxgwtNdoARpXFarSHZ8E8HsAs4UQnUKIv1dxXUIA/cJaPana13ESA4b2Gi1A46pCVVbMLSquQ4gV1ZOqse/YPsvjcbD8ouVofq25YNVQkanA8ouWx3J/YoOBvUYLWHhPYeYOENuqgq4YYjzLL1qOikxFwbE4hbVpVhOuP/d6lIns16VMlOH6c69H06ymWO5PEorGVQVb4xHjyQno2jfWYv+x/aieVI3lFy2PTVhbO1rxzDvPYEgOAQCG5BCeeecZ1J9ZT3EnzmhaVQipof16Q0ODbG9vj/2+hAShcUOjpSto2qRpaLuxTcOISKkihNgqpWxwO48WOyEWtHa0jqwQJKyNH2bFpIAd67Pphx92ZoOaC+8x22/vEQo7IUXk8uaLUyyLYVZMwomjLIGmiYPBU0KKsMqbL4ZZMSkg6g1E3HlKiDk4uVgEBKZNmobmS5oZOE06UW8g0rjzlK4YQoqwy5tnsDRlTKkdtqYtjqsg6TtPCfFCUiok6s6bJzERdVmCpO88JVlatnXhgU178EFPL86qqsSKRbOxuL4mtff1Q3FAMlfIC4BxLg2VefP52TVx598TF3JBzKiCmxp3njKPXREt27pw58ad6O0fHDlWWZ7Bd5fMjVRkdd3XL6WYC26VXVORqQjln79vy3146u2nMCSHUCbKcNMnbsLdF9+tashENYqzYpjHHjMPbNpTIK4A0Ns/iAc27YlUYMPeNy5rX3chLx04VaUMIuz3bbkPv9jzi5Gfh+TQyM8Ud0PRtPOUPnZFfNDT6+u4CffNWftdPb2QALp6enHnxp1o2daleJSlWSFR9WT21NtP+TpOShcKuyLOqqr0ddyE+zpZ+6opxYCk6sksV6vG63FSulDYFbFi0WxUlmcKjlWWZ7Bi0Wxj7xvnKqNpVhOaL2nGtEnTtOaC37flPsz76TzM/clczPvpPNy35b7I7qV6MstVl7QirmciyYA+dkXk/NJxZ6eEue9ZVZXoshDxqFYZTbOatGaE2PqoD76Du//SXhDgaj1lUuhsFtVVKW/6xE0F488nv/Ik/e6EWTElTFIyalQx76fzLN0WZVLizXdHN6q0fqwKzWecjj7ZP3IsTDZL2JTH/NdXZCrQN9gHCYkyUQYppWWRsjJRhjdvfdP3WInZeM2KoSumhFlcX4PvLpmLmqpKCAA1VZWpFXXAwUdd9PPaj00sEHUgeI/VsI24i1/fO9iLCZkJWPO5NXjz1jdtK0/S717a0BVT4iyur0mtkBdTJsqsLfain/ePy4w5BwiWzRI25dHt9bbP5OCPJ+mHnz4pGW76xE1jD0qJm45+VHCoemBw7HkIls0SNuXR7fWWz+RwnJQGFHaDadnWhUvXbMbMla24dM3mSPLLS4m7L74bN8++uaB36c1nNODuj04UnLf86HFUiPKCY0GzWcKmPLq93vKZZt/MwGmJw+CpoaQxsGls3RSLbd8qsmKAsZk4gL9AbBRlCUqOFHVJYkmBhKOrREFQ3EoTGF0EzGLbdxPCjyvXBLuY68+93vO1dTfyTgROwh1HlyQDobBHTNBaLG6bh0yq6Fi8usiVJgBG8+xV103xis5Vgl0nppc6X/J1HdX5/yb97YTGTbidml2kWNjpY/eBX593mFosTqUC4qzx4gUvpQl0FAELm2oYFhMLn5n2txMaty5FGptd6ITC7pEgX4gwtVicSgX4vW7UQVgvpQnCBBGDNuhwWiXEgd2zCSG0NRmJsz5QLLgJd5zNLnasBx66EGiuyv4bQ29TOyjsHvHzhcgJqdV2fcBbLRanzUN+arzEYaG5FSJr2daFI51XQA4VZpqUiwk40nmF44Tjxeq2E37dFrNVrRggu3kozpVDPrqqkEaGm3BH3SUph8bG1VZQ2D3i9QuRL6R2eKnF4uQH9VPRMQ4LzWl1kXs/uvfPQd++JRg6WQUpgUpxBvr2LUH3/jmOE46b1e0k/LpLBecKn1ltFopz5ZBP6CqkBlmlANyFu24pcO3DwJTpAET232sfVu9f19i42goKu0e8fiGshDQfL5UX3axsPxUd47DQnFYX+e/HwNF6HPuPlfjbn9fg0O4VOH5kXsF1rCYcN6vbSfhNKBXcNKsJdinFOnztoaqQGmaVAvAm3HVLgTveApp7sv9GETQ1zJfPrBiPrFg02zKvvPgL4SSYNR4zENxSHZ0qOhZb+lUTy3HkeP+Ye6iu4GhXmsDu/Ri0Ebvi86snVVu21MtZ3U7Cb0qqoNszhMVP5k+oKqSmZpho6lJUwJTa4QnP4rgGKOwe8fqFsCuFW1NViVdXLvB0Ly9WtpWQWqUdlpcJlGcE+gdHhTSOOvE57N6PjBCW4l484Sy/aLnlBp2c1e0mmrpLBQPuzxCGIPsDAtcHMswqNQqNjautoCvGB4vra/DqygX465omvLpygeWXQ0XDjaB+UCtLv39IYtL4cdoqONq9H7d8Zrqn98mtQYcJ7hY3/DYZ8ZPFFGvmT5wZJkkjLl++R2ixK0ZFww2vbp9i7Cz9D3v7sX1Vo+f7q8Tp/Wg45zRP75OT1R3U3RL3xiWvKwcvm73yiTXzxzCr1DhMcAkNw1oxhhJkd6BdiqUfN1ApYHL9Fb+fYeOGRktX1LRJ09B2Y5v6Aaao7koSYa2YhBPEDxrU0i81dJU38ILfLKYo/feWGGSVEnso7ClCV9/VpKF745ITfvvQmpL5Q8yCwp4CUlXUKQaiTD8M+1kEWXWZkPlDzIJZMQkndUWdYiCqTBo/n4Vd5kup9aEl0aAkeCqEuArAWgAZAP8mpVzjdL7JwdOkWb8MmAYjiqwYr59F3E1UkvY3TeyJLXgqhMgA+FcAVwLoBPC6EOJZKeWfwl47bvymmplA6oo6xUQU7guvn0WcTVSS+DdNwqPCFfNpAO9IKTuklCcB/BzA9QquGzs6SpqGLakbuqiTB9h71RteP4s4J+PUlek1rQiZoagQ9hoA+UUSOoePJY64rV8V/nEVO12jHmOp4PWziGMyzpGqFZ2OImQJnUhiC54KIW4TQrQLIdq7u7vjuq0v4vzCtWzrwj+vfzO0NRV1sM0Eiy8pKwavn0XUk3E+cf5NR07cpXFNrGbpERXpjl0Apuf9XDt8rAAp5ToA64Bs8FTBfZUT1wafnBXstcKhG4GLOnlAt8Vngo/YT/DRy2cR534D4zathdm5GncRMlOrWXpAhbC/DuA8IcRMZAX9CwCWKbhu7MT1hXOr2W6SNeV3w4xq4gw0WhHVxFI8AeRWJar/7ozatObWeNqNuEvjJriaZWhhl1IOCCG+DmATsumOj0opd4UemQVxpG1Faf3mcLJ2TSsBoNvi071iiGNiiXpVEsfftCfCWsBWRcgA4OSx7KSh2oo2rMa6H5TsPJVS/gbAb1Rcyw4TluSqcKpRbsJmlOIJ9O8+WYMX/9wdm8WXf/8yj3Xb7cYedqxxTCyBJo8kFuMKawHnnu+5bwO9h0eP9x72Z/l7JcHVLBNTUkD3klwldlawKaJePIE+vbUrtrEV399K1O1WDFZjX/HUm1j9q13oOd7vKPTFE8Ll/3kqXvxzN+yCQSpdUb4nj7AuDV2osIDrlmYntHxhB6LxfeeulbQJFAkSdt1LcpUY5fcswm4CXf2rXbGM1y7+kBECQ1I63tuu0UiuNaDdKs9qQvjZlr22Y1TtivIdx0hqUE+VBRyn7zuh1SwTI+y6g3iqidPv6cc9YTdRHjne7yqQKrC7/5CU+Osa552iXiZ5q1WeWzA7H699a/3gO46R1KCeKgs4wb7vuEhMEbA4c3/ThN8NRl4nyqhy2cPkXXsde/EEoGrV19rRisYNjaj7SR0aNzSitaPV0+t870VIcou6uqXAHW8BzT3Zf4NYwwvvyVr6+STE9x0XibHYTXZfRIWKQKDf2ISV9WhHFG6wMFk4XsdePAHYrQatyE2M7e8dLggoN366C7/+4GFfTaXz8bWCS3BQT0nQN8G+77hIjLADBqVtxYCqLCC/sQmrCfTYiQH09PaPOTcKN1iYCXxxfQ3a3zuMJ//wvu3mr/IyMWaS8DOZAdmJMd8H39XTi/Ud61BWHlNXJlOFzU20VQZ9g/i+k5hJFJBECXspoSoLKEhswmrzTJy57EEn8JZtXXh6a5etqAPAKRXZP/nizUDfXTLXMivGqyUvxvVYHo+sK5NpQT0voq0z6JvUTKKAUNgNRVUWkIoNRklxg3kJgh453m+5EvrukrmW9evtaqwXI/urIMaPFXcVXZkSgRfR1hn0TWomUUAo7IaiKgtIlSgnwQ3mZdLLCBE45jDuY9swYeomiPIeyP4qnOhehIGj9QCAE92LUDFtI0TZqMsq0qbSpuFFtHVmsyQ1kyggFHZDUbmVPwmirIKqieUjKZlWVJZnbC16t5jDd/7v4+idMircYnwPKqZtRB+AgaP1GDhaj7LxGVTPeLE0m0p7Ee2F9wDP/CMweHL0WGZ8PEHfEkuRTEy6Y6nB3pf+aNnWhb/1Ddj+Pvf+1QRIp1xcX4NTa39bYI0DgCjrx4SpmwAA5RmBexfeirYb27Djv+9A241t8Yq67rrh5zV6O14c/1DQmtMTJZYiSYvdYErB0lZV2+WBTXvQP2QtEj+8eX7BNYOshOyCoKK8J5JNS74wITD4lzb34y/cCwwVraiG+uPxc5uaSRQRFHaiDZWF3Zz86/n+86Axh+pJ1dh3bN+Y42edMg1tDk3DY2kkbUJg0IsPW7ef27RMogihsGvm7padI3nXGSFwy2em477FcyO/rwmd65uf3aWssJvTJqNi0Q+yElp+0XI0v9Y8sgEJcA+Oqpq4XD8r3YIJ2PuwK091Pyelfm6d0MeukbtbduJnW/aO5F0PSomfbdmLGRG3gDOhj2nLti7LTU9AsB2tKxbNhrD5nYqNVE2zmtB8STOmTZoGAYFpk6ah+ZJmRz+6iraCnj6rqEoMuPnt839/8hhQlhl7jRMfjb7uvEag+FNKsZ9bJ7TYNfLkHyysl2GiLLRlQglkJ3ELIsS5XaePb9lbUGpX5UaqpllNvgKiKvYiePqsoigx4Oa3L/59cRndHDkfOgC8+QRQ8OkIYN6yknGPxAktdo047ZAE3K27oE2eTSiB7HSvoEJ83+K5eOjm+cZkEqloJO3ps6pbClz7MDBlOgCR/ffah8MJplvjaKvf2/Fhp8350j7o6gXdmUAGQ4tdIxmb7kD52H2xw/hvTSiBbDeGUyeWK+0lqhPPexEcaph4/qxUBwbd/PZ+/PdTatXHAUzIBDIYWuwaueUz013PsRPbMP5bE0ogW41BILvlP8r4Qpx42ouQE6gP3wcgRwVq2PrU9lm5+e29+u9zLiHVcQC3FUWJQ4tdI7nsF7tqhE5f4DDuFBNqv+SPoaunFwKj3tfI+9lGVeXP4rqL65c6P4NLqqK2z8rNb2/XWDofkSl0CamMA5iQCWQwQsa18yuPhoYG2d7eHvt9TcdPCqJdcaqaqkrLYlYmE+uzFC/hgazAhPVJB71ucxVg2VlVZJtR6MRLGd4X7h1ebeRPzbB+dpUT6kMX2qROTs828EgpQoitUsoGt/NosRuEH/+wyloyuok1mBvVZp6g1/WS262rjrib3z7/93GPMcnNRmKAwp5QTHCnqCLWYG5US/ig13UTqKQECd0mAdXPUWIlAvxCYU8wJmWAhCHW1UfY3Y92lmnQ67oJlAnlAlQQxXOUUIkAv1DYiXZiXX2EWcI7WZ1hruskUBGsMLSUk2CwM1Yo7MQIYlt9hFnCO1mduYCdateA4voqKguv+aLyVOvdqfm1ZIgyKOyk9Ai6hHezOqNwDSgOEm5vXYfnxc9w1oSD+ECegfsHluLZ/s/GWk6CRA+FnRCv6KhOGHSFYRULAPCt/v+NiWXZDka14iDWlP8b0A/8quezztfIWda9R4KtRnqP+DtOQkFhJ8QrYa3noCmBflcCdrGAcZWYKE4WnDpRnMS3xq3H1olXOl8j340SJKOFJXtjhSUFiBaCFjDTSphiWy6lA8acG6a4lV0swKYC41ni0NgMJLciX36375dYazrd0GInsaMtgKcCj9ZzcebJ8+IeTLQQ2+PP3YMrf3PGyHk/vOAv+NTOVeHyvX1mmvRNrB77vnu5hp/7MO88VijsCcOEzkdOeBmfCfXgo8Rq4qqYsH9MjwkAqDi+H10nekfOO2vr/YAIme9t283oNGCgd4wraeLVFpa33TWKz3HDyf2UW5lQ6JVDV0yCMKHzkRNex2dCPfgosZq4PpCnW55bfHwaDlpf1I91bOf2uPp73l1JVtcovp6bG8XJ/eTHNUV8Q2FPECparUWJ1/GpaEBhMlYT1P0DS3Fcji84dlyOx/0DhaL6gTzD+qJ+goxOsYC6pdmc++ae7L92FnLxNSpPy/7nJ7bglPdv97vnvu39OYktFPYEYbqlazeOrp7eAqvdhHrwUWI1QT079FncX/4PBWJ7f/k/4NmhwjTD+weWohcTCl8cJMjoVcC9Mn5S1uL3cz2nvH+73/UeptWuAAp7gjDd0nUaR75LxlMDigRjN3HNb7qtQGznN9025rznM5/HWxf9i9o2d/l4zbhR4Spxaq7htALZ+FW2ugsJ67Frxk8wtDgoB2QFwxRRtBpfPkmsFR8Ur5+rp/PcApBeM0381Iz3U+/cbgxO9wOyAu6Eijr5KcNrPXYKu0aCCHUSsmL+6RfbLX8nAPx1TVO8A4qCOGuPu4mjn+YefsTaawMQt8nC6b363kzb3HrHsZUwbLSRAIKk/eku1es2sSyurxlpd1eMk6vG9AlrhLjro7v19vRTCtdPhUWvO0WdgqD5gr5k3dgxXf099/Z6rP4YiFA+diHETUKIXUKIISGE6yxCCjE9GFqM13RGv8FR09M4C4i7ibKtGL9vn2du9xo/DaW97hR1CoK6+ecLMm9sYMmBQIQNnr4FYAmAlxSMpeQwPRhajNd0Rr/BUdPTOAuIu654EGGze42fbf1eyyd4HZ/d5JfL3lnyCEsOKCSUK0ZKuRsAhLDYUkcAOLsYkta31M8Kw4/LKFErl7iLWVkVHnPCSQz9buv3Uj7Bz/icJj+WHFBKbD52IcRtAG4DgLPPPjuu22rFrSZK0vqWRtWbNNaep2GJu4lyseBZBjSHmTLdXQxV14y3EuSTx6yDoqIMaJ4CiAwgB8eOl63ulOGaFSOE+C2Aaotf3SWlfGb4nN8B+B9SSk+pLqWSFXPpms2WgpXUtD+vWTx+A6Fu1zUusBpnVkwxfjJbnIjyGawyZexgSqMvlGXFSCmvUDOk0iNRLgYPeFlhBKncuLi+Bu3vHcaTf3gfg1IiIwT+7pM1I6JuXCVInZalihVD1Jk9xVa8KMta6FYksTF3AmC6Y4QkysXgETffuV0gdPWvdjnm5j+9tQuDw6vHQSnx9NYuNJxzWuorQfpGhS/aKbNHlcDmT37NVc7nMqVROWHTHW8QQnQC+C8AWoUQm9QMKx2kvSaKFXarkSPH+23TF53EO22rHiWErQNjWmYPUxqVE0rYpZS/lFLWSiknSCn/k5RykaqBpYG010Sxwmk1Ype+6CTeoVNCw3YjSiN+8tlVsPAeoKzc+ndl5UxpjAAWAYuYxfU1eHXlAvx1TRNeXbnACFGPsi2d02rETsCdxDvUqoc1v62Ju01d3VJgwmTr302YTP96BFDYS4yod3kurq9BVaW1dWYn4E7iHWrVE/cu0TD8+pvA6tOy6YCrT8v+HBVhercGpfeIv+MkFAyehsC4NDwPxBGMbL5ujq+NV27ZNoHr48TtS3bDLsXw198E2n88ep4cHP35mh+ovVfxcasaLlEQ98auEofCHhAj0/A8EEcw0u/Gq8gmSJPE5NffBNofxcgGo/wUw62PWb9m62PBhN0unXHvFuDNJ+IrYJZP3Bu7ShwKe0CSmoYXVwqmVys70gnSFDHZsb5Q1HPk3EJ2Od75x/1sKLJzQW19bOy94sojZ8mAWKGwBySpaXim1aeJdII0RUxeuBe2pQA+7BzdYl+MGI47+N1QZOdqsptA4nJNsWRAbFDYA5LUzUem1aeJfII0QUychHNKLXBeY6GPPccnv5j91++GIjsXlN0EQj936qCwB8Q0y9cPupt15JPUCdIXdkILUbiCyLlKRAaY8VngL20OnYxgP2HYuaDmLSv0seeO51xTOmvgEKUw3TEgpbj5KApKYneuVd44BNDw5VHhvOYHwKrDQPOHwA3/B+j842j+vR12lrZdOuM1P7BPc2TOf6pgz1OinSSmjfrGjzVsV8ExH9VVEVVVjSSRwp6nJDEUuIZ2rAde+AbwTMrcAX58/Y7BTBHN+2Jazj8JBYWdmINT/vVf2krH92ubfx+h9WxSzj8JDX3sxBzssj/aHy0t36/XWi4qC5zFXT+GRAqFnZiD7bLfZmNPWvFSy0V1sFNH/RgSGQyeEjWoSJXzEjQcQWTrkevAhLRABjtLEq/BU1rsJDyqrEe7tEArdPl+TUkLZLCTOEBhJ+FRVR7Xyh3Q8GWzfL+mlAKOu1kGSRTMiiHhUWk9WqUFnn2xftdHDlMsZZUFzkxwLRGl0GIn4YnSeiwWnfMasz/ranVniqWsKthpimuJKIXBUxKe4vxzQM3OSKvrFqN6B2aQMcU9Br84WeQMwiYKBk9JfESVKmflzy4mbv920tIC3SxyU1xLRCn0sRM1RFEe16u4xC1CJpQC9opbyV/uOE0ltNiJuXgVlyAipHLXpsm4WeTccZpKKOzEXCzz2osIIkKlFDB0C/YmzbVEPEFhJ+Zimdf+98FFKGelb/yqGbnoceDFIq9bmg2ULlmX/XnjbelexZQA9LETs1Hlz/aSYWN6wDBIvrnXvq9++6oSo6Gwk9LAS4aNCQFDO/EOI7xeJke/fVWJ0VDYSXLxY8G6WeM6AoZWm6/ye5Lmi3fUwsu0x1RBHztJJn4DoE7WuI6AodX42x+1F28n4VWR4WPKjlqiBAo7iZao0gr9FuOyCyIueSQbOIzb3WDpGrLZBZ6z6K2oPDW6yppMe0wsFHYSHVGmFfp1HZiW1ufHxZFzM1kJLxBdZU2mPSYW+thJdAT1C3vxnQfZMWnSjlG78Y9BFD5/8fuy8Tbrl6mqrEkSCS12Eh1BAnJerfykuw68bL4CAMhRsc3lmzf3jLqP6BsnFhhjsff396OzsxN9fX26h2IEFRUVqK2tRXl5ue6hBCeIVW1n5T/37UJr0mt+tqkUj1+UAXJw7HlTpjtfR2VddpIajBH2zs5OTJ48GTNmzIAQNu3QSgQpJQ4dOoTOzk7MnDlT93CCE0R07NwTvYezVnuxuOsSchXNKfLHb1cO2E2g65YCe7cAWx/LTgwiA8xblpwJjkSCMa6Yvr4+nH766SUv6gAghMDpp5+e/NWL34DcjvWw7XEKmLPlP4qgsNt7ZZddtGN9Nvc9Z+3LwezPLAdQ0hhjsQOgqOeRmvfCj1X9wr2wTfkDzNksE9VmIbv3ymnXKXeMEguMsdhNIJPJYP78+ZgzZw7mzZuH73//+xgaGgIAtLe34xvf+IbrNS655BIAwLvvvosnnnjC9V7z58/Hddddp+YBko6bcJsSEIx7l6aTeHPHKLEglMUuhHgAwLUATgL4DwBfklL2qBiYDiorK7F9+3YAwIEDB7Bs2TIcPXoUq1evRkNDAxoaXDtS4bXXXgMwKuzLli1zvZdWTGpk7JQCGGVA0O97EHdzCifxZqMMYkFYi/15ABdKKesAvA3gzvBD8kbLti5cumYzZq5sxaVrNqNlW5fS65955plYt24dfvSjH0FKid/97ne45pprAADd3d248sorMWfOHHzlK1/BOeecg4MHDwIATjnlFADAypUr8fLLL2P+/Pl46KGHlI5NGabVJbdLAaw8LbrNMkHeg7hTLZ1SGpOe9kkiIZSwSynbpJQDwz9uARCLmdAbdHFMAAAKj0lEQVSyrQt3btyJrp5eSABdPb24c+NO5eI+a9YsDA4O4sCBAwXHV69ejQULFmDXrl248cYbsXfv3jGvXbNmDT73uc9h+/btuOOOO8b8vq+vDw0NDbj44ovR0tKidNye8bstP2qsAohLHgG+/dfoVhFB3oO4d2k6iTd3jBILVAZPvwzgFwqvZ8sDm/agt78w57e3fxAPbNqDxfU1kd//lVdewS9/+UsAwFVXXYVTTz3V9zXee+891NTUoKOjAwsWLMDcuXPx8Y9/XPVQnTHRPxt3CmPQ9yDOcbrl7HPHKCnCVdiFEL8FUG3xq7uklM8Mn3MXgAEAjztc5zYAtwHA2WefHWiwOT7osa6rbXc8KB0dHchkMjjzzDOxe/dupdeuqclOQLNmzcJll12Gbdu2RSfsdj5k+meT8x5QvIkPXF0xUsorpJQXWvyXE/UvArgGwH+TUtrmqkkp10kpG6SUDVOnTg016LOqrLdi2x0PQnd3N26//XZ8/etfH5N6eOmll2L9+qwPtq2tDUeOHBnz+smTJ+Ojjz6yvPaRI0dw4sQJAMDBgwfx6quv4oILLlA29gKcfMj0z/I9IKkklI9dCHEVgG8BuE5KeVzNkNxZsWg2KsszBccqyzNYsWh2qOv29vaOpDteccUVaGxsxKpVq8act2rVKrS1teHCCy/EU089herqakyePLngnLq6OmQyGcybN29M8HT37t1oaGjAvHnzcPnll2PlypXRCbtbnnPc/tmoyvgGJe73wLTnJ6lEOBjZ7i8W4h0AEwAcGj60RUp5u9vrGhoaZHt7e8Gx3bt34/zzz/d875ZtXXhg0x580NOLs6oqsWLR7Fj86wBw4sQJZDIZjBs3Dr///e/xta99LZLURb/viSXNVbDe9COyxaTixG7bfKkE+0r9+UlohBBbpZSuedehgqdSynPDvD4Mi+trYhPyYvbu3YulS5diaGgI48ePxyOPPKJlHJ4wyYdc6rskS/35SWwYVVIgKZx33nnYtm2b7mF4w6TqfyZm4cRJqT8/iQ2WFEg7JuU5l3rt8FJ/fhIbtNhLAVNS5UxaPeig1J+fxAYtdhIfJq0edFDqz09igxY78Y7q5hKlSKk/P4kFWux5xFm2d+/evWhsbMT555+PCy64AO+++66SZ4gM0wqGEUJsobDnkSulu2vXLjz//PN47rnnsHr1agBAQ0MDHn74YddrFJfttePWW2/FihUrsHv3bvzxj3/EmWeeWXjC8cPA0Q/M2ciiu2AYN/YQ4pnkCnvEX/Qoy/b+6U9/wsDAAK688sqR10ycOHH0hOOHsxbx0ACMsY51pupxtUCIL5Ip7DF90aMq2/v222+jqqoKS5YsQX19PVasWIHBwbxqlR/tA+RQ4QV1ltMF9Kbq6V4tEJIwkinsmr/or7zyCr7whS8ACFa2d2BgAC+//DIefPBBvP766+jo6MBjjz02esLgSesX6tzIorNYFjf2EOKLZAp7TF/0/LK9KqmtrcX8+fMxa9YsjBs3DosXL8Ybb7wxekJmvPULdW5k0Zmqx409hPgimcIewxc9yrK9n/rUp9DT04Pu7m4AwObNmwurO06eBoiij8aEjSx1S4E73soWD7vjrfjS9lhalxBfJFPYI/qix1W2N5PJ4MEHH8TChQsxd+5cSCnx1a9+dfSEiadlLeKyceBGFnBjDyE+CVW2NygqyvYq2SwTkESV7SWEpIZYyvZqReMOvkSV7SWElBzJFXaNJKpsLyGk5Eimj50QQogtFHZCCEkZFHZCCEkZFHZCCEkZFPY84irb++KLL2L+/Pkj/1VUVKClpUXdgxBCShpmxeSRK9sLAAcOHMCyZctw9OhRrF69Gg0NDWhocE0fHVO2d9myZWPOufzyy0fuc/jwYZx77rlobGxU+CSEkFImsRZ7a0crGjc0ou4ndWjc0IjWjlal14+ybG8+GzZswNVXX11YtpcQQkKQSGFv7WhF82vN2HdsHyQk9h3bh+bXmpWLe1Rle/P5+c9/jltuuUXpuAkhpU0ihX3tG2vRN9hXcKxvsA9r31gby/3Dlu3NsW/fPuzcuROLFi1SOTwSBHZoIikikcK+/9h+X8eDElXZ3hzr16/HDTfcgPLy8kiuTzzCDk0kZSRS2KsnVfs6HoQoy/bmePLJJ+mGMQF2aCIpI5HCvvyi5ajIVBQcq8hUYPlFy0NdN66yvUA2a+b999/H5z//+VBjJgpghyaSMhKZ7tg0qwlA1te+/9h+VE+qxvKLlo8cD0pB39EiLrvsMlx22WUAgClTpmDTpk0jZXtff/11TJgwAQDwt7/9DQBQXl6OzZs3215vxowZ6OrqCjVeoogptcNuGIvjhCSQRAo7kBX3sEIeFJbtTRkL78n61PPdMezQRBJMYoVdJyzbmzJydf01NW4hRDUUdkIArY1bCFGNUcFTHW36TIXvBSEkKMYIe0VFBQ4dOkRBQ1bUDx06hIqKCveTCSGkCGNcMbW1tejs7ER3d7fuoRhBRUUFamuZlUEI8Y8xwl5eXo6ZM2fqHgYhhCQeY1wxhBBC1EBhJ4SQlEFhJ4SQlCF0ZKEIIT4CsCf2G8fHGQAO6h5EhKT9+YD0PyOfL5mcI6Wc6naSruDpHimle5+5hCKEaOfzJZu0PyOfL93QFUMIISmDwk4IISlDl7Cv03TfuODzJZ+0PyOfL8VoCZ4SQgiJDrpiCCEkZWgTdiHEvwghdgghtgsh2oQQZ+kaSxQIIR4QQvx5+Bl/KYSo0j0mlQghbhJC7BJCDAkhUpN9IIS4SgixRwjxjhBipe7xqEYI8agQ4oAQ4i3dY4kCIcR0IcSLQog/Df99huuXmVB0WuwPSCnrpJTzAfwaQNra1TwP4EIpZR2AtwHcqXk8qnkLwBIAL+keiCqEEBkA/wrgagAXALhFCHGB3lEp5zEAV+keRIQMAPhnKeUFAC4G8I8p/Axd0SbsUsqjeT9OApAqZ7+Usk1KOTD84xYAqSrVKKXcLaVM2yazTwN4R0rZIaU8CeDnAK7XPCalSClfAnBY9ziiQkq5T0r5xvD/fwRgN4AavaOKH63VHYUQ3wFwK4APAVyucywR82UAv9A9COJKDYD8rtadAD6jaSwkJEKIGQDqAfxB70jiJ1JhF0L8FkC1xa/uklI+I6W8C8BdQog7AXwdwKoox6Mat+cbPucuZJeHj8c5NhV4eT5CTEQIcQqApwH8U5F3oCSIVNillFd4PPVxAL9BwoTd7fmEEF8EcA2AhTKBeaU+Pr+00AVget7PtcPHSIIQQpQjK+qPSyk36h6PDnRmxZyX9+P1AP6sayxRIIS4CsC3AFwnpTyuezzEE68DOE8IMVMIMR7AFwA8q3lMxAdCCAHgxwB2Syl/oHs8utC2QUkI8TSA2QCGALwH4HYpZWqsIyHEOwAmADg0fGiLlPJ2jUNSihDiBgD/C8BUAD0AtkspF+kdVXiEEP8VwA8BZAA8KqX8juYhKUUI8SSAy5Ctfvj/AKySUv5Y66AUIoT4LICXAexEVlsA4H9KKX+jb1Txw52nhBCSMrjzlBBCUgaFnRBCUgaFnRBCUgaFnRBCUgaFnRBCUgaFnRBCUgaFnRBCUgaFnRBCUsb/B6MGGqDCqMoYAAAAAElFTkSuQmCC\n", 375 | "text/plain": [ 376 | "
" 377 | ] 378 | }, 379 | "metadata": {}, 380 | "output_type": "display_data" 381 | } 382 | ], 383 | "source": [ 384 | "for x, x_l in train_loader:\n", 385 | " z_loc, z_scale = vae.encoder(x)\n", 386 | "\n", 387 | "legends = [\"Digit 5\", \"Digit 6\", \"Digit 7\"]\n", 388 | "z_loc = z_loc.detach().numpy()\n", 389 | "for idx, i in enumerate([5,6,7]):\n", 390 | " plt.scatter(z_loc[x_l.numpy()==i,0], z_loc[x_l.numpy()==i,1], label=legends[idx])\n", 391 | "plt.legend()\n", 392 | "plt.show()" 393 | ] 394 | } 395 | ], 396 | "metadata": { 397 | "kernelspec": { 398 | "display_name": "probabilistic.ai", 399 | "language": "python", 400 | "name": "probabilistic.ai" 401 | }, 402 | "language_info": { 403 | "codemirror_mode": { 404 | "name": "ipython", 405 | "version": 3 406 | }, 407 | "file_extension": ".py", 408 | "mimetype": "text/x-python", 409 | "name": "python", 410 | "nbconvert_exporter": "python", 411 | "pygments_lexer": "ipython3", 412 | "version": "3.7.0" 413 | } 414 | }, 415 | "nbformat": 4, 416 | "nbformat_minor": 2 417 | } 418 | -------------------------------------------------------------------------------- /Day3/.ipynb_checkpoints/solution_simple_model-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import numpy as np\n", 17 | "import torch\n", 18 | "from torch.distributions import constraints\n", 19 | "import matplotlib.pyplot as plt\n", 20 | "\n", 21 | "import pyro\n", 22 | "from pyro.distributions import Normal, Gamma, MultivariateNormal\n", 23 | "from pyro.infer import SVI, Trace_ELBO\n", 24 | "from pyro.optim import Adam\n", 25 | "import pyro.optim as optim" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "## Generate some data" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "# Sample data\n", 42 | "np.random.seed(123)\n", 43 | "N = 100\n", 44 | "correct_mean = 5\n", 45 | "correct_precision = 1\n", 46 | "data = torch.tensor(np.random.normal(loc=correct_mean, scale=np.sqrt(1./correct_precision), size=N), dtype=torch.float)\n" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "## Our model specification" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 3, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "def model(data):\n", 63 | " gamma = pyro.sample(\"gamma\", Gamma(torch.tensor(1.), torch.tensor(1.)))\n", 64 | " mu = pyro.sample(\"mu\", Normal(torch.zeros(1), torch.tensor(10000.0)))\n", 65 | " with pyro.plate(\"data\", len(data)):\n", 66 | " pyro.sample(\"x\", Normal(loc=mu, scale=torch.sqrt(1. / gamma)), obs=data)" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": {}, 72 | "source": [ 73 | "## Our guide specification" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 4, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "def guide(data=None):\n", 83 | " rate = pyro.param(\"rate\", torch.tensor(1.))\n", 84 | " conc = pyro.param(\"conc\", torch.tensor(1.))\n", 85 | " pyro.sample(\"gamma\", Gamma(rate, conc))\n", 86 | "\n", 87 | " mu_mean = pyro.param(\"mu_mean\", torch.tensor(0.))\n", 88 | " mu_scale = pyro.param(\"mu_scale\", torch.tensor(1.))\n", 89 | " pyro.sample(\"mu\", Normal(mu_mean, mu_scale))" 90 | ] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "metadata": {}, 95 | "source": [ 96 | "## Do learning" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 5, 102 | "metadata": {}, 103 | "outputs": [ 104 | { 105 | "name": "stdout", 106 | "output_type": "stream", 107 | "text": [ 108 | "[epoch 000] average training loss: 1443.1143\n", 109 | "[epoch 500] average training loss: 336.6232\n", 110 | "[epoch 1000] average training loss: 243.5829\n", 111 | "[epoch 1500] average training loss: 178.8034\n", 112 | "[epoch 2000] average training loss: 179.9048\n", 113 | "[epoch 2500] average training loss: 187.7421\n" 114 | ] 115 | } 116 | ], 117 | "source": [ 118 | "# setup the optimizer\n", 119 | "adam_args = {\"lr\": 0.01}\n", 120 | "optimizer = Adam(adam_args)\n", 121 | "\n", 122 | "pyro.clear_param_store()\n", 123 | "svi = SVI(model, guide, optimizer, loss=Trace_ELBO(), num_samples=10)\n", 124 | "train_elbo = []\n", 125 | "# training loop\n", 126 | "for epoch in range(3000):\n", 127 | " loss = svi.step(data)\n", 128 | " train_elbo.append(-loss)\n", 129 | " if (epoch % 500) == 0:\n", 130 | " print(\"[epoch %03d] average training loss: %.4f\" % (epoch, loss))" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 6, 136 | "metadata": {}, 137 | "outputs": [ 138 | { 139 | "name": "stdout", 140 | "output_type": "stream", 141 | "text": [ 142 | "rate tensor(1.8184, requires_grad=True)\n", 143 | "conc tensor(2.1202, requires_grad=True)\n", 144 | "mu_mean tensor(5.0471, requires_grad=True)\n", 145 | "mu_scale tensor(0.0859, requires_grad=True)\n" 146 | ] 147 | } 148 | ], 149 | "source": [ 150 | "for name, value in pyro.get_param_store().items():\n", 151 | " print(name, pyro.param(name))" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 7, 157 | "metadata": {}, 158 | "outputs": [ 159 | { 160 | "data": { 161 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAEKCAYAAADenhiQAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xl8VNXdx/HPL5ONkJCFHUJI2AkiW0RUKousaotLrdi626Italu1FmpdirV1eWz7dLOl1aotrVprK3Upgnt9CghYREAFARFUdllEliTn+ePehEkyk0nCTGYyfN+vV17MnLudkwn3N2e555hzDhERkWhKiXcGREQk+Si4iIhI1Cm4iIhI1Cm4iIhI1Cm4iIhI1Cm4iIhI1Cm4iIhI1Cm4iIhI1Cm4iIhI1KXGOwPRYmaTgP8FAsDvnXN31rd/u3btXHFxcXNkTUQkaSxdunS7c659pP2SIriYWQD4FTAe2AS8bmZznXOrwh1TXFzMkiVLmiuLIiJJwczeb8h+ydIsNhxY65xb55w7BDwCTIlznkREjlnJEly6Ah8Evd/kp9VgZtPMbImZLdm2bVuzZU5E5FiTLMGlQZxzs51zZc65svbtIzYZiohIEyVLcNkMdAt6X+iniYhIHCRLcHkd6G1mJWaWDkwF5sY5TyIix6ykGC3mnCs3s6uBeXhDkR9wzq2Mc7ZERI5ZSRFcAJxzzwDPxDsfIiKSRMFFJNHs+vQQ727Zy7Du+ez+7DCHKirZ/dlhStq1JiM1QHlFJQCpgRS27j1A6/RUzOBwhSMnI5XHl25ieEkBua3SaJUeYM6ijbTJTCUrPZVAChyqcHTNy2TWU6tpn53OqL4dSE0xsjNSWbt1Hx9+8hkf7NpPflY6AwtzeXzpJsb378jhCkdhfis27fqMzZ/s53CFY83WvUz7XA9WfriHNq3SyEwL4Jzjz4s2MnV4Nx5bsolBhXnktkqjS14m2RmpvLdtH4EU46PdB1j54R4Gd8ujU5tM9h0sp3fHbOYs3MjZQ7qybe9BXlmzjbH9OrDnQDnOOT4/qAuPL93E/FVbOL4wl7c/3suh8koyUlM4WF7J+WVeF2paqvHBzs/o37kNr2/Yyda9BzhnSCHb9x3ktbXb2bX/MEUFWZhBVnqAnZ8eomObTNZv/5Rzhhby8+fXMLy4gDc+2MXhCsc1Y3vx50UbOaG4gD4ds1n54R7WbN3Hx7sPcFr/DvTumMOmnfurf99TT+jG4vU76dc5h/2HKth/qILt+w6S2yqN7gVZvPHBJ5wzpJCVH+5m38Fy3tq8m4kDOvHcqi2cUJzPgtVbaZUW4Nvje/OnhRsZ1j2f7IxUVn+0h4GFuaz8cA/rtn3K9n0HAWifk0F6IIWBXXNZt30f727Zx4TSjgzqlsdTb37E3gOHGdGjLbmt0piz6H2Gl7RlRI8CHl+6iXXbPuXCEUWs2bKPy04p5qV3tvGP/26mQ04mew4cZsqgLhw4XMmsswaQkRqI+d+/OedifpFEVFZW5vQQZcvinMM5SEkxlm3cRWZqgNIubQBYuG4HpV3a0CYzrcYxr2/YCcAJxQUcKq/kqw8vYWzf9rTLySAjNUCXvEz+9dbH/OKFtay4bQIP/d8G7v/3en5/yQnkZKby3MqP+Z/n3uXxq07ijmdWU5CVTlZGKleMLGHFpk+4+Um1vkrLs+HOM5p8rJktdc6VRdxPwUUaqryikk27PqO4XeuQ2w9XVJIWqDtG5NU12yjrXsDB8goOllfSPjuD8kpHeWUlWempbNt7kIPlFRTmZ3nflhdvpH12BuNLO7Lk/V10zMmkqG0Wp979Iht37q9x7tWzJnGovJJBs54DoE1mKtPH9GLHp4dwzvG7V9dH/xch0sK996PTCaRYk45taHBRs1iCeHbFRwztnk/HNpk10jfu2M+WvQc4obigwefaf6icWf9cxczJ/cnNSot8QAPdPe8dZr+yjtdmjKVTm0zKKyvJSA2wbe9Bdn92mHE/eZlWaQE+O1wBwKDCXP7nvEFcdP9ivjiskH8u/5CD5ZW0y05n+75DAPzgCwO4dW7ob/+ZaSkcOOw1HVU1l9TW/5Z/1Xi/50A5P3727aiVWeLn4pO68/B/ws80svim0xh+x/NRu15OZip7D5Q3aN9zhnbliWUNf9rhguHd+MviD+qkf2N0T3Iy0yhoncZ3/7YCgAmlHXlu1ZYj1xrSlSfeqP9aFwwvYuueAzz/9lZunNSXu//1To3t3dtm8f6OI1/MDpZXkJUe29u/gkuUzF3+Ic452mdnMLR7Ptv2HmT/oQr6dsqJeGx5RSVfn7OM4rZZvPSdMRyuqKxuwz71nhcBWPmDibTOqPlxHSyvIDUlhRSDeSs/Zmy/jqSnpnDfS+/xyOsfEEgxzh1WyJ7PDjO6bweeWLaJ6x5bzmn9OvCtcX0YWJjL9Y8tJ8XgpjP6U17pmL9qC+NLO7L7s8Ns33uQ82cvBODmM0uZ/co6AEbe9QK1K7wFrdMBqgMLwPJNuxn/01cAeHzppur0qsAChA0sQHVg8cpaN7Acq746soTf/7tmjax3h2zWbN1XZ7/pY3ox5Pb59Z5v9axJdYJ0KF85sYgJAzrRtnU6xe1as3Lz7uq/j0enjeDEHm35ZP8hLrp/MTmZqdzy+VIm/ezVsOe7/JQS1m/fx4vveLNlnNqnPa+8672edmoPvnd6f2ZNOY5Z/1zFA6+t58pTe/Bb/29w/Y9Px+zIN++nrx1JflY6I+96gcoGNMbktkpj4czTqsudl5XGST3a8uxbH0c89tUbx9CtIIuigizWbNnH0ys+YtKATvxrZd1je3XI5sfnDGRoUT4HyytrBKSTe7bluvF9SPVr++1zMrj8wSWccXzn6uByy5mlXD6yJGRwyctK45P9hwH44VnH8Z3HlwPQLjuDJ6efQqv0ABP8/3/zvz2Kfy7/kHufe4eXbxwTsoUh2hRcmuCDnfvZtu8gQ4vy+en8dxnTrwPX/uWN6u2d2mTy8Z4DAHQraMUHOz/jkpO68+HuA/znvR289YOJACx9fxdb9xzgtP4dAdiwYz/PrviIr89ZVueaA26dx/JbJzDoB17zzwOXlnH5g0sY07c9l5xczFV/WkbXvFYM6NKm+g9zzqKNzFm0EYDfXjSM6x7z/vief3srz7+9lZ+dP5i/LfNu+n8NuvnPfGJFnevf/tSROUBDtaTu/PRQ3USJaGDXXFZs3g1AemoKhxoQRM8c1IWbzujPf9bt4Mu/WwRAp9zMGsFlTN/2fP/M0pDH33JmKbOCPs9W6QFev2kcJ9yxAICvfa6kTnPiazPG0jWvVY20E3u0rdN2n5eVzj+vGVn9Pvgbs5n3t1MVCC8fWUzb1hn8ZfFGZj21iq55mWy48wzWb/+UkqCm1xsn9aVLXiaTB3bmt6+sIy8rrUZgARjQJReAVbMm8e6WvXzhl68B8Oevnsiw4nzu//f66m/zOZmpLL91AgDfP6M/J5a0ZWBhLnsOHOZwhWPB6iO1hqrynf3r13hj4yf8/uIyuhVkAfCtcX04cLiCjm0yOWdo1+rg8u4PJ5MWMP66ZBOTBnaq7gf8yZcGc8uZpTz4fxv42YI19OqQXR1YAMb268iC606lV4cc1m7dxy9eWMulJxcDXnPvngPlLLhuFEs27GTGEyv45QVDufB+7/MPpBgnlhTwxLLNlLRrzaBueTV+P+mpKZw7rJBzhxXSXNTn0khvbNzF2b/+PwAuHFHEnxZubPQ5TunVlj2flVffVIK/sUlyyMlMZdH3TmPOwo3c8czq6vTfXDiUq/60rPr1pOM6M2fR+wzulseALrnMW/kxV/5xKcNLCli8fmfIcz91zUiO65rL0vd3ce593t/iou+dxp8XbWRg11y++vASzhtWyD3nDQLgn8s/5Br/y88j00Ywokdbimc8DcDrN42jfU4GAP95bwcOx8k921Vvr9LUDuCfLXiXny1YA8CC604lKz2VLnmtOFheUT1iqbyikt++so5LTy6uUzsP9vHuA4z48fN0yMlg8U3jAKrzWTt/1/7lDeYu/7D6Zh28b26rtOrgEsoLb2+p/uL2h8uGN6ic723bx2n3vkxJu9a8eMPoevetuofcf0lZ9RfLSA74LQJVo/je3bKPvp1yapTfOcfGnfvp3vZIYP73mu0crqxkTN8ODbpOQ6jPJUaqAgvQpMAC8NraHTXeK7DUr1eHbNbWavKJtatG9eQ3L7/X6OP+9a3PVTcFZaWn8tXPlTCydzuyM1IpaJ1e4+Y56bjOAHzlxO7VaRMHdKq+UU7/8zKefvMj7j1vENf/1at1fmdiXwb4I+SCazkd22Ty7fF9cM5x85mlnH/CkdmQPj+oC3/8z/ss3lA3WFUFFoCTerZtdHkjuWZsb8q6F5CTmVp9kwdqDIVNDaQwfUyvBp8zuNJyfGEu73y8t84+Pz5nIF8Y1KXGNQMpRkWlY0SP+vsvKyurrtPwDu9u+VkMKsxlxuT+EfcdUpTPqlkTG9XnkZl25PdlZtXN7d88rXd1c7WZ1QgsACN7t2vwNaJNwUUSXmoDR7WUdm7Dqo/2HNW1nr52JGu37mPTrs8ArwnnvguHMe4nL1fv89Dlw9m449Maw5BzMlLZe7Cc1JSabdlmRv/ObWqkzZjcj4FdcyPm5edTh3DveYPITAtUB5erRvWsvumVFefXOcbMuGJkSd2T1foV9mjXmh0NaMoc178DC1ZvjbhfOIEUi9oNrurPIDhAz716JKFaX1pnpDKutGatoHeHbN7+eC/fmdi33utU+udrzGCq9NQUnrx6ZOQdfdHqTP/2+D58e3yfqJwr2pJlbjFJYr/6ytAG7XfhiO413pe0a82ym8eH3T9U0OrfqQ1TBh9ZrWF8aUd6dciufv/fW8Yzqk97LjqpuDptw51nkOKfqyGB8KpRPTmlV+QbbiDFanxjhZo3vKPplH3++lG8Uc/vpsp9Fw5jzR2Tm3ydaOrQJpOZk/vxUK2mqobWMGZfVMb0MT3p2T673v2qBgSkNKLmInUpuEjC+MUFQ0Km92yfzYUjirh9ygB++eXQ+4A33DNYcdss8usZin3tab156HLvRnXFyBIeu/Kk6iBRpfZ34rys9HpKENRkE+WuzGHd8/3zN+2GV+YfX9UMZmZ1yhrsslOKAS9YNsfIooa6clTP6g71xipqm8V3JvaL+Ds8UnNRcDkaahaThJEWqPuf+Zun9Qbgh2cNBGDb3oM1tv/gCwMY0cPrKzAznrpmJGf+4t8AfH10rxo3kr4dc3hny5H2+RSDUX3a89INo8M+GFpl/rdPrbffp+oyVrv9KUoevOwEPvzkQJOPv35CX84Z2jXit/Yqt5xZys1nlDY5mLVk1cElcWJqi6Rfn8TE3eceX6evIZJQ3xRrtyfXDkCXnFxM30451R2cx3XNZUiRNwyz9hfuf0w/haXfH0e3Am9IbVUtJFRgOa+skAFd2nCR39TWu2MOkwd2Dpv3GZP6AZCdGZvvazmZaQ16ZiqcQIrV6NyOJFLNJplVdeEci4E1mlRzkZj40gndeHXtdlY3ooO9IZ2cqU1ooql6rqJVeoBW6QFeuH40jy35gKknFIU9pkNOJk9f+7l6z/vqjWPY4j/PNHV4EVOHF7H/kPeEd88ODashSOIZ7D8jcl4zPhOSjBRcJOp+er73fEVjv/ed0qst3z+jPz982nsu5OXvjK6zT6AJ3yb/ec3IGk/4pwVSagz/bapuBVl12v+z0lN5+PLhDRoNFg1//8bJdeZbk6PTrSDrqCZ2FI+axSSq/nDZCZw9xPvGF6pPe9nN4/nhWceFPNbM+OrnelS/rz1mH0L3y9T2v+cP4csnFjGo0PsGmpkWILdV9OZYi+TUPu3Jb11/x3+0DCnKrzG6TSRRKLjIUXsmqPko0pPABa3TQwad26cMaNC1UgMp9Q4vBm9U0I/OHtikJjQRiQ41i0lE4WYkrlK1pkqD+T2mJe1as377pwzqllfjuZFlN4+vHrETSkEz1QpEpOn01U4iqpqjKlqqwkbrjNCr4RW0TqdddkbIbVVmTu7HH69o2LxPItL8VHORiE7qEX7OqaaMVq0e6nkUz4RcOapnk48VkdhTzSWJVE098vz1o2qk/7rW9CmPTBvRqPPWN0Drf6eGf2I+nKq5oPQYgUjyUnBpgb5y4pHnM84ZemSkUNXNOrvWtOW11+Fo7D29vv3rmyKjY07opq3avSmKMSLJR8ElQeTUs45FbY4jtZSCoLmuqp4obsrNukNOBqtmTQy5rb4nlesbkHVDmNlnaz8BfWyuKCSS3BRcEsSNk/s1eF/n4O/fOAWAsf2ODP2t7v+oFQvq1BRCBIvUFCMrPZWOberWNppac6k9o2/t/KjGIpK8FFxaiKtG9axe8tQ5x8DCXDbceQYnB03dHnyjXzjztOo5tGqrr6+jaknWUEpCzMHVlJlj1ecikvwUXFqI/p1z6OdPXBjuGZAjFRejU24m+X6TWe17eKh7enWTWoiNZvDiDaP5x/RT6mw7mpljFVtEkpeCS4KIdKM98/gu1Tf+ep4v9M4V4WSR1iQJpaRd65BTqJzcs/GrDGrWWZHkp+CSIOqLF8NLCgikWPVzIeH2TWlgh373tlk8Of2UkA8qNubZk5mT+4XtV6mPq1UChRiR5KPg0pJEqLmk1jOpY+3JIgd1yyM9xP7NUZlw1cvI+u9jf0kRaWYKLgmgXXZGvd/eaw8Cq/3Nv3q/eiJD7fXlAcorj5zHRWpr89X3tH5DHRktpjqLSLJScGkBrhhZAhxp9qodB166YTSvzRgb9vhwDy1WVDYsoARf7y/TRjDt1B7hd27M+RRbRJKW5hZLcMGLFlWNzKo9WizS+u/hHK6oO9Nxc3SyV9W8FFtEkldcai5mdp6ZrTSzSjMrq7VtppmtNbN3zGxiUPokP22tmc0ISi8xs0V++qNmllDzsedlRW+RquoO/UZ2UtQZiuwHkPIG1lxqG9GjAPAWqmqKI6PFmnS4iLQA8WoWews4B3glONHMSoGpwABgEvBrMwuYWQD4FTAZKAUu8PcFuAv4qXOuF7ALuKJ5itAwz37zc4zu2z7CXg27yVsDO8AbGjJ++eW6k04G3++r5ihrlV5zRNjYfh1587YJDC8paOCVajp7SFc652ZywfDwa9iLSMsWl+DinFvtnHsnxKYpwCPOuYPOufXAWmC4/7PWObfOOXcIeASYYt5X8LHA4/7xDwFnxb4EDdc5txUjotAJHixc53tDKwJV+43t17HOtguGd6t+ff2EPrz3o9NDDjeu70n+SLrkteI/M0+jqNb68yKSPBKtQ78r8EHQ+01+Wrj0tsAnzrnyWulJqTDfm86lf+dGrvzYCBedVMzFJ3kjywwINGXBFhE55sWsQ9/MFgCdQmy6yTn3ZKyuWx8zmwZMAygqSpwmmStGNmz01bDuBcy9+hSO65J7VNcLPcVLqGdeYhtYqmYKKI1hsBSR+IhZcHHOjWvCYZuBbkHvC/00wqTvAPLMLNWvvQTvHypPs4HZAGVlZQnx7N6/vzuGwvws/rTw/Qbtf3xhXoxzBO39J/ejORghlJJ2rfnb10/muK4KLiLJJtGGIs8F/mxmPwG6AL2BxXgtNL3NrAQveEwFvuycc2b2IvBFvH6YS4C41IrqE6inBlCYH/1+h8hzj9XNzw0T+1S/vmp0T7rmt+ILg7o0+Jq3TxnQpOa6Yd2bNuJMRBJbvIYin21mm4CTgKfNbB6Ac24l8BiwCvgXMN05V+HXSq4G5gGrgcf8fQG+C1xnZmvx+mDub97SRBbq6fhY+N7p/QmkWMgJJiM5e0hh9eu0QArnDC1sVLPYRScVU1bctNFjIpJ84lJzcc79Hfh7mG13AHeESH8GeCZE+jq80WQJq1V6gJmT+zGsez63PLmSVR/tqbNPNLo3zh1WyLnDCiPvKCISY4k2WixpXTmqJ2XFBVxycuhaTGMfjBQRSWQKLiIiEnUKLgkiFqN+e/hzjrXOSLRxGyKS7HTXaWbNOc38j84ZyFlDutKrQ3azXVNEBFRzSWpZ6amM7tsh3tkQkWOQgkuCaMpa9CIiiUrBJUGUtGtN6/TGr0cvIpKIFFxi4MQmTkUvIpIsFFxi4OYzSxncLcwcYPX057f15/Q6oVhToohIy6bgEgNm8Mi0Efz3lvF1tqUHwv/K22Z7swTfOKlfzPImItIcFFxiJDMtUD2lfLAzj+/M9DE96z02Rev/ikgLp+ASA/U9y5IaSOE7E1UzEZHkpuASA6p4iMixTk/oJ6jfXjSMjTv2xzsbIiJNouCSoCYOCLVCtIhIy6BmMYmpf0w/hQXXnRrvbIhIM1PNJQbU53JE2Od9RCSpqeYSA80587GISCJScBERkahTcIkBNYuJyLFOfS4J5K5zj+eeee8wsGtuvLMSdbd9vpStew/GOxsi0kwUXBJIn445/O7ispidf9qpPZj9yrqYnb8+l55SEpfrikh8qFnsGPK90/uz4c4z4p0NETkGKLjEgLpcRORYp+ASA+rQF5FjnYKLiIhEnTr0Y6LlVF16tm9Ndob+DEQkunRXOcY9f/3oeGdBRJKQmsWOQlFBVsh09bmIyLFOweUonNSjbZ20oUV5YYOOiMixQsElyuZ8dQRpAf1aReTYprugiIhEXVyCi5ndY2Zvm9mbZvZ3M8sL2jbTzNaa2TtmNjEofZKfttbMZgSll5jZIj/9UTNLb+7yBFN/i4hI/Gou84HjnHPHA+8CMwHMrBSYCgwAJgG/NrOAmQWAXwGTgVLgAn9fgLuAnzrnegG7gCuatSQiIlJHXIKLc+4551y5/3YhUOi/ngI84pw76JxbD6wFhvs/a51z65xzh4BHgClmZsBY4HH/+IeAs5qrHCIiEloi9LlcDjzrv+4KfBC0bZOfFi69LfBJUKCqSm8R+nXKiXcWRERiImYPUZrZAqBTiE03Oeee9Pe5CSgH5sQqH7XyNA2YBlBUVBSjazRsv5U/mEhqQB00IpKcYhZcnHPj6ttuZpcCZwKnOeecn7wZ6Ba0W6GfRpj0HUCemaX6tZfg/UPlaTYwG6CsrMyF2685tNaUKyKSxOI1WmwScCPwBefc/qBNc4GpZpZhZiVAb2Ax8DrQ2x8Zlo7X6T/XD0ovAl/0j78EeLL5ytFcVxIRaVni9fX5l0AGMN/rk2ehc+4q59xKM3sMWIXXXDbdOVcBYGZXA/OAAPCAc26lf67vAo+Y2Q+BN4D7m7coNVkLmrRSRCRW4hJc/GHD4bbdAdwRIv0Z4JkQ6evwRpOJiEiCSITRYklFTWUiIgouIiISAwouRyFULUUVFxERBRcREYkBBZcoM3W6iIgouIiISPQpuESZ6i0iIgouIiISAxGDi5kdZ2YPm9kS/+chMzu+OTLXEqnLRUQkQnAxsynA34GX8KbGvxx4Gfibv01ERKSOSNO/zALGO+c2BKW9aWYv4E0Q2WyTRLYUGi0mIhK5WSy1VmABwE9Li0WGRESk5YsUXMrNrM6qWmbWHW/WYhERkToiBZdbgQVmdqmZDfR/LgOeA26JffYSz5CivHhnQUQk4dXb5+Kc+4eZrQeuB67xk1cBX3LOLY915hLRyT3b8sbGT/x36l8REQkl4noufhC5uBnyIiIiSSLSUOR2ZnarmV1rZtlmdp+ZvWVmT5pZ2AW/kplWmhQRiSxSn8uf8ZYjrlrLfj3eevVPAb+PbdZERKSlitQs1tE59z3zHt543zl3t5/+tplNj3HeRESkhYpUc6kAcM45YHutbZUxyVGCC35GUs9LioiEFqnm0sPM5uINi6p6jf++JKY5ExGRFitScAmeP+x/am2r/V5ERASI/JzLy+G2mdmjeJNYioiI1HA067mcFLVciIhIUtFiYY2kPnwRkcjqbRYzs6HhNqFZkUVEJIxIHfr31rPt7WhmREREkkekDv0xzZWRFkMPt4iIRBRpbrEbg16fV2vbj2KVKRERadkidehPDXo9s9a2SVHOi4iIJIlIwcXCvA71XkREBIgcXFyY16HeH3OCo2uKQq2ISLVIo8UGmdkevPtoK/81/vvMmOYsQYWLIfOvG8WKTbubNS8iIomq3pqLcy7gnGvjnMtxzqX6r6veN/k5FzO73czeNLP/mtlzZtbFTzcz+7mZrfW3Dw065hIzW+P/XBKUPszMVvjH/NxfHiBmwp29Z/tszhrSNZaXFhFpMeL1hP49zrnjnXOD8RYeu8VPn4y3MFlvYBpwH4CZFQC3AicCw4FbzSzfP+Y+4GtBx2mggYhInMUluDjn9gS9bc2R/pspwMPOsxDIM7POwERgvnNup3NuFzAfmORva+OcW+ivOfMwcFZs8x7Ls4uIJIdIfS4xY2Z3ABcDu4GqhzW7Ah8E7bbJT6svfVOI9HDXnIZXI6KoqOjoCiAiImHFrOZiZgvM7K0QP1MAnHM3Oee6AXOAq2OVj2DOudnOuTLnXFn79u2bdA49oC8iElnMai7OuXEN3HUO8Axen8pmoFvQtkI/bTMwulb6S356YYj9Y8b0eI+ISERx6XMxs95Bb6dwZBLMucDF/qixEcBu59xHwDxggpnl+x35E4B5/rY9ZjbCHyV2MfBk85Wjua4kItKyxKvP5U4z6wtUAu8DV/npzwCnA2uB/cBlAM65nWZ2O/C6v98s59xO//U3gAeBVsCz/o+IiMRRXIKLc+7cMOkOmB5m2wPAAyHSlwDHRTWDIiJyVLQSpYiIRJ2Ci4iIRJ2CSyOpE19EJDIFFxERiToFl0ZSxUVEJDIFFxERiToFl0bSvJUiIpEpuDSSmsVERCJTcGmk4NFimmdMRCQ0BZdGCl7P5dQ+TZtZWUQk2Sm4NFJlUHAZX9oxfhkREUlgCi6N5NSlLyISkYJLI7VKCwAwsle7OOdERCRxKbg0Uf/OOfHOgohIwlJwaSLTJGMiImEpuIiISNQpuIiISNQpuIiISNQpuDSRcxqSLCISjoJLI6kfX0QkMgWXRlKFRUQkMgWXJtJQZBGR8BRcREQk6hRcREQk6hRcmkijxUREwlNwaSR1tYiIRKbgIiIiUafg0khqDRMRiUzBpYk0FFlEJDwFlyZSh76ISHgKLo2kCouISGQKLiIiEnVxDS5mdr2ZOTNr57+QPuHkAAAM0klEQVQ3M/u5ma01szfNbGjQvpeY2Rr/55Kg9GFmtsI/5uemzhARkbiLW3Axs27ABGBjUPJkoLf/Mw24z9+3ALgVOBEYDtxqZvn+MfcBXws6blJz5F9ERMKLZ83lp8CNQHDP+BTgYedZCOSZWWdgIjDfObfTObcLmA9M8re1cc4tdF4P+8PAWbHMtPrxRUQii0twMbMpwGbn3PJam7oCHwS93+Sn1Ze+KUS6iIjEUWqsTmxmC4BOITbdBHwPr0msWZnZNLzmNoqKipp4jmjmSEQkOcUsuDjnxoVKN7OBQAmw3O97LwSWmdlwYDPQLWj3Qj9tMzC6VvpLfnphiP3D5Wk2MBugrKxMDVwiIjHS7M1izrkVzrkOzrli51wxXlPWUOfcx8Bc4GJ/1NgIYLdz7iNgHjDBzPL9jvwJwDx/2x4zG+GPErsYeLK5yyQiIjXFrObSRM8ApwNrgf3AZQDOuZ1mdjvwur/fLOfcTv/1N4AHgVbAs/6PiIjEUdyDi197qXrtgOlh9nsAeCBE+hLguFjlr+71mutKIiItl57QFxGRqFNwaSSNFhMRiUzBJQoGd8uLdxZERBJK3PtcWrq3b59EaoqqMyIiwRRcjlJmWiDeWRARSThqFmsijRoTEQlPwUVERKJOwaWJNGpMRCQ8BRcREYk6BRcREYk6BRcREYk6BZcm0mgxEZHwFFwayVBPvohIJAoujeRQlUVEJBIFlybSUGQRkfAUXEREJOoUXJpIHfoiIuEpuDSSOvRFRCJTcBERkahTcBERkahTcGkkDUUWEYlMwaWJNBRZRCQ8BZcm0mgxEZHwFFwaSaPFREQiU3AREZGoU3AREZGoU3AREZGoU3BpJA1FFhGJTMFFRESiTsGlkTRaTEQkMgUXERGJOgUXERGJurgEFzO7zcw2m9l//Z/Tg7bNNLO1ZvaOmU0MSp/kp601sxlB6SVmtshPf9TM0pu7PCIiUlM8ay4/dc4N9n+eATCzUmAqMACYBPzazAJmFgB+BUwGSoEL/H0B7vLP1QvYBVwRy0ynBrw+l7RUVfpERMJJjXcGapkCPOKcOwisN7O1wHB/21rn3DoAM3sEmGJmq4GxwJf9fR4CbgPui1UGv3xiER/vPsDVY3rF6hIiIi1ePL9+X21mb5rZA2aW76d1BT4I2meTnxYuvS3wiXOuvFZ6zGSkBph5en9aZyRaXBYRSRwxCy5mtsDM3grxMwWvZtETGAx8BNwbq3zUytM0M1tiZku2bdvWHJcUETkmxezrt3NuXEP2M7PfAU/5bzcD3YI2F/pphEnfAeSZWapfewneP1SeZgOzAcrKyvSovYhIjMRrtFjnoLdnA2/5r+cCU80sw8xKgN7AYuB1oLc/Miwdr9N/rnPOAS8CX/SPvwR4sjnKICIi4cWr4+BuMxsMOGADcCWAc26lmT0GrALKgenOuQoAM7samAcEgAeccyv9c30XeMTMfgi8AdzfnAUREZG6zB2jSyqWlZW5JUuWxDsbIiItipktdc6VRdpPD2uIiEjUKbiIiEjUKbiIiEjUHbN9Lma2DXi/iYe3A7ZHMTvxlCxlSZZygMqSqJKlLEdbju7OufaRdjpmg8vRMLMlDenQagmSpSzJUg5QWRJVspSlucqhZjEREYk6BRcREYk6BZemmR3vDERRspQlWcoBKkuiSpayNEs51OciIiJRp5qLiIhEnYJLI4RbajmRmdkGM1vhLye9xE8rMLP5ZrbG/zffTzcz+7lfvjfNbGic8/6AmW01s7eC0hqddzO7xN9/jZldkkBlidpy381Yjm5m9qKZrTKzlWb2TT+9xX0u9ZSlJX4umWa22MyW+2X5gZ9eYiGWgfcnB37UT19kZsWRythozjn9NOAHb8LM94AeQDqwHCiNd74akO8NQLtaaXcDM/zXM4C7/NenA88CBowAFsU576cCQ4G3mpp3oABY5/+b77/OT5Cy3AbcEGLfUv/vKwMo8f/uAonwNwh0Bob6r3OAd/38trjPpZ6ytMTPxYBs/3UasMj/fT8GTPXTfwN83X/9DeA3/uupwKP1lbEpeVLNpeGG4y+17Jw7BDyCtyxzSzQFb0lo/H/PCkp/2HkW4q2V0znUCZqDc+4VYGet5MbmfSIw3zm30zm3C5gPTIp97msKU5Zwqpf7ds6tB6qW+47736Bz7iPn3DL/9V5gNd7qry3uc6mnLOEk8ufinHP7/Ldp/o/DWwb+cT+99udS9Xk9DpxmZkb4MjaagkvDhVtqOdE54DkzW2pm0/y0js65j/zXHwMd/dctoYyNzXuilykay33Hhd+UMgTvW3KL/lxqlQVa4OdiZgEz+y+wFS9Yv0f4ZeCr8+xv3423bHzUyqLgkvxGOueGApOB6WZ2avBG59WFW+SQwZacd19clvuOBjPLBv4GfMs5tyd4W0v7XEKUpUV+Ls65CufcYLwVeYcD/eKZHwWXhqtvCeaE5Zzb7P+7Ffg73h/dlqrmLv/frf7uLaGMjc17wpbJObfFvyFUAr/jSPNDQpfFzNLwbsZznHNP+Mkt8nMJVZaW+rlUcc59grdC70n4y8CHyFd1nv3tuXjLxketLAouDRdyqeU456leZtbazHKqXgMT8JaUnou3JDTUXBp6LnCxP8JnBLA7qKkjUTQ27/OACWaW7zdvTPDT4s6itNx3M+fZ8FZ7Xe2c+0nQphb3uYQrSwv9XNqbWZ7/uhUwHq8PKdwy8MGf1xeBF/waZ7gyNl5zjmho6T94I1/exWvLvCne+WlAfnvgjfxYDqysyjNe2+rzwBpgAVDgpxvwK798K4CyOOf/L3jNEofx2n6vaEregcvxOibXApclUFn+6Of1Tf8/deeg/W/yy/IOMDlR/gaBkXhNXm8C//V/Tm+Jn0s9ZWmJn8vxeMu8v4kXDG/x03vgBYe1wF+BDD8903+/1t/eI1IZG/ujJ/RFRCTq1CwmIiJRp+AiIiJRp+AiIiJRp+AiIiJRp+AiIiJRp+AiScHMnJndG/T+BjO7LUrnftDMvhh5z6O+znlmttrMXqyV3sXMHvdfDw6epTcK18wzs2+EupbI0VBwkWRxEDjHzNrFOyPBgp6ObogrgK8558YEJzrnPnTOVQW3wXjPVEQrD3l4M+SGupZIkym4SLIox1u+9du1N9SueZjZPv/f0Wb2spk9aWbrzOxOM/uKvy7GCjPrGXSacWa2xMzeNbMz/eMDZnaPmb3uT3J4ZdB5XzWzucCqEPm5wD//W2Z2l592C95Dffeb2T219i/2900HZgHnm7fOyPn+LAwP+Hl+w8ym+MdcamZzzewF4Hkzyzaz581smX/tqll77wR6+ue7p+pa/jkyzewP/v5vmNmYoHM/YWb/Mm8tlruDfh8P+nldYWZ1Pgs5djTmW5VIovsV8GbVza6BBgH98abDXwf83jk33LyFo64BvuXvV4w3x1RP4EUz6wVcjDedyQlmlgG8ZmbP+fsPBY5z3rTl1cysC3AXMAzYhTdj9VnOuVlmNhZvHZEloTLqnDvkB6Ey59zV/vl+hDd1x+X+9B+LzWxBUB6Od87t9GsvZzvn9vi1u4V+8Jvh53Owf77ioEtO9y7rBppZPz+vffxtg/FmET4IvGNmvwA6AF2dc8f558qL8LuXJKaaiyQN581o+zBwbSMOe91563ocxJvyoio4rMALKFUec85VOufW4AWhfnjzYV1s3jTni/CmQOnt77+4dmDxnQC85Jzb5rypzufgLSTWVBOAGX4eXsKb1qPI3zbfOVe1howBPzKzN/GmZ+nKkWnxwxkJ/AnAOfc28D5QFVyed87tds4dwKuddcf7vfQws1+Y2SRgT4hzyjFCNRdJNj8DlgF/CEorx/8iZWYpeKsFVjkY9Loy6H0lNf9/1J4nyeHdsK9xztWYcNHMRgOfNi37jWbAuc65d2rl4cRaefgK0B4Y5pw7bGYb8AJRUwX/3iqAVOfcLjMbhLcQ2FXAl/DmD5NjkGouklT8b+qP4XWOV9mA1wwF8AW8Vfoa6zwzS/H7YXrgTeo3D/i6edO2Y2Z9zJt9uj6LgVFm1s7MAsAFwMuNyMdevCV5q8wDrjEz8/MwJMxxucBWP7CMwatphDpfsFfxghJ+c1gRXrlD8pvbUpxzfwO+j9csJ8coBRdJRvcCwaPGfod3Q1+Ot8ZFU2oVG/ECw7PAVX5z0O/xmoSW+Z3gvyVCa4DzppufgTcV+nJgqXPuyfqOqeVFoLSqQx+4HS9YvmlmK/33ocwBysxsBV5f0dt+fnbg9RW9VXsgAfBrIMU/5lHgUr/5MJyuwEt+E92fgJmNKJckGc2KLCIiUaeai4iIRJ2Ci4iIRJ2Ci4iIRJ2Ci4iIRJ2Ci4iIRJ2Ci4iIRJ2Ci4iIRJ2Ci4iIRN3/AyvrVt/o/EiHAAAAAElFTkSuQmCC\n", 162 | "text/plain": [ 163 | "
" 164 | ] 165 | }, 166 | "metadata": {}, 167 | "output_type": "display_data" 168 | } 169 | ], 170 | "source": [ 171 | "plt.plot(range(len(train_elbo)), train_elbo)\n", 172 | "plt.xlabel(\"Number of iterations\")\n", 173 | "plt.ylabel(\"ELBO\")\n", 174 | "plt.show()" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": null, 180 | "metadata": {}, 181 | "outputs": [], 182 | "source": [] 183 | } 184 | ], 185 | "metadata": { 186 | "kernelspec": { 187 | "display_name": "probabilistic.ai", 188 | "language": "python", 189 | "name": "probabilistic.ai" 190 | }, 191 | "language_info": { 192 | "codemirror_mode": { 193 | "name": "ipython", 194 | "version": 3 195 | }, 196 | "file_extension": ".py", 197 | "mimetype": "text/x-python", 198 | "name": "python", 199 | "nbconvert_exporter": "python", 200 | "pygments_lexer": "ipython3", 201 | "version": "3.7.0" 202 | } 203 | }, 204 | "nbformat": 4, 205 | "nbformat_minor": 2 206 | } 207 | -------------------------------------------------------------------------------- /Day3/BBVI-gradient-variance.eps: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PGM-Lab/probabilisticAI_tutorials/40b644dccee2d26e469bb5fd32ae01e430e66200/Day3/BBVI-gradient-variance.eps -------------------------------------------------------------------------------- /Day3/BBVI_exercise.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PGM-Lab/probabilisticAI_tutorials/40b644dccee2d26e469bb5fd32ae01e430e66200/Day3/BBVI_exercise.png -------------------------------------------------------------------------------- /Day3/FA_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PGM-Lab/probabilisticAI_tutorials/40b644dccee2d26e469bb5fd32ae01e430e66200/Day3/FA_model.png -------------------------------------------------------------------------------- /Day3/VAE.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Adapt the variational auto encoder\n", 8 | "\n", 9 | "Below you will find an implementation of a VAE for the MNIST data. To allow for faster learning time, we only consider the digits 0,1, and 2 and only the first 100 samples of those digits.\n", 10 | "\n", 11 | "In this exercise, you should familiarize yourself with the implementation below and experiment with the structure of the VAE specification in order to emphasize digit separation in the latent space and the generation of images when sampling from the latent space.\n", 12 | "\n", 13 | "Part of the implementation is based on code from the official Pyro examples." 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 1, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "import numpy as np\n", 23 | "import torch\n", 24 | "import torchvision.datasets as datasets\n", 25 | "import torch.nn as nn\n", 26 | "import torchvision.transforms as transforms\n", 27 | "import pyro\n", 28 | "import pyro.distributions as dist\n", 29 | "from pyro.infer import SVI, Trace_ELBO\n", 30 | "from pyro.optim import Adam\n", 31 | "import datetime\n", 32 | "import os\n", 33 | "import matplotlib.gridspec as gridspec\n", 34 | "from matplotlib import pyplot\n", 35 | "import matplotlib.pyplot as plt\n", 36 | "from scipy.stats import norm" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "### Get the MNIST data\n", 44 | "\n", 45 | "We will wrap the MNIST data set in a Pyro data loader. " 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 2, 51 | "metadata": { 52 | "scrolled": true 53 | }, 54 | "outputs": [], 55 | "source": [ 56 | "def setup_data_loader(batch_size=64):\n", 57 | " #data = datasets.MNIST('./data', train=True, download=True,\n", 58 | " # transform=transforms.Compose([\n", 59 | " # transforms.ToTensor(),\n", 60 | " # transforms.Normalize((0.1307,), (0.3081,))\n", 61 | " # ]))\n", 62 | " data = datasets.MNIST('./data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor()]))\n", 63 | " \n", 64 | " # We only select the digits 0, 1, and 2 and only the first 100 of each of these\n", 65 | " # digits\n", 66 | " selector = np.array([], dtype=int)\n", 67 | " for i in [5, 6, 7]:\n", 68 | " selector = np.concatenate((selector, np.where(data.targets == i)[0][:100]))\n", 69 | " data.data = data.data[selector, :, :]\n", 70 | " data.targets = data.targets[selector]\n", 71 | " \n", 72 | " # Binarize the data\n", 73 | " data.data[data.data<128] = 0\n", 74 | " data.data[data.data>=128] = 1\n", 75 | "\n", 76 | " data.data = data.data.type(torch.float)\n", 77 | " \n", 78 | " # Put the data within a data loader \n", 79 | " train_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)\n", 80 | " return train_loader\n", 81 | "\n", 82 | "\n", 83 | "train_loader = setup_data_loader(batch_size=300)" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 3, 89 | "metadata": {}, 90 | "outputs": [ 91 | { 92 | "data": { 93 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAA5NJREFUeJzt3VFuwjAUAMG64v5Xdi9QoGowCd6ZbwQBafU+XhzGnPML6Pk++wKAc4gfosQPUeKHKPFDlPghSvwQJX6IEj9E3d75YWMMtxPCYnPO8ZfXmfwQJX6IEj9EiR+ixA9R4oco8UOU+CFK/BAlfogSP0SJH6LED1HihyjxQ5T4IUr8ECV+iBI/RIkfosQPUeKHKPFDlPghSvwQJX6IEj9EiR+ixA9R4oco8UOU+CFK/BAlfogSP0SJH6LED1Hih6jb2RfAtc05l733GGPZe/OcyQ9R4oco8UOU+CFK/BAlfogSP0TZ88et3OMf/Wz3Aaxl8kOU+CFK/BAlfogSP0SJH6Ks+jZ3dJW3ct125poRkx+yxA9R4oco8UOU+CFK/BAlfoiy59/AkX35lY/NOvK7lskPUeKHKPFDlPghSvwQJX6IEj9E2fNvzi6ce0x+iBI/RIkfosQPUeKHKPFDlPghyp7/A3i+PSuY/BAlfogSP0SJH6LED1HihyjxQ5T4IUr8ECV+iBI/RIkfosQPUeKHKPFDlPghSvwQJX6IEj9EiR+ixA9R4ocoj+7+ANW/2a5+73cx+SFK/BAlfogSP0SJH6LED1Hihyh7fpby9+LXZfJDlPghSvwQJX6IEj9EiR+ixA9R9vw8ZE+/L5MfosQPUeKHKPFDlPghSvwQJX6Isue/ALv03z37XTzX/xiTH6LED1HihyjxQ5T4IUr8EGXV9wJWdf9jVXcukx+ixA9R4oco8UOU+CFK/BAlfoiy53+BnffVR+9h2Pm3+XQmP0SJH6LED1HihyjxQ5T4IUr8EGXPH+dZBF0mP0SJH6LED1HihyjxQ5T4IUr8EGXPzyHO638ukx+ixA9R4oco8UOU+CFK/BBl1bc5R3a5x+SHKPFDlPghSvwQJX6IEj9EiR+i7Pl5yJHdfZn8ECV+iBI/RIkfosQPUeKHKPFDlD1/nD1+l8kPUeKHKPFDlPghSvwQJX6IEj9EiR+ixA9R4oco8UOU+CFK/BAlfogSP0Q5z7855/W5x+SHKPFDlPghSvwQJX6IEj9EiR+ixA9R4oco8UOU+CFK/BAlfogSP0SJH6LED1HihyjxQ5T4IUr8ECV+iBI/RI0559nXAJzA5Ico8UOU+CFK/BAlfogSP0SJH6LED1HihyjxQ5T4IUr8ECV+iBI/RIkfosQPUeKHKPFDlPghSvwQJX6IEj9EiR+ifgCJ5jwXeHW/1QAAAABJRU5ErkJggg==\n", 94 | "text/plain": [ 95 | "
" 96 | ] 97 | }, 98 | "metadata": {}, 99 | "output_type": "display_data" 100 | } 101 | ], 102 | "source": [ 103 | "def display_image(x):\n", 104 | " plt.axis('off')\n", 105 | " pyplot.imshow(x.reshape((28, 28)), cmap=\"gray\")\n", 106 | " \n", 107 | "toy_image = train_loader.dataset.data[215,:,:]\n", 108 | "display_image(toy_image)" 109 | ] 110 | }, 111 | { 112 | "cell_type": "markdown", 113 | "metadata": {}, 114 | "source": [ 115 | "### Setup the decoder network" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 4, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "class Decoder(nn.Module):\n", 125 | " def __init__(self, z_dim, hidden_dim):\n", 126 | " super(Decoder, self).__init__()\n", 127 | " # setup the two linear transformations used\n", 128 | " self.fc1 = nn.Linear(z_dim, hidden_dim)\n", 129 | " self.fc21 = nn.Linear(hidden_dim, 784)\n", 130 | " # setup the non-linearities\n", 131 | " self.softplus = nn.Softplus()\n", 132 | " self.sigmoid = nn.Sigmoid()\n", 133 | "\n", 134 | " def forward(self, z):\n", 135 | " # define the forward computation on the latent z\n", 136 | " # first compute the hidden units\n", 137 | " hidden = self.softplus(self.fc1(z))\n", 138 | " # return the parameter for the output Bernoulli\n", 139 | " # each is of size batch_size x 784\n", 140 | " #loc_img = self.sigmoid(self.fc21(hidden))\n", 141 | " loc_img = self.fc21(hidden)\n", 142 | " return loc_img" 143 | ] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "metadata": {}, 148 | "source": [ 149 | "### Setup the encoder network" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 5, 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "class Encoder(nn.Module):\n", 159 | " def __init__(self, z_dim, hidden_dim):\n", 160 | " super(Encoder, self).__init__()\n", 161 | " # setup the three linear transformations used\n", 162 | " self.fc1 = nn.Linear(784, hidden_dim)\n", 163 | " self.fc21 = nn.Linear(hidden_dim, z_dim)\n", 164 | " self.fc22 = nn.Linear(hidden_dim, z_dim)\n", 165 | " # setup the non-linearities\n", 166 | " self.softplus = nn.Softplus()\n", 167 | "\n", 168 | " def forward(self, x):\n", 169 | " # define the forward computation on the image x\n", 170 | " # first shape the mini-batch to have pixels in the rightmost dimension\n", 171 | " x = x.reshape(-1, 784)\n", 172 | " # then compute the hidden units\n", 173 | " hidden = self.softplus(self.fc1(x))\n", 174 | " # then return a mean vector and a (positive) square root covariance\n", 175 | " # each of size batch_size x z_dim\n", 176 | " z_loc = self.fc21(hidden)\n", 177 | " z_scale = torch.exp(self.fc22(hidden))\n", 178 | " return z_loc, z_scale" 179 | ] 180 | }, 181 | { 182 | "cell_type": "markdown", 183 | "metadata": {}, 184 | "source": [ 185 | "### Packaging it all together" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": 6, 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [ 194 | "class VAE(nn.Module):\n", 195 | " # by default our latent space is 50-dimensional\n", 196 | " # and we use 400 hidden units\n", 197 | " def __init__(self, z_dim=2, hidden_dim=400, use_cuda=False):\n", 198 | " super(VAE, self).__init__()\n", 199 | " # create the encoder and decoder networks\n", 200 | " self.encoder = Encoder(z_dim, hidden_dim)\n", 201 | " self.decoder = Decoder(z_dim, hidden_dim)\n", 202 | " self.z_dim = z_dim\n", 203 | "\n", 204 | " # define the model p(x|z)p(z)\n", 205 | " def model(self, x):\n", 206 | " # register PyTorch module `decoder` with Pyro\n", 207 | " pyro.module(\"decoder\", self.decoder)\n", 208 | " with pyro.plate(\"data\", x.shape[0]):\n", 209 | " # setup hyperparameters for prior p(z)\n", 210 | " z_loc = x.new_zeros(torch.Size((x.shape[0], self.z_dim)))\n", 211 | " z_scale = x.new_ones(torch.Size((x.shape[0], self.z_dim)))\n", 212 | " # sample from prior (value will be sampled by guide when computing the ELBO)\n", 213 | " z = pyro.sample(\"latent\", dist.Normal(z_loc, z_scale).to_event(1))\n", 214 | " # decode the latent code z\n", 215 | " loc_img = self.decoder.forward(z)\n", 216 | " # score against actual images\n", 217 | " pyro.sample(\"obs\", dist.Bernoulli(logits=loc_img).to_event(1), obs=x.reshape(-1, 784))\n", 218 | " #pyro.sample(\"obs\", dist.Bernoulli(loc_img).to_event(1), obs=x.reshape(-1, 784))\n", 219 | "\n", 220 | " # define the guide (i.e. variational distribution) q(z|x)\n", 221 | " def guide(self, x):\n", 222 | " # register PyTorch module `encoder` with Pyro\n", 223 | " pyro.module(\"encoder\", self.encoder)\n", 224 | " with pyro.plate(\"data\", x.shape[0]):\n", 225 | " # use the encoder to get the parameters used to define q(z|x)\n", 226 | " z_loc, z_scale = self.encoder.forward(x)\n", 227 | " # sample the latent code z\n", 228 | " pyro.sample(\"latent\", dist.Normal(z_loc, z_scale).to_event(1))\n", 229 | "\n", 230 | " # define a helper function for reconstructing images\n", 231 | " def reconstruct_img(self, x):\n", 232 | " # encode image x\n", 233 | " z_loc, z_scale = self.encoder(x)\n", 234 | " # sample in latent space\n", 235 | " z = dist.Normal(z_loc, z_scale).sample()\n", 236 | " # decode the image (note we don't sample in image space)\n", 237 | " loc_img = self.decoder(z)\n", 238 | " return loc_img\n", 239 | "\n", 240 | " def sample_images(self, dim=10):\n", 241 | "\n", 242 | " plt.figure(figsize=(dim, dim))\n", 243 | " gs1 = gridspec.GridSpec(dim, dim)\n", 244 | " gs1.update(wspace=0.025, hspace=0.05) # set the spacing between axes.\n", 245 | "\n", 246 | " z_1 = norm.ppf(np.linspace(0.00001, 0.99999, dim), loc=0, scale=1)\n", 247 | " z_2 = norm.ppf(np.linspace(0.00001, 0.99999, dim), loc=0, scale=1)\n", 248 | " for j in range(dim):\n", 249 | " for i in range(dim):\n", 250 | " x_val = self.decoder.forward(torch.tensor([z_1[i], z_2[j]], dtype=torch.float32))\n", 251 | " plt.subplot(gs1[i*dim+j])\n", 252 | " plt.axis('off')\n", 253 | " plt.imshow(x_val.detach().numpy().reshape((28, 28)), cmap=\"gray_r\")\n", 254 | " plt.show()" 255 | ] 256 | }, 257 | { 258 | "cell_type": "markdown", 259 | "metadata": {}, 260 | "source": [ 261 | "### Setup training (single epoch)" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": 7, 267 | "metadata": {}, 268 | "outputs": [], 269 | "source": [ 270 | "def train(svi, train_loader):\n", 271 | " # initialize loss accumulator\n", 272 | " epoch_loss = 0.\n", 273 | " # do a training epoch over each mini-batch x returned\n", 274 | " # by the data loader\n", 275 | " for x, _ in train_loader:\n", 276 | " # do ELBO gradient and accumulate loss\n", 277 | " epoch_loss += svi.step(x)\n", 278 | "\n", 279 | " # return epoch loss\n", 280 | " normalizer_train = len(train_loader.dataset)\n", 281 | " total_epoch_loss_train = epoch_loss / normalizer_train\n", 282 | " return total_epoch_loss_train" 283 | ] 284 | }, 285 | { 286 | "cell_type": "markdown", 287 | "metadata": {}, 288 | "source": [ 289 | "### Perform learning" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": 8, 295 | "metadata": {}, 296 | "outputs": [ 297 | { 298 | "name": "stdout", 299 | "output_type": "stream", 300 | "text": [ 301 | "[epoch 000] average training loss: 561.9964\n", 302 | "[epoch 100] average training loss: 20.7505\n", 303 | "[epoch 200] average training loss: 19.0819\n", 304 | "[epoch 300] average training loss: 18.7468\n", 305 | "[epoch 400] average training loss: 18.1215\n", 306 | "[epoch 500] average training loss: 17.7618\n", 307 | "[epoch 600] average training loss: 17.4973\n", 308 | "[epoch 700] average training loss: 17.4621\n", 309 | "[epoch 800] average training loss: 17.3345\n", 310 | "[epoch 900] average training loss: 17.1521\n" 311 | ] 312 | }, 313 | { 314 | "data": { 315 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZAAAAEKCAYAAAA8QgPpAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAG+dJREFUeJzt3X20HHWd5/H3p/s+hCSYBEEICTFBgoqKyNwB3Z3dReUhsLNmncEZWM+CDzuMu6DjHOe4MOwZXBk8KrqOOsgxMug6w4isD0OWzRgBgZ2zZ5UEB0NAkMuDkgyuQRTEwM293d/9o359U7fTVZ1bSafvvfm8zmno+lV11a+6bvpTv/rVgyICMzOz6ar1uwJmZjY7OUDMzKwSB4iZmVXiADEzs0ocIGZmVokDxMzMKnGAmJlZJQ4QMzOrxAFiZmaVDPS7AtMlaQ3waaAOXBcRHy2a9vDDD4+VK1ceqKqZmc0J99xzz1MRcUS36WZVgEiqA9cAZwDbgE2S1kfEA52mX7lyJZs3bz6QVTQzm/Uk/Xhvpptth7BOAUYj4tGI2AXcCKztc53MzA5Ksy1AlgFP5Ia3pTIzMzvAZluAdCXpIkmbJW3esWNHv6tjZjZnzbYA2Q4ckxtensomRcS6iBiJiJEjjujaB2RmZhXNtgDZBKyWtErSEHAesL7PdTIzOyjNqrOwImJC0iXARrLTeK+PiPv7XC0zs4PSrAoQgIjYAGzodz3MzA52sy5AzPKazeyRzM0IovX/9JTm1vvWuAiIorLc9NE+31ZZ633us/llkC+bUpegGVOnD6Lj5yWh3PopNzDRCMYmmtSUTVcT1CSU/r9rokm9nn2+GYGkrOLA8+ONyfqNTTQ4dN4gEUEzgkYTGhGT381AXUw0AgkazamPvG7VJwLGG00kMVSvUauJF8YbNJtBvSZ2NZoM1JTqmdW1GTBYT/WsabLurXXevS6w6JAhdjWaPPfCBL98fhfzh+oM1Gr8emyC8UaztVqT27q1TfJlkM2rlurRbMbkdzVQExPNYCKtQ6fvG9i9LSTqEhPNJmL399KMbBn1mhhvBI20/o1mMDRQm/L3mP8m2x8lvnua6FC253oND9SoCcYbwc5dDYYHa5PrK7Lv8cgXzePfvPZoeskBMgM0m8HO8Qa/HpvgubEJfrlznB2/GmNsosGzz4/zwniTZmQ/Qs/vmmCs0WRsvMmuRnPKD2Kz7Qep/f+taVo/gtP6DLnPNKf+EObnMdFsAtk/0vZ/FM1mNo9G+tFqzW+gXmO80aTRiMkfPonJf6jjzdjjh6ylqNzsYHfSMYsdILPFrokmjz71HD995gWeem4XP33meZ4ba/D8rgl27mqwc7zBzrEJfj3W4Ne7JlJYNNiZxu8tCQbrNYbrNYYH65N7bDVpyt5o+/9b00AarrXGK+21TJ1HvSYG017i5GfU9pncchHUJeq1bPpmBILJvTul/9Rbe6Q1Ua9le0sTzWCwLgZqtck9wFYgDdREvZ7tMU7dN4cXxhsMDdQYHtj9PShX38ky2vd02R1SbXu/rfe01hWmfJdqL4MpgZf/zulQ1voeRPZdteq4e193zz3QllpNDA/U0g5Ae0snC95GM3hhvMEhg3WAyb3hhfMGJr+/8UaT59M0tbTN6rXdfx+NZrZT0GgG84fqU1odeYMDNSIi7Xk3GR6oM95oMlCrMTxYY6KZ1au1w1GTGJvItllRq64ZwUQjeOq5MRYODzA0UOP5XQ0WzR+k2YShAbHokKHJ75u2v7FWXYUmW3mNtKNSlyaX0Whm9Rms7z6PaI9WQdt2GG80J6dv/Z3XJRqRtWQG6jXqEi9MNBisZ9/N7r+ryYpNUlvRlH8rufGt9WmVNSMYG89aYYN1MW+wzthEk7o05W+irTHVEw6Qiv7pl8/zvcd+zl0P7eCHT/6KR3Y8x0Tb3vBgXcwfGmD+UJ1DhuosHM7eL100j/lDAywYHmDBUJ0FwwPZuOFsmhfNG+SIQ4eZN1hj0SFDDA/WJn9AhweywwVmNjMtYrDny5g/NHU4H4QHkgNkGhrN4K//7+Ns2fYMG7Y+yQvjTZbMH+R1K5bw5le+hJcfdSjLlxzCixcMc9SiecxLe4FmZnORA2QavvH9bXzof2b3bTzuJQu5+twTeeXSFzkozOyg5ADZC3f9aAdbnvglt2x5kuOPXMhf/ruTWbpoHofO631T1cxspnKA7IU/v+UBHv7ZcwB87u0nc/yRh/a5RmZm/ecA6eLqjQ9Ohsff/odT+WfHHd7nGpmZzQyz7V5YB9S1dz7CNXc8AsB733Scw8PMLMcBUuJj33pw8v3hC4f7WBMzs5nHAbKXzjjhyH5XwcxsRnGAFMhflfr+01dz9OJD+lgbM7OZxwFSoHXzOYA/+BfH9rEmZmYzkwOkwHNjEwBcufZVLBj2yWpmZu0cIAV2jmUtEIeHmVlnDpACp33iTgDmDzlAzMw6cYB0sdAtEDOzjhwgXSwY9o0Szcw6cYB04RaImVlnDpAu5jtAzMw6coB0sWDIh7DMzDpxgHThQ1hmZp3517HAicsXcdiCIQb69KxhM7OZzr+OJdTvCpiZzWAOEDMzq8QBUiB3M14zM+vAAVJC8kEsM7MifQkQSW+TdL+kpqSRtnGXSRqV9JCks3Lla1LZqKRLe13HwE0QM7My/WqBbAV+B/jf+UJJJwDnAa8C1gCfk1SXVAeuAc4GTgDOT9P2lNsfZmbF+nIab0T8EDoeIloL3BgRY8BjkkaBU9K40Yh4NH3uxjTtA72rY6/mbGY2N8y0PpBlwBO54W2prKh8D5IukrRZ0uYdO3bsU2XcBWJmVqxnLRBJtwFHdRh1eUTc3KvlRsQ6YB3AyMiI2xFmZj3SswCJiNMrfGw7cExueHkqo6S8J3wIy8ys3Ew7hLUeOE/SsKRVwGrgbmATsFrSKklDZB3t63tfHR/DMjMr0pdOdElvBT4LHAH8L0n3RsRZEXG/pJvIOscngIsjopE+cwmwEagD10fE/b2soxsgZmbl+nUW1jeBbxaMuwq4qkP5BmBDj6s2hTvRzcyKzbRDWDNGuBPEzKyUA6SEGyBmZsUcIGZmVokDxMzMKnGAlHAnuplZMQdIAfehm5mVc4CUkLvRzcwKOUAK+HkgZmblHCAl3AdiZlbMAWJmZpU4QAq4E93MrJwDpIQPYZmZFXOAFHADxMysnAOkhE/jNTMr5gAp4LvxmpmVc4CUcQPEzKyQA8TMzCpxgBTwASwzs3IOkBI+gmVmVswBUsRNEDOzUg6QEvKVhGZmhRwgBdwAMTMr5wAp4faHmVkxB4iZmVXiACngK9HNzMo5QEq4D93MrJgDpIDbH2Zm5foSIJKulvSgpC2SvilpcW7cZZJGJT0k6axc+ZpUNirp0gNSzwOxEDOzWapfLZBbgVdHxInAj4DLACSdAJwHvApYA3xOUl1SHbgGOBs4ATg/Tdsz7gIxMyvXlwCJiG9HxEQa/C6wPL1fC9wYEWMR8RgwCpySXqMR8WhE7AJuTNP2lC8kNDMrNhP6QN4F/H16vwx4IjduWyorKjczsz4Z6NWMJd0GHNVh1OURcXOa5nJgArhhPy73IuAigBUrVlSeT7gb3cysVM8CJCJOLxsv6R3AbwNvjt0XXWwHjslNtjyVUVLevtx1wDqAkZGRfUoBH8AyMyvWr7Ow1gAfBN4SETtzo9YD50kalrQKWA3cDWwCVktaJWmIrKN9fS/r6E50M7NyPWuBdPGXwDBwa+qo/m5EvCci7pd0E/AA2aGtiyOiASDpEmAjUAeuj4j7e15LN0HMzAr1JUAi4riScVcBV3Uo3wBs6GW9pi7vQC3JzGx2mglnYc1YchPEzKyQA8TMzCpxgJiZWSUOkBK+EN3MrJgDpICfB2JmVs4BUsINEDOzYg4QMzOrxAFSwAewzMzKOUBKuBPdzKyYA6SA+9DNzMo5QEr4SnQzs2IOkAJ+HoiZWTkHSAn3gZiZFXOAmJlZJQ6QAu5ENzMr5wAp4UNYZmbFHCAF3AAxMyvXNUAkvVrSlyVtTq//LunEA1G5/nMTxMysSGmASFoLfBO4E3hXet0FfD2Nm7PcB2JmVq7bM9E/DJwREY/nyrZI+g5wc3rNWe4DMTMr1u0Q1kBbeACQygZ7USEzM5sdugXIhKQV7YWSXgpM9KZKM4WPYZmZlel2COsK4DZJHwHuSWUjwKXAf+5lxWYCH8EyMytWGiAR8XeSHgM+ALw3FT8A/F5E/KDXlesnd6KbmZXr1gIhBcUFB6AuM4470c3MinU7jfdwSVdIep+khZKulbRV0s2SjjtQlewHN0DMzMp160T/W2AYWA3cDTwGnAvcAlzX26r1n58HYmZWrFuAHBkRfwq8D1gYER+PiAcj4gvA4qoLlXSlpC2S7pX0bUlHp3JJ+oyk0TT+5NxnLpT0cHpdWHXZZma2f3QLkAZARATwVNu45j4s9+qIODEiTiJrzfxZKj+brLWzGrgIuBZA0mFkZ4SdCpwCXCFpyT4sv6twL7qZWalunejHSlpPdkZr6z1peFXVhUbEs7nBBezuclgLfDkF1nclLZa0FDgNuDUingaQdCuwBvhK1TrsDXeim5kV6xYg+ftdfaJtXPvwtEi6iuzsrmeAN6biZcATucm2pbKi8p5x+8PMrFy360DuKhon6atkN1YsGn8bcFSHUZdHxM0RcTlwuaTLgEvIDlHtM0kXkR3+YsWKPS6in9689keFzMzmqK7XgZR4Q9nIiDh9L+dzA7CBLEC2A8fkxi1PZdvJDmPly+8sWO46YB3AyMhI5YaEu0DMzMr15YFSklbnBtcCD6b364EL0tlYrweeiYgngY3AmZKWpM7zM1NZr+vZ60WYmc1apS2Q/Gm07aPYt7vxflTSy8nO5Pox8J5UvgE4BxgFdgLvBIiIpyVdCWxK03241aFuZmb90e0Q1idLxj1YMq5URPxuQXkAFxeMux64vuoyp8un8ZqZlevWif7GsvFmZnbw6nYvrA/m3r+tbdxHelWpmcDtDzOzct060c/Lvb+sbdya/VyXGcd96GZmxboFiAredxqeW9wEMTMr1S1AouB9p+E5x3fjNTMr1u0srNdKepastXFIek8antfTmpmZ2YzW7Sys+oGqyEwz55tXZmb7qC9Xos8W7kQ3MyvmACngCwnNzMo5QEq4AWJmVswBUsDtDzOzcg6QEu4DMTMr5gAxM7NKHCAF3IduZlbOAVLCD5QyMyvmACkQ7kY3MyvlACnh9oeZWTEHSAH3gZiZlXOAlHETxMyskAPEzMwqcYAU8BEsM7NyDpASfqCUmVkxB0gRN0HMzEo5QEr4OkIzs2IOkAK+kNDMrJwDpIQbIGZmxRwgZmZWSV8DRNIHJIWkw9OwJH1G0qikLZJOzk17oaSH0+vCXtfNV6KbmZUb6NeCJR0DnAn8JFd8NrA6vU4FrgVOlXQYcAUwQnZ+1D2S1kfEL3pbx17O3cxsdutnC+RTwAeZesLsWuDLkfkusFjSUuAs4NaIeDqFxq3Aml5Wzg0QM7NyfQkQSWuB7RHxg7ZRy4AncsPbUllReU/5QkIzs2I9O4Ql6TbgqA6jLgf+lOzwVS+WexFwEcCKFSt6sQgzM6OHARIRp3cql/QaYBXwg/TEv+XA9yWdAmwHjslNvjyVbQdOayu/s2C564B1ACMjI5WPRIV70c3MSh3wQ1gRcV9EvCQiVkbESrLDUSdHxE+B9cAF6Wys1wPPRMSTwEbgTElLJC0ha71s7HVd3YluZlasb2dhFdgAnAOMAjuBdwJExNOSrgQ2pek+HBFP97Iibn+YmZXre4CkVkjrfQAXF0x3PXD9AaoW4CvRzczK+Er0Au4CMTMr5wAp404QM7NCDhAzM6vEAWJmZpU4QEr4AJaZWTEHSAe+iNDMrDsHSAn3oZuZFXOAdOAGiJlZdw6QEr4br5lZMQeImZlV4gDpwEewzMy6c4CUcCe6mVkxB0gHPo3XzKw7B0gJN0DMzIo5QDpw+8PMrDsHSAn3gZiZFXOAmJlZJQ6QDtyHbmbWnQOkhHwMy8yskAOkg3A3uplZVw4QMzOrxAHSgftAzMy6c4CUcBeImVkxB4iZmVXiADEzs0ocICX8QCkzs2IOkA7ciW5m1p0DpIQ70c3MivUlQCR9SNJ2Sfem1zm5cZdJGpX0kKSzcuVrUtmopEt7WT9fSGhm1t1AH5f9qYj4RL5A0gnAecCrgKOB2yQdn0ZfA5wBbAM2SVofEQ/0soJugJiZFetngHSyFrgxIsaAxySNAqekcaMR8SiApBvTtD0NEDMzK9bPPpBLJG2RdL2kJalsGfBEbpptqayofA+SLpK0WdLmHTt2VKqYO9HNzLrrWYBIuk3S1g6vtcC1wMuAk4AngU/ur+VGxLqIGImIkSOOOGKf5uVOdDOzYj07hBURp+/NdJK+ANySBrcDx+RGL09llJTvd26AmJl116+zsJbmBt8KbE3v1wPnSRqWtApYDdwNbAJWS1olaYiso319z+vpbnQzs0L96kT/uKSTyHb2Hwf+ECAi7pd0E1nn+ARwcUQ0ACRdAmwE6sD1EXF/ryoX7gQxM+uqLwESEf++ZNxVwFUdyjcAG3pZr3buAzEzK+Yr0c3MrBIHSAc+gGVm1p0DxMzMKnGAdOA+dDOz7hwgJeRedDOzQg6QTtwCMTPrygFSwu0PM7NiDhAzM6vEAdKBHyhlZtadA6SE+9DNzIo5QDrwabxmZt05QEq4AWJmVswB0oEbIGZm3TlASvhCQjOzYg4QMzOrxAHSgR8oZWbWnQOkhI9gmZkVc4B0MDRQ41+/ZikrDpvf76qYmc1Y/Xom+ox26LxBrnn7yf2uhpnZjOYWiJmZVeIAMTOzShwgZmZWiQPEzMwqcYCYmVklDhAzM6vEAWJmZpU4QMzMrBLN5fs+SdoB/HgfZnE48NR+qs5s4XWe+w629QWv83S9NCKO6DbRnA6QfSVpc0SM9LseB5LXee472NYXvM694kNYZmZWiQPEzMwqcYCUW9fvCvSB13nuO9jWF7zOPeE+EDMzq8QtEDMzq8QB0oGkNZIekjQq6dJ+12d/kXSMpDskPSDpfkl/lMoPk3SrpIfT/5ekckn6TPoetkiatQ9JkVSX9I+SbknDqyR9L63bVyUNpfLhNDyaxq/sZ72rkrRY0tckPSjph5LeMNe3s6Q/Tn/XWyV9RdK8ubadJV0v6WeStubKpr1dJV2Ypn9Y0oVV6+MAaSOpDlwDnA2cAJwv6YT+1mq/mQA+EBEnAK8HLk7rdilwe0SsBm5Pw5B9B6vT6yLg2gNf5f3mj4Af5oY/BnwqIo4DfgG8O5W/G/hFKv9Umm42+jTwrYh4BfBasnWfs9tZ0jLgfcBIRLwaqAPnMfe285eANW1l09qukg4DrgBOBU4BrmiFzrRFhF+5F/AGYGNu+DLgsn7Xq0frejNwBvAQsDSVLQUeSu8/D5yfm35yutn0Apanf1hvAm4BRHaB1UD7Ngc2Am9I7wfSdOr3OkxzfRcBj7XXey5vZ2AZ8ARwWNputwBnzcXtDKwEtlbdrsD5wOdz5VOmm87LLZA9tf4QW7alsjklNdlfB3wPODIinkyjfgocmd7Ple/iL4APAs00/GLglxExkYbz6zW5zmn8M2n62WQVsAP4Yjpsd52kBczh7RwR24FPAD8BniTbbvcwt7dzy3S3637b3g6Qg5CkhcDXgfdHxLP5cZHtksyZU/Mk/Tbws4i4p991OYAGgJOBayPidcCv2X1YA5iT23kJsJYsPI8GFrDnoZ4570BvVwfInrYDx+SGl6eyOUHSIFl43BAR30jF/0/S0jR+KfCzVD4Xvot/DrxF0uPAjWSHsT4NLJY0kKbJr9fkOqfxi4CfH8gK7wfbgG0R8b00/DWyQJnL2/l04LGI2BER48A3yLb9XN7OLdPdrvtteztA9rQJWJ3O3hgi64hb3+c67ReSBPwV8MOI+G+5UeuB1pkYF5L1jbTKL0hnc7weeCbXVJ4VIuKyiFgeESvJtuV3IuLtwB3AuWmy9nVufRfnpuln1Z56RPwUeELSy1PRm4EHmMPbmezQ1eslzU9/5611nrPbOWe623UjcKakJanldmYqm75+dwjNxBdwDvAj4BHg8n7XZz+u12+RNW+3APem1zlkx35vBx4GbgMOS9OL7Iy0R4D7yM5w6ft67MP6nwbckt4fC9wNjAL/AxhO5fPS8Ggaf2y/611xXU8CNqdt/XfAkrm+nYH/CjwIbAX+Ghiea9sZ+ApZH884WUvz3VW2K/CutO6jwDur1sdXopuZWSU+hGVmZpU4QMzMrBIHiJmZVeIAMTOzShwgZmZWiQPEZhVJIemTueE/kfSh/TTvL0k6t/uU+7yct6U75N7RVn60pK+l9ydJOmc/LnOxpP/UaVlmVTlAbLYZA35H0uH9rkhe7mrnvfFu4A8i4o35woj4p4hoBdhJZNfo7K86LAYmA6RtWWaVOEBstpkge1TnH7ePaG9BSHou/f80SXdJulnSo5I+Kuntku6WdJ+kl+Vmc7qkzZJ+lO6j1XqWyNWSNqXnKvxhbr7/IGk92VXP7fU5P81/q6SPpbI/I7ug868kXd02/co07RDwYeD3Jd0r6fclLUjPgrg73SBxbfrMOyStl/Qd4HZJCyXdLun7adlr0+w/Crwsze/q1rLSPOZJ+mKa/h8lvTE3729I+pay50Z8PPd9fCnV9T5Je2wLOzhMZ6/JbKa4BtjS+kHbS68FXgk8DTwKXBcRpyh7qNZ7gfen6VaSPSPhZcAdko4DLiC7DcRvShoG/o+kb6fpTwZeHRGP5Rcm6WiyZ0z8BtlzKL4t6d9GxIclvQn4k4jY3KmiEbErBc1IRFyS5vcRstttvEvSYuBuSbfl6nBiRDydWiFvjYhnUyvtuyngLk31PCnNb2VukRdni43XSHpFquvxadxJZHdtHgMekvRZ4CXAssieu0Gqjx2E3AKxWSeyOwh/mewBQntrU0Q8GRFjZLd2aAXAfWSh0XJTRDQj4mGyoHkF2b2CLpB0L9nt719M9pAegLvbwyP5TeDOyG7uNwHcAPzLadS33ZnApakOd5LdimNFGndrRDyd3gv4iKQtZLe1WMbu23sX+S3gbwAi4kHgx0ArQG6PiGci4gWyVtZLyb6XYyV9VtIa4NkO87SDgFsgNlv9BfB94Iu5sgnSTpGkGjCUGzeWe9/MDTeZ+u+g/d4+Qfaj/N6ImHLDOUmnkd0q/UAQ8LsR8VBbHU5tq8PbgSOA34iIcWV3IZ63D8vNf28Nsocz/ULSa8ke2PQe4PfI7q1kBxm3QGxWSnvcN7H7EaUAj5MdMgJ4CzBYYdZvk1RL/SLHkj3FbSPwH5XdCh9Jxyt7QFOZu4F/JelwZY9JPh+4axr1+BVwaG54I/BeSUp1eF3B5xaRPf9kPPVlvLRgfnn/QBY8pENXK8jWu6N0aKwWEV8H/gvZITQ7CDlAbDb7JJA/G+sLZD/aPyB7fGmV1sFPyH78/x54Tzp0cx3Z4Zvvp47nz9Ol9R7ZbbMvJbud+A+AeyLi5rLPtLkDOKHViQ5cSRaIWyTdn4Y7uQEYkXQfWd/Ng6k+Pyfru9na3nkPfA6opc98FXhHOtRXZBlwZzqc9jdkj322g5DvxmtmZpW4BWJmZpU4QMzMrBIHiJmZVeIAMTOzShwgZmZWiQPEzMwqcYCYmVklDhAzM6vk/wMAnrJKV3jAxQAAAABJRU5ErkJggg==\n", 316 | "text/plain": [ 317 | "
" 318 | ] 319 | }, 320 | "metadata": {}, 321 | "output_type": "display_data" 322 | } 323 | ], 324 | "source": [ 325 | "vae = VAE(z_dim=2, hidden_dim=400)\n", 326 | "\n", 327 | "# Run options\n", 328 | "LEARNING_RATE = 1.0e-2\n", 329 | "\n", 330 | "# Run only for a single iteration for testing\n", 331 | "NUM_EPOCHS = 1000\n", 332 | "\n", 333 | "#train_loader = setup_data_loader(batch_size=300)\n", 334 | "\n", 335 | "# clear param store\n", 336 | "pyro.clear_param_store()\n", 337 | "\n", 338 | "# setup the optimizer\n", 339 | "adam_args = {\"lr\": LEARNING_RATE}\n", 340 | "optimizer = Adam(adam_args)\n", 341 | "\n", 342 | "# setup the inference algorithm\n", 343 | "svi = SVI(vae.model, vae.guide, optimizer, loss=Trace_ELBO())\n", 344 | "train_elbo = []\n", 345 | "# training loop\n", 346 | "for epoch in range(NUM_EPOCHS):\n", 347 | " total_epoch_loss_train = train(svi, train_loader)\n", 348 | " train_elbo.append(-total_epoch_loss_train)\n", 349 | " if (epoch % 100) == 0:\n", 350 | " print(\"[epoch %03d] average training loss: %.4f\" % (epoch, total_epoch_loss_train))\n", 351 | "\n", 352 | "plt.plot(range(len(train_elbo)), train_elbo)\n", 353 | "plt.xlabel(\"Number of iterations\")\n", 354 | "plt.ylabel(\"ELBO\")\n", 355 | "plt.show()" 356 | ] 357 | }, 358 | { 359 | "cell_type": "markdown", 360 | "metadata": {}, 361 | "source": [ 362 | "### Plot the data in the embedding space" 363 | ] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "execution_count": 9, 368 | "metadata": { 369 | "scrolled": true 370 | }, 371 | "outputs": [ 372 | { 373 | "data": { 374 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAD8CAYAAABjAo9vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJztnX90VdW1778rh0ACUuIPfMEEBarlKRKITVuftqMKGvTFH8hVrLwxfG1v67D3dpRr76PFp0OC145StbX4et94A2+9tqNqi0ijbeogVuzzV2kNgiClWG+qmAiP8CNiIYH8WO+Pk5Occ7J/77X3Wnuf72cMB2Znn73XPifnu+aac645hZQShBBC0kOZ7gEQQghRC4WdEEJSBoWdEEJSBoWdEEJSBoWdEEJSBoWdEEJSBoWdEEJSBoWdEEJSBoWdEEJSxjgdNz3jjDPkjBkzdNyaEEISy9atWw9KKae6nadF2GfMmIH29nYdtyaEkMQihHjPy3l0xRBCSMqgsBNCSMqgsBNCSMqgsBNCSMqgsBNCSMqgsBNCSMqgsBNCSMqgsBNCSBh2rAceuhBorsr+u2O97hHp2aBECCGpYMd64FffAPp7sz9/+H72ZwCoW6ptWLTYCSEkKC/cOyrqOfp7s8c1QoudkCLm/PtcCDH6s5TAri/t1DcgYi4fdvo7HhO02AnJIyfqxf/N+fe5uodGTGRKrb/jMUFhJySPnJC7HSMEALDwHqC8svBYeWX2uEYo7IQQEpS6pcC1DwNTpgMQ2X+vfVhr4BSgj50QQsJRt1S7kBdDi52QPKTM/ud2jBCTobATkseuL+0cEfL8/5gVQ5IEXTGEFEERJ0mHFjshhKQMCjshhKQMCjshhKQMCjshhKQMCjshhKQMCjshhKQMCjshhKQMCjshhKQMJcIuhHhUCHFACPGWiusRQggJjiqL/TEAVym6FiGEuGNgr1FTUFJSQEr5khBihoprEUKIK4b2GjWF2HzsQojbhBDtQoj27u7uuG5LiBJaO1rRuKERdT+pQ+OGRrR2tOoeUmljaK/RMWhaVcQm7FLKdVLKBillw9SpU+O6LUkJOoW1taMVza81Y9+xfZCQ2HdsH5pfa6a468TQXqMF5FYVH74PQI6uKmIQd2bFEOPRLaxr31iLvsG+gmN9g31Y+8baWO5PLDC012gBGlcVFHZiPLqFdf+x/b6OkxgwtNdoARpXFarSHZ8E8HsAs4UQnUKIv1dxXUIA/cJaPana13ESA4b2Gi1A46pCVVbMLSquQ4gV1ZOqse/YPsvjcbD8ouVofq25YNVQkanA8ouWx3J/YoOBvUYLWHhPYeYOENuqgq4YYjzLL1qOikxFwbE4hbVpVhOuP/d6lIns16VMlOH6c69H06ymWO5PEorGVQVb4xHjyQno2jfWYv+x/aieVI3lFy2PTVhbO1rxzDvPYEgOAQCG5BCeeecZ1J9ZT3EnzmhaVQipof16Q0ODbG9vj/2+hAShcUOjpSto2qRpaLuxTcOISKkihNgqpWxwO48WOyEWtHa0jqwQJKyNH2bFpIAd67Pphx92ZoOaC+8x22/vEQo7IUXk8uaLUyyLYVZMwomjLIGmiYPBU0KKsMqbL4ZZMSkg6g1E3HlKiDk4uVgEBKZNmobmS5oZOE06UW8g0rjzlK4YQoqwy5tnsDRlTKkdtqYtjqsg6TtPCfFCUiok6s6bJzERdVmCpO88JVlatnXhgU178EFPL86qqsSKRbOxuL4mtff1Q3FAMlfIC4BxLg2VefP52TVx598TF3JBzKiCmxp3njKPXREt27pw58ad6O0fHDlWWZ7Bd5fMjVRkdd3XL6WYC26VXVORqQjln79vy3146u2nMCSHUCbKcNMnbsLdF9+tashENYqzYpjHHjMPbNpTIK4A0Ns/iAc27YlUYMPeNy5rX3chLx04VaUMIuz3bbkPv9jzi5Gfh+TQyM8Ud0PRtPOUPnZFfNDT6+u4CffNWftdPb2QALp6enHnxp1o2daleJSlWSFR9WT21NtP+TpOShcKuyLOqqr0ddyE+zpZ+6opxYCk6sksV6vG63FSulDYFbFi0WxUlmcKjlWWZ7Bi0Wxj7xvnKqNpVhOaL2nGtEnTtOaC37flPsz76TzM/clczPvpPNy35b7I7qV6MstVl7QirmciyYA+dkXk/NJxZ6eEue9ZVZXoshDxqFYZTbOatGaE2PqoD76Du//SXhDgaj1lUuhsFtVVKW/6xE0F488nv/Ik/e6EWTElTFIyalQx76fzLN0WZVLizXdHN6q0fqwKzWecjj7ZP3IsTDZL2JTH/NdXZCrQN9gHCYkyUQYppWWRsjJRhjdvfdP3WInZeM2KoSumhFlcX4PvLpmLmqpKCAA1VZWpFXXAwUdd9PPaj00sEHUgeI/VsI24i1/fO9iLCZkJWPO5NXjz1jdtK0/S717a0BVT4iyur0mtkBdTJsqsLfain/ePy4w5BwiWzRI25dHt9bbP5OCPJ+mHnz4pGW76xE1jD0qJm45+VHCoemBw7HkIls0SNuXR7fWWz+RwnJQGFHaDadnWhUvXbMbMla24dM3mSPLLS4m7L74bN8++uaB36c1nNODuj04UnLf86HFUiPKCY0GzWcKmPLq93vKZZt/MwGmJw+CpoaQxsGls3RSLbd8qsmKAsZk4gL9AbBRlCUqOFHVJYkmBhKOrREFQ3EoTGF0EzGLbdxPCjyvXBLuY68+93vO1dTfyTgROwh1HlyQDobBHTNBaLG6bh0yq6Fi8usiVJgBG8+xV103xis5Vgl0nppc6X/J1HdX5/yb97YTGTbidml2kWNjpY/eBX593mFosTqUC4qzx4gUvpQl0FAELm2oYFhMLn5n2txMaty5FGptd6ITC7pEgX4gwtVicSgX4vW7UQVgvpQnCBBGDNuhwWiXEgd2zCSG0NRmJsz5QLLgJd5zNLnasBx66EGiuyv4bQ29TOyjsHvHzhcgJqdV2fcBbLRanzUN+arzEYaG5FSJr2daFI51XQA4VZpqUiwk40nmF44Tjxeq2E37dFrNVrRggu3kozpVDPrqqkEaGm3BH3SUph8bG1VZQ2D3i9QuRL6R2eKnF4uQH9VPRMQ4LzWl1kXs/uvfPQd++JRg6WQUpgUpxBvr2LUH3/jmOE46b1e0k/LpLBecKn1ltFopz5ZBP6CqkBlmlANyFu24pcO3DwJTpAET232sfVu9f19i42goKu0e8fiGshDQfL5UX3axsPxUd47DQnFYX+e/HwNF6HPuPlfjbn9fg0O4VOH5kXsF1rCYcN6vbSfhNKBXcNKsJdinFOnztoaqQGmaVAvAm3HVLgTveApp7sv9GETQ1zJfPrBiPrFg02zKvvPgL4SSYNR4zENxSHZ0qOhZb+lUTy3HkeP+Ye6iu4GhXmsDu/Ri0Ebvi86snVVu21MtZ3U7Cb0qqoNszhMVP5k+oKqSmZpho6lJUwJTa4QnP4rgGKOwe8fqFsCuFW1NViVdXLvB0Ly9WtpWQWqUdlpcJlGcE+gdHhTSOOvE57N6PjBCW4l484Sy/aLnlBp2c1e0mmrpLBQPuzxCGIPsDAtcHMswqNQqNjautoCvGB4vra/DqygX465omvLpygeWXQ0XDjaB+UCtLv39IYtL4cdoqONq9H7d8Zrqn98mtQYcJ7hY3/DYZ8ZPFFGvmT5wZJkkjLl++R2ixK0ZFww2vbp9i7Cz9D3v7sX1Vo+f7q8Tp/Wg45zRP75OT1R3U3RL3xiWvKwcvm73yiTXzxzCr1DhMcAkNw1oxhhJkd6BdiqUfN1ApYHL9Fb+fYeOGRktX1LRJ09B2Y5v6Aaao7koSYa2YhBPEDxrU0i81dJU38ILfLKYo/feWGGSVEnso7ClCV9/VpKF745ITfvvQmpL5Q8yCwp4CUlXUKQaiTD8M+1kEWXWZkPlDzIJZMQkndUWdYiCqTBo/n4Vd5kup9aEl0aAkeCqEuArAWgAZAP8mpVzjdL7JwdOkWb8MmAYjiqwYr59F3E1UkvY3TeyJLXgqhMgA+FcAVwLoBPC6EOJZKeWfwl47bvymmplA6oo6xUQU7guvn0WcTVSS+DdNwqPCFfNpAO9IKTuklCcB/BzA9QquGzs6SpqGLakbuqiTB9h71RteP4s4J+PUlek1rQiZoagQ9hoA+UUSOoePJY64rV8V/nEVO12jHmOp4PWziGMyzpGqFZ2OImQJnUhiC54KIW4TQrQLIdq7u7vjuq0v4vzCtWzrwj+vfzO0NRV1sM0Eiy8pKwavn0XUk3E+cf5NR07cpXFNrGbpERXpjl0Apuf9XDt8rAAp5ToA64Bs8FTBfZUT1wafnBXstcKhG4GLOnlAt8Vngo/YT/DRy2cR534D4zathdm5GncRMlOrWXpAhbC/DuA8IcRMZAX9CwCWKbhu7MT1hXOr2W6SNeV3w4xq4gw0WhHVxFI8AeRWJar/7ozatObWeNqNuEvjJriaZWhhl1IOCCG+DmATsumOj0opd4UemQVxpG1Faf3mcLJ2TSsBoNvi071iiGNiiXpVEsfftCfCWsBWRcgA4OSx7KSh2oo2rMa6H5TsPJVS/gbAb1Rcyw4TluSqcKpRbsJmlOIJ9O8+WYMX/9wdm8WXf/8yj3Xb7cYedqxxTCyBJo8kFuMKawHnnu+5bwO9h0eP9x72Z/l7JcHVLBNTUkD3klwldlawKaJePIE+vbUrtrEV399K1O1WDFZjX/HUm1j9q13oOd7vKPTFE8Ll/3kqXvxzN+yCQSpdUb4nj7AuDV2osIDrlmYntHxhB6LxfeeulbQJFAkSdt1LcpUY5fcswm4CXf2rXbGM1y7+kBECQ1I63tuu0UiuNaDdKs9qQvjZlr22Y1TtivIdx0hqUE+VBRyn7zuh1SwTI+y6g3iqidPv6cc9YTdRHjne7yqQKrC7/5CU+Osa552iXiZ5q1WeWzA7H699a/3gO46R1KCeKgs4wb7vuEhMEbA4c3/ThN8NRl4nyqhy2cPkXXsde/EEoGrV19rRisYNjaj7SR0aNzSitaPV0+t870VIcou6uqXAHW8BzT3Zf4NYwwvvyVr6+STE9x0XibHYTXZfRIWKQKDf2ISV9WhHFG6wMFk4XsdePAHYrQatyE2M7e8dLggoN366C7/+4GFfTaXz8bWCS3BQT0nQN8G+77hIjLADBqVtxYCqLCC/sQmrCfTYiQH09PaPOTcKN1iYCXxxfQ3a3zuMJ//wvu3mr/IyMWaS8DOZAdmJMd8H39XTi/Ud61BWHlNXJlOFzU20VQZ9g/i+k5hJFJBECXspoSoLKEhswmrzTJy57EEn8JZtXXh6a5etqAPAKRXZP/nizUDfXTLXMivGqyUvxvVYHo+sK5NpQT0voq0z6JvUTKKAUNgNRVUWkIoNRklxg3kJgh453m+5EvrukrmW9evtaqwXI/urIMaPFXcVXZkSgRfR1hn0TWomUUAo7IaiKgtIlSgnwQ3mZdLLCBE45jDuY9swYeomiPIeyP4qnOhehIGj9QCAE92LUDFtI0TZqMsq0qbSpuFFtHVmsyQ1kyggFHZDUbmVPwmirIKqieUjKZlWVJZnbC16t5jDd/7v4+idMircYnwPKqZtRB+AgaP1GDhaj7LxGVTPeLE0m0p7Ee2F9wDP/CMweHL0WGZ8PEHfEkuRTEy6Y6nB3pf+aNnWhb/1Ddj+Pvf+1QRIp1xcX4NTa39bYI0DgCjrx4SpmwAA5RmBexfeirYb27Djv+9A241t8Yq67rrh5zV6O14c/1DQmtMTJZYiSYvdYErB0lZV2+WBTXvQP2QtEj+8eX7BNYOshOyCoKK8J5JNS74wITD4lzb34y/cCwwVraiG+uPxc5uaSRQRFHaiDZWF3Zz86/n+86Axh+pJ1dh3bN+Y42edMg1tDk3DY2kkbUJg0IsPW7ef27RMogihsGvm7padI3nXGSFwy2em477FcyO/rwmd65uf3aWssJvTJqNi0Q+yElp+0XI0v9Y8sgEJcA+Oqpq4XD8r3YIJ2PuwK091Pyelfm6d0MeukbtbduJnW/aO5F0PSomfbdmLGRG3gDOhj2nLti7LTU9AsB2tKxbNhrD5nYqNVE2zmtB8STOmTZoGAYFpk6ah+ZJmRz+6iraCnj6rqEoMuPnt839/8hhQlhl7jRMfjb7uvEag+FNKsZ9bJ7TYNfLkHyysl2GiLLRlQglkJ3ELIsS5XaePb9lbUGpX5UaqpllNvgKiKvYiePqsoigx4Oa3L/59cRndHDkfOgC8+QRQ8OkIYN6yknGPxAktdo047ZAE3K27oE2eTSiB7HSvoEJ83+K5eOjm+cZkEqloJO3ps6pbClz7MDBlOgCR/ffah8MJplvjaKvf2/Fhp8350j7o6gXdmUAGQ4tdIxmb7kD52H2xw/hvTSiBbDeGUyeWK+0lqhPPexEcaph4/qxUBwbd/PZ+/PdTatXHAUzIBDIYWuwaueUz013PsRPbMP5bE0ogW41BILvlP8r4Qpx42ouQE6gP3wcgRwVq2PrU9lm5+e29+u9zLiHVcQC3FUWJQ4tdI7nsF7tqhE5f4DDuFBNqv+SPoaunFwKj3tfI+9lGVeXP4rqL65c6P4NLqqK2z8rNb2/XWDofkSl0CamMA5iQCWQwQsa18yuPhoYG2d7eHvt9TcdPCqJdcaqaqkrLYlYmE+uzFC/hgazAhPVJB71ucxVg2VlVZJtR6MRLGd4X7h1ebeRPzbB+dpUT6kMX2qROTs828EgpQoitUsoGt/NosRuEH/+wyloyuok1mBvVZp6g1/WS262rjrib3z7/93GPMcnNRmKAwp5QTHCnqCLWYG5US/ig13UTqKQECd0mAdXPUWIlAvxCYU8wJmWAhCHW1UfY3Y92lmnQ67oJlAnlAlQQxXOUUIkAv1DYiXZiXX2EWcI7WZ1hruskUBGsMLSUk2CwM1Yo7MQIYlt9hFnCO1mduYCdateA4voqKguv+aLyVOvdqfm1ZIgyKOyk9Ai6hHezOqNwDSgOEm5vXYfnxc9w1oSD+ECegfsHluLZ/s/GWk6CRA+FnRCv6KhOGHSFYRULAPCt/v+NiWXZDka14iDWlP8b0A/8quezztfIWda9R4KtRnqP+DtOQkFhJ8QrYa3noCmBflcCdrGAcZWYKE4WnDpRnMS3xq3H1olXOl8j340SJKOFJXtjhSUFiBaCFjDTSphiWy6lA8acG6a4lV0swKYC41ni0NgMJLciX36375dYazrd0GInsaMtgKcCj9ZzcebJ8+IeTLQQ2+PP3YMrf3PGyHk/vOAv+NTOVeHyvX1mmvRNrB77vnu5hp/7MO88VijsCcOEzkdOeBmfCfXgo8Rq4qqYsH9MjwkAqDi+H10nekfOO2vr/YAIme9t283oNGCgd4wraeLVFpa33TWKz3HDyf2UW5lQ6JVDV0yCMKHzkRNex2dCPfgosZq4PpCnW55bfHwaDlpf1I91bOf2uPp73l1JVtcovp6bG8XJ/eTHNUV8Q2FPECparUWJ1/GpaEBhMlYT1P0DS3Fcji84dlyOx/0DhaL6gTzD+qJ+goxOsYC6pdmc++ae7L92FnLxNSpPy/7nJ7bglPdv97vnvu39OYktFPYEYbqlazeOrp7eAqvdhHrwUWI1QT079FncX/4PBWJ7f/k/4NmhwjTD+weWohcTCl8cJMjoVcC9Mn5S1uL3cz2nvH+73/UeptWuAAp7gjDd0nUaR75LxlMDigRjN3HNb7qtQGznN9025rznM5/HWxf9i9o2d/l4zbhR4Spxaq7htALZ+FW2ugsJ67Frxk8wtDgoB2QFwxRRtBpfPkmsFR8Ur5+rp/PcApBeM0381Iz3U+/cbgxO9wOyAu6Eijr5KcNrPXYKu0aCCHUSsmL+6RfbLX8nAPx1TVO8A4qCOGuPu4mjn+YefsTaawMQt8nC6b363kzb3HrHsZUwbLSRAIKk/eku1es2sSyurxlpd1eMk6vG9AlrhLjro7v19vRTCtdPhUWvO0WdgqD5gr5k3dgxXf099/Z6rP4YiFA+diHETUKIXUKIISGE6yxCCjE9GFqM13RGv8FR09M4C4i7ibKtGL9vn2du9xo/DaW97hR1CoK6+ecLMm9sYMmBQIQNnr4FYAmAlxSMpeQwPRhajNd0Rr/BUdPTOAuIu654EGGze42fbf1eyyd4HZ/d5JfL3lnyCEsOKCSUK0ZKuRsAhLDYUkcAOLsYkta31M8Kw4/LKFErl7iLWVkVHnPCSQz9buv3Uj7Bz/icJj+WHFBKbD52IcRtAG4DgLPPPjuu22rFrSZK0vqWRtWbNNaep2GJu4lyseBZBjSHmTLdXQxV14y3EuSTx6yDoqIMaJ4CiAwgB8eOl63ulOGaFSOE+C2Aaotf3SWlfGb4nN8B+B9SSk+pLqWSFXPpms2WgpXUtD+vWTx+A6Fu1zUusBpnVkwxfjJbnIjyGawyZexgSqMvlGXFSCmvUDOk0iNRLgYPeFlhBKncuLi+Bu3vHcaTf3gfg1IiIwT+7pM1I6JuXCVInZalihVD1Jk9xVa8KMta6FYksTF3AmC6Y4QkysXgETffuV0gdPWvdjnm5j+9tQuDw6vHQSnx9NYuNJxzWuorQfpGhS/aKbNHlcDmT37NVc7nMqVROWHTHW8QQnQC+C8AWoUQm9QMKx2kvSaKFXarkSPH+23TF53EO22rHiWErQNjWmYPUxqVE0rYpZS/lFLWSiknSCn/k5RykaqBpYG010Sxwmk1Ype+6CTeoVNCw3YjSiN+8tlVsPAeoKzc+ndl5UxpjAAWAYuYxfU1eHXlAvx1TRNeXbnACFGPsi2d02rETsCdxDvUqoc1v62Ju01d3VJgwmTr302YTP96BFDYS4yod3kurq9BVaW1dWYn4E7iHWrVE/cu0TD8+pvA6tOy6YCrT8v+HBVhercGpfeIv+MkFAyehsC4NDwPxBGMbL5ujq+NV27ZNoHr48TtS3bDLsXw198E2n88ep4cHP35mh+ovVfxcasaLlEQ98auEofCHhAj0/A8EEcw0u/Gq8gmSJPE5NffBNofxcgGo/wUw62PWb9m62PBhN0unXHvFuDNJ+IrYJZP3Bu7ShwKe0CSmoYXVwqmVys70gnSFDHZsb5Q1HPk3EJ2Od75x/1sKLJzQW19bOy94sojZ8mAWKGwBySpaXim1aeJdII0RUxeuBe2pQA+7BzdYl+MGI47+N1QZOdqsptA4nJNsWRAbFDYA5LUzUem1aeJfII0QUychHNKLXBeY6GPPccnv5j91++GIjsXlN0EQj936qCwB8Q0y9cPupt15JPUCdIXdkILUbiCyLlKRAaY8VngL20OnYxgP2HYuaDmLSv0seeO51xTOmvgEKUw3TEgpbj5KApKYneuVd44BNDw5VHhvOYHwKrDQPOHwA3/B+j842j+vR12lrZdOuM1P7BPc2TOf6pgz1OinSSmjfrGjzVsV8ExH9VVEVVVjSSRwp6nJDEUuIZ2rAde+AbwTMrcAX58/Y7BTBHN+2Jazj8JBYWdmINT/vVf2krH92ubfx+h9WxSzj8JDX3sxBzssj/aHy0t36/XWi4qC5zFXT+GRAqFnZiD7bLfZmNPWvFSy0V1sFNH/RgSGQyeEjWoSJXzEjQcQWTrkevAhLRABjtLEq/BU1rsJDyqrEe7tEArdPl+TUkLZLCTOEBhJ+FRVR7Xyh3Q8GWzfL+mlAKOu1kGSRTMiiHhUWk9WqUFnn2xftdHDlMsZZUFzkxwLRGl0GIn4YnSeiwWnfMasz/ranVniqWsKthpimuJKIXBUxKe4vxzQM3OSKvrFqN6B2aQMcU9Br84WeQMwiYKBk9JfESVKmflzy4mbv920tIC3SxyU1xLRCn0sRM1RFEe16u4xC1CJpQC9opbyV/uOE0ltNiJuXgVlyAipHLXpsm4WeTccZpKKOzEXCzz2osIIkKlFDB0C/YmzbVEPEFhJ+Zimdf+98FFKGelb/yqGbnoceDFIq9bmg2ULlmX/XnjbelexZQA9LETs1Hlz/aSYWN6wDBIvrnXvq9++6oSo6Gwk9LAS4aNCQFDO/EOI7xeJke/fVWJ0VDYSXLxY8G6WeM6AoZWm6/ye5Lmi3fUwsu0x1RBHztJJn4DoE7WuI6AodX42x+1F28n4VWR4WPKjlqiBAo7iZao0gr9FuOyCyIueSQbOIzb3WDpGrLZBZ6z6K2oPDW6yppMe0wsFHYSHVGmFfp1HZiW1ufHxZFzM1kJLxBdZU2mPSYW+thJdAT1C3vxnQfZMWnSjlG78Y9BFD5/8fuy8Tbrl6mqrEkSCS12Eh1BAnJerfykuw68bL4CAMhRsc3lmzf3jLqP6BsnFhhjsff396OzsxN9fX26h2IEFRUVqK2tRXl5ue6hBCeIVW1n5T/37UJr0mt+tqkUj1+UAXJw7HlTpjtfR2VddpIajBH2zs5OTJ48GTNmzIAQNu3QSgQpJQ4dOoTOzk7MnDlT93CCE0R07NwTvYezVnuxuOsSchXNKfLHb1cO2E2g65YCe7cAWx/LTgwiA8xblpwJjkSCMa6Yvr4+nH766SUv6gAghMDpp5+e/NWL34DcjvWw7XEKmLPlP4qgsNt7ZZddtGN9Nvc9Z+3LwezPLAdQ0hhjsQOgqOeRmvfCj1X9wr2wTfkDzNksE9VmIbv3ymnXKXeMEguMsdhNIJPJYP78+ZgzZw7mzZuH73//+xgaGgIAtLe34xvf+IbrNS655BIAwLvvvosnnnjC9V7z58/Hddddp+YBko6bcJsSEIx7l6aTeHPHKLEglMUuhHgAwLUATgL4DwBfklL2qBiYDiorK7F9+3YAwIEDB7Bs2TIcPXoUq1evRkNDAxoaXDtS4bXXXgMwKuzLli1zvZdWTGpk7JQCGGVA0O97EHdzCifxZqMMYkFYi/15ABdKKesAvA3gzvBD8kbLti5cumYzZq5sxaVrNqNlW5fS65955plYt24dfvSjH0FKid/97ne45pprAADd3d248sorMWfOHHzlK1/BOeecg4MHDwIATjnlFADAypUr8fLLL2P+/Pl46KGHlI5NGabVJbdLAaw8LbrNMkHeg7hTLZ1SGpOe9kkiIZSwSynbpJQDwz9uARCLmdAbdHFMAAAKj0lEQVSyrQt3btyJrp5eSABdPb24c+NO5eI+a9YsDA4O4sCBAwXHV69ejQULFmDXrl248cYbsXfv3jGvXbNmDT73uc9h+/btuOOOO8b8vq+vDw0NDbj44ovR0tKidNye8bstP2qsAohLHgG+/dfoVhFB3oO4d2k6iTd3jBILVAZPvwzgFwqvZ8sDm/agt78w57e3fxAPbNqDxfU1kd//lVdewS9/+UsAwFVXXYVTTz3V9zXee+891NTUoKOjAwsWLMDcuXPx8Y9/XPVQnTHRPxt3CmPQ9yDOcbrl7HPHKCnCVdiFEL8FUG3xq7uklM8Mn3MXgAEAjztc5zYAtwHA2WefHWiwOT7osa6rbXc8KB0dHchkMjjzzDOxe/dupdeuqclOQLNmzcJll12Gbdu2RSfsdj5k+meT8x5QvIkPXF0xUsorpJQXWvyXE/UvArgGwH+TUtrmqkkp10kpG6SUDVOnTg016LOqrLdi2x0PQnd3N26//XZ8/etfH5N6eOmll2L9+qwPtq2tDUeOHBnz+smTJ+Ojjz6yvPaRI0dw4sQJAMDBgwfx6quv4oILLlA29gKcfMj0z/I9IKkklI9dCHEVgG8BuE5KeVzNkNxZsWg2KsszBccqyzNYsWh2qOv29vaOpDteccUVaGxsxKpVq8act2rVKrS1teHCCy/EU089herqakyePLngnLq6OmQyGcybN29M8HT37t1oaGjAvHnzcPnll2PlypXRCbtbnnPc/tmoyvgGJe73wLTnJ6lEOBjZ7i8W4h0AEwAcGj60RUp5u9vrGhoaZHt7e8Gx3bt34/zzz/d875ZtXXhg0x580NOLs6oqsWLR7Fj86wBw4sQJZDIZjBs3Dr///e/xta99LZLURb/viSXNVbDe9COyxaTixG7bfKkE+0r9+UlohBBbpZSuedehgqdSynPDvD4Mi+trYhPyYvbu3YulS5diaGgI48ePxyOPPKJlHJ4wyYdc6rskS/35SWwYVVIgKZx33nnYtm2b7mF4w6TqfyZm4cRJqT8/iQ2WFEg7JuU5l3rt8FJ/fhIbtNhLAVNS5UxaPeig1J+fxAYtdhIfJq0edFDqz09igxY78Y7q5hKlSKk/P4kFWux5xFm2d+/evWhsbMT555+PCy64AO+++66SZ4gM0wqGEUJsobDnkSulu2vXLjz//PN47rnnsHr1agBAQ0MDHn74YddrFJfttePWW2/FihUrsHv3bvzxj3/EmWeeWXjC8cPA0Q/M2ciiu2AYN/YQ4pnkCnvEX/Qoy/b+6U9/wsDAAK688sqR10ycOHH0hOOHsxbx0ACMsY51pupxtUCIL5Ip7DF90aMq2/v222+jqqoKS5YsQX19PVasWIHBwbxqlR/tA+RQ4QV1ltMF9Kbq6V4tEJIwkinsmr/or7zyCr7whS8ACFa2d2BgAC+//DIefPBBvP766+jo6MBjjz02esLgSesX6tzIorNYFjf2EOKLZAp7TF/0/LK9KqmtrcX8+fMxa9YsjBs3DosXL8Ybb7wxekJmvPULdW5k0Zmqx409hPgimcIewxc9yrK9n/rUp9DT04Pu7m4AwObNmwurO06eBoiij8aEjSx1S4E73soWD7vjrfjS9lhalxBfJFPYI/qix1W2N5PJ4MEHH8TChQsxd+5cSCnx1a9+dfSEiadlLeKyceBGFnBjDyE+CVW2NygqyvYq2SwTkESV7SWEpIZYyvZqReMOvkSV7SWElBzJFXaNJKpsLyGk5Eimj50QQogtFHZCCEkZFHZCCEkZFHZCCEkZFPY84irb++KLL2L+/Pkj/1VUVKClpUXdgxBCShpmxeSRK9sLAAcOHMCyZctw9OhRrF69Gg0NDWhocE0fHVO2d9myZWPOufzyy0fuc/jwYZx77rlobGxU+CSEkFImsRZ7a0crGjc0ou4ndWjc0IjWjlal14+ybG8+GzZswNVXX11YtpcQQkKQSGFv7WhF82vN2HdsHyQk9h3bh+bXmpWLe1Rle/P5+c9/jltuuUXpuAkhpU0ihX3tG2vRN9hXcKxvsA9r31gby/3Dlu3NsW/fPuzcuROLFi1SOTwSBHZoIikikcK+/9h+X8eDElXZ3hzr16/HDTfcgPLy8kiuTzzCDk0kZSRS2KsnVfs6HoQoy/bmePLJJ+mGMQF2aCIpI5HCvvyi5ajIVBQcq8hUYPlFy0NdN66yvUA2a+b999/H5z//+VBjJgpghyaSMhKZ7tg0qwlA1te+/9h+VE+qxvKLlo8cD0pB39EiLrvsMlx22WUAgClTpmDTpk0jZXtff/11TJgwAQDwt7/9DQBQXl6OzZs3215vxowZ6OrqCjVeoogptcNuGIvjhCSQRAo7kBX3sEIeFJbtTRkL78n61PPdMezQRBJMYoVdJyzbmzJydf01NW4hRDUUdkIArY1bCFGNUcFTHW36TIXvBSEkKMYIe0VFBQ4dOkRBQ1bUDx06hIqKCveTCSGkCGNcMbW1tejs7ER3d7fuoRhBRUUFamuZlUEI8Y8xwl5eXo6ZM2fqHgYhhCQeY1wxhBBC1EBhJ4SQlEFhJ4SQlCF0ZKEIIT4CsCf2G8fHGQAO6h5EhKT9+YD0PyOfL5mcI6Wc6naSruDpHimle5+5hCKEaOfzJZu0PyOfL93QFUMIISmDwk4IISlDl7Cv03TfuODzJZ+0PyOfL8VoCZ4SQgiJDrpiCCEkZWgTdiHEvwghdgghtgsh2oQQZ+kaSxQIIR4QQvx5+Bl/KYSo0j0mlQghbhJC7BJCDAkhUpN9IIS4SgixRwjxjhBipe7xqEYI8agQ4oAQ4i3dY4kCIcR0IcSLQog/Df99huuXmVB0WuwPSCnrpJTzAfwaQNra1TwP4EIpZR2AtwHcqXk8qnkLwBIAL+keiCqEEBkA/wrgagAXALhFCHGB3lEp5zEAV+keRIQMAPhnKeUFAC4G8I8p/Axd0SbsUsqjeT9OApAqZ7+Usk1KOTD84xYAqSrVKKXcLaVM2yazTwN4R0rZIaU8CeDnAK7XPCalSClfAnBY9ziiQkq5T0r5xvD/fwRgN4AavaOKH63VHYUQ3wFwK4APAVyucywR82UAv9A9COJKDYD8rtadAD6jaSwkJEKIGQDqAfxB70jiJ1JhF0L8FkC1xa/uklI+I6W8C8BdQog7AXwdwKoox6Mat+cbPucuZJeHj8c5NhV4eT5CTEQIcQqApwH8U5F3oCSIVNillFd4PPVxAL9BwoTd7fmEEF8EcA2AhTKBeaU+Pr+00AVget7PtcPHSIIQQpQjK+qPSyk36h6PDnRmxZyX9+P1AP6sayxRIIS4CsC3AFwnpTyuezzEE68DOE8IMVMIMR7AFwA8q3lMxAdCCAHgxwB2Syl/oHs8utC2QUkI8TSA2QCGALwH4HYpZWqsIyHEOwAmADg0fGiLlPJ2jUNSihDiBgD/C8BUAD0AtkspF+kdVXiEEP8VwA8BZAA8KqX8juYhKUUI8SSAy5Ctfvj/AKySUv5Y66AUIoT4LICXAexEVlsA4H9KKX+jb1Txw52nhBCSMrjzlBBCUgaFnRBCUgaFnRBCUgaFnRBCUgaFnRBCUgaFnRBCUgaFnRBCUgaFnRBCUsb/B6MGGqDCqMoYAAAAAElFTkSuQmCC\n", 375 | "text/plain": [ 376 | "
" 377 | ] 378 | }, 379 | "metadata": {}, 380 | "output_type": "display_data" 381 | } 382 | ], 383 | "source": [ 384 | "for x, x_l in train_loader:\n", 385 | " z_loc, z_scale = vae.encoder(x)\n", 386 | "\n", 387 | "legends = [\"Digit 5\", \"Digit 6\", \"Digit 7\"]\n", 388 | "z_loc = z_loc.detach().numpy()\n", 389 | "for idx, i in enumerate([5,6,7]):\n", 390 | " plt.scatter(z_loc[x_l.numpy()==i,0], z_loc[x_l.numpy()==i,1], label=legends[idx])\n", 391 | "plt.legend()\n", 392 | "plt.show()" 393 | ] 394 | } 395 | ], 396 | "metadata": { 397 | "kernelspec": { 398 | "display_name": "probabilistic.ai", 399 | "language": "python", 400 | "name": "probabilistic.ai" 401 | }, 402 | "language_info": { 403 | "codemirror_mode": { 404 | "name": "ipython", 405 | "version": 3 406 | }, 407 | "file_extension": ".py", 408 | "mimetype": "text/x-python", 409 | "name": "python", 410 | "nbconvert_exporter": "python", 411 | "pygments_lexer": "ipython3", 412 | "version": "3.7.0" 413 | } 414 | }, 415 | "nbformat": 4, 416 | "nbformat_minor": 2 417 | } 418 | -------------------------------------------------------------------------------- /Day3/elbo_evolution.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PGM-Lab/probabilisticAI_tutorials/40b644dccee2d26e469bb5fd32ae01e430e66200/Day3/elbo_evolution.pdf -------------------------------------------------------------------------------- /Day3/elbo_evolution_with_1_samples.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PGM-Lab/probabilisticAI_tutorials/40b644dccee2d26e469bb5fd32ae01e430e66200/Day3/elbo_evolution_with_1_samples.pdf -------------------------------------------------------------------------------- /Day3/reg_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PGM-Lab/probabilisticAI_tutorials/40b644dccee2d26e469bb5fd32ae01e430e66200/Day3/reg_model.png -------------------------------------------------------------------------------- /Day3/simple_pyro_exercise.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PGM-Lab/probabilisticAI_tutorials/40b644dccee2d26e469bb5fd32ae01e430e66200/Day3/simple_pyro_exercise.png -------------------------------------------------------------------------------- /Day3/slides-L3.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PGM-Lab/probabilisticAI_tutorials/40b644dccee2d26e469bb5fd32ae01e430e66200/Day3/slides-L3.pdf -------------------------------------------------------------------------------- /Day3/solution_simple_model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import numpy as np\n", 17 | "import torch\n", 18 | "from torch.distributions import constraints\n", 19 | "import matplotlib.pyplot as plt\n", 20 | "\n", 21 | "import pyro\n", 22 | "from pyro.distributions import Normal, Gamma, MultivariateNormal\n", 23 | "from pyro.infer import SVI, Trace_ELBO\n", 24 | "from pyro.optim import Adam\n", 25 | "import pyro.optim as optim" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "## Generate some data" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "# Sample data\n", 42 | "np.random.seed(123)\n", 43 | "N = 100\n", 44 | "correct_mean = 5\n", 45 | "correct_precision = 1\n", 46 | "data = torch.tensor(np.random.normal(loc=correct_mean, scale=np.sqrt(1./correct_precision), size=N), dtype=torch.float)\n" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "## Our model specification" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 3, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "def model(data):\n", 63 | " gamma = pyro.sample(\"gamma\", Gamma(torch.tensor(1.), torch.tensor(1.)))\n", 64 | " mu = pyro.sample(\"mu\", Normal(torch.zeros(1), torch.tensor(10000.0)))\n", 65 | " with pyro.plate(\"data\", len(data)):\n", 66 | " pyro.sample(\"x\", Normal(loc=mu, scale=torch.sqrt(1. / gamma)), obs=data)" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": {}, 72 | "source": [ 73 | "## Our guide specification" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 4, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "def guide(data=None):\n", 83 | " rate = pyro.param(\"rate\", torch.tensor(1.), contrainst=constrain.positive)\n", 84 | " conc = pyro.param(\"conc\", torch.tensor(1.), contrainst=constrain.positive)\n", 85 | " pyro.sample(\"gamma\", Gamma(rate, conc))\n", 86 | "\n", 87 | " mu_mean = pyro.param(\"mu_mean\", torch.tensor(0.))\n", 88 | " mu_scale = pyro.param(\"mu_scale\", torch.tensor(1.), contrainst=constrain.positive)\n", 89 | " pyro.sample(\"mu\", Normal(mu_mean, mu_scale))" 90 | ] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "metadata": {}, 95 | "source": [ 96 | "## Do learning" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 5, 102 | "metadata": {}, 103 | "outputs": [ 104 | { 105 | "name": "stdout", 106 | "output_type": "stream", 107 | "text": [ 108 | "[epoch 000] average training loss: 1599.6830\n", 109 | "[epoch 500] average training loss: 546.2211\n", 110 | "[epoch 1000] average training loss: 284.1279\n", 111 | "[epoch 1500] average training loss: 185.3350\n", 112 | "[epoch 2000] average training loss: 195.8745\n", 113 | "[epoch 2500] average training loss: 178.9226\n" 114 | ] 115 | } 116 | ], 117 | "source": [ 118 | "# setup the optimizer\n", 119 | "adam_args = {\"lr\": 0.01}\n", 120 | "optimizer = Adam(adam_args)\n", 121 | "\n", 122 | "pyro.clear_param_store()\n", 123 | "svi = SVI(model, guide, optimizer, loss=Trace_ELBO(), num_samples=10)\n", 124 | "train_elbo = []\n", 125 | "# training loop\n", 126 | "for epoch in range(3000):\n", 127 | " loss = svi.step(data)\n", 128 | " train_elbo.append(-loss)\n", 129 | " if (epoch % 500) == 0:\n", 130 | " print(\"[epoch %03d] average training loss: %.4f\" % (epoch, loss))" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 6, 136 | "metadata": {}, 137 | "outputs": [ 138 | { 139 | "name": "stdout", 140 | "output_type": "stream", 141 | "text": [ 142 | "rate tensor(1.8581, requires_grad=True)\n", 143 | "conc tensor(2.1757, requires_grad=True)\n", 144 | "mu_mean tensor(5.0201, requires_grad=True)\n", 145 | "mu_scale tensor(0.1044, requires_grad=True)\n" 146 | ] 147 | } 148 | ], 149 | "source": [ 150 | "for name, value in pyro.get_param_store().items():\n", 151 | " print(name, pyro.param(name))" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 7, 157 | "metadata": {}, 158 | "outputs": [ 159 | { 160 | "data": { 161 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAEKCAYAAADenhiQAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xl8VPW9//HXZyYLIRASdsIW9kU2MYAooiwioBa12qK3VSu/InW5bW9bi1rrVq3Ltbb2WltatWp7r1q6SC2KuFtbFFzYXdgUEGSVnUCS7++PcxJmwkySIWcymfB+Ph7zyMz3fM+Z7zeTnM98l/M95pxDREQkSKFUF0BERBofBRcREQmcgouIiAROwUVERAKn4CIiIoFTcBERkcApuIiISOAUXEREJHAKLiIiEriMVBcgKGY2EfgFEAZ+55y7q7r8rVu3dkVFRfVRNBGRRuOdd97Z5pxrU1O+RhFczCwMPAicCWwAFprZHOfcinj7FBUVsWjRovoqoohIo2Bmn9QmX2PpFhsOrHLOrXHOHQKeBKakuEwiIsetxhJcOgLrI15v8NOimNl0M1tkZou2bt1ab4UTETneNJbgUivOuVnOuWLnXHGbNjV2GYqIyDFqLMFlI9A54nUnP01ERFKgsQSXhUAvM+tmZlnAVGBOisskInLcahSzxZxzpWZ2DTAPbyryI8655SkulojIcatRBBcA59xcYG6qyyEiIo0ouIjUxf5DpZSVO5o3yaxV/i/2H6LcQcvcrMq0NVv30javCc2yj/xb7S0pJSczTDhkbPziAAVNM9m+9xDtWzQB4M/vbKBrq1wcjo75OezYd4h/fryNddv3M6ZvGz774gDZGWFO6dGKjV8cYNbra8jKCDFlSCE92jTjk+37yW+aybOLN7F9Xwmn9mwNwKZdB2mVm8UpPVqzYtMuln+2m10HDvP57oN8pbgz2/aWsOKz3ZgZ553YkdysMI/+ax39O+RRXu7IzAjhHBTmN+Hvizfx9trtnDu4kHMGFfKrV1cxrKglpeWOnMwwQ7vk88mO/azfsZ9Pd+ynaVYG67btY0DHPDbtOkhRq1x2HThMrv97ee/TnbTLa8KzSz5jypCOTBvVjQ07D/De+p2s3rKXcwYV0r8wj/fXf8GidTvo1yGPDzbtodw52jTPJjMcIisjRGbY2LqnhO17D7Fy8x4mD2hPucN/rzDb9h6ibfNsSkrL6dwyh/U7DpCTGSYjbJSUlrOvpJRVW/Yyvn87XlzxOVv2lHDu4A5khUOs3LSHM/q0YeWm3azZuo+BnVrw1todFLVqSqeCHLbvO0RuVgYTB7Rn254S1mzbx5zFn7H3YCnhkDGqV2tO7JzP22t3MGlge1Zt2csHm/fQu11z+rZvzsdb9vLC8s10aZnLCys2M6ZPW8b2bcuCtdtZvnE3h8rKOalrAS+v3EJ+00wOlZWT1ySTolZNGdCxBc8t20xRq1yaZYf5dMd+WjfLZmjXAnKzM9i2p4R12/exc/8hXv9oG93b5HJ67zaEQ8a85Zu5eHgXJg3oQDhkQf37xGTOuaS+QUNVXFzsdBFl41FSWsatf1/BVWf0oFNB02rz7tx3iHtf+JCbz+1PdkaY9Tv2c9o9rwCw+s7JHCotJzsjxD+WbuLa/3uPvu2b8/0JfRjRvSUlpeUU/+TFymMN79aSt9fuSGrdRIL23k1nUhDxxSgRZvaOc664xnwKLpJMuw4cpklmiOyMcI15y8odzy75jHMGFUZ9qzpcVs6yjbt49M11/GLqEA6XOTLDhpmXp2jmP6KO87erT+W62YsZ0a0Vzy/fTNvm2Vw8vAun925D27xs+vzo+WArKZJm1t119jHvq+BSAwWX5Pp890EWrtvBNf/7HsOLWvL0jJGV6Ss37WbTroOs2rKXVz7cwg2T+rFq617ueu6Dyv0zw8bhMkePNrms3rovVdU47vzjP0dx9gP/THUxku6eLw/iuj8vqXydlRHiUGn5MR8vOyNESR32jzSqZ2v+uWpb5evpo7szaUB7ilrlcuLt8+t8/Bsm92X66B7HvL+CSw3qK7g8+fanjO3XlrbNm9Qq/+h7XqFZdgZzv31azO2fbN9HWbmje5tmMbeXlJZRWuYq+7cjvblqG2XljmFFLdm2t4R/r97OV4YduTxo8fov+HDzHgrzcyjIzeSWOcsZ1bMN3dvkMr5fO4p/Mp/zh3bkgqGd2LanhOlPvMMzV5/Kyx9s4RcvfczMSX35anFn5q/8nOtmL4l673MGdeDELgXc/mzc5d4khrU/nczBw+X0+3H9tLbW3XU2W/eUMOwOr+vvhxP7cvfzXtBvlu2NMcx+Z0Pc/U/u3pI/TBvBik27mf74O2zefTBmvg9un0h2Roj5Kz5n+hPvADCub1te+mBLzPx5TTLYfbCUa8b0ZO22ffxj6aaj8rx9wzi27zvEgcNlXPCrf3H/Vwfz3acWH5Xv79eMYmCnFnyyfR97DpZSkJtFx/wcSkrL4rZqf3R2P07u3opzfukF3nMHF3rHWvwZV47uznfG92b2uxu46W/LaJadwbfH9eKOuSsr92+f14SfTx1C+7wmXPLbBXy26yDhkPG9Cb254MROzPjDO7y//guW3DKBrHCIvjd55XjjujF0bhndzfvKB1vIyQrz/LLN/P5f6yrTX/vBGZx+76tHlf29m86MCkpPXzmS4d1axqxnbdQ2uGhAPwB/WrSeH8xewkc/mURWRoiyckdZuePAoTJm/mUp7fKyeeuG8Rw8XMacxZ9x3ewlvHvTmVGDwUDlgCjAb19fw9h+bVmy4QvOHVRIzxufi8r743P6c2KXfHq0bcagW16gV9tmnD+0I/c8/yEAXzu5C39Y8CkAs2eMZPY7G3hy4Xqq2rTrIF8+qSMfbt7DtMeODrYL1+2Mev2HBZ9WHhdgyoNvVj6/67kPolofkZ5dsolnlxx9QmgsrpvYp/J3fyyywiH6FeaxeP0XAAwrKuC+i4ZgZuRkhXnxv0Yz/mevAzB1WOeYn2WF753Zm/vmfwTAo5cP4xu/X1jte19+ShG//9c6ilp5J7E2zbMrt33rjB6VweXNH47l6UXx3xegf4cWZIRDDOqUz4IbxvGPJZvoX5hHt9a59LvpeQ4cLgOgSabXTVoxAQHgd5cVs/tAKYNve4EvDS6kc8scHnxlNVec2o1pp3Vj/Y79nNy9FQcPl/GPpZsY3Dmf/zqzNwAdWjShbZ73gCPdPhXBJfLEmxH2ulO7tsqNKntW2Lvsb3CnFvzsq0MYd99rldsm9G9Pl1ZHTvI/vWAg1/9lqVfnwjxyssKM7duWm/CC8DdHd68MLr/5+kkM6tSCDi1yAJj9rVN45v3PmHF698qu3dkzRlLmHNkZYcrLvS/8fds3PyqwAIzp2xaAk7u34vtn9WHAzfMq6xP5dwLw1g3jKMjN4g/TRjB/xWamn96Djvk58T/AACm41NETCz7hpr8tA+CXL39M62bZvLjyc974eBsvfHc0AJ/vLmHW66u5c+6RE+/D/1zDg6+sjnvcO+aurPzjjDyZV7itSgvg4y17o05ukftc+Ot/x32f+1/8iPtf/Ki6KkotdI4zieDK0d35zetrqt336StH0r8wj4tnLahM++6ZvaNOZj3bNq98fteXB0UFl3F921KYn8MTCz5hTJ82fHN0d+6b/xE3n9u/8kRUoVvrXNZui+5mrAgm00Z1q7acLZpmMm1UN9bv3M/j/45eGHfVHZN45M21XDqyKCr97EEdKp+X+b0kd10wsDItJzNMftNMZk7si5nRomkm/75+LK2bZVNW7th7sJRvj+9Fi5zMypNik8wwT0wbzoDCFjUOSn/j1CIefXMd7fKa0LNtM1Zt2UtGnFlSZsbbN44jPyeLrIwQ6+46m0+37+fP726gc0vvvX/z9ZPoVJBDs+wMCpp6MwubZnmn0Wb+z4kD2kcd96wTol8X5ufwrTOiu6UywqHKk3EoZPxr5tijvnzG0iw7g79cdQobdx4AoGN+U3Kzwuw75AXxdn6wHdWrNaN6tY57nGRQcElQSWkZ33h0ITdM7sfO/YcqAwvAL19eFZV3wv1HvkFEBhag2sBS1Tuf7Kw5kwTu2WtH0btdc/aWlDLU71Y4rVdrcjLD3DrlBEb+9OXKvFkZRy92MeP0HnxvQu/K4NIiJ5NdBw4D3jfrl1Z+ztY9JZVdFGX+N9Znrx3FgI4talXGD26fSEbIyAiHGNuvLcOKWtIkMxx3wHbigPY89Gr0397/O60bTTLDXDy8S43vFwoZt00ZcFRwyQiHauzHz2uSwba9h7hgaKeo473/4wlR+Sq+4WeG4dYpA2Ie67RetVsb8Edn9+d7E/rQJDPMzef25+ZnlsdsDVSo2n3dpVVTvuu3jiA6UPxwYl+6tsplnB/AWzTN5O0bxtUqKNSkMIHWxdAuBQztUgBATlaY5bdNZPveEvb7ASZVFFwS9OHmPfxr9XamP76Iz3bF7k+W9NQxP4eNXxyofN2/Qx6hkNEyI4thRQUsXLeTb4/rRXFRdH/1uzedycJ13nTk8f3acduUE1j+2W7O7N8uKt/U4Z35zWtruPyUIgDG9YveXu5/sw9Z/OsPKgLRPRcO4sTO+ZXdSwBj+rSNuU/FN/bXfnAGhfk5nNarNZf89i0AXvjuaLIzwke1Wv7nkhPZc7A0bjmOxZ9mnMIbH2+NGYiTJRyyyuuOTuvVhpe/f0Zgx87Nzjjq91bRLQdel2ReTu2umwpaq2bZtErJOx+hAf0Erd66N6ovVupH62bZbNtbUvl6zZ2T6X6DtyDD+z8+k1PuejnuN7WQgd8o4Edn9+PELvnkN82iTfNsJv/iDTbsPMBL3zudri2bMvjWF+jephlzrjm1sj8cvPGwh15bzW1fOoEMv2/++WWb2VtSyoUndeLDzXs46+evc8u5/bn81OgTTsVU6dV3TmbPwcPkN439zXbhuh3c9vcV/GnGyKigUWFfSal/8WBiJ+fycoeDqOndf1q0nn4d8mrVQlq8/gtCZgzsdCTv0g27yMoIkZMZZteBw1HbpHHTgL40KlXPp6GIE2V+0yxW3DbxqOtdKjTNyiC/aSYbdh7ghMIWnNT1SMvjqjN6csNfl9I+rwkZ4RDLb5sY8xidWzblzvMHRqVF9q33ad+cBdePo11edtVd+f6E3mSEQ4RDFjewAAwrasnfrx0Vd3usGYC1EYoxxnBRcecYOWMb3Dn/qDQFE6mJgkuCjtOGXspV/N5//bWTaNUs9gn6spFdeXHllqiuLQDDG7TesPMAmeHoE+0lI7pwyYiaxxpqo2JJl6quGdsrkOOLpBMFl4QpuqTCf180mJ/N/4jx/dpWdks9evmwqP77W6cM4NYpR1+xbwYlh70L3GJ1N4lI8BrL/VzqTTq2XCZUGVhOB7efN4Dnv3PkQtLRvdvwt6tPrQws4M33j7xOosLfrj416nU4ZDxw8RC+UtyJvu2bH5VfRIKn4FLPbjqnf+DHvHJ092q3f+3kroG/Z9D++6LBUa8rLiQ7FkM65/ODs/pUvg6Z0bNtc+65cHBUcBKR5NF/Wj2r2ucfhJys6rt6qpnZ2mDkVOmuOlxWt3WarjqjB29cNwZIj/qLNDYKLgmqa69YRij4X3nVK4Crqu66idp47IrhzPYXnkxEIstMdIm4sG1Et5Z8aUhhwu8Xycwqx2NM0UWk3im4JKiuYy4ZSWi51BQ86vqOg+JcC9EkM7g/n8iprU9dOZK2zZtgdSx5C/8CtumnVd9tKCLB02yxBD23rG6LL4aT8C26pkPW9Zt7QW4W/TLyaJeXzee7j1zIWFOgLcjNPGpacCJ6tW3GxcO7MG1U0THtX90yKCKSXGq5JOi5pZvrtH8ybi3aooYlJhJ9y4uHH32BXW52Bm/dMD7qZF1dbOndrhnnn9ipmhw1C4WMn14wMGrRRhFJDwouCaprw2PSwOrHRxL10vdOr1z5NJ5EWy5VV7aNp7qlgyYP7FC5oF9NhsS4AlxE0puCSz2rze1+E9Ejzk3DIiVrPLu6dakMo6h1bq26pQb74y1dWzWle+vcGnKLSDpQcElQOs48SrRbrLZVnFhlltrXI66niTzG7y6tcY07AF77wZhAV60VkdTRgP5xIbHoUttZWtNHd+en/p0nK1oo2RkhfvfPtVFHiLfmVoU0XPRARGqglkuCkjAeH4jq7iKYrJZLrFbckWtLEntPEWlcFFwSFMRJM3Jpkup8f0LvmjP5qltWpmN+DsOKCpjsTyao6Y6DdaliRSukpu5DTREWadwUXBJU1wv7AK4e07PGcYi1P53M1WN61vm9ADLDIf404xQ+2b4fgP97+9Nq8zeNuG9IxX3CaysdF/YUkeApuKTImL5tmXF6/HuOm1lgkwcqDhN5J8d4npg2PGrZlupubhWL89susYrev0NeQscSkfSl4JKgoMZcwiFj5qS+wRysBokEqdN6tYl6nfBtsP3siaxnptaOSOOj4JKgoV0Ljnnff80cW6t8p/U6+h4lsVRdSTieioB4LKvYt0iw5TKiu3cL4RN1YaTIcU3BJUFDuxx7cCms5SrB/WrZfVTbRTArWi7xWgiRKxJXNevrJ9XqPSqM7duOxTdPYET3VgntJyKNi65zaWAeu2I4o2LcXTGWqqHlzZlj2XPwMBN//kZU+pGuvNjRpboerJqWlokl3lpn8RpOmrYs0vgouDQwp/duU3MmX9WxFG8g/ujWUV3v51JXVd/+m6fFvyZHRBqHlHSLmdlFZrbczMrNrLjKtuvNbJWZfWhmZ0WkT/TTVpnZzIj0bmb2lp/+lJklNkiQoIY09jy8W8uE8sfrFqvP0LPurrO58WzvmpyvFNdt1WQRabhSNeayDLgAeD0y0cz6A1OBE4CJwK/MLGxmYeBBYBLQH7jYzwtwN3C/c64nsBOYVj9VSK153xnNA1NPrFXeipZLQwqMACcUegtWaraYSOOTkm4x59xKiDlFdgrwpHOuBFhrZquA4f62Vc65Nf5+TwJTzGwlMBa4xM/zGHAL8FBSK9AA9Glf+3ucVPya400rjjVVeXTvNoysxaB8p4IcBnWKvzpybcolIo1PQxtz6QgsiHi9wU8DWF8lfQTQCvjCOVcaI7/4qmu5/MeILlx2ShET7o9qRPL4FcNj5D7aP39Yu+nVcAzXzIhI2kpacDGzF4FYd8a60Tn3TLLetzpmNh2YDtClS/Xra8WT6AmyW+tc1m7bd0zvVVfNm2Sw52Bp5ZhKrKLfcf7ApJejV9vmnNar9VFrqnXz793Sv1BX7os0NkkLLs658cew20Yg8h67nfw04qRvB/LNLMNvvUTmj1WmWcAsgOLi4kb/NfqvV53K6x9tJRSquM4lNVXOygjxxLQRR6Wf1qsNz3/nNPq0022MRRqbhnYR5Rxgqpllm1k3oBfwNrAQ6OXPDMvCG/Sf47yz5SvAhf7+lwEpaRU1RD3bNuOKiKX4G2I07ds+Ly1vwCYi1UvVVOTzzWwDMBL4h5nNA3DOLQeeBlYAzwNXO+fK/FbJNcA8YCXwtJ8X4IfAf/mD/62Ah+u3NulDQx4iUl9SNVvsr8Bf42y7A7gjRvpcYG6M9DUcmVHW6PRok8vqrcGM2Yzq2Zrnl2+mTfNstu6peYVkEZFj1dC6xaSKk+qwUGZVl4zwJjFojENEkk3BJUEbdh5IKH9dRxOCuDlZ5bE0tCEi9UTBJUH3zvswofzJGOY4e2CHJBxVRCQ4De0iSqnBo98YVutVk+NxDXLemIg0JgouDUhtuq3G9GlbY553bzqTck0NE5EUUrdYAxcZcE7uXrtVkFvmZtG6WfZR6a1yvbS+7XVFvIgkl1ouaeThy4bVaf/+hXnMnjGSwZ3zmTywAx9s3h1QyUREoim4NCA19YrlZtf94you8lo/J3UtCHSas4hIJHWLJVkis3+1DIqINBYKLklwrDFCoUVEGgt1iyWBEX19S0bIGFaU2C2JK4+liCMiaUjBpR6sunNyrfKlMpDMnNSXTgU5qSuAiDQqCi4BysoIcai03Bs7SbPrTGac3qPWeR+7Yjjb92rhSxGJT8ElCY6lAdIsO4MbJvcLvCzJcHrvNqkugog0cAouDcSyW89KdRFERAKj2WIBSs6QiUb0RST9KLgkQeTA/ENfOyl1BRERSREFlwDFmu3Vp71uzCUixx8FlyQI8gZfIiLpSMFFREQCp+ASoGS0WHSFvoikIwWXZFBAEJHjnIJLgNTKEBHxKLgkgWKMiBzvFFxERCRwCi4BqmixqHtMRI53Ci4NnOKUiKQjBRcREQmcgkuALhnRJdVFEBFpELTkfkBW3zmZktIyfvvG2kDvE5ZetxwTEfEouAQkHLKoK/S/O7434/q1TWGJRERSR8ElCczg2+N7pboYIiIpozGXJAiyW0xEJB0puDRwGSGvq+3Unq1SXBIRkdpLSXAxs3vN7AMzW2JmfzWz/Iht15vZKjP70MzOikif6KetMrOZEendzOwtP/0pM8uq7/rUhzP7tUt1EUREai1VLZf5wADn3CDgI+B6ADPrD0wFTgAmAr8ys7CZhYEHgUlAf+BiPy/A3cD9zrmewE5gWr3WJEIyr8w3XfYvImkkJcHFOfeCc67Uf7kA6OQ/nwI86Zwrcc6tBVYBw/3HKufcGufcIeBJYIp5Z9yxwGx//8eA8+qrHlVprEVExNMQxlyuAJ7zn3cE1kds2+CnxUtvBXwREagq0kVEJIWSNhXZzF4E2sfYdKNz7hk/z41AKfDHZJWjSpmmA9MBunQJ/mp69VyJiHiSFlycc+Or225mlwPnAOOcq+xQ2gh0jsjWyU8jTvp2IN/MMvzWS2T+WGWaBcwCKC4uVieWiEiSpGq22ETgOuBLzrn9EZvmAFPNLNvMugG9gLeBhUAvf2ZYFt6g/xw/KL0CXOjvfxnwTH3Voz5oHEdE0lGqrtD/HyAbmO/PglrgnJvhnFtuZk8DK/C6y652zpUBmNk1wDwgDDzinFvuH+uHwJNm9hPgPeDh+q1K/VCXm4ikk5QEF3/acLxtdwB3xEifC8yNkb4GbzZZo6YWjIikk4YwW0xERBoZBZc0oW4xEUknWhU5QNkZISb0b8elI4sCO6bTHV1EJA0puATIzJh1aXFyjp2Uo4qIJIe6xdKE2i8ikk5qDC5mNsDMHjezRf7jMTMbVB+FExGR9FRtcDGzKcBfgVfx1gC7AngN+LO/TeqJusVEJJ3UNOZyG3Cmc25dRNoSM3sZ70r4RnU1vIiIBKOmbrGMKoEFAD8tMxkFkmi6eFJE0lFNLZdSM+vinPs0MtHMuuItzyJ19PBlxazdtq/mjLrQRUTSSE3B5WbgRTO7E3jHTysGZuKt6SV1NE63LxaRRqja4OKc+5uZrQW+B1zrJ68AvuKcW5zswomISHqq8SJKP4hcWg9lERGRRqKmqcitzexmM/tPM2tmZg+Z2TIze8bM4q5sLMHReL6IpKOaZov9L959Vypu2rUW78ZczwK/S27RJJKG80UkndTULdbOOXeDeXf0+sQ5d4+f/oGZXZ3kskkEtWBEJJ3U1HIpA/BvJ7ytyrbypJRIRETSXk0tl+5mNgevV6biOf7rbkktmURRt5iIpJOagkvk+mH/XWVb1deSBLpCX0TSUU3XubwWb5uZPYW3iKXUA12gLyLppC73cxkZWCnSUF4T3WdNRCQe3SxMREQCV+3XbzMbGm8Tx/mqyBoKERGJr6a+nfuq2fZBkAUREZHGo6YB/TH1VRCJR20kEUk/Na0tdl3E84uqbLszWYWSo5mudBGRNFLTgP7UiOfXV9k2MeCyiIhII1FTcLE4z2O9Pm6dO7gw1UUQEWlQagouLs7zWK+PWz+ZMiDVRRARaVBqmi022Mx247VScvzn+K+bJLVkAmj5FxFJTzXNFgvXV0HSTuRJvx46CLX8i4ikE12hf4y6tm5a+VwnfhGRaAouxyikiCIiEpeCSwAUZkREoqUkuJjZ7Wa2xMzeN7MXzKzQTzcze8DMVvnbh0bsc5mZfew/LotIP8nMlvr7PODfkllERFIoVS2Xe51zg5xzQ4BngR/76ZOAXv5jOvAQgJm1BG4GRgDDgZvNrMDf5yHgmxH71fvFncmMZ5otJiLpKCXBxTm3O+JlLkfmXk0BHneeBUC+mXUAzgLmO+d2OOd2AvOBif62POfcAuecAx4HzqufOhx5Xh9NJTXHRCSdpOyOV2Z2B3ApsAuoWCCzI7A+ItsGP6269A0x0kVEJIWS1nIxsxfNbFmMxxQA59yNzrnOwB+Ba5JVjiplmm5mi8xs0datW+t4rNjPRUQkiS0X59z4Wmb9IzAXb0xlI9A5YlsnP20jcEaV9Ff99E4x8scr0yxgFkBxcbFGM0REkiRVs8V6RbycwpEbj80BLvVnjZ0M7HLObQLmARPMrMAfyJ8AzPO37Tazk/1ZYpcCz9RfTTzJXA7faQk3EUlDqRpzucvM+gDlwCfADD99LjAZWAXsB74B4JzbYWa3Awv9fLc553b4z68Cfg/kAM/5j0ZHXW8ikk5SElycc1+Ok+6Aq+NsewR4JEb6IqDelyWOmi1WDyd+TUkWkXSiK/RFRCRwCi7HqL67qdQtJiLpRMElAOGQrtAXEYmk4FJHPzlvAJnh5P8akzkjTUQkaAoudTSwY4tUF0FEpMFRcBERkcApuIiISOAUXEREJHAKLg2cJouJSDpScEkXmiwmImlEwUVERAKn4CIiIoFTcBERkcApuDRwWv5FRNKRgkua0Hi+iKQTBRcREQmcgouIiAROwUVERAKn4NLAOV2jLyJpSMElTZhuRSkiaUTBRUREAqfg0sB1KmgKQKvcrBSXRESk9jJSXQCp3rVjezKgMI8z+rRJdVFERGpNwaWBywyHmHBC+1QXQ0QkIeoWExGRwCm4iIhI4BRcREQkcAouIiISOAUXEREJnIKLiIgETsFFREQCp+AiIiKBU3AREZHAKbiIiEjgUhpczOx7ZubMrLX/2szsATNbZWZLzGxoRN7LzOxj/3FZRPpJZrbU3+cB09r0IiIpl7LgYmadgQnApxHJk4Be/mM68JCftyVwMzACGA7cbGYF/j4PAd+M2G9ifZT/F1NP5KvFnTmhMK8+3k5EJK2ksuVyP3AX8KN5AAAL6klEQVQdRN1qcQrwuPMsAPLNrANwFjDfObfDObcTmA9M9LflOecWOOcc8DhwXn0UvlvrXO6+cBAZYfUsiohUlZIzo5lNATY65xZX2dQRWB/xeoOfVl36hhjpIiKSQklbct/MXgRirRV/I3ADXpdYvTKz6XjdbXTp0qW+315E5LiRtODinBsfK93MBgLdgMX+2Hsn4F0zGw5sBDpHZO/kp20EzqiS/qqf3ilG/nhlmgXMAiguLnbx8omISN3Ue7eYc26pc66tc67IOVeE15U11Dm3GZgDXOrPGjsZ2OWc2wTMAyaYWYE/kD8BmOdv221mJ/uzxC4FnqnvOomISLSGdifKucBkYBWwH/gGgHNuh5ndDiz0893mnNvhP78K+D2QAzznP0REJIVSHlz81kvFcwdcHSffI8AjMdIXAQOSVT4REUmc5tGKiEjgFFxERCRwCi4iIhI4BRcREQmcgouIiAROwUVERAKn4CIiIoFTcBERkcApuIiISOAUXEREJHAKLiIiEjgFFxERCZyCi4iIBE7BRUREAqfgIiIigVNwERGRwCm4iIhI4BRcREQkcAouIiISOAUXEREJnIKLiIgETsFFREQCp+AiIiKBU3AREZHAKbiIiEjgFFxERCRwCi4iIhI4BRcREQlcRqoLkG4evqyYw2Uu1cUQEWnQFFwSNK5fu1QXQUSkwVO3mIiIBE7BRUREAqfgIiIigVNwERGRwKUkuJjZLWa20cze9x+TI7Zdb2arzOxDMzsrIn2in7bKzGZGpHczs7f89KfMLKu+6yMiItFS2XK53zk3xH/MBTCz/sBU4ARgIvArMwubWRh4EJgE9Acu9vMC3O0fqyewE5hW3xUREZFoDa1bbArwpHOuxDm3FlgFDPcfq5xza5xzh4AngSlmZsBYYLa//2PAeSkot4iIREhlcLnGzJaY2SNmVuCndQTWR+TZ4KfFS28FfOGcK62SHpOZTTezRWa2aOvWrUHVQ0REqkjaRZRm9iLQPsamG4GHgNsB5/+8D7giWWWp4JybBczyy7fVzD45xkO1BrYFVrDUaix1aSz1ANWloWosdalrPbrWJlPSgotzbnxt8pnZb4Fn/Zcbgc4Rmzv5acRJ3w7km1mG33qJzF9T+drUJl+cMi9yzhUf6/4NSWOpS2OpB6guDVVjqUt91SNVs8U6RLw8H1jmP58DTDWzbDPrBvQC3gYWAr38mWFZeIP+c5xzDngFuNDf/zLgmfqog4iIxJeqtcXuMbMheN1i64ArAZxzy83saWAFUApc7ZwrAzCza4B5QBh4xDm33D/WD4EnzewnwHvAw/VZEREROVpKgotz7uvVbLsDuCNG+lxgboz0NXizyerTrHp+v2RqLHVpLPUA1aWhaix1qZd6mNezJCIiEpyGdp2LiIg0AgouCYi3BE1DZmbrzGypv8zOIj+tpZnNN7OP/Z8FfrqZ2QN+/ZaY2dAUl/0RM9tiZssi0hIuu5ld5uf/2Mwua0B1CWwZpHqsR2cze8XMVpjZcjP7tp+edp9LNXVJx8+liZm9bWaL/brc6qd3sxjLY/mTpp7y098ys6Ka6pgw55wetXjgTSRYDXQHsoDFQP9Ul6sW5V4HtK6Sdg8w038+E7jbfz4ZeA4w4GTgrRSXfTQwFFh2rGUHWgJr/J8F/vOCBlKXW4Dvx8jb3//7yga6+X934YbwNwh0AIb6z5sDH/nlTbvPpZq6pOPnYkAz/3km8Jb/+34amOqn/xr4lv/8KuDX/vOpwFPV1fFYyqSWS+3FXIImxWU6VlPwlsqB6CVzpgCPO88CvGuIOsQ6QH1wzr0O7KiSnGjZzwLmO+d2OOd2AvPx1q2rV3HqEk9CyyAlpcBxOOc2Oefe9Z/vAVbirYqRdp9LNXWJpyF/Ls45t9d/mek/HPGXx4r8vGYD48zMiF/HhCm41F68JWgaOge8YGbvmNl0P62dc26T/3wzUHHv5nSoY6Jlb+h1CmIZpJTwu1JOxPuWnNafS5W6QBp+LuYt8vs+sAUvWK8m/vJYlWX2t+/CW04rsLoouDR+o5xzQ/FWlL7azEZHbnReWzgtpwymc9l9DwE9gCHAJrxlkNKCmTUD/gx8xzm3O3Jbun0uMeqSlp+Lc67MOTcEb6WS4UDfVJZHwaX2qluapsFyzm30f24B/or3R/d5RXeX/3OLnz0d6pho2RtsnZxzn/snhHLgtxzpfmjQdTGzTLyT8R+dc3/xk9Pyc4lVl3T9XCo4577AW7lkJP7yWDHKVVlmf3sLvOW0AquLgkvtxVyCJsVlqpaZ5ZpZ84rnwAS8pXbm4C2VA9FL5swBLvVn+JwM7Iro6mgoEi37PGCCmRX43RsT/LSUs4CWQarnMhveKhgrnXM/i9iUdp9LvLqk6efSxszy/ec5wJl4Y0jxlseK/LwuBF72W5zx6pi4+pzRkO4PvJkvH+H1Zd6Y6vLUorzd8WZ+LAaWV5QZr2/1JeBj4EWgpZ9ueDdlWw0sBYpTXP7/w+uWOIzX9zvtWMqOt+L2Kv/xjQZUlyf8si7x/6k7ROS/0a/Lh8CkhvI3CIzC6/JaArzvPyan4+dSTV3S8XMZhLf81RK8YPhjP707XnBYBfwJyPbTm/ivV/nbu9dUx0QfukJfREQCp24xEREJnIKLiIgETsFFREQCp+AiIiKBU3AREZHAKbhIo2Bmzszui3j9fTO7JaBj/97MLqw5Z53f5yIzW2lmr1RJLzSz2f7zIZGr9AbwnvlmdlWs9xKpCwUXaSxKgAvMrHWqCxIp4uro2pgGfNM5NyYy0Tn3mXOuIrgNwbumIqgy5OOtkBvrvUSOmYKLNBaleLdv/W7VDVVbHma21/95hpm9ZmbPmNkaM7vLzP7Dvy/GUjPrEXGY8Wa2yMw+MrNz/P3DZnavmS30Fzm8MuK4b5jZHGBFjPJc7B9/mZnd7af9GO+ivofN7N4q+Yv8vFnAbcBXzbvPyFf9VRge8cv8nplN8fe53MzmmNnLwEtm1szMXjKzd/33rli19y6gh3+8eyveyz9GEzN71M//npmNiTj2X8zsefPuxXJPxO/j935Zl5rZUZ+FHD8S+VYl0tA9CCypONnV0mCgH95y+GuA3znnhpt346hrge/4+Yrw1pjqAbxiZj2BS/GWMxlmZtnAm2b2gp9/KDDAecuWVzKzQuBu4CRgJ96K1ec5524zs7F49xFZFKugzrlDfhAqds5d4x/vTrylO67wl/9428xejCjDIOfcDr/1cr5zbrffulvgB7+ZfjmH+McrinjLq723dQPNrK9f1t7+tiF4qwiXAB+a2S+BtkBH59wA/1j5NfzupRFTy0UaDeetaPs48J8J7LbQeff1KMFb8qIiOCzFCygVnnbOlTvnPsYLQn3x1sO61Lxlzt/CWwKll5//7aqBxTcMeNU5t9V5S53/Ee9GYsdqAjDTL8OreMt6dPG3zXfOVdxDxoA7zWwJ3vIsHTmyLH48o4A/ADjnPgA+ASqCy0vOuV3OuYN4rbOueL+X7mb2SzObCOyOcUw5TqjlIo3Nz4F3gUcj0krxv0iZWQjvboEVSiKel0e8Lif6/6PqOkkO74R9rXMuasFFMzsD2HdsxU+YAV92zn1YpQwjqpThP4A2wEnOucNmtg4vEB2ryN9bGZDhnNtpZoPxbgQ2A/gK3vphchxSy0UaFf+b+tN4g+MV1uF1QwF8Ce8ufYm6yMxC/jhMd7xF/eYB3zJv2XbMrLd5q09X523gdDNrbWZh4GLgtQTKsQfvlrwV5gHXmpn5ZTgxzn4tgC1+YBmD19KIdbxIb+AFJfzusC549Y7J724LOef+DPwIr1tOjlMKLtIY3QdEzhr7Ld4JfTHePS6OpVXxKV5geA6Y4XcH/Q6vS+hdfxD8N9TQG+C85eZn4i2Fvhh4xzn3THX7VPEK0L9iQB+4HS9YLjGz5f7rWP4IFJvZUryxog/88mzHGytaVnUiAfArIOTv8xRwud99GE9H4FW/i+4PwPUJ1EsaGa2KLCIigVPLRUREAqfgIiIigVNwERGRwCm4iIhI4BRcREQkcAouIiISOAUXEREJnIKLiIgE7v8DczkENBjqMxMAAAAASUVORK5CYII=\n", 162 | "text/plain": [ 163 | "
" 164 | ] 165 | }, 166 | "metadata": {}, 167 | "output_type": "display_data" 168 | } 169 | ], 170 | "source": [ 171 | "plt.plot(range(len(train_elbo)), train_elbo)\n", 172 | "plt.xlabel(\"Number of iterations\")\n", 173 | "plt.ylabel(\"ELBO\")\n", 174 | "plt.show()" 175 | ] 176 | } 177 | ], 178 | "metadata": { 179 | "kernelspec": { 180 | "display_name": "probabilistic.ai", 181 | "language": "python", 182 | "name": "probabilistic.ai" 183 | }, 184 | "language_info": { 185 | "codemirror_mode": { 186 | "name": "ipython", 187 | "version": 3 188 | }, 189 | "file_extension": ".py", 190 | "mimetype": "text/x-python", 191 | "name": "python", 192 | "nbconvert_exporter": "python", 193 | "pygments_lexer": "ipython3", 194 | "version": "3.7.0" 195 | } 196 | }, 197 | "nbformat": 4, 198 | "nbformat_minor": 2 199 | } 200 | -------------------------------------------------------------------------------- /Day3/student_BBVI.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Applying BBVI for a simple Gaussian Model" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 13, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import numpy as np\n", 24 | "from scipy.stats import norm\n", 25 | "import matplotlib.pyplot as plt\n", 26 | "import seaborn as sns" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "# Data" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 14, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "# Generate data from a simple model: Normal(10, 1)\n", 43 | "data = np.random.normal(loc = 10, scale = 1, size = 100)" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": {}, 49 | "source": [ 50 | "# Helper function: ELBO\n", 51 | "\n", 52 | "Calculate the exact value of the ELBO. Generally one would have to estimate this using sampling, but for this simple model we can evaluate it exactly " 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 15, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "def calculate_lower_bound(tau, q_mu):\n", 62 | " \"\"\"\n", 63 | " Helper routine: Calculate ELBO. Data is the sampled x-values, anything without a star relates to the prior,\n", 64 | " everything _with_ a star relates to the variational posterior.\n", 65 | " Note that we have no nu without a star; I am simplifying by forcing this to be zero a priori\n", 66 | "\n", 67 | " Note: This function obviously only works when the model is as in this code challenge,\n", 68 | " and is not a general solution.\n", 69 | "\n", 70 | " :param data: The sampled data\n", 71 | " :param tau: prior precision for mu, the mean for the data generation\n", 72 | " :param alpha: prior shape of dist for gamma, the precision of the data generation\n", 73 | " :param beta: prior rate of dist for gamma, the precision of the data generation\n", 74 | " :param nu_star: VB posterior mean for the distribution of mu - the mean of the data generation\n", 75 | " :param tau_star: VB posterior precision for the distribution of mu - the mean of the data generation\n", 76 | " :param alpha_star: VB posterior shape of dist for gamma, the precision of the data generation\n", 77 | " :param beta_star: VB posterior shape of dist for gamma, the precision of the data generation\n", 78 | " :return: the ELBO\n", 79 | " \"\"\"\n", 80 | "\n", 81 | " # We calculate ELBO as E_q log p(x,mu) - E_q log q(mu)\n", 82 | " # log p(x,z) here is log p(mu) + \\sum_i log p(x_i | mu, 1)\n", 83 | "\n", 84 | " # E_q log p(mu)\n", 85 | " log_p = -.5 * np.log(2 * np.pi) - .5 * (1/tau) * (1 + q_mu**2)\n", 86 | "\n", 87 | "\n", 88 | " # E_q log p(x_i|mu, 1)\n", 89 | " for xi in data:\n", 90 | " log_p += -.5 * np.log(2 * np.pi) - .5 * (xi * xi - 2 * xi * q_mu + 1 + q_mu**2)\n", 91 | "\n", 92 | " # Entropy of mu (Gaussian)\n", 93 | " entropy = .5 * np.log(2 * np.pi * np.exp(1))\n", 94 | "\n", 95 | " return log_p + entropy" 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "metadata": {}, 101 | "source": [ 102 | "# Manual estimation of the gradient of the ELBO for the above model" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 4, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "# Gradient estimator using sampling -- vanilla BBVI\n", 112 | "# We here assume the model X ~ Normal(mu, 1)\n", 113 | "# with unknown mu, that in itself is Normal, mean 0 and standard deviation 1000, \n", 114 | "# so effectively an uniformed prior. \n", 115 | "# The variational dstribution for mu is also Normal, with parameter q_mu_lambda\n", 116 | "# -- taking the role of lambda in the calculations -- and variance 1.\n", 117 | "#\n", 118 | "# Note:\n", 119 | "# We can sample from a normal using:\n", 120 | "# * np.random.normal(loc=mu, scale=1, size=1)\n", 121 | "# We can evaluate the the normal density using\n", 122 | "# * norm.logpdf(sample, loc = mu, scale = std. dev.)\n", 123 | "\n", 124 | "def grad_estimate(q_mu_lambda, samples = 1):\n", 125 | " # sum_grad_estimate will hold the sum as we move along over the samples. \n", 126 | " sum_grad_estimate = 0\n", 127 | " for i in range(samples):\n", 128 | " # Sample one example from current best guess for the variational distribution\n", 129 | " mu_sample = np.random.normal(loc=q_mu_lambda, scale=1, size=1)\n", 130 | " \n", 131 | " # Now we want to calculate the contribution from this sample, namely \n", 132 | " # [log p(x, mu_sample) - log q(mu|lambda) ] * grad( log q(mu_sample|lambda) )\n", 133 | " #\n", 134 | " value = ?\n", 135 | "\n", 136 | " # Next grad (log q(mu_sample|lambda))\n", 137 | " # The Normal distribution gives the score function with known variance as - \n", 138 | " grad_q = ?\n", 139 | "\n", 140 | " \n", 141 | " # grad ELBO for this sample is therefore in total given by\n", 142 | " sum_grad_estimate = sum_grad_estimate + grad_q * value\n", 143 | " \n", 144 | " # Divide by number of samples to get average value -- the estimated expectation \n", 145 | " return sum_grad_estimate/samples" 146 | ] 147 | }, 148 | { 149 | "cell_type": "markdown", 150 | "metadata": {}, 151 | "source": [ 152 | "# Perform BBVI using the estimated gradient" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 11, 158 | "metadata": {}, 159 | "outputs": [ 160 | { 161 | "name": "stdout", 162 | "output_type": "stream", 163 | "text": [ 164 | " 100 sample(s) -- Estimate: 9.91943; error 0.8% -- Calc.time: 17.40 sec.\n" 165 | ] 166 | } 167 | ], 168 | "source": [ 169 | "import time\n", 170 | "no_loops = 500\n", 171 | "sample_count = 100\n", 172 | "##### Starting point\n", 173 | "q_mu = -10\n", 174 | "start = time.time()\n", 175 | "elbos = []\n", 176 | "lr = 1E-4 \n", 177 | "\n", 178 | "\n", 179 | "#loop a couple of times\n", 180 | "for t in range(no_loops):\n", 181 | " elbos.append(calculate_lower_bound(1000, q_mu))\n", 182 | " q_grad = grad_estimate(q_mu, samples=sample_count)\n", 183 | " q_mu = q_mu + lr * q_grad\n", 184 | "\n", 185 | "print(\"{:4d} sample(s) -- Estimate: {:9.5f}; error {:5.1f}% -- Calc.time: {:5.2f} sec.\".format(\n", 186 | " sample_count, float(q_mu), float(10*np.abs(q_mu-10)), time.time() - start))" 187 | ] 188 | }, 189 | { 190 | "cell_type": "markdown", 191 | "metadata": {}, 192 | "source": [ 193 | "### Exercise\n", 194 | "* Try varying the number of samples used for estimating the gradient. What effect does it have on the results?" 195 | ] 196 | } 197 | ], 198 | "metadata": { 199 | "kernelspec": { 200 | "display_name": "Python (TFEnv conda)", 201 | "language": "python", 202 | "name": "tfenv" 203 | }, 204 | "language_info": { 205 | "codemirror_mode": { 206 | "name": "ipython", 207 | "version": 3 208 | }, 209 | "file_extension": ".py", 210 | "mimetype": "text/x-python", 211 | "name": "python", 212 | "nbconvert_exporter": "python", 213 | "pygments_lexer": "ipython3", 214 | "version": "3.6.8" 215 | } 216 | }, 217 | "nbformat": 4, 218 | "nbformat_minor": 2 219 | } 220 | -------------------------------------------------------------------------------- /Day3/student_simple_model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import numpy as np\n", 17 | "import torch\n", 18 | "from torch.distributions import constraints\n", 19 | "import matplotlib.pyplot as plt\n", 20 | "\n", 21 | "import pyro\n", 22 | "from pyro.distributions import Normal, Gamma, MultivariateNormal\n", 23 | "from pyro.infer import SVI, Trace_ELBO\n", 24 | "from pyro.optim import Adam\n", 25 | "import pyro.optim as optim" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "## Generate some data" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "# Sample data\n", 42 | "np.random.seed(123)\n", 43 | "N = 100\n", 44 | "correct_mean = 5\n", 45 | "correct_precision = 1\n", 46 | "data = torch.tensor(np.random.normal(loc=correct_mean, scale=np.sqrt(1./correct_precision), size=N), dtype=torch.float)\n" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "## Our model specification" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 3, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "# Fill the method with code to define a simple Gaussian model with mean \\mu and precision \\gamma\n", 63 | "def model(data):\n" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": {}, 69 | "source": [ 70 | "## Our guide specification" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 4, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "# Define the right guide for the above model, including the variational parameters. \n", 80 | "def guide(data=None):\n" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": {}, 86 | "source": [ 87 | "## Do learning" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 5, 93 | "metadata": {}, 94 | "outputs": [ 95 | { 96 | "name": "stdout", 97 | "output_type": "stream", 98 | "text": [ 99 | "[epoch 000] average training loss: 1599.6830\n", 100 | "[epoch 500] average training loss: 546.2211\n", 101 | "[epoch 1000] average training loss: 284.1279\n", 102 | "[epoch 1500] average training loss: 185.3350\n", 103 | "[epoch 2000] average training loss: 195.8745\n", 104 | "[epoch 2500] average training loss: 178.9226\n" 105 | ] 106 | } 107 | ], 108 | "source": [ 109 | "# setup the optimizer\n", 110 | "adam_args = {\"lr\": 0.01}\n", 111 | "optimizer = Adam(adam_args)\n", 112 | "\n", 113 | "pyro.clear_param_store()\n", 114 | "svi = SVI(model, guide, optimizer, loss=Trace_ELBO(), num_samples=10)\n", 115 | "train_elbo = []\n", 116 | "# training loop\n", 117 | "for epoch in range(3000):\n", 118 | " loss = svi.step(data)\n", 119 | " train_elbo.append(-loss)\n", 120 | " if (epoch % 500) == 0:\n", 121 | " print(\"[epoch %03d] average training loss: %.4f\" % (epoch, loss))" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 6, 127 | "metadata": {}, 128 | "outputs": [ 129 | { 130 | "name": "stdout", 131 | "output_type": "stream", 132 | "text": [ 133 | "rate tensor(1.8581, requires_grad=True)\n", 134 | "conc tensor(2.1757, requires_grad=True)\n", 135 | "mu_mean tensor(5.0201, requires_grad=True)\n", 136 | "mu_scale tensor(0.1044, requires_grad=True)\n" 137 | ] 138 | } 139 | ], 140 | "source": [ 141 | "for name, value in pyro.get_param_store().items():\n", 142 | " print(name, pyro.param(name))" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": 7, 148 | "metadata": {}, 149 | "outputs": [ 150 | { 151 | "data": { 152 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAEKCAYAAADenhiQAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xl8VPW9//HXZyYLIRASdsIW9kU2MYAooiwioBa12qK3VSu/InW5bW9bi1rrVq3Ltbb2WltatWp7r1q6SC2KuFtbFFzYXdgUEGSVnUCS7++PcxJmwkySIWcymfB+Ph7zyMz3fM+Z7zeTnM98l/M95pxDREQkSKFUF0BERBofBRcREQmcgouIiAROwUVERAKn4CIiIoFTcBERkcApuIiISOAUXEREJHAKLiIiEriMVBcgKGY2EfgFEAZ+55y7q7r8rVu3dkVFRfVRNBGRRuOdd97Z5pxrU1O+RhFczCwMPAicCWwAFprZHOfcinj7FBUVsWjRovoqoohIo2Bmn9QmX2PpFhsOrHLOrXHOHQKeBKakuEwiIsetxhJcOgLrI15v8NOimNl0M1tkZou2bt1ab4UTETneNJbgUivOuVnOuWLnXHGbNjV2GYqIyDFqLMFlI9A54nUnP01ERFKgsQSXhUAvM+tmZlnAVGBOisskInLcahSzxZxzpWZ2DTAPbyryI8655SkulojIcatRBBcA59xcYG6qyyEiIo0ouIjUxf5DpZSVO5o3yaxV/i/2H6LcQcvcrMq0NVv30javCc2yj/xb7S0pJSczTDhkbPziAAVNM9m+9xDtWzQB4M/vbKBrq1wcjo75OezYd4h/fryNddv3M6ZvGz774gDZGWFO6dGKjV8cYNbra8jKCDFlSCE92jTjk+37yW+aybOLN7F9Xwmn9mwNwKZdB2mVm8UpPVqzYtMuln+2m10HDvP57oN8pbgz2/aWsOKz3ZgZ553YkdysMI/+ax39O+RRXu7IzAjhHBTmN+Hvizfx9trtnDu4kHMGFfKrV1cxrKglpeWOnMwwQ7vk88mO/azfsZ9Pd+ynaVYG67btY0DHPDbtOkhRq1x2HThMrv97ee/TnbTLa8KzSz5jypCOTBvVjQ07D/De+p2s3rKXcwYV0r8wj/fXf8GidTvo1yGPDzbtodw52jTPJjMcIisjRGbY2LqnhO17D7Fy8x4mD2hPucN/rzDb9h6ibfNsSkrL6dwyh/U7DpCTGSYjbJSUlrOvpJRVW/Yyvn87XlzxOVv2lHDu4A5khUOs3LSHM/q0YeWm3azZuo+BnVrw1todFLVqSqeCHLbvO0RuVgYTB7Rn254S1mzbx5zFn7H3YCnhkDGqV2tO7JzP22t3MGlge1Zt2csHm/fQu11z+rZvzsdb9vLC8s10aZnLCys2M6ZPW8b2bcuCtdtZvnE3h8rKOalrAS+v3EJ+00wOlZWT1ySTolZNGdCxBc8t20xRq1yaZYf5dMd+WjfLZmjXAnKzM9i2p4R12/exc/8hXv9oG93b5HJ67zaEQ8a85Zu5eHgXJg3oQDhkQf37xGTOuaS+QUNVXFzsdBFl41FSWsatf1/BVWf0oFNB02rz7tx3iHtf+JCbz+1PdkaY9Tv2c9o9rwCw+s7JHCotJzsjxD+WbuLa/3uPvu2b8/0JfRjRvSUlpeUU/+TFymMN79aSt9fuSGrdRIL23k1nUhDxxSgRZvaOc664xnwKLpJMuw4cpklmiOyMcI15y8odzy75jHMGFUZ9qzpcVs6yjbt49M11/GLqEA6XOTLDhpmXp2jmP6KO87erT+W62YsZ0a0Vzy/fTNvm2Vw8vAun925D27xs+vzo+WArKZJm1t119jHvq+BSAwWX5Pp890EWrtvBNf/7HsOLWvL0jJGV6Ss37WbTroOs2rKXVz7cwg2T+rFq617ueu6Dyv0zw8bhMkePNrms3rovVdU47vzjP0dx9gP/THUxku6eLw/iuj8vqXydlRHiUGn5MR8vOyNESR32jzSqZ2v+uWpb5evpo7szaUB7ilrlcuLt8+t8/Bsm92X66B7HvL+CSw3qK7g8+fanjO3XlrbNm9Qq/+h7XqFZdgZzv31azO2fbN9HWbmje5tmMbeXlJZRWuYq+7cjvblqG2XljmFFLdm2t4R/r97OV4YduTxo8fov+HDzHgrzcyjIzeSWOcsZ1bMN3dvkMr5fO4p/Mp/zh3bkgqGd2LanhOlPvMMzV5/Kyx9s4RcvfczMSX35anFn5q/8nOtmL4l673MGdeDELgXc/mzc5d4khrU/nczBw+X0+3H9tLbW3XU2W/eUMOwOr+vvhxP7cvfzXtBvlu2NMcx+Z0Pc/U/u3pI/TBvBik27mf74O2zefTBmvg9un0h2Roj5Kz5n+hPvADCub1te+mBLzPx5TTLYfbCUa8b0ZO22ffxj6aaj8rx9wzi27zvEgcNlXPCrf3H/Vwfz3acWH5Xv79eMYmCnFnyyfR97DpZSkJtFx/wcSkrL4rZqf3R2P07u3opzfukF3nMHF3rHWvwZV47uznfG92b2uxu46W/LaJadwbfH9eKOuSsr92+f14SfTx1C+7wmXPLbBXy26yDhkPG9Cb254MROzPjDO7y//guW3DKBrHCIvjd55XjjujF0bhndzfvKB1vIyQrz/LLN/P5f6yrTX/vBGZx+76tHlf29m86MCkpPXzmS4d1axqxnbdQ2uGhAPwB/WrSeH8xewkc/mURWRoiyckdZuePAoTJm/mUp7fKyeeuG8Rw8XMacxZ9x3ewlvHvTmVGDwUDlgCjAb19fw9h+bVmy4QvOHVRIzxufi8r743P6c2KXfHq0bcagW16gV9tmnD+0I/c8/yEAXzu5C39Y8CkAs2eMZPY7G3hy4Xqq2rTrIF8+qSMfbt7DtMeODrYL1+2Mev2HBZ9WHhdgyoNvVj6/67kPolofkZ5dsolnlxx9QmgsrpvYp/J3fyyywiH6FeaxeP0XAAwrKuC+i4ZgZuRkhXnxv0Yz/mevAzB1WOeYn2WF753Zm/vmfwTAo5cP4xu/X1jte19+ShG//9c6ilp5J7E2zbMrt33rjB6VweXNH47l6UXx3xegf4cWZIRDDOqUz4IbxvGPJZvoX5hHt9a59LvpeQ4cLgOgSabXTVoxAQHgd5cVs/tAKYNve4EvDS6kc8scHnxlNVec2o1pp3Vj/Y79nNy9FQcPl/GPpZsY3Dmf/zqzNwAdWjShbZ73gCPdPhXBJfLEmxH2ulO7tsqNKntW2Lvsb3CnFvzsq0MYd99rldsm9G9Pl1ZHTvI/vWAg1/9lqVfnwjxyssKM7duWm/CC8DdHd68MLr/5+kkM6tSCDi1yAJj9rVN45v3PmHF698qu3dkzRlLmHNkZYcrLvS/8fds3PyqwAIzp2xaAk7u34vtn9WHAzfMq6xP5dwLw1g3jKMjN4g/TRjB/xWamn96Djvk58T/AACm41NETCz7hpr8tA+CXL39M62bZvLjyc974eBsvfHc0AJ/vLmHW66u5c+6RE+/D/1zDg6+sjnvcO+aurPzjjDyZV7itSgvg4y17o05ukftc+Ot/x32f+1/8iPtf/Ki6KkotdI4zieDK0d35zetrqt336StH0r8wj4tnLahM++6ZvaNOZj3bNq98fteXB0UFl3F921KYn8MTCz5hTJ82fHN0d+6b/xE3n9u/8kRUoVvrXNZui+5mrAgm00Z1q7acLZpmMm1UN9bv3M/j/45eGHfVHZN45M21XDqyKCr97EEdKp+X+b0kd10wsDItJzNMftNMZk7si5nRomkm/75+LK2bZVNW7th7sJRvj+9Fi5zMypNik8wwT0wbzoDCFjUOSn/j1CIefXMd7fKa0LNtM1Zt2UtGnFlSZsbbN44jPyeLrIwQ6+46m0+37+fP726gc0vvvX/z9ZPoVJBDs+wMCpp6MwubZnmn0Wb+z4kD2kcd96wTol8X5ufwrTOiu6UywqHKk3EoZPxr5tijvnzG0iw7g79cdQobdx4AoGN+U3Kzwuw75AXxdn6wHdWrNaN6tY57nGRQcElQSWkZ33h0ITdM7sfO/YcqAwvAL19eFZV3wv1HvkFEBhag2sBS1Tuf7Kw5kwTu2WtH0btdc/aWlDLU71Y4rVdrcjLD3DrlBEb+9OXKvFkZRy92MeP0HnxvQu/K4NIiJ5NdBw4D3jfrl1Z+ztY9JZVdFGX+N9Znrx3FgI4talXGD26fSEbIyAiHGNuvLcOKWtIkMxx3wHbigPY89Gr0397/O60bTTLDXDy8S43vFwoZt00ZcFRwyQiHauzHz2uSwba9h7hgaKeo473/4wlR+Sq+4WeG4dYpA2Ie67RetVsb8Edn9+d7E/rQJDPMzef25+ZnlsdsDVSo2n3dpVVTvuu3jiA6UPxwYl+6tsplnB/AWzTN5O0bxtUqKNSkMIHWxdAuBQztUgBATlaY5bdNZPveEvb7ASZVFFwS9OHmPfxr9XamP76Iz3bF7k+W9NQxP4eNXxyofN2/Qx6hkNEyI4thRQUsXLeTb4/rRXFRdH/1uzedycJ13nTk8f3acduUE1j+2W7O7N8uKt/U4Z35zWtruPyUIgDG9YveXu5/sw9Z/OsPKgLRPRcO4sTO+ZXdSwBj+rSNuU/FN/bXfnAGhfk5nNarNZf89i0AXvjuaLIzwke1Wv7nkhPZc7A0bjmOxZ9mnMIbH2+NGYiTJRyyyuuOTuvVhpe/f0Zgx87Nzjjq91bRLQdel2ReTu2umwpaq2bZtErJOx+hAf0Erd66N6ovVupH62bZbNtbUvl6zZ2T6X6DtyDD+z8+k1PuejnuN7WQgd8o4Edn9+PELvnkN82iTfNsJv/iDTbsPMBL3zudri2bMvjWF+jephlzrjm1sj8cvPGwh15bzW1fOoEMv2/++WWb2VtSyoUndeLDzXs46+evc8u5/bn81OgTTsVU6dV3TmbPwcPkN439zXbhuh3c9vcV/GnGyKigUWFfSal/8WBiJ+fycoeDqOndf1q0nn4d8mrVQlq8/gtCZgzsdCTv0g27yMoIkZMZZteBw1HbpHHTgL40KlXPp6GIE2V+0yxW3DbxqOtdKjTNyiC/aSYbdh7ghMIWnNT1SMvjqjN6csNfl9I+rwkZ4RDLb5sY8xidWzblzvMHRqVF9q33ad+cBdePo11edtVd+f6E3mSEQ4RDFjewAAwrasnfrx0Vd3usGYC1EYoxxnBRcecYOWMb3Dn/qDQFE6mJgkuCjtOGXspV/N5//bWTaNUs9gn6spFdeXHllqiuLQDDG7TesPMAmeHoE+0lI7pwyYiaxxpqo2JJl6quGdsrkOOLpBMFl4QpuqTCf180mJ/N/4jx/dpWdks9evmwqP77W6cM4NYpR1+xbwYlh70L3GJ1N4lI8BrL/VzqTTq2XCZUGVhOB7efN4Dnv3PkQtLRvdvwt6tPrQws4M33j7xOosLfrj416nU4ZDxw8RC+UtyJvu2bH5VfRIKn4FLPbjqnf+DHvHJ092q3f+3kroG/Z9D++6LBUa8rLiQ7FkM65/ODs/pUvg6Z0bNtc+65cHBUcBKR5NF/Wj2r2ucfhJys6rt6qpnZ2mDkVOmuOlxWt3WarjqjB29cNwZIj/qLNDYKLgmqa69YRij4X3nVK4Crqu66idp47IrhzPYXnkxEIstMdIm4sG1Et5Z8aUhhwu8Xycwqx2NM0UWk3im4JKiuYy4ZSWi51BQ86vqOg+JcC9EkM7g/n8iprU9dOZK2zZtgdSx5C/8CtumnVd9tKCLB02yxBD23rG6LL4aT8C26pkPW9Zt7QW4W/TLyaJeXzee7j1zIWFOgLcjNPGpacCJ6tW3GxcO7MG1U0THtX90yKCKSXGq5JOi5pZvrtH8ybi3aooYlJhJ9y4uHH32BXW52Bm/dMD7qZF1dbOndrhnnn9ipmhw1C4WMn14wMGrRRhFJDwouCaprw2PSwOrHRxL10vdOr1z5NJ5EWy5VV7aNp7qlgyYP7FC5oF9NhsS4AlxE0puCSz2rze1+E9Ejzk3DIiVrPLu6dakMo6h1bq26pQb74y1dWzWle+vcGnKLSDpQcElQOs48SrRbrLZVnFhlltrXI66niTzG7y6tcY07AF77wZhAV60VkdTRgP5xIbHoUttZWtNHd+en/p0nK1oo2RkhfvfPtVFHiLfmVoU0XPRARGqglkuCkjAeH4jq7iKYrJZLrFbckWtLEntPEWlcFFwSFMRJM3Jpkup8f0LvmjP5qltWpmN+DsOKCpjsTyao6Y6DdaliRSukpu5DTREWadwUXBJU1wv7AK4e07PGcYi1P53M1WN61vm9ADLDIf404xQ+2b4fgP97+9Nq8zeNuG9IxX3CaysdF/YUkeApuKTImL5tmXF6/HuOm1lgkwcqDhN5J8d4npg2PGrZlupubhWL89susYrev0NeQscSkfSl4JKgoMZcwiFj5qS+wRysBokEqdN6tYl6nfBtsP3siaxnptaOSOOj4JKgoV0Ljnnff80cW6t8p/U6+h4lsVRdSTieioB4LKvYt0iw5TKiu3cL4RN1YaTIcU3BJUFDuxx7cCms5SrB/WrZfVTbRTArWi7xWgiRKxJXNevrJ9XqPSqM7duOxTdPYET3VgntJyKNi65zaWAeu2I4o2LcXTGWqqHlzZlj2XPwMBN//kZU+pGuvNjRpboerJqWlokl3lpn8RpOmrYs0vgouDQwp/duU3MmX9WxFG8g/ujWUV3v51JXVd/+m6fFvyZHRBqHlHSLmdlFZrbczMrNrLjKtuvNbJWZfWhmZ0WkT/TTVpnZzIj0bmb2lp/+lJklNkiQoIY09jy8W8uE8sfrFqvP0LPurrO58WzvmpyvFNdt1WQRabhSNeayDLgAeD0y0cz6A1OBE4CJwK/MLGxmYeBBYBLQH7jYzwtwN3C/c64nsBOYVj9VSK153xnNA1NPrFXeipZLQwqMACcUegtWaraYSOOTkm4x59xKiDlFdgrwpHOuBFhrZquA4f62Vc65Nf5+TwJTzGwlMBa4xM/zGHAL8FBSK9AA9Glf+3ucVPya400rjjVVeXTvNoysxaB8p4IcBnWKvzpybcolIo1PQxtz6QgsiHi9wU8DWF8lfQTQCvjCOVcaI7/4qmu5/MeILlx2ShET7o9qRPL4FcNj5D7aP39Yu+nVcAzXzIhI2kpacDGzF4FYd8a60Tn3TLLetzpmNh2YDtClS/Xra8WT6AmyW+tc1m7bd0zvVVfNm2Sw52Bp5ZhKrKLfcf7ApJejV9vmnNar9VFrqnXz793Sv1BX7os0NkkLLs658cew20Yg8h67nfw04qRvB/LNLMNvvUTmj1WmWcAsgOLi4kb/NfqvV53K6x9tJRSquM4lNVXOygjxxLQRR6Wf1qsNz3/nNPq0022MRRqbhnYR5Rxgqpllm1k3oBfwNrAQ6OXPDMvCG/Sf47yz5SvAhf7+lwEpaRU1RD3bNuOKiKX4G2I07ds+Ly1vwCYi1UvVVOTzzWwDMBL4h5nNA3DOLQeeBlYAzwNXO+fK/FbJNcA8YCXwtJ8X4IfAf/mD/62Ah+u3NulDQx4iUl9SNVvsr8Bf42y7A7gjRvpcYG6M9DUcmVHW6PRok8vqrcGM2Yzq2Zrnl2+mTfNstu6peYVkEZFj1dC6xaSKk+qwUGZVl4zwJjFojENEkk3BJUEbdh5IKH9dRxOCuDlZ5bE0tCEi9UTBJUH3zvswofzJGOY4e2CHJBxVRCQ4De0iSqnBo98YVutVk+NxDXLemIg0JgouDUhtuq3G9GlbY553bzqTck0NE5EUUrdYAxcZcE7uXrtVkFvmZtG6WfZR6a1yvbS+7XVFvIgkl1ouaeThy4bVaf/+hXnMnjGSwZ3zmTywAx9s3h1QyUREoim4NCA19YrlZtf94you8lo/J3UtCHSas4hIJHWLJVkis3+1DIqINBYKLklwrDFCoUVEGgt1iyWBEX19S0bIGFaU2C2JK4+liCMiaUjBpR6sunNyrfKlMpDMnNSXTgU5qSuAiDQqCi4BysoIcai03Bs7SbPrTGac3qPWeR+7Yjjb92rhSxGJT8ElCY6lAdIsO4MbJvcLvCzJcHrvNqkugog0cAouDcSyW89KdRFERAKj2WIBSs6QiUb0RST9KLgkQeTA/ENfOyl1BRERSREFlwDFmu3Vp71uzCUixx8FlyQI8gZfIiLpSMFFREQCp+ASoGS0WHSFvoikIwWXZFBAEJHjnIJLgNTKEBHxKLgkgWKMiBzvFFxERCRwCi4BqmixqHtMRI53Ci4NnOKUiKQjBRcREQmcgkuALhnRJdVFEBFpELTkfkBW3zmZktIyfvvG2kDvE5ZetxwTEfEouAQkHLKoK/S/O7434/q1TWGJRERSR8ElCczg2+N7pboYIiIpozGXJAiyW0xEJB0puDRwGSGvq+3Unq1SXBIRkdpLSXAxs3vN7AMzW2JmfzWz/Iht15vZKjP70MzOikif6KetMrOZEendzOwtP/0pM8uq7/rUhzP7tUt1EUREai1VLZf5wADn3CDgI+B6ADPrD0wFTgAmAr8ys7CZhYEHgUlAf+BiPy/A3cD9zrmewE5gWr3WJEIyr8w3XfYvImkkJcHFOfeCc67Uf7kA6OQ/nwI86Zwrcc6tBVYBw/3HKufcGufcIeBJYIp5Z9yxwGx//8eA8+qrHlVprEVExNMQxlyuAJ7zn3cE1kds2+CnxUtvBXwREagq0kVEJIWSNhXZzF4E2sfYdKNz7hk/z41AKfDHZJWjSpmmA9MBunQJ/mp69VyJiHiSFlycc+Or225mlwPnAOOcq+xQ2gh0jsjWyU8jTvp2IN/MMvzWS2T+WGWaBcwCKC4uVieWiEiSpGq22ETgOuBLzrn9EZvmAFPNLNvMugG9gLeBhUAvf2ZYFt6g/xw/KL0CXOjvfxnwTH3Voz5oHEdE0lGqrtD/HyAbmO/PglrgnJvhnFtuZk8DK/C6y652zpUBmNk1wDwgDDzinFvuH+uHwJNm9hPgPeDh+q1K/VCXm4ikk5QEF3/acLxtdwB3xEifC8yNkb4GbzZZo6YWjIikk4YwW0xERBoZBZc0oW4xEUknWhU5QNkZISb0b8elI4sCO6bTHV1EJA0puATIzJh1aXFyjp2Uo4qIJIe6xdKE2i8ikk5qDC5mNsDMHjezRf7jMTMbVB+FExGR9FRtcDGzKcBfgVfx1gC7AngN+LO/TeqJusVEJJ3UNOZyG3Cmc25dRNoSM3sZ70r4RnU1vIiIBKOmbrGMKoEFAD8tMxkFkmi6eFJE0lFNLZdSM+vinPs0MtHMuuItzyJ19PBlxazdtq/mjLrQRUTSSE3B5WbgRTO7E3jHTysGZuKt6SV1NE63LxaRRqja4OKc+5uZrQW+B1zrJ68AvuKcW5zswomISHqq8SJKP4hcWg9lERGRRqKmqcitzexmM/tPM2tmZg+Z2TIze8bM4q5sLMHReL6IpKOaZov9L959Vypu2rUW78ZczwK/S27RJJKG80UkndTULdbOOXeDeXf0+sQ5d4+f/oGZXZ3kskkEtWBEJJ3U1HIpA/BvJ7ytyrbypJRIRETSXk0tl+5mNgevV6biOf7rbkktmURRt5iIpJOagkvk+mH/XWVb1deSBLpCX0TSUU3XubwWb5uZPYW3iKXUA12gLyLppC73cxkZWCnSUF4T3WdNRCQe3SxMREQCV+3XbzMbGm8Tx/mqyBoKERGJr6a+nfuq2fZBkAUREZHGo6YB/TH1VRCJR20kEUk/Na0tdl3E84uqbLszWYWSo5mudBGRNFLTgP7UiOfXV9k2MeCyiIhII1FTcLE4z2O9Pm6dO7gw1UUQEWlQagouLs7zWK+PWz+ZMiDVRRARaVBqmi022Mx247VScvzn+K+bJLVkAmj5FxFJTzXNFgvXV0HSTuRJvx46CLX8i4ikE12hf4y6tm5a+VwnfhGRaAouxyikiCIiEpeCSwAUZkREoqUkuJjZ7Wa2xMzeN7MXzKzQTzcze8DMVvnbh0bsc5mZfew/LotIP8nMlvr7PODfkllERFIoVS2Xe51zg5xzQ4BngR/76ZOAXv5jOvAQgJm1BG4GRgDDgZvNrMDf5yHgmxH71fvFncmMZ5otJiLpKCXBxTm3O+JlLkfmXk0BHneeBUC+mXUAzgLmO+d2OOd2AvOBif62POfcAuecAx4HzqufOhx5Xh9NJTXHRCSdpOyOV2Z2B3ApsAuoWCCzI7A+ItsGP6269A0x0kVEJIWS1nIxsxfNbFmMxxQA59yNzrnOwB+Ba5JVjiplmm5mi8xs0datW+t4rNjPRUQkiS0X59z4Wmb9IzAXb0xlI9A5YlsnP20jcEaV9Ff99E4x8scr0yxgFkBxcbFGM0REkiRVs8V6RbycwpEbj80BLvVnjZ0M7HLObQLmARPMrMAfyJ8AzPO37Tazk/1ZYpcCz9RfTTzJXA7faQk3EUlDqRpzucvM+gDlwCfADD99LjAZWAXsB74B4JzbYWa3Awv9fLc553b4z68Cfg/kAM/5j0ZHXW8ikk5SElycc1+Ok+6Aq+NsewR4JEb6IqDelyWOmi1WDyd+TUkWkXSiK/RFRCRwCi7HqL67qdQtJiLpRMElAOGQrtAXEYmk4FJHPzlvAJnh5P8akzkjTUQkaAoudTSwY4tUF0FEpMFRcBERkcApuIiISOAUXEREJHAKLg2cJouJSDpScEkXmiwmImlEwUVERAKn4CIiIoFTcBERkcApuDRwWv5FRNKRgkua0Hi+iKQTBRcREQmcgouIiAROwUVERAKn4NLAOV2jLyJpSMElTZhuRSkiaUTBRUREAqfg0sB1KmgKQKvcrBSXRESk9jJSXQCp3rVjezKgMI8z+rRJdVFERGpNwaWBywyHmHBC+1QXQ0QkIeoWExGRwCm4iIhI4BRcREQkcAouIiISOAUXEREJnIKLiIgETsFFREQCp+AiIiKBU3AREZHAKbiIiEjgUhpczOx7ZubMrLX/2szsATNbZWZLzGxoRN7LzOxj/3FZRPpJZrbU3+cB09r0IiIpl7LgYmadgQnApxHJk4Be/mM68JCftyVwMzACGA7cbGYF/j4PAd+M2G9ifZT/F1NP5KvFnTmhMK8+3k5EJK2ksuVyP3AX8KN5AAAL6klEQVQdRN1qcQrwuPMsAPLNrANwFjDfObfDObcTmA9M9LflOecWOOcc8DhwXn0UvlvrXO6+cBAZYfUsiohUlZIzo5lNATY65xZX2dQRWB/xeoOfVl36hhjpIiKSQklbct/MXgRirRV/I3ADXpdYvTKz6XjdbXTp0qW+315E5LiRtODinBsfK93MBgLdgMX+2Hsn4F0zGw5sBDpHZO/kp20EzqiS/qqf3ilG/nhlmgXMAiguLnbx8omISN3Ue7eYc26pc66tc67IOVeE15U11Dm3GZgDXOrPGjsZ2OWc2wTMAyaYWYE/kD8BmOdv221mJ/uzxC4FnqnvOomISLSGdifKucBkYBWwH/gGgHNuh5ndDiz0893mnNvhP78K+D2QAzznP0REJIVSHlz81kvFcwdcHSffI8AjMdIXAQOSVT4REUmc5tGKiEjgFFxERCRwCi4iIhI4BRcREQmcgouIiAROwUVERAKn4CIiIoFTcBERkcApuIiISOAUXEREJHAKLiIiEjgFFxERCZyCi4iIBE7BRUREAqfgIiIigVNwERGRwCm4iIhI4BRcREQkcAouIiISOAUXEREJnIKLiIgETsFFREQCp+AiIiKBU3AREZHAKbiIiEjgFFxERCRwCi4iIhI4BRcREQlcRqoLkG4evqyYw2Uu1cUQEWnQFFwSNK5fu1QXQUSkwVO3mIiIBE7BRUREAqfgIiIigVNwERGRwKUkuJjZLWa20cze9x+TI7Zdb2arzOxDMzsrIn2in7bKzGZGpHczs7f89KfMLKu+6yMiItFS2XK53zk3xH/MBTCz/sBU4ARgIvArMwubWRh4EJgE9Acu9vMC3O0fqyewE5hW3xUREZFoDa1bbArwpHOuxDm3FlgFDPcfq5xza5xzh4AngSlmZsBYYLa//2PAeSkot4iIREhlcLnGzJaY2SNmVuCndQTWR+TZ4KfFS28FfOGcK62SHpOZTTezRWa2aOvWrUHVQ0REqkjaRZRm9iLQPsamG4GHgNsB5/+8D7giWWWp4JybBczyy7fVzD45xkO1BrYFVrDUaix1aSz1ANWloWosdalrPbrWJlPSgotzbnxt8pnZb4Fn/Zcbgc4Rmzv5acRJ3w7km1mG33qJzF9T+drUJl+cMi9yzhUf6/4NSWOpS2OpB6guDVVjqUt91SNVs8U6RLw8H1jmP58DTDWzbDPrBvQC3gYWAr38mWFZeIP+c5xzDngFuNDf/zLgmfqog4iIxJeqtcXuMbMheN1i64ArAZxzy83saWAFUApc7ZwrAzCza4B5QBh4xDm33D/WD4EnzewnwHvAw/VZEREROVpKgotz7uvVbLsDuCNG+lxgboz0NXizyerTrHp+v2RqLHVpLPUA1aWhaix1qZd6mNezJCIiEpyGdp2LiIg0AgouCYi3BE1DZmbrzGypv8zOIj+tpZnNN7OP/Z8FfrqZ2QN+/ZaY2dAUl/0RM9tiZssi0hIuu5ld5uf/2Mwua0B1CWwZpHqsR2cze8XMVpjZcjP7tp+edp9LNXVJx8+liZm9bWaL/brc6qd3sxjLY/mTpp7y098ys6Ka6pgw55wetXjgTSRYDXQHsoDFQP9Ul6sW5V4HtK6Sdg8w038+E7jbfz4ZeA4w4GTgrRSXfTQwFFh2rGUHWgJr/J8F/vOCBlKXW4Dvx8jb3//7yga6+X934YbwNwh0AIb6z5sDH/nlTbvPpZq6pOPnYkAz/3km8Jb/+34amOqn/xr4lv/8KuDX/vOpwFPV1fFYyqSWS+3FXIImxWU6VlPwlsqB6CVzpgCPO88CvGuIOsQ6QH1wzr0O7KiSnGjZzwLmO+d2OOd2AvPx1q2rV3HqEk9CyyAlpcBxOOc2Oefe9Z/vAVbirYqRdp9LNXWJpyF/Ls45t9d/mek/HPGXx4r8vGYD48zMiF/HhCm41F68JWgaOge8YGbvmNl0P62dc26T/3wzUHHv5nSoY6Jlb+h1CmIZpJTwu1JOxPuWnNafS5W6QBp+LuYt8vs+sAUvWK8m/vJYlWX2t+/CW04rsLoouDR+o5xzQ/FWlL7azEZHbnReWzgtpwymc9l9DwE9gCHAJrxlkNKCmTUD/gx8xzm3O3Jbun0uMeqSlp+Lc67MOTcEb6WS4UDfVJZHwaX2qluapsFyzm30f24B/or3R/d5RXeX/3OLnz0d6pho2RtsnZxzn/snhHLgtxzpfmjQdTGzTLyT8R+dc3/xk9Pyc4lVl3T9XCo4577AW7lkJP7yWDHKVVlmf3sLvOW0AquLgkvtxVyCJsVlqpaZ5ZpZ84rnwAS8pXbm4C2VA9FL5swBLvVn+JwM7Iro6mgoEi37PGCCmRX43RsT/LSUs4CWQarnMhveKhgrnXM/i9iUdp9LvLqk6efSxszy/ec5wJl4Y0jxlseK/LwuBF72W5zx6pi4+pzRkO4PvJkvH+H1Zd6Y6vLUorzd8WZ+LAaWV5QZr2/1JeBj4EWgpZ9ueDdlWw0sBYpTXP7/w+uWOIzX9zvtWMqOt+L2Kv/xjQZUlyf8si7x/6k7ROS/0a/Lh8CkhvI3CIzC6/JaArzvPyan4+dSTV3S8XMZhLf81RK8YPhjP707XnBYBfwJyPbTm/ivV/nbu9dUx0QfukJfREQCp24xEREJnIKLiIgETsFFREQCp+AiIiKBU3AREZHAKbhIo2Bmzszui3j9fTO7JaBj/97MLqw5Z53f5yIzW2lmr1RJLzSz2f7zIZGr9AbwnvlmdlWs9xKpCwUXaSxKgAvMrHWqCxIp4uro2pgGfNM5NyYy0Tn3mXOuIrgNwbumIqgy5OOtkBvrvUSOmYKLNBaleLdv/W7VDVVbHma21/95hpm9ZmbPmNkaM7vLzP7Dvy/GUjPrEXGY8Wa2yMw+MrNz/P3DZnavmS30Fzm8MuK4b5jZHGBFjPJc7B9/mZnd7af9GO+ivofN7N4q+Yv8vFnAbcBXzbvPyFf9VRge8cv8nplN8fe53MzmmNnLwEtm1szMXjKzd/33rli19y6gh3+8eyveyz9GEzN71M//npmNiTj2X8zsefPuxXJPxO/j935Zl5rZUZ+FHD8S+VYl0tA9CCypONnV0mCgH95y+GuA3znnhpt346hrge/4+Yrw1pjqAbxiZj2BS/GWMxlmZtnAm2b2gp9/KDDAecuWVzKzQuBu4CRgJ96K1ec5524zs7F49xFZFKugzrlDfhAqds5d4x/vTrylO67wl/9428xejCjDIOfcDr/1cr5zbrffulvgB7+ZfjmH+McrinjLq723dQPNrK9f1t7+tiF4qwiXAB+a2S+BtkBH59wA/1j5NfzupRFTy0UaDeetaPs48J8J7LbQeff1KMFb8qIiOCzFCygVnnbOlTvnPsYLQn3x1sO61Lxlzt/CWwKll5//7aqBxTcMeNU5t9V5S53/Ee9GYsdqAjDTL8OreMt6dPG3zXfOVdxDxoA7zWwJ3vIsHTmyLH48o4A/ADjnPgA+ASqCy0vOuV3OuYN4rbOueL+X7mb2SzObCOyOcUw5TqjlIo3Nz4F3gUcj0krxv0iZWQjvboEVSiKel0e8Lif6/6PqOkkO74R9rXMuasFFMzsD2HdsxU+YAV92zn1YpQwjqpThP4A2wEnOucNmtg4vEB2ryN9bGZDhnNtpZoPxbgQ2A/gK3vphchxSy0UaFf+b+tN4g+MV1uF1QwF8Ce8ufYm6yMxC/jhMd7xF/eYB3zJv2XbMrLd5q09X523gdDNrbWZh4GLgtQTKsQfvlrwV5gHXmpn5ZTgxzn4tgC1+YBmD19KIdbxIb+AFJfzusC549Y7J724LOef+DPwIr1tOjlMKLtIY3QdEzhr7Ld4JfTHePS6OpVXxKV5geA6Y4XcH/Q6vS+hdfxD8N9TQG+C85eZn4i2Fvhh4xzn3THX7VPEK0L9iQB+4HS9YLjGz5f7rWP4IFJvZUryxog/88mzHGytaVnUiAfArIOTv8xRwud99GE9H4FW/i+4PwPUJ1EsaGa2KLCIigVPLRUREAqfgIiIigVNwERGRwCm4iIhI4BRcREQkcAouIiISOAUXEREJnIKLiIgE7v8DczkENBjqMxMAAAAASUVORK5CYII=\n", 153 | "text/plain": [ 154 | "
" 155 | ] 156 | }, 157 | "metadata": {}, 158 | "output_type": "display_data" 159 | } 160 | ], 161 | "source": [ 162 | "plt.plot(range(len(train_elbo)), train_elbo)\n", 163 | "plt.xlabel(\"Number of iterations\")\n", 164 | "plt.ylabel(\"ELBO\")\n", 165 | "plt.show()" 166 | ] 167 | } 168 | ], 169 | "metadata": { 170 | "kernelspec": { 171 | "display_name": "probabilistic.ai", 172 | "language": "python", 173 | "name": "probabilistic.ai" 174 | }, 175 | "language_info": { 176 | "codemirror_mode": { 177 | "name": "ipython", 178 | "version": 3 179 | }, 180 | "file_extension": ".py", 181 | "mimetype": "text/x-python", 182 | "name": "python", 183 | "nbconvert_exporter": "python", 184 | "pygments_lexer": "ipython3", 185 | "version": "3.7.0" 186 | } 187 | }, 188 | "nbformat": 4, 189 | "nbformat_minor": 2 190 | } 191 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # Day1-Day3 Tutorial of the Nordic Probabilistic AI School 2 | 3 | 4 | Make sure you have installed Python 3.6 (e.g. running the command *python -V* on the console) and the following python packages: 5 | - [Numpy](https://www.numpy.org/) 6 | - [Scipy](https://www.scipy.org/) 7 | - [Matplotlib](https://matplotlib.org/) 8 | - [Pandas](https://pandas.pydata.org/) 9 | - [seaborn](https://seaborn.pydata.org/) 10 | - [Pytorch](https://pytorch.org/) 11 | - [TorchVision](https://pypi.org/project/torchvision/) 12 | - [Pyro](http://pyro.ai/) 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: probai 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - blas=1.0=mkl 7 | - ca-certificates=2019.1.23=0 8 | - certifi=2019.3.9=py36_0 9 | - cffi=1.12.3=py36hb5b8e2f_0 10 | - cycler=0.10.0=py36hfc81398_0 11 | - freetype=2.9.1=hb4e5f40_0 12 | - intel-openmp=2019.3=199 13 | - jpeg=9b=he5867d9_2 14 | - kiwisolver=1.1.0=py36h0a44026_0 15 | - libcxx=4.0.1=hcfea43d_1 16 | - libcxxabi=4.0.1=hcfea43d_1 17 | - libedit=3.1.20181209=hb402a30_0 18 | - libffi=3.2.1=h475c297_4 19 | - libgfortran=3.0.1=h93005f0_2 20 | - libpng=1.6.37=ha441bb4_0 21 | - libtiff=4.0.10=hcb84e12_2 22 | - matplotlib=3.1.0=py36h54f8f79_0 23 | - mkl=2019.3=199 24 | - mkl_fft=1.0.12=py36h5e564d8_0 25 | - mkl_random=1.0.2=py36h27c97d8_0 26 | - ncurses=6.1=h0a44026_1 27 | - ninja=1.9.0=py36h04f5b5a_0 28 | - numpy=1.16.4=py36hacdab7b_0 29 | - numpy-base=1.16.4=py36h6575580_0 30 | - olefile=0.46=py36_0 31 | - openssl=1.1.1c=h1de35cc_1 32 | - pandas=0.24.2=py36h0a44026_0 33 | - patsy=0.5.1=py36_0 34 | - pillow=6.0.0=py36hb68e598_0 35 | - pip=19.1.1=py36_0 36 | - pycparser=2.19=py36_0 37 | - pyparsing=2.4.0=py_0 38 | - python=3.6.8=haf84260_0 39 | - python-dateutil=2.8.0=py36_0 40 | - pytorch=1.1.0=py3.6_0 41 | - pytz=2019.1=py_0 42 | - readline=7.0=h1de35cc_5 43 | - scipy=1.2.1=py36h1410ff5_0 44 | - seaborn=0.9.0=py36_0 45 | - setuptools=41.0.1=py36_0 46 | - six=1.12.0=py36_0 47 | - sqlite=3.28.0=ha441bb4_0 48 | - statsmodels=0.9.0=py36h1d22016_0 49 | - tk=8.6.8=ha441bb4_0 50 | - torchvision=0.3.0=py36_cuNone_1 51 | - tornado=6.0.2=py36h1de35cc_0 52 | - wheel=0.33.4=py36_0 53 | - xz=5.2.4=h1de35cc_4 54 | - zlib=1.2.11=h1de35cc_3 55 | - zstd=1.3.7=h5bba6e5_0 56 | prefix: /anaconda3/envs/probai 57 | --------------------------------------------------------------------------------