├── .gitattributes ├── 010-understanding-DL-models.odp ├── 020-dl-pca-ica.ipynb ├── 030-2d-layers.ipynb ├── 040-bayesian-decision-theory.ipynb ├── 050-parameter-estimation.ipynb ├── PDF ├── 010-understanding-DL-models.pdf ├── 020-dl-pca-ica.pdf ├── 030-2d-layers.pdf ├── 040-bayesian-decision-theory.pdf ├── 050-parameter-estimation.pdf └── 100-Large Scale DL.pdf ├── chapter.png ├── flex.py ├── helpers.py ├── imgclass.py ├── imgimg.py ├── init.py ├── layers.py ├── nb.png ├── note.png ├── roadmodel.py ├── sicon.png ├── summary.png ├── texture.png ├── texture3.png ├── texture4.png └── trainers.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-language=Python 2 | *.tex linguist-detectable=false 3 | *.html linguist-detectable=false 4 | *.pdf linguist-detectable=false 5 | *.md linguist-detectable=false 6 | html/* linguist-vendored 7 | Figures/* linguist-vendored 8 | pdf/* linguist-vendored 9 | -------------------------------------------------------------------------------- /010-understanding-DL-models.odp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tmbdev-tutorials/dl-2018/7b42601e90a9cad1d0fb8d399c55c178561313ab/010-understanding-DL-models.odp -------------------------------------------------------------------------------- /050-parameter-estimation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Bayesian vs Maximum Likelihood Parameter Estimation\n", 8 | "\n", 9 | "Thomas Breuel\n", 10 | "\n", 11 | "Deep Learning Summer School 2018, Genoa, Italy\n", 12 | "\n", 13 | "_This is a low-quality PDF conversion; to see the original notebook, please go to: [github.com/tmbdev/dl-2018](github.com/tmbdev/dl-2018)_\n", 14 | "\n", 15 | "The notebooks are directly executable." 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 1, 21 | "metadata": { 22 | "slideshow": { 23 | "slide_type": "skip" 24 | } 25 | }, 26 | "outputs": [ 27 | { 28 | "name": "stdout", 29 | "output_type": "stream", 30 | "text": [ 31 | "Populating the interactive namespace from numpy and matplotlib\n" 32 | ] 33 | } 34 | ], 35 | "source": [ 36 | "%pylab inline\n", 37 | "from scipy import stats\n", 38 | "from scipy.stats import norm" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": { 44 | "slideshow": { 45 | "slide_type": "slide" 46 | } 47 | }, 48 | "source": [ 49 | "# Tiny Review" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": { 55 | "slideshow": { 56 | "slide_type": "slide" 57 | } 58 | }, 59 | "source": [ 60 | "## Notation for Classification Problems\n", 61 | "\n", 62 | "- $\\omega$: class\n", 63 | "- $p(\\omega, x)$: joint distribution defining the classification problem\n", 64 | "- $D=\\{(\\omega_1, x_1),...\\}$: i.i.d. sample from $p$\n", 65 | "- $\\arg\\max_\\omega P(\\omega | x)$: optimal zero-one loss function classifier\n", 66 | "- $\\hat{\\theta}$: a maximum likelihood estimate\n" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": { 72 | "slideshow": { 73 | "slide_type": "slide" 74 | } 75 | }, 76 | "source": [ 77 | "## Bayes Rule \n", 78 | "\n", 79 | "$$ P( \\omega | x ) = \\frac{ p(x | \\omega) P(\\omega) } { p(x) } $$\n", 80 | "\n", 81 | "- $P(\\omega | x)$ posterior distribution at input $x$\n", 82 | "- $p(x | \\omega)$ class conditional density (\"generative model\")\n", 83 | "- $p(x)$ sample distribution or \"evidence\"\n", 84 | "- $P(\\omega)$ prior class probabilities\n" 85 | ] 86 | }, 87 | { 88 | "cell_type": "markdown", 89 | "metadata": { 90 | "slideshow": { 91 | "slide_type": "slide" 92 | } 93 | }, 94 | "source": [ 95 | "# Maximum Likelihood Estimation\n", 96 | "\n" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "metadata": { 102 | "slideshow": { 103 | "slide_type": "slide" 104 | } 105 | }, 106 | "source": [ 107 | "## Neural Networks and Maximum Likelihood Estimation\n", 108 | "\n", 109 | "A DNN is ultimate a parameterized function $f_\\theta(x)$, with $\\theta\\in\\mathbb{R}^{1000...}$.\n", 110 | "\n", 111 | "We're training the DNN on some random training sample $D$ drawn according to some joint distribution $p(y, x)$, and we try to maximize $\\theta$ so that the model matches the dataset best under some loss function." 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "metadata": { 117 | "slideshow": { 118 | "slide_type": "slide" 119 | } 120 | }, 121 | "source": [ 122 | "## \"One Parameter Generative Network\"\n", 123 | "\n", 124 | "Let's drop all the complexity of classification etc. and just focus on parameter estimation. Let's use a one-parameter model:\n", 125 | "\n", 126 | "$$p_\\theta(x) = (2\\pi)^{-1/2} e^{-\\frac{(x-\\theta)^2}{2}}$$\n", 127 | "\n", 128 | "Our training set $D$ consists of $N$ samples from $p_\\theta(x)$ and we need to find $\\theta$.\n", 129 | "\n", 130 | "You might call this a \"one parameter generative network\"." 131 | ] 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "metadata": { 136 | "slideshow": { 137 | "slide_type": "slide" 138 | } 139 | }, 140 | "source": [ 141 | "## Maximum Likelihood Estimate\n", 142 | "\n", 143 | "How do we find $\\theta$? You already know a formula for finding $\\theta$:\n", 144 | "\n", 145 | "$$ \\hat{\\theta} = \\frac{1}{N} \\sum_{x\\in D} x $$\n", 146 | "\n", 147 | "This is the _maximum likelihood estimate_ of $\\theta$.\n", 148 | "\n", 149 | "Let's derive this." 150 | ] 151 | }, 152 | { 153 | "cell_type": "markdown", 154 | "metadata": { 155 | "slideshow": { 156 | "slide_type": "slide" 157 | } 158 | }, 159 | "source": [ 160 | "## Univariate Normal Density\n", 161 | "\n", 162 | "For the univariate normal density, $\\theta = (\\mu,\\sigma)$, but we assume $\\sigma=1$ is known. We write:\n", 163 | "\n", 164 | "$$p(x|\\theta) = p_\\theta(x) = (2\\pi)^{-1/2} e^{-\\frac{(x-\\theta)^2}{2}}$$" 165 | ] 166 | }, 167 | { 168 | "cell_type": "markdown", 169 | "metadata": { 170 | "slideshow": { 171 | "slide_type": "slide" 172 | } 173 | }, 174 | "source": [ 175 | "## Univariate Normal Density (2)\n", 176 | "\n", 177 | "Now we can write a density for the entire dataset:\n", 178 | "\n", 179 | "$$ p(D|\\theta) = p(x_1,\\ldots,x_n|\\theta) $$\n", 180 | "$$ = p(x_1|\\theta)\\cdot\\ldots\\cdot p(x_n|\\theta)$$\n", 181 | "$$ = \\prod_{i=1}^n p(x_i|\\theta)$$\n", 182 | "\n", 183 | "We call this the _likelihood_ of the data.\n", 184 | "\n", 185 | "Note that likelihoods are parameterized densities viewed as functions of the parameters." 186 | ] 187 | }, 188 | { 189 | "cell_type": "markdown", 190 | "metadata": { 191 | "slideshow": { 192 | "slide_type": "slide" 193 | } 194 | }, 195 | "source": [ 196 | "## Maximum Likelihood Estimate\n", 197 | "\n", 198 | "The _maximum likelihood estimate_ is given by:\n", 199 | "\n", 200 | "$$\\hat{\\theta} = \\arg\\max_\\theta ~ p(D|\\theta)$$\n", 201 | "\n", 202 | "This seems like a reasonable thing to do: choose the parameter that was\n", 203 | "most likely to produce the data set." 204 | ] 205 | }, 206 | { 207 | "cell_type": "markdown", 208 | "metadata": { 209 | "slideshow": { 210 | "slide_type": "slide" 211 | } 212 | }, 213 | "source": [ 214 | "## Maximum Likelihood Instability\n", 215 | "\n", 216 | "However, in general, there is little reason why the maximum of the likelihood function\n", 217 | "should mean anything.\n", 218 | "\n", 219 | "For example, we can easily modify the likelihood function to put a tiny spike in it\n", 220 | "that moves the maximum somewhere arbitrary without actually changing the problem much at all.\n", 221 | "\n", 222 | "For bimodal densities, the Maximum Likelihood solution can jump between the two peaks with tiny changes in the data." 223 | ] 224 | }, 225 | { 226 | "cell_type": "markdown", 227 | "metadata": { 228 | "slideshow": { 229 | "slide_type": "slide" 230 | } 231 | }, 232 | "source": [ 233 | "## Maximum Likelihood Estimation of the Mean of a Normal Density\n", 234 | "\n", 235 | "Let's derive the ML estimator explicitly.\n", 236 | "\n", 237 | "We're trying to maximize:\n", 238 | "\n", 239 | "$$p(D|\\theta) = \\prod_i (2\\pi)^{-1/2} e^{-\\frac{(x_i-\\theta)^2}{2}}$$\n", 240 | "\n", 241 | "Let's take logarithms\n", 242 | "\n", 243 | "$$l(\\mu) = \\sum_i \\log p(x_i|\\theta) = \\hbox{const} - {1\\over 2} (x_i-\\theta)\\cdot(x_i-\\theta)$$" 244 | ] 245 | }, 246 | { 247 | "cell_type": "markdown", 248 | "metadata": { 249 | "slideshow": { 250 | "slide_type": "slide" 251 | } 252 | }, 253 | "source": [ 254 | "## Maximizing the Likelihood\n", 255 | "\n", 256 | "If $l$ is sufficiently well behaved (it is), then a necessary condition for a local maximum is that the gradient in the parameter to be estimated is 0. The gradient is:\n", 257 | "\n", 258 | "$$\\nabla_\\theta l(\\theta) = \\nabla_\\theta(\\hbox{const} - {1\\over 2} (x_i-\\theta)\\cdot(x_i-\\theta)) = \\sum_i (x_i-\\theta)$$\n", 259 | "\n", 260 | "$$ \\nabla_\\theta l(\\theta)= \\sum_i (x_i-\\theta) = 0 \\Rightarrow \\sum_i x_i = n \\theta $$\n", 261 | "\n", 262 | "$$ \\hat{\\theta} = \\frac{1}{n}\\sum x_i $$\n", 263 | "\n", 264 | "Therefore: the arithmetic mean is the maximum likelihood estimator for the parameter of the normal density with known variance." 265 | ] 266 | }, 267 | { 268 | "cell_type": "markdown", 269 | "metadata": { 270 | "slideshow": { 271 | "slide_type": "slide" 272 | } 273 | }, 274 | "source": [ 275 | "## Connection to Classification\n", 276 | "\n", 277 | "Let's say you have a classification problem in which you know that both classes are distributed according to $p_\\theta(x)$ for two different parameters, $\\theta_1$ and $\\theta_2$.\n", 278 | "\n", 279 | "Recipe for classification:\n", 280 | "- compute $\\theta_1$ and $\\theta_2$ from the training samples\n", 281 | "- apply Bayes formula to derive $P(\\omega | x; \\theta_1, \\theta_2)$ using the parametric densities\n", 282 | "\n", 283 | "Other recipe:\n", 284 | "- write down $P(\\omega | x; \\theta_1, \\theta_2)$\n", 285 | "- use gradient descent to optimize training set error" 286 | ] 287 | }, 288 | { 289 | "cell_type": "markdown", 290 | "metadata": { 291 | "slideshow": { 292 | "slide_type": "slide" 293 | } 294 | }, 295 | "source": [ 296 | "# Bayesian vs Maximum Likelihood\n", 297 | "\n", 298 | "Let's say that the $\\theta$ are the parameters of a DNN classifier $P_\\theta(\\omega|x)$. In typical DL, we estimate a single parameter vector $\\hat{\\theta}$ and use that for classification:\n", 299 | "\n", 300 | "$$\\hat{\\theta} = \\arg\\max_\\theta p(D | \\theta)$$\n", 301 | "\n", 302 | "In Bayesian methods, we average out our classifications over all possible parameter vectors:\n", 303 | "\n", 304 | "$$P(\\omega | x) \\propto \\int P_\\theta(\\omega | x) p(\\theta | D) d\\theta$$" 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": 2, 310 | "metadata": { 311 | "slideshow": { 312 | "slide_type": "slide" 313 | } 314 | }, 315 | "outputs": [], 316 | "source": [ 317 | "# parameter estimate for nornmal density, unknown mean, sigma=1\n", 318 | "\n", 319 | "def p_x_theta(x, theta, sigma=1):\n", 320 | " return exp(-(x-theta)**2/(2*sigma**2))/sqrt(2*pi*sigma**2)\n", 321 | "def p_theta_D(theta, D):\n", 322 | " return prod([p_x_theta(x, theta) for x in D], axis=0)\n", 323 | "\n", 324 | "D = [0.7, 3.3]" 325 | ] 326 | }, 327 | { 328 | "cell_type": "code", 329 | "execution_count": 3, 330 | "metadata": { 331 | "code_folding": [], 332 | "slideshow": { 333 | "slide_type": "slide" 334 | } 335 | }, 336 | "outputs": [ 337 | { 338 | "data": { 339 | "text/plain": [ 340 | "[]" 341 | ] 342 | }, 343 | "execution_count": 3, 344 | "metadata": {}, 345 | "output_type": "execute_result" 346 | }, 347 | { 348 | "data": { 349 | "image/png": "\n", 350 | "text/plain": [ 351 | "" 352 | ] 353 | }, 354 | "metadata": {}, 355 | "output_type": "display_data" 356 | } 357 | ], 358 | "source": [ 359 | "# plot\n", 360 | "xs = linspace(-2, 6, 1000)\n", 361 | "thetas = linspace(-1, 5, 100)\n", 362 | "# p(theta|D)\n", 363 | "plot(thetas, p_theta_D(thetas, D), color=\"blue\")\n", 364 | "# arg max p(theta|D)... or mean(D)\n", 365 | "plot([mean(D)]*2, [0, 0.03], color=\"red\")" 366 | ] 367 | }, 368 | { 369 | "cell_type": "code", 370 | "execution_count": 4, 371 | "metadata": { 372 | "code_folding": [], 373 | "slideshow": { 374 | "slide_type": "slide" 375 | } 376 | }, 377 | "outputs": [ 378 | { 379 | "data": { 380 | "text/plain": [ 381 | "[]" 382 | ] 383 | }, 384 | "execution_count": 4, 385 | "metadata": {}, 386 | "output_type": "execute_result" 387 | }, 388 | { 389 | "data": { 390 | "image/png": "\n", 391 | "text/plain": [ 392 | "" 393 | ] 394 | }, 395 | "metadata": {}, 396 | "output_type": "display_data" 397 | } 398 | ], 399 | "source": [ 400 | "# p(x|D) using Bayesian vs ML\n", 401 | "\n", 402 | "def bayesian_integral(x, D):\n", 403 | " thetas = linspace(-5, 10, 1000)\n", 404 | " return sum(p_x_theta(x, thetas) * p_theta_D(thetas, D)) / sum(p_theta_D(thetas, D)) \n", 405 | "\n", 406 | "ps = [(x, bayesian_integral(x, D)) for x in xs]\n", 407 | "\n", 408 | "ps = array(ps); plot(ps[:,0], ps[:,1])\n", 409 | "plot(xs, p_x_theta(xs, mean(D)), color=\"red\")" 410 | ] 411 | }, 412 | { 413 | "cell_type": "code", 414 | "execution_count": 5, 415 | "metadata": { 416 | "slideshow": { 417 | "slide_type": "slide" 418 | } 419 | }, 420 | "outputs": [ 421 | { 422 | "data": { 423 | "text/plain": [ 424 | "[]" 425 | ] 426 | }, 427 | "execution_count": 5, 428 | "metadata": {}, 429 | "output_type": "execute_result" 430 | }, 431 | { 432 | "data": { 433 | "image/png": "\n", 434 | "text/plain": [ 435 | "" 436 | ] 437 | }, 438 | "metadata": {}, 439 | "output_type": "display_data" 440 | } 441 | ], 442 | "source": [ 443 | "# overparameterizing the distribution\n", 444 | "plot(ps[:,0], ps[:,1])\n", 445 | "plot(xs, p_x_theta(xs, mean(D)), color=\"black\", alpha=0.3)\n", 446 | "plot(xs, p_x_theta(xs, mean(D), var(D)**.5), color=\"red\")" 447 | ] 448 | }, 449 | { 450 | "cell_type": "markdown", 451 | "metadata": { 452 | "slideshow": { 453 | "slide_type": "slide" 454 | } 455 | }, 456 | "source": [ 457 | "## Maximum Likelihood vs Bayesian Estimates\n", 458 | "\n", 459 | "The Bayesian estimate $p(x|D)$ usually is quite different from the ML estimate $p(x|\\hat{\\theta}(D))$. It often doesn't even have the same form.\n", 460 | "\n", 461 | "ML estimates tend to be \"overtrained\" / underestimate variance. Bayesian estimates can never be \"overtrained\".\n", 462 | "\n", 463 | "Using an overparameterization accidentally can compensate for this to some degree (that probably makes DNNs work better)." 464 | ] 465 | }, 466 | { 467 | "cell_type": "markdown", 468 | "metadata": { 469 | "slideshow": { 470 | "slide_type": "slide" 471 | } 472 | }, 473 | "source": [ 474 | "# Maximum A-Posteriori Estimate" 475 | ] 476 | }, 477 | { 478 | "cell_type": "markdown", 479 | "metadata": { 480 | "slideshow": { 481 | "slide_type": "slide" 482 | } 483 | }, 484 | "source": [ 485 | "## Maximum A Posteriori Estimate\n", 486 | "\n", 487 | "By analogy to Bayesian methods, we can also ``multiply in'' a prior:\n", 488 | "\n", 489 | "$$\\hat{\\theta} = \\arg\\max_\\theta ~ p(D|\\theta) p(\\theta)$$\n", 490 | "\n", 491 | "This is called the _maximum a-posterior estimate_ (MAP) for the parameter $\\theta$." 492 | ] 493 | }, 494 | { 495 | "cell_type": "markdown", 496 | "metadata": { 497 | "slideshow": { 498 | "slide_type": "slide" 499 | } 500 | }, 501 | "source": [ 502 | "## MAP vs Bayesian Methods\n", 503 | "\n", 504 | "Although MAP looks like the derivation may have involved _Bayes rule_, it is\n", 505 | "not a Bayesian method at all.\n", 506 | "Bayesian methods are *not* methods that involve Bayes rule somewhere.\n", 507 | "\n", 508 | "**Bayesian methods are methods that result in decisions that minimize expected loss.**\n", 509 | "\n", 510 | "**Bayesian methods are NOT methods that happen to use Bayes rule somewhere.**" 511 | ] 512 | }, 513 | { 514 | "cell_type": "markdown", 515 | "metadata": { 516 | "slideshow": { 517 | "slide_type": "slide" 518 | } 519 | }, 520 | "source": [ 521 | "# Summary\n", 522 | "\n", 523 | "We know how to do classification / estimation optimally and correctly: Bayesian methods.\n", 524 | "\n", 525 | "Why aren't we using them with DNNs? Because they are computationally prohibitive.\n", 526 | "\n", 527 | "However, several methods attempt to approximate Bayesian solutions:\n", 528 | "- dropout methods\n", 529 | "- ensemble methods\n", 530 | "- some forms of variational methods" 531 | ] 532 | }, 533 | { 534 | "cell_type": "markdown", 535 | "metadata": { 536 | "slideshow": { 537 | "slide_type": "slide" 538 | } 539 | }, 540 | "source": [ 541 | "# Uniform Distribution Example" 542 | ] 543 | }, 544 | { 545 | "cell_type": "markdown", 546 | "metadata": { 547 | "slideshow": { 548 | "slide_type": "slide" 549 | } 550 | }, 551 | "source": [ 552 | "Bayesian vs Maximum Likelihood Example\n", 553 | "==============================\n", 554 | "\n", 555 | "Let's apply this to another, simple example and compare maximum likelihood and bayesian approaches\n", 556 | "~\n", 557 | "Assume that the samples $x$ come from a uniform density over the interval $[0,\\theta]$\n", 558 | "\n", 559 | "$$p(x|\\theta) = U(x;0,\\theta) = 1/\\theta \\cdot \\lfloor x\\in[0,\\theta]\\rfloor$$\n", 560 | "\n", 561 | "We also assume a prior\n", 562 | "\n", 563 | "$$p(\\theta) = U(\\theta;0,10)$$\n", 564 | "\n", 565 | "That is, $\\theta$ is distributed uniformly over the interval $[0,10]$\n", 566 | "\n", 567 | "Let's assume we see a sequence of training examples $D = \\\\{4,7,2\\\\}$" 568 | ] 569 | }, 570 | { 571 | "cell_type": "code", 572 | "execution_count": 8, 573 | "metadata": { 574 | "slideshow": { 575 | "slide_type": "slide" 576 | } 577 | }, 578 | "outputs": [ 579 | { 580 | "data": { 581 | "text/plain": [ 582 | "[]" 583 | ] 584 | }, 585 | "execution_count": 8, 586 | "metadata": {}, 587 | "output_type": "execute_result" 588 | }, 589 | { 590 | "data": { 591 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAEACAYAAABfxaZOAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAEpBJREFUeJzt3W2MHeV5xvHr2jhWeKsbKuIIvzUN1GlIE5cPjlta5bS0iUMFjlRR2ZGahirIamNKQtWSUlWsP0FbtRURtMGKi0IFIgJKcSRCDaLbFjmAIUCA2Ng0jWMb1wkECCRAjffuhzO2x0frPTNnZnfm2fn/JMt7zpk5exueufZ57pk564gQAKAbxpouAAAwewh9AOgQQh8AOoTQB4AOIfQBoEMIfQDokEKhb3u17Z22d9m+corXP2n7yezPg7Y/WHRfAMDs8bDr9G2PSdol6XxJz0vaLmltROzMbbNK0o6IeMX2aknjEbGqyL4AgNlTZKa/UtLuiNgTEYck3SZpTX6DiHgoIl7JHj4kaVHRfQEAs6dI6C+StDf3eJ+OhfpUPiPp6yPuCwCYQfPqfDPbvy7pEkm/Wuf7AgDqUST090tamnu8OHvuONnJ202SVkfES2X2zfbnQ4AAoKSIcJnti7R3tks6y/Yy2/MlrZW0Jb+B7aWS7pT0exHx32X2HSg+yT9XX3114zVQf/N1UH+af1KufxRDZ/oRcdj2Bklb1f8hsTkidthe3385Nkn6S0mnS/oH25Z0KCJWnmjfkSoFAFRWqKcfEfdKWj7w3I25ry+VdGnRfQEAzeCO3Br0er2mS6iE+ptF/c1Kvf6yht6cNVtsR1tqAYAU2FbMwIlcAMAcQegDQIcQ+gDQIYQ+AHQIoQ8AHULoA0CHEPoA0CGEPgB0CKEPAB1C6ANAhxD6ANAhhD4AdAihDwAdUuvvyO2qa6+Vnnuu6SrQFp/7nPSBDzRdBTA1Qr8GN9wgXXaZdPrpTVeCpt18s/SNbxD6aC9Cvybr1klLljRdBZr2yCNNVwBMj55+DfjdL8hjPKDNCH0A6BBCvyYu9QvLMFcxDtB2hH4NWM4jj/GANiP0a8IMDxLjAO1H6ANAhxD6NWA5jzzGA9qM0K8Jy3pIjAO0H6FfA2Z2yGM8oM0IfQDoEEK/JizrITEO0H6Efg1YziOP8YA2I/QBoEMI/ZqwrIfEOED7Efo1YDmPPMYD2ozQB4AOIfRrwrIeEuMA7Ufo14DlPPIYD2gzQr8mzPAgMQ7QfoQ+UDNm+mgzQr8GHOQAUkHo14RlPSTGAdqP0K8BM33kMR7QZoQ+AHQIoV8TlvWQGAdoP0K/Biznkcd4QJsR+gDQIYR+TVjWQ+qPA2b6aDNCvwYc5ABSUSj0ba+2vdP2LttXTvH6ctvbbL9h+4qB175r+0nbj9t+pK7C24aZPiTGAdpv3rANbI9Jul7S+ZKel7Td9t0RsTO32YuSLpP0iSneYlJSLyJeqqFeoPVY+aHNisz0V0raHRF7IuKQpNskrclvEBEvRMRjkt6aYn8X/D7J4iAHkIoiYbxI0t7c433Zc0WFpPtsb7d9aZniUsKyHhLjAO03tL1Tg/Mi4oDtM9QP/x0R8eBUG46Pjx/9utfrqdfrzUJ51THTRx7jATNlYmJCExMTld6jSOjvl7Q093hx9lwhEXEg+/sHtu9Sv100NPQBAMcbnAxv3Lix9HsUae9sl3SW7WW250taK2nLNNsfXeDaPtn2qdnXp0j6qKSnS1eZAJb1kBgHaL+hM/2IOGx7g6St6v+Q2BwRO2yv778cm2wvlPSopNMkTdq+XNL7JZ0h6S7bkX2vWyJi60z9Y5rCch55jAe0WaGefkTcK2n5wHM35r4+KGnJFLu+JmlFlQIBAPWZ05dSziaW9ZD4GAa0H6FfAw5yAKkg9GvCTB8S4wDtR+gDNWPlhzYj9GvAQQ4gFYR+TVjWQ2IcoP0IfaBmrPzQZoR+DTjIAaSC0K8Jy3pIjAO0H6FfA2b6yGM8oM0IfQDoEEK/JizrIfExDGg/Qr8GHOQAUkHoA0CHEPo1ob0DifYO2o/QrwEHOYBUEPo1YaYPiXGA9iP0gZqx8kObEfo14CAHkApCvyYs6yExDtB+hH4NmOkjj/GANiP0AaBDCP2asKyHxHX6aD9CvwYc5ABSQegDQIcQ+jWhvQOJ9g7aj9CvAQc5gFQQ+jVhpg+JcYD2I/SBmrHyQ5sR+jXgIAeQCkK/JizrITEO0H6Efg2Y6SOP8YA2I/QBoEMI/ZqwrIfEdfpoP0K/BhzkAFJB6ANAhxD6NaG9A4n2DtqP0K8BBzmAVBD6ANAhhH5NaO9Aor2D9iP0AaBDCP2aMNOHxEwf7UfoA0CHEPoVMasDkBJCH6gR7R20HaFfEQc4gJQQ+gDQIYVC3/Zq2ztt77J95RSvL7e9zfYbtq8os+9cwJU7OIL2DtpuaOjbHpN0vaSPSTpH0jrb7xvY7EVJl0n6mxH2TRoHOICUFJnpr5S0OyL2RMQhSbdJWpPfICJeiIjHJL1Vdl8AwOyZV2CbRZL25h7vUz/Mi6iybzrOfFRfe/ZA01WgBXaG9Ialrz3bdCUoYsmCJVrx7hVNlzGrioT+rBkfHz/6da/XU6/Xa6yWoiKkuPh3dN3DZ+ukt5/UdDlo2O5J6fCY9MY3m64Ew7z65qt68fUX9dQfPtV0KYVNTExoYmKi0nsUCf39kpbmHi/Oniui1L750E/K2GHdtOYmLVmwpOlK0LBrrpF+9Lp0zbqmK8Ewz3z/GV18+8VNl1HK4GR448aNpd+jSE9/u6SzbC+zPV/SWklbptk+fy1L2X0BADNo6Ew/Ig7b3iBpq/o/JDZHxA7b6/svxybbCyU9Kuk0SZO2L5f0/oh4bap9Z+xf04D+1Tshc90mkBTbCnXv8rtCPf2IuFfS8oHnbsx9fVDSlL2NqfadcxyyCH1wnX5KLCs6+D+LO3Ir6uCYAZAwQr8WtHeA1HS1vUPo14H2DjK0d9JBewcj6eCYAZAwQr8WtHeA1NDeweho7yBDeycdtHcwkg6OGQAJI/TrYNo76GOmnw7aO6iA9g6QGto7GEkHxwyAhBH6daC9gwztnXTQ3kEFtHeA1NDewUg6OGYAJIzQrwPtHWRo76SD9g5GcvTz9GnvAEmhvQMAmPMI/TrQ3kGG9k46aO9gJBzgAFJC6NeCnj6OYSKQBnr6GB3tHWQYBumgvYORdHCiACBhhH4taO+gjxO56aC9g9HR3gGSQ3sHI+ngRAFAwgj9WtDeQR/tnXTQ3sFIjn4MA+0dICm0dwDUooOTRySE0K+Dae+gjwVfOmjvYCS0d4A00d4BUIsOTh6REEK/DrR3kGHBlw7aOxgJ7R0gTbR3AFTGdfpoO0K/DrR3gOTQ3sFIaO8AaaK9A6Ay2jtoO0K/ogjR3gESRHsHo+OjlZHTwRxJEu0dAJXxsx9tR+hXdGRWR3sHSAvtHYzkyKChvYMjOpgjSaK9A6Ayfvaj7Qj9iiaZ1gFJor2DkUSEFEzvcEwHcyRJtHcAVEZ7B21H6Fc0yUwfSBLtHYykP2gIffTxMQzpoL0zDdurbe+0vcv2lSfY5ou2d9t+wvYv5Z7/ru0nbT9u+5G6CgcAlDdv2Aa2xyRdL+l8Sc9L2m777ojYmdvm45LeGxFn2/6wpH+UtCp7eVJSLyJeqr36FqC9g0HM9NNAe+fEVkraHRF7IuKQpNskrRnYZo2kmyUpIh6WtMD2wuw1F/w+SeovDwl99HEiNx20d05skaS9ucf7suem22Z/bpuQdJ/t7bYvHbXQturgRAFAwoa2d2pwXkQcsH2G+uG/IyIenGrD8fHxo1/3ej31er1ZKK8aTuRiEBOBNKTY3pmYmNDExESl9ygS+vslLc09Xpw9N7jNkqm2iYgD2d8/sH2X+u2ioaEPpIj2DmbS4GR448aNpd+jSHtnu6SzbC+zPV/SWklbBrbZIulTkmR7laSXI+Kg7ZNtn5o9f4qkj0p6unSVLcaJXCBNXe3pD53pR8Rh2xskbVX/h8TmiNhhe33/5dgUEffYvsD2c5J+LOmSbPeFku6yHdn3uiUits7MP6UZnMjFoMQ6Bp2VYnunDoV6+hFxr6TlA8/dOPB4wxT7/Y+kFVUKBFJCewdtN2cvpZwtMUl7B0hRV9s7hH5FIX4pOo7hYxjS0dX2DqEPAB1C6Fc0OcmJXByvg5PHJNHewUhC9PRxDCdy09HVtiyhXxGzOgApIfQr4jp9DGIikAZny7Kuncwl9Cvis3eQR3snPV3r6xP6FXVskgAgcYR+RZzIxSAmAuno4rX6hH5ltHdwDO2dtHTxsk1Cv6KOTRIAJI7QryiCj2HAMXwMQ1po76A0evpAumjvoLSOTRJQAGMCbUboV8TNWcjjRG5aaO+gtODXJQLJor2D0jo2SUABjAm0GaFfEe0d5NHeSQvtHZTGZ+8A6aK9A6Cyjk0ckRhCvyJm+sijvZMW2jsobTJC5uod5HQsQ5JGewcAMKcR+hVx9Q7yaO+khfYOSpucJPRxvI5lSNJo7wCohJk+2o7Qr4hP2QTSRXsHpXHJJgZ1LEOSRnsHQCW0d9B2hH5Fk/zmLCBZtHcAVNaxDEFiCP2K+Dx95NHeSQs9fZQ2yYlcDGCmnw7aOwCAOY3Qr4yZPo6hvZMW2jsobXKSnj6O17FuQdJo72AkzO5wBGMBbUfoV8THMADpor2D0rh6B4M61i1IGu0dAJXQ3kHbEfoV8YFrQLpo76A07sjFoI51C5JGewdAJbR30HaEfkV8yiYGdWzimDTaOyiNX4wOpIv2zgnYXm17p+1dtq88wTZftL3b9hO2V5TZF5graO+g7YaGvu0xSddL+pikcySts/2+gW0+Lum9EXG2pPWSvlR039RFhCb3/qTpMiqZmJhouoRK2lZ/2Ylj2+ovK+X6bWvbf21ruoxZVWSmv1LS7ojYExGHJN0mac3ANmsk3SxJEfGwpAW2FxbcN2n90H+96TIqSfmglai/aSnXbxH6U1kkaW/u8b7suSLbFNkXmDNo76Dt5s3Q+7Zm6P/focNa8mefmLH3fzNe5eodHDU2Jm3bJl14YfF9nn1WeuyxmatppqVc/w9/8W268T/u1q2ff25G3n+e52v/3905I+89Kg87c217laTxiFidPf6CpIiIv8pt8yVJ/x4RX80e75T0EUnvGbZv7j26dQodAGoQUe7u0CIz/e2SzrK9TNIBSWslrRvYZoukz0r6avZD4uWIOGj7hQL7jlQ4AKC8oaEfEYdtb5C0Vf1zAJsjYoft9f2XY1NE3GP7AtvPSfqxpEum23fG/jUAgGkNbe8AAOaO1tyRa/uvbe/Ibu660/ZPNV1TESnffGZ7se0HbD9j+ynbf9x0TWXZHrP9Tdtbmq6lLNsLbN+ejftnbH+46ZrKsP1520/b/pbtW2zPb7qm6djebPug7W/lnnun7a22n7X9b7YXNFnjdE5Qf+ncbE3oq98COiciVkjaLenPG65nqDlw89lbkq6IiHMk/bKkzyZWvyRdLunbTRcxousk3RMRvyDpQ5KSaX3aPlPSZZLOjYgPqt8qXttsVUPdpP6xmvcFSfdHxHJJD6jduTNV/aVzszWhHxH3R8Rk9vAhSYubrKegpG8+i4j/jYgnsq9fUz90krmPwvZiSRdI+nLTtZSVzch+LSJukqSIeCsiftRwWWW9TdIptudJOlnS8w3XM62IeFDSSwNPr5H0lezrr0iaueu7K5qq/lFyszWhP+APJH296SIKmDM3n9n+WUkrJD3cbCWl/L2kP5WS/JjE90h6wfZNWXtqk+2Tmi6qqIh4XtLfSvqepP3qX7F3f7NVjeRdEXFQ6k+CJL2r4XqqKJSbsxr6tu/L+n9H/jyV/X1hbpu/kHQoIm6dzdq6zPapku6QdHk24289278t6WC2UrFadENgQfMknSvphog4V9JP1G81JMH2T6s/S14m6UxJp9r+ZLNV1SLFCUSp3JypO3KnFBG/Nd3rtj+t/nL9N2aloOr2S1qae7w4ey4Z2dL8Dkn/HBF3N11PCedJusj2BZJOknSa7Zsj4lMN11XUPkl7I+LR7PEdklK6EOA3JX0nIn4oSbb/RdKvSEptsnbQ9sLsvqJ3S/p+0wWVVTY3W9Pesb1a/aX6RRHxZtP1FHT0xrXsyoW16t+olpJ/kvTtiLiu6ULKiIirImJpRPyc+v/dH0go8JW1FPba/vnsqfOV1gnp70laZfsdtq1+/SmciB5cFW6R9Ons69+X1PaJz3H1j5KbrblO3/ZuSfMlvZg99VBE/FGDJRWS/Ue/TsduPru24ZIKs32epP+U9JT6y9qQdFVE3NtoYSXZ/oikP4mIi5qupQzbH1L/JPTbJX1H0iUR8UqzVRVn+2r1f+AekvS4pM9kFzS0ku1bJfUk/Yykg5KulvSvkm6XtETSHkm/GxEvN1XjdE5Q/1UqmZutCX0AwMxrTXsHADDzCH0A6BBCHwA6hNAHgA4h9AGgQwh9AOgQQh8AOoTQB4AO+X/yyD6I7NWHhAAAAABJRU5ErkJggg==\n", 592 | "text/plain": [ 593 | "" 594 | ] 595 | }, 596 | "metadata": {}, 597 | "output_type": "display_data" 598 | } 599 | ], 600 | "source": [ 601 | "xs = linspace(-1,11,1000)\n", 602 | "mus = xs\n", 603 | "C = (amax(xs)-amin(xs))/len(xs)\n", 604 | "def pxt(x,mu): return (xs>=0)*(x<=mu)*1.0/maximum(mu,1e-6)\n", 605 | "plot(xs,pxt(xs,5.5))\n", 606 | "pmu = (xs>=0)*(xs<=10)*0.1\n", 607 | "plot(xs,pmu)" 608 | ] 609 | }, 610 | { 611 | "cell_type": "markdown", 612 | "metadata": { 613 | "slideshow": { 614 | "slide_type": "slide" 615 | } 616 | }, 617 | "source": [ 618 | "Now assume we draw the sample $x_1=4$. How should we update our estimate?\n", 619 | "\n", 620 | "$$p(\\mu|x) = \\frac{p(x|\\mu) p(\\mu)}{p(x)}$$" 621 | ] 622 | }, 623 | { 624 | "cell_type": "code", 625 | "execution_count": 9, 626 | "metadata": { 627 | "slideshow": { 628 | "slide_type": "slide" 629 | } 630 | }, 631 | "outputs": [ 632 | { 633 | "data": { 634 | "text/plain": [ 635 | "[]" 636 | ] 637 | }, 638 | "execution_count": 9, 639 | "metadata": {}, 640 | "output_type": "execute_result" 641 | }, 642 | { 643 | "data": { 644 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAEACAYAAABfxaZOAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAGwFJREFUeJzt3XuUVPWZ7vHv0yBeSLwlwpmAeE0wMhEkM+gRNaUxSnQQczGCMTOaJToKxkRzJF6izUQTxYm3GCcyMiY4ZqHG4GV5Cd56lHEUVLymEaJREdGjJhqN8UjwPX/sQsu26a6qru5dv97PZ61eVu3au+pt7P3Ur9767b0VEZiZWTG05F2AmZn1HYe+mVmBOPTNzArEoW9mViAOfTOzAnHom5kVSFWhL2mCpKWSlkma0cnjB0p6RNISSYskja92WzMz6zvqbp6+pBZgGfB54AVgMTA5IpZWrLNRRLxVvv0Z4OqI+HQ125qZWd+pZqQ/DlgeEc9GxGpgHjCpcoW1gV/2EeDdarc1M7O+U03oDwNWVNx/vrzsAyQdJKkduBH4Zi3bmplZ32jYF7kRcV1EfBo4CDizUc9rZmaNM7CKdVYCIyruDy8v61RELJS0raTNa9lWkk8CZGZWo4hQLetXM9JfDGwvaStJg4DJwA2VK0jaruL2WGBQRPyhmm07FJ/kzxlnnJF7Da4//zpcf5o/Kddfj25H+hGxRtJ0YAHZm8SciGiXdHT2cMwGviLpH4F3gL8AX+tq27oqNTOzHqumvUNE3AqM7LDs0orbs4BZ1W5rZmb58BG5DVAqlfIuoUdcf75cf75Sr79W3R6c1VckRbPUYmaWAklEL3yRa2Zm/YRD38ysQBz6ZmYF4tA3MysQh76ZWYE49M3MCsShb2ZWIA59M7MCceibmRWIQ9/MrEAc+mZmBeLQNzMrEIe+mVmBOPQtaS+9BOeem3cVZulw6FvSnnoKTjoJHn8870rM0uDQt37hrLPyrsAsDQ59S1oEfOYzcMcd8OSTeVdj1vwc+pa8j34UjjsOfvjDvCsxa35VXRjdrNkddxxsvz08/TRsu23e1Zg1L4/0LWkRIMGmm8Kxx8KPfpR3RWbNzaFv/ca3vw2//jU8+2zelZg1L4e+9Rubbw7HHAP/8i95V2LWvBz6lrS17Z21vvtduPFGz+QxWxeHvvUrm24KJ54I3/9+3pWYNSeHviUt4sPLpk+HhQvhoYf6vh6zZufQt+RVtncABg+GU0+F007Lpx6zZlZV6EuaIGmppGWSZnTy+KGSHin/LJS0U8Vjz5SXL5G0qJHFm63L1KnQ3g733JN3JWbNpdvQl9QCXAzsB4wCpkjaocNqTwN7RsRo4ExgdsVj7wKliNg5IsY1pmyzTMcvctcaNAhmzoQZMzpvAZkVVTUj/XHA8oh4NiJWA/OASZUrRMR9EfF6+e59wLCKh1Xl65g11GGHwdtvwzXX5F2JWfOoJoyHASsq7j/PB0O9oyOBWyruB3CbpMWSptZeoll9Wlrgxz/ORvtvv513NWbNoaEjcEl7AUcAlX3/8RExFtgfmCZp90a+phXbuto7a+21F4weDRdd1Hc1mTWzak64thIYUXF/eHnZB5S/vJ0NTIiIP65dHhGryv99WdJ8snbRws5eqLW19b3bpVKJUqlURXlmXZs1C3bbDQ4/HIYMybsas/q1tbXR1tbWo+dQdPMtl6QBwJPA54FVwCJgSkS0V6wzArgD+EZE3FexfCOgJSLelDQYWADMjIgFnbxOdFeLWUd33ZV9YdvdfvDtb8M778All/RJWWZ9QhIR0cVn3Q/rdqQfEWskTScL7BZgTkS0Szo6ezhmA98HNgcukSRgdXmmzlBgvqQov9aVnQW+WU901d5Z6/TTYeRImDYNRo3q/ZrMmlW3I/2+4pG+1eOuu7ITrN11V/frXnQRXH893H57dW8UZs2unpG+p1Ja0moZJxx7LLzyClx9de/VY9bsHPqWvGpH7QMHZj39E0+EN97o3ZrMmpVD3wpl/Hj4wheyL3/Nisihb0nrbp5+Z845B+bOhccf752azJqZQ98KZ8gQaG3NZvJ47oAVjUPfCunoo+HPf4af/zzvSsz6lkPfklZPewdgwAC47LLsvDwvvtj4usyalUPfCmvMGDjyyOxKW2ZF4dC3pPW0J3/66fDYYzB/fmPqMWt2Dn1LXk+Ort1gg6zNM306vPZa42oya1YOfSu8PfaASZPgu9/NuxKz3ufQt6Q1asrl2Wdn5+S5+ebGPJ9Zs3LoW/IacfK0jTfOpm9OnZqdn8esv3Lom5WVSjB5MhxzjA/asv7LoW9Jq3ee/rqcdRa0t8OVVzbuOc2aiUPfrMIGG8AVV8AJJ8CKFXlXY9Z4Dn2zDnbeObu84uGHw7vv5l2NWWM59C1pjW7vrHXSSbB6dTarx6w/ceibdWLgQPjlL7NLLC5cmHc1Zo3j0Lek9eYsm+HDYc4cOPRQePXV3nsds77k0Lfk9eZFzg84AA4+GI44wtM4rX9w6Jt140c/yk6/fOGFeVdi1nMD8y7ArCf6YvQ9aBDMmwe77grjxsFuu/X+a5r1Fo/0LXm92d5Za9tt4fLL4Wtfg1Wrev/1zHqLQ9+sSgccAEcdlfX433kn72rM6uPQt6T11jz9dTntNPjYx7Ijds1S5NA3q0FLC8ydCwsWwC9+kXc1ZrXzF7lmNdpkk+zyiqUS7Lgj/P3f512RWfU80rek9XV7Z61Ro7LLLB50EDz3XN+/vlm9qgp9SRMkLZW0TNKMTh4/VNIj5Z+FknaqdluzVE2aBN/5DkycCG+8kXc1ZtXpNvQltQAXA/sBo4ApknbosNrTwJ4RMRo4E5hdw7Zmdcv7KNkTT4RddoEpU2DNmnxrMatGNSP9ccDyiHg2IlYD84BJlStExH0R8Xr57n3AsGq3NeupPNo7la/905/C22/7wuqWhmpCfxhQeTmJ53k/1DtzJHBLnduaJWe99eCaa+CWW7I3ALNm1tDZO5L2Ao4Adq9n+9bW1vdul0olSqVSQ+qy/ivv9s5am20GN98Me+wBQ4fCV7+ad0XWH7W1tdHW1taj56gm9FcCIyruDy8v+4Dyl7ezgQkR8cdatl2rMvTNqpVne6fSttvCTTfBvvtmB3DttVfeFVl/03EwPHPmzJqfo5r2zmJge0lbSRoETAZuqFxB0gjgWuAbEfFULdua9SdjxsBVV8Ehh8CSJXlXY/Zh3YZ+RKwBpgMLgCeAeRHRLuloSUeVV/s+sDlwiaQlkhZ1tW0v/B5WUHnN0+/KXnvBJZfAP/wDPPVU9+ub9aWqevoRcSswssOySytuTwWmVrutWX/31a/Cyy/DfvvBPffA3/xN3hWZZXwaBrNecswx2WUW99kH2tpgiy3yrsjMoW+Ja8b2TqXTTsvm8O+zD9x1F2y+ed4VWdH53DtmvewHP8hm9Oy7L7z2Wt7VWNE59C1pzTJPvysSzJoF48fDF7/o8/RYvhz6lrxmbu+sJcEFF8Do0Vnw/+lPeVdkReXQN+sjUjaVc6edsh7/H/6Qd0VWRA59S1oK7Z1KLS3Z+Xn22AP23jub1mnWlxz6lrwU2juVJPjXf83Ow/+5z8ELL+RdkRWJp2ya5UDKZvVsuCHsuSfccQdstVXeVVkROPQtaam1dzo65RQYPDhr99xyS3YZRrPe5NC35KXW3uno+OPh4x/PevzXXgu713VicrPquKdvSUt9pL/W178OV1wBX/4yXHdd3tVYf+aRviUv9ZH+Wvvum7V4Jk6EF1+Ef/7nvCuy/sihb9ZEPvtZuPtumDABVqzIvuxt8edxayD/OVnS+kt7p9L228O992Zn5jzkEHjrrbwrsv7EoW/J6y/tnUpDhmTTODfYwHP5rbEc+mZNaoMNYO5c+NKXYJdd4KGH8q7I+gOHviWtP7Z3KknZXP4LLsiuwnXttXlXZKnzF7mWvP7Y3unoK1+BrbfOpnQuXgxnngkDvfdaHTzSN0vEZz8LDzyQ/UyY4JO1WX0c+pa0/t7e6WiLLeDWW+Hv/i77Wbw474osNQ59S14R2juVBg6Es8+G88+H/feHyy7LuyJLiUPfkla0kX6lL38Z7rknC/9vfMOXYbTqOPQteUUb6VfaYQdYtAjWXx/GjvW0TuueQ98scYMHZy2eH/wgm9Z54YXF/gRkXXPoW9Icbu+bPBnuuw+uvBIOPBBeeSXviqwZOfQteUVu73S03XawcGHW9hk9Gm66Ke+KrNk49M36mUGD4NxzsxH/tGkwdaq/5LX3VRX6kiZIWippmaQZnTw+UtK9kt6WdEKHx56R9IikJZIWNapwM3B7pyulEjz6KLz7Luy0E/zXf+VdkTWDbkNfUgtwMbAfMAqYImmHDqu9ChwHnNvJU7wLlCJi54gY18N6zT7E7Z1123hjmDMHLroIpkyBE0+Ev/wl76osT9WM9McByyPi2YhYDcwDJlWuEBGvRMSDwF872V5Vvo6Z9ZKJE7NR//PPZ71+j/qLq5owHgasqLj/fHlZtQK4TdJiSVNrKc6sO27vVO/jH4errsr6/YcdBkcdBa+9lndV1tf64jx94yNilaQtyMK/PSIWdrZia2vre7dLpRKlUqkPyrPUub1Tm0mTsn7/ySfDqFHwk59kR/da82tra6Otra1Hz6HoZqgkaVegNSImlO9/D4iIOKeTdc8A3oiI89bxXOt8XFJ0V4tZR//5n9nFxK+8Mu9K0rRwYTa7Z4cdsvAfPjzviqwWkoiImoY91bR3FgPbS9pK0iBgMnBDV3VUFLSRpI+Ubw8G9gUer6VAs+54pF+/3XeHhx/OZveMGQPnnAPvvJN3Vdabug39iFgDTAcWAE8A8yKiXdLRko4CkDRU0grgO8Cpkp4rh/1QYKGkJcB9wI0RsaC3fhkzq93668PMmdnRvPfck70B3HZb3lVZb+m2vdNX3N6xelxxBfzmN1mbxxrjxhvh+OOzE7iddx6MGJF3RbYuvdXeMWtqbu801sSJ8MQT2Yh/7Njs0oxvvZV3VdYoDn0z+5ANN4TTT8+uzPXoozByJMydmx3da2lz6FvS3BHsXdtsA1dfnc3v/9nPsuv03nln3lVZTzj0LXlu7/S+3XaD//5vOOWUbIrnxInQ3p53VVYPh76ZVUWCgw+G3/4W9toL9twTjjwSnnsu78qsFg59S5rbO31v/fXhhBNg2TIYOhR23hm+9S148cW8K7NqOPQteW7v5GOzzeCss7I2z4ABsOOOMGMGvPpq3pVZVxz6ljSP9PM3ZAicf342y+f11+FTn4LWVp/MrVk59C15Huk3h+HDsxk+ixbBM89kl2489VR4+eW8K7NKDn0za6jttoOf/xweeCBr9YwcmX0H8MILeVdm4NC3xLm907y22SYb+T/2WPb/6W//Fo49NvsUYPlx6Fvy3N5pbsOGZT3/pUth002zA7wOOwyWLMm7smJy6JtZnxgyBH74Q3jqqey8PhMnwt57w003+fQOfcmhb0lzeyc9m24KJ50ETz8N3/wmnHZa1vr593+Ht9/Ou7r+z6FvyXN7J02DBmVtnocegosvhuuug623zqZ7rlqVd3X9l0PfzHIlvd/mufPO7MjeHXeEQw6Bu+/2p7lGc+hb0hwI/cuOO2Yzfp55BsaPh6OOgtGj4dJL4c03866uf3DoW/Lc3ul/NtkkO59Pe3t29a5bb4Wttnp/mdXPoW9J80i/f5Ngn31g/vzsAu4bb5y1gsaPh8sv9+i/Hg59S55H+sWw5ZbZpRufey47sdv8+dmyo46C++/3AKBaDn0zS8p668GBB8INN2TX8t1mG/j617O5/xdcAK+8kneFzc2hb0nz6K7YPvEJOPnk7Nz+P/kJPPhgdu6fgw6Ca6/1vP/OOPQteW7vWEsLlEpwxRVZ+2fSJLjkkuxNYerUbOqnj/rNOPTNrF/ZZBM44gi44w545BH45Cdh2jTYdtvsVM9Fn/3j0Lekub1jXdlyy+yUD489ln0H8M472WygsWNh1ixYuTLvCvueQ9+S5/aOVWOnneDcc7P2z3nnZVf6+tKX8q6q7w3MuwCznvBI32o1YEDW/99iCzj44Lyr6Xse6ZtZIUnFHDRUFfqSJkhaKmmZpBmdPD5S0r2S3pZ0Qi3bmvWU2ztWD4f+OkhqAS4G9gNGAVMk7dBhtVeB44Bz69jWrG5F3GmtcYr491PNSH8csDwino2I1cA8YFLlChHxSkQ8CPy11m3NesojfauHR/rrNgxYUXH/+fKyavRkWzOzXlPUwUJTzd5pbW1973apVKJUKuVWi6WhiCM1a4wUR/ptbW20tbX16DmqCf2VwIiK+8PLy6pR07aVoW9WraKO2KxnUgz9joPhmTNn1vwc1bR3FgPbS9pK0iBgMnBDF+tX7oK1bmtm1idSDP1G6HakHxFrJE0HFpC9ScyJiHZJR2cPx2xJQ4EHgI8C70o6HtgxIt7sbNte+22scIq401pjOPS7EBG3AiM7LLu04vZLwJbVbmvWSG7vWD2KGvo+IteSVsSd1hrDoW9mViAOfbNEub1j9XLomyWmiDutNYZH+maJ8kjf6uHQNzMrkKIOFhz6lrQijtSsMTzSN0tUUUds1jMOfTOzAnHomyWoiDutNYZD3yxRbu9YPRz6Zgkq4k5rjeHQNzMrGIe+WYLc3rF6eKRvlqAi7rTWGA59s0R5pG/1KOrfjUPfzArJI32zBBVxp7XGcOibJaqoH9OtZxz6ZmYF4tA3S1ARd1prDIe+WaLc3rF6OPTNElTEndYap4h/Pw59Myskj/TNEuX2jtXDoW+WoCLutNYYRR0sOPQteUXdea1nPNI3MysQh34XJE2QtFTSMkkz1rHORZKWS3pY0s4Vy5+R9IikJZIWNapwMyjmTmuNUdTQH9jdCpJagIuBzwMvAIslXR8RSyvW+SKwXUR8UtIuwL8Bu5YffhcoRcQfG169GW7vWH2KGvrVjPTHAcsj4tmIWA3MAyZ1WGcSMBcgIu4HNpE0tPyYqnwds5oVcae1xnDor9swYEXF/efLy7paZ2XFOgHcJmmxpKn1Fmpm1khFDf1u2zsNMD4iVknagiz82yNiYR+8rhWE2ztWL4d+51YCIyruDy8v67jOlp2tExGryv99WdJ8snZRp6Hf2tr63u1SqUSpVKqiPCuyIu601hgpjvTb2tpoa2vr0XMouvmtJQ0AniT7IncVsAiYEhHtFevsD0yLiAMk7QpcEBG7StoIaImINyUNBhYAMyNiQSevE93VYtbROefAq6/CrFl5V2KpWbMGBg5ML/grSSIiavqs2+1IPyLWSJpOFtgtwJyIaJd0dPZwzI6ImyXtL+l3wJ+BI8qbDwXmS4rya13ZWeCb9YTbO1aPov7dVNXTj4hbgZEdll3a4f70Trb7PTCmJwWadSXlUZrlq6ih76mUlryi7rzWM2v/boo2cHDom1mhOfTNElK0HdYaK8UZPD3l0Lfkub1j9XLomyWmaDusNZZD38ysYBz6Zolxe8fq5ZG+WWKKtsNaYzn0zcwKpIifEh36lrwi7rjWGB7pmyWmaDusNZZD3yxBHulbvRz6ZmYF4tA3S0zRdlhrLIe+WYLc3rF6OfTNElO0HdYar2h/Qw59Myssj/TNEuT2jtXLoW+WmKLtsNZYDn0zswIp4qdEh74lr4g7rjWGR/pmiSnaDmuN5dA3S5BH+lYvh76ZWYE49M0SU7Qd1hrLoW+WILd3rCcc+mYJKdoOa43lkb6ZWYE49NdB0gRJSyUtkzRjHetcJGm5pIcljallW7OecHvH6uXQ74SkFuBiYD9gFDBF0g4d1vkisF1EfBI4GvhZtdv2B21tbXmX0CMp1x8Bv/99W95l9EjK//6Qdv0S3HtvW85V9K1qRvrjgOUR8WxErAbmAZM6rDMJmAsQEfcDm0gaWuW2yUv5jx7Sr/+ZZ9ryLqFHUv/3T7l+Cf7nf9ryLqNPVRP6w4AVFfefLy+rZp1qtjUzy0UR2zsDe+l5m6bLumYNHHRQ777Gk0/Cgw/27mv0ppTrX7oUttwy7yosVQMGwPXXw+9+1zvPP2gQXHtt7zx3vRTdvM1J2hVojYgJ5fvfAyIizqlY52fAXRFxVfn+UuBzwDbdbVvxHAV7vzUz67mIqGmQXc1IfzGwvaStgFXAZGBKh3VuAKYBV5XfJF6LiJckvVLFtnUVbmZmtes29CNijaTpwAKy7wDmRES7pKOzh2N2RNwsaX9JvwP+DBzR1ba99tuYmVmXum3vmJlZ/9E0R+RKmiWpvXxw17WSNs67pmqkfPCZpOGS7pT0hKTHJH0r75pqJalF0kOSbsi7llpJ2kTSNeW/+yck7ZJ3TbWQ9B1Jj0t6VNKVkgblXVNXJM2R9JKkRyuWbSZpgaQnJf1G0iZ51tiVddRfc242TeiTtYBGRcQYYDlwcs71dKsfHHz2V+CEiBgF/G9gWmL1AxwP/DbvIup0IXBzRHwaGA0k0/qU9AngOGBsROxE1iqenG9V3bqcbF+t9D3g9ogYCdxJc+dOZ/XXnJtNE/oRcXtEvFu+ex8wPM96qpT0wWcR8WJEPFy+/SZZ6CRzHIWk4cD+wGV511Kr8ohsj4i4HCAi/hoRf8q5rFoNAAZLGghsBLyQcz1dioiFwB87LJ4E/KJ8+xdAL0/wrl9n9deTm00T+h18E7gl7yKq0G8OPpO0NTAGuD/fSmpyPvB/gBS/mNoGeEXS5eX21GxJG+ZdVLUi4gXgx8BzwEqyGXu351tVXYZExEuQDYKAITnX0xNV5Wafhr6k28r9v7U/j5X/O7FinVOB1RHxy76srcgkfQT4FXB8ecTf9CQdALxU/qQimuiAwCoNBMYCP42IscBbZK2GJEjalGyUvBXwCeAjkg7Nt6qGSHEAUVNu9tYRuZ2KiC909bikw8k+ru/dJwX13EpgRMX94eVlySh/NP8VcEVEXJ93PTUYDxwoaX9gQ+CjkuZGxD/mXFe1ngdWRMQD5fu/AlKaCLAP8HRE/AFA0q+B3YDUBmsvSRpaPq7ofwH/N++CalVrbjZNe0fSBLKP6gdGxP/Lu54qvXfgWnnmwmSyA9VS8h/AbyPiwrwLqUVEnBIRIyJiW7J/9zsTCnzKLYUVkj5VXvR50vpC+jlgV0kbSBJZ/Sl8Ed3xU+ENwOHl2/8ENPvA5wP115ObTTNPX9JyYBDwannRfRFxbI4lVaX8j34h7x98dnbOJVVN0njgbuAxso+1AZwSEbfmWliNJH0OODEiDsy7llpIGk32JfR6wNPAERHxer5VVU/SGWRvuKuBJcCR5QkNTUnSL4ES8DHgJeAM4DrgGmBL4FngaxHxWl41dmUd9Z9CjbnZNKFvZma9r2naO2Zm1vsc+mZmBeLQNzMrEIe+mVmBOPTNzArEoW9mViAOfTOzAnHom5kVyP8H3slfhZZAVgkAAAAASUVORK5CYII=\n", 645 | "text/plain": [ 646 | "" 647 | ] 648 | }, 649 | "metadata": {}, 650 | "output_type": "display_data" 651 | } 652 | ], 653 | "source": [ 654 | "pmu1 = pxt(4,mus)*pmu\n", 655 | "pmu1 /= C*sum(pmu1)\n", 656 | "plot(xs,pmu1)" 657 | ] 658 | }, 659 | { 660 | "cell_type": "markdown", 661 | "metadata": { 662 | "slideshow": { 663 | "slide_type": "slide" 664 | } 665 | }, 666 | "source": [ 667 | "The maximum likelihood estimate is clearly at $\\mu=4$. If we now plug this\n", 668 | "into $p(x|\\mu)$ we get..." 669 | ] 670 | }, 671 | { 672 | "cell_type": "code", 673 | "execution_count": 10, 674 | "metadata": { 675 | "slideshow": { 676 | "slide_type": "slide" 677 | } 678 | }, 679 | "outputs": [ 680 | { 681 | "data": { 682 | "text/plain": [ 683 | "[]" 684 | ] 685 | }, 686 | "execution_count": 10, 687 | "metadata": {}, 688 | "output_type": "execute_result" 689 | }, 690 | { 691 | "data": { 692 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAEACAYAAABI5zaHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAD+pJREFUeJzt3H+s3Xddx/Hnq5QJ8mMKhBlaNt34MV1ky6JzisqRoSszUOIf2i0BmYE0yJAo0QHG7JKQACZGgaFY7SYYyQiDaDX8KAgnBGU/+DE2WLsWB6XtRgmMYcBgSvf2j3s27y7tvd/v6bn99nz6fCQ3u9/v+fR7XrntXvdz3+d8b6oKSVJb1g0dQJI0e5a7JDXIcpekBlnuktQgy12SGmS5S1KDVi33JNuTHEpy+wpr3pZkb5Lbklww24iSpL667NyvBy491oNJng+cU1VPB7YC75xRNknSlFYt96r6FPDtFZZsBt49WXszcHqSM2YTT5I0jVnM3DcA+5ccH5yckyQNxBdUJalB62dwjYPAU5ccb5yc+yFJ/EU2kjSFqkqf9V137pl8HM0O4CUASS4G7q+qQysEnNuPa665ZvAM5h8+x6mYf56zt5B/Gqvu3JO8BxgBT0zyNeAa4LTFnq5tVfXBJJcl+TLwPeDKqZJIkmZm1XKvqis6rLlqNnEkSbPgC6o9jEajoSMcF/MPa57zz3N2mP/808i085ypniypE/l8ktSCJNQavaAqSZojlrskNchyl6QGWe6S1CDLXZIaZLlLUoMsd0lqkOUuSQ2y3CWpQZa7JDXIcpekBlnuktQgy12SGmS5S1KDLHdJapDlLkkNstwlqUGWuyQ1yHKXpAZZ7pLUIMtdkhpkuUtSgyx3SWqQ5S5JDbLcJalBlrskNchyl6QGWe6S1CDLXZIaZLlLUoMsd0lqkOUuSQ2y3CWpQZa7JDWoU7kn2ZRkd5I9Sa4+yuOPT7IjyW1J7kjy0pknlSR1lqpaeUGyDtgDXALcA9wKbKmq3UvWvA54fFW9LsmTgLuAM6rqB8uuVas9nyTp4ZJQVenzZ7rs3C8C9lbVvqo6DNwAbF62poDHTT5/HPCt5cUuSTpxupT7BmD/kuMDk3NLXQv8TJJ7gC8Ar55NPEnSNNbP6DqXAp+vqucmOQf4aJJnVdV3ly9cWFh46PPRaMRoNJpRBElqw3g8ZjweH9c1uszcLwYWqmrT5Pi1QFXVW5as+TfgTVX1H5PjfweurqrPLLuWM3dJ6mmtZu63Ak9LclaS04AtwI5la/YBz5uEOAN4BnB3nyCSpNlZdSxTVUeSXAXsZPGbwfaq2pVk6+LDtQ14I/APSW6f/LE/qar71iy1JGlFq45lZvpkjmUkqbe1GstIkuaM5S5JDbLcJalBlrskNchyl6QGWe6S1CDLXZIaZLlLUoMsd0lqkOUuSQ2y3CWpQZa7JDXIcpekBlnuktQgy12SGmS5S1KDLHdJapDlLkkNstwlqUGWuyQ1yHKXpAZZ7pLUIMtdkhpkuUtSgyx3SWqQ5S5JDbLcJalBlrskNchyl6QGWe6S1CDLXZIaZLlLUoMsd0lqkOUuSQ2y3CWpQZ3KPcmmJLuT7Ely9THWjJJ8PskXk3xitjElSX2kqlZekKwD9gCXAPcAtwJbqmr3kjWnA/8J/EZVHUzypKr65lGuVas9nyTp4ZJQVenzZ7rs3C8C9lbVvqo6DNwAbF625grg/VV1EOBoxS5JOnG6lPsGYP+S4wOTc0s9A3hCkk8kuTXJi2cVUJLU3/oZXudC4LnAY4BPJ/l0VX15RteXJPXQpdwPAmcuOd44ObfUAeCbVfV94PtJPgmcD/xQuS8sLDz0+Wg0YjQa9UssSY0bj8eMx+PjukaXF1QfAdzF4guq9wK3AJdX1a4la84F3g5sAn4EuBn4naq6c9m1fEFVknqa5gXVVXfuVXUkyVXAThZn9NuraleSrYsP17aq2p3kI8DtwBFg2/JilySdOKvu3Gf6ZO7cJam3tXorpCRpzljuktQgy12SGmS5S1KDLHdJapDlLkkNstwlqUGWuyQ1yHKXpAZZ7pLUIMtdkhpkuUtSgyx3SWqQ5S5JDbLcJalBlrskNchyl6QGWe6S1CDLXZIaZLlLUoMsd0lqkOUuSQ2y3CWpQZa7JDXIcpekBlnuktQgy12SGmS5S1KDLHdJapDlLkkNstwlqUGWuyQ1yHKXpAZZ7pLUIMtdkhpkuUtSgzqVe5JNSXYn2ZPk6hXW/XySw0l+a3YRJUl9rVruSdYB1wKXAucBlyc59xjr3gx8ZNYhJUn9dNm5XwTsrap9VXUYuAHYfJR1rwJuBL4xw3ySpCl0KfcNwP4lxwcm5x6S5CnAi6rqb4DMLp4kaRqzekH1r4Cls3gLXpIGtL7DmoPAmUuON07OLfVzwA1JAjwJeH6Sw1W1Y/nFFhYWHvp8NBoxGo16Rpakto3HY8bj8XFdI1W18oLkEcBdwCXAvcAtwOVVtesY668H/rWqPnCUx2q155MkPVwSqqrXRGTVnXtVHUlyFbCTxTHO9qralWTr4sO1bfkf6RNAkjR7q+7cZ/pk7twlqbdpdu7eoSpJDbLcJalBlrskNchyl6QGWe6S1CDLXZIaZLlLUoMsd0lqkOUuSQ2y3CWpQZa7JDXIcpekBlnuktQgy12SGmS5S1KDLHdJapDlLkkNstwlqUGWuyQ1yHKXpAZZ7pLUIMtdkhpkuUtSgyx3SWqQ5S5JDbLcJalBlrskNchyl6QGrR86wLy47z54xzvgyJGhk5yaHvUoeM1r4JGPHDqJNB/cuXf0uc/BddcNneLU9aY3wb33Dp1Cmh/u3DuqgrPPhoWFoZOcmq67bvHvQFI37tw7qoJ1frUGs26d5S71YV11VAXJ0ClOXYnlLvVhuXf0wAOW+5CSxb8DSd1Y7h05lhmWYxmpH+uqI8cyw3IsI/XTqdyTbEqyO8meJFcf5fErknxh8vGpJD87+6jDciwzLMcyUj+rlnuSdcC1wKXAecDlSc5dtuxu4Fer6nzgjcDfzTro0BzLDMuxjNRPl7q6CNhbVfuq6jBwA7B56YKquqmqvjM5vAnYMNuYw3MsMyzHMlI/Xcp9A7B/yfEBVi7vlwEfOp5QJyPHMsNyLCP1M9M7VJP8GnAl8MvHWrOw5BbP0WjEaDSaZYQ141hmWI5ldCoZj8eMx+PjukaXcj8InLnkeOPk3MMkeRawDdhUVd8+1sUW5vT+fccyw3Iso1PJ8o3vG97wht7X6LIXvRV4WpKzkpwGbAF2LF2Q5Ezg/cCLq+q/eqeYA5b7sBzLSP2sunOvqiNJrgJ2svjNYHtV7UqydfHh2gb8GfAE4K+TBDhcVRetZfAT7YEHHMsMybGM1E+nmXtVfRh45rJzf7vk85cDL59ttJOLO/dhOZaR+nEv2pHlPizHMlI/lntHjmWG5VhG6se66sid+7Acy0j9WO4dWe7Dciwj9WO5d+RYZliOZaR+rKuO3LkPy7GM1I/l3pHlPizHMlI/lntH/m6ZYTmWkfqxrjryt0IOy7GM1I/l3pFjmWE5lpH6sdw7styH5c5d6sdy78i3Qg7LmbvUj3XVkTv3Yblzl/qx3Duy3IflzF3qx3LvyLHMsBzLSP1YVx25cx+WYxmpH8u9I8t9WI5lpH4s9468Q3VYjmWkfqyrjrxDdViOZaR+LPeOHMsMy7GM1I/l3pFjmWE5lpH6sa46ciwzLMcyUj+We0eOZYblWEbqx3LvyLHMsBzLSP1YVx05lhmWYxmpH8u9I8cyw3IsI/VjuXfkWGZYjmWkfqyrjty5D8uxjNSP5d6RM/dhOZaR+rHcO3IsMyzHMlI/1lVHjmWG5VhG6sdy78ixzLAcy0j9WO4dOZYZlmMZqR/rqiPHMsNyLCP1Y7l35FhmWI5lpH46lXuSTUl2J9mT5OpjrHlbkr1JbktywWxjDs+d+7DcuUv9rFruSdYB1wKXAucBlyc5d9ma5wPnVNXTga3AO9cg66Cq4KtfHQ8d47iMx+OhI0xt3TrYtWs8dIzjMs9f/3nODvOffxpddu4XAXural9VHQZuADYvW7MZeDdAVd0MnJ7kjJkmHdgDD8BXvjIeOsZxmed/4Ans3j0eOsZxmeev/zxnh/nPP40u5b4B2L/k+MDk3EprDh5lzVxzJDAsxzJSP+uHDjArO3fC29++dte/804455y1u75Wtn493HILvOAFQyeZ3l13wWc/O3SK6cxzdlj7/JddBq94xdpdfxqpVbZDSS4GFqpq0+T4tUBV1VuWrHkn8Imqeu/keDfwnKo6tOxa7r0kaQpV1estHV127rcCT0tyFnAvsAW4fNmaHcArgfdOvhncv7zYpwknSZrOquVeVUeSXAXsZHFGv72qdiXZuvhwbauqDya5LMmXge8BV65tbEnSSlYdy0iS5s8Jv0M1yZ8n2TW52en9SR5/ojP01eUmrpNVko1JPp7kS0nuSPIHQ2eaRpJ1ST6XZMfQWfpKcnqS903+3X8pyS8MnamPJH+Y5ItJbk/yT0lOGzrTSpJsT3Ioye1Lzv14kp1J7krykSSnD5lxJcfI37s3h/j1AzuB86rqAmAv8LoBMnTW5Sauk9wPgD+qqvOAXwReOWf5H/Rq4M6hQ0zprcAHq+qngfOBXQPn6SzJU4BXARdW1bNYHOVuGTbVqq5n8f/XpV4LfKyqngl8nJO7d46Wv3dvnvByr6qPVdWDvyXkJmDjic7QU5ebuE5aVfX1qrpt8vl3WSyWuboHIclG4DLg74fO0tdkh/UrVXU9QFX9oKr+e+BYfT0CeEyS9cCPAvcMnGdFVfUp4NvLTm8G3jX5/F3Ai05oqB6Oln+a3hz6F4f9HvChgTOspstNXHMhyU8CFwA3D5ukt78E/hiYxxeIfgr4ZpLrJ2OlbUkePXSorqrqHuAvgK+xeHPi/VX1sWFTTeXJD76Dr6q+Djx54DzHo1Nvrkm5J/noZD734Mcdk/++YMmaPwUOV9V71iKDHi7JY4EbgVdPdvBzIclvAocmP31k8jFP1gMXAu+oqguB/2FxRDAXkvwYi7ves4CnAI9NcsWwqWZiHjcKvXpzTe5QrapfX+nxJC9l8cfs567F88/YQeDMJccbJ+fmxuTH6RuBf6yqfxk6T0/PBl6Y5DLg0cDjkry7ql4ycK6uDgD7q+ozk+MbgXl6Uf55wN1VdR9Akg8AvwTM26bsUJIzqupQkp8AvjF0oL769uYQ75bZxOKP2C+sqv890c8/hYdu4pq8S2ALizdtzZPrgDur6q1DB+mrql5fVWdW1dksfu0/PkfFzmQUsD/JMyanLmG+Xhj+GnBxkkclCYv55+EF4eU/5e0AXjr5/HeBk32T87D80/TmCX+fe5K9wGnAtyanbqqq3z+hIXqafGHfyv/fxPXmgSN1luTZwCeBO1j8UbSA11fVhwcNNoUkzwFeU1UvHDpLH0nOZ/HF4EcCdwNXVtV3hk3VXZJrWPzGehj4PPCyyZsLTkpJ3gOMgCcCh4BrgH8G3gc8FdgH/HZV3T9UxpUcI//r6dmb3sQkSQ0a+t0ykqQ1YLlLUoMsd0lqkOUuSQ2y3CWpQZa7JDXIcpekBlnuktSg/wNC/JuqhprR7wAAAABJRU5ErkJggg==\n", 693 | "text/plain": [ 694 | "" 695 | ] 696 | }, 697 | "metadata": {}, 698 | "output_type": "display_data" 699 | } 700 | ], 701 | "source": [ 702 | "ylim(0,1)\n", 703 | "plot(xs,pxt(xs,4))" 704 | ] 705 | }, 706 | { 707 | "cell_type": "markdown", 708 | "metadata": { 709 | "slideshow": { 710 | "slide_type": "slide" 711 | } 712 | }, 713 | "source": [ 714 | "This is odd because it predicts that only values between 0 and 4 can occur.\n", 715 | "But the training sample $x_1=4$ only excludes that $\\mu\\lt4$; it doesn't \n", 716 | "exclude any values greater than $4$.\n", 717 | "\n", 718 | "What's the Bayesian estimate?\n", 719 | "\n", 720 | "$$p(x|D) = \\int p(x|\\theta) p(\\theta|D) d\\theta \\propto \\int_0^{10} 1/\\theta \\cdot \\lfloor x\\in[0,\\theta]\\rfloor \\cdot 1/\\theta \\cdot \\lfloor \\theta\\in[4,10]\\rfloor d\\theta\n", 721 | "= \\int_{x_1}^{10} 1/\\theta^2 \\cdot \\lfloor x\\in[0,\\theta]\\rfloor d\\theta$$\n", 722 | "\n", 723 | "You can either think about it, or we can simply perform this integral numerically." 724 | ] 725 | }, 726 | { 727 | "cell_type": "code", 728 | "execution_count": 11, 729 | "metadata": { 730 | "slideshow": { 731 | "slide_type": "slide" 732 | } 733 | }, 734 | "outputs": [ 735 | { 736 | "data": { 737 | "text/plain": [ 738 | "[]" 739 | ] 740 | }, 741 | "execution_count": 11, 742 | "metadata": {}, 743 | "output_type": "execute_result" 744 | }, 745 | { 746 | "data": { 747 | "image/png": "\n", 748 | "text/plain": [ 749 | "" 750 | ] 751 | }, 752 | "metadata": {}, 753 | "output_type": "display_data" 754 | } 755 | ], 756 | "source": [ 757 | "result = zeros(xs.shape)\n", 758 | "for i,mu in enumerate(mus):\n", 759 | " weight = pmu1[i]\n", 760 | " result += weight * pxt(xs,mu)\n", 761 | "result /= C*sum(result)\n", 762 | "plot(xs,result)\n", 763 | "plot(xs,pxt(xs,4))" 764 | ] 765 | }, 766 | { 767 | "cell_type": "markdown", 768 | "metadata": { 769 | "slideshow": { 770 | "slide_type": "slide" 771 | } 772 | }, 773 | "source": [ 774 | "Now assume we get another sample, $x_1=7$" 775 | ] 776 | }, 777 | { 778 | "cell_type": "code", 779 | "execution_count": 12, 780 | "metadata": { 781 | "slideshow": { 782 | "slide_type": "slide" 783 | } 784 | }, 785 | "outputs": [ 786 | { 787 | "data": { 788 | "text/plain": [ 789 | "[]" 790 | ] 791 | }, 792 | "execution_count": 12, 793 | "metadata": {}, 794 | "output_type": "execute_result" 795 | }, 796 | { 797 | "data": { 798 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAEACAYAAABI5zaHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAFnJJREFUeJzt3X20VHW9x/H398BlXTXNuhUpiNcn1Kg0S8qbD4fwAalEi7xg6TVFScGjeS3UfDhULrVVJlc0Jb2uvGlgelMsH4irZ6mZhVhYBkK6JEBFzfC5QvzeP/YcHKYzZ2bP2TO/+e39ea111pmH3cyH4+lzfvOdvfeYuyMiIvnSETqAiIhkT+UuIpJDKncRkRxSuYuI5JDKXUQkh1TuIiI5VFe5m9k4M1tmZsvNbEYf9x9gZuvM7OHS1znZRxURkXoNrrWBmXUAs4GxwFPAIjO71d2XVWx6r7sf1oSMIiKSUj0r99HACndf6e7rgbnAhD62s0yTiYhIw+op92HAqrLrq0u3VdrHzH5rZj8zs/dlkk5ERBpScyxTp8XACHd/zcwOBW4BRmb02CIiklI95b4GGFF2fXjpto3c/ZWyy3eY2RVm9k53f6F8OzPTiWxERBrg7qlG3/WMZRYBO5vZ9mY2BJgEzC/fwMyGll0eDVhlsZcFjPbr/PPPD55B+cPnKGL+mLPnIX8jaq7c3X2DmU0HFpD8MbjG3Zea2dTkbp8DTDSzk4D1wOvAvzeURkREMlHXzN3d7wR2rbjtqrLLlwOXZxtNREQapSNUU+js7AwdYUCUP6yY88ecHeLP3whrdJ7T0JOZeSufT0QkD8wMb8IbqiIiEhmVu4hIDqncRURySOUukgN6K0sqqdxFIvfKK7DjjvDqq6GTSDtRuYtE7m9/gyefhHnzQieRdqJyF8mJ730vdAJpJyp3kci5w9Zbw/PPw0MPhU4j7ULlLpIDgwbBiSfClVeGTiLtQuUukhPHHQc33wzr1oVOIu1A5S4SOXcwg6FD4eCD4Yc/DJ1I2oHKXSRHTjopeWNV+72Lyl0kRw44ADZsgPvvD51EQlO5i0SudywDyfcvfQmuuCJsJglP5S6SM8ceC3fdBWvW1NxUckzlLhK5yvn61lvDUUfpoKaiU7mL5IBVfIxDVxfMmQOvvx4mj4SnchfJoZEjYe+94YYbQieRUFTuIpGrttvjaafBrFnaLbKoVO4iOVA5lgE48MBkt8h77ml9HglP5S6SU2bJ7H3WrNBJJASVu0jk+hu7HH00PPAAPP546/JIe1C5i+RAX2MZgM03hylT4NJLW5tHwlO5i+RcVxdcfz0891zoJNJKKneRyNXaG2abbWDiRJg9uzV5pD2o3EVyoNpYptcZZyRHrOpDtItD5S5SACNHwn77wTXXhE4iraJyF4lcvQcpzZgB3/kOrF/f3DzSHlTuIjlQaywDMHo07LgjzJvX/DwSnspdpEC++lX41rd0SoIiULmLRC5NUY8bBx0d8NOfNi+PtAeVu0gO1DOW6d3u3HPh61/X6j3vVO4iBXPEEfDXv8Idd4ROIs2kcheJXPlnqNajoyNZvc+cqdV7ntVV7mY2zsyWmdlyM5vRz3Z7m9l6M/tMdhFFJGsTJ8IrrySftSr5VLPczawDmA0cAowCJpvZblW2uwjQr4tICzWy+tbqPf/qWbmPBla4+0p3Xw/MBSb0sd0pwE3AsxnmE5E6pBnL9Prc52DdOli4MPs8El495T4MWFV2fXXpto3MbFvgcHf/HtDAr5mItNqgQXDOOVq959XgjB7nUqB8Fl+14Lu7uzde7uzspLOzM6MIIsU0kGKeNAm+8Q34+c/h4IOzyyQD09PTQ09Pz4Aew7zGb4aZfQzodvdxpetnAu7uF5dt80TvReBdwKvAie4+v+KxvNbziUg6K1fC/vsn3xtx443JUauLFjU23pHmMzPcPdV/nXrGMouAnc1sezMbAkwCNiltd9+x9LUDydz95MpiF5H2NHEivPkm3Hxz6CSSpZrl7u4bgOnAAuBRYK67LzWzqWZ2Yl//k4wzikg/BvpiuKMDLrwwmb+/8UY2mSS8mmOZTJ9MYxmRzD35JHR2Jt8b5Q5jxiQfqH388RkFk8w0aywjIjlnlqzeZ85MTk0g8VO5i0QuqxfD++wDH/pQ8nF8Ej+Vu0gOZLWXywUXwEUXwUsvZfN4Eo7KXUQ2ev/74dBDkxGNxE3lLhK5rPdRuOACmDNnYG/QSngqd5EcyPLgo2HDoKsLzjwzu8eU1lO5i8g/OOMM+MUv4Je/DJ1EGqVyF4lcMw4d2WKLZDzz5S8nR69KfFTuIjnQjHPCfOELyRGr8+Zl/9jSfCp3EelTRwd897vJ7P3110OnkbRU7iKRS/sZqmnstx+MHg0XX1x7W2kvKncR6dcll8Ds2fDEE7W3lfahcheJXLPPxbfddsneM6ee2tznkWyp3EVyoNkfsnH66bBiBdx2W3OfR7KjcheRmoYMgcsuSw5u0purcVC5i0SuVR+RcNBBsPfeyYnFpP2p3EVyoFWffXrJJXD55cmIRtqbyl1E6jZ8OJx9Nkyd2rpXDNIYlbtI5Fpdsl1d8PLLcO21rX1eSUflLpIDrRrLAAweDFdfnRy5+swzrXteSUflLiKp7bEHTJkCp5wSOolUo3IXiVyo2fe558KSJXDLLWGeX/qnchfJgVaOZXptthl8//swfTq8+GLrn1/6p3IXkYYdcACMHw9f+UroJFJJ5S4SudC7JH7727BgAdx+e9gcsimVu0gOhBjL9Npqq2S3yBNOgBdeCJdDNqVyF5EBGzMGJk6EadNCJ5FeKneRyIUey/S68EJ4+GG48cbQSQRU7iK5EHIs02vzzeG665J9359+OnQaUbmLSGY++tFk9j5lSvu8oigqlbtI5NqtRM87D559NvloPglncOgAIjJw7TCW6TVkCPzoR7DPPskHbO+5Z+hExaSVu4hkbued4dJLYdIkePXV0GmKSeUuEjn39lq59/r855MZfFdX6CTFpHIXkaaZPRvuuw/mzg2dpHjqKnczG2dmy8xsuZnN6OP+w8xsiZn9xsx+bWYfzz6qiPSl3d5QLbfllsn8vasLHn88dJpiqVnuZtYBzAYOAUYBk81st4rNFrr7Hu7+IeB44OrMk4pIVe04lun14Q8ne9B89rPw2muh0xRHPSv30cAKd1/p7uuBucCE8g3cvfw/2duAN7OLKCKxmzYNRo2Ck09u71caeVJPuQ8DVpVdX126bRNmdriZLQVuA47LJp6I1BJDWZrBnDmweHHyXZovszdU3f0Wd98dOBz4ZlaPKyK1tfNYptcWW8DNNyef4LRoUeg0+VfPQUxrgBFl14eXbuuTu99vZjua2Tvd/R9OANrd3b3xcmdnJ52dnXWHFZG4jRwJV12VnEFy8WJ417tCJ2pPPT099PT0DOgxzGu8pjOzQcBjwFjgaeDXwGR3X1q2zU7u/njp8l7Are6+XR+P5bWeT0TSWbIEjj4aHnkkdJL6zZiRlPudd8JgHSdfk5nh7qlen9Ucy7j7BmA6sAB4FJjr7kvNbKqZnVja7LNm9nszexi4DDgyZXYRGYAYxjLlLrggKfXTTw+dJL9qrtwzfTKt3EUyt2QJHHNM8j0m69Yl55857TSYOjV0mvbWyMpdL4hEIhfremnrrWH+fNh332QWP2ZM6ET5otMPiORAbGOZXrvsAjfcAJMnwx//GDpNvqjcRSSosWOTI1gPOwxefDF0mvxQuYtELtaxTLmTT4ZPfCLZRfLvfw+dJh9U7iI5EOtYptysWcnnsOoj+rKhcheRtjBoUHIGyeXLk6NYZWBU7iKRy9Mqd/PN4bbbYN685EhWaZx2hRTJgTyMZXq9+93Jkav77gvDhsGnPhU6UZy0cheRtrPTTnDrrXDccXD//aHTxEnlLhK5PI1lyo0eDddfD5/5DDz8cOg08VG5i+RAnsYy5Q46KJm9f/KTsHRp7e3lLZq5i0Quryv3XkccAS+/DAcfDPfeCzvsEDpRHFTuIjmQ15V7r2OOSQr+wAPhvvtg221DJ2p/KncRicK0afDSS0nB33MPDB0aOlF7U7mLRC7vY5lyZ52VnJ5gzBi4+25473tDJ2pfKneRHMj7WKbc+ecn/97egt9mm9CJ2pPKXUSic955bxX8Pfeo4PuicheJXJHGMuXOPRc6OlTw1ajcRXKgSGOZcl/7WlLw++8PCxfC9tuHTtQ+VO4iErWzzoIttoD99oO77oLddw+dqD2o3EUiV9SxTLmuLnjHO5IP/LjtNvjIR0InCk/lLpIDRR3LlDv6aNhqKxg/Hm68ETo7QycKS+eWEZHcmDAhORf8kUfC/Pmh04SllbtI5DSW2dSYMfCznyUfuP3cc3D88aEThaFyF8kBjWU2tffeyUnGxo2DVaveOvCpSDSWEZFc2mUXeOAB+MlP4LrrQqdpPZW7SOQ0lqlu6NDkVMFr14ZO0noqd5EcKNrIQWpTuYtIrpkV89WNyl0kckUsrjRU7iISLY1lpJLKXSRyRVyVpqGVu4hIDqncRSRaGstUV9SfjcpdJHJFXJWmVcSfkcpdJAeKujqth8Yy/TCzcWa2zMyWm9mMPu4/ysyWlL7uN7MPZB9VRCQ9lXsVZtYBzAYOAUYBk81st4rNngD2d/c9gG8C3886qIj0rYjFJbXVs3IfDaxw95Xuvh6YC0wo38DdH3T3F0tXHwSGZRtTRPqjsUx1WrlXNwxYVXZ9Nf2X9xTgjoGEEhHJSlHLPdPzuZvZGOCLwL5ZPq6IVFfE4pLa6in3NcCIsuvDS7dtwsw+CMwBxrn7X6o9WHd398bLnZ2ddBb9gw5FMqCxTHUxrtx7enro6ekZ0GOY1/hXm9kg4DFgLPA08GtgsrsvLdtmBPB/wNHu/mA/j+W1nk9E0rn3XjjnnOS7/KPu7qTcZ84MnaRxZoa7p/oTXnPl7u4bzGw6sIBkRn+Nuy81s6nJ3T4HOBd4J3CFmRmw3t1Hp/8niEhaWi/1L8aVexbqmrm7+53ArhW3XVV2+QTghGyjiUi9NJbpXxHLXUeoikiuFXXlrnIXiVwRiysNlbuIREtjGamkcheRXNPKXUSiVMTiSkPlLiLR0lhGKqncRSTXtHIXkSgVsbjSULmLSLQ0lqlO5S4iUSpicUltKncRyTWt3EUkWhrLVKdyF5EoFbG4pDaVu0gOaOVenVbuIiI5pHIXkSgVsbjSKOqrGpW7SA4UtcDqVcQ/gCp3Eck1jWVEJEpFLK40VO4iEi2NZaSSyl1Eck0rdxGJUhGLKw2Vu4hES2MZqaRyF5Fc08pdRKJUxOJKQ+UuItHSWKa6ov5sVO4ikntauYtIdIpYXGloLCMi0Srq6KEeKncRiVIRi0tqU7mLSK5p5S4i0dJYpjqVu4hEqYjFJbWp3EUk17RyF5FoaSxTncpdRKJUxOJKQ+XeDzMbZ2bLzGy5mc3o4/5dzewBM/urmZ2efUwR6Y9W7lJpcK0NzKwDmA2MBZ4CFpnZre6+rGyzPwOnAIc3JaWISIO0cq9uNLDC3Ve6+3pgLjChfAN3f97dFwNvNCGjiPSjiMWVhsq9umHAqrLrq0u3iUib0FhGKtUcy2Stu7t74+XOzk46OztbHUFECiTGlXtPTw89PT0Deox6yn0NMKLs+vDSbQ0pL3cRGbjYiqvVYiz3yoXvzJkzUz9GPWOZRcDOZra9mQ0BJgHz+9leLxBFWkxjGalUc+Xu7hvMbDqwgOSPwTXuvtTMpiZ3+xwzGwo8BGwJvGlmpwLvc/dXmhleRKSWGFfuWahr5u7udwK7Vtx2VdnltcB22UYTkXoUsbjSKGq56whVkRzQWKY6lbuIiOSGyl0kckVclaahlbuIREtjmepU7iISpSIWl9SmcheRXNPKXUSipbFMdSp3EYlSEYtLalO5i0iuaeUuItHSWKY6lbuIRKmIxZWGyl1ERHJD5S6SAxrLVKeVu4hEqYjFlYbKXUSipZW7VFK5i0iuaeUuIlEqYnGloXIXkWhpLNM/lbuISM4U9Q+fyl0kckVclaahsYyIRKuoq9N6qNxFRCQ3VO4ikSviqjQNrdxFJFoay1SncheRKBWxuKQ2lbuI5JpW7iISLY1lqlO5i0iUilhcaajcRUQkN1TuIjmgsUx1WrmLSJSKWFxpqNxFRCQ3VO4iOaCxTHVauYtIlIpYXGmo3EVEJDfqKnczG2dmy8xsuZnNqLLNf5nZCjP7rZntmW1MEemPxjLVaeVehZl1ALOBQ4BRwGQz261im0OBndx9F2AqcGUTsgbX09MTOsKAKH9YzcrfiuKK+WdvBi+80BM6RsvVs3IfDaxw95Xuvh6YC0yo2GYCcB2Au/8KeLuZDc00aRuI+RcclD+0ZuZv9so95p+9yr26YcCqsuurS7f1t82aPrYREZEWGRw6QFYWLIDLLmvuczz2GCxe3NznaCblD6tZ+Vevhg98IPvHzYtBg2DtWvj0p5v3HOPHw0knNe/xG2FeY2BnZh8Dut19XOn6mYC7+8Vl21wJ3OPu80rXlwEHuPvaiscq4NsaIiID5+6phm/1rNwXATub2fbA08AkYHLFNvOBacC80h+DdZXF3kg4ERFpTM1yd/cNZjYdWEAyo7/G3Zea2dTkbp/j7reb2Xgz+yPwKvDF5sYWEZH+1BzLiIhIfFp+hKqZfcvMlpYOdrrZzLZqdYa06jmIq12Z2XAzu9vMHjWz35lZV+hMjTCzDjN72Mzmh86Slpm93cx+XPq9f9TMPho6Uxpm9mUz+72ZPWJm15vZkNCZ+mNm15jZWjN7pOy2d5jZAjN7zMzuMrO3h8zYnyr5U/dmiNMPLABGufuewArgrAAZ6lbPQVxt7g3gdHcfBewDTIssf69TgT+EDtGgWcDt7r47sAewNHCeupnZtsApwF7u/kGSUe6ksKlqupbk/6/lzgQWuvuuwN20d+/0lT91b7a83N19obu/Wbr6IDC81RlSqucgrrbl7s+4+29Ll18hKZaojkEws+HAeODq0FnSKq2w9nP3awHc/Q13fylwrLQGAVuY2WBgc+CpwHn65e73A3+puHkC8IPS5R8Ah7c0VAp95W+kN0OfOOw44I7AGWqp5yCuKJjZvwJ7Ar8KmyS17wJfAWJ8g2gH4Hkzu7Y0VppjZpuFDlUvd38K+A7wJ5KDE9e5+8KwqRrynt49+Nz9GeA9gfMMRF292ZRyN7Ofl+ZzvV+/K33/dNk2XwPWu/sNzcggmzKztwE3AaeWVvBRMLNPAmtLrz6s9BWTwcBewOXuvhfwGsmIIApmtjXJqnd7YFvgbWZ2VNhUmYhxoZCqN5tyhKq7H9Tf/WZ2LMnL7E804/kztgYYUXZ9eOm2aJReTt8E/I+73xo6T0ofBw4zs/HAZsCWZnadux8TOFe9VgOr3P2h0vWbgJjelD8QeMLdXwAws/8F/g2IbVG21syGuvtaM3sv8GzoQGml7c0Qe8uMI3mJfZi7/63Vz9+AjQdxlfYSmERy0FZM/hv4g7vPCh0kLXc/291HuPuOJD/7uyMqdkqjgFVmNrJ001jiemP4T8DHzOyfzcxI8sfwhnDlq7z5wLGly/8BtPsiZ5P8jfRmy/dzN7MVwBDgz6WbHnT3k1saIqXSD3YWbx3EdVHgSHUzs48D9wK/I3kp6sDZ7n5n0GANMLMDgP9098NCZ0nDzPYgeTP4n4AngC+6+4thU9XPzM4n+cO6HvgNMKW0c0FbMrMbgE7gX4C1wPnALcCPge2AlcCR7r4uVMb+VMl/Nil7UwcxiYjkUOi9ZUREpAlU7iIiOaRyFxHJIZW7iEgOqdxFRHJI5S4ikkMqdxGRHFK5i4jk0P8DbKFfDbjhJEoAAAAASUVORK5CYII=\n", 799 | "text/plain": [ 800 | "" 801 | ] 802 | }, 803 | "metadata": {}, 804 | "output_type": "display_data" 805 | } 806 | ], 807 | "source": [ 808 | "pmu2 = pxt(7,mus)*pmu1\n", 809 | "pmu2 /= C*sum(pmu2)\n", 810 | "plot(xs,pmu2)" 811 | ] 812 | }, 813 | { 814 | "cell_type": "code", 815 | "execution_count": 13, 816 | "metadata": { 817 | "slideshow": { 818 | "slide_type": "slide" 819 | } 820 | }, 821 | "outputs": [ 822 | { 823 | "data": { 824 | "text/plain": [ 825 | "[]" 826 | ] 827 | }, 828 | "execution_count": 13, 829 | "metadata": {}, 830 | "output_type": "execute_result" 831 | }, 832 | { 833 | "data": { 834 | "image/png": "\n", 835 | "text/plain": [ 836 | "" 837 | ] 838 | }, 839 | "metadata": {}, 840 | "output_type": "display_data" 841 | } 842 | ], 843 | "source": [ 844 | "result = zeros(xs.shape)\n", 845 | "for i,mu in enumerate(mus):\n", 846 | " weight = pmu2[i]\n", 847 | " result += weight * pxt(xs,mu)\n", 848 | "result /= C*sum(result)\n", 849 | "plot(xs,result)\n", 850 | "plot(xs,pxt(xs,7))\n", 851 | "plot(xs,pxt(xs,4))" 852 | ] 853 | }, 854 | { 855 | "cell_type": "markdown", 856 | "metadata": { 857 | "slideshow": { 858 | "slide_type": "slide" 859 | } 860 | }, 861 | "source": [ 862 | "This is even weirder. After seeing the first sample, the maximum likelihood\n", 863 | "estimator predicts only values between 0 and 4 occurring, but after seeing\n", 864 | "another training sample, it is changing its mind and now predicts that values\n", 865 | "between 0 and 7 can occur.\n", 866 | "\n", 867 | "The Bayesian estimator, in contrast, \"knows\" that the parameter must be greater than 7,\n", 868 | "so it predicts a uniform distribution for the interval [0...7] and then a tradeoff\n", 869 | "between the parameter distribution and the uniform distribution of the parameters." 870 | ] 871 | }, 872 | { 873 | "cell_type": "markdown", 874 | "metadata": { 875 | "slideshow": { 876 | "slide_type": "slide" 877 | } 878 | }, 879 | "source": [ 880 | "The last sample illustrates this further.\n", 881 | "\n", 882 | "A sample of $x_3=2$ doesn't cause any update to the maximum likelihood estimator,\n", 883 | "but it does cause an update to posterior distribution." 884 | ] 885 | }, 886 | { 887 | "cell_type": "code", 888 | "execution_count": 14, 889 | "metadata": { 890 | "slideshow": { 891 | "slide_type": "slide" 892 | } 893 | }, 894 | "outputs": [ 895 | { 896 | "data": { 897 | "text/plain": [ 898 | "[]" 899 | ] 900 | }, 901 | "execution_count": 14, 902 | "metadata": {}, 903 | "output_type": "execute_result" 904 | }, 905 | { 906 | "data": { 907 | "image/png": "\n", 908 | "text/plain": [ 909 | "" 910 | ] 911 | }, 912 | "metadata": {}, 913 | "output_type": "display_data" 914 | } 915 | ], 916 | "source": [ 917 | "pmu3 = pxt(2,mus)*pmu2\n", 918 | "pmu3 /= C*sum(pmu3)\n", 919 | "plot(xs,pmu3)\n", 920 | "plot(xs,pmu2)" 921 | ] 922 | }, 923 | { 924 | "cell_type": "code", 925 | "execution_count": 15, 926 | "metadata": { 927 | "slideshow": { 928 | "slide_type": "slide" 929 | } 930 | }, 931 | "outputs": [ 932 | { 933 | "data": { 934 | "text/plain": [ 935 | "[]" 936 | ] 937 | }, 938 | "execution_count": 15, 939 | "metadata": {}, 940 | "output_type": "execute_result" 941 | }, 942 | { 943 | "data": { 944 | "image/png": "\n", 945 | "text/plain": [ 946 | "" 947 | ] 948 | }, 949 | "metadata": {}, 950 | "output_type": "display_data" 951 | } 952 | ], 953 | "source": [ 954 | "result = zeros(xs.shape)\n", 955 | "total = 0\n", 956 | "for i,mu in enumerate(mus):\n", 957 | " weight = pmu3[i]\n", 958 | " result += weight * pxt(xs,mu)\n", 959 | " total += weight\n", 960 | "result /= total\n", 961 | "plot(xs,result)\n", 962 | "plot(xs,pxt(xs,7))" 963 | ] 964 | }, 965 | { 966 | "cell_type": "markdown", 967 | "metadata": { 968 | "slideshow": { 969 | "slide_type": "slide" 970 | } 971 | }, 972 | "source": [ 973 | "In fact, if we repeat the same process with a lot of samples (in this case\n", 974 | "the true parameter is 7), we see that the Bayesian parameter estimate\n", 975 | "becomes an increasingly peaked distribution close to the true value.\n", 976 | "\n", 977 | "I.e., if, out of 100 samples, we haven't seen a value greater than 7,\n", 978 | "then the probability that the mean is significantly greater than 7 must\n", 979 | "be very small." 980 | ] 981 | }, 982 | { 983 | "cell_type": "code", 984 | "execution_count": 16, 985 | "metadata": { 986 | "slideshow": { 987 | "slide_type": "slide" 988 | } 989 | }, 990 | "outputs": [ 991 | { 992 | "data": { 993 | "text/plain": [ 994 | "[]" 995 | ] 996 | }, 997 | "execution_count": 16, 998 | "metadata": {}, 999 | "output_type": "execute_result" 1000 | }, 1001 | { 1002 | "data": { 1003 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAEACAYAAACj0I2EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAEndJREFUeJzt3XusnHWdx/H3tz2Ua1tUUkS6gizUNSTIEi+sLrsTK9JoBP/Y3SC7q2CWf1aBoAGh/tGDiUYlXshe/jBAwxqbTaxmxY2XgjhsWFddQShCKSQaiyD1slyD1kK/+8eZqQ/Dc3rmds4zv+n7lUw685zn8k3z9NPvfOd55kRmIkkq37KmC5AkjYeBLklTwkCXpClhoEvSlDDQJWlKGOiSNCUWDPSIuCEidkfE9pqffSgi9kXESxenPElSv/rp0DcD5/QujIi1wNnAz8ZdlCRpcAsGembeATxe86PPAleMvSJJ0lCGmqFHxLnAw5l575jrkSQNaWbQDSLicGAjc+OW/YvHVpEkaSgDBzrwx8CJwD0REcBa4M6IeENm/rJ35Yjwy2IkaQiZOVCz3O/IJToPMvPHmfnyzDwpM18F/Bz407owrxRV7GPTpk2N13Cw1l9y7dbf/KP0+ofRz2WLW4DvAusiYldEXNSb1zhykaTGLThyycwLFvj5SeMrR5I0LO8UXUCr1Wq6hJGUXH/JtYP1N630+ocRw85q+j5ARC72MSRp2kQEuUgfikqSJpyBLklTwkCXpClhoEvSlDDQJWlKGOiSNCUMdEmaEga6JE0JA10q0J498OijTVehSWOgSwXauBGOP77pKjRpDHSpQL+c98uqdTAz0KUCLV/edAWaRAa6VKBl/stVDU8LqUAGuup4WkgFMtBVx9NCKpCBrjqeFlKBDHTV8bSQCmSgq46nhVQgA111PC2kAhnoquNpIRXIQFedBU+LiLghInZHxPbKsk9FxI6IuDsivhwRqxa3TElVBrrq9HNabAbO6Vm2DTg1M08HHgKuHndhkubnrf+qs2CgZ+YdwOM9y27NzH2dl98D1i5CbZLmYYeuOuM4Ld4HfGMM+5HUJwNddWZG2TgiPgLszcwtB1pvdnZ2//NWq0Wr1RrlsNJBz0CfPu12m3a7PdI+IjMXXiniBOBrmXlaZdmFwMXAWzJzzwG2zX6OIal/mzbBRz8K/tOaXhFBZsYg2/TboUfn0T3QBuAK4C8OFOaSFocduur0c9niFuC7wLqI2BURFwH/BBwF3BIRd0XEvy5ynZIqDHTVWbBDz8wLahZvXoRaJPXJQFcdTwupQAa66nhaSAUy0FXH00IqkIGuOp4WUoG89V91DHSpQDHQ1ck6WBjoUoEcuaiOp4VUIEcuqmOgSwWyQ1cdTwupQAa66nhaSAUy0FXH00IqkIGuOp4WUoEMdNXxtJAKZKCrjqeFVKDujUX+ggtVGehSwZ5/vukKNEkMdKlA3c58375m69BkMdClgtmhq8pAlwpmh64qA10qUHfkYoeuKgNdKpgduqoMdKlgBrqqDHSpQI5cVMdAlwpmh66qBQM9Im6IiN0Rsb2y7CURsS0idkbEtyJi9eKWKamOHbqq+unQNwPn9Cy7Crg1M18N3AZcPe7CJM3PG4tUZ8FAz8w7gMd7Fp8H3NR5fhPwrjHXJakPduiqGnaGviYzdwNk5mPAmvGVJGkhduiqMzOm/RzwO99mZ2f3P2+1WrRarTEdVjq42aFPj3a7TbvdHmkfkX18/2ZEnAB8LTNP67zeAbQyc3dEvBz4Tma+Zp5ts59jSOrf9dfDxRfDzp2wbl3T1WgxRASZGYNs0+/IJTqPrpuBCzvP3wt8dZCDShqNIxfV6eeyxS3Ad4F1EbErIi4CPgGcHRE7gfWd15KWmCMXVS04Q8/MC+b50VvHXIukAdmhq8o7RaUCeeu/6hjoUsHs0FVloEsFs0NXlYEuFcirXFTHQJcKZoeuKgNdKpAduuoY6FLBDHRVGehSwRy5qMpAlwrkdeiqY6BLBXvuuaYr0CQx0KWC2aGrykCXCtQdudihq8pAlwpmh64qA10qmB26qgx0qUBe5aI6BrpUMDt0VRnoUsHs0FVloEsFcuSiOga6VDBHLqoy0KUC2aGrjoEuFcwOXVUGulQwO3RVGehSgbz1X3VGCvSIuDwifhwR2yPiixGxYlyFSVqYHbqqhg70iHgFcAlwRmaeBswA54+rMEkLs0NX1cyI2y8HjoyIfcARwKOjlyRpIV7lojpDd+iZ+SjwaWAX8AjwRGbeOq7CJC3MDl1VQ3foEXE0cB5wAvAksDUiLsjMLb3rzs7O7n/earVotVrDHlZShR369Gi327Tb7ZH2Edl97zbohhF/BZyTmRd3Xv898MbM/EDPejnsMSTV+9zn4PLL4cor4ZOfbLoaLYaIIDNjkG1GucplF3BmRBwWEQGsB3aMsD9JA3LkoqpRZug/ALYCPwLuAQL4/JjqktQHRy6qGukql8y8BrhmTLVI6lMmLF9uh64X8k5RqVAzM3boeiEDXSqQHbrqGOhSoezQ1ctAlwo1M2OHrhcy0KUCZdqh68UMdKlQBrp6GehSofxQVL0MdKlAjlxUx0CXCuWHouploEuFskNXLwNdKlB35GKHrioDXSqUHbp6GehSobzKRb0MdKlAXuWiOga6VChn6OploEsF6n7boh26qgx0qVCOXNTLQJcK5chFvQx0qUBeh646BrpUqBUrYO/epqvQJDHQpUIdcoiBrhcy0KUCZdqh68UMdKlQdujqNVKgR8TqiPhSROyIiPsi4o3jKkzSgfmhqHrNjLj9dcDXM/OvI2IGOGIMNUlagCMX1Rk60CNiFXBWZl4IkJnPAU+NqS5JC3Dkol6jjFxeBfw6IjZHxF0R8fmIOHxchUk6MANdvUYZucwAZwDvz8wfRsTngKuATb0rzs7O7n/earVotVojHFZS98aiffvmHsu8vKF47Xabdrs90j4iM4fbMOJY4H8y86TO6z8HPpyZ7+xZL4c9hqR6H/sYPPssXHstPP00HHpo0xVp3CKCzIxBthn6//XM3A08HBHrOovWA/cPuz9J/ev2SI5dVDXqVS6XAl+MiEOAnwAXjV6SpH4dcoiXLuoPRgr0zLwHeP2YapE0gIi5Obodurr8KEUqkCMX1THQpYIZ6Koy0KVCRRjoeiEDXSqQIxfVMdClQtmhq5eBLhXMQFeVgS4VqDty8St0VWWgS4Vy5KJeBrpUMANdVQa6VCCvclEdA10qlCMX9TLQpQLZoauOgS4VzEBXlYEuFar7bYtetqguA10qkCMX1THQpYIZ6Koy0KVCRcCKFfD73zddiSaFgS4VqDtyMdBVZaBLhYqAQw+FPXuarkSTwkCXCmagq8pAlwrUHbkY6Koy0KVCOXJRLwNdKpiBrqqRAz0ilkXEXRFx8zgKkrQwRy6qM44O/TLg/jHsR9IAvA5dvUYK9IhYC7wduH485Ujqhx266ozaoX8WuALIMdQiaUAGuqpmht0wIt4B7M7MuyOiBcR8687Ozu5/3mq1aLVawx5WUodXuUyXdrtNu90eaR+ROVxzHREfB/4OeA44HFgJfCUz39OzXg57DEn1rr4aVq6E170Orr0Wbrml6Yo0bhFBZs7bKNcZeuSSmRsz85WZeRJwPnBbb5hLWlx26KryOnSpUI5c1GvoGXpVZt4O3D6OfUlamN+2qDp26FKh7NDVy0CXCmagq8pAlwrkjUWqY6BLhXLkol4GulQwA11VBrpUIEcuqmOgS4Xqjlz27oV9+5quRpPAQJcK1O3QI+Cww+C3v222Hk0GA10q3BFHwLPPNl2FJoGBLhUqOl/bdMQRduiaY6BLBap+genhh9uha46BLhXOkYu6DHSpUI5c1MtAlwrkyEV1DHSpUNUO3UAXGOhS8Ry5qMtAlwrkyEV1DHSpUI5c1MtAlwrnyEVdBrpUIEcuqmOgS4Vy5KJeBrpUoGqHftRR8MwzzdWiyWGgS4VbuRKefrrpKjQJhg70iFgbEbdFxH0RcW9EXDrOwiQdWHfksmoVPPVUs7VoMsyMsO1zwAcz8+6IOAq4MyK2ZeYDY6pN0jyqIxc7dHUN3aFn5mOZeXfn+TPADuD4cRUmqT926Ooayww9Ik4ETge+P479SVpYd+SycqWBrjmjjFwA6IxbtgKXdTr1F5mdnd3/vNVq0Wq1Rj2sdFCrjlxWrXLkMg3a7TbtdnukfURWz4xBN46YAf4T+EZmXjfPOjnKMSS92KWXwsknz/359NNw3HFeujhtIoLMjEG2GXXkciNw/3xhLmnxHXnk3K3/zz/fdCVq2iiXLb4Z+FvgLRHxo4i4KyI2jK80SfOpvuldtmwu1O3QNfQMPTP/G1g+xlokDSAqb8ZXroQnn4TVq5urR83zTlFpCrz0pfD4401XoaYZ6FKBeq8zeNnL4De/aaYWTQ4DXSpUdeRioAsMdKlIduiqY6BLU8BAFxjoUrEcuaiXgS4VqHfkcswxBroMdGkqHHMM/OpXTVehphnoUqGqI5fjjoNf/KK5WjQZDHSpQL0jl+OPh0ceaaYWTQ4DXSpUtUNfswaeeAL27GmuHjXPQJemwLJlcOyx8NhjTVeiJhnoUoHqfsWAYxcZ6FKhoudXH5x4Ivz0p42UoglhoEtT4pRT4MEHm65CTTLQpQLVjVzWrYOHHlr6WjQ5DHSpUL0jl1NOgZ07m6lFk8FAlwpU16Gfeio88ADs3bv09WgyGOjSlFi5Ek44Ae67r+lK1BQDXSpU78gF4PWvh+9/f+lr0WQw0KUC1Y1cANavh23blrYWTQ4DXZoiGzbAt78Nv/td05WoCSMFekRsiIgHIuLBiPjwuIqStLC6kcuaNXDmmbB169LXo+YNHegRsQz4Z+Ac4FTg3RHxJ+MqbFK02+2mSxhJyfWXXDssbv3zjVwALrsMPv7x0a928e+/PKN06G8AHsrMn2XmXuDfgfPGU9bkKP2kKLn+kmuHxa+/rkOHubHLySfDJZfAvn3D79+///KMEujHAw9XXv+8s0xSgyLgC1+A7dvhbW+Db34Tnnmm6aq0FGaaLmAUn/kMfOc7i3uMnTvhzjsX9xiLqeT6S64dFrf+e+6BN71p/p+vXg233w433gjXXDO3/tFHz/0y6VWrYGYGli9/4aO343/wQfjhDxen/qVwoPrne3cziCuvhLPOGn0/4xR5oGHcgTaMOBOYzcwNnddXAZmZn+xZb7gDSNJBLjMH+q9nlEBfDuwE1gO/AH4AvDszdwy1Q0nSSIYeuWTm8xHxAWAbc7P4GwxzSWrO0B26JGmyLMmdohHxqYjYERF3R8SXI2LVUhx3FCXfNBURayPitoi4LyLujYhLm65pGBGxLCLuioibm65lUBGxOiK+1Dnv74uINzZd0yAi4vKI+HFEbI+IL0bEiqZrOpCIuCEidkfE9sqyl0TEtojYGRHfiojVTdZ4IPPUP3BuLtWt/9uAUzPzdOAh4OolOu5QpuCmqeeAD2bmqcCfAe8vrP6uy4D7my5iSNcBX8/M1wCvBYoZR0bEK4BLgDMy8zTmRrPnN1vVgjYz9++16irg1sx8NXAbk507dfUPnJtLEuiZeWtmdm9x+B6wdimOO4Kib5rKzMcy8+7O82eYC5Oi7hGIiLXA24Hrm65lUJ1O6qzM3AyQmc9l5lMNlzWo5cCRETEDHAE82nA9B5SZdwCP9yw+D7ip8/wm4F1LWtQA6uofJjeb+HKu9wHfaOC4g5iam6Yi4kTgdKC0L1X9LHAFUOKHPK8Cfh0Rmzsjo89HxOFNF9WvzHwU+DSwC3gEeCIzb222qqGsyczdMNfkAGsarmcUfeXm2AI9Im7pzNu6j3s7f76zss5HgL2ZuWVcx9X8IuIoYCtwWadTL0JEvAPY3XmXEZ1HSWaAM4B/ycwzgGeZe/tfhIg4mrnu9gTgFcBREXFBs1WNRYnNwUC5ObY7RTPz7AWKupC5t9BvGdcxF9EjwCsrr9d2lhWj81Z5K/CFzPxq0/UM6M3AuRHxduBwYGVE/Ftmvqfhuvr1c+DhzOzep7gVKOmD9bcCP8nM/wOIiK8AbwJKa8R2R8Sxmbk7Il4O/LLpggY1aG4u1VUuG5h7+3xuZu5ZimOO6H+BkyPihM6n++cDpV1pcSNwf2Ze13Qhg8rMjZn5ysw8ibm/+9sKCnM6b/Mfjoh1nUXrKevD3V3AmRFxWEQEc/WX8KFu77u5m4ELO8/fC0x6Y/OC+ofJzSW5Dj0iHgJWAL/pLPpeZv7joh94BJ2/zOv4w01Tn2i4pL5FxJuB/wLuZe5tZgIbM/ObjRY2hIj4S+BDmXlu07UMIiJey9wHuocAPwEuyswnm62qfxGxibn/TPcCPwL+oXOBwESKiC1AC3gZsBvYBPwH8CXgj4CfAX+TmU80VeOBzFP/RgbMTW8skqQp4a+gk6QpYaBL0pQw0CVpShjokjQlDHRJmhIGuiRNCQNdkqaEgS5JU+L/AZJJYMxLBHozAAAAAElFTkSuQmCC\n", 1004 | "text/plain": [ 1005 | "" 1006 | ] 1007 | }, 1008 | "metadata": {}, 1009 | "output_type": "display_data" 1010 | } 1011 | ], 1012 | "source": [ 1013 | "p = pmu.copy()\n", 1014 | "for i in range(100):\n", 1015 | " p = pxt(rand()*7,mus)*p\n", 1016 | " p /= C*sum(p)\n", 1017 | "plot(mus,p)" 1018 | ] 1019 | }, 1020 | { 1021 | "cell_type": "markdown", 1022 | "metadata": { 1023 | "slideshow": { 1024 | "slide_type": "slide" 1025 | } 1026 | }, 1027 | "source": [ 1028 | "Loss Functions for Parameter Estimation\n", 1029 | "========================================\n", 1030 | "\n", 1031 | "Consider $p(\\theta|x)$ from the previous example again." 1032 | ] 1033 | }, 1034 | { 1035 | "cell_type": "code", 1036 | "execution_count": 17, 1037 | "metadata": { 1038 | "slideshow": { 1039 | "slide_type": "slide" 1040 | } 1041 | }, 1042 | "outputs": [ 1043 | { 1044 | "data": { 1045 | "text/plain": [ 1046 | "[]" 1047 | ] 1048 | }, 1049 | "execution_count": 17, 1050 | "metadata": {}, 1051 | "output_type": "execute_result" 1052 | }, 1053 | { 1054 | "data": { 1055 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAEACAYAAABI5zaHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAFnJJREFUeJzt3X20VHW9x/H398BlXTXNuhUpiNcn1Kg0S8qbD4fwAalEi7xg6TVFScGjeS3UfDhULrVVJlc0Jb2uvGlgelMsH4irZ6mZhVhYBkK6JEBFzfC5QvzeP/YcHKYzZ2bP2TO/+e39ea111pmH3cyH4+lzfvOdvfeYuyMiIvnSETqAiIhkT+UuIpJDKncRkRxSuYuI5JDKXUQkh1TuIiI5VFe5m9k4M1tmZsvNbEYf9x9gZuvM7OHS1znZRxURkXoNrrWBmXUAs4GxwFPAIjO71d2XVWx6r7sf1oSMIiKSUj0r99HACndf6e7rgbnAhD62s0yTiYhIw+op92HAqrLrq0u3VdrHzH5rZj8zs/dlkk5ERBpScyxTp8XACHd/zcwOBW4BRmb02CIiklI95b4GGFF2fXjpto3c/ZWyy3eY2RVm9k53f6F8OzPTiWxERBrg7qlG3/WMZRYBO5vZ9mY2BJgEzC/fwMyGll0eDVhlsZcFjPbr/PPPD55B+cPnKGL+mLPnIX8jaq7c3X2DmU0HFpD8MbjG3Zea2dTkbp8DTDSzk4D1wOvAvzeURkREMlHXzN3d7wR2rbjtqrLLlwOXZxtNREQapSNUU+js7AwdYUCUP6yY88ecHeLP3whrdJ7T0JOZeSufT0QkD8wMb8IbqiIiEhmVu4hIDqncRURySOUukgN6K0sqqdxFIvfKK7DjjvDqq6GTSDtRuYtE7m9/gyefhHnzQieRdqJyF8mJ730vdAJpJyp3kci5w9Zbw/PPw0MPhU4j7ULlLpIDgwbBiSfClVeGTiLtQuUukhPHHQc33wzr1oVOIu1A5S4SOXcwg6FD4eCD4Yc/DJ1I2oHKXSRHTjopeWNV+72Lyl0kRw44ADZsgPvvD51EQlO5i0SudywDyfcvfQmuuCJsJglP5S6SM8ceC3fdBWvW1NxUckzlLhK5yvn61lvDUUfpoKaiU7mL5IBVfIxDVxfMmQOvvx4mj4SnchfJoZEjYe+94YYbQieRUFTuIpGrttvjaafBrFnaLbKoVO4iOVA5lgE48MBkt8h77ml9HglP5S6SU2bJ7H3WrNBJJASVu0jk+hu7HH00PPAAPP546/JIe1C5i+RAX2MZgM03hylT4NJLW5tHwlO5i+RcVxdcfz0891zoJNJKKneRyNXaG2abbWDiRJg9uzV5pD2o3EVyoNpYptcZZyRHrOpDtItD5S5SACNHwn77wTXXhE4iraJyF4lcvQcpzZgB3/kOrF/f3DzSHlTuIjlQaywDMHo07LgjzJvX/DwSnspdpEC++lX41rd0SoIiULmLRC5NUY8bBx0d8NOfNi+PtAeVu0gO1DOW6d3u3HPh61/X6j3vVO4iBXPEEfDXv8Idd4ROIs2kcheJXPlnqNajoyNZvc+cqdV7ntVV7mY2zsyWmdlyM5vRz3Z7m9l6M/tMdhFFJGsTJ8IrrySftSr5VLPczawDmA0cAowCJpvZblW2uwjQr4tICzWy+tbqPf/qWbmPBla4+0p3Xw/MBSb0sd0pwE3AsxnmE5E6pBnL9Prc52DdOli4MPs8El495T4MWFV2fXXpto3MbFvgcHf/HtDAr5mItNqgQXDOOVq959XgjB7nUqB8Fl+14Lu7uzde7uzspLOzM6MIIsU0kGKeNAm+8Q34+c/h4IOzyyQD09PTQ09Pz4Aew7zGb4aZfQzodvdxpetnAu7uF5dt80TvReBdwKvAie4+v+KxvNbziUg6K1fC/vsn3xtx443JUauLFjU23pHmMzPcPdV/nXrGMouAnc1sezMbAkwCNiltd9+x9LUDydz95MpiF5H2NHEivPkm3Hxz6CSSpZrl7u4bgOnAAuBRYK67LzWzqWZ2Yl//k4wzikg/BvpiuKMDLrwwmb+/8UY2mSS8mmOZTJ9MYxmRzD35JHR2Jt8b5Q5jxiQfqH388RkFk8w0aywjIjlnlqzeZ85MTk0g8VO5i0QuqxfD++wDH/pQ8nF8Ej+Vu0gOZLWXywUXwEUXwUsvZfN4Eo7KXUQ2ev/74dBDkxGNxE3lLhK5rPdRuOACmDNnYG/QSngqd5EcyPLgo2HDoKsLzjwzu8eU1lO5i8g/OOMM+MUv4Je/DJ1EGqVyF4lcMw4d2WKLZDzz5S8nR69KfFTuIjnQjHPCfOELyRGr8+Zl/9jSfCp3EelTRwd897vJ7P3110OnkbRU7iKRS/sZqmnstx+MHg0XX1x7W2kvKncR6dcll8Ds2fDEE7W3lfahcheJXLPPxbfddsneM6ee2tznkWyp3EVyoNkfsnH66bBiBdx2W3OfR7KjcheRmoYMgcsuSw5u0purcVC5i0SuVR+RcNBBsPfeyYnFpP2p3EVyoFWffXrJJXD55cmIRtqbyl1E6jZ8OJx9Nkyd2rpXDNIYlbtI5Fpdsl1d8PLLcO21rX1eSUflLpIDrRrLAAweDFdfnRy5+swzrXteSUflLiKp7bEHTJkCp5wSOolUo3IXiVyo2fe558KSJXDLLWGeX/qnchfJgVaOZXptthl8//swfTq8+GLrn1/6p3IXkYYdcACMHw9f+UroJFJJ5S4SudC7JH7727BgAdx+e9gcsimVu0gOhBjL9Npqq2S3yBNOgBdeCJdDNqVyF5EBGzMGJk6EadNCJ5FeKneRyIUey/S68EJ4+GG48cbQSQRU7iK5EHIs02vzzeG665J9359+OnQaUbmLSGY++tFk9j5lSvu8oigqlbtI5NqtRM87D559NvloPglncOgAIjJw7TCW6TVkCPzoR7DPPskHbO+5Z+hExaSVu4hkbued4dJLYdIkePXV0GmKSeUuEjn39lq59/r855MZfFdX6CTFpHIXkaaZPRvuuw/mzg2dpHjqKnczG2dmy8xsuZnN6OP+w8xsiZn9xsx+bWYfzz6qiPSl3d5QLbfllsn8vasLHn88dJpiqVnuZtYBzAYOAUYBk81st4rNFrr7Hu7+IeB44OrMk4pIVe04lun14Q8ne9B89rPw2muh0xRHPSv30cAKd1/p7uuBucCE8g3cvfw/2duAN7OLKCKxmzYNRo2Ck09u71caeVJPuQ8DVpVdX126bRNmdriZLQVuA47LJp6I1BJDWZrBnDmweHHyXZovszdU3f0Wd98dOBz4ZlaPKyK1tfNYptcWW8DNNyef4LRoUeg0+VfPQUxrgBFl14eXbuuTu99vZjua2Tvd/R9OANrd3b3xcmdnJ52dnXWHFZG4jRwJV12VnEFy8WJ417tCJ2pPPT099PT0DOgxzGu8pjOzQcBjwFjgaeDXwGR3X1q2zU7u/njp8l7Are6+XR+P5bWeT0TSWbIEjj4aHnkkdJL6zZiRlPudd8JgHSdfk5nh7qlen9Ucy7j7BmA6sAB4FJjr7kvNbKqZnVja7LNm9nszexi4DDgyZXYRGYAYxjLlLrggKfXTTw+dJL9qrtwzfTKt3EUyt2QJHHNM8j0m69Yl55857TSYOjV0mvbWyMpdL4hEIhfremnrrWH+fNh332QWP2ZM6ET5otMPiORAbGOZXrvsAjfcAJMnwx//GDpNvqjcRSSosWOTI1gPOwxefDF0mvxQuYtELtaxTLmTT4ZPfCLZRfLvfw+dJh9U7iI5EOtYptysWcnnsOoj+rKhcheRtjBoUHIGyeXLk6NYZWBU7iKRy9Mqd/PN4bbbYN685EhWaZx2hRTJgTyMZXq9+93Jkav77gvDhsGnPhU6UZy0cheRtrPTTnDrrXDccXD//aHTxEnlLhK5PI1lyo0eDddfD5/5DDz8cOg08VG5i+RAnsYy5Q46KJm9f/KTsHRp7e3lLZq5i0Quryv3XkccAS+/DAcfDPfeCzvsEDpRHFTuIjmQ15V7r2OOSQr+wAPhvvtg221DJ2p/KncRicK0afDSS0nB33MPDB0aOlF7U7mLRC7vY5lyZ52VnJ5gzBi4+25473tDJ2pfKneRHMj7WKbc+ecn/97egt9mm9CJ2pPKXUSic955bxX8Pfeo4PuicheJXJHGMuXOPRc6OlTw1ajcRXKgSGOZcl/7WlLw++8PCxfC9tuHTtQ+VO4iErWzzoIttoD99oO77oLddw+dqD2o3EUiV9SxTLmuLnjHO5IP/LjtNvjIR0InCk/lLpIDRR3LlDv6aNhqKxg/Hm68ETo7QycKS+eWEZHcmDAhORf8kUfC/Pmh04SllbtI5DSW2dSYMfCznyUfuP3cc3D88aEThaFyF8kBjWU2tffeyUnGxo2DVaveOvCpSDSWEZFc2mUXeOAB+MlP4LrrQqdpPZW7SOQ0lqlu6NDkVMFr14ZO0noqd5EcKNrIQWpTuYtIrpkV89WNyl0kckUsrjRU7iISLY1lpJLKXSRyRVyVpqGVu4hIDqncRSRaGstUV9SfjcpdJHJFXJWmVcSfkcpdJAeKujqth8Yy/TCzcWa2zMyWm9mMPu4/ysyWlL7uN7MPZB9VRCQ9lXsVZtYBzAYOAUYBk81st4rNngD2d/c9gG8C3886qIj0rYjFJbXVs3IfDaxw95Xuvh6YC0wo38DdH3T3F0tXHwSGZRtTRPqjsUx1WrlXNwxYVXZ9Nf2X9xTgjoGEEhHJSlHLPdPzuZvZGOCLwL5ZPq6IVFfE4pLa6in3NcCIsuvDS7dtwsw+CMwBxrn7X6o9WHd398bLnZ2ddBb9gw5FMqCxTHUxrtx7enro6ekZ0GOY1/hXm9kg4DFgLPA08GtgsrsvLdtmBPB/wNHu/mA/j+W1nk9E0rn3XjjnnOS7/KPu7qTcZ84MnaRxZoa7p/oTXnPl7u4bzGw6sIBkRn+Nuy81s6nJ3T4HOBd4J3CFmRmw3t1Hp/8niEhaWi/1L8aVexbqmrm7+53ArhW3XVV2+QTghGyjiUi9NJbpXxHLXUeoikiuFXXlrnIXiVwRiysNlbuIREtjGamkcheRXNPKXUSiVMTiSkPlLiLR0lhGKqncRSTXtHIXkSgVsbjSULmLSLQ0lqlO5S4iUSpicUltKncRyTWt3EUkWhrLVKdyF5EoFbG4pDaVu0gOaOVenVbuIiI5pHIXkSgVsbjSKOqrGpW7SA4UtcDqVcQ/gCp3Eck1jWVEJEpFLK40VO4iEi2NZaSSyl1Eck0rdxGJUhGLKw2Vu4hES2MZqaRyF5Fc08pdRKJUxOJKQ+UuItHSWKa6ov5sVO4ikntauYtIdIpYXGloLCMi0Srq6KEeKncRiVIRi0tqU7mLSK5p5S4i0dJYpjqVu4hEqYjFJbWp3EUk17RyF5FoaSxTncpdRKJUxOJKQ+XeDzMbZ2bLzGy5mc3o4/5dzewBM/urmZ2efUwR6Y9W7lJpcK0NzKwDmA2MBZ4CFpnZre6+rGyzPwOnAIc3JaWISIO0cq9uNLDC3Ve6+3pgLjChfAN3f97dFwNvNCGjiPSjiMWVhsq9umHAqrLrq0u3iUib0FhGKtUcy2Stu7t74+XOzk46OztbHUFECiTGlXtPTw89PT0Deox6yn0NMKLs+vDSbQ0pL3cRGbjYiqvVYiz3yoXvzJkzUz9GPWOZRcDOZra9mQ0BJgHz+9leLxBFWkxjGalUc+Xu7hvMbDqwgOSPwTXuvtTMpiZ3+xwzGwo8BGwJvGlmpwLvc/dXmhleRKSWGFfuWahr5u7udwK7Vtx2VdnltcB22UYTkXoUsbjSKGq56whVkRzQWKY6lbuIiOSGyl0kckVclaahlbuIREtjmepU7iISpSIWl9SmcheRXNPKXUSipbFMdSp3EYlSEYtLalO5i0iuaeUuItHSWKY6lbuIRKmIxZWGyl1ERHJD5S6SAxrLVKeVu4hEqYjFlYbKXUSipZW7VFK5i0iuaeUuIlEqYnGloXIXkWhpLNM/lbuISM4U9Q+fyl0kckVclaahsYyIRKuoq9N6qNxFRCQ3VO4ikSviqjQNrdxFJFoay1SncheRKBWxuKQ2lbuI5JpW7iISLY1lqlO5i0iUilhcaajcRUQkN1TuIjmgsUx1WrmLSJSKWFxpqNxFRCQ3VO4iOaCxTHVauYtIlIpYXGmo3EVEJDfqKnczG2dmy8xsuZnNqLLNf5nZCjP7rZntmW1MEemPxjLVaeVehZl1ALOBQ4BRwGQz261im0OBndx9F2AqcGUTsgbX09MTOsKAKH9YzcrfiuKK+WdvBi+80BM6RsvVs3IfDaxw95Xuvh6YC0yo2GYCcB2Au/8KeLuZDc00aRuI+RcclD+0ZuZv9so95p+9yr26YcCqsuurS7f1t82aPrYREZEWGRw6QFYWLIDLLmvuczz2GCxe3NznaCblD6tZ+Vevhg98IPvHzYtBg2DtWvj0p5v3HOPHw0knNe/xG2FeY2BnZh8Dut19XOn6mYC7+8Vl21wJ3OPu80rXlwEHuPvaiscq4NsaIiID5+6phm/1rNwXATub2fbA08AkYHLFNvOBacC80h+DdZXF3kg4ERFpTM1yd/cNZjYdWEAyo7/G3Zea2dTkbp/j7reb2Xgz+yPwKvDF5sYWEZH+1BzLiIhIfFp+hKqZfcvMlpYOdrrZzLZqdYa06jmIq12Z2XAzu9vMHjWz35lZV+hMjTCzDjN72Mzmh86Slpm93cx+XPq9f9TMPho6Uxpm9mUz+72ZPWJm15vZkNCZ+mNm15jZWjN7pOy2d5jZAjN7zMzuMrO3h8zYnyr5U/dmiNMPLABGufuewArgrAAZ6lbPQVxt7g3gdHcfBewDTIssf69TgT+EDtGgWcDt7r47sAewNHCeupnZtsApwF7u/kGSUe6ksKlqupbk/6/lzgQWuvuuwN20d+/0lT91b7a83N19obu/Wbr6IDC81RlSqucgrrbl7s+4+29Ll18hKZaojkEws+HAeODq0FnSKq2w9nP3awHc/Q13fylwrLQGAVuY2WBgc+CpwHn65e73A3+puHkC8IPS5R8Ah7c0VAp95W+kN0OfOOw44I7AGWqp5yCuKJjZvwJ7Ar8KmyS17wJfAWJ8g2gH4Hkzu7Y0VppjZpuFDlUvd38K+A7wJ5KDE9e5+8KwqRrynt49+Nz9GeA9gfMMRF292ZRyN7Ofl+ZzvV+/K33/dNk2XwPWu/sNzcggmzKztwE3AaeWVvBRMLNPAmtLrz6s9BWTwcBewOXuvhfwGsmIIApmtjXJqnd7YFvgbWZ2VNhUmYhxoZCqN5tyhKq7H9Tf/WZ2LMnL7E804/kztgYYUXZ9eOm2aJReTt8E/I+73xo6T0ofBw4zs/HAZsCWZnadux8TOFe9VgOr3P2h0vWbgJjelD8QeMLdXwAws/8F/g2IbVG21syGuvtaM3sv8GzoQGml7c0Qe8uMI3mJfZi7/63Vz9+AjQdxlfYSmERy0FZM/hv4g7vPCh0kLXc/291HuPuOJD/7uyMqdkqjgFVmNrJ001jiemP4T8DHzOyfzcxI8sfwhnDlq7z5wLGly/8BtPsiZ5P8jfRmy/dzN7MVwBDgz6WbHnT3k1saIqXSD3YWbx3EdVHgSHUzs48D9wK/I3kp6sDZ7n5n0GANMLMDgP9098NCZ0nDzPYgeTP4n4AngC+6+4thU9XPzM4n+cO6HvgNMKW0c0FbMrMbgE7gX4C1wPnALcCPge2AlcCR7r4uVMb+VMl/Nil7UwcxiYjkUOi9ZUREpAlU7iIiOaRyFxHJIZW7iEgOqdxFRHJI5S4ikkMqdxGRHFK5i4jk0P8DbKFfDbjhJEoAAAAASUVORK5CYII=\n", 1056 | "text/plain": [ 1057 | "" 1058 | ] 1059 | }, 1060 | "metadata": {}, 1061 | "output_type": "display_data" 1062 | } 1063 | ], 1064 | "source": [ 1065 | "plot(mus,pmu2)" 1066 | ] 1067 | }, 1068 | { 1069 | "cell_type": "markdown", 1070 | "metadata": { 1071 | "slideshow": { 1072 | "slide_type": "slide" 1073 | } 1074 | }, 1075 | "source": [ 1076 | "Assume now that we are supposed to return a \"best estimate\" of the parameter.\n", 1077 | "\n", 1078 | "By itself, that isn't sufficient.\n", 1079 | "\n", 1080 | "But now assume that we are given a loss function: if our estimate is with in $\\pm 0.5$ of the true value,\n", 1081 | "we don't need to pay a penalty, otherwise, we need to pay a penalty of 1. What value should we return?\n", 1082 | "\n", 1083 | "The most likely value is 7, but no values less than 7 can occur." 1084 | ] 1085 | }, 1086 | { 1087 | "cell_type": "markdown", 1088 | "metadata": { 1089 | "slideshow": { 1090 | "slide_type": "slide" 1091 | } 1092 | }, 1093 | "source": [ 1094 | "Therefore, it is better to return 7.5. That way, not only do we have the most likely value, but we also\n", 1095 | "get all the probability mass between 7 and 8 as well and our expected loss is about half of what it would be\n", 1096 | "if we had returned 7." 1097 | ] 1098 | }, 1099 | { 1100 | "cell_type": "markdown", 1101 | "metadata": { 1102 | "slideshow": { 1103 | "slide_type": "slide" 1104 | } 1105 | }, 1106 | "source": [ 1107 | "Now assume we are penalized if we are outside the range $\\pm 1$ from the true value.\n", 1108 | "By the same reasoning, our parameter estimate should now be 8." 1109 | ] 1110 | }, 1111 | { 1112 | "cell_type": "markdown", 1113 | "metadata": { 1114 | "slideshow": { 1115 | "slide_type": "slide" 1116 | } 1117 | }, 1118 | "source": [ 1119 | "As you can see from this simple example, there is not \"best\" answer to the parameter estimation\n", 1120 | "problem; our answer depends on the loss function.\n", 1121 | "\n", 1122 | "But we can see that for any symmetric loss function, the value 7 (the maximum likelihood estimate)\n", 1123 | "is never the optimal answer." 1124 | ] 1125 | }, 1126 | { 1127 | "cell_type": "code", 1128 | "execution_count": null, 1129 | "metadata": {}, 1130 | "outputs": [], 1131 | "source": [] 1132 | } 1133 | ], 1134 | "metadata": { 1135 | "celltoolbar": "Slideshow", 1136 | "kernelspec": { 1137 | "display_name": "Python 2", 1138 | "language": "python", 1139 | "name": "python2" 1140 | }, 1141 | "language_info": { 1142 | "codemirror_mode": { 1143 | "name": "ipython", 1144 | "version": 2 1145 | }, 1146 | "file_extension": ".py", 1147 | "mimetype": "text/x-python", 1148 | "name": "python", 1149 | "nbconvert_exporter": "python", 1150 | "pygments_lexer": "ipython2", 1151 | "version": "2.7.15" 1152 | } 1153 | }, 1154 | "nbformat": 4, 1155 | "nbformat_minor": 1 1156 | } 1157 | -------------------------------------------------------------------------------- /PDF/010-understanding-DL-models.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tmbdev-tutorials/dl-2018/7b42601e90a9cad1d0fb8d399c55c178561313ab/PDF/010-understanding-DL-models.pdf -------------------------------------------------------------------------------- /PDF/020-dl-pca-ica.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tmbdev-tutorials/dl-2018/7b42601e90a9cad1d0fb8d399c55c178561313ab/PDF/020-dl-pca-ica.pdf -------------------------------------------------------------------------------- /PDF/030-2d-layers.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tmbdev-tutorials/dl-2018/7b42601e90a9cad1d0fb8d399c55c178561313ab/PDF/030-2d-layers.pdf -------------------------------------------------------------------------------- /PDF/040-bayesian-decision-theory.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tmbdev-tutorials/dl-2018/7b42601e90a9cad1d0fb8d399c55c178561313ab/PDF/040-bayesian-decision-theory.pdf -------------------------------------------------------------------------------- /PDF/050-parameter-estimation.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tmbdev-tutorials/dl-2018/7b42601e90a9cad1d0fb8d399c55c178561313ab/PDF/050-parameter-estimation.pdf -------------------------------------------------------------------------------- /PDF/100-Large Scale DL.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tmbdev-tutorials/dl-2018/7b42601e90a9cad1d0fb8d399c55c178561313ab/PDF/100-Large Scale DL.pdf -------------------------------------------------------------------------------- /chapter.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tmbdev-tutorials/dl-2018/7b42601e90a9cad1d0fb8d399c55c178561313ab/chapter.png -------------------------------------------------------------------------------- /flex.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 NVIDIA CORPORATION. All rights reserved. 2 | # See the LICENSE file for licensing terms (BSD-style). 3 | 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | from torch import autograd 8 | from torch.autograd import Variable 9 | from torch.legacy import nn as legnn 10 | import numpy as np 11 | import torch 12 | from torch import nn 13 | from torch import autograd 14 | from torch.autograd import Variable 15 | from torch.legacy import nn as legnn 16 | import layers 17 | 18 | class Flex(nn.Module): 19 | def __init__(self, creator): 20 | super(Flex, self).__init__() 21 | self.creator = creator 22 | self.layer = None 23 | self.dummy = nn.Parameter(torch.zeros(1)) 24 | def forward(self, *args): 25 | if self.layer is None: 26 | self.layer = self.creator(*args) 27 | self.layer.to(self.dummy.device) 28 | return self.layer.forward(*args) 29 | def __repr__(self): 30 | return "Flex:"+repr(self.layer) 31 | def __str__(self): 32 | return "Flex:"+str(self.layer) 33 | 34 | def Linear(*args, **kw): 35 | def creator(x): 36 | assert x.ndimension()==2 37 | return nn.Linear(x.size(1), *args, **kw) 38 | return Flex(creator) 39 | 40 | 41 | def Conv1d(*args, **kw): 42 | def creator(x): 43 | assert x.ndimension()==3 44 | d = x.size(1) 45 | return nn.Conv1d(x.size(1), *args, **kw) 46 | return Flex(creator) 47 | 48 | 49 | def Conv2d(*args, **kw): 50 | def creator(x): 51 | assert x.ndimension()==4 52 | return nn.Conv2d(x.size(1), *args, **kw) 53 | return Flex(creator) 54 | 55 | 56 | def Conv3d(*args, **kw): 57 | def creator(x): 58 | assert x.ndimension()==5 59 | return nn.Conv3d(x.size(1), *args, **kw) 60 | return Flex(creator) 61 | 62 | 63 | def Lstm1(*args, **kw): 64 | def creator(x): 65 | assert x.ndimension()==3 66 | return layers.LSTM1(x.size(1), *args, **kw) 67 | return Flex(creator) 68 | 69 | 70 | def LSTM1to0(*args, **kw): 71 | def creator(x): 72 | assert x.ndimension()==3 73 | return layers.Lstm1to0(x.size(1), *args, **kw) 74 | return Flex(creator) 75 | 76 | 77 | def Lstm2(*args, **kw): 78 | def creator(x): 79 | assert x.ndimension()==4 80 | return layers.LSTM2(x.size(1), *args, **kw) 81 | return Flex(creator) 82 | 83 | 84 | def Lstm2to1(*args, **kw): 85 | def creator(x): 86 | assert x.ndimension()==4 87 | return layers.LSTM2to1(x.size(1), *args, **kw) 88 | return Flex(creator) 89 | 90 | def BatchNorm1d(*args, **kw): 91 | def creator(x): 92 | assert x.ndimension()==3 93 | return nn.BatchNorm1d(x.size(1), *args, **kw) 94 | return Flex(creator) 95 | 96 | def BatchNorm2d(*args, **kw): 97 | def creator(x): 98 | assert x.ndimension()==4 99 | return nn.BatchNorm2d(x.size(1), *args, **kw) 100 | return Flex(creator) 101 | 102 | def BatchNorm3d(*args, **kw): 103 | def creator(x): 104 | assert x.ndimension()==5 105 | return nn.BatchNorm3d(x.size(1), *args, **kw) 106 | return Flex(creator) 107 | 108 | def replace_modules(model, f): 109 | for key in model._modules.keys(): 110 | sub = model._modules[key] 111 | replacement = f(sub) 112 | if replacement is not None: 113 | model._modules[key] = replacement 114 | else: 115 | replace_modules(sub, f) 116 | 117 | def flex_replacer(module): 118 | if isinstance(module, Flex): 119 | return module.layer 120 | else: 121 | return None 122 | 123 | def flex_freeze(model): 124 | replace_modules(model, flex_replacer) 125 | 126 | def delete_modules(model, f): 127 | for key in model._modules.keys(): 128 | if f(model._modules[key]): 129 | del model._modules[key] 130 | 131 | -------------------------------------------------------------------------------- /helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 NVIDIA CORPORATION. All rights reserved. 2 | # See the LICENSE file for licensing terms (BSD-style). 3 | 4 | """A set of helper functions for dealing uniformly with tensors and 5 | ndarrays.""" 6 | 7 | import numpy as np 8 | import torch 9 | from torch import autograd, nn, optim 10 | from torch.autograd import Variable 11 | import torch.nn.functional as F 12 | from scipy import ndimage 13 | 14 | torch_tensor_types = tuple([ 15 | torch.Tensor, 16 | torch.FloatTensor, torch.IntTensor, torch.LongTensor, 17 | torch.cuda.FloatTensor, torch.cuda.IntTensor, torch.cuda.LongTensor 18 | ]) 19 | 20 | def asnd(x): 21 | """Convert torch/numpy to numpy.""" 22 | if isinstance(x, np.ndarray): 23 | return x 24 | if isinstance(x, Variable): 25 | x = x.data 26 | if isinstance(x, (torch.cuda.FloatTensor, torch.cuda.DoubleTensor, torch.cuda.IntTensor)): 27 | x = x.cpu() 28 | return x.numpy() 29 | 30 | def as_nda(x, transpose_on_convert=None): 31 | """Turns any tensor into an ndarray.""" 32 | if isinstance(x, np.ndarray): 33 | return x 34 | if isinstance(x, list): 35 | return np.array(x) 36 | if isinstance(x, autograd.Variable): 37 | x = x.data 38 | if isinstance(x, torch_tensor_types): 39 | x = x.cpu().numpy() 40 | return np.ascontiguousarray(maybe_transpose(x, transpose_on_convert)) 41 | raise ValueError("{}: can't convert to np.array".format(type(x))) 42 | 43 | def astorch(x, single=True): 44 | """Convert torch/numpy to torch.""" 45 | if isinstance(x, np.ndarray): 46 | if x.dtype == np.dtype("f"): 47 | return torch.FloatTensor(x) 48 | elif x.dtype == np.dtype("d"): 49 | if single: 50 | return torch.FloatTensor(x) 51 | else: 52 | return torch.DoubleTensor(x) 53 | elif x.dtype == np.dtype("i"): 54 | return torch.IntTensor(x) 55 | else: 56 | error("unknown np.dtype") 57 | return x 58 | 59 | def as_torch(x, transpose_on_convert=None, single=True): 60 | """Converts any kind of tensor/array into a torch tensor.""" 61 | if isinstance(x, Variable): 62 | return x.data 63 | if isinstance(x, torch_tensor_types): 64 | return x 65 | if isinstance(x, list): 66 | x = np.array(x) 67 | if isinstance(x, np.ndarray): 68 | x = maybe_transpose(x, transpose_on_convert) 69 | if x.dtype == np.dtype("f"): 70 | return torch.FloatTensor(x) 71 | elif x.dtype == np.dtype("d"): 72 | if single: 73 | return torch.FloatTensor(x) 74 | else: 75 | return torch.DoubleTensor(x) 76 | elif x.dtype in [np.dtype("i"), np.dtype("int64")]: 77 | return torch.LongTensor(x) 78 | else: 79 | raise ValueError("{} {}: unknown dtype".format(x, x.dtype)) 80 | raise ValueError("{} {}: unknown type".format(x, type(x))) 81 | 82 | def is_tensor(x): 83 | if isinstance(x, Variable): 84 | x = x.data 85 | return isinstance(x, torch_tensor_types) 86 | 87 | def rank(x): 88 | """Return the rank of the ndarray or tensor.""" 89 | if isinstance(x, np.ndarray): 90 | return x.ndim 91 | else: 92 | return x.dim() 93 | 94 | def size(x, i): 95 | """Return the size of dimension i.""" 96 | if isinstance(x, np.ndarray): 97 | return x.shape[i] 98 | else: 99 | return x.size(i) 100 | 101 | def shp(x): 102 | """Returns the shape of a tensor or ndarray as a tuple.""" 103 | if isinstance(x, Variable): 104 | return tuple(x.data.size()) 105 | elif isinstance(x, np.ndarray): 106 | return tuple(x.shape) 107 | elif isinstance(x, torch_tensor_types): 108 | return tuple(x.size()) 109 | else: 110 | raise ValueError("{}: unknown type".format(type(x))) 111 | 112 | def novar(x): 113 | """Turns a variable into a tensor; does nothing for a tensor.""" 114 | if isinstance(x, Variable): 115 | return x.data 116 | return x 117 | 118 | def maybe_transpose(x, axes): 119 | if axes is None: return x 120 | return x.transpose(axes) 121 | 122 | def typeas(x, y): 123 | """Make x the same type as y, for numpy, torch, torch.cuda.""" 124 | assert not isinstance(x, Variable) 125 | if isinstance(y, Variable): 126 | y = y.data 127 | if isinstance(y, np.ndarray): 128 | return asnd(x) 129 | if isinstance(x, np.ndarray): 130 | if isinstance(y, (torch.FloatTensor, torch.cuda.FloatTensor)): 131 | x = torch.FloatTensor(x) 132 | else: 133 | x = torch.DoubleTensor(x) 134 | return x.type_as(y) 135 | 136 | def sequence_is_normalized(x, d, eps=1e-3): 137 | """Check whether a batch of sequences BDL is normalized in d.""" 138 | if isinstance(x, Variable): 139 | x = x.data 140 | assert x.dim() == 3 141 | marginal = x.sum(d) 142 | return (marginal - 1.0).abs().lt(eps).all() 143 | 144 | def bhwd2bdhw(images, depth1=False): 145 | images = as_torch(images) 146 | if depth1: 147 | assert len(shp(images)) == 3, shp(images) 148 | images = images.unsqueeze(3) 149 | assert len(shp(images)) == 4, shp(images) 150 | return images.permute(0, 3, 1, 2) 151 | 152 | def bdhw2bhwd(images, depth1=False): 153 | images = as_torch(images) 154 | assert len(shp(images)) == 4, shp(images) 155 | images = images.permute(0, 2, 3, 1) 156 | if depth1: 157 | assert images.size(3) == 1 158 | images = images.index_select(3, 0) 159 | return images 160 | 161 | def reorder(batch, inp, out): 162 | """Reorder the dimensions of the batch from inp to out order. 163 | 164 | E.g. BHWD -> BDHW. 165 | """ 166 | if inp is None: return batch 167 | if out is None: return batch 168 | assert isinstance(inp, str) 169 | assert isinstance(out, str) 170 | assert len(inp) == len(out), (inp, out) 171 | assert rank(batch) == len(inp), (rank(batch), inp) 172 | result = [inp.find(c) for c in out] 173 | # print ">>>>>>>>>>>>>>>> reorder", result 174 | for x in result: assert x >= 0, result 175 | if is_tensor(batch): 176 | return batch.permute(*result) 177 | elif isinstance(batch, np.ndarray): 178 | return batch.transpose(*result) 179 | 180 | def assign(dest, src, transpose_on_convert=None): 181 | """Resizes the destination and copies the source.""" 182 | src = as_torch(src, transpose_on_convert) 183 | if isinstance(dest, Variable): 184 | dest.data.resize_(*shp(src)).copy_(src) 185 | elif isinstance(dest, torch.Tensor): 186 | dest.resize_(*shp(src)).copy_(src) 187 | else: 188 | raise ValueError("{}: unknown type".format(type(dest))) 189 | 190 | def one_sequence_softmax(x): 191 | """Compute softmax over a sequence; shape is (l, d)""" 192 | y = asnd(x) 193 | assert y.ndim==2, "%s: input should be (length, depth)" % y.shape 194 | l, d = y.shape 195 | y = np.amax(y, axis=1)[:, np.newaxis] -y 196 | y = np.clip(y, -80, 80) 197 | y = np.exp(y) 198 | y = y / np.sum(y, axis=1)[:, np.newaxis] 199 | return typeas(y, x) 200 | 201 | def sequence_softmax(x): 202 | """Compute sotmax over a batch of sequences; shape is (b, l, d).""" 203 | y = asnd(x) 204 | assert y.ndim==3, "%s: input should be (batch, length, depth)" % y.shape 205 | for i in range(len(y)): 206 | y[i] = one_sequence_softmax(y[i]) 207 | return typeas(y, x) 208 | 209 | def ctc_align(prob, target): 210 | """Perform CTC alignment on torch sequence batches (using ocrolstm)""" 211 | import cctc 212 | prob_ = prob.cpu() 213 | target = target.cpu() 214 | b, l, d = prob.size() 215 | bt, lt, dt = target.size() 216 | assert bt==b, (bt, b) 217 | assert dt==d, (dt, d) 218 | assert sequence_is_normalized(prob, 2), prob 219 | assert sequence_is_normalized(target, 2), target 220 | result = torch.rand(1) 221 | cctc.ctc_align_targets_batch(result, prob_, target) 222 | return typeas(result, prob) 223 | 224 | def ctc_loss(logits, target): 225 | """A CTC loss function for BLD sequence training.""" 226 | assert logits.is_contiguous() 227 | assert target.is_contiguous() 228 | probs = sequence_softmax(logits) 229 | aligned = ctc_align(probs, target) 230 | assert aligned.size()==probs.size(), (aligned.size(), probs.size()) 231 | deltas = aligned - probs 232 | logits.backward(deltas.contiguous()) 233 | return deltas, aligned 234 | 235 | class LearningRateSchedule(object): 236 | def __init__(self, schedule): 237 | if ":" in schedule: 238 | self.learning_rates = [[float(y) for y in x.split(",")] for x in schedule.split(":")] 239 | assert self.learning_rates[0][0] == 0 240 | else: 241 | lr0 = float(schedule) 242 | self.learning_rates = [[0, lr0]] 243 | def __call__(self, count): 244 | _, lr = self.learning_rates[0] 245 | for n, l in self.learning_rates: 246 | if count < n: break 247 | lr = l 248 | return lr 249 | 250 | -------------------------------------------------------------------------------- /imgclass.py: -------------------------------------------------------------------------------- 1 | from numpy import * 2 | import numpy as np 3 | from random import randint, uniform 4 | from pylab import randn 5 | import h5py 6 | import torch 7 | from torch import nn 8 | from torch import optim 9 | import numpy as np 10 | import torchvision 11 | from scipy import ndimage as ndi 12 | import flex 13 | import layers 14 | import time 15 | import copy 16 | import pylab 17 | from IPython import display 18 | import matplotlib.pyplot as plt 19 | 20 | training_figsize = (4, 4) 21 | 22 | def one_hot(classes, nclasses=None, value=1.0): 23 | if nclasses is None: nclasses = 1+np.amax(classes) 24 | targets = torch.FloatTensor(len(classes), nclasses) 25 | targets[:, :] = 0 26 | return targets.scatter(1, classes.reshape(-1, 1), value) 27 | 28 | def C(images): 29 | if isinstance(images, np.ndarray): 30 | raise Error("accepts only Torch tensors") 31 | if images.dtype == torch.uint8: 32 | return images.type(torch.float)/255.0 33 | elif images.dtype == torch.float: 34 | return images 35 | else: 36 | raise Error("unknown dtype", images.dtype) 37 | 38 | def evaluate(model, images, classes, bs=200, return_results=False): 39 | results = [] 40 | errs = 0 41 | total = 0 42 | with torch.no_grad(): 43 | for i in range(0, len(images), bs): 44 | outputs = model.forward(C(images[i:i+bs])) 45 | _, indexes = outputs.max(1) 46 | results.append(indexes) 47 | errs += int((indexes!=classes[i:i+bs]).sum()) 48 | total += outputs.size(0) 49 | erate = float(errs) / total 50 | if return_results: return erate, total, results 51 | return erate 52 | 53 | def train(model, images, classes, ntrain=100000, bs=20, lr=0.001, momentum=0.9, decay=0, mode="ce"): 54 | 55 | with torch.no_grad(): 56 | model.forward(C(images[:bs])) 57 | if mode.lower() in ["ce", "crossentropy"]: 58 | criterion = nn.CrossEntropyLoss() 59 | expand = lambda target, output: target 60 | elif model.lower() in ["mse", "meansquared", "meansquarederror"]: 61 | criterion = nn.MSELoss() 62 | expand = lambda target, output: one_hot(target, output.size(1)) 63 | optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=decay) 64 | losses = [] 65 | for i in range(ntrain//bs): 66 | with torch.set_grad_enabled(True): 67 | optimizer.zero_grad() 68 | start = randint(0, len(images)-bs) 69 | inputs = C(images[start:start+bs]) 70 | outputs = model(inputs) 71 | target = expand(classes[start:start+bs], outputs) 72 | loss = criterion(outputs, target) 73 | losses.append(float(loss)) 74 | loss.backward() 75 | optimizer.step() 76 | return mean(losses[-100:]) 77 | 78 | class AutoMLP(object): 79 | def __init__(self, make_model, images, classes, test_images, test_classes, 80 | initial_bs=50, 81 | initial_lrs=10**linspace(-6, 2, 20), 82 | initial_ntrain=100000, 83 | momentum=0.9, 84 | decay=0, 85 | ntrain=50000, 86 | maxtrain=3e6, 87 | selection_noise=0.05, 88 | mintrain=1e6, 89 | stop_no_improvement=2.0 90 | ): 91 | self.make_model = make_model 92 | self.images = images 93 | self.classes = classes 94 | self.test_images = test_images 95 | self.test_classes = test_classes 96 | self.verbose = False 97 | self.initial_bs = initial_bs 98 | self.initial_lrs = initial_lrs 99 | self.initial_ntrain = int(initial_ntrain) 100 | self.ntrain = int(ntrain) 101 | self.momentum = momentum 102 | self.decay = decay 103 | self.maxrounds = int(maxtrain) // int(ntrain) 104 | self.selection_noise = selection_noise 105 | self.mintrain = int(mintrain) 106 | self.stop_no_improvement = stop_no_improvement 107 | self.best_model = None 108 | 109 | def initial_population(self, make_model): 110 | population = [] 111 | for lr in self.initial_lrs: 112 | model = make_model().cuda() 113 | model.PARAMS = dict(bs=self.initial_bs, lr=lr, id=randint(0, 1000000000)) 114 | model.LOG = [] 115 | population.append(model) 116 | return population 117 | 118 | def selection(self, population, size=4, key="training_loss"): 119 | if len(population) == 0: return [] 120 | for model in population: 121 | model.KEY = (1.0+randn()*self.selection_noise) * model.LOG[-1][key] 122 | population = sorted(population, key=lambda m: m.KEY) 123 | while len(population) > size: del population[-1] 124 | return population 125 | 126 | def mutation(self, old_population, variants=1, variation=0.8): 127 | population = [] 128 | for model in old_population: 129 | population += [model] 130 | for _ in range(variants): 131 | cloned = copy.deepcopy(model) 132 | cloned.PARAMS = dict( 133 | lr = clip(cloned.PARAMS["lr"] * uniform(variation, 1.0/variation), 1e-7, 1e2), 134 | bs = clip(int(cloned.PARAMS["bs"] * uniform(variation, 1.0/variation)), 1, 1000), 135 | id = randint(0, 1000000000) 136 | ) 137 | population += [cloned] 138 | return population 139 | 140 | def is_better(self, model, other): 141 | if model is None: return False 142 | if other is None: return True 143 | return model.LOG[-1]["test_loss"] < other.LOG[-1]["test_loss"] 144 | 145 | def train_population(self, population, ntrain=50000, momentum=0.9, verbose=False): 146 | infos = [] 147 | for model in population: 148 | lr, bs = [model.PARAMS[name] for name in "lr bs".split()] 149 | ntrained = 0 if len(model.LOG)==0 else model.LOG[-1]["ntrain"] 150 | training_loss = train(model, self.images, self.classes, lr=lr, bs=bs, 151 | momentum=self.momentum, ntrain=ntrain, decay=self.decay) 152 | test_loss = evaluate(model, self.test_images, self.test_classes) 153 | info = dict( 154 | training_loss=training_loss, 155 | test_loss=test_loss, 156 | lr=lr, 157 | ntrain=ntrain+ntrained, 158 | momentum=momentum, 159 | bs=bs) 160 | if self.verbose: print info 161 | model.LOG += [info] 162 | infos += [info] 163 | if self.is_better(model, self.best_model): 164 | self.best_model = model 165 | return infos 166 | 167 | def to(self, device): 168 | for model in self.population: 169 | model.to(device) 170 | 171 | def cpu(self): 172 | for model in self.population: 173 | model.cpu() 174 | 175 | def train(self): 176 | self.fig = plt.figure(figsize=training_figsize) 177 | self.fig.add_subplot(1,1,1) 178 | self.ax = self.fig.get_axes()[0] 179 | self.infos = [] 180 | self.best_model = None 181 | self.population = self.initial_population(self.make_model) 182 | initial_infos = self.train_population(self.population, ntrain=self.initial_ntrain) 183 | # infos += initial_infos 184 | self.population = self.selection(self.population) 185 | for r in xrange(self.maxrounds): 186 | old_population = [copy.deepcopy(model) for model in self.population] 187 | self.population = self.mutation(self.population) 188 | self.infos += self.train_population(self.population, ntrain=self.ntrain) 189 | self.population = self.selection(self.population + old_population) 190 | # information display 191 | if len(self.infos)>0: self.display() 192 | maxtrained = amax([m.LOG[-1]["ntrain"] for m in self.population]) 193 | l = self.best_model.LOG[-1] 194 | print "#best", l["test_loss"], "@", l["ntrain"], "of", maxtrained 195 | last_best = l["ntrain"] 196 | if maxtrained > self.mintrain and maxtrained > self.stop_no_improvement * maxtrained: 197 | print "# stopping b/c no improvement" 198 | break 199 | self.cpu() 200 | return self.population[0].cpu() 201 | 202 | def display(self, key="training_loss", yscale="log", ylim=None): 203 | self.ax.cla() 204 | self.ax.set_yscale(yscale) 205 | if ylim is not None: self.ax.set_ylim(ylim) 206 | self.ax.scatter(*zip(*[(l["ntrain"], l["training_loss"]) for l in self.infos])) 207 | display.clear_output(wait=True) 208 | display.display(self.fig) 209 | -------------------------------------------------------------------------------- /imgimg.py: -------------------------------------------------------------------------------- 1 | from numpy import * 2 | import numpy as np 3 | from random import randint, uniform 4 | from pylab import randn 5 | import h5py 6 | import torch 7 | from torch import nn 8 | from torch import optim 9 | import torchvision 10 | from scipy import ndimage as ndi 11 | import flex 12 | import layers 13 | import time 14 | import copy 15 | import pylab 16 | import matplotlib.pyplot as plt 17 | from IPython import display 18 | 19 | training_figsize = (4,4) 20 | 21 | 22 | def evaluate(model, images, targets, bs=200): 23 | assert images.dtype == torch.float 24 | assert targets.dtype == torch.float 25 | losses = [] 26 | with torch.no_grad(): 27 | for i in range(0, len(images), bs): 28 | outputs = model.forward(images[i:i+bs].type(torch.float)) 29 | results.append(outputs) 30 | loss = criterion(outputs, targets[i:i+bs]) 31 | losses.append(float(loss)) 32 | return mean(losses) 33 | 34 | def train(model, images, targets, ntrain=100000, bs=20, lr=0.001, momentum=0.9, decay=0.0): 35 | assert images.dtype == torch.float 36 | assert targets.dtype == torch.float 37 | with torch.no_grad(): 38 | model.forward(images[:bs].type(torch.float)) 39 | criterion = nn.MSELoss() 40 | optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=decay) 41 | losses = [] 42 | for i in range(ntrain//bs): 43 | with torch.set_grad_enabled(True): 44 | optimizer.zero_grad() 45 | start = randint(0, len(images)-bs) 46 | inputs = images[start:start+bs].type(torch.float) 47 | outputs = model(inputs) 48 | if isinstance(outputs, tuple): 49 | loss = criterion(outputs[0], targets[start:start+bs].type(torch.float)) + \ 50 | outputs[1] 51 | else: 52 | loss = criterion(outputs, targets[start:start+bs].type(torch.float)) 53 | losses.append(float(loss)) 54 | loss.backward() 55 | optimizer.step() 56 | return losses 57 | 58 | 59 | class AutoMLP(object): 60 | def __init__(self, make_model, images, targets, 61 | initial_bs=50, 62 | initial_lrs=10**linspace(-6, 2, 20), 63 | initial_ntrain=10000, 64 | momentum=0.9, 65 | ntrain=10000, 66 | maxtrain=1e6, 67 | selection_noise=0.05, 68 | mintrain=1e6, 69 | stop_no_improvement=2.0, 70 | decay=1e-6 71 | ): 72 | self.make_model = make_model 73 | self.images = images 74 | self.targets = targets 75 | self.verbose = False 76 | self.initial_bs = initial_bs 77 | self.initial_lrs = initial_lrs 78 | self.initial_ntrain = int(initial_ntrain) 79 | self.ntrain = int(ntrain) 80 | self.momentum = momentum 81 | self.decay = decay 82 | self.maxrounds = int(maxtrain) // int(ntrain) 83 | self.selection_noise = selection_noise 84 | self.mintrain = int(mintrain) 85 | self.stop_no_improvement = stop_no_improvement 86 | self.best_model = None 87 | 88 | def initial_population(self, make_model): 89 | population = [] 90 | for lr in self.initial_lrs: 91 | model = make_model().cuda() 92 | model.PARAMS = dict(bs=self.initial_bs, lr=lr, id=randint(0, 1000000000)) 93 | model.LOG = [] 94 | population.append(model) 95 | return population 96 | 97 | def selection(self, population, size=4, key="training_loss"): 98 | if len(population) == 0: return [] 99 | for model in population: 100 | model.KEY = (1.0+randn()*self.selection_noise) * model.LOG[-1][key] 101 | population = sorted(population, key=lambda m: m.KEY) 102 | while len(population) > size: del population[-1] 103 | return population 104 | 105 | def mutation(self, old_population, variants=1, variation=0.8): 106 | population = [] 107 | for model in old_population: 108 | population += [model] 109 | for _ in range(variants): 110 | cloned = copy.deepcopy(model) 111 | cloned.PARAMS = dict( 112 | lr = clip(cloned.PARAMS["lr"] * uniform(variation, 1.0/variation), 1e-7, 1e2), 113 | bs = clip(int(cloned.PARAMS["bs"] * uniform(variation, 1.0/variation)), 1, 1000), 114 | id = randint(0, 1000000000) 115 | ) 116 | population += [cloned] 117 | return population 118 | 119 | def train_population(self, population, ntrain=50000, momentum=0.9, verbose=False): 120 | infos = [] 121 | for model in population: 122 | lr, bs = [model.PARAMS[name] for name in "lr bs".split()] 123 | ntrained = 0 if len(model.LOG)==0 else model.LOG[-1]["ntrain"] 124 | training_loss = train(model, self.images, self.targets, lr=lr, bs=bs, 125 | momentum=self.momentum, 126 | decay=self.decay, 127 | ntrain=ntrain) 128 | if isinstance(training_loss, list): 129 | training_loss = mean(training_loss[-100:]) 130 | info = dict( 131 | training_loss=training_loss, 132 | lr=lr, 133 | ntrain=ntrain+ntrained, 134 | momentum=momentum, 135 | bs=bs) 136 | if self.verbose: print info 137 | model.LOG += [info] 138 | infos += [info] 139 | return infos 140 | 141 | def to(self, device): 142 | for model in self.population: 143 | model.to(device) 144 | 145 | def cpu(self): 146 | for model in self.population: 147 | model.cpu() 148 | 149 | def train(self): 150 | self.fig = plt.figure(figsize=training_figsize) 151 | self.fig.add_subplot(1,1,1) 152 | self.ax = self.fig.get_axes()[0] 153 | self.infos = [] 154 | self.population = self.initial_population(self.make_model) 155 | initial_infos = self.train_population(self.population, ntrain=self.initial_ntrain) 156 | # infos += initial_infos 157 | self.population = self.selection(self.population) 158 | for r in xrange(self.maxrounds): 159 | old_population = [copy.deepcopy(model) for model in self.population] 160 | self.population = self.mutation(self.population) 161 | self.infos += self.train_population(self.population, ntrain=self.ntrain) 162 | self.population = self.selection(self.population + old_population) 163 | # information display 164 | if len(self.infos)>0: self.display() 165 | self.cpu() 166 | self.fig.clf() 167 | return self.best() 168 | 169 | def display(self, key="training_loss", yscale="log", ylim=None): 170 | self.ax.cla() 171 | self.ax.set_yscale(yscale) 172 | if ylim is not None: self.ax.set_ylim(ylim) 173 | self.ax.scatter(*zip(*[(l["ntrain"], l["training_loss"]) for l in self.infos])) 174 | display.clear_output(wait=True) 175 | display.display(self.fig) 176 | 177 | def best(self): 178 | return self.population[0].cuda() 179 | 180 | def info(self, model=None): 181 | for i, key in enumerate("training_loss lr bs".split()): 182 | pylab.subplot(2, 2, i+1) 183 | pylab.plot(*zip(*[(l["ntrain"], l[key]) for l in model.LOG])) 184 | -------------------------------------------------------------------------------- /init.py: -------------------------------------------------------------------------------- 1 | import scipy 2 | import scipy.ndimage as ndi 3 | 4 | rc("image", cmap="gray", interpolation="bicubic") 5 | figsize(10, 10) 6 | 7 | import dlinputs as dli 8 | import dltrainers as dlt 9 | import dlinputs.filters as dlf 10 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 NVIDIA CORPORATION. All rights reserved. 2 | # See the LICENSE file for licensing terms (BSD-style). 3 | 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | from torch import autograd 8 | from torch.legacy import nn as legnn 9 | import numpy as np 10 | import torch 11 | from torch import nn 12 | from torch import autograd 13 | from torch.legacy import nn as legnn 14 | 15 | BD = "BD" 16 | LBD = "LBD" 17 | LDB = "LDB" 18 | BDL = "BDL" 19 | BLD = "BLD" 20 | BWHD = "BWHD" 21 | BDWH = "BDWH" 22 | BWH = "BWH" 23 | 24 | def deprecated(f): 25 | def g(*args, **kw): 26 | raise Exception("deprecated") 27 | return g 28 | 29 | def lbd2bdl(x): 30 | assert len(x.size()) == 3 31 | return x.permute(1, 2, 0).contiguous() 32 | 33 | 34 | def bdl2lbd(x): 35 | assert len(x.size()) == 3 36 | return x.permute(2, 0, 1).contiguous() 37 | 38 | def data(x): 39 | return x 40 | 41 | @deprecated 42 | def typeas(x, y): 43 | """Make x the same type as y, for numpy, torch, torch.cuda.""" 44 | if isinstance(y, np.ndarray): 45 | return asnd(x) 46 | if isinstance(x, np.ndarray): 47 | if isinstance(y, (torch.FloatTensor, torch.cuda.FloatTensor)): 48 | x = torch.FloatTensor(x) 49 | else: 50 | x = torch.DoubleTensor(x) 51 | return x.type_as(y) 52 | 53 | class Fun(nn.Module): 54 | def __init__(self, f, info=None): 55 | nn.Module.__init__(self) 56 | assert isinstance(f, str) 57 | self.f = eval(f) 58 | self.f_str = f 59 | self.info = info 60 | def __getnewargs__(self): 61 | return (self.f_str, self.info) 62 | def forward(self, x): 63 | return self.f(x) 64 | def __repr__(self): 65 | return "Fun {} {}".format(self.info, self.f) 66 | 67 | class PixelsToBatch(nn.Module): 68 | def forward(self, x): 69 | b, d, h, w = x.size() 70 | return x.permute(0, 2, 3, 1).contiguous().view(b*h*w, d) 71 | 72 | class WeightedGrad(autograd.Function): 73 | def forward(self, input, weights): 74 | self.weights = weights 75 | return input 76 | def backward(self, grad_output): 77 | return grad_output * self.weights, None 78 | 79 | def weighted_grad(x, y): 80 | return WeightedGrad()(x, y) 81 | 82 | class Info(nn.Module): 83 | def __init__(self, info="", count=1, mod=1): 84 | nn.Module.__init__(self) 85 | self.mod = mod 86 | self.count = count 87 | self.steps = 0 88 | self.outputs = 0 89 | self.info = info 90 | def forward(self, x): 91 | if self.outputs < self.count: 92 | if self.steps % self.mod == 0: 93 | print "Info", self.info, x.size(), float(x.min()), float(x.max()) 94 | self.outputs += 1 95 | self.steps += 1 96 | return x 97 | def __repr__(self): 98 | return "Info {}".format(self.info) 99 | 100 | class CheckSizes(nn.Module): 101 | def __init__(self, *args, **kw): 102 | nn.Module.__init__(self) 103 | self.order = kw.get("order") 104 | self.name = kw.get("name") 105 | self.limits = [(x, x) if isinstance(x, int) else x for x in args] 106 | def forward(self, x): 107 | for (i, actual), (lo, hi) in zip(enumerate(tuple(x.size())), self.limits): 108 | if actual < lo: 109 | raise Exception("{} ({}): index {} too low ({} not >= {})" 110 | .format(self.name, self.order, 111 | i, actual, lo)) 112 | if actual > hi: 113 | raise Exception("{} ({}): index {} too high ({} not <= {})" 114 | .format(self.name, self.order, 115 | i, actual, hi)) 116 | return x 117 | 118 | def __repr__(self): 119 | return "CheckSizes {}".format(self.limits) 120 | 121 | 122 | class AutoDevice(nn.Module): 123 | def __init__(self): 124 | nn.Module.__init__(self) 125 | self.dummy = nn.Parameter(torch.zeros(1)) 126 | def forward(self, x): 127 | return x.type(type(self.dummy.data)) 128 | def __repr__(self): 129 | return "AutoDevice:{}:{}".format(type(self.dummy.data), self.dummy.device) 130 | 131 | class Cpu(nn.Module): 132 | def __init__(self): 133 | nn.Module.__init__(self) 134 | def forward(self, x): 135 | return x.cpu() 136 | def __repr__(self): 137 | return "Cpu" 138 | 139 | class Check(nn.Module): 140 | def __init__(self, *shape, **kw): 141 | nn.Module.__init__(self) 142 | self.expected = tuple(shape) 143 | self.valid = kw.get("valid", (-1e-5, 1+1e-5)) 144 | def forward(self, x): 145 | expected_shape = self.expected 146 | actual_shape = tuple(x.size()) 147 | assert len(actual_shape)==len(expected_shape) 148 | for i in range(len(actual_shape)): 149 | assert expected_shape[i]<0 or expected_shape[i]==actual_shape[i], \ 150 | (expected_shape, actual_shape, i) 151 | assert data(x).min() >= self.valid[0], (data(x).min(), self.valid) 152 | assert data(x).max() <= self.valid[1], (data(x).max(), self.valid) 153 | return x 154 | 155 | class Reorder(nn.Module): 156 | def __init__(self, old, new): 157 | self.old = old 158 | self.new = new 159 | nn.Module.__init__(self) 160 | self.permutation = tuple([old.find(c) for c in new]) 161 | def forward(self, x): 162 | return x.permute(*self.permutation).contiguous() 163 | def __repr__(self): 164 | return "Reorder {}->{}".format(self.old, self.new) 165 | 166 | class Permute(nn.Module): 167 | def __init__(self, *args): 168 | nn.Module.__init__(self) 169 | self.permutation = args 170 | def forward(self, x): 171 | return x.permute(*self.permutation).contiguous() 172 | def __repr__(self): 173 | return "Permute({})".format(self.permutation) 174 | 175 | class Reshape(nn.Module): 176 | def __init__(self, *args): 177 | nn.Module.__init__(self) 178 | self.shape = args 179 | def forward(self, x): 180 | newshape = [] 181 | for s in self.shape: 182 | if isinstance(s, int): 183 | newshape.append(int(x.size(s))) 184 | elif isinstance(s, (tuple, list)): 185 | total = 1 186 | for j in s: 187 | total *= int(x.size(j)) 188 | newshape.append(total) 189 | else: 190 | raise ValueError("shape spec must be either int or tuple, got {}".format(s)) 191 | return x.view(*newshape) 192 | def __repr__(self): 193 | return "Reshape({})".format(self.shape) 194 | 195 | class Viewer(nn.Module): 196 | def __init__(self, *args): 197 | nn.Module.__init__(self) 198 | self.shape = args 199 | 200 | def forward(self, x): 201 | return x.view(*self.shape) 202 | 203 | def __repr__(self): 204 | return "Viewer %s" % (self.shape,) 205 | 206 | class Norm(nn.Module): 207 | def __init__(self, r=2): 208 | nn.Module.__init__(self) 209 | self.r = r 210 | 211 | def forward(self, x): 212 | assert x.ndimension() == 2 213 | r = self.r 214 | return x / (((x.abs()**r).sum(1))**(1.0/r)).unsqueeze(1) 215 | 216 | def __repr__(self): 217 | return "Norm-{}".format(self.r) 218 | 219 | class Flat(nn.Module): 220 | def __init__(self): 221 | nn.Module.__init__(self) 222 | 223 | def forward(self, x): 224 | rank = len(x.size()) 225 | assert rank > 2 226 | new_depth = np.prod(tuple(x.size())[1:]) 227 | return x.view(-1, new_depth) 228 | 229 | def __repr__(self): 230 | return "Flat" 231 | 232 | class Img2FlatSum(nn.Module): 233 | input_order = BDWH 234 | output_order = BD 235 | 236 | def __init__(self): 237 | nn.Module.__init__(self) 238 | 239 | def forward(self, img): 240 | # BDWH -> BD 241 | return img.sum(3).sum(2) 242 | 243 | def __repr__(self): 244 | return "Img2FlatSum" 245 | 246 | class Img2FlatMax(nn.Module): 247 | input_order = BDWH 248 | output_order = BD 249 | 250 | def __init__(self): 251 | nn.Module.__init__(self) 252 | 253 | def forward(self, img): 254 | # BDWH -> BD 255 | return img.max(3)[0].max(2)[0] 256 | 257 | def __repr__(self): 258 | return "Img2FlatMax" 259 | 260 | class Textline2Img(nn.Module): 261 | input_order = BWH 262 | output_order = BDWH 263 | 264 | def __init__(self): 265 | nn.Module.__init__(self) 266 | 267 | def forward(self, seq): 268 | b, l, d = seq.size() 269 | return seq.view(b, 1, l, d) 270 | 271 | def __repr__(self): 272 | return "Textline2Img" 273 | 274 | 275 | class Img2Seq(nn.Module): 276 | input_order = BDWH 277 | output_order = BDL 278 | 279 | def __init__(self): 280 | nn.Module.__init__(self) 281 | 282 | def forward(self, img): 283 | b, d, w, h = img.size() 284 | perm = img.permute(0, 1, 3, 2).contiguous() 285 | return perm.view(b, d * h, w) 286 | 287 | def __repr__(self): 288 | return "Img2Seq" 289 | 290 | class ImgMaxSeq(nn.Module): 291 | input_order = BDWH 292 | output_order = BDL 293 | 294 | def __init__(self): 295 | nn.Module.__init__(self) 296 | 297 | def forward(self, img): 298 | # BDWH -> BDW -> BWD 299 | return img.max(3)[0].squeeze(3) 300 | 301 | def __repr__(self): 302 | return "ImgMaxSeq" 303 | 304 | class ImgSumSeq(nn.Module): 305 | input_order = BDWH 306 | output_order = BDL 307 | 308 | def __init__(self): 309 | nn.Module.__init__(self) 310 | 311 | def forward(self, img): 312 | # BDWH -> BDW -> BWD 313 | return img.sum(3)[0].squeeze(3).permute(0, 2, 1).contiguous() 314 | 315 | def __repr__(self): 316 | return "ImgSumSeq" 317 | 318 | 319 | class LSTM1(nn.Module): 320 | """A simple bidirectional LSTM. 321 | 322 | All the sequence processing layers use BDL order by default to 323 | be consistent with 1D convolutions. 324 | """ 325 | input_order = BDL 326 | output_order = BDL 327 | 328 | def __init__(self, ninput=None, noutput=None, ndir=2): 329 | nn.Module.__init__(self) 330 | assert ninput is not None 331 | assert noutput is not None 332 | self.ndir = ndir 333 | self.ninput = ninput 334 | self.noutput = noutput 335 | self.lstm = nn.LSTM(ninput, noutput, 1, bidirectional=self.ndir - 1) 336 | self.lstm.flatten_parameters() 337 | 338 | def forward(self, seq): 339 | seq = bdl2lbd(seq) 340 | l, bs, d = seq.size() 341 | assert d == self.ninput, seq.size() 342 | h0 = torch.zeros(self.ndir, bs, self.noutput, dtype=seq.dtype) 343 | c0 = torch.zeros(self.ndir, bs, self.noutput, dtype=seq.dtype) 344 | post_lstm, _ = self.lstm(seq, (h0, c0)) 345 | return lbd2bdl(post_lstm) 346 | 347 | def __repr__(self): 348 | return "LSTM1:"+self.lstm.__repr__() 349 | 350 | 351 | 352 | 353 | class LSTM2to1(nn.Module): 354 | """An LSTM that summarizes one dimension.""" 355 | input_order = BDWH 356 | output_order = BDL 357 | 358 | def __init__(self, ninput=None, noutput=None): 359 | nn.Module.__init__(self) 360 | self.ninput = ninput 361 | self.noutput = noutput 362 | self.lstm = nn.LSTM(ninput, noutput, 1, bidirectional=False) 363 | self.lstm.flatten_parameters() 364 | 365 | def forward(self, img): 366 | # BDWH -> HBWD -> HBsD 367 | b, d, w, h = img.size() 368 | seq = img.permute(3, 0, 2, 1).contiguous().view(h, b * w, d) 369 | bs = b * w 370 | h0 = torch.zeros(1, bs, self.noutput, dtype=img.dtype) 371 | c0 = torch.zeros(1, bs, self.noutput, dtype=img.dtype) 372 | # HBsD -> HBsD 373 | assert seq.size() == (h, b * w, d), (seq.size(), (h, b * w, d)) 374 | post_lstm, _ = self.lstm(seq, (h0, c0)) 375 | assert post_lstm.size() == (h, b * w, self.noutput), (post_lstm.size(), 376 | (h, b * w, self.noutput)) 377 | # HBsD -> BsD -> BWD 378 | final = post_lstm.select(0, h - 1).view(b, w, self.noutput) 379 | assert final.size() == (b, w, self.noutput), (final.size(), (b, w, self.noutput)) 380 | # BWD -> BDW 381 | final = final.permute(0, 2, 1).contiguous() 382 | assert final.size() == (b, self.noutput, w), (final.size(), 383 | (b, self.noutput, self.noutput)) 384 | return final 385 | 386 | 387 | class LSTM1to0(nn.Module): 388 | """An LSTM that summarizes one dimension.""" 389 | input_order = BDL 390 | output_order = BD 391 | 392 | def __init__(self, ninput=None, noutput=None): 393 | nn.Module.__init__(self) 394 | self.ninput = ninput 395 | self.noutput = noutput 396 | self.lstm = nn.LSTM(ninput, noutput, 1, bidirectional=False) 397 | self.lstm.flatten_parameters() 398 | 399 | def forward(self, seq): 400 | seq = bdl2lbd(seq) 401 | l, b, d = seq.size() 402 | assert d == self.ninput, (d, self.ninput) 403 | h0 = torch.zeros(1, b, self.noutput, dtype=seq.dtype) 404 | c0 = torch.zeros(1, b, self.noutput, dtype=seq.dtype) 405 | assert seq.size() == (l, b, d) 406 | post_lstm, _ = self.lstm(seq, (h0, c0)) 407 | assert post_lstm.size() == (l, b, self.noutput) 408 | final = post_lstm.select(0, l - 1).view(b, self.noutput) 409 | return final 410 | 411 | 412 | class RowwiseLSTM(nn.Module): 413 | def __init__(self, ninput=None, noutput=None, ndir=2): 414 | nn.Module.__init__(self) 415 | self.ndir = ndir 416 | self.ninput = ninput 417 | self.noutput = noutput 418 | self.lstm = nn.LSTM(ninput, noutput, 1, bidirectional=self.ndir - 1) 419 | self.lstm.flatten_parameters() 420 | 421 | def forward(self, img): 422 | b, d, h, w = img.size() 423 | # BDHW -> WHBD -> WB'D 424 | seq = img.permute(3, 2, 0, 1).contiguous().view(w, h * b, d) 425 | # WB'D 426 | h0 = torch.zeros(self.ndir, h * b, self.noutput, dtype=img.dtype) 427 | c0 = torch.zeros(self.ndir, h * b, self.noutput, dtype=img.dtype) 428 | seqresult, _ = self.lstm(seq, (h0, c0)) 429 | # WB'D' -> BD'HW 430 | result = seqresult.view( 431 | w, h, b, self.noutput * self.ndir).permute(2, 3, 1, 0) 432 | return result 433 | 434 | 435 | class LSTM2(nn.Module): 436 | """A 2D LSTM module.""" 437 | 438 | def __init__(self, ninput=None, noutput=None, nhidden=None, ndir=2): 439 | nn.Module.__init__(self) 440 | assert ndir in [1, 2] 441 | nhidden = nhidden or noutput 442 | self.hlstm = RowwiseLSTM(ninput, nhidden, ndir=ndir) 443 | self.vlstm = RowwiseLSTM(nhidden * ndir, noutput, ndir=ndir) 444 | 445 | def forward(self, img): 446 | horiz = self.hlstm(img) 447 | horizT = horiz.permute(0, 1, 3, 2).contiguous() 448 | vert = self.vlstm(horizT) 449 | vertT = vert.permute(0, 1, 3, 2).contiguous() 450 | return vertT 451 | 452 | -------------------------------------------------------------------------------- /nb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tmbdev-tutorials/dl-2018/7b42601e90a9cad1d0fb8d399c55c178561313ab/nb.png -------------------------------------------------------------------------------- /note.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tmbdev-tutorials/dl-2018/7b42601e90a9cad1d0fb8d399c55c178561313ab/note.png -------------------------------------------------------------------------------- /roadmodel.py: -------------------------------------------------------------------------------- 1 | from numpy import * 2 | from pylab import rand, randn, randint 3 | import scipy 4 | import scipy.ndimage as ndi 5 | 6 | 7 | def spatial_sampler(image): 8 | h, w = image.shape 9 | probs = image.reshape(h*w) 10 | assert amin(probs) >= 0 11 | probs = probs * 1.0 / sum(probs) 12 | probs = add.accumulate(probs) 13 | def f(): 14 | x = rand() 15 | i = searchsorted(probs, x) 16 | return (i//w, i%w) 17 | return f 18 | 19 | class RoadModel(object): 20 | def __init__(self, n=256, background=0.01, sigma=10.0): 21 | self.n = n 22 | self.background = background 23 | self.minsize = 3 24 | self.scale = 0.05 25 | 26 | obstacle_prior = zeros([n, n]) 27 | xs, ys = meshgrid(linspace(-1, 1, n), linspace(-1, 1, n)) 28 | prior = 1.0 * maximum((xs>0.5+0.5*ys), (xs<-0.5-0.5*ys)) 29 | self.road_map = prior 30 | 31 | prior = ndi.gaussian_filter(prior, sigma) 32 | prior -= amin(prior); prior /= amax(prior) 33 | prior = maximum(prior, self.background) 34 | self.prior = prior 35 | 36 | self.xs, self.ys = meshgrid(range(n), range(n)) 37 | self.sampler = spatial_sampler(self.prior) 38 | 39 | def sampled_image(self, k=100000): 40 | n = self.n 41 | result = zeros([n, n]) 42 | for _ in xrange(k): 43 | i, j = self.sampler() 44 | result[i, j] += 1.0 45 | return result 46 | 47 | def sample(self, k): 48 | if isinstance(k, tuple): k = randint(*k) 49 | return [self.sampler() for _ in range(k)] 50 | 51 | def render(self, samples): 52 | n = self.n 53 | target = zeros([n, n], 'f') 54 | for i, j in samples: 55 | target[i, j] = 1.0 56 | d = ndi.distance_transform_cdt(target==0) 57 | target = 1.0*(d < maximum(self.minsize, self.ys*self.scale)) 58 | return target 59 | 60 | def road_truth(self, k=(5, 20)): 61 | return self.render(self.sample(k)) 62 | 63 | def sense(self, target, fg=0.2, bg=0.02): 64 | n = self.n 65 | return 1.0*maximum(rand(n, n) < fg * target, rand(n, n) < bg) 66 | -------------------------------------------------------------------------------- /sicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tmbdev-tutorials/dl-2018/7b42601e90a9cad1d0fb8d399c55c178561313ab/sicon.png -------------------------------------------------------------------------------- /summary.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tmbdev-tutorials/dl-2018/7b42601e90a9cad1d0fb8d399c55c178561313ab/summary.png -------------------------------------------------------------------------------- /texture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tmbdev-tutorials/dl-2018/7b42601e90a9cad1d0fb8d399c55c178561313ab/texture.png -------------------------------------------------------------------------------- /texture3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tmbdev-tutorials/dl-2018/7b42601e90a9cad1d0fb8d399c55c178561313ab/texture3.png -------------------------------------------------------------------------------- /texture4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tmbdev-tutorials/dl-2018/7b42601e90a9cad1d0fb8d399c55c178561313ab/texture4.png -------------------------------------------------------------------------------- /trainers.py: -------------------------------------------------------------------------------- 1 | # copyright (c) 2017 NVIDIA CORPORATION. All rights reserved. 2 | # See the LICENSE file for licensing terms (BSD-style). 3 | 4 | """A set of "trainers", classes that wrap around Torch models 5 | and provide methods for training and evaluation.""" 6 | 7 | import time 8 | import types 9 | import platform 10 | import numpy as np 11 | import torch 12 | from torch import autograd, nn, optim 13 | from torch.autograd import Variable, Function 14 | import torch.nn.functional as F 15 | from scipy import ndimage 16 | import helpers as dlh 17 | 18 | def add_log(log, logname, **kw): 19 | entry = dict(kw, __log__=logname, __at__=time.time(), __node__=platform.node()) 20 | log.append(entry) 21 | 22 | def get_log(log, logname, **kw): 23 | records = [x for x in log if x.get("__log__")==logname] 24 | return records 25 | 26 | def update_display(): 27 | from matplotlib import pyplot 28 | from IPython import display 29 | display.clear_output(wait=True) 30 | display.display(pyplot.gcf()) 31 | 32 | class Weighted(Function): 33 | def forward(self, x, weights): 34 | self.saved_for_backward = [weights] 35 | return x 36 | def backward(self, grad_output): 37 | weights, = self.saved_for_backward 38 | grad_input = weights * grad_output 39 | return grad_input 40 | 41 | 42 | class BasicTrainer(object): 43 | """Trainers take care of bookkeeping for training models. 44 | 45 | The basic method is `train_batch(inputs, targets)`. It catches errors 46 | during forward propagation and reports the model and input shapes 47 | (shape mismatches are the most common source of errors. 48 | 49 | Trainers are just a temporary tool that's wrapped around a model 50 | for training purposes, so you can create, use, and discard them 51 | as convenient. 52 | """ 53 | 54 | def __init__(self, model, use_cuda=True, 55 | fields = ("input", "output"), 56 | input_axes = None, 57 | output_axes = None): 58 | self.use_cuda = use_cuda 59 | self.model = self._cuda(model) 60 | self.init_loss() 61 | self.input_name, self.output_name = fields 62 | self.no_display = False 63 | self.current_lr = None 64 | self.optimizer = None 65 | self.weighted = Weighted() 66 | self.ntrain = 0 67 | self.log = [] 68 | 69 | def _cuda(self, x): 70 | """Convert object to CUDA if use_cuda==True.""" 71 | if self.use_cuda: 72 | return x.cuda() 73 | else: 74 | return x.cpu() 75 | 76 | def set_training(self, mode=True): 77 | """Set training or prediction mode.""" 78 | if mode: 79 | if not self.model.training: 80 | self.model.train() 81 | self.cuinput = autograd.Variable( 82 | torch.randn(1, 1, 100, 100).cuda()) 83 | self.cutarget = autograd.Variable(torch.randn(1, 11).cuda()) 84 | else: 85 | if self.model.training: 86 | self.model.eval() 87 | self.cuinput = autograd.Variable(torch.randn(1, 1, 100, 100).cuda(), 88 | volatile=True) 89 | self.cutarget = autograd.Variable(torch.randn(1, 11).cuda(), 90 | volatile=True) 91 | 92 | def set_lr(self, lr, momentum=0.9, weight_decay=0.0): 93 | """Set the optimizer to SGD with the given parameters.""" 94 | self.current_lr = lr 95 | self.optimizer = optim.SGD(self.model.parameters(), 96 | lr=lr, 97 | momentum=momentum, 98 | weight_decay=weight_decay) 99 | 100 | def get_outputs(self): 101 | """Performs any necessary transformations on the output tensor. 102 | """ 103 | return dlh.novar(self.cuoutput).cpu() 104 | 105 | def set_inputs(self, batch): 106 | """Sets the cuinput variable from the input data. 107 | """ 108 | assert isinstance(batch, torch.Tensor) 109 | dlh.assign(self.cuinput, batch) 110 | 111 | def set_targets(self, targets, weights=None): 112 | """Sets the cutarget variable from the given tensor. 113 | """ 114 | dlh.assign(self.cutarget, targets, False) 115 | assert self.cuoutput.size() == self.cutargets.size() 116 | if weights is not None: 117 | dlh.assign(self.cuweights, weights, False) 118 | assert self.cuoutput.size() == self.cuweights.size() 119 | else: 120 | self.cuweights = None 121 | 122 | def init_loss(self, loss=nn.MSELoss()): 123 | self.criterion = self._cuda(loss) 124 | 125 | def compute_loss(self, targets, weights=None): 126 | self.set_targets(targets, weights=weights) 127 | return self.criterion(self.cuoutput, self.cutarget) 128 | 129 | def forward(self): 130 | try: 131 | self.cuoutput = self.model(self.cuinput) 132 | except RuntimeError, err: 133 | print "runtime error in forward step:" 134 | print "input", self.cuinput.size() 135 | raise err 136 | 137 | def train_batch(self, inputs, targets, weights=None, update=True, logname="train"): 138 | if update: 139 | self.set_training(True) 140 | self.optimizer.zero_grad() 141 | else: 142 | self.set_training(False) 143 | self.set_inputs(inputs) 144 | self.forward() 145 | if weights is not None: 146 | self.cuweights = autograd.Variable(torch.randn(1, 1).cuda()) 147 | dlh.assign(self.cuweights, weights, False) 148 | self.cuoutput = self.weighted(self.cuoutput, self.cuweights) 149 | culoss = self.compute_loss(targets, weights=weights) 150 | if update: 151 | culoss.backward() 152 | self.optimizer.step() 153 | ploss = dlh.novar(culoss)[0] 154 | self.ntrain += dlh.size(inputs, 0) 155 | add_log(self.log, logname, loss=ploss, ntrain=self.ntrain, lr=self.current_lr) 156 | return self.get_outputs(), ploss 157 | 158 | def eval_batch(self, inputs, targets): 159 | return self.train_batch(inputs, targets, update=False, logname="eval") 160 | 161 | def predict_batch(self, inputs): 162 | self.set_training(False) 163 | self.set_inputs(inputs) 164 | self.forward() 165 | return self.get_outputs() 166 | 167 | def loss_curve(self, logname): 168 | records = get_log(self.log, logname) 169 | records = [(x["ntrain"], x["loss"]) for x in records] 170 | records = sorted(records) 171 | if len(records)==0: 172 | return [], [] 173 | else: 174 | return zip(*records) 175 | 176 | def plot_loss(self, every=100, smooth=1e-2, yscale=None): 177 | if self.no_display: return 178 | # we import these locally to avoid dependence on display 179 | # functions for training 180 | import matplotlib as mpl 181 | from matplotlib import pyplot 182 | from scipy.ndimage import filters 183 | x, y = self.loss_curve("train") 184 | pyplot.plot(x, y) 185 | x, y = self.loss_curve("test") 186 | pyplot.plot(x, y) 187 | 188 | def display_loss(self, *args, **kw): 189 | pyplot.clf() 190 | self.plot_loss(*args, **kw) 191 | update_display() 192 | 193 | def set_sample_fields(self, input_name, output_name): 194 | self.input_name = input_name 195 | self.output_name = output_name 196 | 197 | def train_for(self, training, training_size=1e99): 198 | if isinstance(training, types.FunctionType): 199 | training = training() 200 | count = 0 201 | losses = [] 202 | for batch in training: 203 | if count >= training_size: break 204 | input_tensor = batch[self.input_name] 205 | output_tensor = batch[self.output_name] 206 | _, loss = self.train_batch(input_tensor, output_tensor) 207 | count += len(input_tensor) 208 | losses.append(loss) 209 | loss = np.mean(losses) 210 | return loss, count 211 | 212 | def eval_for(self, testset, testset_size=1e99): 213 | if isinstance(testset, types.FunctionType): 214 | testset = testset() 215 | count = 0 216 | losses = [] 217 | for batch in testset: 218 | if count >= testset_size: break 219 | input_tensor = batch[self.input_name] 220 | output_tensor = batch[self.output_name] 221 | _, loss = self.eval_batch(input_tensor, output_tensor) 222 | count += len(input_tensor) 223 | losses.append(loss) 224 | loss = np.mean(losses) 225 | return loss, count 226 | 227 | class ImageClassifierTrainer(BasicTrainer): 228 | def __init__(self, *args, **kw): 229 | BasicTrainer.__init__(self, *args, **kw) 230 | 231 | def set_inputs(self, images, depth1=False): 232 | dlh.assign(self.cuinput, images, transpose_on_convert=(0, 3, 1, 2)) 233 | 234 | def set_targets(self, targets, weights=None): 235 | assert weights is None, "weights not implemented" 236 | if isinstance(targets, list): 237 | targets = np.array(targets) 238 | if dlh.rank(targets) == 1: 239 | targets = dlh.as_torch(targets) 240 | targets = targets.unsqueeze(1) 241 | b, c = dlh.shp(self.cuoutput) 242 | onehot = torch.zeros(b, c) 243 | onehot.scatter_(1, targets, 1) 244 | dlh.assign(self.cutarget, onehot) 245 | else: 246 | assert dlh.shp(targets) == dlh.shp(self.cuoutput) 247 | dlh.assign(self.cutarget, targets) 248 | 249 | 250 | def zoom_like(batch, target_shape, order=0): 251 | assert isinstance(batch, np.ndarray) 252 | scales = [r * 1.0 / b for r, b in zip(target_shape, batch.shape)] 253 | result = np.zeros(target_shape) 254 | ndimage.zoom(batch, scales, order=order, output=result) 255 | return result 256 | 257 | def pixels_to_batch(x): 258 | b, d, h, w = x.size() 259 | return x.permute(0, 2, 3, 1).contiguous().view(b*h*w, d) 260 | 261 | class Image2ImageTrainer(BasicTrainer): 262 | """Train image to image models.""" 263 | def __init__(self, *args, **kw): 264 | BasicTrainer.__init__(self, *args, **kw) 265 | 266 | def compute_loss(self, targets, weights=None): 267 | self.set_targets(targets, weights=weights) 268 | return self.criterion(pixels_to_batch(self.cuoutput), 269 | pixels_to_batch(self.cutarget)) 270 | 271 | def set_inputs(self, images): 272 | dlh.assign(self.cuinput, images, (0, 3, 1, 2)) 273 | 274 | def get_outputs(self): 275 | return dlh.as_nda(self.cuoutput, (0, 2, 3, 1)) 276 | 277 | def set_targets(self, targets, weights=None): 278 | b, d, h, w = tuple(self.cuoutput.size()) 279 | targets = dlh.as_nda(targets, (0, 2, 3, 1)) 280 | targets = zoom_like(targets, (b, h, w, d)) 281 | dlh.assign(self.cutarget, targets, (0, 3, 1, 2)) 282 | assert self.cutarget.size() == self.cuoutput.size() 283 | if weights is not None: 284 | weights = dlh.as_nda(weights, (0, 2, 3, 1)) 285 | weights = zoom_like(weights, (b, h, w, d)) 286 | dlh.assign(self.cuweights, weights, (0, 3, 1, 2)) 287 | 288 | def ctc_align(prob, target): 289 | """Perform CTC alignment on torch sequence batches (using ocrolstm). 290 | 291 | Inputs are in BDL format. 292 | """ 293 | import cctc 294 | assert dlh.sequence_is_normalized(prob), prob 295 | assert dlh.sequence_is_normalized(target), target 296 | # inputs are BDL 297 | prob_ = dlh.novar(prob).permute(0, 2, 1).cpu().contiguous() 298 | target_ = dlh.novar(target).permute(0, 2, 1).cpu().contiguous() 299 | # prob_ and target_ are both BLD now 300 | assert prob_.size(0) == target_.size(0), (prob_.size(), target_.size()) 301 | assert prob_.size(2) == target_.size(2), (prob_.size(), target_.size()) 302 | assert prob_.size(1) >= target_.size(1), (prob_.size(), target_.size()) 303 | result = torch.rand(1) 304 | cctc.ctc_align_targets_batch(result, prob_, target_) 305 | return dlh.typeas(result.permute(0, 2, 1).contiguous(), prob) 306 | 307 | def sequence_softmax(seq): 308 | """Given a BDL sequence, computes the softmax for each time step.""" 309 | b, d, l = seq.size() 310 | batch = seq.permute(0, 2, 1).contiguous().view(b*l, d) 311 | smbatch = F.softmax(batch) 312 | result = smbatch.view(b, l, d).permute(0, 2, 1).contiguous() 313 | return result 314 | 315 | class Image2SeqTrainer(BasicTrainer): 316 | """Train image to sequence models using CTC. 317 | 318 | This takes images in BHWD order, plus output sequences 319 | consisting of lists of integers. 320 | """ 321 | def __init__(self, *args, **kw): 322 | BasicTrainer.__init__(self, *args, **kw) 323 | 324 | def init_loss(self, loss=None): 325 | assert loss is None, "Image2SeqTrainer must be trained with BCELoss (default)" 326 | self.criterion = nn.BCELoss(size_average=False) 327 | 328 | def compute_loss(self, targets, weights=None): 329 | self.cutargets = None # not used 330 | assert weights is None 331 | logits = self.cuoutput 332 | b, d, l = logits.size() 333 | probs = sequence_softmax(logits) 334 | assert dlh.sequence_is_normalized(probs), probs 335 | ttargets = torch.FloatTensor(targets) 336 | target_b, target_d, target_l = ttargets.size() 337 | assert b == target_b, (b, target_b) 338 | assert dlh.sequence_is_normalized(ttargets), ttargets 339 | aligned = ctc_align(probs.cpu(), ttargets.cpu()) 340 | assert dlh.sequence_is_normalized(aligned) 341 | return self.criterion(probs, Variable(self._cuda(aligned))) 342 | 343 | def set_inputs(self, images): 344 | dlh.assign(self.cuinput, images, (0, 3, 1, 2)) 345 | 346 | def set_targets(self, targets, outputs, weights=None): 347 | raise Exception("overridden by compute_loss") 348 | --------------------------------------------------------------------------------