├── .gitignore ├── 2018-01-03-bayesian-linreg.ipynb ├── 2018-01-08-bayesian-linreg-sample.ipynb ├── 2018-01-09-bayesian-linreg-posterior.ipynb ├── 2018-01-10-bayesian-linreg-plots.ipynb ├── 2018-05-02-hmm-alpha-recursion.ipynb ├── 2018-05-03-SmallNORB.ipynb ├── 2018-05-04-ipywidgets-for-learning-logistic-sigmoid-and-bayes-classifiers.ipynb ├── 2018-05-07-reading-jupyter-notebooks-into-Python.ipynb ├── 2018-05-09-gibbs-sampling.ipynb ├── 2018-05-12-pymc3-bayesian-linear-regression.ipynb ├── 2018-05-13-hmm-check-results.ipynb ├── 2018-05-13-viterbi-message-passing.ipynb ├── 2018-05-16-SVD-demo.ipynb ├── 2018-05-16-library-catalog-subject.ipynb ├── 2018-05-18-Gaussian-Processes.ipynb ├── 2018-05-20-Multidimensional-Scaling.ipynb ├── 2018-05-22-logistic-regression.ipynb ├── 2018-05-25-quick-example-of-dirichlet-distribution.ipynb ├── 2018-05-27-coffee.ipynb ├── 2018-06-11-climate-classification-with-neural-nets.ipynb ├── 2018-06-22-spacy.ipynb ├── 2018-12-27-KL-Divergence.ipynb ├── 2019-01-09-sum-product-message-passing.ipynb ├── 2019-03-17-style-transfer.ipynb ├── 2019-06-02-climate-with-keras.ipynb ├── 2019-07-01-systems-modeling-from-scratch.ipynb ├── 2019-11-06-learning-chernoff-faces.ipynb ├── 2020-01-06-faces-in-a-crowd-visualization.ipynb ├── LICENSE ├── README.md ├── data ├── .git_keep_folder ├── coffee_comparisons.json └── coffee_metadata.yaml ├── images ├── 2018-01-03-different-w.png ├── 2018-01-08-linear-sample-example.png ├── 2018-01-08-sigmoid.png ├── 2018-01-09-mean-far-away.png ├── 2018-01-09-mean.png ├── 2018-01-10-five-degrees-uncertainty-another-point.png ├── 2018-01-10-five-degrees-uncertainty-few.png ├── 2018-01-10-five-degrees.png ├── 2018-01-10-sample-with-error.png ├── 2018-01-10-samples.png ├── 2018-01-10-uncertainty-zoom-out.png ├── 2018-01-10-uncertainty.png ├── 2018-05-02-bumps-creaks.png ├── 2018-05-02-filtering.png ├── 2018-05-02-hmm.png ├── 2018-05-02-room-coordinates.png ├── 2018-05-02-sample-paths.png ├── 2018-05-02-transition.png ├── 2018-05-02-v-given-h.png ├── 2018-05-02-v-over-time.png ├── 2018-05-04-bayes-1.png ├── 2018-05-04-bayes-2.png ├── 2018-05-04-sigmoid.png ├── 2018-05-12-posterior.png ├── 2018-05-12-simulated-data.png ├── 2018-05-12-trace.png ├── 2018-05-12-weight-samples.png ├── 2018-05-13-viterbi.png ├── 2018-05-16-4d-plot.png ├── 2018-05-16-ex.png ├── 2018-05-16-labeled-plot.png ├── 2018-05-16-svd.png ├── 2018-05-18-cov-element.png ├── 2018-05-18-cov-similarity.png ├── 2018-05-18-cov.png ├── 2018-05-18-different-cov.png ├── 2018-05-18-large-ell.png ├── 2018-05-18-large-sigma.png ├── 2018-05-18-observations.png ├── 2018-05-18-poor-prediction.png ├── 2018-05-18-prediction.png ├── 2018-05-18-prior-beliefs.png ├── 2018-05-18-small-ell.png ├── 2018-05-20-comparison.png ├── 2018-05-20-distances.png ├── 2018-05-20-from-gram.png ├── 2018-05-20-north-america-per-state-province.png ├── 2018-05-20-north-america.png ├── 2018-05-20-pc-scores-from-x.png ├── 2018-05-22-2d-sigmoid.png ├── 2018-05-22-bad-fit.png ├── 2018-05-22-fitted-weights.png ├── 2018-05-22-fuzzy.png ├── 2018-05-22-less-steep.png ├── 2018-05-22-lin-sep.png ├── 2018-05-22-optimizer.png ├── 2018-05-22-reg-vs-no.png ├── 2018-05-22-sigmoid.png ├── 2018-05-26-dist.png ├── 2018-05-27-cairngorm.png ├── 2018-05-27-comparisons-box.png ├── 2018-05-27-pairwise-comparison.png ├── 2018-05-27-posteriors.png ├── 2018-05-27-predictions.png ├── 2018-05-27-ranking.png ├── 2018-05-27-spoon.png ├── 2018-05-27-traceplot.png ├── 2018-05-27-twelve-triangles.png ├── 2018-06-11-data.png ├── 2018-06-11-koppen.png ├── 2018-06-11-nn.png ├── 2018-06-11-silly-nn-1.png ├── 2018-06-11-silly-nn-2.png ├── 2018-06-11-silly-nn-3.png ├── 2018-06-11-stations.png ├── 2018-06-11-tensorboard.png ├── 2018-06-11-us.png ├── 2018-06-11-year-precipitation-predictions.png ├── 2018-06-11-year-temp-predictions.png ├── 2018-06-11-year-temp-single-pred.png ├── 2018-12-27-comparison.png ├── 2018-12-27-divergence-ex-bimodal-1.png ├── 2018-12-27-divergence-ex-bimodal-2.png ├── 2018-12-27-divergence-ex.png ├── 2018-12-27-examples.png ├── 2018-12-27-learning.png ├── 2018-12-27-minimizing.png ├── 2019-01-09-factor-graph.png ├── 2019-03-17-content-convs.png ├── 2019-03-17-ghost-church.png ├── 2019-03-17-initial-image.png ├── 2019-03-17-mt-hood.jpg ├── 2019-03-17-orig-content.png ├── 2019-03-17-orig-convs.png ├── 2019-03-17-orig-primordial-chaos.jpg ├── 2019-03-17-orig-scream.png ├── 2019-03-17-orig-start.png ├── 2019-03-17-orig-style.png ├── 2019-03-17-orig-tsunami.jpg ├── 2019-03-17-primordial-chaos.png ├── 2019-03-17-result.png ├── 2019-03-17-scream.png ├── 2019-03-17-style-convs.png ├── 2019-03-17-sutro.png ├── 2019-03-17-three-images.png ├── 2019-03-17-tsunami.png ├── 2019-06-02-new-map.png ├── 2019-07-01-delay.png ├── 2019-07-01-delays_system.png ├── 2019-07-01-no-delay.png ├── 2019-07-01-overfishing-equal.png ├── 2019-07-01-overfishing-oscillations-spiral.png ├── 2019-07-01-overfishing-oscillations.png ├── 2019-07-01-overfishing.png ├── 2019-07-01-renewable-equations.png ├── 2019-07-01-renewable_resource.png ├── 2019-07-01-response-time.png ├── 2019-07-01-spiral.png ├── 2019-07-01-temperatures.png ├── 2019-07-01-thermostat_system.png ├── 2019-07-01-yield-per-unit.png ├── 2019-11-06-autoencoder-diagram.png ├── 2019-11-06-autoencoder-example.png ├── 2019-11-06-chernoff-1.png ├── 2019-11-06-chernoff-2.png ├── 2019-11-06-clipping.png ├── 2019-11-06-dataset-examples-points.png ├── 2019-11-06-dataset-examples.png ├── 2019-11-06-ex.png ├── 2019-11-06-face-changing.png ├── 2019-11-06-face-interpolate.png ├── 2019-11-06-not-fixed.png ├── 2019-11-06-other-faces.png ├── 2020-01-06-facemap-labeled.png ├── 2020-01-06-facemap-many.png └── 2020-01-06-neighbors.png ├── nb_code ├── __init__.py ├── hmm_alpha_recursion.py └── viterbi.py ├── requirements.txt └── scripts ├── model_simulation.py ├── process_weather_data.py └── run_simulation_for_d3.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | migrated 3 | .ipynb_checkpoints 4 | __pycache__ 5 | data/* 6 | nbconfig 7 | -------------------------------------------------------------------------------- /2018-05-04-ipywidgets-for-learning-logistic-sigmoid-and-bayes-classifiers.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# `ipywidgets.interact` for logistic sigmoid and Bayes classifiers\n", 8 | "\n", 9 | "This notebook shows how I used `ipywidgets.interact` to better understand equations in machine learning, including the logistic sigmoid and Bayes classifiers." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": { 16 | "collapsed": true 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "import numpy as np\n", 21 | "import matplotlib.pyplot as plt\n", 22 | "\n", 23 | "from ipywidgets import interact" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "## Logistic Sigmoid\n", 31 | "\n", 32 | "The sigmoid function shows up a lot in machine learning. For one input dimension, one form I frequently see is\n", 33 | "\n", 34 | "$$\\sigma(x) = \\frac{1}{1 + \\exp(-v^{\\top}x + b)}$$\n", 35 | "\n", 36 | "This code allows one to vary $v$ and $b$ for two sigmoid functions. The code produces something like:\n", 37 | "\n", 38 | "![sigmoid example](images/2018-05-04-sigmoid.png)" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 2, 44 | "metadata": {}, 45 | "outputs": [ 46 | { 47 | "data": { 48 | "application/vnd.jupyter.widget-view+json": { 49 | "model_id": "a97904098819464da2d62ef02565c67c", 50 | "version_major": 2, 51 | "version_minor": 0 52 | }, 53 | "text/html": [ 54 | "

Failed to display Jupyter Widget of type interactive.

\n", 55 | "

\n", 56 | " If you're reading this message in Jupyter Notebook or JupyterLab, it may mean\n", 57 | " that the widgets JavaScript is still loading. If this message persists, it\n", 58 | " likely means that the widgets JavaScript library is either not installed or\n", 59 | " not enabled. See the Jupyter\n", 60 | " Widgets Documentation for setup instructions.\n", 61 | "

\n", 62 | "

\n", 63 | " If you're reading this message in another notebook frontend (for example, a static\n", 64 | " rendering on GitHub or NBViewer),\n", 65 | " it may mean that your frontend doesn't currently support widgets.\n", 66 | "

\n" 67 | ], 68 | "text/plain": [ 69 | "interactive(children=(FloatSlider(value=-1.0, description='vv1', max=10.0, min=-12.0, step=0.25), IntSlider(value=0, description='bb1', max=10, min=-10), IntSlider(value=1, description='vv2', max=12, min=-10), IntSlider(value=0, description='bb2', max=10, min=-10), Output()), _dom_classes=('widget-interact',))" 70 | ] 71 | }, 72 | "metadata": {}, 73 | "output_type": "display_data" 74 | }, 75 | { 76 | "data": { 77 | "text/plain": [ 78 | "" 79 | ] 80 | }, 81 | "execution_count": 2, 82 | "metadata": {}, 83 | "output_type": "execute_result" 84 | } 85 | ], 86 | "source": [ 87 | "def logistic_sigmoid(xx, vv, b):\n", 88 | " return 1 / (1 + np.exp(-(np.dot(vv, xx) + b)))\n", 89 | "\n", 90 | "plt.clf()\n", 91 | "grid_size = 0.01\n", 92 | "x_grid = np.arange(-5, 5, grid_size)\n", 93 | "\n", 94 | "\n", 95 | "def plot_logistic_sigmoid(vv1, bb1, vv2, bb2):\n", 96 | " plt.plot(x_grid, logistic_sigmoid(x_grid, vv=vv1, b=bb1), '-b')\n", 97 | " plt.plot(x_grid, logistic_sigmoid(x_grid, vv=vv2, b=bb2), '-r')\n", 98 | " plt.axis([-5, 5, -0.5, 1.5])\n", 99 | " plt.show()\n", 100 | "\n", 101 | "\n", 102 | "interact(\n", 103 | " plot_logistic_sigmoid, \n", 104 | " vv1=(-12, 10, .25), \n", 105 | " bb1=(-10, 10), \n", 106 | " vv2=(-10, 12), \n", 107 | " bb2=(-10, 10)\n", 108 | ")" 109 | ] 110 | }, 111 | { 112 | "cell_type": "markdown", 113 | "metadata": {}, 114 | "source": [ 115 | "## Bayes Classifiers\n", 116 | "\n", 117 | "(These notes are from the machine learning course I took last semester. The relevant notes are here: [MLPR](http://www.inf.ed.ac.uk/teaching/courses/mlpr/2017/notes/w3a_intro_classification.html))\n", 118 | "\n", 119 | "Bayes classifiers are a method of classifying used in machine learning. \n", 120 | "\n", 121 | "For each class $k$, the Bayes classifier has a prior, $P(y = k)$, and a distribution over features, $P(\\textbf x \\mid y = k)$, such as a Gaussian $\\mathcal N (\\textbf x; \\hat\\mu_k, \\Sigma_k)$. \n", 122 | "The probability of the decision is $P(y=k \\mid \\textbf x) = \\frac{\\mathcal N(\\textbf x; \\hat \\mu_k, \\Sigma_k)}{\\sum_k'\\mathcal N(\\textbf x; \\hat \\mu_k', \\Sigma_k')}$. ([Naive Bayes](https://en.wikipedia.org/wiki/Naive_Bayes_classifier) is an example of a classifier that assumes features are independent, i.e. $\\Sigma_k$ is diagonal.)\n", 123 | "\n", 124 | "I wanted to see what shapes the decision surface could take on. \n", 125 | "If the covariances of the classes are equal, the decision boundary is linear.\n", 126 | "Even with one input dimension, the decision boundaries do interesting things. For example, if one class's variance is larger than the other, then there are two decision boundaries! And the sigmoid function makes another appearance." 127 | ] 128 | }, 129 | { 130 | "cell_type": "markdown", 131 | "metadata": {}, 132 | "source": [ 133 | "### One input dimension\n", 134 | "\n", 135 | "This code allows you to adjust the parameters of a Bayes classifier with one input dimension and two classes. It produces something like\n", 136 | "\n", 137 | "![Bayes classifer with one dimension](images/2018-05-04-bayes-1.png)\n" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 3, 143 | "metadata": {}, 144 | "outputs": [ 145 | { 146 | "data": { 147 | "application/vnd.jupyter.widget-view+json": { 148 | "model_id": "668a7957ab6c49c98436e825195f47ae", 149 | "version_major": 2, 150 | "version_minor": 0 151 | }, 152 | "text/html": [ 153 | "

Failed to display Jupyter Widget of type interactive.

\n", 154 | "

\n", 155 | " If you're reading this message in Jupyter Notebook or JupyterLab, it may mean\n", 156 | " that the widgets JavaScript is still loading. If this message persists, it\n", 157 | " likely means that the widgets JavaScript library is either not installed or\n", 158 | " not enabled. See the Jupyter\n", 159 | " Widgets Documentation for setup instructions.\n", 160 | "

\n", 161 | "

\n", 162 | " If you're reading this message in another notebook frontend (for example, a static\n", 163 | " rendering on GitHub or NBViewer),\n", 164 | " it may mean that your frontend doesn't currently support widgets.\n", 165 | "

\n" 166 | ], 167 | "text/plain": [ 168 | "interactive(children=(IntSlider(value=5, description='zoom', max=10, min=1), FloatSlider(value=0.505, description='pi_1', max=1.0, min=0.01), IntSlider(value=3, description='mu_1', max=12, min=-5), IntSlider(value=-4, description='mu_2', max=5, min=-12), FloatSlider(value=1.0, description='var_1', max=2.0), FloatSlider(value=1.0, description='var_2', max=2.0), Output()), _dom_classes=('widget-interact',))" 169 | ] 170 | }, 171 | "metadata": {}, 172 | "output_type": "display_data" 173 | }, 174 | { 175 | "data": { 176 | "text/plain": [ 177 | "" 178 | ] 179 | }, 180 | "execution_count": 3, 181 | "metadata": {}, 182 | "output_type": "execute_result" 183 | } 184 | ], 185 | "source": [ 186 | "import numpy as np\n", 187 | "from ipywidgets import interact\n", 188 | "import matplotlib.pyplot as plt\n", 189 | "from scipy.stats import multivariate_normal\n", 190 | "\n", 191 | "plt.clf()\n", 192 | "grid_size = 0.1\n", 193 | "x_grid = np.arange(-5, 5, grid_size)\n", 194 | "\n", 195 | "def draw(zoom, pi_1, mu_1, mu_2, var_1, var_2):\n", 196 | "\n", 197 | " Z1 = multivariate_normal(mu_1, var_1)\n", 198 | " Z2 = multivariate_normal(mu_2, var_2)\n", 199 | " \n", 200 | " # Decision boundary!\n", 201 | " Z_num = pi_1 * Z1.pdf(x_grid)\n", 202 | " z_den = ((1 - pi_1) * Z2.pdf(x_grid) + pi_1 * Z1.pdf(x_grid)) + 1e-300\n", 203 | " Z = Z_num / z_den\n", 204 | "\n", 205 | " plt.figure()\n", 206 | " plt.plot(x_grid, Z, c='k')\n", 207 | " plt.plot(x_grid, Z1.pdf(x_grid))\n", 208 | " plt.plot(x_grid, Z2.pdf(x_grid))\n", 209 | "\n", 210 | " # set the axis based on the zoom\n", 211 | " plt.axis([-5, 5, -0.05, 1.05])\n", 212 | " plt.show()\n", 213 | "\n", 214 | "\n", 215 | "interact(\n", 216 | " draw, \n", 217 | " zoom=(1, 10), # zoom into image\n", 218 | " pi_1=(0.01, 1), # prior on first class\n", 219 | " mu_1=(-5, 12), # mean of first class\n", 220 | " mu_2=(-12, 5), # mean of second class\n", 221 | " var_1=(0, 2.), # variance of first class\n", 222 | " var_2=(0, 2.), # variance of second class\n", 223 | ")" 224 | ] 225 | }, 226 | { 227 | "cell_type": "markdown", 228 | "metadata": {}, 229 | "source": [ 230 | "## Two input dimensions\n", 231 | "\n", 232 | "This code produces something like\n", 233 | "\n", 234 | "![Bayes classifier with 2 input dimensions](images/2018-05-04-bayes-2.png)" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": 4, 240 | "metadata": {}, 241 | "outputs": [ 242 | { 243 | "data": { 244 | "application/vnd.jupyter.widget-view+json": { 245 | "model_id": "48d361e08db541c888ac92f36daf3c84", 246 | "version_major": 2, 247 | "version_minor": 0 248 | }, 249 | "text/html": [ 250 | "

Failed to display Jupyter Widget of type interactive.

\n", 251 | "

\n", 252 | " If you're reading this message in Jupyter Notebook or JupyterLab, it may mean\n", 253 | " that the widgets JavaScript is still loading. If this message persists, it\n", 254 | " likely means that the widgets JavaScript library is either not installed or\n", 255 | " not enabled. See the Jupyter\n", 256 | " Widgets Documentation for setup instructions.\n", 257 | "

\n", 258 | "

\n", 259 | " If you're reading this message in another notebook frontend (for example, a static\n", 260 | " rendering on GitHub or NBViewer),\n", 261 | " it may mean that your frontend doesn't currently support widgets.\n", 262 | "

\n" 263 | ], 264 | "text/plain": [ 265 | "interactive(children=(IntSlider(value=5, description='zoom', max=10, min=1), FloatSlider(value=0.505, description='pi_1', max=1.0, min=0.01), IntSlider(value=3, description='m_1', max=12, min=-5), IntSlider(value=-4, description='m_2', max=5, min=-12), IntSlider(value=3, description='stretch', max=5, min=1), IntSlider(value=3, description='stretch_2', max=5, min=1), FloatSlider(value=0.5, description='cov_1', max=1.0), FloatSlider(value=0.5, description='cov_2', max=1.0), Output()), _dom_classes=('widget-interact',))" 266 | ] 267 | }, 268 | "metadata": {}, 269 | "output_type": "display_data" 270 | }, 271 | { 272 | "data": { 273 | "text/plain": [ 274 | "" 275 | ] 276 | }, 277 | "execution_count": 4, 278 | "metadata": {}, 279 | "output_type": "execute_result" 280 | } 281 | ], 282 | "source": [ 283 | "import numpy as np\n", 284 | "from ipywidgets import interact\n", 285 | "import matplotlib.pyplot as plt\n", 286 | "from scipy.stats import multivariate_normal\n", 287 | "\n", 288 | "\n", 289 | "delta = 0.2\n", 290 | "x = np.arange(-60, 60, delta)\n", 291 | "y = np.arange(-40, 40, delta)\n", 292 | "X, Y = np.meshgrid(x, y)\n", 293 | "pos = np.empty(X.shape + (2,))\n", 294 | "pos[:, :, 0] = X\n", 295 | "pos[:, :, 1] = Y\n", 296 | "\n", 297 | "\n", 298 | "def draw(zoom, pi_1, m_1, m_2, stretch, stretch_2, cov_1, cov_2):\n", 299 | " Sigma_1 = np.array([[stretch, cov_1], [cov_1, 1.0]])\n", 300 | " mu_1 = np.array([m_1, 0.0])\n", 301 | "\n", 302 | " Sigma_2 = np.array([[stretch_2, cov_2], [cov_2, stretch_2]])\n", 303 | " mu_2 = np.array([m_2, 0.0])\n", 304 | "\n", 305 | " Z1 = multivariate_normal(mu_1, Sigma_1)\n", 306 | " Z2 = multivariate_normal(mu_2, Sigma_2)\n", 307 | " \n", 308 | " # compute the decision boundary!\n", 309 | " Z_num = pi_1 * Z1.pdf(pos)\n", 310 | " z_den = ((1 - pi_1) * Z2.pdf(pos) + pi_1 * Z1.pdf(pos)) + 1e-300 # add an offset to avoid divide by 0\n", 311 | " Z = Z_num / z_den - 0.5\n", 312 | "\n", 313 | " plt.figure()\n", 314 | " plt.contour(X, Y, Z, 10, colors='k') \n", 315 | " plt.contour(X, Y, pi_1 * Z1.pdf(pos), 5)\n", 316 | " plt.contour(X, Y, (1 - pi_1) * Z2.pdf(pos), 5)\n", 317 | "\n", 318 | " # set the axis based on the zoom\n", 319 | " plt.axis([int(i / (zoom * 4)) for i in [-150, 150, -100, 100]])\n", 320 | " plt.show()\n", 321 | "\n", 322 | "\n", 323 | "interact(\n", 324 | " draw, \n", 325 | " zoom=(1, 10), \n", 326 | " pi_1=(0.01, 1), # prior on the first class\n", 327 | " m_1=(-5, 12), # x dimension of the first class's mu\n", 328 | " m_2=(-12, 5), # x dimension of the second class's mu\n", 329 | " stretch=(1, 5), # x dimension variance\n", 330 | " stretch_2=(1, 5), # x and y dimension variance\n", 331 | " cov_1=(0, 1.0), # covariance between two dimension's for first class\n", 332 | " cov_2=(0, 1.0), # covariance between two dimension's for second class\n", 333 | ")" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": null, 339 | "metadata": { 340 | "collapsed": true 341 | }, 342 | "outputs": [], 343 | "source": [] 344 | } 345 | ], 346 | "metadata": { 347 | "kernelspec": { 348 | "display_name": "Python 3", 349 | "language": "python", 350 | "name": "python3" 351 | }, 352 | "language_info": { 353 | "codemirror_mode": { 354 | "name": "ipython", 355 | "version": 3 356 | }, 357 | "file_extension": ".py", 358 | "mimetype": "text/x-python", 359 | "name": "python", 360 | "nbconvert_exporter": "python", 361 | "pygments_lexer": "ipython3", 362 | "version": "3.6.2" 363 | } 364 | }, 365 | "nbformat": 4, 366 | "nbformat_minor": 2 367 | } 368 | -------------------------------------------------------------------------------- /2018-05-07-reading-jupyter-notebooks-into-Python.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Reading Jupyter notebooks into Python\n", 8 | "\n", 9 | "For my [digitized notes](https://jessicastringham.net/2018/05/06/notebook-tour.html) project, I wrote a few scripts that read Markdown cells from Jupyter notebook files. Specifically, I read a notebook's non-empty Markdown cells and used them for my search index and flashcard database. \n", 10 | "\n", 11 | "## Reading Jupyter notebooks as data\n", 12 | "Reading Jupyter notebooks as data is pretty easy! Below I'll read the non-empty markdown cells." 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 1, 18 | "metadata": { 19 | "collapsed": true, 20 | "scrolled": false 21 | }, 22 | "outputs": [], 23 | "source": [ 24 | "import nbformat\n", 25 | "\n", 26 | "path = '2018-05-02-HMM.ipynb'\n", 27 | "\n", 28 | "NB_VERSION = 4\n", 29 | "\n", 30 | "with open(path) as f:\n", 31 | " nb = nbformat.read(f, NB_VERSION)\n", 32 | "\n", 33 | "markdown_cells = [\n", 34 | " cell['source']\n", 35 | " for cell in nb['cells'] # go through the cells\n", 36 | " if cell['cell_type'] == 'markdown' and cell['source'] # skip things like 'code' cells, and empty markdown cells\n", 37 | "]" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "## Rendering Markdown and LaTeX\n", 45 | "\n", 46 | "Below shows how to render markdown in a iPython notebook to show what I can do with a dictionary of Jupyter notebook data. This is also how I render flashcards in [digitized notes](https://jessicastringham.net/2018/05/06/notebook-tour.html)." 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 2, 52 | "metadata": { 53 | "scrolled": false 54 | }, 55 | "outputs": [ 56 | { 57 | "data": { 58 | "text/markdown": [ 59 | "**Below is data loaded from [this other file](2018-05-02-HMM.ipynb)!** \n", 60 | "\n" 61 | ], 62 | "text/plain": [ 63 | "" 64 | ] 65 | }, 66 | "metadata": {}, 67 | "output_type": "display_data" 68 | }, 69 | { 70 | "data": { 71 | "text/markdown": [ 72 | "# Inference in discrete state Hidden Markov Models using numpy\n", 73 | "\n", 74 | "This demo shows exact inference on a Hidden Markov Model with known, discrete transition and emission distributions that are fixed over time. It does alpha recursion, which is a sum-product algorithm on HMMs.\n", 75 | "\n", 76 | "The problem is taken from Example 23.3 from [\"Bayesian Reasoning and Machine Learning\"](http://www.cs.ucl.ac.uk/staff/d.barber/brml/), and this is more or less a port of the provided Matlab implementation with my commentary.\n", 77 | "\n", 78 | "I'm trying to understand how my probabilistic graphical modelling class can be represented in code, so I try to go through this painfully slow and show how the `numpy` arrays map to the discrete probability distributions.\n", 79 | "\n", 80 | "## Problem\n", 81 | "\n", 82 | "The idea is that there's someone moving around a room. You can't see them, but you can hear bumps and creaks from them moving around. You know the probability of someone bumping or creaking at each location in the room. You also know how one moves around the room.\n", 83 | "Using the sequence of bumps and/or creaks, the goal is to figure out the locations in the room.\n", 84 | "\n", 85 | "Below is what I will end up with:\n", 86 | "\n", 87 | "![Three sets of 10 images, second row showing a dot moving around a grid, third row showing highlighted area where model thinks the dot is](images/2018-05-02-filtering.png)\n", 88 | "\n", 89 | "The first row represents the bumps/creaks you hear. The second row represents where the person actually is. The third row represents where alpha recursion guesses the person is.\n", 90 | "\n", 91 | "### HMMs\n", 92 | "\n", 93 | "![HMM diagram](images/2018-05-02-hmm.png)\n", 94 | "\n", 95 | "This problem can be modeled using a Hidden Markov Model. At each timestep $t$ in the sequence, there's the visible state $v_t$ (the bump/creak combination) and the hidden state $h_t$ (the location in the room.)\n", 96 | "\n", 97 | "The goal of *filtering* is to show the probability distribution over the locations in the room at a given timestep $t$ given the bumps/creaks from the current and previous timesteps, $p(h_t\\mid v_{1:t})$. Plotting this distribution at each timestep gives a heatmap of where in the room we think the person is. \n", 98 | "\n", 99 | "In this model, there is a transition distribution $p(h_t \\mid h_{t - 1})$ to show that each hidden state depends on the previous timestep's hidden state. There also is the emission distribution $p(v_t \\mid h_t)$ to show each visible state depends on the corresponding hidden state. There's also the probability of where in the room the person starts, $p(h_1)$.\n", 100 | "\n", 101 | "The rest of the notebook shows:\n", 102 | " - Setting up the known distributions $p(h_1)$, $p(h_t \\mid h_{t - 1})$, and $p(v_t \\mid h_t)$. This is a lot of the notebook!\n", 103 | " - Generating some data for the locations ($h_{1:t}$) and sounds ($v_{1:t}$) from that distribution.\n", 104 | " - Finally I get to alpha recursion. I try to predict $h_{1:t}$ based on $v_{1:t}$, $p(h_1)$, $p(h_t \\mid h_{t - 1})$, and $p(v_t \\mid h_t)$." 105 | ], 106 | "text/plain": [ 107 | "" 108 | ] 109 | }, 110 | "metadata": {}, 111 | "output_type": "display_data" 112 | } 113 | ], 114 | "source": [ 115 | "from IPython.display import display, Markdown\n", 116 | "\n", 117 | "display(Markdown(\"**Below is data loaded from [this other file]({})!** \\n\\n\".format(path)))\n", 118 | "display(Markdown(markdown_cells[0]))" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "metadata": {}, 124 | "source": [ 125 | "## Bonus application (added 2018/05/08),\n", 126 | "\n", 127 | "I realized that GitHub doesn't always render Jupyter notebooks nicely, so here's how I [generate Jekyll posts from my Jupyter notebooks](https://gist.github.com/jessstringham/1ff8ec24dafc0fcff15d4a0e88be074e). This post is an example! Another example is my [alpha recursion post](http://localhost:4000/2018/05/02/hmm-alpha-recursion.html)." 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "metadata": { 134 | "collapsed": true 135 | }, 136 | "outputs": [], 137 | "source": [] 138 | } 139 | ], 140 | "metadata": { 141 | "kernelspec": { 142 | "display_name": "Python 3", 143 | "language": "python", 144 | "name": "python3" 145 | }, 146 | "language_info": { 147 | "codemirror_mode": { 148 | "name": "ipython", 149 | "version": 3 150 | }, 151 | "file_extension": ".py", 152 | "mimetype": "text/x-python", 153 | "name": "python", 154 | "nbconvert_exporter": "python", 155 | "pygments_lexer": "ipython3", 156 | "version": "3.6.1" 157 | } 158 | }, 159 | "nbformat": 4, 160 | "nbformat_minor": 2 161 | } 162 | -------------------------------------------------------------------------------- /2018-05-13-hmm-check-results.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Verify HMM Checks\n", 8 | "\n", 9 | "This post recreates the example from [\"Bayesian Reasoning and Machine Learning\"](http://www.cs.ucl.ac.uk/staff/d.barber/brml/) to sanity check my results. See [my alpha recursion HMM post](2018-05-02-hmm-alpha-recursion.ipynb) and [my viterbi HMM post](2018-05-13-more-exact-inference-in-hmm.ipynb) for the real posts (which use custom generated data instead)." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": { 16 | "scrolled": false 17 | }, 18 | "outputs": [ 19 | { 20 | "name": "stdout", 21 | "output_type": "stream", 22 | "text": [ 23 | "[11 10 4 3 4 10 16 17 23 17]\n" 24 | ] 25 | } 26 | ], 27 | "source": [ 28 | "import numpy as np\n", 29 | "import matplotlib.pyplot as plt\n", 30 | "\n", 31 | "import nb_code.hmm_alpha_recursion as prev_post\n", 32 | "import nb_code.viterbi as viterbi_post" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": { 39 | "collapsed": true 40 | }, 41 | "outputs": [], 42 | "source": [ 43 | "# helper functions you can skip over :D\n", 44 | "def hide_ticks(plot):\n", 45 | " plot.axes.get_xaxis().set_visible(False)\n", 46 | " plot.axes.get_yaxis().set_visible(False)" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "## Code\n", 54 | "\n", 55 | "This adjust the constants from the previous posts to match the example in the textbook." 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 3, 61 | "metadata": { 62 | "scrolled": false 63 | }, 64 | "outputs": [], 65 | "source": [ 66 | "width = 5\n", 67 | "height = 5\n", 68 | "num_hidden_states = width * height\n", 69 | "\n", 70 | "map_x_y_to_hidden_state_id = np.arange(num_hidden_states).reshape(height, width).T\n", 71 | "\n", 72 | "# prob of starting starting locations\n", 73 | "p_hidden_start = np.ones(num_hidden_states) / num_hidden_states\n", 74 | "\n", 75 | "# verify it's a valid probability distribution\n", 76 | "assert np.all(np.isclose(np.sum(p_hidden_start), 1))\n", 77 | "assert np.all(p_hidden_start >= 0)\n", 78 | "\n", 79 | "p_transition = prev_post.create_transition(width, height)\n", 80 | "\n", 81 | "# Hard code the creak and bump locations!\n", 82 | "prob_bump_true_given_location = 0.1 * np.ones(num_hidden_states)\n", 83 | "prob_bump_true_given_location[[3, 4, 6, 7, 11, 15, 18, 19, 20, 21]] = 0.9\n", 84 | "\n", 85 | "prob_creak_true_given_location = 0.1 * np.ones(num_hidden_states)\n", 86 | "prob_creak_true_given_location[[0, 3, 4, 5, 8, 9, 11, 13, 16, 20]] = 0.9\n", 87 | "\n", 88 | "p_emission = prev_post.get_emission_matrix(prob_bump_true_given_location, prob_creak_true_given_location)\n", 89 | "\n", 90 | "# hard code the visibles and hidden\n", 91 | "visibles = np.array([1, 2, 0, 3, 0, 3, 2, 2, 3, 3])\n", 92 | "hiddens = np.array([15, 16, 11, 6, 11, 10, 5, 0, 1, 2])\n", 93 | "\n", 94 | "alphas = prev_post.alpha_recursion(\n", 95 | " visibles, \n", 96 | " p_hidden_start,\n", 97 | " p_transition,\n", 98 | " p_emission,\n", 99 | ")\n", 100 | "\n", 101 | "most_likely_states = viterbi_post.viterbi(\n", 102 | " visibles, \n", 103 | " p_hidden_start,\n", 104 | " p_transition,\n", 105 | " p_emission,\n", 106 | ")" 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "metadata": {}, 112 | "source": [ 113 | "### Visualizing bumps and creaks\n", 114 | "\n", 115 | "This should match the image on page 501." 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 4, 121 | "metadata": {}, 122 | "outputs": [ 123 | { 124 | "data": { 125 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWQAAAC6CAYAAACZWDfLAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAC3VJREFUeJzt3HuMXHUZxvHngUJLqaYpRaVQWrzUAFoqpEVuAcNlA4rE\nIF5ajIAV8RL8B/CGCoJBE0SNiCgJRSgoRLyDLq2koAIKSilRtEql3VKQ0lJS2nLt6x/nbD2d7u7M\n7O28234/ySQ7M78zv/fsvvPMmd+ZWUeEAAD126nuAgAABQIZAJIgkAEgCQIZAJIgkAEgCQIZAJIY\nkYFse7Ttv9veq7x+ne1L664rK9vH2V5YuT7f9kW9jN3N9j9t7zFsBWKLxt4exnmPsb1yOOccKNuX\n2r6wcn2l7WN6GXuw7d8PW3H9NCIDWdLZku6OiCfqLqRdtve1/VzlErY3VK4fVWd9EbFJ0g8lXVBn\nHTuwEdnbtudUeniT7c3VPq+7voj4q6RNtk+su5a+jNRAPkfSDXUX0R8RsSIixnVfypsPqty2zau4\n7Z2HucwbJZ1pe5dhnhf97G3bo4aglpZFxI2Vnj5R0qoe+nyLmuq9UdLHapi3ZWkD2fZjtj9Xvn17\nxvY822Ns7yvp9ZL+1LDJRNsLbK+3fZftKeXjTC2PQkdVHnuR7bnlz2fY/qPtb9peZ3uZ7cPL27ts\nP2X7w5Vtr7N9dU9zDeK+z7f9Xdu/tb1B0lG2/2D7jMqYubYXVa4fYHuh7bW2/2H71CbTTLD9m3If\n7rW9X/cdEbFc0gZJswZzv1Bop7fLJaRv2F5u+9myD3ar9PVHbK+QdGc5/u227yl7+aHqW3jbZ9p+\npPybL7PdazjZPresb59B3O+Vts+3/bCkDbZHlfswtTJmq+U02+8u92Ndue9vaTLNwbYfLn9XP7I9\nunLfIknHZz7QSBvIpTmSOiS9QdI0SRdKequkZRHxcg9jL5E0UdJiFa+GrTpU0hJJe0i6SdKPJc2U\n9EZJp0u60nb1Vb7XuWwvKZunp8tVbdQ0W9LFkl4l6d6+Bpa1LZB0vaTXlPX9wPabmzz+FyVNkLSi\n3J+qRyQd1Ea9aE+rvX25pEMkHa7ib3WBpM2V+4+WtL+kDtt7S7pN0qXl2PMk3Wp7z3LsU5LeJenV\nks6U9E3bBzcWZvtLks6QdHRErHSxzNZbT6+zPbuN/f6AiiPo8c0G2p4p6RpJc1U8N6+V9Avbu/ax\n2fskHa/ihe0QSR/qvqM80LCkN7VR77DKHshXRkRXRKyV9FVJH1Txh1zfw9jbIuLuiHhB0hckHWZ7\ncovz/Cci5kXEK5JuljRZ0lci4oWIuEPSiyrCuelcETE9Isb3cvlEG/v+s4i4NyI2l/P05RRJSyPi\n+oh4OSL+Iunnkt7bxzY/iYgHIuIlFS8oMxruX68WnjTot6a9bXsnSWdJ+nREPB4Rr0TEPQ39cFFE\nbCjX/k+XdHtE3F72zQJJD0g6SZIi4raIeDQKd0m6Q1L1nIVtXyHpBEnviIjV5XYr+ujp8RFxUxv7\n/e2IWFnW28zZkq6KiPvLfb+2vH1mH9t8KyKejIg1kn6tEdbX2QO5q/LzckmTJD2j4qix17ER8Zyk\nteX4Vvy38vOm8jEab6seIQ9krlZ1NR+yxRRJR1SPWiS9X1JfZ+qfrPy8UVvvn1T8jte1UQPa00pv\nT5Q0RtKjLT7OFEmnNfTBkSr7wPaJtu8rl7XWqQjqiZXtx6sIwcsi4tl+7lcz7fb1Zxr2Zy9Je/ex\nzYju6+yBXD3C3VfSKhVLC/t525MCW8aWb+EnlOM3lDePrYx93WDV1TCXbP/NW3+Konq5uo05Gv8N\n3wb1vg9dkn7XcNQyLiI+1d5ubWV/SQ8NYHv0rZXeflrS8yqWNXpT7ZMuSTc09MHuEfG1ci31VhVL\nIK+NiPGSblfxFr7bMyqWNObZPqL7Rm/7yaDGy5w29ntLveXSzAvqu68vbtifsRFxSxvzbeH/n+v5\nV3+2Hw7ZA/mTtvexPUHF0sDNEbFS0r+17Qmnk2wfWa4vXSLpvvIt4WpJj0s63fbOts9S3w3eih7n\nkqSIOLB6drnhcs4A5lws6dTyhM40FW9lu/1S0oG2Z9vepbzMarKG3CsXJ5fGSbp/APWib017OyI2\nq1g3vcL2pLJ/D2s4UVU1X9LJtjvKsWNcfL54H0m7ShotabWkl118/OuExgeIiEUq1rd/aru7jhV9\n9PS4iGjnfE2jhyTNKet9p4oj+m7XlL+nmS6Ms32y7d37OdfRkhaWy3QpZQ/km1Sscy1T8bat+8sf\n31dlsb4y9ssqlg8OUbGe1u2jks6XtEbSgZLuGYS6eptrqFyu4ujiKRVP0vndd5RvLzvKOp5Q8bbt\nMhVPwP6YI2leRLw4kILRp1Z7+zxJD6t4cVwr6evq5XlbHhScIunzKoK3S0Xf7xQR6yWdK+kWFUfC\ns1W8kPf0OAtUvOD/qqeTfoPsXEnvUbGMcFq1poi4T9LHJX2vrHmpBvZcmyOpnXepw86R9B/U235M\n0tyIWNjDfaMlPSjp2BjmD9Dbvk7Syoi4sNnYLGwfJ+mzEXFcC2N3U3E0fkREPD3kxe2Asvb2SOPi\n27nPR0TTb+nafpuk70TEkc3G1qnWD5P3V3mW+YC669gelWe/+7XUgYGjt4dGRDyorZdDUhqRgYy2\nLVPxGWVge3KnpMbvI4xoaZcsAGBHk/2kHgDsMAhkAEiirTXkXT06xqi/HwFEu6ZN31h3CcPqsa6X\n9PTaV9x85OCaOGHnmDq5nv83s3TJ2OaDtjP0de/aCuQx2l2H+tj+VYW2dXYurruEYTWro51v1Q6e\nqZN30Z87W/23J4OrY1Ljv1rY/tHXvWPJAgCSIJABIAkCGQCSIJABIAkCGQCSIJABIAkCGQCSIJAB\nIAkCGQCSIJABIAkCGQCSIJABIAkCGQCSIJABIAkCGQCSIJABIAkCGQCSIJABIAkCGQCSIJABIAkC\nGQCSIJABIAkCGQCSIJABIAkCGQCSIJABIAkCGQCSIJABIIlRdRfQqs5Vi+suYdh1TJpRdwnDamms\nqbsEoFYcIQNAEgQyACRBIANAEgQyACRBIANAEgQyACRBIANAEgQyACRBIANAEgQyACRBIANAEgQy\nACRBIANAEgQyACRBIANAEgQyACRBIANAEgQyACRBIANAEgQyACRBIANAEgQyACRBIANAEgQyACRB\nIANAEgQyACRBIANAEgQyACRBIANAEqPqLiC7jkkz6i4BQ2zpkrG1/Z07Vy2uZd4dsa/r2uelsabl\nsRwhA0ASBDIAJEEgA0ASBDIAJEEgA0ASBDIAJEEgA0ASBDIAJEEgA0ASBDIAJEEgA0ASBDIAJEEg\nA0ASBDIAJEEgA0ASBDIAJEEgA0ASBDIAJEEgA0ASBDIAJEEgA0ASBDIAJEEgA0ASBDIAJEEgA0AS\nBDIAJEEgA0ASBDIAJEEgA0ASBDIAJDGqncHTpm9UZ+fioaqlTx2TZtQyb506V9Xzu67LrI6NdZeA\nYbAjPpdbxREyACRBIANAEgQyACRBIANAEgQyACRBIANAEgQyACRBIANAEgQyACRBIANAEgQyACRB\nIANAEgQyACRBIANAEgQyACRBIANAEgQyACRBIANAEgQyACRBIANAEgQyACRBIANAEgQyACRBIANA\nEgQyACRBIANAEgQyACRBIANAEgQyACQxqu4CWtW5anEt83ZMmlHLvHXPXYelsaaWeadN36jOzh2v\nv5APR8gAkASBDABJEMgAkASBDABJEMgAkASBDABJEMgAkASBDABJEMgAkASBDABJEMgAkASBDABJ\nEMgAkASBDABJEMgAkASBDABJEMgAkASBDABJEMgAkASBDABJEMgAkASBDABJEMgAkASBDABJEMgA\nkASBDABJEMgAkASBDABJEMgAkASBDABJOCJaH2yvlrR86MrBDm5KROw53JPS1xhiLfd1W4EMABg6\nLFkAQBIEMgAkQSADQBIEMgAkQSADQBIEMgAkQSADQBIEMgAkQSADQBL/A6spq8cG7RjvAAAAAElF\nTkSuQmCC\n", 126 | "text/plain": [ 127 | "" 128 | ] 129 | }, 130 | "metadata": {}, 131 | "output_type": "display_data" 132 | } 133 | ], 134 | "source": [ 135 | "fig, (bump, creak) = plt.subplots(1, 2)\n", 136 | "\n", 137 | "bump.imshow(prob_bump_true_given_location.reshape(height, width))\n", 138 | "bump.set_title('p(bump=True|h)')\n", 139 | "hide_ticks(bump)\n", 140 | "\n", 141 | "creak.imshow(prob_creak_true_given_location.reshape(height, width))\n", 142 | "creak.set_title('p(creak=True|h)')\n", 143 | "hide_ticks(creak)\n", 144 | "\n", 145 | "plt.show()" 146 | ] 147 | }, 148 | { 149 | "cell_type": "markdown", 150 | "metadata": {}, 151 | "source": [ 152 | "### Visualizing different methods\n", 153 | "\n", 154 | "This should match the image on page 502." 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": 5, 160 | "metadata": {}, 161 | "outputs": [ 162 | { 163 | "data": { 164 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAA8sAAADqCAYAAACV6dehAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAGdtJREFUeJzt3XuQneddH/DvT7uyJCeMbzJ2FNneGNshcYDQNYRkEjCD\nBxOSEpcyKaXBJJRLOh1aCkwJMKGG4d40FKaUUDoNScMlhg7m1kTGQNK4uWpDMLk4Jk7kS3yJZVmJ\nb5Kt3ad/7BFdXkt+X0lnz54jfz4zZ2b33d95nmd/77vvnu+e9+yp1loAAACA/2/TRi8AAAAApo2w\nDAAAAB3CMgAAAHQIywAAANAhLAMAAECHsAwAAAAdwjIAAAB0CMsAAADQISwDAABAh7AMAAAAHfMb\nvYBps3379rawsLDRy3iCpaWljV7C0extrZ19PHesqjbuxYzD4uLiRi/hiJaWlo6719N6XE+jPXv2\nZO/evXU8953WY3qKOaYn5ETOH47rY3bS9fpk/L2o18P5vThRJ90xPcUG9VpY7lhYWMju3bs3ehlP\nUHVc56hJuG2jFzBu07j/k6SqjrvX03pcT6PLLrtso5fwVOKYnpATOX9wzE66Xk/rz9rJeFxPY6/9\nXpyok+6YnmKDeu0ybAAAAOgQlgEAAKBDWAYAAIAOYRkAAAA6hGUAAADoEJYBAACgQ1gGAACADmEZ\nAAAAOoRlAAAA6BCWAQAAoENYBgAAgA5hGQAAADpOKCxX1Zuq6vUD6j5WVZePPr6mqt72JLV7quqK\nE1kXAAAAnIj5J/tiVb0zyQdbaz/Z2f6KJL+RZGdr7VDfJK21S09olQAAADBBfc8svyXJq6qqOtu/\nM8lvDwnKAAAAMGv6wvJ1Sc5K8pLDG6rqjCQvT/LWqvqtqvqZ0fbtVfWnVbW/qvZV1XuqatPoa91L\nq7dW1dur6sGq+nBVfcWRJq+qTVX1uqq6tarur6prq+rM0de2VtXbRtv3V9WHquqc428FAAAArHrS\nsNxaezTJtUmuXrP5lUlubq39Taf8h5PcmeTsJOck+fEk7ShDvyLJ7yc5M8nvJLmuqjYfoe4HklyV\n5OuS7EjyQJJfG33tu5KcluS8rAb61yZ59Mm+HwAAABhiyD/4ekuSb6uqraPPrx5t63o8yTOSXNBa\ne7y19p7W2tHC8lJr7Q9aa48neWOSrUm+5gh1r03yE621O1trB5NcM1rL/Gi+s5Jc1Fpbbq0ttda+\nMOD7AQAAgCfVG5Zbazcm2Zvkqqr6kiRfndVng7v+Y5JPJbm+qj5dVa97kmHvWDP+Slafkd5xhLoL\nkvzh6DLr/Uk+kWQ5q89c/88ku5L8XlXdVVW/dJRnpwEAAOCYDH3rqLdm9RnlVyXZ1Vq7t1vQWnuw\ntfbDrbULk3xLkh+qqm84ynjnHf5g9LrmnUnuOkLdHUle2lo7fc1ta2vts6Nnr3+qtfbcJC/K6uuo\nrz7CGAAAAHBMjiUsX5Hke3PkS7BTVS+vqotG/zn781l9BnjlKOMtVtW3ji6n/sEkB5O8/wh1b0ry\ns1V1wWiOs0dvW5Wq+vqq+rKqmkvyhaxeln20+QAAAGCwQWG5tbYnyXuTPC3JHx+l7OIkNyR5KMn7\nkvzX1tpfHaX2j5L8s6z+w67vTPKto9cvd/3KaL7rq+rBrAbqF4y+dm6SP8hqUP5Ekndn9dJsAAAA\nOCHzQwtba5cfYdur13z8y0l++Sj3XVjz8TU986ytXcnqPwB74xHqfjfJ7/YsGwAAAI7Z0MuwAQAA\n4ClDWAYAAIAOYRkAAAA6hGUAAADoEJYBAACgQ1gGAACADmEZAAAAOoRlAAAA6BCWAQAAoENYBgAA\ngA5hGQAAADqEZQAAAOio1tpGr2GqVNV9SW7b6HXMkAtaa2cfzx31+pjp9WTo8+To9eTo9eTo9eTo\n9WTo8+To9eQM6rWwDAAAAB0uwwYAAIAOYRkAAAA6hGUAAADoEJYBAACgQ1gGAACADmEZAAAAOoRl\nAAAA6BCWAQAAoENYBgAAgA5hGQAAADqEZQAAAOgQlgEAAKBDWAYAAICO+Y1ewMmiqtpGr+EE7W2t\nnb3Rixhi1nvdWquNXsNQej05ej0Zs97nOFdPkl5PyKycP5Jk+/btbWFhobduaWlp/RdzHGap10OP\n68XFxd6ajdgfs9LrcZ4/huyLZOz7Y13P1cIyh9220QsAoJdz9eToNU+wsLCQ3bt399ZVzUROOinY\nH9NjyL5Ixr4/1vVc7TJsAAAA6BCWAQAAoENYBgAAgA5hGQAAADqEZQAAAOgQlgEAAKBDWAYAAIAO\nYRlghrTWem/AU8vi4qJzw4QsLS2lqnpvTI79MT2G7ItZ2x/CMgAAAHQIywAAANAhLA9UVZdX1Z0b\nvQ4AAADW30yF5ap6V1U9UFVbBtQuVFWrqvlJrA0AAICTx8yE5apaSPKSJC3Jt2zoYgAAADipzUxY\nTnJ1kvcn+a0k33V4Y1Vtq6r/VFW3VdXnq+rGqtqW5P+MSvZX1UNV9cKquqaq3rbmvv/g2eeqek1V\nfaKqHqyqT1fV90/u2wMAAGBazNIlylcneWOSDyR5f1Wd01q7N8kbklya5EVJ7knygiQrSb42yWeS\nnN5aO5QkVXVlzxyfS/LyJJ8e3f8dVfWh1tqH1+H7AQAAYErNxDPLVfXiJBckuba1tpTk1iTfUVWb\nknx3kn/bWvtsa225tfbe1trB45mntfZnrbVb26p3J7k+q5d+AwAA8BQyE2E5q5ddX99a2zv6/HdG\n27Yn2ZrV8HzCquqlVfX+qtpXVfuTfPNoDuBJtNZ6b4xHVfXeOHGO6RO3uLiojxOytLTk3ACwDqb+\nMuzR649fmWSuqu4Zbd6S5PQkz0hyIMmXJPmbzl2P9Bv44SSnrvn83DXzbEnyv7J6ufcftdYer6rr\nkvjtAgAA8BQzC88sX5VkOclzkzx/dHtOkvdkNdj+jyRvrKodVTU3+kdeW5Lcl9XXLl+4ZqyPJPna\nqjq/qk5L8mNrvnZKVkP4fUkOVdVLk3zj+n5rAAAATKNZCMvfleTNrbXbW2v3HL4l+S9J/kWS1yX5\n2yQfSrIvyS8m2dRaeyTJzyb5v1W1v6q+prX250nenuSmJEtJ/vTwJK21B5P8myTXJnkgyXck+eNJ\nfZMAAABMj/J6ofGoqllv5FJr7bKNXsQQs97r1trMXNo/tNdDziMb8Xq5k7HX02pWej3rx3Rm6Fx9\n2WWXtd27d/fWTfFraWem184fk6PXk6PXkzHrfc46n6tn4ZllAAAAmChhGQAAADqEZQAAAOgQlgEA\nAKBDWAYAAIAOYXlMFhcX01rrvTE59sfkVFXvDWaJY/rELS0t6SMAM01YBgAAgA5hGQAAADqEZQAA\nAOgQlgEAAKBDWAYAAIAOYRkAAAA6hGUAAADoEJYBAACgY36jF3CyWFpaSlVt9DJYw/4AAACOl2eW\nAQAAoENYBgAAgA5hGQAAADqEZQAAAOgQlgEAAKBDWAYAAIAOYRkAAAA6hGUAAADoEJYBAACgY36j\nF3AS2Zvkto1exAm4YKMXcAxmudez1OdErydJrydjlvuc6PUk6fVkzFKfE72eJL2ejFnuc7LOva7W\n2nqODwAAADPHZdgAAADQISwDAABAh7AMAAAAHcIyAAAAdAjLAAAA0CEsAwAAQIewDAAAAB3zG72A\nk8UptaVtzdN662rLKRNYzbH7wsF797bWzt7odQxxyqatbdvcF/UXjvEtxNvK8ljmO5CH81g7WCe+\noskY2ut2aEB/BnrsGf0/R6d8YaW35sDB/Xns8Ydnp9cDzyE5dWtvycHtc4Pm3HL7w4Pq+szScT24\nz2O0fFb/fHP7Hhk01oNt3+ycq4f+XtzU/3f7g+dsGzbn3eM5ppPkwTwwQ73e2rZtenp/4ab+H9Oh\n5/Oa6z/PtOX+sWbp/JEk28+cawvnbe6tu+WmU/sHG/pdD3k8U/2DHWgP57F2YGZ6PfQccsmX958/\nB+2PMZql4/qU+VPbts2n9da1Awd7a4bsi2S8+2O9z9XC8phszdPygvqG3rr5nQvrv5jj8M5b33Db\nRq9hqG1zX5QXnvFP+wsPHeqvqWEXV6w8+GBvzZAHBR9YuWHQfNNiaK+X994/tjlv/74X9dacd0P/\nA+IPfuTXx7GciRl6DqlLn9dbc8trhoXBi//1B/qLBjwAm6XjemuelhfMfWN/4ZA/kA207xUv7K05\n63c+PGisPz/w2zNzrh7a601bt/TW7Pm+5w+a8/yffm9/0YBjOkluWPn9men1tk1Pz9dse1lvXW3r\n/2Pb8v37Bs05d9oZ/WM98EBvzQfaXwyab1osnLc5H9x1Xm/dlc/8yt6aIX9wSJI24PFMbe5/Mub9\nj79z0HzTYujvxV27PtJbc+WOYeeQQeeH1v/Xi1k6rrdtPi0vvPA1vXXLn/x0b82uXUuD5hy8Pwa4\nof3Bup6rXYYNAAAAHcIyAAAAdAjLAAAA0CEsAwAAQIewDAAAAB3CMgAAAHQIywAAANAhLAMAAEDH\n/EYv4Knmz268bmxjjfMNvWfJY9u35vbXPLu37pm/8N6xzXnnj72ot+b8P9nXW1OfunEcy5mYtryS\nlc8/2Fs3d84X9w926NCgOZ/16383lrHqwGOD5psaVanNp/SWze17qLfm4h/4xKAp5889p7dm5eFH\nemvqodn5u2vNz2du+1m9dcv33d9bM7/j3EFznvnm9/XWtE1zg8aaOSvLvSW18xm9Nef/9LDz+dxZ\nZ/bWLO97YNBYs+Ti5z2Ud+zq79GQxw01P+yh4fL+/b01cxc9q3++2/vPe9Pk7z729Hzzl35tb13N\n9Z8728Dfi/mLnf1jXfHZ/nFaGzbflKitWzJ34cW9dVfu6B9r7vTTBs25vP/zYxmrvjA75/R24GCW\nP9H/2GvuuZf01ly5o/+cnww7zwz++Vhns/MIBwAAACZEWAYAAICOqQ/LVXV+VT1UVXOjz99VVd8z\nxvFfUlWfHNd4AAAAzL6pCstVtaeqHh2F44eq6qEkh1prT2+tPeEi+Kp6dVWd0ItAW2vvaa31vwAW\nAACAp4ypCssj/3gUjg/f7lqviarKPzgDAADgCaYxLP8DVbVQVa0bbKvqOUnelOSFo2eh94+2b6mq\nN1TV7VV1b1W9qaq2jb52eVXdWVU/WlX3JHnz4W1rxt1TVT9SVTdV1eer6u1VtXXN1/99Vd1dVXdV\n1feM1nbRRJoBAADAREx9WD6a1tonkrw2yftGz0CfPvrSLyS5JMnzk1yU5JlJfnLNXc9NcmaSC5J8\n31GGf2WSb0ryrCRfnuTVSVJV35Tkh5JcMRr78rF9QwAAAEyNaQzL11XV/tHtmN6UuKoqqwH437XW\n9rXWHkzyc0m+fU3ZSpL/0Fo72Fp79ChD/Wpr7a7W2r4kf5LV4J2shug3t9Y+1lp7JMk1x7I+AAAA\nZsM0vmb3qtbaDYc/qaqFY7jv2UlOTbK0mptXh0iy9p3B72utHegZ5541Hz+S5PDbne9IsnvN1+74\n+0m2nJL5nf1LfdmL+2uG29NbMX/hwPluPaGFTNQpew/k/N+8ub/wjDP6a9rKoDlXtvTX1N2f6y96\n/PFB802N1tIO9a/5f//19b01Vz7zKwdNOXdm/35b3r+/t6atPOF/Ak61S77s4bxz1wd7667c8fze\nmv+8572D5vzBC1/SWzP37Av7B/r05kHzTYN26FCW7+3/Wd1110d6a4bsiySZO+vM3prl+/cNGmuW\nXPLlj2TXriF97B/r9+4Ydkx/+3kv6q2ZP/ecQWPl7mFl0+CWm04ddDwuX/6Pemvm3vXhQXN+yYe2\n9tbc+lWf6a1p7bFB802NzZuTnef2lrWP39JbM3/ezkFTHrris/1jDTiu675pfNj/JA4+lvaZO/rr\nBlje//lBdXMDHjuebI9BatvWbLrkS3vrlm/qf+y98uJhvxc33dj/u2Hoz0duH1Z2vGbsp+YJWufz\nvUkeTXJpa+1oZ5bufY7F3UnW7rnzTmAsAAAAptQ0XoZ9LO5NsrOqTkmS1tpKkt9M8stV9cVJUlXP\nrKorxzTftUleU1XPqapTk7x+TOMCAAAwRWY9LP9lko8luaeq9o62/WiSTyV5f1V9IckNScbyPsqt\ntXck+dUkf3V4jtGXDo5jfAAAAKbDVF2G3VpbOMK2PVl93fHhzy9f8/FjSV7WqT+Q5MdHt+5Y78o/\nvIz6Cdu6a2itXdP5/OeT/Hzy929ftZKZemUTAAAAfWb9meWJq6p/Mnov5zOS/GKSP2mtHdrodQEA\nADA+wvKx+/4kn8vq/49eTvKvNnY5AAAAjNtUXYY9C1pr37TRawAAAGB9eWYZAAAAOoRlAAAA6HAZ\n9phc/Oz9+bNd1/XWXbnj+b018xcuDJpz110fGVA1pCaZe8agsqnQDi1n+f59vXVD+vPNl379oDnP\n/6n39dYsDxinrQypmh61dUvmLrqkt+7KnQP+7lb9Jat1/YU1v7l/nMeHTjgdbrnp1EHnh7nn9u+P\nH/6qs4dN2vb2lizf/Kn+YVZm593zan4+c9u/uLduyL4YcqwmyfK+B/qLNs0NGmvQiWZKDD6mL3pW\nb823XzCwP7XSW3LonnuHjTVLavXY7jP37r/urZk/b2dvTZLc+lV3Dqo72bQDB7P88Vt66x6/YrF/\nsBuWhk064Fxz6O57emtm7f/RtiSttbGMVZc9b1Dd8u6PDhhsth5f9GmPHsjKTTf31tXmU3prNt04\nLHcMGevQHdNxjvHMMgAAAHQIywAAANAhLAMAAECHsAwAAAAdwjIAAAB0CMsAAADQISwDAABAh7AM\nAAAAHf3vYM8gf/fJ0/OyF1/VW7frrut6a1724oVBcw6Zb7g3jHGs9VVbTsn8zoXeuit39I81d/aw\nH4G5M8/oLxrwJvX1wNyg+aZFO3Awyx/7ZG/d3KXP7q0ZMk6SLO+9f1Bdr9bGM86UWf74Lb01ddnz\nBo01d+hQb83KI4/0D3Sw/9ifFu3QoSzf+7neuprvPze0Af0brC2Pb6wZs/ypz/TWbNq6ddBYKwcP\n9tYM2bdJkseHlU2FNr7j8dAddw4rHPA772Q9Dw+x+Yal/qIhPUyeun1sLW1MP9Nt90eHzblpwOO0\nlafm+bo9/lhvzabnfemgsVY+enNvzbScqz2zDAAAAB3CMgAAAHQIywAAANAhLAMAAECHsAwAAAAd\nwjIAAAB0CMsAAADQISwDAABAh7AMAAAAHdVa2+g1nBSq6r4kt230Ok7ABa21szd6EUPMeK9nps+J\nXk+SXk/GjPc50etJ0uvJmJk+J3o9SXo9GTPe52Sdey0sAwAAQIfLsAEAAKBDWAYAAIAOYRkAAAA6\nhGUAAADoEJYBAACgQ1gGAACADmEZAAAAOuY3egEni6qa9Tes3jtDb54+071urdVGr2EovZ4cvZ6M\nWe9znKsnSa8nZFbOH0myffv2trCw0Fu3tLS0/os5DrPU66HH9eLiYm/NRuyPWen1OM8fQ/ZFMvb9\nsa7namGZw27b6AUA0Mu5enL0midYWFjI7t27e+uqZiInnRTsj+kxZF8kY98f63qudhk2AAAAdAjL\nAAAA0CEsAwAAQIewDAAAAB3CMgAAAHQIywAAANAhLAMAAECHsAwAMMMWFxfTWuu9ceKWlpZSVb03\nJsf+mB5D9sWs7Q9hGQAAADqEZQAAAOgQlgEAAKBjpsNyVb2pql4/xvEWqqpV1fxRvv7jVfXfxzUf\nAAAA0+mIoXBaVNU7k3ywtfaTne2vSPIbSXa21g6Ntl2e5G2ttZ3rtZ7W2s+t19gAAABMj2l/Zvkt\nSV5VT/y3ad+Z5LcPB+VxONqzyQAAADz1THtYvi7JWUlecnhDVZ2R5OVJ3lpVv1VVP1NVT0vyjiQ7\nquqh0W1HVW2qqtdV1a1VdX9VXVtVZ47GOXzJ9b+sqtuT/OWaeb+7qu6qqrur6kfWzH1NVb1tEt84\nAAAAG2eqw3Jr7dEk1ya5es3mVya5ubX2N2vqHk7y0iR3tdaePrrdleQHklyV5OuS7EjyQJJf60zz\ndUmek+TKNdu+PsnFSb4xyY9W1RVj/cYAAACYalMdlkfekuTbqmrr6POrR9uGeG2Sn2it3dlaO5jk\nmtFYay+5vqa19vAomB/2U6Ntf5vkzUn++Yl9C3Bya6313nhyi4uL+jhF7IsT55ienKWlpVRV7w2A\nYzP1r9Ntrd1YVXuTXFVVH0ry1Um+deDdL0jyh1W1smbbcpJz1nx+xxHut3bbbUm+7BiWDAAAwIyb\nhWeWk+StWX1G+VVJdrXW7j1CzZH+PH1Hkpe21k5fc9vaWvtsz/3OW/Px+UnuOt6FAwAAMHtmKSxf\nkeR7c/RLsO9NclZVnbZm25uS/GxVXZAkVXX26G2n+ry+qk6tqkuTvCbJ249/6QAAAMyaqb8MO0la\na3uq6r1JviLJHx+l5uaq+t0kn66quSTPTfIrSSrJ9VW1I8nnshp8/6hnyncn+VRW/5jwhtba9eP5\nTgAAAJgFMxGWk6S1dvkRtr268/l3H+Gubxzduvfdk9UgfbRt/+0I97lm2GoBAACYZbNyGTYAAABM\njLAMAAAAHcIyAAAAdAjLAAAA0CEsj8ni4mJaa703Jsf+mJyq6r3x5JaWlvRxitgXJ84xDcCsE5YB\nAACgQ1gGAACADmEZAAAAOoRlAAAA6BCWAQAAoENYBgAAgA5hGQAAADqEZQAAAOiY3+gFnCyWlpZS\nVRu9DNawPwAAgOPlmWUAAADoEJYBAACgQ1gGAACADmEZAAAAOoRlAAAA6BCWAQAAoENYBgAAgA5h\nGQAAADqEZQAAAOiY3+gFnET2JrltoxdxAi7Y6AUcg1nu9Sz1OdHrSdLryZjlPid6PUl6PRmz1OdE\nrydJrydjlvucrHOvq7W2nuMDAADAzHEZNgAAAHQIywAAANAhLAMAAECHsAwAAAAdwjIAAAB0CMsA\nAADQISwDAABAh7AMAAAAHcIyAAAAdPw/Q4Vz0iX4lakAAAAASUVORK5CYII=\n", 165 | "text/plain": [ 166 | "" 167 | ] 168 | }, 169 | "metadata": {}, 170 | "output_type": "display_data" 171 | } 172 | ], 173 | "source": [ 174 | "fig, all_axs = plt.subplots(4, prev_post.timesteps, figsize=(16, 4))\n", 175 | "all_axs = all_axs.T\n", 176 | "\n", 177 | "VISIBLES = 0\n", 178 | "TRUE_STATES = 1\n", 179 | "FILTERING = 2\n", 180 | "VITERBI = 3\n", 181 | "\n", 182 | "all_axs[0][VISIBLES].set_title('Visibles', x=-0.5, y=0.2)\n", 183 | "all_axs[0][TRUE_STATES].set_title('Actual', x=-0.5, y=0.4)\n", 184 | "all_axs[0][FILTERING].set_title('Filtering', x=-0.5, y=0.4)\n", 185 | "all_axs[0][VITERBI].set_title('Viterbi', x=-0.5, y=0.4)\n", 186 | "\n", 187 | "for i, (axs, hidden, visible, alpha, viterbi) in enumerate(zip(all_axs, hiddens, visibles, alphas, most_likely_states)):\n", 188 | " axs[VISIBLES].imshow([prev_post.map_visible_state_to_bump_creak[visible]], cmap='gray', vmin=0)\n", 189 | " hide_ticks(axs[VISIBLES]) \n", 190 | " \n", 191 | " axs[TRUE_STATES].imshow(prev_post.plot_state_in_room(hidden, width=width, height=height), cmap='gray')\n", 192 | " hide_ticks(axs[TRUE_STATES])\n", 193 | " \n", 194 | " axs[FILTERING].imshow(alpha.reshape(height, width))\n", 195 | " hide_ticks(axs[FILTERING]) \n", 196 | " \n", 197 | " axs[VITERBI].imshow(prev_post.plot_state_in_room(viterbi, width=width, height=height), cmap='gray')\n", 198 | " hide_ticks(axs[VITERBI])\n", 199 | " \n", 200 | "plt.show()" 201 | ] 202 | } 203 | ], 204 | "metadata": { 205 | "kernelspec": { 206 | "display_name": "Python 3", 207 | "language": "python", 208 | "name": "python3" 209 | }, 210 | "language_info": { 211 | "codemirror_mode": { 212 | "name": "ipython", 213 | "version": 3 214 | }, 215 | "file_extension": ".py", 216 | "mimetype": "text/x-python", 217 | "name": "python", 218 | "nbconvert_exporter": "python", 219 | "pygments_lexer": "ipython3", 220 | "version": "3.6.1" 221 | } 222 | }, 223 | "nbformat": 4, 224 | "nbformat_minor": 2 225 | } 226 | -------------------------------------------------------------------------------- /2018-05-13-viterbi-message-passing.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Inference in Hidden Markov Models continued: Viterbi Algorithm\n", 8 | "\n", 9 | "This is a short post that continues from [the more-detailed alpha recursion HMM post](2018-05-02-hmm-alpha-recursion.ipynb). In this post I'll implement the [Viterbi algorithm](https://en.wikipedia.org/wiki/Viterbi_algorithm) like Barber does in [\"Bayesian Reasoning and Machine Learning\"](http://www.cs.ucl.ac.uk/staff/d.barber/brml/). Like before, I'm porting the MatLab code from the textbook." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import numpy as np\n", 19 | "import matplotlib.pyplot as plt\n", 20 | "\n", 21 | "import nb_code.hmm_alpha_recursion as prev_post" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "metadata": { 28 | "collapsed": true 29 | }, 30 | "outputs": [], 31 | "source": [ 32 | "# helper functions you can skip over :D\n", 33 | "SAVE = True\n", 34 | "def maybe_save_plot(filename):\n", 35 | " if SAVE:\n", 36 | " plt.tight_layout()\n", 37 | " plt.savefig('images/' + filename, bbox_inches=\"tight\")\n", 38 | "\n", 39 | "def hide_ticks(plot):\n", 40 | " plot.axes.get_xaxis().set_visible(False)\n", 41 | " plot.axes.get_yaxis().set_visible(False)" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "metadata": {}, 47 | "source": [ 48 | "## Viterbi algorithm in HMMs using message passing\n", 49 | "\n", 50 | "The Viterbi algorithm finds the most-likely path $h_{1:T}$ for the visibles $v_{1:T}$, where $T$ is the timestep of the last observed visible.\n", 51 | "The algorithm takes in the visibles $v_{1:T}$, the initial state $p(h_1)$, the transition probabilities $p(h_{t} \\mid h_{t - 1})$, and the emission probabilities $p(v_t \\mid h_t)$, and returns the list of most-likely hidden states $h_{1:T}$. \n", 52 | "\n", 53 | "I generated the required probabilities in the [alpha recursion HMM post](2018-05-02-hmm-alpha-recursion.ipynb).\n", 54 | "\n", 55 | "### Algorithm\n", 56 | "\n", 57 | "Barber frames the Viterbi algorithm as message passing using the max-product algorithm.\n", 58 | "\n", 59 | "This version of the algorithm begins at the end of the hidden states ($h_T$), and computes an incoming message from future states. The message is meant to represent the effect of maximizing over those states.\n", 60 | "Barber gives the messages as:\n", 61 | "\n", 62 | "$$\\mu(h_T) = 1$$\n", 63 | "\n", 64 | "$$\\mu(h_{t - 1}) = \\max_{h_t} p(v_t \\mid h_t)p(h_t \\mid h_{t - 1})\\mu(h_t).$$\n", 65 | "\n", 66 | "Once the messages are computed, the algorithm then computes the most-likely state for $h_1$, and uses that to compute the most-likely state for $h_2$ and so on. It basically maximizes the marginal of $p(h_t \\mid v_{1:T})$ and then uses the most-likely state for $h_t$ in the transition matrix for computing $p(h_{t + 1} \\mid v_{1:T})$ so it returns a valid path.\n", 67 | "\n", 68 | "$$h_1^* = \\max_{h_1} p(v_1 \\mid h_1)p(h_1)\\mu(h_1)$$\n", 69 | "\n", 70 | "$$h_t^* = \\max_{h_t} p(v_t \\mid h_t)p(h_t \\mid h_{t - 1}^*)\\mu(h_t).$$\n", 71 | "\n", 72 | "Now in Python!" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 3, 78 | "metadata": {}, 79 | "outputs": [ 80 | { 81 | "name": "stdout", 82 | "output_type": "stream", 83 | "text": [ 84 | "[15 9 3 2 1 2 3 4 10 9]\n" 85 | ] 86 | } 87 | ], 88 | "source": [ 89 | "def viterbi(visibles, p_hidden_start, p_transition, p_emission):\n", 90 | " num_timestamps = visibles.shape[0]\n", 91 | " num_hidden_states = p_transition.shape[0]\n", 92 | " \n", 93 | " # messages[t] corresponds to mu(h_t), which is the message coming into h_t\n", 94 | " messages = np.zeros((num_timestamps, num_hidden_states))\n", 95 | " \n", 96 | " most_likely_states = np.zeros((num_timestamps,), dtype=int)\n", 97 | "\n", 98 | " # The message coming into the last node is 1 for all states\n", 99 | " messages[-1] = np.ones(num_hidden_states)\n", 100 | "\n", 101 | " # normalize!\n", 102 | " messages[-1] /= np.sum(messages[-1])\n", 103 | " \n", 104 | " # Compute the messages!\n", 105 | " for t in range(num_timestamps - 1, 0, -1):\n", 106 | " # use the data at time t to make mu[h_{t - 1}]\n", 107 | " \n", 108 | " # compute max p(v|h)p(h|h)mu(h)!\n", 109 | " \n", 110 | " # compute p(v|h)mu(h)\n", 111 | " message_and_emission = messages[t] * p_emission[visibles[t]]\n", 112 | " \n", 113 | " # compute p(v|h)p(h|h)mu(h)\n", 114 | " # message_and_emission.reshape(-1, 1): new_state x 1\n", 115 | " # np.tile(...): new_state x old_state\n", 116 | " # p_transition: new_state x old_state\n", 117 | " # np.tile(...) * p_transition: new_state x old_state\n", 118 | " all_h_ts = np.tile(\n", 119 | " message_and_emission.reshape(-1, 1),\n", 120 | " (1, num_hidden_states)\n", 121 | " ) * p_transition\n", 122 | " \n", 123 | " # the message is the value from the highest h_t\n", 124 | " messages[t - 1] = np.max(all_h_ts, axis=0)\n", 125 | " \n", 126 | " # and normalize\n", 127 | " messages[t - 1] /= np.sum(messages[t - 1])\n", 128 | " \n", 129 | " # now from the beginning! compute h_t* using these messages\n", 130 | " \n", 131 | " # argmax will give us the state.\n", 132 | " # argmax p(v_1|h_1)p(h_1)mu(h_1)\n", 133 | " most_likely_states[0] = np.argmax(\n", 134 | " p_hidden_start \n", 135 | " * p_emission[visibles[0]] \n", 136 | " * messages[0]\n", 137 | " )\n", 138 | " \n", 139 | " for t in range(1, num_timestamps):\n", 140 | " # argmax_h_t p(v_t|h_t)p(h_t|h_{t - 1})mu(h_t)\n", 141 | " most_likely_states[t] = np.argmax(\n", 142 | " p_emission[visibles[t], :]\n", 143 | " * p_transition[:, most_likely_states[t - 1]] \n", 144 | " * messages[t]\n", 145 | " )\n", 146 | " \n", 147 | " return most_likely_states\n", 148 | "\n", 149 | "most_likely_states = viterbi(\n", 150 | " prev_post.visibles, \n", 151 | " prev_post.p_hidden_start,\n", 152 | " prev_post.p_transition,\n", 153 | " prev_post.p_emission,\n", 154 | ")\n", 155 | "\n", 156 | "print(most_likely_states)" 157 | ] 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "metadata": {}, 162 | "source": [ 163 | "## Visualizing\n", 164 | "\n", 165 | "I can plot the most-likely states." 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 4, 171 | "metadata": {}, 172 | "outputs": [ 173 | { 174 | "data": { 175 | "image/png": "iVBORw0KGgoAAAANSUhEUgAABIAAAAEYCAYAAAAkpY3TAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAHa9JREFUeJzt3XuUZWdZJ+DfW1Xp7lwgITdCEkgLhBAY1JlEbsIICwYS\nghJZLHQQIiAoumR0vIyIAxNmeZ1BHB1xQB25yEXQ0ahESIYlMkRASDuCMoFAkk4CCSFNEjrXvlR9\n80dVa9F2c77q7upTe9fzrFVrde16691f1Vt71zm/3vtUtdYCAAAAwHjNTHsBAAAAAKwuARAAAADA\nyAmAAAAAAEZOAAQAAAAwcgIgAAAAgJETAAEAAACMnAAIAAAAYOQEQAAAAAAjJwACAAAAGLm5aS9g\naKqqTXsN60VrrQ62x4knntg2b958CFZz+GzZsmXaSzhQ21prJx1MgyEeX+ecc860l3BAtmzZctDz\nSszscNm6dWu2bdt20OfEIc5rwJwTB8Q5cVicEwfpkBxjHtsfPp6LDU7XMSYAYtQ2b96cK6+8ctrL\nWJGqgz7XTsv1017ANAzt52uPqlqX80qGObNzzz132ktg5dblMTbE4ytxThwa58RBOiTHmMf2w2Je\nh1XXMeYWMAAAAICREwABAAAAjJwACAAAAGDkBEAAAAAAIycAAgAAABg5ARAAAADAyAmAAAAAAEZO\nAAQAAAAwcgIgAAAAgJETAAEAAACMnAAIAAAAYOQEQAAAAAAjJwACAAAAGDkBEAAAAMDICYAAAAAA\nRk4ABAAAADByAiAAAACAkRMAAQAAAIycAAgAAABg5A4qAKqqN1XVazrqPlNVT1n698VV9Y5vULu1\nqp5+MOsCAAAA4J/MfaMPVtUHknyitfbavbY/J8mbk5zeWts9aSettUcf1CoBAAAAOGCTrgB6W5IX\nVlXttf1FSd7ZE/4AAAAAMF2TAqBLkpyQ5Ml7NlTVA5I8O8nbq+qtVfXzS9tPrKr3VdUdVXVbVX2k\nqmaWPrb3bV2bquo9VXVnVf1tVX3LvnZeVTNV9aqquqaqvlpV762q45c+tqmq3rG0/Y6q+mRVPfDA\nvxUAAAAA4/QNA6DW2r1J3pvkomWbn5/ks621T+1V/pNJvpjkpCQPTPLqJG0/rZ+T5A+THJ/kXUku\nqaoj9lH3yiQXJvmOJKcmuT3JG5c+9v1Jjk3y4CyGVK9Icu83+noAAAAA1qOeF4F+W5LnVdWmpfcv\nWtq2t11JHpTkjNbartbaR1pr+wuAtrTW/qi1tivJG5JsSvL4fdS9IsnPtda+2FrbkeTipbXMLe3v\nhCQPb63Nt9a2tNa2d3w9AAAAAOvKxACotXZFkm1JLqyqhyV5bBav2tnbf03yhSSXV9W1VfWqb9D2\nxmX9F7J45dCp+6g7I8mfLN3idUeSq5LMZ/EKo99PclmSP6iqm6rqv+znKiIAAACAda33z8C/PYtX\n/rwwyWWttVv2Lmit3dla+8nW2kOTfFeSn6iqp+2n34P3/GPpdYJOT3LTPupuTHJ+a+24ZW+bWmtf\nWrrK6HWttUcleWIWX5foon30AAAAAFjXVhIAPT3Jy7Pv279SVc+uqocv/cWwr2XxSp2F/fQ7p6qe\nu3Qr148n2ZHk4/uoe1OSX6iqM5b2cdLSn6BPVT21qh5TVbNJtmfxlrD97Q8AAABg3eoKgFprW5N8\nNMnRSf5sP2VnJvlgkruSfCzJb7XWPrSf2j9N8j1ZfFHnFyV57tLrAe3t15f2d3lV3ZnFkOhxSx87\nJckfZTH8uSrJh7N4WxgAAAAAy8z1FrbWnrKPbS9e9u9fS/Jr+/nczcv+ffGE/SyvXcjii0S/YR91\n707y7gnLBgAAAFj3em8BAwAAAGCgBEAAAAAAIycAAgAAABg5ARAAAADAyAmAAAAAAEZOAAQAAAAw\ncgIgAAAAgJETAAEAAACMnAAIAAAAYOQEQAAAAAAjJwACAAAAGDkBEAAAAMDICYAAAAAARq5aa9Ne\nw6BU1a1Jrp/2OtaBM1prJx1sE/M6rA56ZuZ1WDnGhsW8hsc5cVgcY8NiXsNjZsNiXsPTNTMBEAAA\nAMDIuQUMAAAAYOQEQAAAAAAjJwACAAAAGDkBEAAAAMDICYAAAAAARk4ABAAAADByAiAAAACAkRMA\nAQAAAIycAAgAAABg5ARAAAAAACMnAAIAAAAYOQEQAAAAwMgJgAAAAABGTgAEAAAAMHICIAAAAICR\nEwABAAAAjJwACAAAAGDkBEAAAAAAIycAAgAAABg5ARAAAADAyM1NewHrQVW1aa9hira11k6a9iJW\nysyGNbPVmNc555zTVbdly5ZDveuVGty8kukeY72zTVZnvq21OuRNV9lQzomrdNwO7hgbyrx6rfCY\nHdy8kvHNbIUGNzPzMq8BGdy8EjNbzZkJgFht1097AayYmSW58soru+qqpv5c3rxWqHe2yZqYLyuw\nSsetY2zKVnjMmtfwmNmwmNewmNfwrOrM3AIGAAAAMHICIAAAAICREwABAAAAjJwACAAAAGDkBEAA\nAAAAIycAAgAAABg5ARAAAADAyAmAAAAAAEZubtoLgPWitdZVV1WrvBJ6mMN4me14me04mevw9D7m\nScwX4HByBRAAAADAyAmAAAAAAEZOAAQAAAAwcgIgAAAAgJETAAEAAACMnACoQ1U9paq+OO11AAAA\nAByIwQRAVfVXVXV7VW3sqN1cVa2q/Jl7AAAAYN0bRABUVZuTPDlJS/JdU10MAAAAwMAMIgBKclGS\njyd5a5Lv37Oxqo6sql+tquur6mtVdUVVHZnk/yyV3FFVd1XVE6rq4qp6x7LP/bqrhKrqJVV1VVXd\nWVXXVtUPHb4vDwAAAGD1DOUWqYuSvCHJ3yT5eFU9sLV2S5LXJ3l0kicm+XKSxyVZSPKvk1yX5LjW\n2u4kqapnTtjHV5I8O8m1S5///qr6ZGvtb1fh6zlorbWuuqpa5ZXQyywAgPXAY55h6X1ekZjtavL8\njsNhzV8BVFVPSnJGkve21rYkuSbJC6pqJslLk/xYa+1LrbX51tpHW2s7DmQ/rbVLW2vXtEUfTnJ5\nFm87AwAAABi0NR8AZfGWr8tba9uW3n/X0rYTk2zKYiB00Krq/Kr6eFXdVlV3JHnW0j4AAAAABm1N\n3wK29Ho+z08yW1VfXtq8MclxSR6U5L4kD0vyqb0+dV/Xz92d5Khl75+ybD8bk/yvLN5q9qettV1V\ndUkS19cBAAAAg7fWrwC6MMl8kkcl+dalt7OTfCSLYc3vJXlDVZ1aVbNLL/a8McmtWXwtoIcu6/V3\nSf51VT2kqo5N8rPLPrYhi8HSrUl2V9X5SZ6xul8aAAAAwOGx1gOg70/yltbaDa21L+95S/KbSb4v\nyauS/H2STya5LcmvJJlprd2T5BeS/HVV3VFVj2+t/e8k70ny6SRbkrxvz05aa3cm+XdJ3pvk9iQv\nSPJnh+uLBAAAAFhNtZJXfefAVNUh/yYP6FXit7TWzp32IlZqNWY2IIObmXkNa17J+p5Za23qJ+aV\nWs/zygCPMfMa1rwSMxvazMY2rxX+FTDzWiWr9PxucPNKhjOzVbKqM1vrVwABAAAAcJAEQAAAAAAj\nJwACAAAAGDkBEAAAAMDIzU17ARyYNfDizgAAwMB5XrE2mMPwDOgPM/0jVwABAAAAjJwACAAAAGDk\nBEAAAAAAIycAAgAAABg5ARAAAADAyAmAAAAAAEZOAAQAAAAwcgIgAAAAgJETAAEAAACM3Ny0FwAA\nAAAwJFU17SWsmCuAAAAAAEZOAAQAAAAwcgIgAAAAgJETAAEAAACMnAAIAAAAYOQEQAAAAAAjJwAC\nAAAAGDkBEAAAAMDICYAAAAAARk4ABAAAADByc9NewDqxLcn1017ElJwx7QUcIDMbFvManvU6M/Ma\nniHOzLyGx8yGxbyGxbyGx8xWSbXWVrM/AAAAAFPmFjAAAACAkRMAAQAAAIycAAgAAABg5ARAAAAA\nACMnAAIAAAAYOQEQAAAAwMgJgAAAAABGTgAEAAAAMHICIAAAAICREwABAAAAjJwACAAAAGDkBEAA\nAAAAIycAAgAAABg5ARAAAADAyAmAAAAAAEZOAAQAAAAwcnPTXsB6sGH2yHbkEcdOrDvzrDu6e37+\nc8cdzJIOev+9tnx6x7bW2kmHvPEq21Ab26YcPbmwVtC0HfBy9r/7I47oLOxf6PadtwxuZhtmNrUj\nZ46ZWNfmF7p77nxQx/yTtBWcRTfduruvcMfO7p7b222Dm1eydF6cm3xeTOs/cHacuKGvcAXH7cZb\n7uuqawt9P1v3tbuzs923kjPHmrChNrZNNfmY2HBW//8r7fpCb23/z8B9p27sqtt08wqOsV23Du4Y\nm73f0W3uxAdMrNu49Z7unjtP6zsnzvR/azN3691ddTXT/3O1feGrg5tXkmyoTV3H2I7TjuruefQx\nfeevh264s7vn5z/b9/iz7ez/Qbgztw9uZouPO+43sW7HyZu6e278auf3bAWPZe47rfOceNOO7p7b\nd28b3rw6H9fXbP+5pvsx5Qp+41dv8YbOx/9Jtu8Y3uP6ZAXPxaas5vqeCLTdnc8BsvrnRAHQYXDk\nEcfmiae/aGLdpZdd0t3zgiddeDBLOuj995p90BeuP+RND4NNOTqPm3n6xLqane3uuZIDv9fcKaf1\nFR7Rf6h/4NpfHdzMjpw5Jk845jkT6+a3b+/uecMPPrGrbscJ/Q/EHvlb27rq2vVf7O55+b3vGNy8\nkuTIuWPzxNO+b3Lhrv7j5roXb+6qm9/UHyg87A2f7apbuKfvifTHd7y/e99ryaY6Oo+fe+bEutPe\n1v9k5+bndD6RnZ/v7nnVzz2sq+5Rr7uhu+cHbvrNwR1jcyc+IA963Y9OrDvzJX/b3fO6H318V90x\nN3a3zElv/kRX3czR/aHH5dvfMrh5JUvH2BHnTay75sfO6e75bd/ed/561zd9qLvnBd8++Xdtkuze\n2n+MfXDhDwc3syNn7pcnHPvdE+u2vuzs7p7f9Pt9B0/bfld3z6tefWZX3dmvuaa752W3vnlw89qU\no/O42WdMrJs9pj9wmL+rN8DuT4B6w4SZU0/p7vmBa14/uHklSzOrpx3apr3/Ib6C/3ycPfHkrrr5\nW77S3fOD7Y9WdWZuAQMAAAAYOQEQAAAAwMgJgAAAAABGTgAEAAAAMHICIAAAAICREwABAAAAjJwA\nCAAAAGDkBEAAAAAAIzc37QWsB2eedUcuveySiXUXPOnCw7Cag9//pVdM/loGr5KanZ1Y1hZad8vr\nfvkJB7Oifbr6ov/RVfesf/M9h3zfa8rcEckpJ00su+a1j+pu+bCf/lhXXc/PyR4L1Ze5zzzim7p7\n5h/6S9eU3buz8NXbJ5bN/4uHdrd88C/2zWwl5lvfMV5zvb9O+88Za8n88Ufntu/8tsmFT97S3XP2\nhKO76lbyHXvEj3Tu/+QTV9B1eDZuvSdnvnjy9+Kym/6uu+czT12Fn92qrrKFu+859PteY3adfFRu\n/r5zJ9ad+fP9J/3b5+e76p557zndPbNwfVdZ/zkxyUJ/6ZrRWtrOnRPLNv/OF/pb7tzVV7iC7+0j\nfuSTXXV9PynDVbMzmT1m8u+chXv6zzXvu/ETXXXnv/Dl3T0/+M7f66o774zHdvccqpqdzez9j51Y\nN/+17d09Z775kV1173//u7t7nvddL+yqm13Bc8Z8pb/0QLgCCAAAAGDkBEAAAAAAIycAAgAAABg5\nARAAAADAyAmAAAAAAEZuTQdAVfWQqrqrqmaX3v+rqnrZIez/5Kr63KHqBwAAALAWrZkAqKq2VtW9\nS4HPXVV1V5LdrbVjWmv/7K8TVtWLq+qKg9lna+0jrbWzDqYHAAAAwFq3ZgKgJd+5FPjsebtptXZU\nVXOr1RsAAABgLVlrAdDXqarNVdX2Dmuq6uwkb0ryhKWrhe5Y2r6xql5fVTdU1S1V9aaqOnLpY0+p\nqi9W1c9U1ZeTvGXPtmV9t1bVT1XVp6vqa1X1nqratOzj/6Gqbq6qm6rqZUtre/hh+WYAAAAAHKBB\nXgXTWruqql6R5GWttSct+9AvJ3lYkm9NsivJu5K8NsnPLn38lCTHJzkji+HX4/bR/vlJzktyX5K/\nTvLiJG+qqvOS/ESSpyW5Lslv96738587Lhc86cKJdZdecUlvy1XRs8b14oRH7cwL/njrxLo/eMJj\nunue+Zs3HMSK9u3S522aXJSk7t1xyPe9lrQdOzJ/9TUT6x7xP1eQeZ9wfN++T39gd8t21eQ1Jkl2\n7uruOVRtYSELd98zsW7uM9f199y4sa9u9+7+np21tWFDX938mv5/l/2ave3uHP/OT06s+8ANV3b3\nfOpLX95Vd+THr+7umbbQVbaw/c7+nkNVNbHkvIec291u9oT7d9Ut3HV3d8+2o/d3U99ch+yIW+7O\nKf/tYxPrbv/efT103bfj3veZg1nSQVnJeXaQ5uYyc9IJE8tueN7p3S3vemjf9+yRP72CubbWVzcz\n299zgHZu3pAbXv/giXUP+Y/9j78uPPfErrojTr63u+cF/+qZXXVt1y3dPQdrbjbpOMZmawWPq778\n1a6y8854bHfLh3/08111n3/szu6eq22tPRK9pKruWHpbURpSVZXkB5P8+9baba21O5P8YpLvXVa2\nkOQ/tdZ2tNb2dzT+RmvtptbabUn+PIthUrIYDL2ltfaZ1to9SS5eyfoAAAAApmWtXQF0YWvtg3ve\nqarNK/jck5IclWRL/dP/elWS5ZH2ra21+yb0+fKyf9+T5NSlf5+aZPl/bd64grUBAAAATM1aC4BW\nYu9rGrcluTfJo1trX+r8nJW4Ocny6zgnX0cIAAAAsAastVvAVuKWJKdX1YYkaa0tJPmdJL9WVScn\nSVWdVlV9N1NO9t4kL6mqs6vqqCSvOUR9AQAAAFbVkAOgv0zymSRfrqptS9t+JskXkny8qrYn+WCS\nsw7Fzlpr70/yG0k+tGcfSx8a96vrAgAAAIO3Zm4Ba61t3se2rVl8HZ897z9l2b93Jrlgr/r7krx6\n6W3vXn+Vr7+F659t23sNrbWL93r/l5L8UvKPf4p+IYu3hgEAAACsWUO+Auiwq6rvrqqNVfWAJL+S\n5M9bayP/O5cAAADA0AmAVuaHknwlyTVJ5pP88HSXAwAAADDZmrkFbAhaa+dNew0AAAAAK+UKIAAA\nAICRcwXQGnLBky485D0vveKSQ95zPfjqZzbknY88fXJhbu9ventf7ewDHtDd8su7juuqa9vv7O45\nRI/45nty2WV/N7HugnNO6e65cNrJnYUL3T3brr6XDNv1wPt398zV/aVrSc3OZvbYyV/nX3zmQ909\nn/bCH+iqO+Ijf9/dM1WTa5Is3Nf3ByHbCn5e1pLasDEzm8+YWHf+WUd29zzqxK901e18zEO7e858\n5P921dXcuB/+1JGbMvPIsyfWffZHj+ru+fC39p2/Zv76U909u7V26HuuMQ95zF3575deMbHu2l3/\nr7vneb/ad146/6wnd/dcuHPcjye6zc+n3fG1iWWnXtH3OC1JZn/3+q66G175Ld09T/vlj/YVLsx3\n9xyiDdftzINftHVi3fw993T3rLkjuuraLbd29/zcr5/bVXfmK2/p7jlUbcfOzH/huo7CQ//7Ye5B\n/c8Xrn1a38/MzMa+x5NJknv7Sw+EK4AAAAAARk4ABAAAADByAiAAAACAkRMAAQAAAIycAAgAAABg\n5ARAAAAAACMnAAIAAAAYOQEQAAAAwMgJgAAAAABGbm7aC1gPzjzrjlx62SUT6y540oXdPXdfu7Wr\nbto9k9evoHaNqZpc01p/v5nZrrKbX3B2d8s/fMydXXXX/udHdPfMz/aXrhVXf/qoPPPUb51YN3v/\ne7p71vX3dtXN3/G17p67nnFuV90Rl1/Z3XOwZmdTx95vYtkFj3t2d8sjbv5UV13bvbu7Z7c2f+h7\nriFttrJw9KaJdTMPOrm75+6rr+mqm7nu+u6evea3bz/kPdeSdt99af9w9cS6R/7Gw7t7LvzD5zt3\nvoLfi/yjG686Lj/+2O+eWPfGT/xxd89n/cvndtUt3PmV7p4sOvPRd+YvLvvwxLpnPfqp3T3nt9/V\nVXfar3ysuyeLWmtpO3dNrJs56qj+njt2dNVd/cZzuns+4of/pruWlZu9//276ua33dbdc+b447rq\nFtbQ4w5XAAEAAACMnAAIAAAAYOQEQAAAAAAjJwACAAAAGDkBEAAAAMDICYAAAAAARk4ABAAAADBy\nAiAAAACAkRMAAQAAAIycAAgAAABg5Kq1Nu01jF5V3Zrk+mmvY0rOaK2dNO1FrJSZDWtm5jWseSXr\nembmNTyDm5l5DWteiZkNbWbmZV4DMrh5JWa2mjMTAAEAAACMnFvAAAAAAEZOAAQAAAAwcgIgAAAA\ngJETAAEAAACMnAAIAAAAYOQEQAAAAAAjJwACAAAAGDkBEAAAAMDICYAAAAAARk4ABAAAADByAiAA\nAACAkRMAAQAAAIycAAgAAABg5ARAAAAAACMnAAIAAAAYOQEQAAAAwMjNTXsB60FVtWmvYYq2tdZO\nmvYiVsrMhjWzac7rnHPO6a7dsmXLaixhcPNKVmdmvbNYpTl0a63VVBdwAHrntQaOh9UwuGNsPR9f\nGeC8Er/HhjazoTxOXKXj1ryGZXDzSsxsNWdWra3n7+3hsc5/gLe01s6d9iJWysyGNbNpzmsl59Cq\nVXneP7h5Jaszs95ZrNIcuo05AFoDx8NqGNwxtp6PrwxwXonfY0Ob2VAeJ67ScWtewzK4eSVmtpoz\ncwsYAAAAwMgJgAAAAABGTgAEAAAAMHICIAAAAICREwABAAAAjJwACAAAAGDkBEAAAAAAIycAAgAA\nABi5uWkvAPh6rbXu2qpaxZXQwwzWDrOYPjMYL7MdL7MdL7Mdlt7nAOa6usb+XMwVQAAAAAAjJwAC\nAAAAGDkBEAAAAMDICYAAAAAARk4ABAAAADBygw2AqupNVfWaQ9hvc1W1qtrnX0arqldX1e8eqv0B\nAAAAHC5r9s/AV9UHknyitfbavbY/J8mbk5zeWtu9tO0pSd7RWjt9tdbTWvvF1eoNAAAAsJrW8hVA\nb0vywqqqvba/KMk794Q/h8L+rvoBAAAAGIO1HABdkuSEJE/es6GqHpDk2UneXlVvraqfr6qjk7w/\nyalVddfS26lVNVNVr6qqa6rqq1X13qo6fqnPntu9fqCqbkjyl8v2+9Kquqmqbq6qn1q274ur6h2H\n4wsHAAAAOJTWbADUWrs3yXuTXLRs8/OTfLa19qlldXcnOT/JTa21Y5bebkryyiQXJvmOJKcmuT3J\nG/fazXckOTvJM5dte2qSM5M8I8nPVNXTD+kXdpi11rrfWBuqqvsNAAAYPo//14axPxdbswHQkrcl\neV5VbVp6/6KlbT1ekeTnWmtfbK3tSHLxUq/lt3td3Fq7eyls2uN1S9v+Pslbkvzbg/sSAAAAAKZr\nTb/2TWvtiqraluTCqvpkkscmeW7np5+R5E+qamHZtvkkD1z2/o37+Lzl265P8pgVLBkAAABgzVnr\nVwAlyduzeOXPC5Nc1lq7ZR81+7p/6cYk57fWjlv2tqm19qUJn/fgZf9+SJKbDnThAAAAAGvBUAKg\npyd5efZ/+9ctSU6oqmOXbXtTkl+oqjOSpKpOWvoT8pO8pqqOqqpHJ3lJkvcc+NIBAAAApm9N3wKW\nJK21rVX10STfkuTP9lPz2ap6d5Jrq2o2yaOS/HqSSnJ5VZ2a5CtZDHP+dMIuP5zkC1kMx17fWrv8\n0HwlAAAAANOx5gOgJGmtPWUf21681/sv3cenvmHpbe/P3ZrFcGh/2357H59zcd9qAQAAANaWIdwC\nBgAAAMBBEAABAAAAjJwACAAAAGDkBEAAAAAAIzeIF4HmwFXV5CIAAABg1FwBBAAAADByAiAAAACA\nkRMAAQAAAIycAAgAAABg5ARAAAAAACMnAAIAAAAYOQEQAAAAwMgJgAAAAABGTgAEAAAAMHICIAAA\nAICREwABAAAAjJwACAAAAGDkBEAAAAAAIycAAgAAABg5ARAAAADAyAmAAAAAAEZOAAQAAAAwcgIg\nAAAAgJETAAEAAACMnAAIAAAAYOTmpr2AdWJbkuunvYgpOWPaCzhAZjYs5jU863Vm5jU8Q5yZeQ2P\nmQ2LeQ2LeQ2Pma2Saq2tZn8AAAAApswtYAAAAAAjJwACAAAAGDkBEAAAAMDICYAAAAAARk4ABAAA\nADByAiAAAACAkRMAAQAAAIycAAgAAABg5ARAAAAAACP3/wGTk9cofOnEdwAAAABJRU5ErkJggg==\n", 176 | "text/plain": [ 177 | "" 178 | ] 179 | }, 180 | "metadata": {}, 181 | "output_type": "display_data" 182 | } 183 | ], 184 | "source": [ 185 | "fig, all_axs = plt.subplots(4, prev_post.timesteps, figsize=(16, 4))\n", 186 | "all_axs = all_axs.T\n", 187 | "\n", 188 | "VISIBLES = 0\n", 189 | "TRUE_STATES = 1\n", 190 | "FILTERING = 2\n", 191 | "VITERBI = 3\n", 192 | "\n", 193 | "all_axs[0][VISIBLES].set_title('Visibles', x=-0.5, y=0.2)\n", 194 | "all_axs[0][TRUE_STATES].set_title('Actual', x=-0.5, y=0.4)\n", 195 | "all_axs[0][FILTERING].set_title('Filtering', x=-0.5, y=0.4)\n", 196 | "all_axs[0][VITERBI].set_title('Viterbi', x=-0.5, y=0.4)\n", 197 | "\n", 198 | "for i, (axs, hidden, visible, alpha, viterbi) in enumerate(zip(\n", 199 | " all_axs, \n", 200 | " prev_post.hiddens, \n", 201 | " prev_post.visibles, \n", 202 | " prev_post.alphas,\n", 203 | " most_likely_states,\n", 204 | ")):\n", 205 | " axs[VISIBLES].imshow([prev_post.map_visible_state_to_bump_creak[visible]], cmap='gray', vmin=0)\n", 206 | " hide_ticks(axs[VISIBLES]) \n", 207 | " \n", 208 | " axs[TRUE_STATES].imshow(prev_post.plot_state_in_room(hidden), cmap='gray')\n", 209 | " hide_ticks(axs[TRUE_STATES])\n", 210 | " \n", 211 | " axs[FILTERING].imshow(alpha.reshape(prev_post.height, prev_post.width))\n", 212 | " hide_ticks(axs[FILTERING]) \n", 213 | "\n", 214 | " axs[VITERBI].imshow(prev_post.plot_state_in_room(viterbi), cmap='gray')\n", 215 | " hide_ticks(axs[VITERBI]) \n", 216 | " \n", 217 | "maybe_save_plot('2018-05-13-viterbi')\n", 218 | "plt.show()" 219 | ] 220 | }, 221 | { 222 | "cell_type": "markdown", 223 | "metadata": {}, 224 | "source": [ 225 | "## See Also\n", 226 | "\n", 227 | " - [This post](2018-05-13-viterbi-message-passing.ipynb) builds on this post to show the Viterbi algorithm.\n", 228 | " - [The more-detailed alpha recursion HMM post](2018-05-02-hmm-alpha-recursion.ipynb).\n", 229 | " - [This notebook](https://github.com/jessstringham/notebooks/blob/master/2018-05-13-hmm-check-results.ipynb) runs this code using the same example from Barber." 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": null, 235 | "metadata": { 236 | "collapsed": true 237 | }, 238 | "outputs": [], 239 | "source": [] 240 | } 241 | ], 242 | "metadata": { 243 | "kernelspec": { 244 | "display_name": "Python 3", 245 | "language": "python", 246 | "name": "python3" 247 | }, 248 | "language_info": { 249 | "codemirror_mode": { 250 | "name": "ipython", 251 | "version": 3 252 | }, 253 | "file_extension": ".py", 254 | "mimetype": "text/x-python", 255 | "name": "python", 256 | "nbconvert_exporter": "python", 257 | "pygments_lexer": "ipython3", 258 | "version": "3.6.1" 259 | } 260 | }, 261 | "nbformat": 4, 262 | "nbformat_minor": 2 263 | } 264 | -------------------------------------------------------------------------------- /2018-05-25-quick-example-of-dirichlet-distribution.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Samples from Dirichlet distribution\n", 8 | "\n", 9 | "The Dirichlet distribution is a distribution over distributions!\n", 10 | "In Bayesian methods, it is used as a prior for categorical and multinomial distributions. The Dirichlet distribution appears in natural language processing in [Latent Dirichlet allocation](https://en.wikipedia.org/wiki/Latent_Dirichlet_allocation) and [Bayesian HMMs](http://www.aclweb.org/anthology/P07-1094).\n", 11 | "\n", 12 | "In this quick post, I'll sample from `pymc3`'s `Dirichlet` distribution using different values of concentration parameters and plot what the corresponding distributions. \n", 13 | "\n", 14 | "For more detailed information:\n", 15 | " - [Wikipedia article](https://en.wikipedia.org/wiki/Dirichlet_distribution).\n", 16 | " - [This document](http://mayagupta.org/publications/FrigyikKapilaGuptaIntroToDirichlet.pdf) seems to give a good background on the mathematics." 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 1, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "import warnings\n", 26 | "warnings.filterwarnings('ignore')\n", 27 | "\n", 28 | "import pymc3 as pm\n", 29 | "\n", 30 | "import matplotlib.pyplot as plt\n", 31 | "import numpy as np" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 2, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "# helper functions you can skip over :D\n", 41 | "def hide_ticks(plot):\n", 42 | " plot.axes.get_xaxis().set_visible(False)\n", 43 | " plot.axes.get_yaxis().set_visible(False)\n", 44 | "\n", 45 | "SAVE = True\n", 46 | "def maybe_save_plot(filename):\n", 47 | " if SAVE:\n", 48 | " plt.tight_layout()\n", 49 | " plt.savefig('images/' + filename, bbox_inches=\"tight\")" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "Eh, a story might help explain. I'm randomly choosing one of three types of cookies, and I might want to weight some cookies higher than others. I don't know how exactly they should be weighted, but I have an idea of how I want them to be weighted. For example, it might be one of:\n", 57 | "\n", 58 | " - One type of cookie (I don't know which) is more frequent than the others. \n", 59 | " - Each type of cookie is equally likely.\n", 60 | " - One particular cookie type is weighted higher than the others.\n", 61 | " - I have no idea. \n", 62 | "\n", 63 | "By adjusting the concentration parameters `a` in the `pymc3.Dirichlet` distribution, I can add this information to my prior.\n", 64 | "\n", 65 | "To try this out, I'll sample a few examples from four Dirichlet distributions with different concentration parameters. \n", 66 | "\n", 67 | "Heads up, in Bayesian statistics, the prior involves the distribution that these samples come from, not the samples individually. If you saw my post about [Bayesian linear regression](2018-01-03-bayesian-linreg.ipynb), each of these samples is like a single line from the samples from the prior. The real prior distribution includes all of them!" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 3, 73 | "metadata": {}, 74 | "outputs": [ 75 | { 76 | "name": "stderr", 77 | "output_type": "stream", 78 | "text": [ 79 | "Only 20 samples in chain.\n", 80 | "Auto-assigning NUTS sampler...\n", 81 | "Initializing NUTS using jitter+adapt_diag...\n", 82 | "Multiprocess sampling (2 chains in 2 jobs)\n", 83 | "NUTS: [dir_skewed_stickbreaking__, dir_ten_stickbreaking__, dir_one_stickbreaking__, dir_tenth_stickbreaking__]\n", 84 | "100%|██████████| 520/520 [00:09<00:00, 53.02it/s]\n", 85 | "There were 3 divergences after tuning. Increase `target_accept` or reparameterize.\n", 86 | "The acceptance probability does not match the target. It is 0.939145954445564, but should be close to 0.99. Try to increase the number of tuning steps.\n", 87 | "There were 5 divergences after tuning. Increase `target_accept` or reparameterize.\n" 88 | ] 89 | } 90 | ], 91 | "source": [ 92 | "# how many categories to sample from\n", 93 | "values = 3\n", 94 | "\n", 95 | "with pm.Model() as model:\n", 96 | " # borrowing these numbers from http://mayagupta.org/publications/FrigyikKapilaGuptaIntroToDirichlet.pdf\n", 97 | " pm.Dirichlet('dir_tenth', a=0.1 * np.ones(values))\n", 98 | " pm.Dirichlet('dir_one', a=1 * np.ones(values))\n", 99 | " pm.Dirichlet('dir_ten', a=10 * np.ones(values))\n", 100 | " pm.Dirichlet('dir_skewed', a=np.array([2, 5, 15]))\n", 101 | " \n", 102 | " # just sample a small number.\n", 103 | " trace = pm.sample(20, nuts_kwargs=dict(target_accept=.99))" 104 | ] 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "metadata": {}, 109 | "source": [ 110 | "For each of the four distribution, I'll grab ten samples. These samples are another distribution! I asked for three numbers, so the samples are three numbers that sum to one.\n", 111 | "I'll plot these ten sampled distributions as a stacked bar plot. They should all sum to one. They will also reflect the concentration parameter values:\n", 112 | "\n", 113 | " - `dir_tenth`: One type of cookie (I don't know which) is more frequent than the others. \n", 114 | " - `dir_one`: I have no idea. The *probabilities* are chosen from a uniform distribution.\n", 115 | " - `dir_ten`: Each type of cookie is equally likely.\n", 116 | " - `dir_skewed`: One particular cookie type is weighted higher than the others.\n", 117 | "\n", 118 | "These can be used as the values of a discrete probability distribution, for example as the probabilities in a multinomial distribution." 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 4, 124 | "metadata": {}, 125 | "outputs": [ 126 | { 127 | "data": { 128 | "image/png": "iVBORw0KGgoAAAANSUhEUgAABHgAAAEYCAYAAAAnPkG+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAEPtJREFUeJzt3X2MZXddx/HPF7ftClQQQbSV7a4kYkmWyB9IjUAbNRKEgkFAXMWFP1SMDzEBYX34Yw00FOIfaoRIiELDoyhigCWmiGHlSZQoOvGhBnRgAfEJ2iAgCP35xz2lw6a7zOzM3XO/c1+vpNmZnXvP+d12zrdn33Pu2RpjBAAAAIC+7jb3AgAAAADYHYEHAAAAoDmBBwAAAKA5gQcAAACgOYEHAAAAoDmBBwAAAKA5gQcAAACgOYFnTVXVK6rq+VX1yKq6Ze713JWqenpVvWvudQDL1WEeAfuP2QPs1jLnSFWdrKpX7eU2d+OO1zr3Ojg/gWfNjTHeOcZ40IU8t6o2q+r79mIdVXW4qkZVHdiL7QH97GYeAVyoVTkXAvpyDsOqEHg4J7EFWBXmETAHswfYLXOEi0ngWRNV9dCq+uuq+nRV/X6Sg9PvX1dVH93yuM2qem5V/V2Sz5xrIFXVK5McSvLmqvqfqnrO9PvXVNV7qurWqvrbqrpuy3PeUVXPq6p3T+u4uaruO335z6dfb522911bnvfrVfWpqvrXqnrMHv5rAWaw1/NoeuzV04y5tar+vqoev+Vrr6iqF1fVqWmf76uqB275+rdX1duq6pNVdUtVPWU5rxyYU4NzIWDFLekc5rlV9bFpm7dU1ffexWMuqarXVtUbqurSqrpbVZ2oqg9V1X9X1eur6j7TY2+qqmdNH185vUviZ6bPHzid79xt+vxxVfWBaV69p6oe8tVeK6tN4FkDVXVpkj9O8sok90nyB0l+6DxP+ZEkj01y7zHGF+/qAWOMpyX5SJLrxxj3HGO8qKquTHIqyfOn/Tw7yRuq6n5bnnosyTOSfGOSS6fHJMmjpl/vPW3vvdPnD09yS5L7JnlRkt+tqtr2iwdWyjLmUVVdkuTNSW7OYrb8XJJXV9XWS6WfmuTXknx9kg8muWF67j2SvC3Ja6bnPjXJS6rqwRf4EoEV1ORcCFhhSzqHeVCSn03ysDHG5UkenWTzrMd87bTfzyd5yhjjC1mc6/xgkmuTXJHkU0lePD3ldJLrpo+vTfIvufPPWtcmeecY4/aqemiS30vyU0m+IclLk7ypqi67gNfKihB41sM1SS5J8htjjP8bY/xhkr86z+N/a4xxZozxuR3u58eSvHWM8dYxxu1jjLcleX+SH9jymJePMf552vbrk3zHV9nmh8cYLxtjfCnJTUm+Ocn9d7guYHUsYx5dk+SeSW4cY3xhjPFnSd6SxYnVHd44xvjL6QTr1blz9jwuyeYY4+VjjC+OMf4myRuSPPnCXh6wojqfCwGrYRlz5EtJLkvy4Kq6ZIyxOcb40Javf12SP0nyoSTPmP5MlCTPTPIrY4yPjjE+n+RkkidNVwqdTvKI6SqdR2XxQ/Lvnp537fT1JPnJJC8dY7xvjPGlMcZNWUSkay7gtbIiBJ71cEWSj40xxpbf+/B5Hn/mAvdzVZInT5f43VpVtyZ5RBZR5g6f2PLxZ7P4Q9n5fPnxY4zPTh9+tecAq2sZ8+iKJGfGGLeftc0rt3x+rtlzVZKHnzW3fjTJN21jv0Afnc+FgNWw53NkjPHBJL+QRaD5j6p6XVVdseUh1yR5SBY/xNq636uSvHHLnPnHLGLR/adA9Jks4vEjs/ih18enq4W2Bp6rkjzrrHn1gOl17vS1siIEnvXwb0muPOutTYfO8/hxnq+d73FnkrxyjHHvLf/cY4xx4wVsC9ifljGPPp7kAXe8n3zLNj+2jeeeSXL6rLl1zzHGT2/juUAfHc6FgNW2lDkyxnjNGOMRWQSXkeSFW758c5IXJHl7VW19F8OZJI85a9YcHGPcce5zOsmTklw6/d7pJMezeKv6B7Zs44aztnH3McZrL+C1siIEnvXw3iRfTPLz0w26npjkO/dgu/+e5Fu3fP6qJNdX1aOr6muq6uB0w7Fv2ca2/jPJ7WdtD9h/ljGP3pfFT8GfM23zuiTXJ3ndNp77liTfVlVPm557SVU9rKqu3uWagNXS4VwIWG17Pkeq6kFV9T1VdVmS/03yuSz+TPRlY4wXZXGvwLdvuSn77yS5oaqumrZzv6p6wpannc7i3j53/EU275g+f9eWt3m9LMkzq+rhtXCPqnpsVV2+jNfKxSHwrIHpRlxPTPL0JJ9M8sNJ/mgPNv2CJL86XdL37DHGmSRPSPLLWQSbM0l+Mdv4PpvefnVDkndP27tmD9YHrJhlzKNpm9cneUyS/0rykiQ/Psb4p20899NJvj+Lmyt/PIu3Trwwi/fDA/tEh3MhYLUtaY5cluTGLM5fPpHFzdd/6S72/bwsbnr8p9PflvWbSd6U5Oaq+nSSv8jiL6e5w+kkl+fOwPOuJHff8nnGGO9P8hNJfjuLmzR/cHpty5yZLFl95dvqAAAAAOjGTxMAAAAAmjsw9wJYXVV1KMk/nOPLDx5jfORirgdYX+YRMAezB9gtc4SLyVu0AAAAAJrzFi0AAACA5gQeAAAAgOYEHgAAAIDmBB4AAACA5gQeAAAAgOYEHgAAAIDmBB4AAACA5gQeAAAAgOYEHgAAAIDmBB4AAACA5gQeAAAAgOYOLGWrJ+81tvvQo0cOLWUJwMW1cXyj5tz/4ROntj13YFVsHjy2lO12/X/r3HNkp8ydOy3re5nt6XrMr4rus8fxd26ODVbVsuaOK3gAAAAAmhN4AAAAAJoTeAAAAACaE3gAAAAAmhN4AAAAAJoTeAAAAACaE3gAAAAAmhN4AAAAAJoTeAAAAACaE3gAAAAAmhN4AAAAAJoTeAAAAACaE3gAAAAAmhN4AAAAAJoTeAAAAACaE3gAAAAAmqsxxtxrAAAAAGAXXMEDAAAA0JzAAwAAANCcwAMAAADQnMADAAAA0NyBZWz08IlT7twMa2bzxsfWnPvfztzZPHjsYiwF2jl65NDcS0iSbBzfmHWO7JTznfMzc3duVY7FdWP27L0Ox7/jjTkta+64ggcAAACgOYEHAAAAoDmBBwAAAKA5gQcAAACgOYEHAAAAoDmBBwAAAKA5gQcAAACgOYEHAAAAoDmBBwAAAKA5gQcAAACgOYEHAAAAoDmBBwAAAKA5gQcAAACgOYEHAAAAoDmBBwAAAKA5gQcAAACguRpjzL0GAAAAAHbBFTwAAAAAzQk8AAAAAM0JPAAAAADNHVjGRo/edNSNfWDNbBzfqDn3f/jEKXOnmc2Dx5ay3aNHDi1luyzf3HNkp8yd5R3HXZg3+8N+nD1djk3HEOtqWXPHFTwAAAAAzQk8AAAAAM0JPAAAAADNCTwAAAAAzQk8AAAAAM0JPAAAAADNCTwAAAAAzQk8AAAAAM0JPAAAAADNCTwAAAAAzQk8AAAAAM0JPAAAAADNCTwAAAAAzQk8AAAAAM0JPAAAAADNCTwAAAAAzQk8AAAAAM3VGGPuNQAAAACwC67gAQAAAGhO4AEAAABoTuABAAAAaO7AMjZ6+MQpN/bZI5sHj829BNiek7fVnLvvOHcc38zl6JFDcy/hLm0c35h1juxUx7mzrszbnVnVGbEsZs/eWtXjbd2+r1lty5o7ruABAAAAaE7gAQAAAGhO4AEAAABoTuABAAAAaE7gAQAAAGhO4AEAAABoTuABAAAAaE7gAQAAAGhO4AEAAABoTuABAAAAaE7gAQAAAGhO4AEAAABoTuABAAAAaE7gAQAAAGhO4AEAAABoTuABAAAAaK7GGHOvAQAAAIBdcAUPAAAAQHMCDwAAAEBzAg8AAABAcwIPAAAAQHMHlrHRwydOuXMzzGTz4LF5dnzytppnxwv7Ye7M9t+Ofe/okUNzL2FbNo5vzDpHdmo/zJ11Yb7uTpcZcqHMnp7mPK73+zHB8i1r7riCBwAAAKA5gQcAAACgOYEHAAAAoDmBBwAAAKA5gQcAAACgOYEHAAAAoDmBBwAAAKA5gQcAAACgOYEHAAAAoDmBBwAAAKA5gQcAAACgOYEHAAAAoDmBBwAAAKA5gQcAAACgOYEHAAAAoDmBBwAAAKC5GmPMvQYAAAAAdsEVPAAAAADNCTwAAAAAzQk8AAAAAM0dWMZGD5845cY+K2jz4LG5l8B+dvK2mnP3qzJ3HGfr6eiRQ3MvYV/YOL4x6xzZqVWZO+yMOb039tPcM3uWZ5WOt/30PUt/y5o7ruABAAAAaE7gAQAAAGhO4AEAAABoTuABAAAAaE7gAQAAAGhO4AEAAABoTuABAAAAaE7gAQAAAGhO4AEAAABoTuABAAAAaE7gAQAAAGhO4AEAAABoTuABAAAAaE7gAQAAAGhO4AEAAABoTuABAAAAaE7gAQAAAGiuxhhzrwEAAACAXXAFDwAAAEBzAg8AAABAcwIPAAAAQHMHlrLVk/dyYx9YNydvqzl3f/jEKXOHPbV58Ni2Hnf0yKElr2R9bBzfmHWO7JS508d2j2fObT/POrNnd7odX/v5e5k+ljV3XMEDAAAA0JzAAwAAANCcwAMAAADQnMADAAAA0JzAAwAAANCcwAMAAADQnMADAAAA0JzAAwAAANCcwAMAAADQnMADAAAA0JzAAwAAANCcwAMAAADQnMADAAAA0JzAAwAAANCcwAMAAADQnMADAAAA0FyNMeZeAwAAAAC74AoeAAAAgOYEHgAAAIDmBB4AAACA5gQeAAAAgOYOLGOjh0+ccudmWCOXX30iG8c3as41mDvnt3nw2NxLWIqjRw7NvQT20NxzZKfMndWzX2fduZiBe2OdZ89+O2YcE3SxrLnjCh4AAACA5gQeAAAAgOYEHgAAAIDmBB4AAACA5gQeAAAAgOYEHgAAAIDmBB4AAACA5gQeAAAAgOYEHgAAAIDmBB4AAACA5gQeAAAAgOYEHgAAAIDmBB4AAACA5gQeAAAAgOYEHgAAAIDmBB4AAACA5mqMMfcaAAAAANgFV/AAAAAANCfwAAAAADQn8AAAAAA0d2ApWz15Lzf2gXVz8raac/eHT5wyd/bQ5sFjcy+BbTp65NDcS9gzG8c3Zp0jO2Xu3MnMWB37aSZcLOs+exy/u+OY40Isa+64ggcAAACgOYEHAAAAoDmBBwAAAKA5gQcAAACgOYEHAAAAoDmBBwAAAKA5gQcAAACgOYEHAAAAoDmBBwAAAKA5gQcAAACgOYEHAAAAoDmBBwAAAKA5gQcAAACgOYEHAAAAoDmBBwAAAKA5gQcAAACgOYEHAAAAoLkaY8y9BgAAAAB2wRU8AAAAAM0JPAAAAADNCTwAAAAAzR1YxkYPnzjlxj6wRi6/+kQ2jm/UnGtYt7mzefDY3Eu4aI4eOTT3ErhI5p4jO7Vuc2dO6zTzdsJ83Btmz1datePN9zn70bLmjit4AAAAAJoTeAAAAACaE3gAAAAAmhN4AAAAAJoTeAAAAACaE3gAAAAAmhN4AAAAAJoTeAAAAACaE3gAAAAAmhN4AAAAAJoTeAAAAACaE3gAAAAAmhN4AAAAAJoTeAAAAACaE3gAAAAAmhN4AAAAAJqrMcbcawAAAABgF1zBAwAAANCcwAMAAADQnMADAAAA0JzAAwAAANCcwAMAAADQnMADAAAA0JzAAwAAANCcwAMAAADQnMADAAAA0JzAAwAAANCcwAMAAADQnMADAAAA0JzAAwAAANCcwAMAAADQnMADAAAA0JzAAwAAANCcwAMAAADQnMADAAAA0JzAAwAAANCcwAMAAADQnMADAAAA0JzAAwAAANCcwAMAAADQnMADAAAA0JzAAwAAANCcwAMAAADQnMADAAAA0JzAAwAAANCcwAMAAADQnMADAAAA0JzAAwAAANCcwAMAAADQnMADAAAA0JzAAwAAANDc/wPaTSwuHMc3sgAAAABJRU5ErkJggg==\n", 129 | "text/plain": [ 130 | "
" 131 | ] 132 | }, 133 | "metadata": {}, 134 | "output_type": "display_data" 135 | } 136 | ], 137 | "source": [ 138 | "sample_count = 10\n", 139 | "\n", 140 | "fig, axs = plt.subplots(1, 4, figsize=(16, 4))\n", 141 | "\n", 142 | "names = ['dir_tenth', 'dir_one', 'dir_ten', 'dir_skewed']\n", 143 | "\n", 144 | "locations = np.arange(sample_count)\n", 145 | "for plot_i, name in enumerate(names):\n", 146 | " examples = trace[name][:sample_count]\n", 147 | " left_offset = np.zeros(sample_count)\n", 148 | " for i in range(values):\n", 149 | " vals = examples[:, i]\n", 150 | " axs[plot_i].barh(\n", 151 | " locations,\n", 152 | " vals,\n", 153 | " left=left_offset,\n", 154 | " )\n", 155 | " left_offset += vals\n", 156 | " axs[plot_i].set_title(name)\n", 157 | " axs[plot_i].axis('off')\n", 158 | " \n", 159 | "maybe_save_plot('2018-05-26-dist') \n", 160 | "plt.show()" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [] 169 | } 170 | ], 171 | "metadata": { 172 | "kernelspec": { 173 | "display_name": "Python 3", 174 | "language": "python", 175 | "name": "python3" 176 | }, 177 | "language_info": { 178 | "codemirror_mode": { 179 | "name": "ipython", 180 | "version": 3 181 | }, 182 | "file_extension": ".py", 183 | "mimetype": "text/x-python", 184 | "name": "python", 185 | "nbconvert_exporter": "python", 186 | "pygments_lexer": "ipython3", 187 | "version": "3.6.1" 188 | } 189 | }, 190 | "nbformat": 4, 191 | "nbformat_minor": 2 192 | } 193 | -------------------------------------------------------------------------------- /2018-06-22-spacy.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Quick post on spaCy\n", 8 | "\n", 9 | "It's been a few days since I've posted, so this is a quick post about what I've been experimenting with: `spaCy`, a natural language processing library.\n", 10 | "\n", 11 | "\n", 12 | "## Why use a natural language processing library like `spaCy`\n", 13 | "\n", 14 | "Natural language processing gets complicated fast.\n", 15 | "For example, it's surprisingly tricky to divide text into sentences and words. A naive approach would be to split on whitespace and periods. It's easy to find a sentence that breaks these rules, such as\n", 16 | "\n", 17 | " I'll drive to Mt. Hood on Friday!\n", 18 | "\n", 19 | "An NLP library like `spaCy` can divide `I'll` into two separate tokens `I` and `'ll`. The library can also tell that not all periods are the end of a sentence (e.g. `Mt.`), and that there is punctuation other than `.` (e.g. `!`). These rules will depend on the language; `spaCy` has an English model that works for my purposes.\n", 20 | "\n", 21 | "Aside from sentence boundary detection and tokenization, `spaCy` can tag parts-of-speech of words (`drive` is a `VERB`), say `'ll` is the same as `will`, parse the sentence (the `drive` is `on Friday`), along with other [linguistics features](https://spacy.io/usage/linguistic-features). It also has a nice set-up for adding custom attributes using [pipelines](https://spacy.io/usage/processing-pipelines). \n", 22 | "\n", 23 | "\n", 24 | "An alternative natural language processing library to `spaCy` is [`nltk`](https://www.nltk.org). `nltk` also comes with a lovely free [book](https://www.nltk.org/book/) on natural language processing." 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": {}, 30 | "source": [ 31 | "### Installation spaCy\n", 32 | "\n", 33 | "The lovely [documentation](https://spacy.io/usage/) explain how to install the package and a language model. I installed the English model." 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 1, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "import spacy\n", 43 | "\n", 44 | "nlp = spacy.load('en')" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "metadata": {}, 50 | "source": [ 51 | "### Sentence boundary detection and tokenization\n", 52 | "\n", 53 | "I can use `nlp` to parse the text and get a `Doc`. This takes a bit of time, but then further processing is fast. The length of the `Doc` (`len(doc)`) gives the number of words (`Tokens`). To get the number of sentences, I can count the sentences (`Span`) from `doc.sents`." 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 2, 59 | "metadata": { 60 | "scrolled": true 61 | }, 62 | "outputs": [ 63 | { 64 | "name": "stdout", 65 | "output_type": "stream", 66 | "text": [ 67 | "\n", 68 | "words\t\t47\n", 69 | "sentences\t3\n", 70 | "\n" 71 | ] 72 | } 73 | ], 74 | "source": [ 75 | "doc = nlp('''\\\n", 76 | "I wondered if I could write a program to automatically catch clumsy style mistakes I often make.\n", 77 | "I'll try using spaCy!\n", 78 | "It turns out style-checking is a little complicated, so this post is actually just about spaCy.\n", 79 | "''')\n", 80 | "\n", 81 | "print('''\n", 82 | "words\\t\\t{num_words}\n", 83 | "sentences\\t{num_sent}\n", 84 | "'''.format(\n", 85 | " num_words=len(doc), # grab number of tokens\n", 86 | " num_sent=len(list(doc.sents)),\n", 87 | "))" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": {}, 93 | "source": [ 94 | " words\t\t47\n", 95 | " sentences\t3" 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "metadata": {}, 101 | "source": [ 102 | "### Tokenization\n", 103 | "\n", 104 | "I can also see how `spaCy` tokenizes my example from above:\n", 105 | "\n", 106 | " I'll drive to Mt. Hood on Friday!\n", 107 | " \n", 108 | "Indeed, `spaCy` doesn't split the sentence at `Mt.` and does split `I'll` into the tokens `I` and `'ll`." 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 3, 114 | "metadata": {}, 115 | "outputs": [ 116 | { 117 | "name": "stdout", 118 | "output_type": "stream", 119 | "text": [ 120 | "I\t'll\tdrive\tto\tMt.\tHood\ton\tFriday\t!\n" 121 | ] 122 | } 123 | ], 124 | "source": [ 125 | "doc = nlp(\"I'll drive to Mt. Hood on Friday!\")\n", 126 | "\n", 127 | "for sentence in doc.sents:\n", 128 | " print('\\t'.join(str(token) for token in sentence))" 129 | ] 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "metadata": {}, 134 | "source": [ 135 | " I\t'll\tdrive\tto\tMt.\tHood\ton\tFriday\t!" 136 | ] 137 | }, 138 | { 139 | "cell_type": "markdown", 140 | "metadata": {}, 141 | "source": [ 142 | "### Lemmatization\n", 143 | "\n", 144 | "I can get root words by checking out what `token.lemma_` gives (`.lemma` without the underscore is a special ID.)\n", 145 | "It converts `'ll` into `will` and `Mt.` into `Mount`!" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 4, 151 | "metadata": {}, 152 | "outputs": [ 153 | { 154 | "name": "stdout", 155 | "output_type": "stream", 156 | "text": [ 157 | "I\t'll\tdrive\tto\tMt.\tHood\ton\tFriday\t!\n", 158 | "-PRON-\twill\tdrive\tto\tMount\thood\ton\tfriday\t!\n" 159 | ] 160 | } 161 | ], 162 | "source": [ 163 | "for sentence in doc.sents:\n", 164 | " print('\\t'.join(str(token) for token in sentence))\n", 165 | " print('\\t'.join(token.lemma_ for token in sentence))" 166 | ] 167 | }, 168 | { 169 | "cell_type": "markdown", 170 | "metadata": {}, 171 | "source": [ 172 | " I\t'll\tdrive\tto\tMt.\tHood\ton\tFriday\t!\n", 173 | " -PRON-\twill\tdrive\tto\tMount\thood\ton\tfriday\t!" 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "metadata": {}, 179 | "source": [ 180 | "### Detour: highlighting words\n", 181 | "\n", 182 | "Switching gears for a moment, I can use `IPython.display` to make funner output in Jupyter notebook [like before](https://jessicastringham.net/2018/05/07/reading-jupyter-notebooks-into-Python.html). `highlight_doc` will take a `Doc` and a function that says whether a given token should be highlighted." 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": 5, 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "from IPython.display import display, Markdown\n", 192 | "\n", 193 | "\n", 194 | "def _highlight_word(word):\n", 195 | " return '**{}**'.format(word)\n", 196 | "\n", 197 | "def highlight_doc(doc, should_highlight_func):\n", 198 | " '''Display a word.\n", 199 | "\n", 200 | " doc: spacy.Doc that should be highlighted\n", 201 | " should_highlight_func: a function that takes in a spacy.Token and returns True\n", 202 | " or False depending on if the token should be highlighted\n", 203 | " '''\n", 204 | " for sentence in doc.sents:\n", 205 | " markdown_sentence = []\n", 206 | " for token in sentence:\n", 207 | " markdown_word = token.text\n", 208 | "\n", 209 | " if should_highlight_func(token):\n", 210 | " markdown_word = _highlight_word(markdown_word)\n", 211 | "\n", 212 | " markdown_sentence.append(markdown_word)\n", 213 | " display(Markdown(' '.join(markdown_sentence)))" 214 | ] 215 | }, 216 | { 217 | "cell_type": "markdown", 218 | "metadata": {}, 219 | "source": [ 220 | "### Highlighting verbs\n", 221 | "\n", 222 | "To test the UI, I can highlight verbs by checking the [`token`'s `pos` attribute](https://spacy.io/api/annotation#pos-tagging). (In this case, I can use `.pos` instead of `.pos_` so I can compare with `spacy.symbols.VERB`.)" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": 6, 228 | "metadata": {}, 229 | "outputs": [ 230 | { 231 | "data": { 232 | "text/markdown": [ 233 | "I **wondered** if I **could** **write** a program to automatically **catch** clumsy style mistakes I often **make** . \n" 234 | ], 235 | "text/plain": [ 236 | "" 237 | ] 238 | }, 239 | "metadata": {}, 240 | "output_type": "display_data" 241 | }, 242 | { 243 | "data": { 244 | "text/markdown": [ 245 | "I **'ll** **try** **using** spaCy ! \n" 246 | ], 247 | "text/plain": [ 248 | "" 249 | ] 250 | }, 251 | "metadata": {}, 252 | "output_type": "display_data" 253 | }, 254 | { 255 | "data": { 256 | "text/markdown": [ 257 | "It **turns** out style - checking **is** a little complicated , so this post **is** actually just about spaCy . \n" 258 | ], 259 | "text/plain": [ 260 | "" 261 | ] 262 | }, 263 | "metadata": {}, 264 | "output_type": "display_data" 265 | } 266 | ], 267 | "source": [ 268 | "from spacy.symbols import VERB\n", 269 | "\n", 270 | "doc = nlp('''\\\n", 271 | "I wondered if I could write a program to automatically catch clumsy style mistakes I often make.\n", 272 | "I'll try using spaCy!\n", 273 | "It turns out style-checking is a little complicated, so this post is actually just about spaCy.\n", 274 | "''')\n", 275 | "\n", 276 | "highlight_doc(doc, lambda token: token.pos == VERB)" 277 | ] 278 | }, 279 | { 280 | "cell_type": "markdown", 281 | "metadata": {}, 282 | "source": [ 283 | "I **wondered** if I **could** **write** a program to automatically **catch** clumsy style mistakes I often **make** . \n", 284 | "\n", 285 | "I **'ll** **try** **using** spaCy ! \n", 286 | "\n", 287 | "It **turns** out style - checking **is** a little complicated , so this post **is** actually just about spaCy . " 288 | ] 289 | }, 290 | { 291 | "cell_type": "markdown", 292 | "metadata": {}, 293 | "source": [ 294 | "### Named entities\n", 295 | "\n", 296 | "`spaCy` also extracts a few neat natural language processing. For example, it can highlight [named entities](https://spacy.io/usage/linguistic-features#section-named-entities), which is often hard to do!\n", 297 | "It says Mt. Hood is a \"Buildings, airports, highways, bridges, etc.\" Neat!" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": 7, 303 | "metadata": {}, 304 | "outputs": [ 305 | { 306 | "data": { 307 | "text/markdown": [ 308 | "I 'll drive to **Mt.** **Hood** on **Friday** !" 309 | ], 310 | "text/plain": [ 311 | "" 312 | ] 313 | }, 314 | "metadata": {}, 315 | "output_type": "display_data" 316 | }, 317 | { 318 | "name": "stdout", 319 | "output_type": "stream", 320 | "text": [ 321 | "Mt. Hood FAC\n", 322 | "Friday DATE\n" 323 | ] 324 | } 325 | ], 326 | "source": [ 327 | "# this will be a little hard to read if noun chunks are near each other\n", 328 | "doc = nlp(\"I'll drive to Mt. Hood on Friday!\")\n", 329 | "\n", 330 | "# get a list of token indexes that are in a noun_chunk\n", 331 | "is_in_named_entity = set(sum((list(range(entity.start, entity.end)) for entity in doc.ents), []))\n", 332 | "\n", 333 | "highlight_doc(doc, lambda token: token.i in is_in_named_entity)\n", 334 | "\n", 335 | "for entity in doc.ents:\n", 336 | " print(entity, entity.label_)" 337 | ] 338 | }, 339 | { 340 | "cell_type": "markdown", 341 | "metadata": {}, 342 | "source": [ 343 | "I 'll drive to **Mt.** **Hood** on **Friday** !\n", 344 | "\n", 345 | " Mt. Hood FAC\n", 346 | " Friday DATE" 347 | ] 348 | }, 349 | { 350 | "cell_type": "markdown", 351 | "metadata": {}, 352 | "source": [ 353 | "## Etc\n", 354 | "\n", 355 | "This was a quick post introducing a few features of `spaCy`. Assembling them into a real project is another challenge!\n", 356 | "\n", 357 | "`spaCy` is an interesting project. It's neat to see how NLP and AI can be used in a usable package. \n", 358 | "The [`spaCy`](https://spacy.io) documentation is lots of fun. One tip is to jump between similarly-named sections, like POS tagging, in \"Usage\", \"Models\", and \"API\"." 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": null, 364 | "metadata": {}, 365 | "outputs": [], 366 | "source": [] 367 | } 368 | ], 369 | "metadata": { 370 | "kernelspec": { 371 | "display_name": "Python 3", 372 | "language": "python", 373 | "name": "python3" 374 | }, 375 | "language_info": { 376 | "codemirror_mode": { 377 | "name": "ipython", 378 | "version": 3 379 | }, 380 | "file_extension": ".py", 381 | "mimetype": "text/x-python", 382 | "name": "python", 383 | "nbconvert_exporter": "python", 384 | "pygments_lexer": "ipython3", 385 | "version": "3.6.1" 386 | } 387 | }, 388 | "nbformat": 4, 389 | "nbformat_minor": 2 390 | } 391 | -------------------------------------------------------------------------------- /2019-01-09-sum-product-message-passing.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Implementing Belief Propagation in Python\n", 8 | "\n", 9 | "As part of reviewing the ML concepts I learned last year, I implemented the _sum-product message passing_, or [belief propagation](https://en.wikipedia.org/wiki/Belief_propagation), that we learned in our probabilistic modeling course.\n", 10 | "\n", 11 | "Belief propagation (or sum-product message passing) is a method that can do inference on [probabilistic graphical models](https://en.wikipedia.org/wiki/Graphical_model). I'll focus on the algorithm that can perform exact inference on tree-like factor graphs.\n", 12 | "\n", 13 | "This post assumes knowledge of probabilistic graphical models (perhaps through [the Coursera course](https://www.coursera.org/learn/probabilistic-graphical-models)) and maybe have heard of belief propagation. I'll freely use terms such as \"factor graph\" and \"exact inference.\"\n", 14 | "\n", 15 | "\n", 16 | "## Belief Propagation\n", 17 | "\n", 18 | "Belief propagation, or sum-product message passing, is an algorithm for efficiently applying the sum rules and product rules of probability to compute different distributions. For example, if a discrete probability distribution $p(h_1, v_1, h_2, v_2)$ can be factorized as\n", 19 | "\n", 20 | "$$p(h_1, h_2, v_1, v_2) = p(h_1)p(h_2 \\mid h_1)p(v_1 \\mid h_1)p(v_2 \\mid h_2),$$\n", 21 | "\n", 22 | "I could compute marginals, for example, $p(v_1)$, by multiplying the terms and summing over the other variables.\n", 23 | "\n", 24 | "$$p(v_1) = \\sum_{h_1, h_2, v_2} p(h_1)p(h_2 \\mid h_1)p(v_1 \\mid h_1)p(v_2 \\mid h_2),$$\n", 25 | "\n", 26 | "With marginals, one can compute distributions such as $p(v_1)$ and $p(v_1, v_2)$, which means that one can also compute terms like $p(v_2 \\mid v_1)$. Belief propagation provides an efficient method for computing these marginals. \n", 27 | "\n", 28 | "This version will only work on discrete distributions. I'll code it with directed graphical models in mind, though it should also work with undirected models with few changes." 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 1, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "import numpy as np\n", 38 | "from collections import namedtuple" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": {}, 44 | "source": [ 45 | "## Part 1: (Digression) Representing probability distributions as numpy arrays\n", 46 | "\n", 47 | "The sum-product message passing involves representing, summing, and multiplying discrete distributions. I think it's pretty fun to try to implement this with numpy arrays; I gained more intuition about probability distributions and numpy.\n", 48 | "\n", 49 | "A discrete conditional distribution $p(v_1 \\mid h_1)$ can be represented as an array with two axes, such as\n", 50 | "\n", 51 | "| | $h_1$ = a | $h_1$ = b | $h_1$ = c |\n", 52 | "|-|-|-|-|\n", 53 | "| $v_1$ = 0 | 0.4 | 0.8 | 0.9 |\n", 54 | "| $v_1$ = 1 | 0.6 | 0.2 | 0.1 |\n", 55 | "\n", 56 | "\n", 57 | "Using an axis for each variable can generalize to more variables. For example, the 5-variable $p(h_5 \\mid h_4, h_3, h_2, h_1)$ could be represented by an array with five axes. \n", 58 | "\n", 59 | "It's useful to label axes with variable names. I'll do this in my favorite way, a little `namedtuple`! (It's kind of like a janky version of the [NamedTensor](http://nlp.seas.harvard.edu/NamedTensor).)" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 2, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "LabeledArray = namedtuple('LabeledArray', [\n", 69 | " 'array',\n", 70 | " 'axes_labels',\n", 71 | "])\n", 72 | "\n", 73 | "def name_to_axis_mapping(labeled_array):\n", 74 | " return {\n", 75 | " name: axis\n", 76 | " for axis, name in enumerate(labeled_array.axes_labels)\n", 77 | " }\n", 78 | "\n", 79 | "def other_axes_from_labeled_axes(labeled_array, axis_label):\n", 80 | " # returns the indexes of the axes that are not axis label\n", 81 | " return tuple(\n", 82 | " axis\n", 83 | " for axis, name in enumerate(labeled_array.axes_labels)\n", 84 | " if name != axis_label\n", 85 | " )" 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "metadata": {}, 91 | "source": [ 92 | "### Checking that a numpy array is a valid discrete distribution\n", 93 | "\n", 94 | "It's easy to accidentally swap axes when creating numpy arrays representing distributions. I'll also write code to verify they are valid distributions.\n", 95 | "\n", 96 | "To check that a multidimensional array is a _joint_ distribution, the entire array should sum to one.\n", 97 | "\n", 98 | "To check that a 2D array is a _conditional_ distribution, when all of the right-hand-side variables have been assigned, such as $p(v_1 \\mid h_1 = a)$, the resulting vector represents a distribution. The vector should have the length of the number of states of $v_1$ and should sum to one. Computing this in numpy involves summing along the axis corresponding to the $v_1$ variable.\n", 99 | "\n", 100 | "To generalize conditional distribution arrays to the multi-dimensional example, again, when all of the right-hand-side variables have been assigned, such as $p(h_5 \\mid h_4=a, h_3=b, h_2=a, h_1=a)$, the resulting vector represents a distribution. The vector should have a length which is the number of states of $h_1$ and should sum to one." 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 3, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "def is_conditional_prob(labeled_array, var_name):\n", 110 | " '''\n", 111 | " labeled_array (LabeledArray)\n", 112 | " variable (str): name of variable, i.e. 'a' in p(a|b)\n", 113 | " '''\n", 114 | " return np.all(np.isclose(np.sum(\n", 115 | " labeled_array.array,\n", 116 | " axis=name_to_axis_mapping(labeled_array)[var_name]\n", 117 | " ), 1.0))\n", 118 | " \n", 119 | "def is_joint_prob(labeled_array):\n", 120 | " return np.all(np.isclose(np.sum(labeled_array.array), 1.0))" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 4, 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "p_v1_given_h1 = LabeledArray(np.array([[0.4, 0.8, 0.9], [0.6, 0.2, 0.1]]), ['v1', 'h1'])\n", 130 | "\n", 131 | "p_h1 = LabeledArray(np.array([0.6, 0.3, 0.1]), ['h1'])\n", 132 | "\n", 133 | "p_v1_given_many = LabeledArray(np.array(\n", 134 | " [[[0.9, 0.2], [0.3, 0.2]],\n", 135 | " [[0.1, 0.8], [0.7, 0.8]]]\n", 136 | "), ['v1', 'h1', 'h2'])\n", 137 | "\n", 138 | "assert is_conditional_prob(p_v1_given_h1, 'v1')\n", 139 | "assert not is_joint_prob(p_v1_given_h1)\n", 140 | "\n", 141 | "assert is_conditional_prob(p_h1, 'h1')\n", 142 | "assert is_joint_prob(p_h1)\n", 143 | "\n", 144 | "assert is_conditional_prob(p_v1_given_many, 'v1')\n", 145 | "assert not is_joint_prob(p_v1_given_many)" 146 | ] 147 | }, 148 | { 149 | "cell_type": "markdown", 150 | "metadata": {}, 151 | "source": [ 152 | "### Multiplying distributions\n", 153 | "\n", 154 | "In belief propagation, I also need to compute the product of distributions, such as $p(h_2 \\mid h_1)p(h_1)$.\n", 155 | "\n", 156 | "In this case, I'll only need to multiply a multidimensional array by a 1D array and occasionally a scalar. The way I ended up implementing this was to align the axis of the 1D array with its corresponding axis from the other distribution. Then I tile the 1D array to be the size of $p(h_2 \\mid h_1)$. This gives me the joint distribution $p(h_1, h_2)$." 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": 5, 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "def tile_to_shape_along_axis(arr, target_shape, target_axis):\n", 166 | " # get a list of all axes\n", 167 | " raw_axes = list(range(len(target_shape)))\n", 168 | " tile_dimensions = [target_shape[a] for a in raw_axes if a != target_axis]\n", 169 | " if len(arr.shape) == 0:\n", 170 | " # If given a scalar, also tile it in the target dimension (so it's a bunch of 1s)\n", 171 | " tile_dimensions += [target_shape[target_axis]]\n", 172 | " elif len(arr.shape) == 1:\n", 173 | " # If given an array, it should be the same shape as the target axis\n", 174 | " assert arr.shape[0] == target_shape[target_axis]\n", 175 | " tile_dimensions += [1]\n", 176 | " else:\n", 177 | " raise NotImplementedError()\n", 178 | " tiled = np.tile(arr, tile_dimensions)\n", 179 | "\n", 180 | " # Tiling only adds prefix axes, so rotate this one back into place\n", 181 | " shifted_axes = raw_axes[:target_axis] + [raw_axes[-1]] + raw_axes[target_axis:-1]\n", 182 | " transposed = np.transpose(tiled, shifted_axes)\n", 183 | "\n", 184 | " # Double-check this code tiled it to the correct shape\n", 185 | " assert transposed.shape == target_shape\n", 186 | " return transposed\n", 187 | "\n", 188 | "def tile_to_other_dist_along_axis_name(tiling_labeled_array, target_array):\n", 189 | " assert len(tiling_labeled_array.axes_labels) == 1\n", 190 | " target_axis_label = tiling_labeled_array.axes_labels[0]\n", 191 | " \n", 192 | " return LabeledArray(\n", 193 | " tile_to_shape_along_axis(\n", 194 | " tiling_labeled_array.array,\n", 195 | " target_array.array.shape,\n", 196 | " name_to_axis_mapping(target_array)[target_axis_label]\n", 197 | " ),\n", 198 | " axes_labels=target_array.axes_labels\n", 199 | " )" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": 6, 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "tiled_p_h1 = tile_to_other_dist_along_axis_name(p_h1, p_v1_given_h1)\n", 209 | "\n", 210 | "# Check that the product is a joint distribution (p(v1, h1))\n", 211 | "assert np.isclose(np.sum(p_v1_given_h1.array * tiled_p_h1.array), 1.0)" 212 | ] 213 | }, 214 | { 215 | "cell_type": "markdown", 216 | "metadata": {}, 217 | "source": [ 218 | "## Part 2: Factor Graphs\n", 219 | "\n", 220 | "Factor graphs are used to represent a distribution for sum-product message passing.\n", 221 | "One factor graph that represents $p(h_1, h_2, v_1, v_2)$ is\n", 222 | "\n", 223 | "![](images/2019-01-09-factor-graph.png)\n", 224 | "\n", 225 | "Factors, such as $p(h_1)$, are represented by black squares and represent a factor (or function, such as a probability distribution.) Variables, such as $h_1$, are represented by white circles. Variables only neighbor factors, and factors only neighbor variables.\n", 226 | "\n", 227 | "In code, \n", 228 | " - There are two classes in the graph: Variable and Factor. Both classes have a string representing the name and a list of neighbors.\n", 229 | " - A Variable can only have Factors in its list of neighbors. A Factor can only have Variables.\n", 230 | " - To represent the probability distribution, Factors also have a field for data." 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": 7, 236 | "metadata": {}, 237 | "outputs": [], 238 | "source": [ 239 | "class Node(object):\n", 240 | " def __init__(self, name):\n", 241 | " self.name = name\n", 242 | " self.neighbors = []\n", 243 | "\n", 244 | " def __repr__(self):\n", 245 | " return \"{classname}({name}, [{neighbors}])\".format(\n", 246 | " classname=type(self).__name__,\n", 247 | " name=self.name,\n", 248 | " neighbors=', '.join([n.name for n in self.neighbors])\n", 249 | " )\n", 250 | "\n", 251 | " def is_valid_neighbor(self, neighbor):\n", 252 | " raise NotImplemented()\n", 253 | "\n", 254 | " def add_neighbor(self, neighbor):\n", 255 | " assert self.is_valid_neighbor(neighbor)\n", 256 | " self.neighbors.append(neighbor)\n", 257 | "\n", 258 | "\n", 259 | "class Variable(Node):\n", 260 | " def is_valid_neighbor(self, factor):\n", 261 | " return isinstance(factor, Factor) # Variables can only neighbor Factors\n", 262 | "\n", 263 | "\n", 264 | "class Factor(Node):\n", 265 | " def is_valid_neighbor(self, variable):\n", 266 | " return isinstance(variable, Variable) # Factors can only neighbor Variables\n", 267 | "\n", 268 | " def __init__(self, name):\n", 269 | " super(Factor, self).__init__(name)\n", 270 | " self.data = None" 271 | ] 272 | }, 273 | { 274 | "cell_type": "markdown", 275 | "metadata": {}, 276 | "source": [ 277 | "## Part 3: Parsing distributions into graphs\n", 278 | "\n", 279 | "Defining a graph can be a little verbose. I can hack together a parser for probability distributions that can interpret a string like `p(h1)p(h2∣h1)p(v1∣h1)p(v2∣h2)` as a factor graph for me.\n", 280 | "\n", 281 | "(This is pretty fragile and not user-friendly. For example, be sure to use `|` character rather than the indistinguishable `∣` character!)" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": 8, 287 | "metadata": {}, 288 | "outputs": [], 289 | "source": [ 290 | "ParsedTerm = namedtuple('ParsedTerm', [\n", 291 | " 'term',\n", 292 | " 'var_name',\n", 293 | " 'given',\n", 294 | "])\n", 295 | "\n", 296 | "\n", 297 | "def _parse_term(term):\n", 298 | " # Given a term like (a|b,c), returns a list of variables\n", 299 | " # and conditioned-on variables\n", 300 | " assert term[0] == '(' and term[-1] == ')'\n", 301 | " term_variables = term[1:-1]\n", 302 | "\n", 303 | " # Handle conditionals\n", 304 | " if '|' in term_variables:\n", 305 | " var, given = term_variables.split('|')\n", 306 | " given = given.split(',')\n", 307 | " else:\n", 308 | " var = term_variables\n", 309 | " given = []\n", 310 | "\n", 311 | " return var, given\n", 312 | "\n", 313 | "\n", 314 | "def _parse_model_string_into_terms(model_string):\n", 315 | " return [\n", 316 | " ParsedTerm('p' + term, *_parse_term(term))\n", 317 | " for term in model_string.split('p')\n", 318 | " if term\n", 319 | " ]\n", 320 | "\n", 321 | "def parse_model_into_variables_and_factors(model_string):\n", 322 | " # Takes in a model_string such as p(h1)p(h2∣h1)p(v1∣h1)p(v2∣h2) and returns a\n", 323 | " # dictionary of variable names to variables and a list of factors.\n", 324 | " \n", 325 | " # Split model_string into ParsedTerms\n", 326 | " parsed_terms = _parse_model_string_into_terms(model_string)\n", 327 | " \n", 328 | " # First, extract all of the variables from the model_string (h1, h2, v1, v2). \n", 329 | " # These each will be a new Variable that are referenced from Factors below.\n", 330 | " variables = {}\n", 331 | " for parsed_term in parsed_terms:\n", 332 | " # if the variable name wasn't seen yet, add it to the variables dict\n", 333 | " if parsed_term.var_name not in variables:\n", 334 | " variables[parsed_term.var_name] = Variable(parsed_term.var_name)\n", 335 | "\n", 336 | " # Now extract factors from the model. Each term (e.g. \"p(v1|h1)\") corresponds to \n", 337 | " # a factor. \n", 338 | " # Then find all variables in this term (\"v1\", \"h1\") and add the corresponding Variables\n", 339 | " # as neighbors to the new Factor, and this Factor to the Variables' neighbors.\n", 340 | " factors = []\n", 341 | " for parsed_term in parsed_terms:\n", 342 | " # This factor will be neighbors with all \"variables\" (left-hand side variables) and given variables\n", 343 | " new_factor = Factor(parsed_term.term)\n", 344 | " all_var_names = [parsed_term.var_name] + parsed_term.given\n", 345 | " for var_name in all_var_names:\n", 346 | " new_factor.add_neighbor(variables[var_name])\n", 347 | " variables[var_name].add_neighbor(new_factor)\n", 348 | " factors.append(new_factor)\n", 349 | "\n", 350 | " return factors, variables" 351 | ] 352 | }, 353 | { 354 | "cell_type": "code", 355 | "execution_count": 9, 356 | "metadata": {}, 357 | "outputs": [ 358 | { 359 | "data": { 360 | "text/plain": [ 361 | "([Factor(p(h1), [h1]),\n", 362 | " Factor(p(h2|h1), [h2, h1]),\n", 363 | " Factor(p(v1|h1), [v1, h1]),\n", 364 | " Factor(p(v2|h2), [v2, h2])],\n", 365 | " {'h1': Variable(h1, [p(h1), p(h2|h1), p(v1|h1)]),\n", 366 | " 'h2': Variable(h2, [p(h2|h1), p(v2|h2)]),\n", 367 | " 'v1': Variable(v1, [p(v1|h1)]),\n", 368 | " 'v2': Variable(v2, [p(v2|h2)])})" 369 | ] 370 | }, 371 | "execution_count": 9, 372 | "metadata": {}, 373 | "output_type": "execute_result" 374 | } 375 | ], 376 | "source": [ 377 | "parse_model_into_variables_and_factors(\"p(h1)p(h2|h1)p(v1|h1)p(v2|h2)\")" 378 | ] 379 | }, 380 | { 381 | "cell_type": "markdown", 382 | "metadata": {}, 383 | "source": [ 384 | " ([Factor(p(h1), [h1]),\n", 385 | " Factor(p(h2|h1), [h2, h1]),\n", 386 | " Factor(p(v1|h1), [v1, h1]),\n", 387 | " Factor(p(v2|h2), [v2, h2])],\n", 388 | " {'h1': Variable(h1, [p(h1), p(h2|h1), p(v1|h1)]),\n", 389 | " 'h2': Variable(h2, [p(h2|h1), p(v2|h2)]),\n", 390 | " 'v1': Variable(v1, [p(v1|h1)]),\n", 391 | " 'v2': Variable(v2, [p(v2|h2)])})" 392 | ] 393 | }, 394 | { 395 | "cell_type": "markdown", 396 | "metadata": {}, 397 | "source": [ 398 | "## Part 4: Adding distributions to the graph\n", 399 | "\n", 400 | "Before I can run the algorithm, I need to associate LabeledArrays with each Factor. At this point, I'll create a class to hold onto the Variables and Factors.\n", 401 | "\n", 402 | "While I'm here, I can do a few checks to make sure the provided data matches the graph." 403 | ] 404 | }, 405 | { 406 | "cell_type": "code", 407 | "execution_count": 10, 408 | "metadata": {}, 409 | "outputs": [], 410 | "source": [ 411 | "class PGM(object):\n", 412 | " def __init__(self, factors, variables):\n", 413 | " self._factors = factors\n", 414 | " self._variables = variables\n", 415 | "\n", 416 | " @classmethod\n", 417 | " def from_string(cls, model_string):\n", 418 | " factors, variables = parse_model_into_variables_and_factors(model_string)\n", 419 | " return PGM(factors, variables)\n", 420 | "\n", 421 | " def set_data(self, data):\n", 422 | " # Keep track of variable dimensions to check for shape mistakes\n", 423 | " var_dims = {}\n", 424 | " for factor in self._factors:\n", 425 | " factor_data = data[factor.name]\n", 426 | "\n", 427 | " if set(factor_data.axes_labels) != set(v.name for v in factor.neighbors):\n", 428 | " missing_axes = set(v.name for v in factor.neighbors) - set(data[factor.name].axes_labels)\n", 429 | " raise ValueError(\"data[{}] is missing axes: {}\".format(factor.name, missing_axes))\n", 430 | " \n", 431 | " for var_name, dim in zip(factor_data.axes_labels, factor_data.array.shape):\n", 432 | " if var_name not in var_dims:\n", 433 | " var_dims[var_name] = dim\n", 434 | " \n", 435 | " if var_dims[var_name] != dim:\n", 436 | " raise ValueError(\"data[{}] axes is wrong size, {}. Expected {}\".format(factor.name, dim, var_dims[var_name])) \n", 437 | " \n", 438 | " factor.data = data[factor.name]\n", 439 | " \n", 440 | " def variable_from_name(self, var_name):\n", 441 | " return self._variables[var_name]" 442 | ] 443 | }, 444 | { 445 | "cell_type": "markdown", 446 | "metadata": {}, 447 | "source": [ 448 | "I can now try to add distributions to a graph." 449 | ] 450 | }, 451 | { 452 | "cell_type": "code", 453 | "execution_count": 11, 454 | "metadata": {}, 455 | "outputs": [], 456 | "source": [ 457 | "p_h1 = LabeledArray(np.array([[0.2], [0.8]]), ['h1'])\n", 458 | "p_h2_given_h1 = LabeledArray(np.array([[0.5, 0.2], [0.5, 0.8]]), ['h2', 'h1'])\n", 459 | "p_v1_given_h1 = LabeledArray(np.array([[0.6, 0.1], [0.4, 0.9]]), ['v1', 'h1'])\n", 460 | "p_v2_given_h2 = LabeledArray(p_v1_given_h1.array, ['v2', 'h2'])\n", 461 | "\n", 462 | "assert is_joint_prob(p_h1)\n", 463 | "assert is_conditional_prob(p_h2_given_h1, 'h2')\n", 464 | "assert is_conditional_prob(p_v1_given_h1, 'v1')\n", 465 | "assert is_conditional_prob(p_v2_given_h2, 'v2')" 466 | ] 467 | }, 468 | { 469 | "cell_type": "code", 470 | "execution_count": 12, 471 | "metadata": {}, 472 | "outputs": [], 473 | "source": [ 474 | "pgm = PGM.from_string(\"p(h1)p(h2|h1)p(v1|h1)p(v2|h2)\")\n", 475 | "\n", 476 | "pgm.set_data({\n", 477 | " \"p(h1)\": p_h1,\n", 478 | " \"p(h2|h1)\": p_h2_given_h1,\n", 479 | " \"p(v1|h1)\": p_v1_given_h1,\n", 480 | " \"p(v2|h2)\": p_v2_given_h2,\n", 481 | "})" 482 | ] 483 | }, 484 | { 485 | "cell_type": "markdown", 486 | "metadata": {}, 487 | "source": [ 488 | "## Part 5: Belief Propagation\n", 489 | "\n", 490 | "We made it! Now we can implement sum-product message passing. \n", 491 | "\n", 492 | "Sum-product message passing will compute values (\"messages\") for every edge in the factor graph.\n", 493 | "\n", 494 | "![](images/2019-01-09-factor-graph.png)\n", 495 | "\n", 496 | "The algorithm will compute a message from the Factor $f$ to the Variable $x$, notated as $\\mu_{f \\to x}(x)$. It will also compute the value from Variable $x$ to the Factor $f$, $\\mu_{x \\to f}(x)$. As is common in graph algorithms, these are defined recursively.\n", 497 | "\n", 498 | "(I'm using the equations as given in Barber p84.)\n", 499 | "\n", 500 | "### Variable-to-Factor Message\n", 501 | "\n", 502 | "The variable-to-factor message is given by:\n", 503 | "\n", 504 | "$$\\mu_{x \\to f}(x) = \\prod_{g \\in \\{ne(x) \\setminus f\\}} \\mu_{g \\to x}(x)$$\n", 505 | "\n", 506 | "where $ne(x)$ are the neighbors of $x$." 507 | ] 508 | }, 509 | { 510 | "cell_type": "code", 511 | "execution_count": 13, 512 | "metadata": {}, 513 | "outputs": [], 514 | "source": [ 515 | "def _variable_to_factor_messages(variable, factor):\n", 516 | " # Take the product over all incoming factors into this variable except the variable\n", 517 | " incoming_messages = [\n", 518 | " _factor_to_variable_message(neighbor_factor, variable)\n", 519 | " for neighbor_factor in variable.neighbors\n", 520 | " if neighbor_factor.name != factor.name\n", 521 | " ]\n", 522 | "\n", 523 | " # If there are no incoming messages, this is 1\n", 524 | " return np.prod(incoming_messages, axis=0)" 525 | ] 526 | }, 527 | { 528 | "cell_type": "markdown", 529 | "metadata": {}, 530 | "source": [ 531 | "### Factor-to-Variable Message\n", 532 | "\n", 533 | "The variable-to-factor message is given by \n", 534 | "\n", 535 | "$$\\mu_{f \\to x}(x) = \\sum_{\\chi_f \\setminus x}\\phi_f(\\chi_f) \\prod_{y \\in \\{ne(f) \\setminus x \\}} \\mu_{y \\to f}(y)$$\n", 536 | "\n", 537 | "In the case of probabilities, $\\phi_f(\\chi_f)$ is the probability distribution associated with the factor, and $\\sum_{\\chi_f \\setminus x}$ sums over all variables except $x$.\n" 538 | ] 539 | }, 540 | { 541 | "cell_type": "code", 542 | "execution_count": 14, 543 | "metadata": {}, 544 | "outputs": [], 545 | "source": [ 546 | "def _factor_to_variable_messages(factor, variable):\n", 547 | " # Compute the product\n", 548 | " factor_dist = np.copy(factor.data.array)\n", 549 | " for neighbor_variable in factor.neighbors:\n", 550 | " if neighbor_variable.name == variable.name:\n", 551 | " continue\n", 552 | " incoming_message = variable_to_factor_messages(neighbor_variable, factor)\n", 553 | " factor_dist *= tile_to_other_dist_along_axis_name(\n", 554 | " LabeledArray(incoming_message, [neighbor_variable.name]),\n", 555 | " factor.data\n", 556 | " ).array\n", 557 | " # Sum over the axes that aren't `variable`\n", 558 | " other_axes = other_axes_from_labeled_axes(factor.data, variable.name)\n", 559 | " return np.squeeze(np.sum(factor_dist, axis=other_axes))" 560 | ] 561 | }, 562 | { 563 | "cell_type": "markdown", 564 | "metadata": {}, 565 | "source": [ 566 | "### Marginal\n", 567 | "\n", 568 | "The marginal of a variable $x$ is given by\n", 569 | "\n", 570 | "$$p(x) \\propto \\prod_{f \\in ne(x)}\\mu_{f \\to x}(x)$$" 571 | ] 572 | }, 573 | { 574 | "cell_type": "code", 575 | "execution_count": 15, 576 | "metadata": {}, 577 | "outputs": [], 578 | "source": [ 579 | "def marginal(variable):\n", 580 | " # p(variable) is proportional to the product of incoming messages to variable.\n", 581 | " unnorm_p = np.prod([\n", 582 | " self.factor_to_variable_message(neighbor_factor, variable)\n", 583 | " for neighbor_factor in variable.neighbors\n", 584 | " ], axis=0)\n", 585 | "\n", 586 | " # At this point, we can normalize this distribution\n", 587 | " return unnorm_p/np.sum(unnorm_p)" 588 | ] 589 | }, 590 | { 591 | "cell_type": "markdown", 592 | "metadata": {}, 593 | "source": [ 594 | "## Adding to PGM\n", 595 | "\n", 596 | "A source of message passing's efficiency is that messages from one computation can be reused by other computations. I'll create an object to store `Messages`." 597 | ] 598 | }, 599 | { 600 | "cell_type": "code", 601 | "execution_count": 16, 602 | "metadata": {}, 603 | "outputs": [], 604 | "source": [ 605 | "class Messages(object):\n", 606 | " def __init__(self):\n", 607 | " self.messages = {}\n", 608 | " \n", 609 | " def _variable_to_factor_messages(self, variable, factor):\n", 610 | " # Take the product over all incoming factors into this variable except the variable\n", 611 | " incoming_messages = [\n", 612 | " self.factor_to_variable_message(neighbor_factor, variable)\n", 613 | " for neighbor_factor in variable.neighbors\n", 614 | " if neighbor_factor.name != factor.name\n", 615 | " ]\n", 616 | "\n", 617 | " # If there are no incoming messages, this is 1\n", 618 | " return np.prod(incoming_messages, axis=0)\n", 619 | " \n", 620 | " def _factor_to_variable_messages(self, factor, variable):\n", 621 | " # Compute the product\n", 622 | " factor_dist = np.copy(factor.data.array)\n", 623 | " for neighbor_variable in factor.neighbors:\n", 624 | " if neighbor_variable.name == variable.name:\n", 625 | " continue\n", 626 | " incoming_message = self.variable_to_factor_messages(neighbor_variable, factor)\n", 627 | " factor_dist *= tile_to_other_dist_along_axis_name(\n", 628 | " LabeledArray(incoming_message, [neighbor_variable.name]),\n", 629 | " factor.data\n", 630 | " ).array\n", 631 | " # Sum over the axes that aren't `variable`\n", 632 | " other_axes = other_axes_from_labeled_axes(factor.data, variable.name)\n", 633 | " return np.squeeze(np.sum(factor_dist, axis=other_axes))\n", 634 | " \n", 635 | " def marginal(self, variable):\n", 636 | " # p(variable) is proportional to the product of incoming messages to variable.\n", 637 | " unnorm_p = np.prod([\n", 638 | " self.factor_to_variable_message(neighbor_factor, variable)\n", 639 | " for neighbor_factor in variable.neighbors\n", 640 | " ], axis=0)\n", 641 | "\n", 642 | " # At this point, we can normalize this distribution\n", 643 | " return unnorm_p/np.sum(unnorm_p)\n", 644 | " \n", 645 | " def variable_to_factor_messages(self, variable, factor):\n", 646 | " message_name = (variable.name, factor.name)\n", 647 | " if message_name not in self.messages:\n", 648 | " self.messages[message_name] = self._variable_to_factor_messages(variable, factor)\n", 649 | " return self.messages[message_name]\n", 650 | " \n", 651 | " def factor_to_variable_message(self, factor, variable):\n", 652 | " message_name = (factor.name, variable.name)\n", 653 | " if message_name not in self.messages:\n", 654 | " self.messages[message_name] = self._factor_to_variable_messages(factor, variable)\n", 655 | " return self.messages[message_name] " 656 | ] 657 | }, 658 | { 659 | "cell_type": "code", 660 | "execution_count": 17, 661 | "metadata": { 662 | "scrolled": false 663 | }, 664 | "outputs": [ 665 | { 666 | "data": { 667 | "text/plain": [ 668 | "array([0.23, 0.77])" 669 | ] 670 | }, 671 | "execution_count": 17, 672 | "metadata": {}, 673 | "output_type": "execute_result" 674 | } 675 | ], 676 | "source": [ 677 | "pgm = PGM.from_string(\"p(h1)p(h2|h1)p(v1|h1)p(v2|h2)\")\n", 678 | "\n", 679 | "pgm.set_data({\n", 680 | " \"p(h1)\": p_h1,\n", 681 | " \"p(h2|h1)\": p_h2_given_h1,\n", 682 | " \"p(v1|h1)\": p_v1_given_h1,\n", 683 | " \"p(v2|h2)\": p_v2_given_h2,\n", 684 | "})\n", 685 | "\n", 686 | "m = Messages()\n", 687 | "m.marginal(pgm.variable_from_name('v2'))" 688 | ] 689 | }, 690 | { 691 | "cell_type": "markdown", 692 | "metadata": {}, 693 | "source": [ 694 | " array([0.23, 0.77])" 695 | ] 696 | }, 697 | { 698 | "cell_type": "code", 699 | "execution_count": 18, 700 | "metadata": {}, 701 | "outputs": [ 702 | { 703 | "data": { 704 | "text/plain": [ 705 | "{('p(h1)', 'h1'): array([0.2, 0.8]),\n", 706 | " ('v1', 'p(v1|h1)'): 1.0,\n", 707 | " ('p(v1|h1)', 'h1'): array([1., 1.]),\n", 708 | " ('h1', 'p(h2|h1)'): array([0.2, 0.8]),\n", 709 | " ('p(h2|h1)', 'h2'): array([0.26, 0.74]),\n", 710 | " ('h2', 'p(v2|h2)'): array([0.26, 0.74]),\n", 711 | " ('p(v2|h2)', 'v2'): array([0.23, 0.77])}" 712 | ] 713 | }, 714 | "execution_count": 18, 715 | "metadata": {}, 716 | "output_type": "execute_result" 717 | } 718 | ], 719 | "source": [ 720 | "m.messages" 721 | ] 722 | }, 723 | { 724 | "cell_type": "markdown", 725 | "metadata": {}, 726 | "source": [ 727 | " {('p(h1)', 'h1'): array([0.2, 0.8]),\n", 728 | " ('v1', 'p(v1|h1)'): 1.0,\n", 729 | " ('p(v1|h1)', 'h1'): array([1., 1.]),\n", 730 | " ('h1', 'p(h2|h1)'): array([0.2, 0.8]),\n", 731 | " ('p(h2|h1)', 'h2'): array([0.26, 0.74]),\n", 732 | " ('h2', 'p(v2|h2)'): array([0.26, 0.74]),\n", 733 | " ('p(v2|h2)', 'v2'): array([0.23, 0.77])}" 734 | ] 735 | }, 736 | { 737 | "cell_type": "code", 738 | "execution_count": 19, 739 | "metadata": {}, 740 | "outputs": [ 741 | { 742 | "data": { 743 | "text/plain": [ 744 | "array([0.2, 0.8])" 745 | ] 746 | }, 747 | "execution_count": 19, 748 | "metadata": {}, 749 | "output_type": "execute_result" 750 | } 751 | ], 752 | "source": [ 753 | "m.marginal(pgm.variable_from_name('v1'))" 754 | ] 755 | }, 756 | { 757 | "cell_type": "markdown", 758 | "metadata": {}, 759 | "source": [ 760 | " array([0.2, 0.8])" 761 | ] 762 | }, 763 | { 764 | "cell_type": "markdown", 765 | "metadata": {}, 766 | "source": [ 767 | "#### Example from book\n", 768 | "\n", 769 | "Example 5.1 on p79 of Barber has a numerical example. I can make sure I get the same values (`[0.5746, 0.318 , 0.1074]`)." 770 | ] 771 | }, 772 | { 773 | "cell_type": "code", 774 | "execution_count": 20, 775 | "metadata": {}, 776 | "outputs": [ 777 | { 778 | "data": { 779 | "text/plain": [ 780 | "array([0.5746, 0.318 , 0.1074])" 781 | ] 782 | }, 783 | "execution_count": 20, 784 | "metadata": {}, 785 | "output_type": "execute_result" 786 | } 787 | ], 788 | "source": [ 789 | "pgm = PGM.from_string(\"p(x5|x4)p(x4|x3)p(x3|x2)p(x2|x1)p(x1)\")\n", 790 | "\n", 791 | "p_x5_given_x4 = LabeledArray(np.array([[0.7, 0.5, 0], [0.3, 0.3, 0.5], [0, 0.2, 0.5]]), ['x5', 'x4'])\n", 792 | "assert is_conditional_prob(p_x5_given_x4, 'x5')\n", 793 | "p_x4_given_x3 = LabeledArray(p_x5_given_x4.array, ['x4', 'x3'])\n", 794 | "p_x3_given_x2 = LabeledArray(p_x5_given_x4.array, ['x3', 'x2'])\n", 795 | "p_x2_given_x1 = LabeledArray(p_x5_given_x4.array, ['x2', 'x1'])\n", 796 | "p_x1 = LabeledArray(np.array([1, 0, 0]), ['x1'])\n", 797 | "\n", 798 | "pgm.set_data({\n", 799 | " \"p(x5|x4)\": p_x5_given_x4,\n", 800 | " \"p(x4|x3)\": p_x4_given_x3,\n", 801 | " \"p(x3|x2)\": p_x3_given_x2,\n", 802 | " \"p(x2|x1)\": p_x2_given_x1,\n", 803 | " \"p(x1)\": p_x1,\n", 804 | "})\n", 805 | "\n", 806 | "Messages().marginal(pgm.variable_from_name('x5'))" 807 | ] 808 | }, 809 | { 810 | "cell_type": "markdown", 811 | "metadata": {}, 812 | "source": [ 813 | "## See Also\n", 814 | "\n", 815 | " - In my previous post [HMM](2018-05-02-hmm-alpha-recursion.ipynb), I implemented a form of belief propagation for Hidden Markov Models called Alpha Recursion.\n", 816 | " - Python library [pgmpy](https://github.com/pgmpy/pgmpy) does probabilistic graphical models and has nice code!" 817 | ] 818 | }, 819 | { 820 | "cell_type": "code", 821 | "execution_count": null, 822 | "metadata": {}, 823 | "outputs": [], 824 | "source": [] 825 | } 826 | ], 827 | "metadata": { 828 | "kernelspec": { 829 | "display_name": "Python 3", 830 | "language": "python", 831 | "name": "python3" 832 | }, 833 | "language_info": { 834 | "codemirror_mode": { 835 | "name": "ipython", 836 | "version": 3 837 | }, 838 | "file_extension": ".py", 839 | "mimetype": "text/x-python", 840 | "name": "python", 841 | "nbconvert_exporter": "python", 842 | "pygments_lexer": "ipython3", 843 | "version": "3.6.1" 844 | } 845 | }, 846 | "nbformat": 4, 847 | "nbformat_minor": 2 848 | } 849 | -------------------------------------------------------------------------------- /2019-06-02-climate-with-keras.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Climate classification with Keras\n", 8 | "\n", 9 | "One of my favorite hack projects was trying to create a [climate classification](2018-06-11-climate-classification-with-neural-nets.ipynb) by clustering learned embeddings of weather stations.\n", 10 | "\n", 11 | "The original model was written in TensorFlow. Since then, I've started to experiment with [Keras](https://keras.io). Because the climate classifier is pretty simple neural network, I rewrote the model using Keras and saved many lines of code.\n", 12 | "\n", 13 | "For the problem description and data preparation, see the [original post](2018-06-11-climate-classification-with-neural-nets.ipynb)." 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 9, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "import numpy as np\n", 23 | "import matplotlib.pyplot as plt\n", 24 | "from keras.models import Model\n", 25 | "from keras.layers import Dense, Flatten, Embedding, Input, Reshape, Concatenate\n", 26 | "from sklearn.cluster import KMeans" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "## Load the data\n", 34 | "\n", 35 | "I'm using the same files that I created for the [original post](2018-06-11-climate-classification-with-neural-nets.ipynb)." 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 2, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "DATA_PATH = 'data/weather'\n", 45 | "\n", 46 | "matrix_file = os.path.join(DATA_PATH, 'data/gsn-2000-TMAX-TMIN-PRCP.npz')\n", 47 | "\n", 48 | "# column labels\n", 49 | "STATION_ID_COL = 0\n", 50 | "MONTH_COL = 1\n", 51 | "DAY_COL = 2\n", 52 | "VAL_COLS_START = 3\n", 53 | "TMAX_COL = 3\n", 54 | "TMIN_COL = 4\n", 55 | "PRCP_COL = 5\n", 56 | "\n", 57 | "with np.load(matrix_file) as npz_data:\n", 58 | " weather_data = npz_data['data'].astype(np.int32)\n", 59 | "\n", 60 | "# I decided to switch over to using the day of the year instead of two \n", 61 | "# eh, this isn't perfect (it assumes all months have 31 days), but it helps differentiate \n", 62 | "# the first of the month vs the last. \n", 63 | "weather_data_day_of_year_data = 31 * (weather_data[:, MONTH_COL] - 1) + (weather_data[:, DAY_COL] - 1)" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 3, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "station_id_data = weather_data[:, STATION_ID_COL].reshape(-1, 1)\n", 73 | "weather_data_day_of_year_data = weather_data_day_of_year_data.reshape(-1, 1)\n", 74 | "weather_prediction = weather_data[:, VAL_COLS_START:]" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": {}, 80 | "source": [ 81 | "## Define the network\n", 82 | "\n", 83 | "I'll use the same network as [before](2018-06-11-climate-classification-with-neural-nets.ipynb).\n", 84 | "\n", 85 | "![](images/2018-06-11-nn.png)" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 4, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "# network parameters\n", 95 | "BATCH_SIZE = 50\n", 96 | "EMBEDDING_SIZE = 20\n", 97 | "HIDDEN_UNITS = 40\n", 98 | "\n", 99 | "# and classification parameters. How many climates I want.\n", 100 | "CLUSTER_NUMBER = 6\n", 101 | "\n", 102 | "# count how many stations there are\n", 103 | "NUM_STATIONS = np.max(weather_data[:, STATION_ID_COL]) + 1" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 5, 109 | "metadata": {}, 110 | "outputs": [ 111 | { 112 | "name": "stdout", 113 | "output_type": "stream", 114 | "text": [ 115 | "WARNING:tensorflow:From /Users/jessica/miniconda3/envs/blog/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n", 116 | "Instructions for updating:\n", 117 | "Colocations handled automatically by placer.\n", 118 | "__________________________________________________________________________________________________\n", 119 | "Layer (type) Output Shape Param # Connected to \n", 120 | "==================================================================================================\n", 121 | "station_id_input (InputLayer) (None, 1) 0 \n", 122 | "__________________________________________________________________________________________________\n", 123 | "embedded_stations (Embedding) (None, 1, 20) 24360 station_id_input[0][0] \n", 124 | "__________________________________________________________________________________________________\n", 125 | "embedded_station_reshape (Resha (None, 20) 0 embedded_stations[0][0] \n", 126 | "__________________________________________________________________________________________________\n", 127 | "month_day_input (InputLayer) (None, 1) 0 \n", 128 | "__________________________________________________________________________________________________\n", 129 | "station_and_day (Concatenate) (None, 21) 0 embedded_station_reshape[0][0] \n", 130 | " month_day_input[0][0] \n", 131 | "__________________________________________________________________________________________________\n", 132 | "hidden (Dense) (None, 40) 880 station_and_day[0][0] \n", 133 | "__________________________________________________________________________________________________\n", 134 | "prediction (Dense) (None, 3) 123 hidden[0][0] \n", 135 | "==================================================================================================\n", 136 | "Total params: 25,363\n", 137 | "Trainable params: 25,363\n", 138 | "Non-trainable params: 0\n", 139 | "__________________________________________________________________________________________________\n" 140 | ] 141 | } 142 | ], 143 | "source": [ 144 | "station_id_input = Input(shape=(1,), name='station_id_input')\n", 145 | "month_day_input = Input(shape=(1,), name='month_day_input')\n", 146 | "\n", 147 | "# Feed the station id through the embedding. This embeddings variable\n", 148 | "# is the whole point of this network!\n", 149 | "embedded_stations = Embedding(\n", 150 | " output_dim=EMBEDDING_SIZE, \n", 151 | " input_dim=NUM_STATIONS,\n", 152 | " name='embedded_stations'\n", 153 | ")(station_id_input)\n", 154 | "\n", 155 | "embedded_station_reshape = Reshape((EMBEDDING_SIZE,), name='embedded_station_reshape')(embedded_stations)\n", 156 | "\n", 157 | "station_and_day = Concatenate(name='station_and_day')([embedded_station_reshape, month_day_input])\n", 158 | "\n", 159 | "# Now build a little network that can learn to predict the weather\n", 160 | "hidden = Dense(HIDDEN_UNITS, activation='relu', name='hidden')(station_and_day)\n", 161 | "\n", 162 | "prediction = Dense(\n", 163 | " 3, # Output for each of the attributes of the weather prediction (max, min, precipitation)\n", 164 | " activation=None, # don't use an activation on predictions\n", 165 | " name='prediction'\n", 166 | ")(hidden)\n", 167 | "\n", 168 | "\n", 169 | "model = Model(inputs=[station_id_input, month_day_input], outputs=prediction)\n", 170 | "\n", 171 | "model.compile(optimizer='adam',\n", 172 | " loss='mean_squared_error')\n", 173 | "\n", 174 | "model.summary()" 175 | ] 176 | }, 177 | { 178 | "cell_type": "markdown", 179 | "metadata": {}, 180 | "source": [ 181 | " __________________________________________________________________________________________________\n", 182 | " Layer (type) Output Shape Param # Connected to \n", 183 | " ==================================================================================================\n", 184 | " station_id_input (InputLayer) (None, 1) 0 \n", 185 | " __________________________________________________________________________________________________\n", 186 | " embedded_stations (Embedding) (None, 1, 20) 24360 station_id_input[0][0] \n", 187 | " __________________________________________________________________________________________________\n", 188 | " embedded_station_reshape (Resha (None, 20) 0 embedded_stations[0][0] \n", 189 | " __________________________________________________________________________________________________\n", 190 | " month_day_input (InputLayer) (None, 1) 0 \n", 191 | " __________________________________________________________________________________________________\n", 192 | " station_and_day (Concatenate) (None, 21) 0 embedded_station_reshape[0][0] \n", 193 | " month_day_input[0][0] \n", 194 | " __________________________________________________________________________________________________\n", 195 | " hidden (Dense) (None, 40) 880 station_and_day[0][0] \n", 196 | " __________________________________________________________________________________________________\n", 197 | " prediction (Dense) (None, 3) 123 hidden[0][0] \n", 198 | " ==================================================================================================\n", 199 | " Total params: 25,363\n", 200 | " Trainable params: 25,363\n", 201 | " Non-trainable params: 0\n", 202 | " __________________________________________________________________________________________________" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 6, 208 | "metadata": {}, 209 | "outputs": [ 210 | { 211 | "name": "stdout", 212 | "output_type": "stream", 213 | "text": [ 214 | "WARNING:tensorflow:From /Users/jessica/miniconda3/envs/blog/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n", 215 | "Instructions for updating:\n", 216 | "Use tf.cast instead.\n", 217 | "Epoch 1/1\n", 218 | "2417066/2417066 [==============================] - 178s 74us/step - loss: 4390.3217\n" 219 | ] 220 | }, 221 | { 222 | "data": { 223 | "text/plain": [ 224 | "" 225 | ] 226 | }, 227 | "execution_count": 6, 228 | "metadata": {}, 229 | "output_type": "execute_result" 230 | } 231 | ], 232 | "source": [ 233 | "model.fit(\n", 234 | " [\n", 235 | " station_id_data,\n", 236 | " weather_data_day_of_year_data,\n", 237 | " ], \n", 238 | " weather_prediction.reshape(-1, 3), \n", 239 | " epochs=1,\n", 240 | ")" 241 | ] 242 | }, 243 | { 244 | "cell_type": "markdown", 245 | "metadata": {}, 246 | "source": [ 247 | " Epoch 1/1\n", 248 | " 2417066/2417066 [==============================] - 178s 74us/step - loss: 4390.3217\n", 249 | "\n", 250 | "\n", 251 | "There is less boilerplate in the Keras code compared to my TensorFlow implementation in the [original post](2018-06-11-climate-classification-with-neural-nets.ipynb). I think it's cool that most of the code is doing work describing the network." 252 | ] 253 | }, 254 | { 255 | "cell_type": "markdown", 256 | "metadata": {}, 257 | "source": [ 258 | "## Classify the embeddings\n", 259 | "\n", 260 | "Finally, I run KMeans on the trained embeddings to assign a \"climate\"." 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": 7, 266 | "metadata": {}, 267 | "outputs": [], 268 | "source": [ 269 | "trained_embeddings = model.get_layer('embedded_stations').get_weights()[0]" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": 8, 275 | "metadata": {}, 276 | "outputs": [], 277 | "source": [ 278 | "with open(os.path.join(DATA_PATH, 'data/stations') )as f:\n", 279 | " list_of_stations = [line.strip() for line in f.readlines()]\n", 280 | "\n", 281 | "kmeans = KMeans(n_clusters=CLUSTER_NUMBER, random_state=0).fit(trained_embeddings)\n", 282 | "\n", 283 | "# I can export the classification here:\n", 284 | "for station, label in zip(list_of_stations, kmeans.labels_):\n", 285 | " #print('{}\\t{}'.format(station, label))\n", 286 | " pass" 287 | ] 288 | }, 289 | { 290 | "cell_type": "markdown", 291 | "metadata": {}, 292 | "source": [ 293 | "![](images/2019-06-02-new-map.png)\n", 294 | "\n", 295 | "I get similar climates as before. Since there's randomness in the neural network initializations and batches and in the KMeans, I wouldn't expect to get exactly the same. For example, the latitude boundaries on the East coast have shifted compared to before.\n", 296 | "\n", 297 | "For this post, I also switched to [Cartopy](https://github.com/SciTools/cartopy) from [Basemap](https://matplotlib.org/basemap/users/intro.html#cartopy-new-management-and-eol-announcement)." 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": null, 303 | "metadata": {}, 304 | "outputs": [], 305 | "source": [] 306 | } 307 | ], 308 | "metadata": { 309 | "jessica": {}, 310 | "kernelspec": { 311 | "display_name": "Python 3", 312 | "language": "python", 313 | "name": "python3" 314 | }, 315 | "language_info": { 316 | "codemirror_mode": { 317 | "name": "ipython", 318 | "version": 3 319 | }, 320 | "file_extension": ".py", 321 | "mimetype": "text/x-python", 322 | "name": "python", 323 | "nbconvert_exporter": "python", 324 | "pygments_lexer": "ipython3", 325 | "version": "3.7.3" 326 | } 327 | }, 328 | "nbformat": 4, 329 | "nbformat_minor": 2 330 | } 331 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Jessica Stringham 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Notebooks 2 | 3 | This repo contains the Jupyter notebooks I post to my blog, [jessicastringham.net](https://jessicastringham.net). 4 | 5 | # create environment of your choice, for example 6 | virtualenv -p python3 venv 7 | 8 | # install requirements 9 | pip install -r requirements.txt 10 | 11 | # run notebook 12 | jupyter notebook 13 | -------------------------------------------------------------------------------- /data/.git_keep_folder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/data/.git_keep_folder -------------------------------------------------------------------------------- /data/coffee_comparisons.json: -------------------------------------------------------------------------------- 1 | {"metric": "meet", "first": "artisan_broughton", "last": "cairngorm_george", "weight": 0.5} 2 | {"metric": "meet", "first": "wellington", "last": "twelve_triangles_portobello", "weight": 0.5} 3 | {"metric": "group", "first": "machina_espresso_uni", "last": "milkman", "weight": 1.0} 4 | {"metric": "group", "first": "brewlab", "last": "cairngorm", "weight": 0.5} 5 | {"metric": "group", "first": "soderberg_meadows", "last": "press", "weight": 1.0} 6 | {"metric": "group", "first": "brewlab", "last": "fortitude", "weight": 1.0} 7 | {"metric": "group", "first": "soderberg_meadows", "last": "castello", "weight": 1.0} 8 | {"metric": "group", "first": "machina_espresso_uni", "last": "cult", "weight": 0} 9 | {"metric": "group", "first": "brewlab", "last": "cairngorm", "weight": 0.5} 10 | {"metric": "group", "first": "brewlab", "last": "machina_espresso_bruntsfield", "weight": 0.5} 11 | {"metric": "group", "first": "artisan_broughton", "last": "milkman", "weight": 0} 12 | {"metric": "group", "first": "levels", "last": "cairngorm", "weight": 0.5} 13 | {"metric": "group", "first": "levels", "last": "machina_espresso_uni", "weight": 0.5} 14 | {"metric": "group", "first": "artisan_stockbridge", "last": "twelve_triangles_portobello", "weight": 1.0} 15 | {"metric": "meet", "first": "machina_espresso_uni", "last": "lowdown", "weight": 0.5} 16 | {"metric": "meet", "first": "machina_espresso_uni", "last": "levels", "weight": 0.5} 17 | {"metric": "meet", "first": "cult", "last": "wellington", "weight": 1.0} 18 | {"metric": "meet", "first": "cult", "last": "artisan_stockbridge", "weight": 0.5} 19 | {"metric": "meet", "first": "press", "last": "wellington", "weight": 0.5} 20 | {"metric": "meet", "first": "black_medicine", "last": "wellington", "weight": 0.5} 21 | {"metric": "meet", "first": "cult", "last": "castello", "weight": 1.0} 22 | {"metric": "meet", "first": "artisan_broughton", "last": "wellington", "weight": 0.5} 23 | {"metric": "meet", "first": "machina_espresso_uni", "last": "cult", "weight": 0} 24 | {"metric": "meet", "first": "fortitude", "last": "wellington", "weight": 0.5} 25 | {"metric": "meet", "first": "fortitude", "last": "soderberg_broughton", "weight": 0.5} 26 | {"metric": "meet", "first": "black_medicine", "last": "machina_espresso_bruntsfield", "weight": 0.5} 27 | {"metric": "meet", "first": "levels", "last": "wellington", "weight": 1.0} 28 | {"metric": "meet", "first": "milkman", "last": "twelve_triangles_portobello", "weight": 0.5} 29 | {"metric": "meet", "first": "soderberg_broughton", "last": "castello", "weight": 1.0} 30 | {"metric": "meet", "first": "cairngorm_george", "last": "wellington", "weight": 0.5} 31 | {"metric": "meet", "first": "soderberg_meadows", "last": "filament", "weight": 1.0} 32 | {"metric": "meet", "first": "levels", "last": "soderberg_broughton", "weight": 0.5} 33 | {"metric": "meet", "first": "cult", "last": "lowdown", "weight": 0.5} 34 | {"metric": "meet", "first": "cairngorm", "last": "wellington", "weight": 0.5} 35 | {"metric": "meet", "first": "press", "last": "soderberg_broughton", "weight": 0} 36 | {"metric": "meet", "first": "artisan_stockbridge", "last": "fortitude", "weight": 0.5} 37 | {"metric": "meet", "first": "cairngorm_george", "last": "fortitude", "weight": 0.5} 38 | {"metric": "meet", "first": "artisan_stockbridge", "last": "soderberg_broughton", "weight": 0.5} 39 | {"metric": "meet", "first": "levels", "last": "twelve_triangles_portobello", "weight": 1.0} 40 | {"metric": "meet", "first": "cairngorm_george", "last": "twelve_triangles_portobello", "weight": 1.0} 41 | {"metric": "meet", "first": "lowdown", "last": "soderberg_broughton", "weight": 0} 42 | {"metric": "meet", "first": "soderberg_broughton", "last": "lowdown", "weight": 0} 43 | {"metric": "meet", "first": "soderberg_meadows", "last": "soderberg_broughton", "weight": 1.0} 44 | {"metric": "meet", "first": "black_medicine", "last": "cairngorm", "weight": 0.5} 45 | {"metric": "meet", "first": "soderberg_meadows", "last": "machina_espresso_bruntsfield", "weight": 0.5} 46 | {"metric": "meet", "first": "brewlab", "last": "wellington", "weight": 1.0} 47 | {"metric": "meet", "first": "filament", "last": "wellington", "weight": 1.0} 48 | {"metric": "meet", "first": "filament", "last": "wellington", "weight": 1.0} 49 | {"metric": "meet", "first": "brewlab", "last": "machina_espresso_uni", "weight": 0.5} 50 | {"metric": "meet", "first": "levels", "last": "machina_espresso_uni", "weight": 0.5} 51 | {"weight": 1.0, "first": "brewlab", "last": "wellington", "metric": "laptop"} 52 | {"weight": 1.0, "first": "machina_espresso_uni", "last": "castello", "metric": "laptop"} 53 | {"weight": 0.5, "first": "levels", "last": "machina_espresso_bruntsfield", "metric": "laptop"} 54 | {"weight": 1.0, "first": "cairngorm", "last": "twelve_triangles_portobello", "metric": "laptop"} 55 | {"weight": 0.5, "first": "black_medicine", "last": "cairngorm_george", "metric": "laptop"} 56 | {"weight": 1.0, "first": "soderberg_meadows", "last": "castello", "metric": "laptop"} 57 | {"weight": 1.0, "first": "machina_espresso_uni", "last": "twelve_triangles_portobello", "metric": "laptop"} 58 | {"weight": 0, "first": "cult", "last": "machina_espresso_bruntsfield", "metric": "laptop"} 59 | {"weight": 0.5, "first": "soderberg_meadows", "last": "artisan_stockbridge", "metric": "laptop"} 60 | {"weight": 0.5, "first": "black_medicine", "last": "machina_espresso_bruntsfield", "metric": "laptop"} 61 | {"weight": 1.0, "first": "cairngorm", "last": "twelve_triangles_portobello", "metric": "laptop"} 62 | {"weight": 0, "first": "cairngorm_george", "last": "machina_espresso_bruntsfield", "metric": "laptop"} 63 | {"weight": 0, "first": "wellington", "last": "castello", "metric": "laptop"} 64 | {"weight": 1.0, "first": "fortitude", "last": "castello", "metric": "laptop"} 65 | {"weight": 1.0, "first": "press", "last": "twelve_triangles_portobello", "metric": "laptop"} 66 | {"weight": 1.0, "first": "machina_espresso_bruntsfield", "last": "artisan_broughton", "metric": "laptop"} 67 | {"weight": 0.5, "first": "filament", "last": "press", "metric": "laptop"} 68 | {"weight": 1.0, "first": "lowdown", "last": "artisan_broughton", "metric": "laptop"} 69 | {"weight": 0.5, "first": "black_medicine", "last": "machina_espresso_uni", "metric": "laptop"} 70 | {"weight": 0.5, "first": "artisan_broughton", "last": "twelve_triangles_portobello", "metric": "laptop"} 71 | {"weight": 1.0, "first": "soderberg_meadows", "last": "soderberg_broughton", "metric": "laptop"} 72 | {"weight": 1.0, "first": "artisan_broughton", "last": "levels", "metric": "reading"} 73 | -------------------------------------------------------------------------------- /data/coffee_metadata.yaml: -------------------------------------------------------------------------------- 1 | - name: BrewLab 2 | id: brewlab 3 | location: university 4 | 5 | - name: Cult Espresso 6 | id: cult 7 | location: university 8 | 9 | - name: Wellington 10 | id: wellington 11 | location: new_town 12 | 13 | - name: Filament 14 | id: filament 15 | location: meadows 16 | 17 | - name: Artisan Roast (Broughton) 18 | id: artisan_broughton 19 | location: meadows 20 | 21 | - name: Artisan Roast (Stockbridge) 22 | id: artisan_stockbridge 23 | location: stockbridge 24 | 25 | - name: Twelve Triangles 26 | id: twelve_triangles_portobello 27 | location: portobello 28 | 29 | - name: Söderberg (Meadows) 30 | id: soderberg_meadows 31 | location: university 32 | 33 | - name: Söderberg (Broughton) 34 | id: soderberg_broughton 35 | location: new_town 36 | 37 | - name: Fortitude 38 | id: fortitude 39 | location: new_town 40 | 41 | - name: Black Medicine 42 | id: black_medicine 43 | location: university 44 | 45 | - name: Press 46 | id: press 47 | location: university 48 | 49 | - name: The Milkman 50 | id: milkman 51 | location: old_town 52 | 53 | - name: Castello 54 | id: castello 55 | location: new_town 56 | 57 | - name: Machina Espresso (Meadows) 58 | id: machina_espresso_bruntsfield 59 | location: bruntsfield 60 | 61 | - name: Machine espresso close 62 | id: machina_espresso_uni 63 | location: university 64 | 65 | - name: Cairngorm close 66 | id: cairngorm_george 67 | location: new_town 68 | 69 | - name: Cairngorn (Queensferry) 70 | id: cairngorm 71 | location: west_end 72 | 73 | - name: Lowdown 74 | id: lowdown 75 | location: new_town 76 | 77 | - name: Levels 78 | id: levels 79 | location: cowgate 80 | -------------------------------------------------------------------------------- /images/2018-01-03-different-w.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-01-03-different-w.png -------------------------------------------------------------------------------- /images/2018-01-08-linear-sample-example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-01-08-linear-sample-example.png -------------------------------------------------------------------------------- /images/2018-01-08-sigmoid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-01-08-sigmoid.png -------------------------------------------------------------------------------- /images/2018-01-09-mean-far-away.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-01-09-mean-far-away.png -------------------------------------------------------------------------------- /images/2018-01-09-mean.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-01-09-mean.png -------------------------------------------------------------------------------- /images/2018-01-10-five-degrees-uncertainty-another-point.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-01-10-five-degrees-uncertainty-another-point.png -------------------------------------------------------------------------------- /images/2018-01-10-five-degrees-uncertainty-few.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-01-10-five-degrees-uncertainty-few.png -------------------------------------------------------------------------------- /images/2018-01-10-five-degrees.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-01-10-five-degrees.png -------------------------------------------------------------------------------- /images/2018-01-10-sample-with-error.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-01-10-sample-with-error.png -------------------------------------------------------------------------------- /images/2018-01-10-samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-01-10-samples.png -------------------------------------------------------------------------------- /images/2018-01-10-uncertainty-zoom-out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-01-10-uncertainty-zoom-out.png -------------------------------------------------------------------------------- /images/2018-01-10-uncertainty.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-01-10-uncertainty.png -------------------------------------------------------------------------------- /images/2018-05-02-bumps-creaks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-02-bumps-creaks.png -------------------------------------------------------------------------------- /images/2018-05-02-filtering.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-02-filtering.png -------------------------------------------------------------------------------- /images/2018-05-02-hmm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-02-hmm.png -------------------------------------------------------------------------------- /images/2018-05-02-room-coordinates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-02-room-coordinates.png -------------------------------------------------------------------------------- /images/2018-05-02-sample-paths.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-02-sample-paths.png -------------------------------------------------------------------------------- /images/2018-05-02-transition.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-02-transition.png -------------------------------------------------------------------------------- /images/2018-05-02-v-given-h.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-02-v-given-h.png -------------------------------------------------------------------------------- /images/2018-05-02-v-over-time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-02-v-over-time.png -------------------------------------------------------------------------------- /images/2018-05-04-bayes-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-04-bayes-1.png -------------------------------------------------------------------------------- /images/2018-05-04-bayes-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-04-bayes-2.png -------------------------------------------------------------------------------- /images/2018-05-04-sigmoid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-04-sigmoid.png -------------------------------------------------------------------------------- /images/2018-05-12-posterior.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-12-posterior.png -------------------------------------------------------------------------------- /images/2018-05-12-simulated-data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-12-simulated-data.png -------------------------------------------------------------------------------- /images/2018-05-12-trace.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-12-trace.png -------------------------------------------------------------------------------- /images/2018-05-12-weight-samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-12-weight-samples.png -------------------------------------------------------------------------------- /images/2018-05-13-viterbi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-13-viterbi.png -------------------------------------------------------------------------------- /images/2018-05-16-4d-plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-16-4d-plot.png -------------------------------------------------------------------------------- /images/2018-05-16-ex.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-16-ex.png -------------------------------------------------------------------------------- /images/2018-05-16-labeled-plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-16-labeled-plot.png -------------------------------------------------------------------------------- /images/2018-05-16-svd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-16-svd.png -------------------------------------------------------------------------------- /images/2018-05-18-cov-element.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-18-cov-element.png -------------------------------------------------------------------------------- /images/2018-05-18-cov-similarity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-18-cov-similarity.png -------------------------------------------------------------------------------- /images/2018-05-18-cov.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-18-cov.png -------------------------------------------------------------------------------- /images/2018-05-18-different-cov.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-18-different-cov.png -------------------------------------------------------------------------------- /images/2018-05-18-large-ell.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-18-large-ell.png -------------------------------------------------------------------------------- /images/2018-05-18-large-sigma.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-18-large-sigma.png -------------------------------------------------------------------------------- /images/2018-05-18-observations.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-18-observations.png -------------------------------------------------------------------------------- /images/2018-05-18-poor-prediction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-18-poor-prediction.png -------------------------------------------------------------------------------- /images/2018-05-18-prediction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-18-prediction.png -------------------------------------------------------------------------------- /images/2018-05-18-prior-beliefs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-18-prior-beliefs.png -------------------------------------------------------------------------------- /images/2018-05-18-small-ell.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-18-small-ell.png -------------------------------------------------------------------------------- /images/2018-05-20-comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-20-comparison.png -------------------------------------------------------------------------------- /images/2018-05-20-distances.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-20-distances.png -------------------------------------------------------------------------------- /images/2018-05-20-from-gram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-20-from-gram.png -------------------------------------------------------------------------------- /images/2018-05-20-north-america-per-state-province.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-20-north-america-per-state-province.png -------------------------------------------------------------------------------- /images/2018-05-20-north-america.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-20-north-america.png -------------------------------------------------------------------------------- /images/2018-05-20-pc-scores-from-x.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-20-pc-scores-from-x.png -------------------------------------------------------------------------------- /images/2018-05-22-2d-sigmoid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-22-2d-sigmoid.png -------------------------------------------------------------------------------- /images/2018-05-22-bad-fit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-22-bad-fit.png -------------------------------------------------------------------------------- /images/2018-05-22-fitted-weights.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-22-fitted-weights.png -------------------------------------------------------------------------------- /images/2018-05-22-fuzzy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-22-fuzzy.png -------------------------------------------------------------------------------- /images/2018-05-22-less-steep.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-22-less-steep.png -------------------------------------------------------------------------------- /images/2018-05-22-lin-sep.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-22-lin-sep.png -------------------------------------------------------------------------------- /images/2018-05-22-optimizer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-22-optimizer.png -------------------------------------------------------------------------------- /images/2018-05-22-reg-vs-no.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-22-reg-vs-no.png -------------------------------------------------------------------------------- /images/2018-05-22-sigmoid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-22-sigmoid.png -------------------------------------------------------------------------------- /images/2018-05-26-dist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-26-dist.png -------------------------------------------------------------------------------- /images/2018-05-27-cairngorm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-27-cairngorm.png -------------------------------------------------------------------------------- /images/2018-05-27-comparisons-box.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-27-comparisons-box.png -------------------------------------------------------------------------------- /images/2018-05-27-pairwise-comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-27-pairwise-comparison.png -------------------------------------------------------------------------------- /images/2018-05-27-posteriors.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-27-posteriors.png -------------------------------------------------------------------------------- /images/2018-05-27-predictions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-27-predictions.png -------------------------------------------------------------------------------- /images/2018-05-27-ranking.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-27-ranking.png -------------------------------------------------------------------------------- /images/2018-05-27-spoon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-27-spoon.png -------------------------------------------------------------------------------- /images/2018-05-27-traceplot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-27-traceplot.png -------------------------------------------------------------------------------- /images/2018-05-27-twelve-triangles.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-05-27-twelve-triangles.png -------------------------------------------------------------------------------- /images/2018-06-11-data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-06-11-data.png -------------------------------------------------------------------------------- /images/2018-06-11-koppen.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-06-11-koppen.png -------------------------------------------------------------------------------- /images/2018-06-11-nn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-06-11-nn.png -------------------------------------------------------------------------------- /images/2018-06-11-silly-nn-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-06-11-silly-nn-1.png -------------------------------------------------------------------------------- /images/2018-06-11-silly-nn-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-06-11-silly-nn-2.png -------------------------------------------------------------------------------- /images/2018-06-11-silly-nn-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-06-11-silly-nn-3.png -------------------------------------------------------------------------------- /images/2018-06-11-stations.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-06-11-stations.png -------------------------------------------------------------------------------- /images/2018-06-11-tensorboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-06-11-tensorboard.png -------------------------------------------------------------------------------- /images/2018-06-11-us.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-06-11-us.png -------------------------------------------------------------------------------- /images/2018-06-11-year-precipitation-predictions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-06-11-year-precipitation-predictions.png -------------------------------------------------------------------------------- /images/2018-06-11-year-temp-predictions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-06-11-year-temp-predictions.png -------------------------------------------------------------------------------- /images/2018-06-11-year-temp-single-pred.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-06-11-year-temp-single-pred.png -------------------------------------------------------------------------------- /images/2018-12-27-comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-12-27-comparison.png -------------------------------------------------------------------------------- /images/2018-12-27-divergence-ex-bimodal-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-12-27-divergence-ex-bimodal-1.png -------------------------------------------------------------------------------- /images/2018-12-27-divergence-ex-bimodal-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-12-27-divergence-ex-bimodal-2.png -------------------------------------------------------------------------------- /images/2018-12-27-divergence-ex.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-12-27-divergence-ex.png -------------------------------------------------------------------------------- /images/2018-12-27-examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-12-27-examples.png -------------------------------------------------------------------------------- /images/2018-12-27-learning.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-12-27-learning.png -------------------------------------------------------------------------------- /images/2018-12-27-minimizing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2018-12-27-minimizing.png -------------------------------------------------------------------------------- /images/2019-01-09-factor-graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-01-09-factor-graph.png -------------------------------------------------------------------------------- /images/2019-03-17-content-convs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-03-17-content-convs.png -------------------------------------------------------------------------------- /images/2019-03-17-ghost-church.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-03-17-ghost-church.png -------------------------------------------------------------------------------- /images/2019-03-17-initial-image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-03-17-initial-image.png -------------------------------------------------------------------------------- /images/2019-03-17-mt-hood.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-03-17-mt-hood.jpg -------------------------------------------------------------------------------- /images/2019-03-17-orig-content.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-03-17-orig-content.png -------------------------------------------------------------------------------- /images/2019-03-17-orig-convs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-03-17-orig-convs.png -------------------------------------------------------------------------------- /images/2019-03-17-orig-primordial-chaos.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-03-17-orig-primordial-chaos.jpg -------------------------------------------------------------------------------- /images/2019-03-17-orig-scream.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-03-17-orig-scream.png -------------------------------------------------------------------------------- /images/2019-03-17-orig-start.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-03-17-orig-start.png -------------------------------------------------------------------------------- /images/2019-03-17-orig-style.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-03-17-orig-style.png -------------------------------------------------------------------------------- /images/2019-03-17-orig-tsunami.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-03-17-orig-tsunami.jpg -------------------------------------------------------------------------------- /images/2019-03-17-primordial-chaos.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-03-17-primordial-chaos.png -------------------------------------------------------------------------------- /images/2019-03-17-result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-03-17-result.png -------------------------------------------------------------------------------- /images/2019-03-17-scream.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-03-17-scream.png -------------------------------------------------------------------------------- /images/2019-03-17-style-convs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-03-17-style-convs.png -------------------------------------------------------------------------------- /images/2019-03-17-sutro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-03-17-sutro.png -------------------------------------------------------------------------------- /images/2019-03-17-three-images.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-03-17-three-images.png -------------------------------------------------------------------------------- /images/2019-03-17-tsunami.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-03-17-tsunami.png -------------------------------------------------------------------------------- /images/2019-06-02-new-map.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-06-02-new-map.png -------------------------------------------------------------------------------- /images/2019-07-01-delay.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-07-01-delay.png -------------------------------------------------------------------------------- /images/2019-07-01-delays_system.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-07-01-delays_system.png -------------------------------------------------------------------------------- /images/2019-07-01-no-delay.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-07-01-no-delay.png -------------------------------------------------------------------------------- /images/2019-07-01-overfishing-equal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-07-01-overfishing-equal.png -------------------------------------------------------------------------------- /images/2019-07-01-overfishing-oscillations-spiral.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-07-01-overfishing-oscillations-spiral.png -------------------------------------------------------------------------------- /images/2019-07-01-overfishing-oscillations.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-07-01-overfishing-oscillations.png -------------------------------------------------------------------------------- /images/2019-07-01-overfishing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-07-01-overfishing.png -------------------------------------------------------------------------------- /images/2019-07-01-renewable-equations.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-07-01-renewable-equations.png -------------------------------------------------------------------------------- /images/2019-07-01-renewable_resource.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-07-01-renewable_resource.png -------------------------------------------------------------------------------- /images/2019-07-01-response-time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-07-01-response-time.png -------------------------------------------------------------------------------- /images/2019-07-01-spiral.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-07-01-spiral.png -------------------------------------------------------------------------------- /images/2019-07-01-temperatures.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-07-01-temperatures.png -------------------------------------------------------------------------------- /images/2019-07-01-thermostat_system.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-07-01-thermostat_system.png -------------------------------------------------------------------------------- /images/2019-07-01-yield-per-unit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-07-01-yield-per-unit.png -------------------------------------------------------------------------------- /images/2019-11-06-autoencoder-diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-11-06-autoencoder-diagram.png -------------------------------------------------------------------------------- /images/2019-11-06-autoencoder-example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-11-06-autoencoder-example.png -------------------------------------------------------------------------------- /images/2019-11-06-chernoff-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-11-06-chernoff-1.png -------------------------------------------------------------------------------- /images/2019-11-06-chernoff-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-11-06-chernoff-2.png -------------------------------------------------------------------------------- /images/2019-11-06-clipping.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-11-06-clipping.png -------------------------------------------------------------------------------- /images/2019-11-06-dataset-examples-points.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-11-06-dataset-examples-points.png -------------------------------------------------------------------------------- /images/2019-11-06-dataset-examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-11-06-dataset-examples.png -------------------------------------------------------------------------------- /images/2019-11-06-ex.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-11-06-ex.png -------------------------------------------------------------------------------- /images/2019-11-06-face-changing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-11-06-face-changing.png -------------------------------------------------------------------------------- /images/2019-11-06-face-interpolate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-11-06-face-interpolate.png -------------------------------------------------------------------------------- /images/2019-11-06-not-fixed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-11-06-not-fixed.png -------------------------------------------------------------------------------- /images/2019-11-06-other-faces.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2019-11-06-other-faces.png -------------------------------------------------------------------------------- /images/2020-01-06-facemap-labeled.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2020-01-06-facemap-labeled.png -------------------------------------------------------------------------------- /images/2020-01-06-facemap-many.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2020-01-06-facemap-many.png -------------------------------------------------------------------------------- /images/2020-01-06-neighbors.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/images/2020-01-06-neighbors.png -------------------------------------------------------------------------------- /nb_code/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessstringham/notebooks/942af7a1b0200dd0d5de351ab3d6359c1edede94/nb_code/__init__.py -------------------------------------------------------------------------------- /nb_code/hmm_alpha_recursion.py: -------------------------------------------------------------------------------- 1 | '''This code was generated from the notebook 2018-05-02-hmm-alpha-recursion.ipynb''' 2 | 3 | width = 6 4 | height = 5 5 | num_hidden_states = width * height 6 | 7 | # prob of starting starting locations 8 | p_hidden_start = np.ones(num_hidden_states) / num_hidden_states 9 | 10 | # verify it's a valid probability distribution 11 | assert np.all(np.isclose(np.sum(p_hidden_start), 1)) 12 | assert np.all(p_hidden_start >= 0) 13 | 14 | def create_transition_joint(width, height): 15 | num_hidden_states = width * height 16 | 17 | # begin by building an unnormalized matrix with 1s for all legal moves. 18 | unnormalized_transition_joint = np.zeros((num_hidden_states, num_hidden_states)) 19 | 20 | # This will help me map from height and width to the state 21 | map_x_y_to_hidden_state_id = np.arange(num_hidden_states).reshape(height, width).T 22 | 23 | for x in range(width): 24 | for y in range(height): 25 | h_t = map_x_y_to_hidden_state_id[x, y] 26 | 27 | # hax to go through each possible direction 28 | for d in range(4): 29 | new_x = x 30 | new_y = y 31 | if d // 2 == 0: 32 | # move left or right! 33 | new_x = x + ((d % 2) * 2 - 1) 34 | else: 35 | # move up or down! 36 | new_y = y + ((d % 2) * 2 - 1) 37 | 38 | # make sure they don't walk through walls 39 | if any(( 40 | new_x > width - 1, 41 | new_x < 0, 42 | new_y > height - 1, 43 | new_y < 0 44 | )): 45 | continue 46 | 47 | h_t_minus_1 = map_x_y_to_hidden_state_id[new_x, new_y] 48 | unnormalized_transition_joint[h_t_minus_1][h_t] = 1 49 | 50 | # normalize! 51 | p_transition_joint = unnormalized_transition_joint / np.sum(unnormalized_transition_joint) 52 | 53 | # make sure this is a joint probability 54 | assert np.isclose(np.sum(p_transition_joint), 1) 55 | # not super necessary, but eh 56 | assert np.all(p_transition_joint >= 0) 57 | 58 | return p_transition_joint 59 | 60 | def create_transition(width, height): 61 | p_transition_joint = create_transition_joint(width, height) 62 | 63 | num_hidden_states = width * height 64 | 65 | p_transition = np.zeros((num_hidden_states, num_hidden_states)) 66 | 67 | for old_state in range(num_hidden_states): 68 | p_transition[:, old_state] = p_transition_joint[:, old_state] / np.sum(p_transition_joint[:, old_state]) 69 | 70 | # verify it's a conditional distribution 71 | assert np.all(np.sum(p_transition, axis=0)) == 1 72 | 73 | return p_transition 74 | 75 | p_transition = create_transition(width, height) 76 | 77 | def plot_state_in_room(state_id, width=width, height=height): 78 | h = np.zeros(width * height) 79 | h[state_id] = 1 80 | return h.reshape(height, width) 81 | 82 | def make_sound_map(): 83 | NUM_SOUNDS = 10 84 | LOW_PROB = 0.1 85 | HIGH_PROB = 0.9 86 | 87 | # everything has at least LOW_PROB of triggering the sound 88 | grid = LOW_PROB * np.ones(num_hidden_states) 89 | # select NUM_BUMP_CREAKS to make HIGH_PROB 90 | locs = np.random.choice( 91 | num_hidden_states, 92 | size=NUM_SOUNDS, 93 | replace=False 94 | ) 95 | grid[locs] = HIGH_PROB 96 | 97 | return grid 98 | 99 | prob_bump_true_given_location = make_sound_map() 100 | prob_creak_true_given_location = make_sound_map() 101 | 102 | num_visible_states = 4 103 | 104 | def get_emission_matrix(prob_bump_true_given_location, prob_creak_true_given_location): 105 | # prob_bump_given_state[v][state] = p(v | state) 106 | p_emission = np.vstack(( 107 | prob_bump_true_given_location * prob_creak_true_given_location, 108 | prob_bump_true_given_location * (1 - prob_creak_true_given_location), 109 | (1 - prob_bump_true_given_location) * prob_creak_true_given_location, 110 | (1 - prob_bump_true_given_location) * (1 - prob_creak_true_given_location), 111 | )) 112 | 113 | assert np.all(np.sum(p_emission, axis=0)) == 1 114 | 115 | return p_emission 116 | 117 | p_emission = get_emission_matrix(prob_bump_true_given_location, prob_creak_true_given_location) 118 | 119 | # 1 means True. ex: [1, 0] means bump=True, creak=False 120 | map_visible_state_to_bump_creak = np.vstack(( 121 | [1, 1], 122 | [1, 0], 123 | [0, 1], 124 | [0, 0], 125 | )) 126 | 127 | timesteps = 10 128 | 129 | hiddens = np.zeros(timesteps, dtype=int) 130 | visibles = np.zeros(timesteps, dtype=int) 131 | 132 | hiddens[0] = np.random.choice(num_hidden_states, p=p_hidden_start) 133 | visibles[0] = np.random.choice( 134 | num_visible_states, 135 | p=p_emission[:, hiddens[0]] 136 | ) 137 | 138 | for t in range(1, timesteps): 139 | hiddens[t] = np.random.choice( 140 | num_hidden_states, 141 | p=p_transition[:, hiddens[t - 1]] 142 | ) 143 | 144 | visibles[t] = np.random.choice( 145 | num_visible_states, 146 | p=p_emission[:, hiddens[t]] 147 | ) 148 | 149 | def alpha_recursion(visibles, p_hidden_start, p_transition, p_emission): 150 | num_timestamps = visibles.shape[0] 151 | num_hidden_states = p_transition.shape[0] 152 | 153 | # There will be one alpha for each timestamp 154 | alphas = np.zeros((num_timestamps, num_hidden_states)) 155 | 156 | # alpha(h_1) = p(h_1) * p(v_1 | h_1) 157 | alphas[0] = p_hidden_start * p_emission[visibles[0]] 158 | 159 | # normalize to avoid overflow 160 | alphas[0] /= np.sum(alphas[0]) 161 | for t in range(1, num_timestamps): 162 | # p(v_s | h_s) 163 | # size: new_states 164 | corrector = p_emission[visibles[t]] 165 | 166 | # sum over all hidden states for the previous timestep and multiply the 167 | # transition prob by the previous alpha 168 | # transition_matrix size: new_state x old_state 169 | # alphas[t_minus_1].T size: old_state x 1 170 | # predictor size: new_state x 1, 171 | predictor = p_transition @ alphas[t - 1, None].T 172 | 173 | # alpha(h_s) 174 | alphas[t, :] = corrector * predictor[:, 0] 175 | 176 | # normalize 177 | alphas[t] /= np.sum(alphas[t]) 178 | 179 | return alphas 180 | 181 | alphas = alpha_recursion( 182 | visibles, 183 | p_hidden_start, 184 | p_transition, 185 | p_emission, 186 | ) 187 | 188 | assert np.all(np.isclose(np.sum(alphas, axis=1), 1)) -------------------------------------------------------------------------------- /nb_code/viterbi.py: -------------------------------------------------------------------------------- 1 | '''This code was generated from the notebook 2018-05-13-viterbi-message-passing.ipynb''' 2 | 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | import nb_code.hmm_alpha_recursion as prev_post 7 | 8 | def viterbi(visibles, p_hidden_start, p_transition, p_emission): 9 | num_timestamps = visibles.shape[0] 10 | num_hidden_states = p_transition.shape[0] 11 | 12 | # messages[t] corresponds to mu(h_t), which is the message coming into h_t 13 | messages = np.zeros((num_timestamps, num_hidden_states)) 14 | 15 | most_likely_states = np.zeros((num_timestamps,), dtype=int) 16 | 17 | # The message coming into the last node is 1 for all states 18 | messages[-1] = np.ones(num_hidden_states) 19 | 20 | # normalize! 21 | messages[-1] /= np.sum(messages[-1]) 22 | 23 | # Compute the messages! 24 | for t in range(num_timestamps - 1, 0, -1): 25 | # use the data at time t to make mu[h_{t - 1}] 26 | 27 | # compute max p(v|h)p(h|h)mu(h)! 28 | 29 | # compute p(v|h)mu(h) 30 | message_and_emission = messages[t] * p_emission[visibles[t]] 31 | 32 | # compute p(v|h)p(h|h)mu(h) 33 | # message_and_emission.reshape(-1, 1): new_state x 1 34 | # np.tile(...): new_state x old_state 35 | # p_transition: new_state x old_state 36 | # np.tile(...) * p_transition: new_state x old_state 37 | all_h_ts = np.tile( 38 | message_and_emission.reshape(-1, 1), 39 | (1, num_hidden_states) 40 | ) * p_transition 41 | 42 | # the message is the value from the highest h_t 43 | messages[t - 1] = np.max(all_h_ts, axis=0) 44 | 45 | # and normalize 46 | messages[t - 1] /= np.sum(messages[t - 1]) 47 | 48 | # now from the beginning! compute h_t* using these messages 49 | 50 | # argmax will give us the state. 51 | # argmax p(v_1|h_1)p(h_1)mu(h_1) 52 | most_likely_states[0] = np.argmax( 53 | p_hidden_start 54 | * p_emission[visibles[0]] 55 | * messages[0] 56 | ) 57 | 58 | for t in range(1, num_timestamps): 59 | # argmax_h_t p(v_t|h_t)p(h_t|h_{t - 1})mu(h_t) 60 | most_likely_states[t] = np.argmax( 61 | p_emission[visibles[t], :] 62 | * p_transition[:, most_likely_states[t - 1]] 63 | * messages[t] 64 | ) 65 | 66 | return most_likely_states 67 | 68 | most_likely_states = viterbi( 69 | prev_post.visibles, 70 | prev_post.p_hidden_start, 71 | prev_post.p_transition, 72 | prev_post.p_emission, 73 | ) 74 | 75 | print(most_likely_states) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jupyter 2 | keras 3 | matplotlib 4 | numpy 5 | pillow 6 | pymc3 7 | requests 8 | scipy 9 | scikit-learn 10 | tensorflow 11 | -------------------------------------------------------------------------------- /scripts/model_simulation.py: -------------------------------------------------------------------------------- 1 | class Parent(object): 2 | '''Fancy accessor helper that ensures components can't see the future and 3 | helps return either the recent value or a full history. 4 | 5 | Parent.get_value is important for avoiding loops, while still being 6 | deterministic about which timestep an answer is gotten from. 7 | 8 | :param prev_timestep: If True, get_value will access the previous 9 | timestep. 10 | :param with_history: If True, returns a list of history. If used 11 | with prev_timestep, will return up until the previous timestep. 12 | Otherwise will return up until the current timestep. 13 | ''' 14 | 15 | def __init__( 16 | self, 17 | component_name, 18 | prev_timestep=False, 19 | with_history=False, 20 | ): 21 | self.component_name = component_name 22 | self.prev_timestep = prev_timestep 23 | self.with_history = with_history 24 | 25 | def get_value(self, t, timestep_results): 26 | """Based on the Parent definition, access the history of 27 | the parent. 28 | 29 | :param t: current timestep 30 | :param timestep_results: full dictionary of all component's history 31 | 32 | :returns: float or list of floats 33 | """ 34 | 35 | # Splice list up until `t + 1` to include the value at time step t 36 | last_timestep = t + 1 37 | if self.prev_timestep: 38 | last_timestep = last_timestep - 1 39 | 40 | time_scope = timestep_results[self.component_name][:last_timestep] 41 | 42 | # For nice function definitions, either return the last element, or return the full history. 43 | if self.with_history: 44 | return time_scope 45 | else: 46 | return time_scope[-1] 47 | 48 | 49 | class Component(object): 50 | '''Base class for components. Shouldn't initiate this directly.''' 51 | _allowed_parents = [] # A list of Component types that can be a parent 52 | 53 | def reset(self): 54 | self.value = None if self.initial_value is None else float(self.initial_value) 55 | 56 | def __init__( 57 | self, 58 | name, 59 | equation=None, 60 | initial_value=None, 61 | parents=None, 62 | ): 63 | # save static attributes 64 | self.name = name 65 | self.initial_value = initial_value # If simulation is reran 66 | self.equation = equation 67 | self.parents = parents or [] 68 | 69 | # initialize 70 | self.reset() 71 | 72 | def __str__(self): 73 | return '{}'.format(self.name) 74 | 75 | def get_inputs(self, t, timestep_results): 76 | return { 77 | p.component_name: p.get_value(t, timestep_results) 78 | for p in self.parents 79 | } 80 | 81 | def _new_value(self, t, dt, prev_value, inputs): 82 | '''Should be implemented for each subtype. Calls the component's equation 83 | and returns the new value.''' 84 | raise NotImplementedError() 85 | 86 | def simulate_timestep(self, t, dt, timestep_results): 87 | inputs = self.get_inputs(t, timestep_results) 88 | # make sure that the new value is float 89 | self.value = float(self._new_value(t, dt, self.value, inputs)) 90 | 91 | 92 | class FlowComponent(Component): 93 | """A flow component.""" 94 | _allowed_parents = ['InfoComponent', 'StockComponent'] 95 | 96 | def _new_value(self, t, dt, prev_value, inputs): 97 | return self.equation(t, **inputs) 98 | 99 | 100 | class InfoComponent(FlowComponent): 101 | """An information component passes information to the flows of a system.""" 102 | _allowed_parents = ['StockComponent', 'FlowComponent', 'InfoComponent'] 103 | 104 | 105 | class StockComponent(Component): 106 | """A stock component. This represents something that grows and shrinks.""" 107 | _allowed_parents = ['FlowComponent'] 108 | 109 | def __init__( 110 | self, 111 | name, 112 | initial_value, 113 | inflow, 114 | outflow, 115 | min_value 116 | ): 117 | self.min_value = min_value 118 | self.inflow = inflow 119 | self.outflow = outflow 120 | 121 | super().__init__( 122 | name, 123 | equation=None, 124 | initial_value=initial_value, 125 | parents=[Parent(inflow, prev_timestep=True), Parent(outflow, prev_timestep=True)], 126 | ) 127 | 128 | def _new_value(self, t, dt, prev_value, inputs): 129 | # The equation is always given by the inflow - outflow 130 | return max(self.min_value, prev_value + dt * (inputs[self.inflow] - inputs[self.outflow])) 131 | 132 | 133 | class System(object): 134 | def __init__(self, components_list): 135 | self.components = { 136 | component.name: component 137 | for component in components_list 138 | } 139 | 140 | self._integrity_checks() 141 | 142 | def tree(self, with_prev_timestamp=True): 143 | return { 144 | node.name: [ 145 | p.component_name 146 | for p in node.parents 147 | if with_prev_timestamp or not p.prev_timestep 148 | ] 149 | for node in self.components.values() 150 | } 151 | 152 | def _integrity_checks(self): 153 | # all parents are defined 154 | for component_name, component in self.components.items(): 155 | for parent in component.parents: 156 | assertion_failed_str = "Component {}({}) parent `{}` not defined".format( 157 | component.name, 158 | type(component).__name__, 159 | parent.component_name, 160 | ) 161 | assert parent.component_name in self.components, assertion_failed_str 162 | 163 | # A time step shouldn't have loops. Also stores a topological sort. 164 | self.sorted_nodes = topological_sort(self.tree(with_prev_timestamp=False)) 165 | 166 | # Make sure connections are allowed 167 | for component in self.components.values(): 168 | for parent in component.parents: 169 | parent_type = type(self.components[parent.component_name]).__name__ 170 | assertion_failed_str = "Component {}({}) can't have parent {}({})".format( 171 | component.name, 172 | type(component).__name__, 173 | parent.component_name, 174 | parent_type 175 | ) 176 | assert parent_type in component._allowed_parents, assertion_failed_str 177 | 178 | # If something is a prev_parent, it needs an initial value. 179 | for component_name, component in self.components.items(): 180 | for parent in component.parents: 181 | if parent.prev_timestep: 182 | assertion_failed_str = "Component {} prev_parent `{}` needs initial value".format( 183 | component.name, 184 | parent.component_name, 185 | ) 186 | assert self.components[parent.component_name].value is not None, assertion_failed_str 187 | 188 | def __str__(self): 189 | return ', '.join(map(str, self.components)) 190 | 191 | def simulate(self, t_num, dt): 192 | # Clear history of components 193 | for c in self.components.values(): 194 | c.reset() 195 | 196 | # Load initial values 197 | history = { 198 | name: [node.value] 199 | for name, node in self.components.items() 200 | } 201 | 202 | # Simulate for t_num timesteps 203 | for t in range(1, t_num + 1): 204 | for name in self.sorted_nodes: 205 | self.components[name].simulate_timestep(t, dt, history) 206 | # update with new values 207 | history[name].append(self.components[name].value) 208 | 209 | # Strip the initial value 210 | history = { 211 | name: timesteps[1:] 212 | for name, timesteps in history.items() 213 | } 214 | 215 | return history 216 | 217 | 218 | def topological_sort(tree): 219 | '''Returns list of nodes in the provided tree in topological sorted.''' 220 | 221 | sorted_nodes = [] 222 | 223 | # loop until tree is empty. But use a for-loop so don't loop forever: 224 | # an element should be removed each loop, so it shouldn't iterate more than 225 | # len(tree) + 1 times before exiting. 226 | for _ in range(len(tree) + 1): 227 | # Done if cleared out tree! 228 | if not tree: 229 | break 230 | 231 | # Find all of the nodes with no more parents 232 | for name in list(tree): 233 | if not tree[name]: 234 | sorted_nodes.append(name) 235 | del tree[name] 236 | 237 | # Clean out solved parents 238 | tree = { 239 | node_name: [p for p in parents if p in tree] 240 | for node_name, parents in tree.items() 241 | } 242 | else: 243 | # if nodes remain, something is wrong 244 | raise Exception("Not all nodes are reachable", tree) 245 | 246 | return sorted_nodes 247 | 248 | -------------------------------------------------------------------------------- /scripts/process_weather_data.py: -------------------------------------------------------------------------------- 1 | '''Goes with the post from 2018-06-11. 2 | Hacky script for extracting a (hopefully) random subset of weather data. 3 | I'm using this for a tiny side project, and I haven't verified it totally 4 | works, so be careful if you're doing anything serious with it! 5 | 6 | This takes the un-tar'd download from 7 | 8 | ftp://ftp.ncdc.noaa.gov/pub/data/ghcn/daily/ghcnd_gsn.tar.gz 9 | 10 | or 11 | 12 | ftp://ftp.ncdc.noaa.gov/pub/data/ghcn/daily/ghcnd_hcn.tar.gz 13 | 14 | and writes a file to DESTINATION_NPZ. It tries to extract TRAIN_EXAMPLES_PER_STATION 15 | examples per station, and will remove the station if there are fewer than 16 | MIN_COUNT_PER_STATION examples for the station. 17 | Each example contains the station_id, month, day, and values for each of 18 | INCLUDED_ELEMENTS. 19 | 20 | It also writes a file STATIONS_FILE, where the line index is the station_id 21 | used in DESTINATION_NPZ, and the contents of the line are the station name 22 | used in the other GHCN data, including ghcnd-stations.txt. 23 | 24 | It throws out data that is low quality or missing. 25 | 26 | This is preprocessed with a script which removes data from before 1990. 27 | 28 | #!/bin/bash 29 | 30 | 31 | SRC_FOLDER=$1 32 | DEST_FOLDER=$2 33 | 34 | year_awk_program='{ 35 | if ( substr($0, 12, 4) > 1990 ) 36 | print $0 37 | }' 38 | 39 | 40 | for filename in ${SRC_FOLDER}/*.dly; do 41 | echo 'before' 42 | wc -l $filename 43 | cat $filename | awk "$year_awk_program" > "${filename}-tmp" 44 | echo 'after' 45 | wc -l "${filename}-tmp" 46 | done 47 | 48 | mkdir -p $DEST_FOLDER 49 | 50 | mv ${SRC_FOLDER}/*.dly-tmp ${DEST_FOLDER} 51 | 52 | # and remove the -tmp part of the name 53 | for filename in ${DEST_FOLDER}/*.dly-tmp; do 54 | mv "${filename}" "${filename%%-tmp}" 55 | done 56 | 57 | 58 | ''' 59 | 60 | import os 61 | from collections import namedtuple 62 | import numpy as np 63 | import random 64 | import sys 65 | 66 | EXAMPLE_COUNT_PER_STATION = 2000 67 | MIN_COUNT_PER_STATION = 0 68 | 69 | INCLUDED_ELEMENTS = ['TMAX', 'TMIN', 'PRCP'] 70 | 71 | SOURCE_FOLDER = sys.argv[1] # .dly files, maybe trimmed by script in comments above 72 | DESTINATION_FOLDER = sys.argv[2] # output directory, should exist 73 | DESTINATION_NPZ = os.path.join( 74 | DESTINATION_FOLDER, 75 | 'gsn-{}-{}.npz'.format( 76 | EXAMPLE_COUNT_PER_STATION, 77 | '-'.join(INCLUDED_ELEMENTS) 78 | ) 79 | ) 80 | STATIONS_FILE = os.path.join(DESTINATION_FOLDER, 'stations') 81 | 82 | 83 | RawStationData = namedtuple('RawStationData', [ 84 | 'station_id', 85 | 'year', 86 | 'month', 87 | 'day', 88 | 'element', 89 | 'value', 90 | ]) 91 | 92 | DailyWeatherReport = namedtuple('DailyWeatherReport', [ 93 | 'station_id', 94 | 'month', 95 | 'day', 96 | ] + INCLUDED_ELEMENTS) 97 | 98 | 99 | def count_lines_in_file(filename): 100 | total = 0 101 | with open(filename) as f: 102 | for line in f: 103 | total += 1 104 | 105 | return total 106 | 107 | 108 | def parse_line_and_filter(line): 109 | '''Generator that gives valid RawStationDatas, skips missing and low-quality data.''' 110 | 111 | # magic numbers are from the readme! 112 | station_id = station_id=line[:11] 113 | year = line[11:15] 114 | month = line[15:17] 115 | element = line[17:21] 116 | 117 | # Don't bother processing the line if I don't need the data from it 118 | if element not in INCLUDED_ELEMENTS: 119 | return None 120 | 121 | # also from the readme, extract daily data for the 31 days! 122 | total_low_qual = 0 123 | for day in range(31): 124 | offset = 21 + (day * 8) 125 | value = line[offset:offset+5] 126 | mflag = line[offset+5] 127 | qflag = line[offset+6] 128 | sflag = line[offset+7] 129 | 130 | # throw out missing and low-quality data 131 | if value == '-9999': 132 | continue 133 | if qflag != ' ': 134 | # double check it's looking at the right column 135 | assert qflag in 'DGIKLMNORSTWXZ' 136 | continue 137 | 138 | yield RawStationData( 139 | station_id=station_id, 140 | year=int(year), 141 | month=int(month), 142 | day=day, 143 | element=element, 144 | value=int(value), 145 | ) 146 | 147 | 148 | 149 | def read_station_data(station_id_to_index, station_id, number_to_return): 150 | filename = os.path.join(SOURCE_FOLDER, '{}.dly'.format(station_id)) 151 | 152 | # I'll use this dictionary to build up each DailyWeatherReport, representing a weather 153 | # report from a day/month/year. 154 | station_day = {} 155 | with open(filename) as f: 156 | for line in f: 157 | for entry in parse_line_and_filter(line): 158 | key = (entry.station_id, entry.year, entry.month, entry.day) 159 | if key not in station_day: 160 | station_day[key] = DailyWeatherReport( 161 | station_id=station_id_to_index[entry.station_id], 162 | month=entry.month, 163 | day=entry.day, 164 | **{element: None for element in INCLUDED_ELEMENTS}, # initialize elements to None 165 | ) 166 | station_day[key] = station_day[key]._replace(**{entry.element: entry.value}) 167 | 168 | # If no lines contained valid data, give up 169 | if not station_day: 170 | print('*{} no data!'.format(station_id)) 171 | return 172 | 173 | # Quality control: count how many of each column are null or not. 174 | print(', '.join( 175 | '{} {:06f}'.format( 176 | element, 177 | sum(1 for d in station_day.values() if getattr(d, element) is not None)/len(station_day) 178 | ) 179 | for element in INCLUDED_ELEMENTS 180 | )) 181 | 182 | # Filter out partial DailyWeatherReport, for example if the high temperature is low quality 183 | data = [ 184 | list(weather) 185 | for weather in station_day.values() 186 | if None not in weather # check that all values of DailyWeatherReport are non-null 187 | ] 188 | 189 | # Quality control, how many were filtered? 190 | print('{}: total valid {}, fraction valid {:06f}'.format( 191 | station_id, 192 | len(data), 193 | len(data)/len(station_day), 194 | )) 195 | 196 | if not data: 197 | print("*{}: No valid data for station!".format(station_id)) 198 | return 199 | 200 | # convert to numpy array 201 | valid_days = np.vstack(data).astype(np.int32) 202 | 203 | if valid_days.shape[0] < MIN_COUNT_PER_STATION: 204 | print("*{}: Not enough valid data for station!".format(station_id)) 205 | return 206 | if valid_days.shape[0] < number_to_return: 207 | # eh, probably doesn't matter, but shuffle them just in case. 208 | np.random.shuffle(valid_days) 209 | return valid_days 210 | else: 211 | # now choose a random subset 212 | indexes = np.random.choice(valid_days.shape[0], replace=False, size=number_to_return) 213 | return valid_days[indexes] 214 | 215 | 216 | def make_dataset(): 217 | print('writing to', DESTINATION_NPZ) 218 | 219 | # read in the stations, and write file containing the list of them 220 | stations = [ 221 | name.split('.')[0] 222 | for name in os.listdir(SOURCE_FOLDER) 223 | if name.endswith('.dly') 224 | ] 225 | station_mapper = list(stations) 226 | 227 | # make a mapping from the station name to index 228 | station_name_to_index = { 229 | station_name: i 230 | for i, station_name in enumerate(station_mapper) 231 | } 232 | 233 | with open(STATIONS_FILE, 'w') as f: 234 | f.write('\n'.join(station_mapper)) 235 | 236 | # do the big slow process of grabbing data from each station 237 | train_dataset_entries = ( 238 | read_station_data(station_name_to_index, station, number_to_return=EXAMPLE_COUNT_PER_STATION) 239 | for station in station_mapper 240 | ) 241 | 242 | data = np.vstack( 243 | entries 244 | for entries in train_dataset_entries 245 | if entries is not None 246 | ).astype(np.int16) 247 | 248 | np.savez( 249 | DESTINATION_NPZ, 250 | data=data, 251 | ) 252 | 253 | 254 | if __name__ == '__main__': 255 | make_dataset() 256 | -------------------------------------------------------------------------------- /scripts/run_simulation_for_d3.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Script for caching three graphs from one of the examples from the post [1], 3 | for use in an interactive post. 4 | 5 | [1] https://jessicastringham.net/2019/07/01/systems-modeling-from-scratch/ 6 | ''' 7 | import numpy as np 8 | import json 9 | 10 | 11 | from model_simulation import * 12 | 13 | 14 | # In this model, fish regenerate slower if there aren't many other fish, or if there are too many other fish. 15 | def regeneration_rate_given_resource(resource): 16 | scaled_resource = (resource/1000) 17 | 18 | if scaled_resource < 0.5: 19 | adjusted_resource = scaled_resource 20 | else: 21 | adjusted_resource = (1 - scaled_resource) 22 | 23 | rate = np.tanh(12 * adjusted_resource - 3) 24 | rate = (rate + 1)/4 25 | return max(0, rate) 26 | 27 | 28 | # People require fish, and are willing to pay more for fish if it is scarce. 29 | def price_given_yield_per_unit_capital(yield_per_unit_capital): 30 | return 8.8 * np.exp(-yield_per_unit_capital*4) + 1.2 31 | 32 | 33 | def yield_per_unit_capital_given_resource(resource, some_measure_of_efficiency): 34 | return min(1, max(0, (np.tanh(resource/1000*6 - 3 + some_measure_of_efficiency))/1.9 + 0.5)) 35 | 36 | 37 | def renewable_resource(some_measure_of_efficiency): 38 | return System([ 39 | StockComponent( 40 | name='resource', 41 | initial_value=1000, 42 | inflow='regeneration', 43 | outflow='harvest', 44 | min_value=0, 45 | ), 46 | FlowComponent( 47 | name='regeneration', 48 | initial_value=0, 49 | equation=lambda t, resource, regeneration_rate: resource * regeneration_rate, 50 | parents=[Parent('resource'), Parent('regeneration_rate')] 51 | ), 52 | FlowComponent( 53 | name='harvest', 54 | initial_value=0, 55 | equation=lambda t, resource, capital, yield_per_unit_capital: min(resource, capital * yield_per_unit_capital), 56 | parents=[Parent('resource'), Parent('capital', prev_timestep=True), Parent('yield_per_unit_capital')] 57 | ), 58 | 59 | StockComponent( 60 | name='capital', 61 | initial_value=5, 62 | inflow='investment', 63 | outflow='depreciation', 64 | min_value=0, 65 | ), 66 | FlowComponent( 67 | name='investment', 68 | equation=lambda t, profit, growth_goal: max(0, min(profit, growth_goal)), 69 | initial_value=0, 70 | parents=[Parent('profit'), Parent('growth_goal')] 71 | ), 72 | FlowComponent( 73 | name='depreciation', 74 | equation=lambda t, capital, capital_lifetime: capital/capital_lifetime, 75 | initial_value=0, 76 | parents=[Parent('capital', prev_timestep=True), Parent('capital_lifetime')] 77 | ), 78 | 79 | InfoComponent( 80 | name='capital_lifetime', 81 | equation=lambda _: 20), 82 | 83 | InfoComponent( 84 | name='growth_goal', 85 | equation=lambda t, capital: capital * .1, 86 | parents=[Parent('capital', prev_timestep=True)]), 87 | 88 | InfoComponent( 89 | name='profit', 90 | equation=lambda t, price, harvest, capital: (price * harvest) - capital, 91 | parents=[Parent('price'), Parent('harvest'), Parent('capital', prev_timestep=True)] 92 | ), 93 | InfoComponent( 94 | name='price', 95 | equation=lambda t, yield_per_unit_capital: price_given_yield_per_unit_capital(yield_per_unit_capital), 96 | parents=[Parent('yield_per_unit_capital')] 97 | ), 98 | InfoComponent( 99 | name='regeneration_rate', 100 | equation=lambda t, resource: regeneration_rate_given_resource(resource), 101 | parents=[Parent('resource')]), 102 | InfoComponent( 103 | name='yield_per_unit_capital', 104 | equation=lambda t, resource: yield_per_unit_capital_given_resource(resource, some_measure_of_efficiency), 105 | parents=[Parent('resource')] 106 | ) 107 | ]) 108 | 109 | 110 | if __name__ == '__main__': 111 | years = 200 112 | samples_per_year = 1 113 | t_num = years * samples_per_year 114 | dt = 1/samples_per_year 115 | 116 | # First generate all data. For each some_measure_of_efficiency, 117 | # compute the three graphs. 118 | xs_yield = np.linspace(0, 1000, 100) 119 | xs_simulated_resource = np.linspace(0, years, t_num + 1) 120 | xs_simulated_capital = xs_simulated_resource # same as resource 121 | 122 | graph_data_yield = {} 123 | graph_data_simulated_resource = {} 124 | graph_data_simulated_capital = {} 125 | 126 | tap_location_values = np.linspace(-0.5, 1.5, 20) 127 | 128 | # Cache the function for each tap_locations 129 | for yield_parameter in tap_location_values: 130 | s = renewable_resource(yield_parameter) 131 | simulation = s.simulate(t_num, dt) 132 | 133 | graph_data_yield[yield_parameter] = [ 134 | yield_per_unit_capital_given_resource(x, yield_parameter) 135 | for x in xs_yield 136 | ] 137 | 138 | graph_data_simulated_resource[yield_parameter] = simulation['resource'] 139 | graph_data_simulated_capital[yield_parameter] = simulation['capital'] 140 | 141 | 142 | database = { 143 | 'yield_parameters': list(sorted(graph_data_yield.keys())), 144 | 'yield_graph': { 145 | 'xs': list(xs_yield), 146 | 'ys_by_yield_parameter': graph_data_yield, 147 | 'x_domain': [min(xs_yield), max(xs_yield)], 148 | 'y_domain': [0, 1], 149 | }, 150 | 'yield_simulated_resource': { 151 | 'xs': list(xs_simulated_resource), 152 | 'ys_by_yield_parameter': graph_data_simulated_resource, 153 | 'x_domain': [min(xs_simulated_resource), max(xs_simulated_resource)], 154 | 'y_domain': [0, 1200], 155 | }, 156 | 'yield_simulated_capital': { 157 | 'xs': list(xs_simulated_capital), 158 | 'ys_by_yield_parameter': graph_data_simulated_capital, 159 | 'x_domain': [min(xs_simulated_capital), max(xs_simulated_capital)], 160 | 'y_domain': [0, 1200], 161 | }, 162 | } 163 | 164 | print(json.dumps(database)) 165 | 166 | --------------------------------------------------------------------------------