├── README.md ├── requirements.txt ├── SimpleGaussian.ipynb ├── Captcha.ipynb └── MarkovPath.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # MLHEP-pyprob 2 | 3 | [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/lukasheinrich/MLHEP-pyprob/master) 4 | 5 | Tutorial Code for MLHEP pyprob 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | appnope==0.1.0 2 | attrs==19.1.0 3 | backcall==0.1.0 4 | bleach==3.1.0 5 | cycler==0.10.0 6 | decorator==4.4.0 7 | defusedxml==0.6.0 8 | entrypoints==0.3 9 | flatbuffers==1.10 10 | ipykernel==5.1.1 11 | ipython==7.6.0 12 | ipython-genutils==0.2.0 13 | ipywidgets==7.5.0 14 | jedi==0.14.0 15 | Jinja2==2.10.1 16 | jsonschema==3.0.1 17 | jupyter==1.0.0 18 | jupyter-client==5.2.4 19 | jupyter-console==6.0.0 20 | jupyter-core==4.5.0 21 | kiwisolver==1.1.0 22 | MarkupSafe==1.1.1 23 | matplotlib==3.1.0 24 | mistune==0.8.4 25 | nbconvert==5.5.0 26 | nbformat==4.4.0 27 | notebook==5.7.8 28 | numpy==1.16.4 29 | pandocfilters==1.4.2 30 | parso==0.5.0 31 | pexpect==4.7.0 32 | pickleshare==0.7.5 33 | Pillow==6.0.0 34 | prometheus-client==0.7.1 35 | prompt-toolkit==2.0.9 36 | ptyprocess==0.6.0 37 | pydotplus==2.0.2 38 | Pygments==2.4.2 39 | pyparsing==2.4.0 40 | pyprob==0.13.0 41 | pyrsistent==0.15.2 42 | python-dateutil==2.8.0 43 | pyzmq==18.0.2 44 | qtconsole==4.5.1 45 | scipy==1.3.0 46 | Send2Trash==1.5.0 47 | six==1.12.0 48 | termcolor==1.1.0 49 | terminado==0.8.2 50 | testpath==0.4.2 51 | tornado==6.0.3 52 | traitlets==4.3.2 53 | wcwidth==0.1.7 54 | webencodings==0.5.1 55 | widgetsnbextension==3.5.0 56 | awkward 57 | uproot_methods 58 | -------------------------------------------------------------------------------- /SimpleGaussian.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "\u001b[1m\u001b[31mWarning: Empirical distributions on disk may perform slow because GNU DBM is not available. Please install and configure gdbm library for Python for better speed.\u001b[0m\n" 13 | ] 14 | } 15 | ], 16 | "source": [ 17 | "import pyprob\n", 18 | "%matplotlib inline\n", 19 | "import matplotlib.pyplot as plt\n", 20 | "from pyprob import Model\n", 21 | "import numpy as np" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "# Intro into `pyprob`\n", 29 | "\n", 30 | "`pyprob` is very easy. You only need to subclass `pyprob.Model` and add a `forward()` method that implements your generative model. Within the body of the model you can use `pyprob.sample` and `pyprob.observe` which are the necessary keywords of a probabilistic programming language.\n", 31 | "\n", 32 | "In this case it's a model which draws to random numbers `a` and `b` and draws the final value from a Gaussian with mean `a+b`, whatever their value was, with a small standard deviation.\n", 33 | "\n", 34 | "You can try thinking about what the posterior p(x,z) of prior p(z) look like." 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 4, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "import math\n", 44 | "import pyprob\n", 45 | "from pyprob import Model\n", 46 | "from pyprob.distributions import Normal\n", 47 | "\n", 48 | "class SimpleGaussian(Model):\n", 49 | " def __init__(self):\n", 50 | " super().__init__(name=\"Simple Gaussian\") # give the model a name\n", 51 | " self.prior_mean = 1\n", 52 | " self.prior_std = 4\n", 53 | " self.likelihood_std = 0.5\n", 54 | "\n", 55 | " def forward(self): # Needed to specifcy how the generative model is run forward\n", 56 | " # sample the (latent) mean variable to be inferred:\n", 57 | " a = pyprob.sample(Normal(self.prior_mean, self.prior_std), name = 'input1') # NOTE: sample -> denotes latent variables\n", 58 | " b = pyprob.sample(Normal(self.prior_mean, self.prior_std), name = 'input2') # NOTE: sample -> denotes latent variables\n", 59 | "\n", 60 | " mu = a+b\n", 61 | "\n", 62 | " pyprob.observe(Normal(mu, self.likelihood_std), name='obs0') # NOTE: observe -> denotes observable variables\n", 63 | " return a,b\n", 64 | "\n", 65 | "model = SimpleGaussian()" 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "metadata": {}, 71 | "source": [ 72 | "## Learning to Infer\n", 73 | "\n", 74 | "inference means that we can efficiently get samples that follow the posterior `p(z|x)`. In the \"inference compilation\" method, we train a network that, when running the `forward()` will propose appropriate distribution to sample from such that the overall trace approaches the true posterior. Let's train this \"smart ML agent\" that will steer our simulation code in the right direction." 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 5, 80 | "metadata": {}, 81 | "outputs": [ 82 | { 83 | "name": "stdout", 84 | "output_type": "stream", 85 | "text": [ 86 | "Creating new inference network...\n", 87 | "Observable obs0: observe embedding not specified, using the default FEEDFORWARD.\n", 88 | "Observe embedding dimension: 32\n", 89 | "Train. time | Epoch| Trace | Init. loss| Min. loss | Curr. loss| T.since min | Traces/sec\n", 90 | "New layers, address: 20__forward__a__Normal__1, distribution: Normal\n", 91 | "New layers, address: 44__forward__b__Normal__1, distribution: Normal\n", 92 | "Total addresses: 2, parameters: 6,926\n", 93 | "0d:00:00:06 | 1 | 10,048 | +5.77e+00 | +4.50e+00 | \u001b[31m+5.08e+00\u001b[0m | 0d:00:00:02 | 1,053.0 \n" 94 | ] 95 | } 96 | ], 97 | "source": [ 98 | "model.learn_inference_network(\n", 99 | " num_traces=10000,\n", 100 | " observe_embeddings={'obs0': {'dim': 32, 'depth': 3}}\n", 101 | ")" 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "metadata": {}, 107 | "source": [ 108 | "Let's now collect some samples from both the prior and the posterior so we can compare" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 7, 114 | "metadata": {}, 115 | "outputs": [ 116 | { 117 | "name": "stdout", 118 | "output_type": "stream", 119 | "text": [ 120 | "Time spent | Time remain.| Progress | Trace | Traces/sec\n", 121 | "0d:00:00:00 | 0d:00:00:00 | #################### | 1000/1000 | 2,388.31 \n", 122 | "Time spent | Time remain.| Progress | Trace | Traces/sec\n", 123 | "0d:00:00:03 | 0d:00:00:00 | #################### | 1000/1000 | 332.45 \n" 124 | ] 125 | } 126 | ], 127 | "source": [ 128 | "condition = {'obs0': 2}\n", 129 | "\n", 130 | "prior = model.prior_distribution(\n", 131 | " num_traces=1000,\n", 132 | ")\n", 133 | "posterior = model.posterior_distribution(\n", 134 | " num_traces=1000,\n", 135 | " inference_engine=pyprob.InferenceEngine.IMPORTANCE_SAMPLING_WITH_INFERENCE_NETWORK,\n", 136 | " observe=condition\n", 137 | ")" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 8, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "edges = np.linspace(-10,10, 25)" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 9, 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "prior = np.asarray([[x.item() for x in prior.sample()] for x in range(1000)])\n", 156 | "posterior = np.asarray([[x.item() for x in posterior.sample()] for x in range(1000)])" 157 | ] 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "metadata": {}, 162 | "source": [ 163 | "## The conditioned posterior\n", 164 | "\n", 165 | "\n", 166 | "As we plot the latent state of the simulator (the numbers `a` and `b`) below we see the effect of the conditioning. Conditioning on a particular value `c` forces the relation `a+b = c` to approximately hold (within the range of the standard deviation). What will happen if you decreate the standard deviation?" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": 11, 172 | "metadata": {}, 173 | "outputs": [ 174 | { 175 | "data": { 176 | "image/png": "\n", 177 | "text/plain": [ 178 | "
" 179 | ] 180 | }, 181 | "metadata": { 182 | "needs_background": "light" 183 | }, 184 | "output_type": "display_data" 185 | } 186 | ], 187 | "source": [ 188 | "plt.scatter(prior[:,0],prior[:,1], alpha = 0.2)\n", 189 | "plt.scatter(posterior[:,0],posterior[:,1], alpha = 0.2)\n", 190 | "plt.xlim(-20,20)\n", 191 | "plt.ylim(-20,20)\n", 192 | "plt.gcf().set_size_inches(5,5)" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": null, 198 | "metadata": {}, 199 | "outputs": [], 200 | "source": [] 201 | } 202 | ], 203 | "metadata": { 204 | "kernelspec": { 205 | "display_name": "Python 3", 206 | "language": "python", 207 | "name": "python3" 208 | }, 209 | "language_info": { 210 | "codemirror_mode": { 211 | "name": "ipython", 212 | "version": 3 213 | }, 214 | "file_extension": ".py", 215 | "mimetype": "text/x-python", 216 | "name": "python", 217 | "nbconvert_exporter": "python", 218 | "pygments_lexer": "ipython3", 219 | "version": "3.6.6" 220 | } 221 | }, 222 | "nbformat": 4, 223 | "nbformat_minor": 2 224 | } 225 | -------------------------------------------------------------------------------- /Captcha.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Probabilistic Programming to break CAPTCHAS\n", 8 | "\n", 9 | "Captchas are puzzles usually given as part of a verification procedure to ensure a internet user is indeed a human and not a bit. It involves identifying a string of letters (or numbers)\n", 10 | "\n", 11 | "\n", 12 | "In this example we will write a probabilistic CAPTCHA generator and instrument it using the probabilistic programming library `pyprob` in order to run amortized inference." 13 | ] 14 | }, 15 | { 16 | "cell_type": "markdown", 17 | "metadata": {}, 18 | "source": [ 19 | "First, we import some basic libraries we will need later on." 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 1, 25 | "metadata": {}, 26 | "outputs": [ 27 | { 28 | "name": "stdout", 29 | "output_type": "stream", 30 | "text": [ 31 | "\u001b[1m\u001b[31mWarning: Empirical distributions on disk may perform slow because GNU DBM is not available. Please install and configure gdbm library for Python for better speed.\u001b[0m\n" 32 | ] 33 | } 34 | ], 35 | "source": [ 36 | "%matplotlib inline\n", 37 | "import numpy as np\n", 38 | "import torch\n", 39 | "import time\n", 40 | "import matplotlib.pyplot as plt\n", 41 | "\n", 42 | "import pyprob\n", 43 | "import pyprob.distributions\n", 44 | "import IPython\n", 45 | "\n", 46 | "from PIL import Image, ImageFilter, ImageDraw, ImageFont\n" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "## Writing a Generator\n", 54 | "\n", 55 | "Here we use the `Pillow` library to write a generator. This is a good example how probabilistic programs allow you use general purpose host languages (in this case Python, but could also be C++) with all their nice libraries, instead of restricting you to a statistical modeling framework (try writing this in RooFit :) )" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 2, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "# prepare the set of letters from which we \n", 65 | "\n", 66 | "alphabetorder = list('abcdefghijklmnopqrstuvw')\n", 67 | "alphabet = dict(zip(range(len(alphabetorder)),alphabetorder))\n", 68 | "\n", 69 | "\n", 70 | "def message_to_picture(width, height, message, blur):\n", 71 | " canvas = Image.new('L', (width, height), \"white\")\n", 72 | " font = ImageFont.load_default()\n", 73 | " ImageDraw.Draw(canvas).text((2,0), message, 'black', font)\n", 74 | " canvas = canvas.filter(ImageFilter.GaussianBlur(blur))\n", 75 | " d = np.asarray(list(canvas.getdata()), dtype = np.float)\n", 76 | " return d\n", 77 | "\n", 78 | "def random_captcha(width, height, blur = 0.5):\n", 79 | " letter_distr = pyprob.distributions.Categorical([1/len(alphabet)]*len(alphabet))\n", 80 | " length = 3\n", 81 | " word = ''.join([alphabet[letter_distr.sample().item()] for i in range(length)])\n", 82 | " d = message_to_picture(width, height, word,blur)\n", 83 | " return word, torch.tensor(d, dtype = torch.float).view(-1)/255-0.5\n", 84 | "\n", 85 | "def likelihoood(observation,eps):\n", 86 | " lhood = pyprob.distributions.Normal(observation,eps*torch.ones(observation.shape))\n", 87 | " return lhood" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": {}, 93 | "source": [ 94 | "# An Example CAPTCHA\n", 95 | "\n", 96 | "Note that our captcha generator is very un-sophisticated, but it has the correct components: sampling latters, adding some blur, etc...\n", 97 | "\n", 98 | "Let's generate a CAPTCHA to see how it looks" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 3, 104 | "metadata": {}, 105 | "outputs": [ 106 | { 107 | "data": { 108 | "text/plain": [ 109 | "Text(0.5, 1.0, 'ground truth: ian')" 110 | ] 111 | }, 112 | "execution_count": 3, 113 | "metadata": {}, 114 | "output_type": "execute_result" 115 | }, 116 | { 117 | "data": { 118 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAAS2ElEQVR4nO3dfbAldX3n8feHGVAeDDADy7OCKZesbhahboxJiKEEXWAjxJSVgmgWHzYTa9dEtpIVUiZqYp40D2WSNTFEEaIskOBDkMIAMRrLrEIGHAQEZUJABwcYB+VBQXn45o/TQy7H+zS3e865v+H9qjp1+3T/uvvbfft+Tt/fOac7VYUkqT27TLsASdLyGOCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywLXTSvK2JB+c8DoPT1JJVi9z/vck+fWh69LOyQDXU1KS45JsGmA5tyc5YYiaAKrq9VX19qGWp52bAa6JWe5Z6bS0Vq+eegxw9ZLkmCSfT/JAkr9JcnGS3+qmHZdkU5KzktwFvL8b//NJNia5N8mlSQ7uxn9P90OSTyX5H93wq5N8JskfJPlGkn9NctKstkck+ceulquA/eapeU/g48DBSR7sHgd3XS6XJPlgkvuBVyc5b9v2zN6mbvgDwDOBj3XLeNOs1bwyyVeSfD3Jm7djfz6xviT7JrksyZZuey9LcujYvnl7kn/qtvnKJHNus3ZOBriWLcluwEeA84A1wIXAy8eaHdhNexawLsmLgd8FfgY4CLgDuGg7VvvDwJcYhfM7gfclSTft/wHXdtPeDpwx1wKq6lvAScDXqmqv7vG1bvKpwCXAPsAFCxVSVT8HfAV4WbeMd86afCxwJHA88JYk/wkgybFJvrnEbd2F0Yvesxi9UDwE/N+xNj8LvAb4D8BuwK8scdnaCfgvovp4IaNj6E9qdFGdDye5ZqzN48Bbq+o7AEleCZxbVdd1z38V+EaSw5e4zjuq6i+7ec8H/gw4oHsx+SHghG5dn07ysWVs02er6qPd8EP//tqw3X6jqh4Crk9yPXAUcHNVfYbRi8Oiqmor8KFtz5P8NvDJsWbvr6ovd9P/GjhluQWrPZ6Bq4+DgTvryVdE++pYmy1V9fDYPHdse1JVDwJbgUOWuM67Zs377W5wr2653+jOrre5g+03Xv9y3TVr+NuMatwuSfZI8hdJ7ui6dD4N7JNk1ZDrUbsMcPWxGTgkTz5NPWyszfjlLr/GqEsAeKI/ei1wJ7AtfPeY1f7A7ahl32552zxzgfbzXYZzfPy3FqlnR17O85cZdcP8cFV9H/Cibvyy/y3QzsUAVx+fBR4D3pBkdZJTgRcsMs+FwGuSPD/J04DfAa6uqturagujIH9VklVJXgt8/1IKqao7gPXAbyTZLcmxwMsWmOVuYG2SvRdZ9Abg5CRrkhwInDnHcp69lBqX4RmM+r2/mWQN8NYdtB41ygDXslXVd4GfBl4HfBN4FXAZ8J0F5vl74NcZ9e1uZhTQp81q8vPA/2HUrfI84P9vR0k/y+hNznsZhd1fLVDHLYxeTG5L8s1tn4SZwweA64HbgSuBi8em/y7wa90yFn0DMcmPJ3lwsXaddwG7A18HPgf83RLn01NEvKGDhpTkauA9VfX+adci7ew8A1cvSX4iyYFdF8oZwH/BM0VpIvwYofo6EvhrYE/gNuAVVbV5uiVJTw12oUhSo+xCkaRGGeCS1KiJ9oHvt2ZVHX7YrpNcpSQ179ovfOfrVbX/+PiJBvjhh+3KNVeMf1FPkrSQVQdtnPOyEHahSFKjDHBJapQBLkmN6hXgSU5M8qXu7ipnD1WUJGlxyw7w7prE72Z0Z5PnAqcnee5QhUmSFtbnDPwFwMaquq27Kt1FjG5HJUmagD4BfghPvnvJJua4q0qSdUnWJ1m/ZetjPVYnSZpth7+JWVXnVNVMVc3sv3bV4jNIkpakT4DfyZNvn3VoN06SNAF9AvyfgeckOaK7I/hpwKXDlCVJWsyyv0pfVY8meQNwBbAKOLeqbhqsMknSgnpdC6WqLgcuH6gWSdJ28JuYktQoA1ySGmWAS1KjvKnxlNz3+EPTLoG9d9l92iVI6sEzcElqlAEuSY0ywCWpUQa4JDXKAJekRhngktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhrl9cCXYYhreZ9+9CkDVNLPhZ+/tPcyvKa4ND2egUtSowxwSWqUAS5JjTLAJalRyw7wJIcl+WSSLya5KckbhyxMkrSwPp9CeRT45aq6LskzgGuTXFVVXxyoNknSApZ9Bl5Vm6vqum74AeBm4JChCpMkLWyQPvAkhwNHA1cPsTxJ0uJ6B3iSvYAPAWdW1f1zTF+XZH2S9Vu2PtZ3dZKkTq8AT7Iro/C+oKo+PFebqjqnqmaqamb/tav6rE6SNEufT6EEeB9wc1X90XAlSZKWos8Z+I8BPwe8OMmG7nHyQHVJkhax7I8RVtVngAxYiyRpO/hNTElqlAEuSY0ywCWpUU+5GzoMcTOGIZx33UenXQJ777LntEsYzEr5vfY1xA0yVsK+cDsmwzNwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDXKAJekRhngktQoA1ySGmWAS1KjmruhQ9+LvJ9+9Cm9a6iHH+69DB55pPcisnu/i81feMPlvWsY4oL3mx59sPcy1h31sl7z16OP9q4hq/v/Of3pho/1XsYvzry83wIG2BcrYjug97a85/r+2/HM1Xv1XsZ8PAOXpEYZ4JLUKANckhplgEtSo3oHeJJVST6f5LIhCpIkLc0QZ+BvBG4eYDmSpO3QK8CTHAr8N+C9w5QjSVqqvmfg7wLeBDw+X4Mk65KsT7J+y9bHeq5OkrTNsgM8yU8C91TVtQu1q6pzqmqmqmb2X7tquauTJI3pcwb+Y8ApSW4HLgJenOSDg1QlSVrUsgO8qn61qg6tqsOB04B/qKpXDVaZJGlBfg5ckho1yMWsqupTwKeGWJYkaWk8A5ekRhngktQoA1ySGpWqmtjKZo56el1zxWETW99c+t4QYrSM/l9I+oX/eELvZfR10cZPTrsEYKCbbDzwQK/533HLp3rXcNYPHNd7GX1v0gHwrg39Lkv0xiOP713DStgO6L8tQ2zHEDdOWXPInddW1cz4eM/AJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDXKAJekRhngktQoA1ySGmWAS1KjDHBJatQg98Rsyd679L++L/S/pvjOYoj9+a71H+29jL7XfT7rB1/Su4bHH+53TXIY5oxqnxVwWlYP9f8b2Vm2Y0daAbtIkrQcBrgkNcoAl6RGGeCS1KheAZ5knySXJLklyc1JfmSowiRJC+v7KZQ/Bv6uql6RZDdgjwFqkiQtwbIDPMnewIuAVwNU1XeB7w5TliRpMX26UI4AtgDvT/L5JO9NsudAdUmSFtEnwFcDxwB/XlVHA98Czh5vlGRdkvVJ1m/Z+liP1UmSZusT4JuATVV1dff8EkaB/iRVdU5VzVTVzP5rV/VYnSRptmUHeFXdBXw1yZHdqOOBLw5SlSRpUX0/hfKLwAXdJ1BuA17TvyRJ0lL0CvCq2gDMDFSLJGk7+E1MSWqUAS5JjTLAJalRT7kbOtz3eP8LtN/3+M7xefYhtuO+xx/svYwzn/+y3svoe5ONd9xwVe8KzvqB43ovQ9oenoFLUqMMcElqlAEuSY0ywCWpUQa4JDXKAJekRhngktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhrV3A0d+t6Q4fSjT+ldQz38cO9lwCMDLKOfX3jeSdMuYeSR/jfZ6OusH3zJAEvp/zvN7rsPUEf7NQxlZ9qWuXgGLkmNMsAlqVEGuCQ1qleAJ/nfSW5KcmOSC5M8fajCJEkLW3aAJzkE+CVgpqr+M7AKOG2owiRJC+vbhbIa2D3JamAP4Gv9S5IkLcWyA7yq7gT+APgKsBm4r6quHKowSdLC+nSh7AucChwBHAzsmeRVc7Rbl2R9kvVbtj62/EolSU/SpwvlBOBfq2pLVT0CfBj40fFGVXVOVc1U1cz+a1f1WJ0kabY+Af4V4IVJ9kgS4Hjg5mHKkiQtpk8f+NXAJcB1wA3dss4ZqC5J0iJ6XQulqt4KvHWgWiRJ28FvYkpSowxwSWqUAS5JjTLAJalRqaqJrWzmqKfXNVccNrH1zaXvDSGklW7vXfrdxGCl/I303Q5YGdsyxHasOmjjtVU1Mz7eM3BJapQBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIa1eumxi0a4tq80s5sZ/ob2Zm2ZS6egUtSowxwSWqUAS5JjTLAJalRiwZ4knOT3JPkxlnj1iS5Ksmt3c99d2yZkqRxSzkDPw84cWzc2cAnquo5wCe655KkCVo0wKvq08C9Y6NPBc7vhs8HfmrguiRJi1huH/gBVbW5G74LOGCgeiRJS9T7TcyqKqDmm55kXZL1SdZv2fpY39VJkjrLDfC7kxwE0P28Z76GVXVOVc1U1cz+a1ctc3WSpHHLDfBLgTO64TOAvx2mHEnSUi3lY4QXAp8FjkyyKcnrgN8DXpLkVuCE7rkkaYIWvZhVVZ0+z6TjB65FkrQd/CamJDXKAJekRhngktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIalaqa3MqSLcAdCzTZD/j6hMrpwzqHZZ3DaaFGsM7t9ayq2n985EQDfDFJ1lfVzLTrWIx1Dss6h9NCjWCdQ7ELRZIaZYBLUqNWWoCfM+0Clsg6h2Wdw2mhRrDOQayoPnBJ0tKttDNwSdISTSXAk5yY5EtJNiY5e47pT0tycTf96iSHT6HGw5J8MskXk9yU5I1ztDkuyX1JNnSPt0y6zq6O25Pc0NWwfo7pSfIn3f78QpJjplDjkbP204Yk9yc5c6zNVPZnknOT3JPkxlnj1iS5Ksmt3c9955n3jK7NrUnOmHCNv5/klu53+pEk+8wz74LHxwTqfFuSO2f9Xk+eZ94Fc2ECdV48q8bbk2yYZ96J7c9FVdVEH8Aq4F+AZwO7AdcDzx1r8z+B93TDpwEXT6HOg4BjuuFnAF+eo87jgMsmXdsctd4O7LfA9JOBjwMBXghcPeV6VwF3Mfps69T3J/Ai4Bjgxlnj3gmc3Q2fDbxjjvnWALd1P/fthvedYI0vBVZ3w++Yq8alHB8TqPNtwK8s4ZhYMBd2dJ1j0/8QeMu09+dij2mcgb8A2FhVt1XVd4GLgFPH2pwKnN8NXwIcnyQTrJGq2lxV13XDDwA3A4dMsoYBnQr8VY18DtgnyUFTrOd44F+qaqEvdU1MVX0auHds9Oxj8Hzgp+aY9b8CV1XVvVX1DeAq4MRJ1VhVV1bVo93TzwGH7oh1b4959uVSLCUXBrNQnV3W/Axw4Y5a/1CmEeCHAF+d9XwT3xuMT7TpDtD7gLUTqW4OXRfO0cDVc0z+kSTXJ/l4kudNtLB/V8CVSa5Nsm6O6UvZ55N0GvP/cayE/QlwQFVt7obvAg6Yo81K2q+vZfRf1lwWOz4m4Q1dV8+583RHraR9+ePA3VV16zzTV8L+BHwTc1FJ9gI+BJxZVfePTb6OUTfAUcCfAh+ddH2dY6vqGOAk4H8ledGU6lhUkt2AU4C/mWPyStmfT1Kj/5tX7Me1krwZeBS4YJ4m0z4+/hz4fuD5wGZG3RMr2eksfPY97f35hGkE+J3AYbOeH9qNm7NNktXA3sDWiVQ3S5JdGYX3BVX14fHpVXV/VT3YDV8O7JpkvwmXSVXd2f28B/gIo39HZ1vKPp+Uk4Drquru8QkrZX927t7WzdT9vGeONlPfr0leDfwk8MruheZ7LOH42KGq6u6qeqyqHgf+cp71T31fwhN589PAxfO1mfb+nG0aAf7PwHOSHNGdjZ0GXDrW5lJg2zv6rwD+Yb6Dc0fp+sHeB9xcVX80T5sDt/XNJ3kBo/050ReaJHsmeca2YUZvbN041uxS4L93n0Z5IXDfrO6BSZv37GYl7M9ZZh+DZwB/O0ebK4CXJtm36xZ4aTduIpKcCLwJOKWqvj1Pm6UcHzvU2PstL59n/UvJhUk4AbilqjbNNXEl7M8nmcY7p4w+FfFlRu86v7kb95uMDkSApzP6F3sjcA3w7CnUeCyjf5u/AGzoHicDrwde37V5A3ATo3fMPwf86BTqfHa3/uu7Wrbtz9l1Bnh3t79vAGam9Hvfk1Eg7z1r3NT3J6MXlM3AI4z6Xl/H6D2XTwC3An8PrOnazgDvnTXva7vjdCPwmgnXuJFRv/G243PbJ7cOBi5f6PiYcJ0f6I67LzAK5YPG6+yef08uTLLObvx5247HWW2ntj8Xe/hNTElqlG9iSlKjDHBJapQBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhr1b/ExnIuZPDTNAAAAAElFTkSuQmCC\n", 119 | "text/plain": [ 120 | "
" 121 | ] 122 | }, 123 | "metadata": { 124 | "needs_background": "light" 125 | }, 126 | "output_type": "display_data" 127 | } 128 | ], 129 | "source": [ 130 | "word, img = random_captcha(20,12, blur = 0.2)\n", 131 | "plt.imshow(img.reshape(12,20))\n", 132 | "plt.title('ground truth: {}'.format(word))" 133 | ] 134 | }, 135 | { 136 | "cell_type": "markdown", 137 | "metadata": {}, 138 | "source": [ 139 | "... and another one" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 4, 145 | "metadata": {}, 146 | "outputs": [ 147 | { 148 | "data": { 149 | "text/plain": [ 150 | "Text(0.5, 1.0, 'ground truth: ckm')" 151 | ] 152 | }, 153 | "execution_count": 4, 154 | "metadata": {}, 155 | "output_type": "execute_result" 156 | }, 157 | { 158 | "data": { 159 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAAVSklEQVR4nO3de5RdZX3G8efJTEJCEnOXXDEJpRQQFToiCCrLpBSogLcqKpZbzXK1VLG1FBet2NpVRdFlbamuKDeRAgqowEIhgmhdSnAIdwLmAoGE3AiEEAhJZubXP84ePDnM5cy8O+fMG76ftWbNPme/73t+Z8+eZ/bssy+OCAEA8jOs2QUAAAaHAAeATBHgAJApAhwAMkWAA0CmCHAAyBQBjj2G7S/Y/n6DX3O27bDdmsO42LMQ4HhNsH2M7dUljPOE7fll1ASkIsCx2+S29ZhbvQABjgGxfZjte22/YPuHtq+1/e/FvGNsr7b9T7bXSbqseP4Ttpfbftb2jbanF8+/ajeB7Ttt/3UxfbrtX9u+yPZzth+3fXxV2zm2f1nUskjS5F5qHi3pp5Km295afE0vdrlcZ/v7trdIOt325d3vp/o9FdNXStpX0k3FGOdWvczHbD9p+xnb5w9geY6y/TXbq2w/X7zfUT20+0Cx9f/GquV2hu2nimXzSdtvtf2A7c22/7veGpAvAhx1sz1C0o8kXS5poqSrJb2vptnUYt4bJC2w/W5JX5L0IUnTJK2SdM0AXvZtkh5TJZy/IukS2y7m/a+ke4p5X5R0Wk8DRMSLko6X9HREjCm+ni5mnyzpOknjJV3VVyER8XFJT0o6sRjjK1Wzj5Z0gKR5kj5v+0BJsn207c19DHuRpD+V9HZVltu5krqqG9g+Q9KFkuZHxENVs94maX9JH5b0DUnnS5ov6WBJH7L9rr7eD/LHv4wYiCNUWWe+GZWL6Nxg++6aNl2SLoiI7ZJk+2OSLo2IJcXjz0l6zvbsOl9zVUR8p+h7haT/kbRP8cfkraqE2nZJv7J90yDe028j4sfF9LY//G0YsH+NiG2S7rd9v6Q3S1oaEb9W5Y/Dq9geJulMSUdExJri6d8U87qbnVO0OSYiavfhfzEiXpZ0m+0XJV0dERuK/v8n6VBJvxzsG8LQxxY4BmK6pDWx6xXQnqpps7EIleo+q7ofRMRWSZskzajzNddV9X2pmBxTjPtcsXXdbZUGrrb+wVpXNf2SKjX2Z7KkkZJW9NHmHyVd3EN4S9L6qultPTyupwZkjADHQKyVNMO7bqbOqmlTe3nLp1XZnSLplf3RkyStkdQdvntXtZ86gFomFON127eP9r1ddrP2+Rf7qafMy3c+I+llSfv10eZYSf9s+wMlvi72EAQ4BuK3kjolnW271fbJkg7vp8/Vks6w/Rbbe0n6D0mLI+KJiNioSpCfarvF9pnqO8xeERGrJLVL+lfbI2wfLenEPrqslzTJ9rh+hr5P0gm2J9qeqsoujNpx5tZTY38iokvSpZK+Xnyo2mL7yGI5dXtY0nGSLrZ9Uhmviz0HAY66RcQOSe+XdJakzZJOlXSzpO199Pm5pH+RdL0qW837STqlqsknVNlNsEmVD99+M4CSPqrKB3nPSrpA0vf6qONRVf6YrCyO0pjeS9MrJd0v6QlJt0m6tmb+l1TZIt5s+7P9FWj7Hba39tHks5IelPS74n1cqJrfy4i4X9J7JH2n+igcwNzQASlsL5b07Yi4rNm1AK81bIFjQGy/y/bUYhfKaZLeJOlnza4LeC3iMEIM1AGSfiBptKSVkj4YEWubWxLw2sQuFADIFLtQACBTBDgAZKqh+8AnT2yJ2bOGN/Il0Yco4ZyUMnbAlVGHNehT4Iv+6VJrAHpzzwPbn4mIKbXPNzTAZ88arrtvrT1xD83SGV39N+pHhzqTx9gZ6WMMd0tS/1al9ZekFvMPLXaPlmnLe7xMBGscAGSKAAeATBHgAJCppAC3fZztx4q7rZxXVlEAgP4NOsBtt0i6WJU7nRwk6SO2DyqrMABA31K2wA+XtDwiVhZXqbtGldtTAQAaICXAZ2jXu5msVg93WbG9wHa77faNm9IPFwMAVOz2DzEjYmFEtEVE25RJ6cfaAgAqUgJ8jXa9ndbM4jkAQAOkBPjvJO1ve05xh/BTJN1YTlkAgP4M+lT6iOiwfbakWyW1SLo0Ih4urTIAQJ+SroUSEbdIuqWkWgAAA8CZmACQKQIcADJFgANAprip8WvYpq5tyWPctHW/5DHW7xyXPMa41peS+v/l2EeTa3h9y+jkMYCBYAscADJFgANApghwAMgUAQ4AmSLAASBTBDgAZIoAB4BMEeAAkCkCHAAyRYADQKYIcADIFAEOAJkiwAEgUwQ4AGSKAAeATL3mrgfeGV3JY3Sos4Q6InmM7dGR1P/ObdOTa/jqD96fPMbUxTuTx9g6PXFV/vvkEvThEq4pvreHJ4/RpfR1fE8x3C1J/VuV1l+SWrz7tpPZAgeATBHgAJApAhwAMkWAA0CmBh3gtmfZ/oXtR2w/bPvTZRYGAOhbykf3HZL+ISKW2B4r6R7biyLikZJqAwD0YdBb4BGxNiKWFNMvSFoqaUZZhQEA+lbKPnDbsyUdKmlxGeMBAPqXHOC2x0i6XtI5EbGlh/kLbLfbbt+4Kf0EGABARVKA2x6uSnhfFRE39NQmIhZGRFtEtE2ZlH5WEwCgIuUoFEu6RNLSiPh6eSUBAOqRsgV+lKSPS3q37fuKrxNKqgsA0I9BH0YYEb+W5BJrAQAMAGdiAkCmCHAAyBQBDgCZyu6GDjsj7VjytZ3bkmu4aeuByWO81DUieYy9h+1I6n/IyKeSa5h65NPJY4y8eWzyGMO3pB2iumTLG5JrKONnWoauSNsue7mrjJtKpH88NkzpNz2ZPPyFpP4njlmaXMO0llHJY/SGLXAAyBQBDgCZIsABIFMEOABkigAHgEwR4ACQKQIcADJFgANApghwAMgUAQ4AmSLAASBTBDgAZIoAB4BMEeAAkCkCHAAyRYADQKayu6FD6g0ZPvX4B5JrWH3l3OQxhnUkD6HOvdL6z/jo48k17Dv22eQx1re8LnmM1s0vJ/W/Y8lByTU80H5I8hhdJfxGpq5br1uVdqMQSWrZVsIK7vSbQmyZMzKp/6JPpK8XX5t9ffIYvWELHAAyRYADQKYIcADIFAEOAJlKDnDbLbbvtX1zGQUBAOpTxhb4pyUtLWEcAMAAJAW47ZmS/kLSd8spBwBQr9Qt8G9IOldSV28NbC+w3W67feOmzsSXAwB0G3SA236PpA0RcU9f7SJiYUS0RUTblEktg305AECNlC3woySdZPsJSddIerft75dSFQCgX4MO8Ij4XETMjIjZkk6RdEdEnFpaZQCAPnEcOABkqpSLWUXEnZLuLGMsAEB92AIHgEwR4ACQKQIcADLV0Bs6dKpLW7vSLrx/09YDk/qvvST9ZgwTV6TdVEKSln18RPIYk+9OO65++c/Tl8WEo9YljzG2o9fzwOoWy9JuTvEn356TXMPmg8cnjzH89PXJY2zdnrZutXw57SYIktS6Of13ROufSR5i/I4ZSf1Xfy/9d+Qnn3lT8hhSz79nbIEDQKYIcADIFAEOAJkiwAEgUwQ4AGSKAAeATBHgAJApAhwAMkWAA0CmCHAAyBQBDgCZIsABIFMEOABkigAHgEwR4ACQqYZeD7wMz3WMTuo/9qnt6UVEJA9x4AGrk8d4dt+9k/q/vrUjuYbX7/1C8hhbnPYzlaTYnvZzHbazM7mGdfPSl+dFc25PHuPuF9OuYf2L/Y9MruGZtvRrik+/c0LyGE8fk9Z/0r3JJej5zlHpg/SCLXAAyBQBDgCZIsABIFMEOABkKinAbY+3fZ3tR20vtZ3+6QcAoC6pR6H8p6SfRcQHbY+QlHZYBACgboMOcNvjJL1T0umSFBE7JO0opywAQH9SdqHMkbRR0mW277X9XbuEA3oBAHVJCfBWSYdJ+lZEHCrpRUnn1TayvcB2u+32TZu6El4OAFAtJcBXS1odEYuLx9epEui7iIiFEdEWEW2TJnHQCwCUZdCJGhHrJD1l+4DiqXmSHimlKgBAv1KPQvk7SVcVR6CslHRGekkAgHokBXhE3CepraRaAAADwE5pAMgUAQ4AmSLAASBT2d3QYULri0n9X5i1V3IN41ZsSx5j6WMzk8eYfHdLUv8NM5xcQ8dR65LHGN2SXsew0WnnkMWI9F+FqXekj3HZHx2VPMbcMc8k9e8o4YIYo2em3+hj26RxJdTxfFL/jqXjk2sYpvQbwPQ+NgAgSwQ4AGSKAAeATBHgAJApAhwAMkWAA0CmCHAAyBQBDgCZIsABIFMEOABkigAHgEwR4ACQKQIcADJFgANApghwAMgUAQ4AmWroDR1aNExjho1MGuPEMUuT+i8666Ck/pK0+sq5yWNMuSt5CHUm3pvij+evSK7hoNetTR7jtkPSb2IwfO4hSf3Xze9IrmHiXek3plhx55zkMUbO25nU/8UZySVov/Gbk8dYsW/6DR1S61g+K72GycPTb27RG7bAASBTBDgAZIoAB4BMJQW47c/Yftj2Q7avtp22gxsAULdBB7jtGZI+JaktIt4oqUXSKWUVBgDoW+oulFZJo2y3Stpb0tPpJQEA6jHoAI+INZIukvSkpLWSno+I28oqDADQt5RdKBMknSxpjqTpkkbbPrWHdgtst9tu37ipc/CVAgB2kbILZb6kxyNiY0TslHSDpLfXNoqIhRHRFhFtUya1JLwcAKBaSoA/KekI23vbtqR5ktJOkwQA1C1lH/hiSddJWiLpwWKshSXVBQDoR9K1UCLiAkkXlFQLAGAAOBMTADJFgANApghwAMgUAQ4AmXJENOzF2t48Mu6+dVbSGDsj7WSgtZ3bkvpL0k1bD0we46WuEcljjB32clL/E8Y8llxDGVsAN289IHmMzsRK3jpqZXINS7bNTh5juNNPdjti1ONJ/R/eMTW5hoNHrEseYyjU8eiOfZJreMeo9JueTJu59p6IaKt9ni1wAMgUAQ4AmSLAASBTBDgAZIoAB4BMEeAAkCkCHAAyRYADQKYIcADIFAEOAJkiwAEgUwQ4AGSKAAeATBHgAJApAhwAMpXd9cBTdUZX8hgdSr9mc2cJy73FTurfqpbkGspQxvJMNayEbZkupa9bZUh9L2W8j6GyPIfCsijj92zE9JVcDxwA9iQEOABkigAHgEwR4ACQqX4D3PaltjfYfqjquYm2F9leVnyfsHvLBADUqmcL/HJJx9U8d56k2yNif0m3F48BAA3Ub4BHxK8kPVvz9MmSriimr5D03pLrAgD0Y7D7wPeJiLXF9DpJ+5RUDwCgTskfYkblTKBez0qxvcB2u+32jZuaf8IGAOwpBhvg621Pk6Ti+4beGkbEwohoi4i2KZOGxpl/ALAnGGyA3yjptGL6NEk/KaccAEC96jmM8GpJv5V0gO3Vts+S9GVJf2Z7maT5xWMAQAO19tcgIj7Sy6x5JdcCABgAzsQEgEwR4ACQKQIcADLV7z7wPU2L0/9mtZTxdy/tXgx7lFKW55CwpxwmO1Tex1CoYyjU0Ls95TcHAF5zCHAAyBQBDgCZIsABIFMEOABkigAHgEwR4ACQKQIcADJFgANApghwAMgUAQ4AmSLAASBTBDgAZIoAB4BMEeAAkCkCHAAy5Yho3IvZGyWt6qPJZEnPNKicFNRZLuosTw41StQ5UG+IiCm1TzY0wPtjuz0i2ppdR3+os1zUWZ4capSosyzsQgGATBHgAJCpoRbgC5tdQJ2os1zUWZ4capSosxRDah84AKB+Q20LHABQp6YEuO3jbD9me7nt83qYv5fta4v5i23PbkKNs2z/wvYjth+2/eke2hxj+3nb9xVfn290nUUdT9h+sKihvYf5tv3NYnk+YPuwJtR4QNVyus/2Ftvn1LRpyvK0fantDbYfqnpuou1FtpcV3yf00ve0os0y26c1uMav2n60+Jn+yPb4Xvr2uX40oM4v2F5T9XM9oZe+feZCA+q8tqrGJ2zf10vfhi3PfkVEQ78ktUhaIWmupBGS7pd0UE2bv5H07WL6FEnXNqHOaZIOK6bHSvp9D3UeI+nmRtfWQ61PSJrcx/wTJP1UkiUdIWlxk+ttkbROlWNbm748Jb1T0mGSHqp67iuSziumz5N0YQ/9JkpaWXyfUExPaGCNx0pqLaYv7KnGetaPBtT5BUmfrWOd6DMXdnedNfO/JunzzV6e/X01Ywv8cEnLI2JlROyQdI2kk2vanCzpimL6OknzbLuBNSoi1kbEkmL6BUlLJc1oZA0lOlnS96LiLknjbU9rYj3zJK2IiL5O6mqYiPiVpGdrnq5eB6+Q9N4euv65pEUR8WxEPCdpkaTjGlVjRNwWER3Fw7skzdwdrz0QvSzLetSTC6Xpq84iaz4k6erd9fplaUaAz5D0VNXj1Xp1ML7SplhBn5c0qSHV9aDYhXOopMU9zD7S9v22f2r74IYW9gch6Tbb99he0MP8epZ5I52i3n85hsLylKR9ImJtMb1O0j49tBlKy/VMVf7L6kl/60cjnF3s6rm0l91RQ2lZvkPS+ohY1sv8obA8JfEhZr9sj5F0vaRzImJLzewlquwGeLOk/5L040bXVzg6Ig6TdLykv7X9zibV0S/bIySdJOmHPcweKstzF1H5v3nIHq5l+3xJHZKu6qVJs9ePb0naT9JbJK1VZffEUPYR9b313ezl+YpmBPgaSbOqHs8snuuxje1WSeMkbWpIdVVsD1clvK+KiBtq50fElojYWkzfImm47ckNLlMRsab4vkHSj1T5d7RaPcu8UY6XtCQi1tfOGCrLs7C+ezdT8X1DD22avlxtny7pPZI+VvyheZU61o/dKiLWR0RnRHRJ+k4vr9/0ZSm9kjfvl3Rtb22avTyrNSPAfydpf9tziq2xUyTdWNPmRkndn+h/UNIdva2cu0uxH+wSSUsj4uu9tJnavW/e9uGqLM+G/qGxPdr22O5pVT7Yeqim2Y2S/qo4GuUISc9X7R5otF63bobC8qxSvQ6eJuknPbS5VdKxticUuwWOLZ5rCNvHSTpX0kkR8VIvbepZP3arms9b3tfL69eTC40wX9KjEbG6p5lDYXnuohmfnKpyVMTvVfnU+fziuX9TZUWUpJGq/Iu9XNLdkuY2ocajVfm3+QFJ9xVfJ0j6pKRPFm3OlvSwKp+Y3yXp7U2oc27x+vcXtXQvz+o6LeniYnk/KKmtST/30aoE8riq55q+PFX5g7JW0k5V9r2epcpnLrdLWibp55ImFm3bJH23qu+ZxXq6XNIZDa5xuSr7jbvXz+4jt6ZLuqWv9aPBdV5ZrHcPqBLK02rrLB6/KhcaWWfx/OXd62NV26Ytz/6+OBMTADLFh5gAkCkCHAAyRYADQKYIcADIFAEOAJkiwAEgUwQ4AGSKAAeATP0/E8oxtNrsWfAAAAAASUVORK5CYII=\n", 160 | "text/plain": [ 161 | "
" 162 | ] 163 | }, 164 | "metadata": { 165 | "needs_background": "light" 166 | }, 167 | "output_type": "display_data" 168 | } 169 | ], 170 | "source": [ 171 | "word, img = random_captcha(20,12, blur = 0.7)\n", 172 | "plt.imshow(img.reshape(12,20))\n", 173 | "plt.title('ground truth: {}'.format(word))" 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "metadata": {}, 179 | "source": [ 180 | "# Integrating into `pyprob`\n", 181 | "\n", 182 | "Now we will see how we integrate this existing generative model into `pyprob`\n", 183 | "\n", 184 | "In `pyprob` we only need to write a `forward()` method and inherit frmo the `Model` class. So integrating it is straight forward!\n", 185 | "\n", 186 | "In the end we want to `observe` the resulting image (we allow the pixel values of the generated image to be sampled frmo the mean determined by the latent rendered image)" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": 5, 192 | "metadata": {}, 193 | "outputs": [], 194 | "source": [ 195 | "import math\n", 196 | "import pyprob\n", 197 | "from pyprob import Model\n", 198 | "from pyprob.distributions import Normal\n", 199 | "\n", 200 | "class CaptchaModel(Model):\n", 201 | " def __init__(self):\n", 202 | " super().__init__(name=\"CAPTCHA model\") # give the model a name\n", 203 | "\n", 204 | " def forward(self): # Needed to specifcy how the generative model is run forward\n", 205 | " mu = pyprob.sample(pyprob.distributions.Normal(0,1))\n", 206 | " blur = pyprob.sample(pyprob.distributions.Normal(0.5,0.3))\n", 207 | " word, d = random_captcha(20,12, blur = blur)\n", 208 | " obs_distr = likelihoood(d,1.0)\n", 209 | " pyprob.observe(pyprob.distributions.Normal(d,0.5*torch.ones(d.shape)), name = 'obs0')\n", 210 | " return {\n", 211 | " 'word': word,\n", 212 | " 'image': d,\n", 213 | " 'blur': blur\n", 214 | " }\n", 215 | "\n", 216 | "model = CaptchaModel()" 217 | ] 218 | }, 219 | { 220 | "cell_type": "markdown", 221 | "metadata": {}, 222 | "source": [ 223 | "# The prior distribution (unconditioned)\n", 224 | "\n", 225 | "let's see what type of CAPTCHAS we get if we just run the generator. This is basically the same way we ran it above, but now it's integrated in `pyprob`." 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": 7, 231 | "metadata": {}, 232 | "outputs": [ 233 | { 234 | "name": "stdout", 235 | "output_type": "stream", 236 | "text": [ 237 | "Time spent | Time remain.| Progress | Trace | Traces/sec\n", 238 | "0d:00:00:00 | 0d:00:00:00 | #################### | 10/10 | 566.42 \n" 239 | ] 240 | }, 241 | { 242 | "data": { 243 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAARsElEQVR4nO3dfZAkdX3H8fcnd6ABLYHjimcEDEVFU1GplaCioYToQRTUsiyMGFAjZSUmkmgMxigaUzEaYyWmfChUBJUAEVGQQh58ilpBdMEDQVBO5PDggBMURI14+M0f04fLsk+30zezP+79qpranu5fT3/3t72f6fnNTHeqCklSe35r3AVIkhbHAJekRhngktQoA1ySGmWAS1KjDHBJapQBrq1Ckrcm+cSIt7lPkkqyfJHrfzDJmxfQ7qYkhy9mG2qbAa6tXpJDk6zr4XF6DdKqenVVvb2vx9PDjwGusVjsUem4tFavtg4GuHqT5MAk30ry0ySfTHJ2kn/qlh2aZF2Sv0tyG/DRbv6rkqxJcleS85Ps3s1/yPBDki8n+bNu+vgkX0vy7iQ/TvKDJEdMabtvkv/parkU2HmWmrcHPgfsnuTe7rZ7N+RyTpJPJLkHOD7JaZt+n6m/Uzf9cWBv4LPdY7xhymZemuTmJD9K8qbN6M8Htpdk5yQXJPlJ11dfTeL/71bOHUC9SLIt8GngNGAn4EzgBdOa7doteyxwQpJnAe8AXgzsBqwFztqMzf4B8F0G4fwu4CNJ0i37L+CKbtnbgeNmeoCq+hlwBHBrVT2qu93aLT4aOAfYAThjrkKq6mXAzcDzusd415TFhwAHAIcBb0nyuwBJDknykwX+rq8D1gErgV2Avwc8D8ZWzpeF6svBDPan99bgBDvnJvnGtDa/Bk6uql8CJHkpcGpVXdndfyPw4yT7LHCba6vqQ926pwPvB3bpnkyeAhzebesrST67iN/psqr6TDf9i988N2y2t1XVL4CrklwFPBG4rqq+xuDJYSF+xeBJ7rFVtQb46mKL0cOHR+Dqy+7ALfXgs6P9cFqbDVX1f9PWWbvpTlXdC9wJ7LHAbd42Zd2fd5OP6h73x93R9SZr2XzT61+s26ZM/5xBjZvrX4E1wCVJbkxyUi+VqWkGuPqyHtgjDz5M3Wtam+kv+W9lMJwCPDAevQK4BdgUvttNab/rZtSyY/d4m+w9R/vZhiKmz//ZPPVssSGNqvppVb2uqvYDjgL+JslhW2p7aoMBrr5cBtwPvCbJ8iRHAwfNs86ZwMuTPCnJI4B/Bi6vqpuqagODID82ybIkrwAet5BCqmotMAm8Lcm2SQ4BnjfHKrcDK5I8Zp6HXg0cmWSnJLsCJ87wOPstpMbNleS5SX6ne4K8m0Ff/3pLbEvtMMDVi6q6D3gh8ErgJ8CxwAXAL+dY5/PAm4FPMThqfhxwzJQmrwL+lsGwyhOA/92Mkv6EwZucdwEnAx+bo47rGTyZ3Nh9ymP3WZp+HLgKuAm4BDh72vJ3AP/QPcbr5yswyTOS3Dtfu87+wOeBexk8Wb6/qr60wHX1MBUv6KAtJcnlwAer6qPjrkV6OPIIXL1J8odJdu2GUI4Dfh+4aNx1SQ9XfoxQfToA+G9ge+BG4EVVtX68JUkPXw6hSFKjHEKRpEYZ4JLUqJGOge+807LaZ69tRrlJSWreFVf/8kdVtXL6/JEG+D57bcM3Lp7+5TxJ0lyW7bZmxlNBOIQiSY0ywCWpUQa4JDVqqABPsirJd7srqnh6S0kaoUUHeJJlwPsYXM3k8cBLkjy+r8IkSXMb5gj8IGBNVd3YnYnuLAaXoJIkjcAwAb4HD75iyTpmuJJKkhOSTCaZ3HDn/UNsTpI01RZ/E7OqTqmqiaqaWLli2ZbenCRtNYYJ8Ft48CWz9uzmSZJGYJgA/yawf5J9u6uAHwOc309ZkqT5LPqr9FW1MclrgIuBZcCpVXVtb5VJkuY01LlQqupC4MKeapEkbQa/iSlJjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDXKAJekRhngktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1aqhrYmrxVu09Me4SuOjmyXGXoCVoKeybALVx47hL4OJbV4+7hDl5BC5JjTLAJalRBrgkNcoAl6RGLTrAk+yV5EtJvpPk2iSv7bMwSdLchvkUykbgdVV1ZZJHA1ckubSqvtNTbZKkOSz6CLyq1lfVld30T4HrgD36KkySNLdexsCT7AM8Gbi8j8eTJM1v6ABP8ijgU8CJVXXPDMtPSDKZZHLDnfcPuzlJUmeoAE+yDYPwPqOqzp2pTVWdUlUTVTWxcsWyYTYnSZpimE+hBPgIcF1Vvae/kiRJCzHMEfjTgZcBz0qyursd2VNdkqR5LPpjhFX1NSA91iJJ2gx+E1OSGmWAS1KjDHBJapQXdBiT89ZeNu4SgG3GXYCWoKVyoY/n7P6kcZew5HkELkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElq1FZ3QYdVe0+MuwQAauPGcZfAxbeuHncJvVkKJ//P8qXx77RULsigLc8jcElqlAEuSY0ywCWpUQa4JDVq6ABPsizJt5Jc0EdBkqSF6eMI/LXAdT08jiRpMwwV4En2BP4Y+HA/5UiSFmrYI/B/B94A/Hq2BklOSDKZZHLDnfcPuTlJ0iaLDvAkzwXuqKor5mpXVadU1URVTaxcsWyxm5MkTTPMEfjTgaOS3AScBTwrySd6qUqSNK9FB3hVvbGq9qyqfYBjgC9W1bG9VSZJmpOfA5ekRvVy9p2q+jLw5T4eS5K0MB6BS1KjDHBJapQBLkmNWhpnoB+hpXKy+6VwAYKlop+LbIz/Ahnnrb1s3CV0thl3ARoRj8AlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElq1FZ3PnAtPX2co30pnF/9qD2eMu4SALj41tXjLkEj4hG4JDXKAJekRhngktQoA1ySGjVUgCfZIck5Sa5Pcl2Sp/ZVmCRpbsN+CuU/gIuq6kVJtgW266EmSdICLDrAkzwGeCZwPEBV3Qfc109ZkqT5DDOEsi+wAfhokm8l+XCS7XuqS5I0j2ECfDlwIPCBqnoy8DPgpOmNkpyQZDLJ5IY77x9ic5KkqYYJ8HXAuqq6vLt/DoNAf5CqOqWqJqpqYuWKZUNsTpI01aIDvKpuA36Y5IBu1mHAd3qpSpI0r2E/hfKXwBndJ1BuBF4+fEmSpIUYKsCrajUw0VMtkqTN4DcxJalRBrgkNcoAl6RGeUEHPSycf8s3x13Ckrmgw6q9h3tbqo8LbGg0PAKXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDVqq7ugw7Anu+/PxnEXsIT64uFi/H9TgPPWXjbU+qv2fmpPlQxr/P25dP5H1sw41yNwSWqUAS5JjTLAJalRQwV4kr9Ocm2Sa5KcmeSRfRUmSZrbogM8yR7AXwETVfV7wDLgmL4KkyTNbdghlOXAbydZDmwH3Dp8SZKkhVh0gFfVLcC7gZuB9cDdVXVJX4VJkuY2zBDKjsDRwL7A7sD2SY6dod0JSSaTTG648/7FVypJepBhhlAOB35QVRuq6lfAucDTpjeqqlOqaqKqJlauWDbE5iRJUw0T4DcDByfZLkmAw4Dr+ilLkjSfYcbALwfOAa4Evt091ik91SVJmsdQ50KpqpOBk3uqRZK0GfwmpiQ1ygCXpEYZ4JLUKANckhq11V3Q4aKbJ8ddgrSFbTPU2v6PLD3Ldpt5vkfgktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElq1LwBnuTUJHckuWbKvJ2SXJrkhu7njlu2TEnSdAs5Aj8NWDVt3knAF6pqf+AL3X1J0gjNG+BV9RXgrmmzjwZO76ZPB57fc12SpHksdgx8l6pa303fBuzSUz2SpAUa+k3MqiqgZlue5IQkk0kmN9x5/7CbkyR1FhvgtyfZDaD7ecdsDavqlKqaqKqJlSuWLXJzkqTpFhvg5wPHddPHAef1U44kaaEW8jHCM4HLgAOSrEvySuBfgD9KcgNweHdfkjRCy+drUFUvmWXRYT3XIknaDH4TU5IaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDXKAJekRhngktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjUpVjW5jyQZg7RxNdgZ+NKJyhmGd/bLO/rRQI1jn5npsVa2cPnOkAT6fJJNVNTHuOuZjnf2yzv60UCNYZ18cQpGkRhngktSopRbgp4y7gAWyzn5ZZ39aqBGssxdLagxckrRwS+0IXJK0QGMJ8CSrknw3yZokJ82w/BFJzu6WX55knzHUuFeSLyX5TpJrk7x2hjaHJrk7yeru9pZR19nVcVOSb3c1TM6wPEne2/Xn1UkOHEONB0zpp9VJ7kly4rQ2Y+nPJKcmuSPJNVPm7ZTk0iQ3dD93nGXd47o2NyQ5bsQ1/muS67u/6aeT7DDLunPuHyOo861Jbpnydz1ylnXnzIUR1Hn2lBpvSrJ6lnVH1p/zqqqR3oBlwPeB/YBtgauAx09r8+fAB7vpY4Czx1DnbsCB3fSjge/NUOehwAWjrm2GWm8Cdp5j+ZHA54AABwOXj7neZcBtDD7bOvb+BJ4JHAhcM2Xeu4CTuumTgHfOsN5OwI3dzx276R1HWOOzgeXd9DtnqnEh+8cI6nwr8PoF7BNz5sKWrnPa8n8D3jLu/pzvNo4j8IOANVV1Y1XdB5wFHD2tzdHA6d30OcBhSTLCGqmq9VV1ZTf9U+A6YI9R1tCjo4GP1cDXgR2S7DbGeg4Dvl9Vc32pa2Sq6ivAXdNmT90HTweeP8OqzwEuraq7qurHwKXAqlHVWFWXVNXG7u7XgT23xLY3xyx9uRALyYXezFVnlzUvBs7cUtvvyzgCfA/gh1Pur+OhwfhAm24HvRtYMZLqZtAN4TwZuHyGxU9NclWSzyV5wkgL+40CLklyRZITZli+kD4fpWOY/Z9jKfQnwC5Vtb6bvg3YZYY2S6lfX8HgVdZM5ts/RuE13VDPqbMMRy2lvnwGcHtV3TDL8qXQn4BvYs4ryaOATwEnVtU90xZfyWAY4InAfwKfGXV9nUOq6kDgCOAvkjxzTHXMK8m2wFHAJ2dYvFT680Fq8Lp5yX5cK8mbgI3AGbM0Gff+8QHgccCTgPUMhieWspcw99H3uPvzAeMI8FuAvabc37ObN2ObJMuBxwB3jqS6KZJswyC8z6iqc6cvr6p7qurebvpCYJskO4+4TKrqlu7nHcCnGbwcnWohfT4qRwBXVtXt0xcslf7s3L5pmKn7eccMbcber0mOB54LvLR7onmIBewfW1RV3V5V91fVr4EPzbL9sfclPJA3LwTOnq3NuPtzqnEE+DeB/ZPs2x2NHQOcP63N+cCmd/RfBHxxtp1zS+nGwT4CXFdV75mlza6bxuaTHMSgP0f6RJNk+ySP3jTN4I2ta6Y1Ox/40+7TKAcDd08ZHhi1WY9ulkJ/TjF1HzwOOG+GNhcDz06yYzcs8Oxu3kgkWQW8ATiqqn4+S5uF7B9b1LT3W14wy/YXkgujcDhwfVWtm2nhUujPBxnHO6cMPhXxPQbvOr+pm/ePDHZEgEcyeIm9BvgGsN8YajyEwcvmq4HV3e1I4NXAq7s2rwGuZfCO+deBp42hzv267V/V1bKpP6fWGeB9XX9/G5gY0999ewaB/Jgp88benwyeUNYDv2Iw9vpKBu+5fAG4Afg8sFPXdgL48JR1X9Htp2uAl4+4xjUMxo037Z+bPrm1O3DhXPvHiOv8eLffXc0glHebXmd3/yG5MMo6u/mnbdofp7QdW3/Od/ObmJLUKN/ElKRGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDXq/wHaOD560wYBSAAAAABJRU5ErkJggg==\n", 244 | "text/plain": [ 245 | "
" 246 | ] 247 | }, 248 | "metadata": { 249 | "needs_background": "light" 250 | }, 251 | "output_type": "display_data" 252 | }, 253 | { 254 | "data": { 255 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAAT40lEQVR4nO3de9AldX3n8ffHZ0Dl4jAMs9yRi4aspLxQs0QjcSlBA6wR41oRVhO8bNjUhl3dTWKw3FWSTSXxErJryo1BRYmyQIKoaIECRrTcKDggKBcVAsPNAcbhJsgKDN/94/TomYfnNk+fOef5Me9X1amnT/evf/09/fTzOf306dOdqkKS1J6nTboASdLiGOCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywPWUk+TUJJ8a8zL3T1JJlo1zudq2GeDapiQ5IskdI+hnbZKjRlGTtFgGuLa61vZKW6tX2y4DXIuS5NAk307y4yT/kOTcJH/aTTsiyR1J/ijJXcDHu/G/k+SmJPcmuSDJXt34Jx1+SHJZkn/fDb8pydeTfCDJfUluSXLMUNsDkny1q+USYLdZat4RuAjYK8lD3WOv7pDLeUk+leRB4E1JPrHp9Qy/pm74k8B+wOe7Pt4xtJg3JLktyY+SvGsL1ufTu9d3W5K7k3w4yTOHX/+09pXkOd3wyiSfT/Jgkm8l+dPp7fXUZIBriyXZHvgM8AlgV+Bs4DemNdujm/Zs4KQkLwf+HPhNYE/gVuCcLVjsLwPfZxDO7wM+liTdtP8DXNlN+x/AiTN1UFUPA8cAP6yqnbrHD7vJxwHnAbsAZ81VSFX9FnAb8OtdH+8bmnw4cDBwJPDuJP8SIMnhSe6fo9u/AH4BeCHwHGBv4N1z1THkQ8DDDNb5iczy+vXU47+KWowXM9h2PliDi+mcn+SKaW2eAN5TVT8FSPIG4Iyquqp7/k7gviT7L3CZt1bVR7p5zwT+N7B792byr4CjumV9LcnnF/GavlFVn+2GH/n5e8MW++OqegS4Jsk1wAuAG6rq6wzeHJ6keyM6CXh+Vd3bjfszBm9M75xrYUmmgH8L/FJV/QS4vls/Ryz2BagdBrgWYy/gztr8Smi3T2uzvqr+37R5rtr0pKoeSrKBwZ7mnQtY5l1D8/6kC9idGOx139ftXW9yK7DvQl7IkOn1L9ZdQ8M/YVDjfFYBOwBXDr1xBJha4LzL2Lz+Ub0WLXEeQtFirAP2zua7qdMDc/plLn/I4HAK8LPj0SsZhPem8N1hqP0eW1DLiq6/Tfabo/1sl9+cPv7heeoZ5WU8fwQ8AhxSVbt0j+VVtSn8N6slyXAt64HHgX2Gxm3pm5caZYBrMb4BbAROTrIsyXHAYfPMczbw5iQvTPJ04M+Ay6tqbVWtZxDkb0wyleQtwEELKaSqbgXWAH+cZPskhwO/PscsdwMrkyyfp+urgWOT7NoF5ttn6OfAhdQ4n6p6AvgI8FdJ/gVAkr2T/FrX5BrgkG7dPQM4dWjejcD5wKlJdkjyi8Bvj6IuLX0GuLZYVT0KvBZ4K3A/8EbgC8BP55jnUuC/A59msNd8EHD8UJPfAf4Q2AAcAvzTFpT07xh8yHkv8B7g7+ao43sM3kxuTnL/pjNhZvBJBsG5FrgYOHfa9D8H/lvXxx/MV2CSX03y0BxN/gi4CfhmdybMpQw+DKWqfgD8STfuRmD6GSYnA8sZHL75ZPf6Zv1d6Kkj3tBBo5DkcuDDVfXxSdeyrUvyXmCPqvJslKc498C1KEn+dZI9ukMoJwLPB7446bq2RUl+McnzM3AYg/+MPjPpurT1eRaKFutg4O+BHYGbgddV1brJlrTN2pnBYZO9GByb/0vgcxOtSGPhIRRJapSHUCSpUQa4JDVqrMfAd9t1qvbfd7txLlKSmnfld376o6paNX38WAN8/32344ov+SUxSdoSU3vedOtM4z2EIkmNMsAlqVEGuCQ1qleAJzk6yfe7u6ycMqqiJEnzW3SAdxeS/xCDO5w8DzghyfNGVZgkaW599sAPA26qqpu7q9Odw+C2VJKkMegT4Huz+Z0/7ujGbSbJSUnWJFmzfsPGHouTJA3b6h9iVtXpVbW6qlavWrmQO0RJkhaiT4Dfyea3btqHhd3bUJI0An0C/FvAc5Mc0N0Z/HjggtGUJUmaz6K/Sl9Vjyc5GfgSg7tnn1FV142sMknSnHpdC6WqLgQuHFEtkqQt4DcxJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDXKAJekRhngktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjVp0gCfZN8lXklyf5LokbxtlYZKkuS3rMe/jwO9X1VVJdgauTHJJVV0/otokSXNY9B54Va2rqqu64R8DNwB7j6owSdLcRnIMPMn+wIuAy0fRnyRpfr0DPMlOwKeBt1fVgzNMPynJmiRr1m/Y2HdxkqROrwBPsh2D8D6rqs6fqU1VnV5Vq6tq9aqVU30WJ0ka0ucslAAfA26oqtNGV5IkaSH67IG/FPgt4OVJru4ex46oLknSPBZ9GmFVfR3ICGuRJG0Bv4kpSY0ywCWpUQa4JDWqz1fpt1m3Pf5Q7z42Vv86pnp+ArHn1DN717Bu4yO9+1gK62K/ZTv1rmEU28UojOK19LVU1kXfbXypbN+zcQ9ckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDXKAJekRhngktSobe6GDpc90v896/1HvWEElUze6y/6v737OPv4V/buI7ff1b+PZ+3ca/7/eukXetdw2ii2i/S/T/gHv/KpXvOvfXx57xpOe8Ube/fxxLN26N3HCedc3Gv+c489vHcNo/GBGce6By5JjTLAJalRBrgkNcoAl6RG9Q7wJFNJvp2k/6dAkqQFG8Ue+NuAG0bQjyRpC/QK8CT7AP8G+OhoypEkLVTfPfD/CbwDeGK2BklOSrImyZr1Gzb2XJwkaZNFB3iSVwH3VNWVc7WrqtOranVVrV61cmqxi5MkTdNnD/ylwKuTrAXOAV6epN9XwCRJC7boAK+qd1bVPlW1P3A88I9V1f/7s5KkBfE8cElq1EguZlVVlwGXjaIvSdLCuAcuSY0ywCWpUQa4JDWquRs63PLYQ73mf/8ILjRf9z/Qu49fvuzu3n1cfsTuveY/95iX9q7h9Rdd2ruPv3/JIb37qPsf7N3HU6EG6H9DhtOOelXvGkbxN/L6C/+pdx99t/F6sP/v9Pev+GrvPr544Mzj3QOXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUc1dD7yvuq//dYo33ndf7z6uOO45vfvgaf2ujT4K+263YdIlaMSWynXNR7Ft9X4tT0vvGvZf1j9zZuMeuCQ1ygCXpEYZ4JLUKANckhrVK8CT7JLkvCTfS3JDkpeMqjBJ0tz6noXyv4AvVtXrkmwP7DCCmiRJC7DoAE+yHHgZ8CaAqnoUeHQ0ZUmS5tPnEMoBwHrg40m+neSjSXYcUV2SpHn0CfBlwKHA31TVi4CHgVOmN0pyUpI1Sdas37Cxx+IkScP6BPgdwB1VdXn3/DwGgb6Zqjq9qlZX1epVK6d6LE6SNGzRAV5VdwG3Jzm4G3UkcP1IqpIkzavvWSj/CTirOwPlZuDN/UuSJC1ErwCvqquB1SOqRZK0BfwmpiQ1ygCXpEYZ4JLUqG3uhg5Zsbx3H1MjuMj7ay+6oncf+2+/vtf8U1TvGp69rP/F/0fxO+l7o461j67qXcNScftjK3vNn12e1buGUdwU4qn0O9la3AOXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDUqVf0v6r9Qq1/wjLriS/uObXkzueyR/u9Z73/Fq0dQyeRtXL5j7z4+9Nm/7d3H2sf739DhtKNe1buPvurBH0+6BABq3z16zX/CORf3ruHcY17au49RePyWW3vNP7ViRe8a/vrqz/fu4xf2u+vKqnrSDeTdA5ekRhngktQoA1ySGtUrwJP8lyTXJbk2ydlJnjGqwiRJc1t0gCfZG/jPwOqq+iVgCjh+VIVJkubW9xDKMuCZSZYBOwA/7F+SJGkhFh3gVXUn8AHgNmAd8EBV9T//SJK0IH0OoawAjgMOAPYCdkzyxhnanZRkTZI16zdsXHylkqTN9DmEchRwS1Wtr6rHgPOBX5neqKpOr6rVVbV61cqpHouTJA3rE+C3AS9OskOSAEcCN4ymLEnSfPocA78cOA+4Cvhu19fpI6pLkjSPZX1mrqr3AO8ZUS2SpC3gNzElqVEGuCQ1ygCXpEYZ4JLUqF4fYrboiGc+0buPAy/7VO8+No7vPhqzmkr/PvZbtlPvPg7arv8XvJ771bN6zb8Ufh9LRf+/ENj70vt697H20VW9+/j04c/rNX+WP6t3DVuTe+CS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDXKAJekRhngktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1apu7HvgojOIa2Botfyc/d8tjD/Wa//de8x961zD1wMO9+xiFLO83/x9eckHvGg7abuttm+6BS1KjDHBJapQBLkmNMsAlqVHzBniSM5Lck+TaoXG7JrkkyY3dzxVbt0xJ0nQL2QP/BHD0tHGnAF+uqucCX+6eS5LGaN4Ar6qvAfdOG30ccGY3fCbwmhHXJUmax2KPge9eVeu64buA3UdUjyRpgXp/iFlVBdRs05OclGRNkjXrN2zsuzhJUmexAX53kj0Bup/3zNawqk6vqtVVtXrVyqlFLk6SNN1iA/wC4MRu+ETgc6MpR5K0UAs5jfBs4BvAwUnuSPJW4C+AVyS5ETiqey5JGqN5L2ZVVSfMMunIEdciSdoCfhNTkhplgEtSowxwSWqUN3SQnmIO6HkDgdM/97e9a9g46zdDxmsq/eZf6jcKcQ9ckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDXKAJekRhngktSoVI3vyutJ1gO3ztFkN+BHYyqnD+scLescnRZqBOvcUs+uqlXTR441wOeTZE1VrZ50HfOxztGyztFpoUawzlHxEIokNcoAl6RGLbUAP33SBSyQdY6WdY5OCzWCdY7EkjoGLklauKW2By5JWqCJBHiSo5N8P8lNSU6ZYfrTk5zbTb88yf4TqHHfJF9Jcn2S65K8bYY2RyR5IMnV3ePd466zq2Ntku92NayZYXqSfLBbn99JcugEajx4aD1dneTBJG+f1mYi6zPJGUnuSXLt0Lhdk1yS5Mbu54pZ5j2xa3NjkhPHXOP7k3yv+51+Jskus8w75/YxhjpPTXLn0O/12FnmnTMXxlDnuUM1rk1y9Szzjm19zquqxvoApoB/Bg4EtgeuAZ43rc1/BD7cDR8PnDuBOvcEDu2GdwZ+MEOdRwBfGHdtM9S6FthtjunHAhcBAV4MXD7heqeAuxic2zrx9Qm8DDgUuHZo3PuAU7rhU4D3zjDfrsDN3c8V3fCKMdb4SmBZN/zemWpcyPYxhjpPBf5gAdvEnLmwteucNv0vgXdPen3O95jEHvhhwE1VdXNVPQqcAxw3rc1xwJnd8HnAkUkyxhqpqnVVdVU3/GPgBmDvcdYwQscBf1cD3wR2SbLnBOs5EvjnqprrS11jU1VfA+6dNnp4GzwTeM0Ms/4acElV3VtV9wGXAEePq8aquriqHu+efhPYZ2sse0vMsi4XYiG5MDJz1dllzW8CZ2+t5Y/KJAJ8b+D2oed38ORg/FmbbgN9AFg5lupm0B3CeRFw+QyTX5LkmiQXJTlkrIX9XAEXJ7kyyUkzTF/IOh+n45n9j2MprE+A3atqXTd8F7D7DG2W0np9C4P/smYy3/YxDid3h3rOmOVw1FJal78K3F1VN84yfSmsT8APMeeVZCfg08Dbq+rBaZOvYnAY4AXAXwOfHXd9ncOr6lDgGOD3krxsQnXMK8n2wKuBf5hh8lJZn5upwf/NS/Z0rSTvAh4HzpqlyaS3j78BDgJeCKxjcHhiKTuBufe+J70+f2YSAX4nsO/Q8326cTO2SbIMWA5sGEt1Q5JsxyC8z6qq86dPr6oHq+qhbvhCYLsku425TKrqzu7nPcBnGPw7Omwh63xcjgGuqqq7p09YKuuzc/emw0zdz3tmaDPx9ZrkTcCrgDd0bzRPsoDtY6uqqruramNVPQF8ZJblT3xdws/y5rXAubO1mfT6HDaJAP8W8NwkB3R7Y8cDF0xrcwGw6RP91wH/ONvGubV0x8E+BtxQVafN0maPTcfmkxzGYH2O9Y0myY5Jdt40zOCDrWunNbsA+O3ubJQXAw8MHR4Yt1n3bpbC+hwyvA2eCHxuhjZfAl6ZZEV3WOCV3bixSHI08A7g1VX1k1naLGT72Kqmfd7yG7MsfyG5MA5HAd+rqjtmmrgU1udmJvHJKYOzIn7A4FPnd3Xj/oTBhgjwDAb/Yt8EXAEcOIEaD2fwb/N3gKu7x7HA7wK/27U5GbiOwSfm3wR+ZQJ1Htgt/5qulk3rc7jOAB/q1vd3gdUT+r3vyCCQlw+Nm/j6ZPCGsg54jMGx17cy+Mzly8CNwKXArl3b1cBHh+Z9S7ed3gS8ecw13sTguPGm7XPTmVt7ARfOtX2Muc5PdtvddxiE8p7T6+yePykXxllnN/4Tm7bHobYTW5/zPfwmpiQ1yg8xJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY36//Dg0z1otvM3AAAAAElFTkSuQmCC\n", 256 | "text/plain": [ 257 | "
" 258 | ] 259 | }, 260 | "metadata": { 261 | "needs_background": "light" 262 | }, 263 | "output_type": "display_data" 264 | } 265 | ], 266 | "source": [ 267 | "prior = model.prior_distribution()\n", 268 | "sample = prior.sample()\n", 269 | "plt.imshow(sample['image'].numpy().reshape(12,-1))\n", 270 | "plt.title('ground truth: {}'.format(sample['word']))\n", 271 | "plt.show()\n", 272 | "\n", 273 | "sample = prior.sample()\n", 274 | "plt.imshow(sample['image'].numpy().reshape(12,-1))\n", 275 | "plt.title('ground truth: {}'.format(sample['word']))\n", 276 | "plt.show()" 277 | ] 278 | }, 279 | { 280 | "cell_type": "markdown", 281 | "metadata": {}, 282 | "source": [ 283 | "## Learning the inference network\n", 284 | "\n", 285 | "\n", 286 | "During inference we want to steer the simulator with a \"smart ML agent\" such that we only get samples that match our observed captcha. This way we can inspect the latent state of the sampled traces and quickly extract the solution.\n", 287 | "\n", 288 | "We will now train the ML agent on 5000 traces." 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": 8, 294 | "metadata": {}, 295 | "outputs": [ 296 | { 297 | "name": "stdout", 298 | "output_type": "stream", 299 | "text": [ 300 | "Creating new inference network...\n", 301 | "Observable obs0: observe embedding not specified, using the default FEEDFORWARD.\n", 302 | "Observe embedding dimension: 32\n", 303 | "Train. time | Epoch| Trace | Init. loss| Min. loss | Curr. loss| T.since min | Traces/sec\n", 304 | "New layers, address: 16__forward__mu__Normal__1, distribution: Normal\n", 305 | "New layers, address: 36__forward__blur__Normal__1, distribution: Normal\n", 306 | "Total addresses: 2, parameters: 61,870\n", 307 | "0d:00:00:07 | 1 | 5,056 | +1.44e+00 | +1.31e+00 | \u001b[32m+1.32e+00\u001b[0m | 0d:00:00:00 | 617.0 \n" 308 | ] 309 | } 310 | ], 311 | "source": [ 312 | "model.learn_inference_network(\n", 313 | " num_traces=5000,\n", 314 | " observe_embeddings={'obs0': {'dim': 32, 'depth': 3}}\n", 315 | ")" 316 | ] 317 | }, 318 | { 319 | "cell_type": "markdown", 320 | "metadata": {}, 321 | "source": [ 322 | "## Sampling from the Posterior\n", 323 | "\n", 324 | "Now we will generate the posterior distribution. `pyprob` will generate 10,000 weighted traces from which we can sample according to their weights, which gives us unweighted samples frmo the posterior." 325 | ] 326 | }, 327 | { 328 | "cell_type": "code", 329 | "execution_count": 9, 330 | "metadata": {}, 331 | "outputs": [ 332 | { 333 | "data": { 334 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAAVeUlEQVR4nO3de5RdZX3G8efJTC7kYu4NuRoQjAUriCMCcolAKaRCBF0KggJSU1elha6qxWUFtF21WutqtS5dURGQWwqiIgvkpoAXCExCCASQBAiQKyGEhHBNMr/+cfbg4TCXM/PunDMvfD9rzZp99t7ve36zZ89z9rxnn70dEQIA5GdQswsAAPQPAQ4AmSLAASBTBDgAZIoAB4BMEeAAkCkCHG8ots+3fUmDn3Om7bDd2sjnBQhwvGnYnm17VQn9rLR9ZBk1ASkIcOxUuR2V5lYv3twIcPSZ7f1s32P7OdtX2l5g+9+KZbNtr7L9z7bXSfpxMf/TtlfYfsb2NbanFPNfN/xg+1bbf1NMn2b7d7a/aXuT7cdsH1O17m62bytquUnShG5qHiHpeklTbG8tvqYUQy5X2b7E9hZJp9m+sPPnqf6ZiumfSJoh6ZdFH1+oepqTbT9h+2nbX6pzW77P9jrbLVXzjre9tJjurG9B8TMutr1Psex027+sarfc9pVVj5+0vW89dSBPBDj6xPYQST+TdKGkcZIul3R8zWq7FsveKmme7cMlfU3SRyVNlvS4pCv68LTvk/RHVcL5G5J+ZNvFssskLSqW/aukU7vqICKel3SMpDURMbL4WlMsnivpKkljJF3aUyER8QlJT0g6tujjG1WLD5Y0S9IRks61/eeSZPtg2892099CSc9LOrxq9seLn6vTXElXqrJNL5P0c9uDJd0m6RDbg4oXxCGSDiyec3dJIyUt7ennQd4IcPTVAZJaJX07IrZFxNWS7qpZp0PSeRHxckS8KOlkSRdExOKIeFnSFyUdaHtmnc/5eET8ICJ2SLpIlReBSbZnSHqvpC8Xz3W7pF/21FE37oiIn0dER1Fvf30lIl6MiHsl3StpH0mKiN9FxJge2l0u6SRJsj1K0pxiXqdFEXFVRGyT9C1JwyQdEBGPSnpO0r6SDpV0g6Q1tt8h6TBJv42IjoSfBwMcAY6+miJpdbz2KmhP1qyzISJeqmnzeOeDiNgqaaOkqXU+57qqti8UkyOLfjcVR9edHlff1dbfX+uqpl9QpcZ6XCbpBNtDJZ0gaXFEVP8cr9ZXBPIqVX52qXIUPluVAL9N0q2qhPdhxWO8gRHg6Ku1kqZWDWFI0vSadWovcblGleEUSa+OR4+XtFqV4QNJGl61/q59qGVs0V+nGT2s392lN2vnP99LPaVewjMiHlDlhecYvX74RKravrYHSZqmyjaV/hTghxTTt4kAf9MgwNFXd0jaIelM262250rav5c2l0s63fa+xVHmv0taGBErI2KDKkF+iu0W25+S9LZ6CimOUtslfcX2ENsHSzq2hybrJY23PbqXrpdImmN7nO1dJZ3dRT+711NjH1wm6SxVjqSvrFn2HtsnFG/0ni3pZUl3Fstuk/QBSbtExCpJv5V0tCovkPeUXCMGGAIcfRIRr6jyb/4Zkp6VdIqka1UJle7a3Czpy5J+qspR89sknVi1yqclfV6VYZW9Jf2hDyV9XJU3OZ+RdJ6ki3uo4yFVXkwetf1s55kwXfiJKmPYKyXdKGlBzfKvSfqXoo/P9Vag7UNsb+1ltctVOWr+dUQ8XbPsF5I+JmmTpE9IOqEYD1dEPCxpqyrBrYjYIulRSb8v3jPAG5i5oQNS2V4o6fsR8eNm1/JGY/t8SXtExCnNrgUDD0fg6DPbh9netRhCOVXSuyT9qtl1AW82fOoM/TFL0v9JGqHKv+sfiYi1zS0JePNhCAUAMsUQCgBkigAHgEw1dAx8wriWmDl9cCOfEgCyt2jpy09HxMTa+Q0N8JnTB+uuG2o/tAcA6EnL5BVdXiKCIRQAyBQBDgCZIsABIFNJAW77aNt/LO60ck5ZRQEAetfvAC9uAfVdVS6BuZekk2zvVVZhAICepRyB7y9pRUQ8Wlyh7gpVbv0EAGiAlACfqtfeyWSVurjDiu15ttttt2/YyNUtAaAsO/1NzIiYHxFtEdE2cXxL7w0AAHVJCfDVeu2ttKYV8wAADZAS4HdL2tP2braHqHKHlWvKKQsA0Jt+f5Q+IrbbPlPSDZJaJF0QEctKqwwA0KOka6FExHWSriupFgBAH/BJTADIFAEOAJkiwAEgUwQ4AGSKAAeATBHgAJApAhwAMkWAA0CmCHAAyBQBDgCZIsABIFMEOABkigAHgEwR4ACQKQIcADJFgANApghwAMgUAQ4AmSLAASBTSffEbIYd0ZHUfrt2lFRJmkElvHZ2KG1bbIuBsS3eKAa7JbmPgbBflGGg/BypdZRRw878O+MIHAAyRYADQKYIcADIFAEOAJnqd4Dbnm77N7YfsL3M9lllFgYA6FnKWSjbJf1TRCy2PUrSIts3RcQDJdUGAOhBv4/AI2JtRCwupp+T9KCkqWUVBgDoWSlj4LZnSnq3pIVl9AcA6F1ygNseKemnks6OiC1dLJ9nu912+4aNfHAEAMqSFOC2B6sS3pdGxNVdrRMR8yOiLSLaJo5P/6QaAKAi5SwUS/qRpAcj4lvllQQAqEfKEfj7JX1C0uG2lxRfc0qqCwDQi36fRhgRv5PkEmsBAPQBn8QEgEwR4ACQKQIcADKV3Q0dNna8mNT+yufekVzDtkg/HfKQ4Q8n9/HbF96e1P7pbaOSa2hxCRe872j+6aVDB21P7mN06wvJfRywyyPJfSx5aWZS+22RHgtl7N/LXpmS3MfeQ9YktU/9G5OkZ7aPSO5DWtXlXI7AASBTBDgAZIoAB4BMEeAAkCkCHAAyRYADQKYIcADIFAEOAJkiwAEgUwQ4AGSKAAeATBHgAJApAhwAMkWAA0CmCHAAyBQBDgCZaugNHXaoQ1s7XkrqY8GWvZLaX/zNOUntJanllUju4+KPvS+5j2GXjk1qP2LNy8k1vDRhSHIfI1al3wgh1dYZw5P76GhNv8f3t4/ZltzHyKVDk9oP3ZS+f1/20bbkPjb/flJyH6Pfvz6p/Y4r/iy5hhLueSLpmi7ncgQOAJkiwAEgUwQ4AGSKAAeATCUHuO0W2/fYvraMggAA9SnjCPwsSQ+W0A8AoA+SAtz2NEl/LemH5ZQDAKhX6hH4f0v6gqRuz3S0Pc92u+32jRtLOSESAKCEALf9QUlPRcSintaLiPkR0RYRbePH854pAJQlJVHfL+k42yslXSHpcNuXlFIVAKBX/Q7wiPhiREyLiJmSTpT064g4pbTKAAA9YkwDADJVysWsIuJWSbeW0RcAoD4cgQNApghwAMgUAQ4AmWroDR3KsENpF81vfSn9YvVj7ns2uY9H3jkuuY9JyzYltX9x+qjkGlbPTr+JwTu+szW5D3ek/V7Xf3hkcg17XJL2+5CkYZvS61h53I6k9mNuTv8bWfPAxOQ+pi7ZntzH6tFpN2SYsjltW0rS5tOeS+5DF3c9myNwAMgUAQ4AmSLAASBTBDgAZIoAB4BMEeAAkCkCHAAyRYADQKYIcADIFAEOAJkiwAEgUwQ4AGSKAAeATBHgAJApAhwAMtXQ64G3aJBGDhqW1MeckcuS2i/45HuS2kuSPj84uYs9//fx5D46xr4lqf3mz6Rfp/hdY9Kvgf3K0DHJfcS2tGtHb5+4LbkG70i/dvTQp15I72NSS1L7jpbhyTW8/Xvrk/vYseKx5D5mLdo1qf326enXNZ81+YnkPu7vZj5H4ACQKQIcADJFgANApghwAMhUUoDbHmP7KtsP2X7Q9oFlFQYA6FnqWSj/I+lXEfER20Mkpb99DQCoS78D3PZoSYdKOk2SIuIVSa+UUxYAoDcpQyi7Sdog6ce277H9Q9sjSqoLANCLlABvlbSfpO9FxLslPS/pnNqVbM+z3W67fcPG9A86AAAqUgJ8laRVEbGweHyVKoH+GhExPyLaIqJt4vi0T4gBAP6k3wEeEeskPWl7VjHrCEkPlFIVAKBXqWeh/L2kS4szUB6VdHp6SQCAeiQFeEQskdRWUi0AgD7gk5gAkCkCHAAyRYADQKYaekOHMkxpSTsV8dip3V0avX7X7zk7uY9Rt6xJ7mPr/tOT2h8z447kGp7dnn71hMdaxyX3MWjLy0ntB29Iv0lHJO6bkvTcHqOS+/jAzHuT2i88bUZyDSPPTd8vWien3YxBkjo2b0lq//xBb02uYdqw9JuedIcjcADIFAEOAJkiwAEgUwQ4AGSKAAeATBHgAJApAhwAMkWAA0CmCHAAyBQBDgCZIsABIFMEOABkigAHgEwR4ACQKQIcADJFgANAprK7ocPIQcOS2p80elFyDX84a/fkPp7Y/Z3Jfcw49rGk9ieOvSu5hgs2Hpzch7ftSO5j+6rVSe2n3TwpuYZn/2JMch+jP/1kch+fnfibpPZ77LJ3cg0XHHV0ch/Pz9oluY+3LB2S1H76h9L+xiTppDF3J/fx1W7mcwQOAJkiwAEgUwQ4AGQqKcBt/6PtZbbvt3257bQBagBA3fod4LanSvoHSW0R8U5JLZJOLKswAEDPUodQWiXtYrtV0nBJa9JLAgDUo98BHhGrJX1T0hOS1kraHBE3llUYAKBnKUMoYyXNlbSbpCmSRtg+pYv15tlut92+YWP6+b4AgIqUIZQjJT0WERsiYpukqyUdVLtSRMyPiLaIaJs4viXh6QAA1VIC/AlJB9gebtuSjpD0YDllAQB6kzIGvlDSVZIWS7qv6Gt+SXUBAHqRdC2UiDhP0nkl1QIA6AM+iQkAmSLAASBTBDgAZIoAB4BMZXdDh1TTWtMvEv+d3a5M7uPuv52a3Md7h6bdxGCbnFxD+4YZyX28ZUdHch+t09K25/Lj0/8Ujj8g/QYZZ064PbmPKa1Dk9pPHH1fcg3TztiY3MfMwU8n9/HQQZOT2h+yy8rkGsrInO5wBA4AmSLAASBTBDgAZIoAB4BMEeAAkCkCHAAyRYADQKYIcADIFAEOAJkiwAEgUwQ4AGSKAAeATBHgAJApAhwAMkWAA0Cm3nTXAx/sluQ+ZrQOT+5jSmv6tY53xJCk9te+MDG5hk2375rcx/Dxzyf3oXEj0mp4In2/2OvwNcl9TGpJ+51K0lAPTmo/flD6tpg7In3/HlTC8eU+Q9YmtW9V+t96i3fecTJH4ACQKQIcADJFgANApghwAMhUrwFu+wLbT9m+v2reONs32V5efB+7c8sEANSq5wj8QklH18w7R9ItEbGnpFuKxwCABuo1wCPidknP1MyeK+miYvoiSR8quS4AQC/6OwY+KSI6T7BcJ2lSSfUAAOqU/CZmRISk6G657Xm22223b9i4I/XpAACF/gb4etuTJan4/lR3K0bE/Ihoi4i2iePTP+EFAKjob4BfI+nUYvpUSb8opxwAQL3qOY3wckl3SJple5XtMyT9h6S/tL1c0pHFYwBAA/V6MauIOKmbRUeUXAsAoA/4JCYAZIoAB4BMEeAAkKk33Q0dylDGBdpbynjtdFrz2buk34BgzcnXJfex8sMTkvtINWdY+g0Ijh35SHIfwwel3ZiiDANm/y7FG/vU5YGylQEAfUSAA0CmCHAAyBQBDgCZIsABIFMEOABkigAHgEwR4ACQKQIcADJFgANApghwAMgUAQ4AmSLAASBTBDgAZIoAB4BMEeAAkClHROOezN4g6fEeVpkgKf3K+jsfdZaLOsuTQ40SdfbVWyNiYu3MhgZ4b2y3R0Rbs+voDXWWizrLk0ONEnWWhSEUAMgUAQ4AmRpoAT6/2QXUiTrLRZ3lyaFGiTpLMaDGwAEA9RtoR+AAgDo1JcBtH237j7ZX2D6ni+VDbS8oli+0PbMJNU63/RvbD9heZvusLtaZbXuz7SXF17mNrrOoY6Xt+4oa2rtYbtvfLrbnUtv7NaHGWVXbaYntLbbPrlmnKdvT9gW2n7J9f9W8cbZvsr28+D62m7anFusst31qg2v8T9sPFb/Tn9ke003bHvePBtR5vu3VVb/XOd207TEXGlDngqoaV9pe0k3bhm3PXkVEQ78ktUh6RNLukoZIulfSXjXr/J2k7xfTJ0pa0IQ6J0var5geJenhLuqcLenaRtfWRa0rJU3oYfkcSddLsqQDJC1scr0tktapcm5r07enpEMl7Sfp/qp535B0TjF9jqSvd9FunKRHi+9ji+mxDazxKEmtxfTXu6qxnv2jAXWeL+lzdewTPebCzq6zZvl/STq32duzt69mHIHvL2lFRDwaEa9IukLS3Jp15kq6qJi+StIRtt3AGhURayNicTH9nKQHJU1tZA0lmivp4qi4U9IY25ObWM8Rkh6JiJ4+1NUwEXG7pGdqZlfvgxdJ+lAXTf9K0k0R8UxEbJJ0k6SjG1VjRNwYEduLh3dKmrYznrsvutmW9agnF0rTU51F1nxU0uU76/nL0owAnyrpyarHq/T6YHx1nWIH3SxpfEOq60IxhPNuSQu7WHyg7XttX29774YW9ich6Ubbi2zP62J5Pdu8kU5U938cA2F7StKkiFhbTK+TNKmLdQbSdv2UKv9ldaW3/aMRziyGei7oZjhqIG3LQyStj4jl3SwfCNtTEm9i9sr2SEk/lXR2RGypWbxYlWGAfSR9R9LPG11f4eCI2E/SMZI+a/vQJtXRK9tDJB0n6couFg+U7fkaUfm/ecCermX7S5K2S7q0m1WavX98T9LbJO0raa0qwxMD2Unq+ei72dvzVc0I8NWSplc9nlbM63Id262SRkva2JDqqtgerEp4XxoRV9cuj4gtEbG1mL5O0mDbExpcpiJidfH9KUk/U+Xf0Wr1bPNGOUbS4ohYX7tgoGzPwvrOYabi+1NdrNP07Wr7NEkflHRy8ULzOnXsHztVRKyPiB0R0SHpB908f9O3pfRq3pwgaUF36zR7e1ZrRoDfLWlP27sVR2MnSrqmZp1rJHW+o/8RSb/ubufcWYpxsB9JejAivtXNOrt2js3b3l+V7dnQFxrbI2yP6pxW5Y2t+2tWu0bSJ4uzUQ6QtLlqeKDRuj26GQjbs0r1PniqpF90sc4Nko6yPbYYFjiqmNcQto+W9AVJx0XEC92sU8/+sVPVvN9yfDfPX08uNMKRkh6KiFVdLRwI2/M1mvHOqSpnRTysyrvOXyrmfVWVHVGShqnyL/YKSXdJ2r0JNR6syr/NSyUtKb7mSPqMpM8U65wpaZkq75jfKemgJtS5e/H89xa1dG7P6jot6bvF9r5PUluTfu8jVAnk0VXzmr49VXlBWStpmypjr2eo8p7LLZKWS7pZ0rhi3TZJP6xq+6liP10h6fQG17hClXHjzv2z88ytKZKu62n/aHCdPyn2u6WqhPLk2jqLx6/LhUbWWcy/sHN/rFq3aduzty8+iQkAmeJNTADIFAEOAJkiwAEgUwQ4AGSKAAeATBHgAJApAhwAMkWAA0Cm/h/E20iwI8vrogAAAABJRU5ErkJggg==\n", 335 | "text/plain": [ 336 | "
" 337 | ] 338 | }, 339 | "metadata": { 340 | "needs_background": "light" 341 | }, 342 | "output_type": "display_data" 343 | } 344 | ], 345 | "source": [ 346 | "condition = prior.sample()\n", 347 | "plt.imshow(condition['image'].numpy().reshape(12,-1))\n", 348 | "plt.title('ground truth: {}'.format(condition['word']))\n", 349 | "plt.show()" 350 | ] 351 | }, 352 | { 353 | "cell_type": "code", 354 | "execution_count": 10, 355 | "metadata": {}, 356 | "outputs": [ 357 | { 358 | "name": "stdout", 359 | "output_type": "stream", 360 | "text": [ 361 | "Time spent | Time remain.| Progress | Trace | Traces/sec\n", 362 | "0d:00:00:53 | 0d:00:00:00 | #################### | 10000/10000 | 188.66 \n" 363 | ] 364 | } 365 | ], 366 | "source": [ 367 | "posterior = model.posterior_distribution(\n", 368 | " num_traces=10000,\n", 369 | " inference_engine=pyprob.InferenceEngine.IMPORTANCE_SAMPLING_WITH_INFERENCE_NETWORK,\n", 370 | " observe={'obs0': condition['image'].numpy()}\n", 371 | ")" 372 | ] 373 | }, 374 | { 375 | "cell_type": "markdown", 376 | "metadata": {}, 377 | "source": [ 378 | "## Making the Animations\n", 379 | "\n", 380 | "Now that we have the prior and posterior distributions we can sample a few captchas to observe the effect of the conditioning on the above \"observation\"\n", 381 | "\n", 382 | "We will draw 100 captchas each. If everything worked, the prior will be changing very rapidly while the posterior should almost always yield a CAPTCHA that looks much like the condition" 383 | ] 384 | }, 385 | { 386 | "cell_type": "code", 387 | "execution_count": 12, 388 | "metadata": { 389 | "scrolled": false 390 | }, 391 | "outputs": [ 392 | { 393 | "name": "stdout", 394 | "output_type": "stream", 395 | "text": [ 396 | "0\n", 397 | "10\n", 398 | "20\n", 399 | "30\n", 400 | "40\n", 401 | "50\n", 402 | "60\n", 403 | "70\n", 404 | "80\n", 405 | "90\n", 406 | "make animation\n", 407 | "display image\n" 408 | ] 409 | }, 410 | { 411 | "data": { 412 | "text/html": [ 413 | "" 414 | ], 415 | "text/plain": [ 416 | "" 417 | ] 418 | }, 419 | "execution_count": 12, 420 | "metadata": {}, 421 | "output_type": "execute_result" 422 | }, 423 | { 424 | "data": { 425 | "text/plain": [ 426 | "
" 427 | ] 428 | }, 429 | "metadata": {}, 430 | "output_type": "display_data" 431 | } 432 | ], 433 | "source": [ 434 | "for i in range(100):\n", 435 | " if i%10 == 0:\n", 436 | " print(i)\n", 437 | " sample = prior.sample()\n", 438 | " plt.imshow(sample['image'].reshape(12,-1))\n", 439 | " plt.title('generated {}'.format(sample['word']))\n", 440 | " plt.savefig('foranim_prior_{}.png'.format(str(i).zfill(6)))\n", 441 | "plt.clf()\n", 442 | "print('make animation')\n", 443 | "uniq = time.time()\n", 444 | "!convert -delay 5 -loop 0 foranim_prior_*.png prior_{uniq}.gif\n", 445 | "!rm foranim_prior_*png\n", 446 | "print('display image')\n", 447 | "IPython.display.Image(url='prior_{}.gif'.format(uniq).format(uniq))" 448 | ] 449 | }, 450 | { 451 | "cell_type": "code", 452 | "execution_count": 13, 453 | "metadata": { 454 | "scrolled": false 455 | }, 456 | "outputs": [ 457 | { 458 | "name": "stdout", 459 | "output_type": "stream", 460 | "text": [ 461 | "0\n", 462 | "10\n", 463 | "20\n", 464 | "30\n", 465 | "40\n", 466 | "50\n", 467 | "60\n", 468 | "70\n", 469 | "80\n", 470 | "90\n", 471 | "make animation\n", 472 | "display image\n" 473 | ] 474 | }, 475 | { 476 | "data": { 477 | "text/html": [ 478 | "" 479 | ], 480 | "text/plain": [ 481 | "" 482 | ] 483 | }, 484 | "execution_count": 13, 485 | "metadata": {}, 486 | "output_type": "execute_result" 487 | }, 488 | { 489 | "data": { 490 | "text/plain": [ 491 | "
" 492 | ] 493 | }, 494 | "metadata": {}, 495 | "output_type": "display_data" 496 | } 497 | ], 498 | "source": [ 499 | "for i in range(100):\n", 500 | " if i%10 == 0:\n", 501 | " print(i)\n", 502 | " sample = posterior.sample()\n", 503 | " plt.imshow(sample['image'].reshape(12,-1))\n", 504 | " plt.title('ground truth: {}, generated: {}'.format(condition['word'],sample['word']))\n", 505 | " plt.savefig('foranim_posterior_{}.png'.format(str(i).zfill(6)))\n", 506 | "plt.clf()\n", 507 | "print('make animation')\n", 508 | "uniq = time.time()\n", 509 | "!convert -delay 5 -loop 0 foranim_posterior_*.png posterior_{uniq}.gif\n", 510 | "!rm -rf foranim_posterior_*png\n", 511 | "print('display image')\n", 512 | "IPython.display.Image(url='posterior_{}.gif'.format(uniq))" 513 | ] 514 | }, 515 | { 516 | "cell_type": "markdown", 517 | "metadata": {}, 518 | "source": [ 519 | "## Inspecting the Latent State\n", 520 | "\n", 521 | "For evey sample we draw from the posterior we have full access to the trace. We can e.g. construct marginal posterior distributions for specific letters of the captcha" 522 | ] 523 | }, 524 | { 525 | "cell_type": "code", 526 | "execution_count": 14, 527 | "metadata": {}, 528 | "outputs": [ 529 | { 530 | "data": { 531 | "text/plain": [ 532 | "" 533 | ] 534 | }, 535 | "execution_count": 14, 536 | "metadata": {}, 537 | "output_type": "execute_result" 538 | }, 539 | { 540 | "data": { 541 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAATnklEQVR4nO3de7QdZXnH8e9TAoT7LZECST1EWAiVJcQISJRagqsIoWAbKEUEXNRURagoImILlEVpVCxSbFOD2ISL3FKQm6IS5C6BXMAAAY0BJFkIATFc0iiRp3/M0HXAcPbMyTlndsbvZ6299lze2fPsfeb8zrvfPXtOZCaSpHb5o6YLkCQNPMNdklrIcJekFjLcJamFDHdJaqFhTRcAMGLEiOzp6Wm6DElaq8ydO/fZzBy5unVdEe49PT3MmTOn6TIkaa0SEU+82TqHZSSphQx3SWohw12SWshwl6QWMtwlqYUMd0lqIcNdklrIcJekFjLcJamFDHdJquPuu4tbl+uKyw9I0lpj772brqASe+6SVIc9d0lqoVNPLe5vvbXRMjqx5y5JLWS4S1ILGe6S1EKGuyS1kB+oSlIdX/ta0xVUYrhLUh277dZ0BZU4LCNJddx8c3HrcvbcJamOs84q7vfbr9k6OrDnLkktZLhLUgsZ7pLUQoa7JLWQH6hKUh3f+EbTFVRiuEtSHTvt1HQFlTgsI0l1XH99cety9twlqY6vfrW4P+igZuvowJ67JLWQ4S5JLWS4S1ILGe6S1EJ+oCpJdVx8cdMVVGK4S1Ido0c3XUElDstIUh1XXFHculylcI+IEyPioYh4MCIui4jhEbF9RMyOiEURcUVErFe2Xb+cX1Su7xnMJyBJQ2rq1OLW5TqGe0RsB5wAjMvMdwDrAIcDXwLOzcwdgOeBY8tNjgWeL5efW7aTJA2hqsMyw4ANImIYsCHwFLAvMLNcPwM4pJw+uJynXD8hImJgypUkVdEx3DNzKXAO8AuKUF8OzAV+nZmrymZLgO3K6e2AJ8ttV5Xtt3rj40bE5IiYExFzli1btqbPQ5LUS5VhmS0oeuPbA9sCGwH7r+mOM3NaZo7LzHEjR45c04eTJPVS5VTI/YDHMnMZQERcDYwHNo+IYWXvfBSwtGy/FBgNLCmHcTYDnhvwyiWpCTNndm7TBaqMuf8C2CsiNizHzicADwM/AiaVbY4Gri2nryvnKdffkpk5cCVLUoNGjChuXa5jzz0zZ0fETGAesAqYD0wDbgQuj4izymUXlptcCFwcEYuAX1GcWSNJa5WeU25c7fJJC24GYOau+/3eusenHDioNdVR6RuqmXk6cPobFi8G9lhN25XAoWtemiR1n77CvZv4DVVJaiHDXZJayHCXpBYy3CWphbzkryTVcMyhZzRdQiWGuyTVsHLd4U2XUInDMpJUw5HzbuTIeas/B76bGO6SVMPER+5g4iN3NF1GR4a7JLWQ4S5JLWS4S1ILGe6S1EKeCilJNRx+xJSmS6jEnrsktZDhLkk1fGz21Xxs9tVNl9GR4S5JNUz4+b1M+Pm9TZfRkeEuSS1kuEtSCxnuktRCngopSTWsHLZ+0yVUYrhLUg3HHPbPTZdQicMyktRChrsk1XD8XZdx/F2XNV1GR4a7JNUw/okHGP/EA02X0ZHhLkktZLhLUgsZ7pLUQp4KKUk1PL/Bpk2XUInhLkk1fOJDpzZdQiUOy0hSCxnuklTDybdN5+TbpjddRkcOy0hSDWOXPtJ0CZXYc5ekFjLcJamFDHdJaqFK4R4Rm0fEzIh4JCIWRsR7ImLLiPhhRPysvN+ibBsR8e8RsSgifhIRYwf3KUjS0HlqkxE8tcmIpsvoqOoHqucBN2XmpIhYD9gQOBWYlZlTIuIU4BTg88AHgR3L257A1PJektZ6Jx50UtMlVNKx5x4RmwH7ABcCZOZvM/PXwMHAjLLZDOCQcvpg4KIs3ANsHhHbDHjlkqQ3VWVYZntgGfDfETE/Ir4ZERsBW2fmU2WbXwJbl9PbAU/22n5Juex1ImJyRMyJiDnLli3r/zOQpCF02s3TOO3maU2X0VGVcB8GjAWmZubuwMsUQzD/LzMTyDo7zsxpmTkuM8eNHDmyzqaS1JhdnlnMLs8sbrqMjqqE+xJgSWbOLudnUoT9068Nt5T3z5TrlwKje20/qlwmSRoiHcM9M38JPBkRO5WLJgAPA9cBR5fLjgauLaevA44qz5rZC1jea/hGkjQEqp4tczxwaXmmzGLgoxR/GK6MiGOBJ4DDyrbfBQ4AFgEryraSpCFUKdwz835g3GpWTVhN2wSOW8O6JKkrLd7y984P6UpeOEySajh1/+ObLqESLz8gSS1kuEtSDWffdD5n33R+02V05LCMJNUw5ldrx5nd9twlqYUMd0lqIcNdklrIMXdJquHht4xpuoRKDHdJquHM/SY3XUIlDstIUgsZ7pJUw7nXn8O515/TdBkdOSwjSTVs8+KzTZdQiT13SWohw12SWshwl6QWcsxdkmqYt93bmy6hEsNdkmr48p8d03QJlTgsI0ktZLhLUg1Trzmbqdec3XQZHTksI0k1bPG/LzRdQiX23CWphQx3SWohw12SWsgxd0mq4a63vrPpEiox3CWphvPH/23TJVTisIwktZDhLkk1TL/ydKZfeXrTZXTksIwk1TB81W+aLqESe+6S1EKGuyS1kOEuSS3kmLsk1TDrbXs0XUIlhrsk1XDBnn/VdAmVOCwjSS1kuEtSDZd/+xQu//YpTZfRUeVwj4h1ImJ+RNxQzm8fEbMjYlFEXBER65XL1y/nF5XrewandEnSm6nTc/8HYGGv+S8B52bmDsDzwLHl8mOB58vl55btJElDqFK4R8Qo4EDgm+V8APsCM8smM4BDyumDy3nK9RPK9pKkIVK15/414GTg1XJ+K+DXmbmqnF8CbFdObwc8CVCuX162f52ImBwRcyJizrJly/pZviRpdTqeChkRE4FnMnNuRLx/oHacmdOAaQDjxo3LgXpcSRpMN7z9fU2XUEmV89zHA38ZEQcAw4FNgfOAzSNiWNk7HwUsLdsvBUYDSyJiGLAZ8NyAVy5JDbhk7IFNl1BJx2GZzPxCZo7KzB7gcOCWzPww8CNgUtnsaODacvq6cp5y/S2Zac9cUisMf2Ulw19Z2XQZHa3Jee6fBz4TEYsoxtQvLJdfCGxVLv8M0P0nhEpSRdOvOoPpV53RdBkd1br8QGbeCtxaTi8Gfu8iC5m5Ejh0AGqTJPWT31CVpBYy3CWphQx3SWohL/krSTXM3HW/pkuoxHCXpBrWlnB3WEaSathixXK2WLG86TI6sucuSTVM/c6/AnD4EVMarqRv9twlqYUMd0lqIcNdklrIcJekFvIDVUmq4ZLdD2i6hEoMd0mq4Yad92m6hEoclpGkGrZ5YRnbvND9/xrUnrsk1XDuDV8FPM9dktQAw12SWshwl6QWMtwlqYX8QFWSarhgjw81XUIlhrsk1TBrhz2bLqESh2UkqYYxzy1hzHNLmi6jI3vuklTD2d//OuB57pKkBhjuktRChrsktZDhLkkt5AeqklTD+Xsf3nQJlRjuklTDXT27NV1CJQ7LSFINuzy9mF2eXtx0GR0Z7pJUw2mzpnHarGlNl9GR4S5JLWS4S1ILGe6S1EKGuyS1UMdTISNiNHARsDWQwLTMPC8itgSuAHqAx4HDMvP5iAjgPOAAYAVwTGbOG5zyJWlofXmfo5suoZIqPfdVwGczcxdgL+C4iNgFOAWYlZk7ArPKeYAPAjuWt8nA1AGvWpIaMm/UzswbtXPTZXTUMdwz86nXet6Z+SKwENgOOBiYUTabARxSTh8MXJSFe4DNI2KbAa9ckhowdslCxi5Z2HQZHdUac4+IHmB3YDawdWY+Va76JcWwDRTB/2SvzZaUyyRprXfy7TM4+fYZnRs2rHK4R8TGwP8An87MF3qvy8ykGI+vLCImR8SciJizbNmyOptKkjqoFO4RsS5FsF+amVeXi59+bbilvH+mXL4UGN1r81HlstfJzGmZOS4zx40cObK/9UuSVqNjuJdnv1wILMzMf+u16jrgtY+Njwau7bX8qCjsBSzvNXwjSRoCVa4KOR74CLAgIu4vl50KTAGujIhjgSeAw8p136U4DXIRxamQHx3QiiVJHXUM98y8E4g3WT1hNe0TOG4N65KkrnTmhMlNl1CJ13OXpBoe3npM0yVU4uUHJKmG8Y/fz/jH7+/csGH23CWphuPvvhzo/v/IZM9dklrIcJekFjLcJamFDHdJaiE/UJWkGk79i081XUIlhrsk1bB4q1FNl1CJwzKSVMOERbOZsGh202V0ZM9dkmr42L3XADBrhz0brqRv9twlqYUMd0lqIcNdklrIcJekFvIDVUmq4cSJn226hEoMd0mq4alN147/+eywjCTVMHHh7UxceHvTZXRkz10aYD2n3Fh7m8enHDgIlWgwHDn/uwDcsPM+DVfSN3vuktRChrsktZDhLkktZLhLUgv5gaok1fCJQ77QdAmVGO5dxLMspO73/IabNV1CJQ7LSFINkxbczKQFNzddRkeGuyTVYLhLkhpjuEtSCxnuktRChrsktZCnQkpSDcccekbTJVRiuEtSDSvXHd50CZU4LCNJNRw570aOnFf/C4dDzZ671AX8dvLaY+IjdwBwydjufv3tuUtSCw1Kzz0i9gfOA9YBvpmZUwZjP5JUVX/eHa3NBjzcI2Id4D+ADwBLgPsi4rrMfHig9wW+nZWk1RmMnvsewKLMXAwQEZcDBwODEu6qxz+GaoM/tF54f0RmDuwDRkwC9s/MvyvnPwLsmZmfekO7ycDkcnYn4NEBLaQwAnh2ENsP1TbWZV3dto111d9mMLw1M0eudk1mDugNmEQxzv7a/EeArw/0firWMmcw2w/VNtZlXd22jXXV32aob4NxtsxSYHSv+VHlMknSEBmMcL8P2DEito+I9YDDgesGYT+SpDcx4B+oZuaqiPgU8H2KUyG/lZkPDfR+Kpo2yO2Hahvr6r599Gebbq2rP9tYV5cb8A9UJUnN8xuqktRChrsktZDhDkRET0Q8OMT7PCMiThqkxz4hIhZGxKWD9Pj9fr0i4u7B2mYN63qpP9tpcETE5hHxyabrWJsZ7u30SeADmfnhpgt5o8zceyi20dCIwmDkyOYUx7H6qZXhHhHfiYi5EfFQ+U3YKoZFxKVlj3dmRGxYYT9HRcRPIuKBiLi4QvsvRsRPI+JOim/ldhQRR0bEvRFxf0R8o7x2T1/t/wsYA3wvIk6suI9/iohHI+LOiLis4juKdSLigvI1/kFEbFBxX7V7yP3cZkxEzI+Id9fdto/H7ImIRyJievlzvDQi9ouIuyLiZxGxRx/bLaz7ekXEZyLiwfL26Rr1VT6Oex+TVX/25X4ejYiLgAd5/fdaVtd+o4i4sfw9eTAi/qbTPoApwNvK4/4rFWt6sNf8SRFxRh/tp0TEcb3m+3wnHRGfi4gTyulzI+KWcnrfwXqHvMaa/hbVYNyALcv7DSgOvq06tO8BEhhfzn8LOKnDNn8K/BQY0XuffbR/F7AA2BDYFFhUYR87A9cD65bz/wkcVeH5P/5aXRXavhu4HxgObAL8rEJdPcAqYLdy/krgyIr7e6kfP89K25R1PUjxh3M+8M6B3Eev570rRcdobnmsBMX1k74zUK9Xr+NlI2Bj4CFg94E8jvtzTPbaz6vAXhVf278GLug1v1nVn2WNn9/r2gMnAWf00X534LZe8w8Do/tovxdwVTl9B3AvsC5wOvD3dY/pobi1sucOnBARDwD3UPQqdqywzZOZeVc5fQnw3g7t96X4YT8LkJm/6tD+fcA1mbkiM1+g2he7JlD8At4XEfeX82MqbFfHeODazFyZmS9S/DGp4rHMvL+cnkvxy9UNRgLXAh/OzAcG4fEfy8wFmfkqReDOyuI3fgF9vwZ1X6/3UhwvL2fmS8DVFMdQJ3WO4/4ck695IjPvqdh2AfCBiPhSRLwvM5fX2M+gyMz5wFsiYtuIeCfwfGY+2ccmc4F3RcSmwG+AHwPjKF7DOwa94H5o3X9iioj3A/sB78nMFRFxK0WvtJM3nvDfDV8ACGBGZn6h6UJW4ze9pn9H8S6pGywHfkERaoNxJdLez/vVXvOv0vfv01C9XkN1HL9ctWFm/jQixgIHAGdFxKzMPHOA61nF64eZq/zOX0VxLaw/Bq7oq2FmvhIRjwHHAHcDPwH+HNgBWNiPegddG3vum1H8FV4REW+neDtVxZ9ExHvK6SOAOzu0vwU4NCK2AoiILTu0vx04JCI2iIhNgIMq1DQLmBQRb3ltHxHx1grb1XEXcFBEDI+IjYGJA/z4Q+23wIeAoyLiiKaLWQN3UBwvG0bERhTPqUoPsc5x3J9jsraI2BZYkZmXAF8BxlbY7EWKYcKqnqboiW8VEetT7Ti+guLyKJMogr6TOyiGe24vpz8OzC/fuXWd1vXcgZuAj0fEQorLCFd96/gocFxEfIuixze1r8aZ+VBE/AtwW0T8jmKM95g+2s+LiCuAB4BnKK7B06fMfDgi/hH4QRRnJLwCHAc8Ue0pdZaZ90XEdRQ9kacp3kI3/rb5DWr98mTmyxExEfhhRLyUmWvdtY3K42U6xdguFFdanV9h08rHcX+OyX7aFfhKRLxKcQx/otMGmflc+UH1g8D3MvNzHdq/EhFnUrxeS4FHKuzjofKP2tLMfKrC87gD+CLw4/IYW0mXDsmAlx8QEBEbZ+ZL5ZkVtwOTM3Ne03UBlO+M5mXmQL9jaZ2I6AFuyMx39HP7Myg+WD5nAMtSQ9rYc1d90yJiF4pxyhldFOzbArcCho1Ukz13SWqhNn6gKkl/8Ax3SWohw12SWshwl6QWMtwlqYX+Dx++TQqr4YC9AAAAAElFTkSuQmCC\n", 542 | "text/plain": [ 543 | "
" 544 | ] 545 | }, 546 | "metadata": { 547 | "needs_background": "light" 548 | }, 549 | "output_type": "display_data" 550 | } 551 | ], 552 | "source": [ 553 | "letter_num = 2\n", 554 | "\n", 555 | "reverse_alphabet = {v:k for k,v in alphabet.items()} \n", 556 | "c,_,_ = plt.hist([reverse_alphabet[posterior.sample()['word'][letter_num]] for i in range(1000)], bins = np.linspace(-0.5,len(alphabet)+0.5,len(alphabet)+2))\n", 557 | "\n", 558 | "plt.xticks(range(len(alphabetorder)),list(alphabetorder));\n", 559 | "plt.vlines(reverse_alphabet[condition['word'][letter_num]],0,1.2*max(c), colors = 'r', linestyles = 'dashed')" 560 | ] 561 | }, 562 | { 563 | "cell_type": "markdown", 564 | "metadata": {}, 565 | "source": [ 566 | "## Solving the CAPTCHA\n", 567 | "\n", 568 | "A good solution is the MAP, the maximum a posteriori value of the posterior. If we sample traces from the posterior and pick the solution which appears most often, it has a good chance to be the true solution" 569 | ] 570 | }, 571 | { 572 | "cell_type": "code", 573 | "execution_count": 15, 574 | "metadata": {}, 575 | "outputs": [ 576 | { 577 | "data": { 578 | "text/plain": [ 579 | "'wpw'" 580 | ] 581 | }, 582 | "execution_count": 15, 583 | "metadata": {}, 584 | "output_type": "execute_result" 585 | } 586 | ], 587 | "source": [ 588 | "solutions, counts = np.unique([posterior.sample()['word'] for i in range(1000)], return_counts = True)\n", 589 | "solutions[np.argmax(counts)]" 590 | ] 591 | }, 592 | { 593 | "cell_type": "code", 594 | "execution_count": null, 595 | "metadata": {}, 596 | "outputs": [], 597 | "source": [] 598 | } 599 | ], 600 | "metadata": { 601 | "kernelspec": { 602 | "display_name": "Python 3", 603 | "language": "python", 604 | "name": "python3" 605 | }, 606 | "language_info": { 607 | "codemirror_mode": { 608 | "name": "ipython", 609 | "version": 3 610 | }, 611 | "file_extension": ".py", 612 | "mimetype": "text/x-python", 613 | "name": "python", 614 | "nbconvert_exporter": "python", 615 | "pygments_lexer": "ipython3", 616 | "version": "3.7.3" 617 | } 618 | }, 619 | "nbformat": 4, 620 | "nbformat_minor": 2 621 | } 622 | -------------------------------------------------------------------------------- /MarkovPath.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Forcing a Markov Chain to arrive at the right place\n", 8 | "\n", 9 | "\n", 10 | "Markov Chains are sequences of random variables, where the value of the current sequence index only depends on the value just before it. Given a current position, the next position is simply defined by a probability distribution p(z,z').\n", 11 | "\n", 12 | "Here we look at a Markov Chain starting at `[0,0]`, which at each point can make exactly two moves \"up\" and \"down\".\n", 13 | "\n", 14 | "If you just let the Markov Chain run freely, we will it be after 10 iteratinos? or 20? or 100? You might have heard about the Galton Board." 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "Before we start, we will import some libraries as always" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 1, 27 | "metadata": {}, 28 | "outputs": [ 29 | { 30 | "name": "stdout", 31 | "output_type": "stream", 32 | "text": [ 33 | "\u001b[1m\u001b[31mWarning: Empirical distributions on disk may perform slow because GNU DBM is not available. Please install and configure gdbm library for Python for better speed.\u001b[0m\n" 34 | ] 35 | } 36 | ], 37 | "source": [ 38 | "import pyprob\n", 39 | "%matplotlib inline\n", 40 | "import matplotlib.pyplot as plt\n", 41 | "from pyprob import Model\n", 42 | "import numpy as np\n", 43 | "\n", 44 | "import math\n", 45 | "import pyprob\n", 46 | "from pyprob import Model\n", 47 | "from pyprob.distributions import Normal, Uniform, Categorical\n", 48 | "import torch\n", 49 | "import IPython\n" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 2, 55 | "metadata": {}, 56 | "outputs": [ 57 | { 58 | "data": { 59 | "text/html": [ 60 | "" 61 | ], 62 | "text/plain": [ 63 | "" 64 | ] 65 | }, 66 | "execution_count": 2, 67 | "metadata": {}, 68 | "output_type": "execute_result" 69 | } 70 | ], 71 | "source": [ 72 | "IPython.display.Image(url = 'https://thumbs.gfycat.com/QuaintTidyCockatiel-max-1mb.gif')" 73 | ] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "metadata": {}, 78 | "source": [ 79 | "## Implementing the Galton Board in Code\n", 80 | "\n", 81 | "we will have a Galton board with 20 steps. At each step the position can move either up and down in a fair way." 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 3, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "class MarkovChainPath(Model):\n", 91 | " def __init__(self):\n", 92 | " super().__init__(name=\"Markov Chain Path\") # give the model a name\n", 93 | "\n", 94 | " def forward(self): # Needed to specifcy how the generative model is run forward\n", 95 | " # sample the (latent) mean variable to be inferred:\n", 96 | " coords = [[0,0]]\n", 97 | " moves = {0: -1, 1: 1, 2: 1}\n", 98 | " for i in range(1,20):\n", 99 | " last = coords[-1][1]\n", 100 | " move = pyprob.sample(Categorical([1/2.,1/2.]), name = 'input{}'.format(i))\n", 101 | " move = moves[move.item()]\n", 102 | " coords.append([i,last+move])\n", 103 | "\n", 104 | " obs_distr = Normal(coords[-1][1], 0.1)\n", 105 | " pyprob.observe(obs_distr, name='obs0') # NOTE: observe -> denotes observable variables\n", 106 | " return coords\n", 107 | "\n", 108 | "model = MarkovChainPath()" 109 | ] 110 | }, 111 | { 112 | "cell_type": "markdown", 113 | "metadata": {}, 114 | "source": [ 115 | "## Learning to Guide the Chain\n", 116 | "\n", 117 | "As usual we will be training our inference network!" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 4, 123 | "metadata": {}, 124 | "outputs": [ 125 | { 126 | "name": "stdout", 127 | "output_type": "stream", 128 | "text": [ 129 | "Creating new inference network...\n", 130 | "Observable obs0: observe embedding not specified, using the default FEEDFORWARD.\n", 131 | "Observe embedding dimension: 100\n", 132 | "Train. time | Epoch| Trace | Init. loss| Min. loss | Curr. loss| T.since min | Traces/sec\n", 133 | "New layers, address: 74__forward__move__Categorical(len_probs:2)__1, distribution: Categorical\n", 134 | "New layers, address: 74__forward__move__Categorical(len_probs:2)__2, distribution: Categorical\n", 135 | "New layers, address: 74__forward__move__Categorical(len_probs:2)__3, distribution: Categorical\n", 136 | "New layers, address: 74__forward__move__Categorical(len_probs:2)__4, distribution: Categorical\n", 137 | "New layers, address: 74__forward__move__Categorical(len_probs:2)__5, distribution: Categorical\n", 138 | "New layers, address: 74__forward__move__Categorical(len_probs:2)__6, distribution: Categorical\n", 139 | "New layers, address: 74__forward__move__Categorical(len_probs:2)__7, distribution: Categorical\n", 140 | "New layers, address: 74__forward__move__Categorical(len_probs:2)__8, distribution: Categorical\n", 141 | "New layers, address: 74__forward__move__Categorical(len_probs:2)__9, distribution: Categorical\n", 142 | "New layers, address: 74__forward__move__Categorical(len_probs:2)__10, distribution: Categorical\n", 143 | "New layers, address: 74__forward__move__Categorical(len_probs:2)__11, distribution: Categorical\n", 144 | "New layers, address: 74__forward__move__Categorical(len_probs:2)__12, distribution: Categorical\n", 145 | "New layers, address: 74__forward__move__Categorical(len_probs:2)__13, distribution: Categorical\n", 146 | "New layers, address: 74__forward__move__Categorical(len_probs:2)__14, distribution: Categorical\n", 147 | "New layers, address: 74__forward__move__Categorical(len_probs:2)__15, distribution: Categorical\n", 148 | "New layers, address: 74__forward__move__Categorical(len_probs:2)__16, distribution: Categorical\n", 149 | "New layers, address: 74__forward__move__Categorical(len_probs:2)__17, distribution: Categorical\n", 150 | "New layers, address: 74__forward__move__Categorical(len_probs:2)__18, distribution: Categorical\n", 151 | "New layers, address: 74__forward__move__Categorical(len_probs:2)__19, distribution: Categorical\n", 152 | "Total addresses: 19, parameters: 132,895\n", 153 | "0d:00:00:47 | 1 | 10,048 | +1.32e+01 | +1.23e+01 | \u001b[32m+1.27e+01\u001b[0m | 0d:00:00:29 | 216.0 \n" 154 | ] 155 | } 156 | ], 157 | "source": [ 158 | "model.learn_inference_network(\n", 159 | " num_traces=10000,\n", 160 | " observe_embeddings={'obs0': {'dim': 100, 'depth': 5}}\n", 161 | ")" 162 | ] 163 | }, 164 | { 165 | "cell_type": "markdown", 166 | "metadata": {}, 167 | "source": [ 168 | "# Generating Prior and Posterior Traces\n", 169 | "\n", 170 | "As in the other examples we have prior and posterior traces. What will they look like? Have a guess." 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": 5, 176 | "metadata": {}, 177 | "outputs": [ 178 | { 179 | "name": "stdout", 180 | "output_type": "stream", 181 | "text": [ 182 | "Time spent | Time remain.| Progress | Trace | Traces/sec\n", 183 | "0d:00:00:04 | 0d:00:00:00 | #################### | 1000/1000 | 229.92 \n" 184 | ] 185 | } 186 | ], 187 | "source": [ 188 | "prior = model.prior_traces(\n", 189 | " num_traces=1000,\n", 190 | ")" 191 | ] 192 | }, 193 | { 194 | "cell_type": "markdown", 195 | "metadata": {}, 196 | "source": [ 197 | "We will also generate some sampled from the **conditioned** model. Feel free to change the condition value from 5 to a number you like." 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": 14, 203 | "metadata": {}, 204 | "outputs": [ 205 | { 206 | "name": "stdout", 207 | "output_type": "stream", 208 | "text": [ 209 | "Time spent | Time remain.| Progress | Trace | Traces/sec\n", 210 | "0d:00:00:18 | 0d:00:00:00 | #################### | 1000/1000 | 53.91 \n" 211 | ] 212 | } 213 | ], 214 | "source": [ 215 | "condition = {'obs0': 5}\n", 216 | "posterior = model.posterior_traces(\n", 217 | " num_traces=1000,\n", 218 | " inference_engine=pyprob.InferenceEngine.IMPORTANCE_SAMPLING_WITH_INFERENCE_NETWORK,\n", 219 | " observe=condition\n", 220 | ")" 221 | ] 222 | }, 223 | { 224 | "cell_type": "markdown", 225 | "metadata": {}, 226 | "source": [ 227 | "Let's get a representative set of paths for both the **conditioned** model as well as the **unconditioned** one" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": 15, 233 | "metadata": {}, 234 | "outputs": [], 235 | "source": [ 236 | "post_paths = [posterior.sample().result for i in range(1000)]\n", 237 | "prior_paths = [prior.sample().result for i in range(1000)]" 238 | ] 239 | }, 240 | { 241 | "cell_type": "markdown", 242 | "metadata": {}, 243 | "source": [ 244 | "## The Plots!\n", 245 | "\n", 246 | "As expected the conoditioned paths always arrive at the same spot, no matter where they wandered off to in the middle of their path. At some point the proposals from the agent will steer it in the correct direction.\n", 247 | "\n", 248 | "\n", 249 | "We can also plot the final position distribution. As expected the unconditinoed one follows a normal distribution while the conditioned one, is basically a delta distribution on the conditioned value" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": 16, 255 | "metadata": {}, 256 | "outputs": [ 257 | { 258 | "data": { 259 | "text/plain": [ 260 | "Text(0, 0.5, 'position')" 261 | ] 262 | }, 263 | "execution_count": 16, 264 | "metadata": {}, 265 | "output_type": "execute_result" 266 | }, 267 | { 268 | "data": { 269 | "image/png": "\n", 270 | "text/plain": [ 271 | "
" 272 | ] 273 | }, 274 | "metadata": { 275 | "needs_background": "light" 276 | }, 277 | "output_type": "display_data" 278 | } 279 | ], 280 | "source": [ 281 | "for p in prior_paths:\n", 282 | " p = np.asarray(p)\n", 283 | " plt.plot(p[:,0],p[:,1], c = 'orangered', alpha = 0.1) \n", 284 | " \n", 285 | " \n", 286 | "for p in post_paths:\n", 287 | " p = np.asarray(p)\n", 288 | " plt.plot(p[:,0],p[:,1] + 0.5, c = 'steelblue', alpha = 0.1, )\n", 289 | "plt.xlabel('step number')\n", 290 | "plt.ylabel('position')" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": 17, 296 | "metadata": {}, 297 | "outputs": [ 298 | { 299 | "data": { 300 | "text/plain": [ 301 | "Text(0.5, 0, 'final position')" 302 | ] 303 | }, 304 | "execution_count": 17, 305 | "metadata": {}, 306 | "output_type": "execute_result" 307 | }, 308 | { 309 | "data": { 310 | "image/png": "\n", 311 | "text/plain": [ 312 | "
" 313 | ] 314 | }, 315 | "metadata": { 316 | "needs_background": "light" 317 | }, 318 | "output_type": "display_data" 319 | } 320 | ], 321 | "source": [ 322 | "c1,_,_ = plt.hist([p[-1][1] for p in prior_paths], np.linspace(-15.5,15.5,32), alpha = 0.2, label = 'prior');\n", 323 | "c2,_,_ = plt.hist([p[-1][1] for p in post_paths], np.linspace(-15.5,15.5,32), alpha = 0.2, label = 'posterior');\n", 324 | "plt.vlines(condition['obs0'],0,np.max([c1.max(),c2.max()])*1.2, linestyle = 'dashed', color = 'red')\n", 325 | "plt.legend()\n", 326 | "plt.gcf().set_size_inches(5,5)\n", 327 | "plt.xlabel('final position')" 328 | ] 329 | }, 330 | { 331 | "cell_type": "code", 332 | "execution_count": null, 333 | "metadata": {}, 334 | "outputs": [], 335 | "source": [] 336 | } 337 | ], 338 | "metadata": { 339 | "kernelspec": { 340 | "display_name": "Python 3", 341 | "language": "python", 342 | "name": "python3" 343 | }, 344 | "language_info": { 345 | "codemirror_mode": { 346 | "name": "ipython", 347 | "version": 3 348 | }, 349 | "file_extension": ".py", 350 | "mimetype": "text/x-python", 351 | "name": "python", 352 | "nbconvert_exporter": "python", 353 | "pygments_lexer": "ipython3", 354 | "version": "3.7.3" 355 | } 356 | }, 357 | "nbformat": 4, 358 | "nbformat_minor": 2 359 | } 360 | --------------------------------------------------------------------------------