├── .gitignore ├── 01_beta_vae.ipynb ├── 02_diffusion.ipynb ├── 03_normalizing_flows.ipynb ├── 04_continuous_normalizing_flows.ipynb ├── 05_consistency_models.ipynb ├── 06_flow_matching.ipynb ├── 07_diffusion_distillation.ipynb ├── 08_discrete_walk_jump_sampling.ipynb ├── LICENSE.md ├── README.md ├── assets └── midwit.png └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | .DS_Store 163 | data/* 164 | data/*.hdf5 165 | notebooks/ckpts 166 | notebooks/ckpts* 167 | scripts/*.out 168 | logging/ 169 | wandb 170 | wandb/* 171 | notebooks/data -------------------------------------------------------------------------------- /03_normalizing_flows.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 5, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from functools import partial\n", 10 | "\n", 11 | "import jax\n", 12 | "\n", 13 | "from jax import config\n", 14 | "config.update(\"jax_enable_x64\", True)\n", 15 | "\n", 16 | "import jax.numpy as np\n", 17 | "import flax.linen as nn\n", 18 | "import optax\n", 19 | "import diffrax as dfx\n", 20 | "import math\n", 21 | "from tensorflow_probability.substrates import jax as tfp\n", 22 | "\n", 23 | "from sklearn import datasets, preprocessing\n", 24 | "\n", 25 | "import matplotlib.pyplot as plt\n", 26 | "from tqdm import trange" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "## The Dataset" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 6, 39 | "metadata": {}, 40 | "outputs": [ 41 | { 42 | "data": { 43 | "text/plain": [ 44 | "(-2.0, 2.0)" 45 | ] 46 | }, 47 | "execution_count": 6, 48 | "metadata": {}, 49 | "output_type": "execute_result" 50 | }, 51 | { 52 | "data": { 53 | "image/png": "", 54 | "text/plain": [ 55 | "
" 56 | ] 57 | }, 58 | "metadata": {}, 59 | "output_type": "display_data" 60 | } 61 | ], 62 | "source": [ 63 | "n_samples = 100_000\n", 64 | "\n", 65 | "x, _ = datasets.make(n_samples=n_samples, noise=.06)\n", 66 | "\n", 67 | "scaler = preprocessing.StandardScaler()\n", 68 | "x = scaler.fit_transform(x)\n", 69 | "\n", 70 | "plt.hist2d(x[:, 0], x[:, 1], bins=100)\n", 71 | "plt.xlim(-2 ,2)\n", 72 | "plt.ylim(-2, 2)" 73 | ] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "metadata": {}, 78 | "source": [ 79 | "## Implementation" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 10, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "class MLP(nn.Module):\n", 89 | " \"\"\" A simple MLP in Flax.\n", 90 | " \"\"\"\n", 91 | " hidden_dim: int = 32\n", 92 | " out_dim: int = 2\n", 93 | " n_layers: int = 3\n", 94 | "\n", 95 | " @nn.compact\n", 96 | " def __call__(self, x):\n", 97 | " for _ in range(self.n_layers):\n", 98 | " x = nn.Dense(features=self.hidden_dim)(x)\n", 99 | " x = nn.gelu(x)\n", 100 | " x = nn.Dense(features=self.out_dim)(x)\n", 101 | " return x" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 59, 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "class AffineBijector:\n", 111 | " def __init__(self, shift_and_log_scale):\n", 112 | " self.shift_and_log_scale = shift_and_log_scale\n", 113 | "\n", 114 | " def forward_and_log_det(self, x):\n", 115 | " shift, log_scale = np.split(self.shift_and_log_scale, 2, axis=-1)\n", 116 | " y = x * np.exp(log_scale) + shift\n", 117 | " log_det = log_scale\n", 118 | " return y, log_det\n", 119 | "\n", 120 | " def inverse_and_log_det(self, y):\n", 121 | " shift, log_scale = np.split(self.shift_and_log_scale, 2, axis=-1)\n", 122 | " x = (y - shift) * np.exp(-log_scale)\n", 123 | " log_det = -log_scale\n", 124 | " return x, log_det\n", 125 | "\n", 126 | "class MaskedCoupling:\n", 127 | " def __init__(self, mask, conditioner, bijector):\n", 128 | " \"\"\"Coupling layer with masking and conditioner.\"\"\"\n", 129 | " self.mask = mask\n", 130 | " self.conditioner = conditioner\n", 131 | " self.bijector = bijector\n", 132 | "\n", 133 | " def forward_and_log_det(self, x):\n", 134 | " \"\"\"Transforms masked indices of `x` conditioned on unmasked indices using bijector.\"\"\"\n", 135 | " x_cond = np.where(self.mask, 0.0, x)\n", 136 | " bijector_params = self.conditioner(x_cond)\n", 137 | " y, log_det = self.bijector(bijector_params).forward_and_log_det(x)\n", 138 | " log_det = np.where(self.mask, log_det, 0.0)\n", 139 | " y = np.where(self.mask, y, x)\n", 140 | " return y, np.sum(log_det, axis=-1)\n", 141 | "\n", 142 | " def inverse_and_log_det(self, y):\n", 143 | " \"\"\"Transforms masked indices of `y` conditioned on unmasked indices using bijector.\"\"\"\n", 144 | " y_cond = np.where(self.mask, 0.0, y)\n", 145 | " bijector_params = self.conditioner(y_cond)\n", 146 | " x, log_det = self.bijector(bijector_params).inverse_and_log_det(y)\n", 147 | " log_det = np.where(self.mask, log_det, 0.0)\n", 148 | " x = np.where(self.mask, x, y)\n", 149 | " return x, np.sum(log_det, axis=-1)" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 102, 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "class RealNVP(nn.Module):\n", 159 | " n_transforms: int = 4\n", 160 | " d_params: int = 2\n", 161 | " d_hidden: int = 128\n", 162 | " n_layers: int = 4\n", 163 | "\n", 164 | " def setup(self):\n", 165 | " self.mask_list = [np.arange(self.d_params) % 2 == i % 2 for i in range(self.n_transforms)]\n", 166 | " self.conditioner_list = [MLP(self.d_hidden, 2 * self.d_params, self.n_layers) for _ in range(self.n_transforms)]\n", 167 | " self.base_dist = tfp.distributions.Normal(loc=np.zeros(self.d_params), scale=np.ones(self.d_params))\n", 168 | " \n", 169 | " def log_prob(self, x):\n", 170 | " log_prob = np.zeros(x.shape[:-1])\n", 171 | " for mask, conditioner in zip(self.mask_list[::-1], self.conditioner_list[::-1]):\n", 172 | " x, ldj = MaskedCoupling(mask, conditioner, AffineBijector).inverse_and_log_det(x)\n", 173 | " log_prob += ldj\n", 174 | " return log_prob + self.base_dist.log_prob(x).sum(-1)\n", 175 | "\n", 176 | " def sample(self, sample_shape, key, n_transforms=None):\n", 177 | " x = self.base_dist.sample(key, sample_shape)\n", 178 | " for mask, conditioner in zip(self.mask_list[:n_transforms], self.conditioner_list[:n_transforms]):\n", 179 | " x, _ = MaskedCoupling(mask, conditioner, AffineBijector).forward_and_log_det(x)\n", 180 | " return x\n", 181 | "\n", 182 | " def __call__(self, x):\n", 183 | " return self.log_prob(x)" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": 103, 189 | "metadata": {}, 190 | "outputs": [ 191 | { 192 | "data": { 193 | "text/plain": [ 194 | "Array([-3.35979128, -2.8513697 , -2.8622152 , -3.03087519], dtype=float64)" 195 | ] 196 | }, 197 | "execution_count": 103, 198 | "metadata": {}, 199 | "output_type": "execute_result" 200 | } 201 | ], 202 | "source": [ 203 | "model = RealNVP()\n", 204 | "\n", 205 | "key = jax.random.PRNGKey(0)\n", 206 | "params = model.init(key, x[:2])\n", 207 | "\n", 208 | "# Test log_prob\n", 209 | "model.apply(params, x[:4])" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 104, 215 | "metadata": {}, 216 | "outputs": [], 217 | "source": [ 218 | "opt = optax.adam(learning_rate=1e-3)\n", 219 | "opt_state = opt.init(params)" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": 105, 225 | "metadata": {}, 226 | "outputs": [], 227 | "source": [ 228 | "@jax.jit\n", 229 | "def train_step(params, opt_state, x):\n", 230 | " def loss_fn(params):\n", 231 | " return -model.apply(params, x).mean()\n", 232 | " loss, grad = jax.value_and_grad(loss_fn)(params)\n", 233 | " updates, opt_state = opt.update(grad, opt_state)\n", 234 | " params = optax.apply_updates(params, updates)\n", 235 | " return loss, params, opt_state" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": 106, 241 | "metadata": {}, 242 | "outputs": [ 243 | { 244 | "name": "stderr", 245 | "output_type": "stream", 246 | "text": [ 247 | "100%|██████████| 10000/10000 [01:31<00:00, 108.79it/s, val=1.407515638835097]\n" 248 | ] 249 | } 250 | ], 251 | "source": [ 252 | "n_steps = 10_000\n", 253 | "n_batch = 128\n", 254 | "\n", 255 | "with trange(n_steps) as steps:\n", 256 | " for step in steps:\n", 257 | "\n", 258 | " # Draw a random batches from x\n", 259 | " key, subkey = jax.random.split(key)\n", 260 | " idx = jax.random.choice(key, x.shape[0], shape=(n_batch,))\n", 261 | " \n", 262 | " x_batch = x[idx]\n", 263 | "\n", 264 | " loss, params, opt_state = train_step(params, opt_state, x_batch)\n", 265 | "\n", 266 | " steps.set_postfix(val=loss)" 267 | ] 268 | }, 269 | { 270 | "cell_type": "markdown", 271 | "metadata": {}, 272 | "source": [ 273 | "Generate some samples." 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": 107, 279 | "metadata": {}, 280 | "outputs": [], 281 | "source": [ 282 | "n_samples = 100_000\n", 283 | "x_sample = model.apply(params, key, (n_samples,), method=model.sample)" 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": 108, 289 | "metadata": {}, 290 | "outputs": [ 291 | { 292 | "data": { 293 | "text/plain": [ 294 | "(-2.0, 2.0)" 295 | ] 296 | }, 297 | "execution_count": 108, 298 | "metadata": {}, 299 | "output_type": "execute_result" 300 | }, 301 | { 302 | "data": { 303 | "image/png": "", 304 | "text/plain": [ 305 | "
" 306 | ] 307 | }, 308 | "metadata": {}, 309 | "output_type": "display_data" 310 | } 311 | ], 312 | "source": [ 313 | "\n", 314 | "plt.hist2d(x_sample[:, 0], x_sample[:, 1], bins=100)\n", 315 | "plt.xlim(-2 ,2)\n", 316 | "plt.ylim(-2, 2)" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": null, 322 | "metadata": {}, 323 | "outputs": [], 324 | "source": [] 325 | } 326 | ], 327 | "metadata": { 328 | "kernelspec": { 329 | "display_name": "torch-mps", 330 | "language": "python", 331 | "name": "python3" 332 | }, 333 | "language_info": { 334 | "codemirror_mode": { 335 | "name": "ipython", 336 | "version": 3 337 | }, 338 | "file_extension": ".py", 339 | "mimetype": "text/x-python", 340 | "name": "python", 341 | "nbconvert_exporter": "python", 342 | "pygments_lexer": "ipython3", 343 | "version": "3.9.13" 344 | } 345 | }, 346 | "nbformat": 4, 347 | "nbformat_minor": 2 348 | } 349 | -------------------------------------------------------------------------------- /04_continuous_normalizing_flows.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Continuous normalizing flows" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "name": "stderr", 17 | "output_type": "stream", 18 | "text": [ 19 | "/opt/homebrew/Caskroom/miniforge/base/envs/torch-mps/lib/python3.9/site-packages/equinox/_ad.py:753: UserWarning: As of Equinox 0.10.7, `equinox.filter_custom_vjp.defvjp` is deprecated in favour of `.def_fwd` and `.def_bwd`. This new API supports symbolic zeros, which allow for more efficient autodifferentiation rules. In particular:\n", 20 | "- the fwd and bwd functions take an extra `perturbed` argument, which indicates which primals actually need a gradient. You can use this to skip computing the gradient for any unperturbed value. (You can also safely just ignore this if you wish.)\n", 21 | "- `None` was previously passed to indicate a symbolic zero gradient for all objects that weren't inexact arrays, but all inexact arrays always had an array-valued gradient. Now, `None` may also be passed to indicate that an inexact array has a symbolic zero gradient.\n", 22 | " warnings.warn(\n" 23 | ] 24 | } 25 | ], 26 | "source": [ 27 | "from functools import partial\n", 28 | "\n", 29 | "import jax\n", 30 | "\n", 31 | "from jax import config\n", 32 | "config.update(\"jax_enable_x64\", True)\n", 33 | "\n", 34 | "import jax.numpy as np\n", 35 | "import flax.linen as nn\n", 36 | "import optax\n", 37 | "import diffrax as dfx\n", 38 | "import math\n", 39 | "from tensorflow_probability.substrates import jax as tfp\n", 40 | "\n", 41 | "from sklearn import datasets, preprocessing\n", 42 | "\n", 43 | "import matplotlib.pyplot as plt\n", 44 | "from tqdm import trange" 45 | ] 46 | }, 47 | { 48 | "attachments": {}, 49 | "cell_type": "markdown", 50 | "metadata": {}, 51 | "source": [ 52 | "## The Dataset" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 2, 58 | "metadata": {}, 59 | "outputs": [ 60 | { 61 | "data": { 62 | "text/plain": [ 63 | "(-2.0, 2.0)" 64 | ] 65 | }, 66 | "execution_count": 2, 67 | "metadata": {}, 68 | "output_type": "execute_result" 69 | }, 70 | { 71 | "data": { 72 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjoAAAGiCAYAAADulWxzAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy89olMNAAAACXBIWXMAAA9hAAAPYQGoP6dpAABPhElEQVR4nO3df3RV1Zk//ncguRci+UUMScCQBrSEFgVEgaR+JXyMEnRZow7jr1XBsWgd6FeE1oJTYaHjJ6NicepY0WkVnak/ylTsYBUHosEvGkCDqFCIBSLhRxJsyG80AXK+fzje9uznSbJzc29yc/J+rZW1vNtzzt0399ybzX728+wox3EcEBEREXnQoL7uABEREVG4cKBDREREnsWBDhEREXkWBzpERETkWRzoEBERkWdxoENERESexYEOEREReRYHOkRERORZHOgQERGRZ3GgQ0RERJ4V1oFOUVERLr74YsTFxWHEiBEoLCxEeXl5l+etW7cO2dnZGDJkCM4//3y88cYb4ewmEREReVRYBzpbtmzBggULsG3bNmzatAmnTp3CFVdcgZaWlg7Pef/993HTTTfh9ttvx0cffYTCwkIUFhZi9+7d4ewqEREReVBUb27q+cUXX2DEiBHYsmULLr30UvWYG264AS0tLXj99dcDbdOnT8ekSZOwZs2a3uoqEREReUB0bz5ZQ0MDAGD48OEdHlNaWorFixe72mbNmoXXXntNPb61tRWtra2Bx+3t7Thx4gSSk5MRFRXV804TERFR2DmOg6amJowcORKDBoUu4NRrA5329nYsWrQI3/ve9zBhwoQOj6uurkZqaqqrLTU1FdXV1erxRUVFWLlyZUj7SkRERH3j8OHDOOecc0J2vV4b6CxYsAC7d+/G1q1bQ3rdZcuWuWaAGhoaMHr0aBw+fBjx8fEhfS4iIiIKj8bGRmRkZCAuLi6k1+2Vgc7ChQvx+uuv49133+1ylJaWloaamhpXW01NDdLS0tTj/X4//H6/aI+Pj+dAh4iIqJ8J9bKTsGZdOY6DhQsXYv369Xj77beRlZXV5Tk5OTkoLi52tW3atAk5OTnh6iYRERF5VFhndBYsWIAXX3wRf/jDHxAXFxdYZ5OQkIChQ4cCAG699VaMGjUKRUVFAIC7774bM2bMwGOPPYarrroKL7/8Mj788EM888wz4ewqEREReVBYZ3SeeuopNDQ0IC8vD+np6YGfV155JXBMZWUlqqqqAo9zc3Px4osv4plnnsHEiRPxX//1X3jttdc6XcBMREREpOnVOjq9obGxEQkJCWhoaOAaHSIion4iXH+/udcVEREReRYHOkRERORZHOgQERGRZ3GgQ0RERJ7FgQ4RERF5Fgc6RERE5Fkc6BAREZFncaBDREREnsWBDhEREXkWBzpERETkWWHd1LO/uHzQnL7uAnUiOuVs0Xb6i7/0+bUi5Tmjx53rvnb5/l7vAxFRVza1r+uT5+WMDhEREXkWBzpERETkWQxdUcQLZYilJ9cywz/hvpZtuEkLVdmcZ9OvnlyLiCgScEaHiIiIPIsDHSIiIvIsDnSIiIjIs7hGh/qlvkiXNq+vrmcZnijbTtQH94TKtdQPrHGczZodwG49ji2u7SGiSMUZHSIiIvIsDnSIiIjIsxi6oogX0srIRhXhjthUFz49dpQ8cdvHoqmt4GLRFltW0WW/6i6UrzvuJaVfZohr+kTZh0SfPM/oQ0/YpscTEfU2zugQERGRZ3GgQ0RERJ7FgQ4RERF5FtfoUMQL5fYFJ7OSRFtTZoxoS7VICY+uaxFtp5XjzPU4AHBySlaX14+r+FI2KutvYPSjKWuo3bVCiLujE1Gk4owOEREReRYHOkRERORZDF1Rv6SGSizSvX318pjkjR/I8yzS0GvyRoi2OCU0VjVBhsYyNhx3PVZDakoI6vg02Y+Mje7UcS1MpV0r6YC8lvk71MJz4f7SYBiMiEKJMzpERETkWRzoEBERkWcxdEW9xjYkISoQK8doYSotzALjWjUTzxKHpB6Q/Tp4swxLjXnRHW5qi5NPp2VwnVKOM69vXhsAWpNkvxL3yn+b+OrdoarK2fI1jn5T/m60zC+z/8my0DOitYw0bTNT4z1Ss8Fe2mZ1LfOLiqEsIrLFGR0iIiLyLA50iIiIyLM40CEiIiLP4hod6jW2O1zbrL9QqxInybUpbUbadvLHynnKep+zjsnnNHcTz3j+M3HMwYXfFm0xTfJaybtPuc9T1gRpTiW0i7aWke7XnVZ6Shxz4Ea5PmbcU/L3HGsUcdZ2UEeWTL1P2imvZe6Yrq3H4Q7nRBRuYZ3Reffdd3H11Vdj5MiRiIqKwmuvvdbp8SUlJYiKihI/1dXV4ewmEREReVRYBzotLS2YOHEinnzyyW6dV15ejqqqqsDPiBF2/9olIiIi+lthDV3Nnj0bs2fP7vZ5I0aMQGJiYug7RH3KtppxtJFefLp8v7yYluKshK5akwYbj+02vPQ1yhBRdf4Z1+PmUTJMpYWWtH9PxFbUuRty5GBeC5/hmLxW/Xj3c2op7u1xMpylhaXqsrv+t4+Wqq6FDcXv/qbp4hgt5KW+3wbbkBfT0IkoIhcjT5o0Cenp6bj88svx3nvvdXpsa2srGhsbXT9EREREQIQNdNLT07FmzRr8/ve/x+9//3tkZGQgLy8PO3fu7PCcoqIiJCQkBH4yMjJ6scdEREQUyaIcx3F65YmiorB+/XoUFhZ267wZM2Zg9OjR+I//+A/1/7e2tqK1tTXwuLGxERkZGWhoaEB8fLzVc1w+aE63+kRdi1Y2xdRCEm0FF4s2M+ShhVO08IlaedcIS9UqlZG1CsdaNWMze2rYURmmav9BrWg7/YcU0dYy0v1Yq4x8+GoZzvIpGVwnprrDUsP2ytCVxuZa5/17m9W1NFWXuH/X5kamgL6ZaWxZhWhTKy8btPuLG4QSRY5N7es6/f+NjY1ISEjo1t9vGxGfXj516lRs3bq1w//v9/vh9/t7sUdERETUX0RU6Eqza9cupKen93U3iIiIqB8K64xOc3Mz9u//63RyRUUFdu3aheHDh2P06NFYtmwZjh49ihdeeAEA8PjjjyMrKwvf/e538dVXX+HXv/413n77bfzP//xPOLtJREREHhXWgc6HH36ImTNnBh4vXrwYADB37lysXbsWVVVVqKysDPz/trY2LFmyBEePHkVsbCwuuOACbN682XUN6h9s10ugXq4BMVPAbdevaOtq6rLd60SS9sl1NW1xcmJTS+02U879dWfkMQ/EyhMh1xPFHXJXDd57r1yrMv4R+bpr8uTrzhRhb5lKrq17qX1ePuewt831RPL9aVh+UrSdFOfJNTnaa0zbPFi0xVqsx1HX9pxQ7i9tJ3SlzSalnYj6p7AOdPLy8tDZWue1a9e6Ht9777249957w9klIiIiGkAifo0OERERUbAiPuuK+qcmyyq42uaccds+dj2uU66lhak0ZqhKS1XXNsHUqgu3xXf97wJfvWzTXqO54aUMP+nVhrVNSc20eu33XHOdrOLc3Nwsn/Mc9+/L7CcAJN+jVEbOE0048KA7jJe2XoaptPCfFp5LLXGHwdQUdI1WQZuIBhTO6BAREZFncaBDREREnsWBDhEREXkW1+hQ902fKJrMtRzajuBaSrC2FsZn7KrdPMou/dvcVkHrR/Moue6ldoLsQ/pWuQ5FnKdsJ6G1pb56VLS1Gq9RW1ej0X6H5mu0XdsTd0hWFPfVd71lBqD04ZBc59S2Y5j7sVLR3XbN1MGb3et2xryodEuhpqGbO8dDfhFymwgi7+CMDhEREXkWBzpERETkWQxdUae0XchxQIZitGqzJi1VWQs3JX9shr3kruSa+vFdP6e2g7YZFtHO05gpz0AHFXaV36FNuEkLG2khqGD5lKrU0cZ7G5eYJftlGerzNbmP095rLUx1aI48bthe2WYqv0tWRh67aJtoO62EXhFkZWTujk4U+TijQ0RERJ7FgQ4RERF5FkNX1CktpFKlbKiZ8fxnrsdaJV4t7HLWMRkGqZx9lnGM7FebUhl53FMylGTSXs/oN5VwkFGdGQBq78x1PfbVy2tpoQw188es7KuE/pKfln3QmOFF7TVq1Zm1qsEnp7hDVVoF4qZM+d5WXSKf0wxnmRuZAnqYSqsS3ZTpfqxVTx77snyNamhJXl4cZxt+YpiKKPJxRoeIiIg8iwMdIiIi8iwOdIiIiMizuEZnABNrOyxTbIcdbRdt2poc04EbZZp4xkalom6ce92ObUr1YW3tkJJObjJTqgEAytqO5KffdzdoacoKrRKvubZDS8+3Tl0219po76N2rbGjRJvNruBaFWTtvLp89/2l7VSur8eR67ZMNru4A0DzJcpasd2y/+YXofa713CNDlHk44wOEREReRYHOkRERORZDF0NEOpUvBHyUFOjLSoEA4CvUYazTOMekiEVM7wB6KExk5Yu7WuSKc5mardWDVgLP2i/CxHqU1LQ1RCRRUjQNmyonmsRPlGP0doswpm+ehny0t5HM1TVmjS4y2MAvXSAGYI88GCsOGbs/fL1NI+S4UyNTVq9VgJA/R0SUUThjA4RERF5Fgc6RERE5Fkc6BAREZFncY1OP2ebBmtFWYOgrauwUZctx9Bt8TLVN/XVz0SbmaqupQ37lbVDIv1boe3Grm0JoLFaa9PP12xYrRVS1ibFoetUe1+9bNN2iTe3E1GvtUOuvam7UK7b0VLJq3O6Tl8fU5HY5TFE1D9wRoeIiIg8iwMdIiIi8iyGrvo561CJUsXXDBtoKbVJO+WltJ2jzfRyZ3yzPGafDC2Yab0AkFrSdTVjjU0ISgvNWFcgpo4p4SwtTGhDq9hshi9PniNLEMT+n1rR9qVy/dFX/lm0tRVc7O6D7Q7wRBTxOKNDREREnsWBDhEREXkWQ1cDhLpxpREiKP8nGWrI2Cgr1zZfJsNSw9e7w1Knj8kwlVYFV2VuUqlkg6mhhSBDUMGGqbwY8jJfU49ej/k+arQsLyXMmrTZfZyWwXfySIpVt07fKY+rH+8OhcU0yPsrrVRmiMWe6Drrsb/fE0T9HWd0iIiIyLM40CEiIiLP4kCHiIiIPItrdPo5dcdxJWVb7tktxR7Rxr1yXU36s37R1ureJBwxDXbX0iov+8w1DZa7i/f2Wggvrr0I5WsKeu2Ttp7MYJYzAIDkj2UyedUlcq2NJnGv+36NOyQrKsdW1MkTtR3NTR68T4j6E87oEBERkWdxoENERESeFdbQ1bvvvotHH30UZWVlqKqqwvr161FYWNjpOSUlJVi8eDH27NmDjIwM/PznP8e8efPC2c2IEGwoRgtTaRWOtWqzpmFHZRhJCy01j5LjY3PzxDEvKn1X0o1jtdRxM71YqbrrxbBRfxHsvapuqKqkl6vXMu4JrVRB7UQZpvI1yUu1xcm2U6JNbvzZlCkrgmub0or+K+ny2j1NROER1hmdlpYWTJw4EU8++aTV8RUVFbjqqqswc+ZM7Nq1C4sWLcIPf/hDvPXWW+HsJhEREXlUWGd0Zs+ejdmzZ1sfv2bNGmRlZeGxxx4DAIwfPx5bt27F6tWrMWvWrHB1k4iIiDwqorKuSktLkZ+f72qbNWsWFi1a1OE5ra2taG1tDTxubGwMV/fCSpuu10IE4hglE0QLU5kbeAJAdY57ej5pn8xkSdqphSRkv0RGihKm0sJsvnolH4zT+hEt2LChFqayZtwTsUoYLFZGbHH4ahluGnZU3udt8e7JbW1jWW0z27p82Y+4l4zfTw/uZ5vvAIZxiToXUYuRq6urkZqa6mpLTU1FY2MjvvxS24cYKCoqQkJCQuAnIyOjN7pKRERE/UBEDXSCsWzZMjQ0NAR+Dh8+3NddIiIioggRUaGrtLQ01NTUuNpqamoQHx+PoUOHquf4/X74/bKAHREREVFEDXRycnLwxhtvuNo2bdqEnJycPupReNim51qt21HWwlRZrkvQ1uSYyu+SfR37sgwjmruJN10oz9NSgrUquKe77BUNNGJXdWW9j5a+nrFBrrWpU+5Nk3k/A8CJqbJactpmWX6hreBi12Ot3IP6nBZrbbgeh6j7whq6am5uxq5du7Br1y4AX6eP79q1C5WVlQC+DjvdeuutgeN/9KMf4eDBg7j33nuxb98+/OpXv8Lvfvc73HPPPeHsJhEREXlUWAc6H374ISZPnozJkycDABYvXozJkydj+fLlAICqqqrAoAcAsrKy8Mc//hGbNm3CxIkT8dhjj+HXv/41U8uJiIgoKFGO4zh93YlQamxsREJCAhoaGhAfH291zuWD5oS5V6EjpvC18JYyha9N15sVjrUqsskft9j1ywxBaRWPLdOLbV4jkckMGQF62Egrc3C4wB2Cih55Uhzj2zFMtimfGZGaroSXeU/TQLSpfV2n/z+Yv982+n3WFREREVFHONAhIiIiz+JAh4iIiDwrotLLyc2m/Lu2M3Jdlqw5dHyaPDVxr/txy0h5zP9zxx7R9tnN3+q6X8q6BFtcv0DBsE3jFtuVAMjYmGQcI9fo7L1X1uvKVJYcmOvhkjbXi2Nsd3Inop7jjA4RERF5Fgc6RERE5FkMXfUBq5BUB8zdkrXdxdsmyqqu0SObRVs9Yo1j5HT9//fMVNGWeuIz0WYVblLCbNypnELFNuSpfulluUNX2k7lmetkZeTaCTGiLXm3+zhth/OkzUplZ8uK6UTUPZzRISIiIs/iQIeIiIg8i6GrPmBbzVhjboypbT5YP15u1hmzV1Z1zSh1T7FrmSZ1F8aKNqv+a1lX3MCTQqQnYR7tuNiKRNdjX738XDUp2Yyn4uT1zXBW+qr35UFa1pXyWTa/oBnKIuo+zugQERGRZ3GgQ0RERJ7FgQ4RERF5FtfohJntWgI1Pl8ndw731be5Hv95vk8ck7nujGjz1X8p2toS5bkmc00Q0MFrYlVX6kU9Watisx5O+2xU58vPwrC9Xf9bsXJlrmhLK5Wp6raVnYlCxfwu9+oaMM7oEBERkWdxoENERESexdBVmNmmkmtp1ubmgADQPMo9Nh3/yHFxzOGrZVXX5vHy+uMfkZsbmnwbPxBtNinhrPJK4WR7f6lVyJXSB2b14riXtoljMnGxaGtNkqUcbLQmDRZtspCDfE38XFEoDZR7hzM6RERE5Fkc6BAREZFnMXQVKbTNLbOmd3na3nuTRNvwHfK4tM1yqtzM9IpWKhcHa6BMiVLf6EkVZC10bG6Oe1rZgNbMeASA2AqZGVl+lzu8FHtE/nsyfas8T9/80/2YnysKJZusKy+ESzmjQ0RERJ7FgQ4RERF5Fgc6RERE5FlcoxNiajqrSUlvbSuQqatxFbKaMWDuoBwjjmhTdlRui5Nj2qSdcp2AjWBjtl6I9YaL1X0D/r5CwaaKt/Z51EotHFSqHsceMc5rktevnSgroaeWyFIRJ6dkua9dJq/Fe4Js2Hz/2n4P9bfvcs7oEBERkWdxoENERESexdBViFlN31mmrjZlmWEqucmmHt4KTo9Sdi3SFCN5arO3WU0RD0+U5ymHsXpu99j8frQNNk8q4ayzjsnrJ3/sDgnHPlItjtn3zlh5/SxZKsKsoByr3BPge0shYl3JXwn/RvL3Dmd0iIiIyLM40CEiIiLP4kCHiIiIPItrdPqCst2Dls7aFi/HoaItU6aXa5Kffl82WqYS2oiUWGwkUmPXY0e5j9G231DKENhcn+9F54L9/WjrdmIrEuWBxvu2751vi0O0tT1Nymc57tAp12Nz2xbyNpu1fLbbNtis+dO3Iem6HENH/YgUnNEhIiIiz+JAh4iIiDyLoatwU1LJo+tkReKqCXLaOnn3KdGmpaGb2hJ9Vv04bYTQIjk9MBLZ/r7UqWWL66u/e+V9NEOhfB97zvb3VXudDEulvlrvehyjVEb2NbaLtuZR8t+d5udd++443UUfA+caacI2FaKpbwVbcV4LU2mhcDOEHvfSNnmM9jdM6Uckf8dwRoeIiIg8iwMdIiIi8qxeCV09+eSTePTRR1FdXY2JEyfiiSeewNSpU9Vj165di9tuu83V5vf78dVXX/VGV0NOCyNpbRkblA39lEqpZrVks1IyAMRW1Ik2LVsjlNk6zPz5X5ahyroL3b8vrcK1+uHUsrP4u++xYLNbzKwoQGauDDsqw1Qam80/U19V3n9LDFV5g9W9alm52Px7klQns65OakshLL9jbDcJDbewz+i88sorWLx4MVasWIGdO3di4sSJmDVrFo4fl3/YvxEfH4+qqqrAz6FDh8LdTSIiIvKgsA90fvGLX2D+/Pm47bbb8J3vfAdr1qxBbGwsnn322Q7PiYqKQlpaWuAnNTW1w2NbW1vR2Njo+iEiIiICwjzQaWtrQ1lZGfLz8//6hIMGIT8/H6WlpR2e19zcjMzMTGRkZOCaa67Bnj17Ojy2qKgICQkJgZ+MjIyQvgYiIiLqv8K6Rucvf/kLzpw5I2ZkUlNTsW/fPvWccePG4dlnn8UFF1yAhoYGrFq1Crm5udizZw/OOecccfyyZcuwePHiwOPGxsZeG+zYxB+jlWqq5f8k46Ajtstr1WXLcWhaqXtNgFatVa2AqaztCOVaDq4L+V9K1Wst/TfJSPU00zwBoEZJXTZ3xgbke8v08u4LabVk47H2eR/3kFxDob3fqa9+1mUf+H4PLMG+t1rVYxvq3xjLtTeRch9GXB2dnJwc5OTkBB7n5uZi/PjxePrpp/Hggw+K4/1+P/x+f292kYiIiPqJsIauzj77bAwePBg1NTWu9pqaGqSlpVldIyYmBpMnT8b+/cwYICIiou4J64yOz+fDlClTUFxcjMLCQgBAe3s7iouLsXDhQqtrnDlzBp9++imuvPLKMPa0d417Sk7nHbx5hGhL2ifTUs2N/5oylenuko4z2v4WU8J7xnYzPS0sBSPl3EwjBjp4H202+tSqovK9DQ+LCrQZG2UJiIML5edWq6BsE25I2mlXeZufb2+yrYyslbAw08vVTWO1NiVEH8nCHrpavHgx5s6di4suughTp07F448/jpaWlkCtnFtvvRWjRo1CUVERAOCBBx7A9OnTce6556K+vh6PPvooDh06hB/+8Ifh7ioRERF5TNgHOjfccAO++OILLF++HNXV1Zg0aRI2btwYWKBcWVmJQYP+GkGrq6vD/PnzUV1djaSkJEyZMgXvv/8+vvOd74S7q0REROQxvbIYeeHChR2GqkpKSlyPV69ejdWrV/dCr4iIiMjrIi7rKhL0pGy1iINrWwIoqd4xTXKNzolrT4q2sfe727RtIrS2WG3nWsbse8Q2Nq5tAVGT536/ta0E1G07lH6Y77eaDkrhoXyuzHU12hqamAny865tFZG02Z2EcXJKlny+C+V9GPcSkze8INitSWotS1No63ZMlbPl99CYA/2rpAE39SQiIiLP4kCHiIiIPIuhK4Vt2rDGPO6gNu33YqJoO3mOnLaO3TFM9i0pyvXYV98mO6Gl/kXILrKeZxki9BnhhtoJMeKYU3Gy7axj8n4yp6QjeQp5IDBDVWp4Wfm8q//utEgvtwk/UP9k81nW/jZpoXCNWdbC1yjvy7OOBdevjvrWFzijQ0RERJ7FgQ4RERF5FkNXloINByTtk1PP2lR2TIPdmNOsZKlOWyuZXqf7WSXLvma1UaKSYXW6XGa7RI+T94D5vjWP0kKcsjKydu+YWXzaJqIUHmqY27gvtPDy2JfltczPNgC0xXf9vaBW1a6T95x2b1L/YxMO0u45LXuq9Rx3iGvYXhku9ykVu7XvNE2k3HOc0SEiIiLP4kCHiIiIPIsDHSIiIvIsrtGxZLVmA0BbwcWux9X5cufitM2DRZu2c7HGrJRqu06EusdmTZbt71k7ruonuV2eZ1ZPBux3pqc+ZJQYUHcSV6peN4+S/+4009AT98pjtIq3WpkD8zusJ2UIQnkt6h7zd62tl9Gqscc0yXtu9L+71/K0Jcrns620Hsn3AGd0iIiIyLM40CEiIiLPYujKkm215NpMd3resL3yWs2jZFvyblnJUksRNDf14waOfcd2U08tJdxM2dTCD1WXyKlmLeTRZFRZTtosuxXJ08qeY9wD6vuvfLZPxclLmWUnWkbKY5ItK0cEW2VXO4/3UwSx2FgWADI2yLC39n1i4/RY+UdMDdFGyH3CGR0iIiLyLA50iIiIyLMYulLYTt9qYQpzUzRfYw/6oayct+kDImS60OvUe0LbwLNcHpZqZErUXSjvOS2c2ZboE21mJp4Z3gQA30a70GukTDVHIuuwjpFlF3vCLsSZVirf26p/aHU9jtorN/rVvifU0ILxnFo2IN//yGK1KaZyL/nrZLavFkKNrahzPdaqc0P5PulvSyY4o0NERESexYEOEREReRYHOkRERORZXKOjsE0l1yTtdJ+rrb3QmLFSQE/9EztVM6Ye8bTKpeZ7e+Lak+IY3w65HmPY0XbRZsbQtfi5tqM5753uCfr3pa2jU1KCW5XvivRn/e5jkuT7r33HmN9DHT0nedPhAll9X6uqXZ3jrr5u7mYO6JX8Yy3v6UjBGR0iIiLyLA50iIiIyLMYurJlWfHWFFfxpWirnK1t6Cc3cNQqWZoVL+NeYvihPzJDkOnPyhTOpky7a5lVdrXUYpYc6DtaGrcWCm+LD+7fncenyTZ/nZJKbDbwnoh46jIKMxSuhIwyNmp/m2TK+Ymp7sfjH5FLKNSQlLaZdATfT5zRISIiIs/iQIeIiIg8iwMdIiIi8iyu0YGMl6uxRst0UEmWdT/rmDwq7pBM61PTRo1y/1raMPUO25IDNjHu2gkx4hBzh/OOmGXbtfLv8i6kcLFZQ2GzfQwg1+2klsh1e9r6vtiKrr/DuAVI/2SWprD9I26u5QOAzHXubwa9HIpsi3tpm2iz/j7sA5zRISIiIs/iQIeIiIg8i6GrHtBCBK1Jgzt93BGtMnJsWX1Q/aLIooUDmvJltWQbddny3yZppe6wp7qzsFadWUl7NjG80X3m79V2Sl/7Pmke5X6/tdCCVi1bOy6uwgh5GCUOqJ/Y9rH7sXJ/iVICHTBLnYx+s0Uco94nlt8LkRLO4owOEREReRYHOkRERORZDF3BbipeOya2IlFpcz8+eLPMiIhRsmn0jfm67Barm/Yh66na6RNFU3W+O0yRuU7JlFKyJJI/Fk2oneiefvZtVPolT7MKSzFM1XPqfaJkXWkhx9Ykd8hRq4Ice0T+e1Xd/NUCQ5WRT7xHWpViZUNorSJ/0j73fRJdJ0NXtpvSaiLl3uGMDhEREXlWrwx0nnzySXzrW9/CkCFDMG3aNOzYsaPT49etW4fs7GwMGTIE559/Pt54443e6CYRERF5TNgHOq+88goWL16MFStWYOfOnZg4cSJmzZqF48dl4SsAeP/993HTTTfh9ttvx0cffYTCwkIUFhZi9+7d4e4qEREReUyU4zhOOJ9g2rRpuPjii/Fv//ZvAID29nZkZGTgxz/+MZYuXSqOv+GGG9DS0oLXX3890DZ9+nRMmjQJa9asEce3traitbU18LixsREZGRloaGhAfHy8VR8vHzSny2Ns0+Rqrvt2l8do1U1tq6eax0VKDJQ6pt075f/kXnsxYrs878S1J0XbsOJhos2sqJu0U94TNqnkGq7Z6Dnb32FbwcWizVynpa2zOJUg1+OMe8pi3WGQ9wT1HvXeGTuqy/PMdXsdaRnpfmyWqgA6KFehsPle2NS+rtP/39jYiISEhG79/bYR1hmdtrY2lJWVIT8//69POGgQ8vPzUVpaqp5TWlrqOh4AZs2a1eHxRUVFSEhICPxkZGSE7gUQERFRvxbWgc5f/vIXnDlzBqmpqa721NRUVFdXq+dUV1d36/hly5ahoaEh8HP48OHQdJ6IiIj6vX6fXu73++H3+3t0DatNPS1T+FJf/cz1+OBCGcqqyZMp53GHkkSbVi2ZYYP+R3vPxj3kfnxySpY4Zvh6Wd80rkKmf1ZdcpZxjLwvxUaTgBouZXp579BCErKYgAxBaKUpnPEyxKl9x5ghc4YlI0uwVYTVlHDYha7M9PJQhqkiSVhndM4++2wMHjwYNTU1rvaamhqkpaWp56SlpXXreCIiIqKOhHWg4/P5MGXKFBQXFwfa2tvbUVxcjJycHPWcnJwc1/EAsGnTpg6PJyIiIupI2ENXixcvxty5c3HRRRdh6tSpePzxx9HS0oLbbrsNAHDrrbdi1KhRKCoqAgDcfffdmDFjBh577DFcddVVePnll/Hhhx/imWeeCVsfrabhlGn+A3fJqcb2OHcIwn9EXqp+vMySaIuLEW2+jcyK8AJ1StoIhdZOUN5/JUxh9XzKVLYWZgWzbsIi2BCEFjbw1bszbLSsK9sQp1l9XQtxstJ637ENB4k/2sqyirY4eZ72fWJuGutXQuhqOEup9i42G40gYR/o3HDDDfjiiy+wfPlyVFdXY9KkSdi4cWNgwXFlZSUGDfrrLzs3Nxcvvvgifv7zn+O+++7Deeedh9deew0TJkwId1eJiIjIY3plMfLChQuxcOFC9f+VlJSItjlz5mDOnK5r2xARERF1hntdERERkWf1+/TycNBi7Fr677iH5BoHM51crTSppY0rayhsYv39Lc1vINLeo1qjgnb6VrmmQqtueuDGoaItY6O8x0zRB47KRiXlnNVye87mM2mzbgsAmrLc7/foN7suLwAAzaNkW8YGoyK7Vl5A9or6obapzaJtmLKWy5zraE0aLI7wKfezOnCI4HIFnNEhIiIiz+JAh4iIiDyLoStLWoqdFs4yK00eLpBTgYl7u65a2pFImQoke1qYIvljdwiiLdEnjtFSRLXNP83U9FituKnFBrHUe7TPcbTyHmkbtJrSt8o2M+QFwGpDYFZL7kNKyrYacjbvE+VzPPZ+edree+UOApnrzrge+zZ+0FkPA2zvnUjBGR0iIiLyLA50iIiIyLM40CEiIiLP4hodS9p6HM3xae7H456SsczDV8s1OmZ5dgBI2lxv9ZzUD5nl0gsuFodoJdur88+ItrTN7scns5LEMdoaM5s4O9dn9FxP1r2YZSe0kgPa+r5mJeVc3hXK8/H97jNW63Eg/35kbJCnaX9P/EfkvEZs2WfuBqXkhFqGQLlPIvne4YwOEREReRYHOkRERORZDF1Zsk0vH/vyl67HB2+WYSrbasnatKX5hkXydCF9zSZE1KRVJG2Uu9xnrtOeQYazBIt7CZB9Zbpxz/Xk92VWPdYqaGthCq0yLjYkuh/zfew1NiFh9T5R2tKNcOYf3/m9OGbmP/xQtJ11TJnX0MpOmP0aO8qqX5GMMzpERETkWRzoEBERkWcxdGVL23AvM0a01RuVkGOPyEv56ttko1alVnlOhg28SauAq2VPafecWUE5eXfXm3wCdlPSvN96znYDT+07wMyo0cJUmvRnZRXcugvdmzomWWbTUM/ZhIRtwkgAcNKoon7RyrvEMb4kGfaOvuYL0VYD99KK1Fc/E8fAAxv9ckaHiIiIPIsDHSIiIvIsDnSIiIjIs7hGR6PESrX1Elr6r1kJuSZPppdrOwsnHZDdOO2B2CjpxFoIZb2MtpbLp6Sh1493P07fKs/T7iU1ddzmGK7j6BbbtGFt92pzV3t/nSwloK3b0iRt7vr7JFqpjMvvodBT7wnlOyC6TpYTML8XkpXviQM3yr8xsW+niLbkQ+71fFrJFN/G/v9554wOEREReRYHOkRERORZDF0ptBTOuIovlSN9sslI2Yw7ZJci7IXqk2RPVEpVjtE2cNRkbHRPP2uh0TjIsAiUTQS5qWfo2Yb/tDCF2aaFwltGyudsi5P/hm277tuux1oqMcNUvUNNL7fc1NMMS43YLk/TNpPW7h2t4r/ggfA1Z3SIiIjIszjQISIiIs9i6ApyGjHupW1dHgPov7yDC93Tw0n7ZGaWWckWkNkVgBoYI48wp37blGyHuEOywnGrknVlbgjbmiTvVXNzSADIqEvsqptWG39S53ry+zptbOCo3ROAXdZVaslxd0OQG71SeGi/51oj3AgAI7a7/6Y0j5LzFXXZMkyl/S2qMcOZ5j3SkX52T3BGh4iIiDyLAx0iIiLyLA50iIiIyLO4Rgd2MWgt/VtL441pcj/W4qfpW2UaKbZ93GUfyDvMNV/RxjobQK7P+JpcuVW7Osr1uOEDec+Z9yWgV/s21/uoadAeSDfta2p6sbKbeJNR6kKrjKyt+cvYINdamO+3+V4DfL97jbI+Slunl/yx/Fthruesy5afd209jk1lbI0X3mvO6BAREZFncaBDREREnsXQVQ9oYalhR91ThlpF5WitAqY2la1Mb5rT216YVhyIzPctWnuvFWol01+6p7xPZ8pDtLRkNXRhhMsYtggP2xCRKHWhbPyZ8bxFdVsAPuO91SrlJrMycliIiuPK7zn2hHz/6/LlJqvm3x0tLK2FOLUNO20qI3vhO4AzOkRERORZHOgQERGRZ3GgQ0RERJ4V1jU6J06cwI9//GNs2LABgwYNwvXXX49//dd/xbBhwzo8Jy8vD1u2bHG13XnnnVizZk04u+qixiSV9O/kxIu7vJZWej85UcZKfRs/kCf3szgo2Yse5469azF77T7U4uyHC9zbQsQekc+nrsfRnnOcXBNAoaeml2tbMhht2i73Gu0+acp0bxWh7l6uXKu/rceIRGJNnvI508pJaGttzHRyLZVcXXuj3F+iXx5Yj6MJ60DnlltuQVVVFTZt2oRTp07htttuwx133IEXX3yx0/Pmz5+PBx54IPA4NjY2nN0kIiIijwrbQGfv3r3YuHEjPvjgA1x00UUAgCeeeAJXXnklVq1ahZEjR3Z4bmxsLNLS0sLVNSIiIhogwjbQKS0tRWJiYmCQAwD5+fkYNGgQtm/fjmuvvbbDc3/729/iP//zP5GWloarr74a999/f4ezOq2trWhtbQ08bmxsDN2L6IIWDjDTv331sqJydJ1SGZnp5QOKFjYSlPdfm5LOgDtMcbhAXkpNJdYqL5ulD7R7kPdcWFiFL9Vq2ZL23RRbVh9MtygExPuoVMFuUyqVtyYNFm1nHXM/1sJbWiV/7e+OSHv36Gc7bAOd6upqjBjh/nKNjo7G8OHDUV1d3eF5N998MzIzMzFy5Eh88skn+NnPfoby8nK8+uqr6vFFRUVYuXJlSPtORERE3tDtgc7SpUvx8MMPd3rM3r17g+7QHXfcEfjv888/H+np6bjssstw4MABjB07Vhy/bNkyLF68OPC4sbERGRkZQT8/EREReUe3BzpLlizBvHnzOj1mzJgxSEtLw/Hj7o3lTp8+jRMnTnRr/c20adMAAPv371cHOn6/H36/3/p6wVJXo1tkrWjThVZhC4AhgoFOmd5WQ0mGjI1yKrsp066ShJmtY1M5lbrPtjKyWRk3aWcPvhPMe0e7vyg8zOw55W9AU+a3rS5lbuJa9Q+t4pjh6+VSD2XvV0QPkHug2wOdlJQUpKSkdHlcTk4O6uvrUVZWhilTpgAA3n77bbS3twcGLzZ27doFAEhPT+9uV4mIiGiAC1vBwPHjx6OgoADz58/Hjh078N5772HhwoW48cYbAxlXR48eRXZ2Nnbs2AEAOHDgAB588EGUlZXh888/x3//93/j1ltvxaWXXooLLrggXF0lIiIijwprZeTf/va3yM7OxmWXXYYrr7wSl1xyCZ555pnA/z916hTKy8tx8uRJAIDP58PmzZtxxRVXIDs7G0uWLMH111+PDRs2hLObRERE5FFRjuM4fd2JUGpsbERCQgIaGhoQHx9vdc7lg+aEuVduavVZJVbq1VQ/smN7n9Rc13Vs39coq6eaFVYBYPSb7jVlPVljNlBSV/ua7X1ig+9Rz9lUFz+ppJKbFc4BYOzLX4q2ytnuEgPa7uXJu0+JNptyKOraMa2Kc5C73G9qX9fp/w/m77cN7nVFREREnsWBDhEREXlWWPe66i9sptjVTfiCpG3eNlDS/KgbLMOZvkYjBXmznFbWwlvaZoChxDBIL1HuE21TT3XjYOoR9e+CRVkIreKxWhYia6hoM0NVGRuOi2O0vzE24SbbMir9DWd0iIiIyLM40CEiIiLP4kCHiIiIPItrdGC3liCkaXfKtU53fRaRes+1xbv/vaLtXNwyUl6rZaT8d05dtju2f9YxGevXdj2PNnc9V3DNTufU9RFBfjf5Nna9zpDvR89pv8Omm6aLNm2HcRvNo+RnVEsdN6llISyez6v3BGd0iIiIyLM40CEiIiLPYuiKqJ9Lfvp91+Oqn+SKY846ZnctLcRlakv0ibZobVd1lkzolmDDBqy03jtsQ4taeQcZTpafodoJMaJNSx03Hb56hDzv+c+6PG8g4YwOEREReRYHOkRERORZDF31gBcqRlLkUkMNSltbwcWuxz5lkz9tU09NW5z73z5xh2SGR1OmnGIH5CaFPptKrCHcMHAg8Grl2t4Uyuw2zeG5sgq5GUpqeF5+XjLukftrH3gwVrSlP+t3P94qM6zMSswA1O+OgYIzOkRERORZHOgQERGRZ3GgQ0RERJ7FNTpEEcp2LUFsWYXrce0EuUbA1yivr1Vdjf0/X7geHz4nWRyTsVGu2/HVt4k2Ne3ZwPUl3RPuFPFg16/0J7avx7x/tR3BtQrEWmp3zXXuz2T9Z3LN3Ok8+Xn07dB65v78aVXJvfae9RRndIiIiMizONAhIiIiz2LoiihCBTv9rE2d1+V3HUYCgJNvp7geJyqp6q0yMxaxFUqKKysj9zsMefyVGVa1LoUwfaJoMiuOt8fJTT6TP/5StFVdIsNlZphY28TX9g+7zfvthXAmZ3SIiIjIszjQISIiIs9i6IrIY7Rp5bgKOb1dly2nxWOMUJW+yaf895GWkVKb595sMPljGd4K5RQ7UbC08IxZXVgLU2nhrIOz5WdhzIvuzTlr8uRGnE1Zg0WbuqmnERLWPhun5VlB88JnjzM6RERE5Fkc6BAREZFncaBDREREnsU1OkT9nBlDV9NglfNGr3hftB14fLrrsZYGO/pNWQVZS4MddrTrHdPV1Fil2qz5ReWFdQMUeuZam57cJ+a6s6qf5IpjtDU05jo3QK7JaYuTx8Qdkp81tV/m590D6d/hxhkdIiIi8iwOdIiIiMizGLoi8hqtIrGS/q1NxY992R02akv0iWOasoaKNi1MlbTTPX2upaBr10riJoUUpGDDOlpb1Vz3RpzpW2VIte5CeX2fEroySyton6vaCTGiLbZCNAla3xnOcuOMDhEREXkWBzpERETkWRzoEBERkWdxjQ6Rx6ixeKUtfZs8zExNj1bWy2BKlmg6NEce1jzKnVKrruPZrOz+bMF6J+kQ4rqHyGfeF9o90VZwsdW1Mp7/zH0tpRSCrdqJ7vVpWnq5lqqubjuhbVdhnsf70oUzOkRERORZYRvoPPTQQ8jNzUVsbCwSExOtznEcB8uXL0d6ejqGDh2K/Px8/PnPfw5XF4mIiMjjwha6amtrw5w5c5CTk4Pf/OY3Vuc88sgj+OUvf4nnn38eWVlZuP/++zFr1iz86U9/wpAhQ8LVVSLPs9mdWXPYSLEFgOTdp0SbL65VtDWPdz9OX/WBOKbmTpniru5ybvZVSaEPd2iJ4YDIZxO+9G2U96F275w0QrRNmTL9O+OWA6Kt/v+O7rIPGq38AsOloRG2gc7KlSsBAGvXrrU63nEcPP744/j5z3+Oa665BgDwwgsvIDU1Fa+99hpuvPHGcHWViIiIPCpi1uhUVFSguroa+fn5gbaEhARMmzYNpaWlHZ7X2tqKxsZG1w8REREREEFZV9XV1QCA1NRUV3tqamrg/2mKiooCs0dEpFOrp1qcl75VTqebGSQAkPkLudFnU9Zg9+ObpotjtOwTlRGqqsuXWVdmJWYAarYZRTabrCIAaujVDP9oWYPavRNX8aVoMysVnzxHZg1mKN1qTRqstLq1xcs5htQyizLICO3GpQNFt2Z0li5diqioqE5/9u3bF66+qpYtW4aGhobAz+HDh3v1+YmIiChydWtGZ8mSJZg3b16nx4wZMyaojqSlpQEAampqkJ6eHmivqanBpEmTOjzP7/fD7/cH9ZxERETkbd0a6KSkpCAlJSUsHcnKykJaWhqKi4sDA5vGxkZs374dd911V1iek4iIiLwtbGt0KisrceLECVRWVuLMmTPYtWsXAODcc8/FsGHDAADZ2dkoKirCtddei6ioKCxatAj//M//jPPOOy+QXj5y5EgUFhaGq5tE9DfMeH+bUgVZW1cTXSdTwo/f6N6ZfOzLch1EnN2yBLGuQl2Po5k+UbZt+9juXAtM/+07NlWDbdfjaMxK3ifPkccc+O+xos0XL487MdVdkiFz3RmrPvBeCo2wDXSWL1+O559/PvB48uTJAIB33nkHeXl5AIDy8nI0NDQEjrn33nvR0tKCO+64A/X19bjkkkuwceNG1tAhIiKioIRtoLN27doua+g4juN6HBUVhQceeAAPPPBAuLpFREREA0jEpJcTUe9SU87NzTKVKrLJkJsialVdxz3lvv7hq0eIY5rHyyrL5/27TFX31xlT/UplZLXScwjDVBqGFrqnt0N92qax2uacbYm+Lq/VHifDTRnPy+tr1cTHP1LnenwyK0n2y/L3wPTy7ouYgoFEREREocaBDhEREXkWQ1dEA5RagdYICYlQFoBopYLr3n+R2VnDd7jDWVpl2cx1sgtaGEFuqCifLzbIyrKAnP7XXrfNhpH0V7a/Q6tQjJY9p4Sg6rLcmX5ahlWTcQygVyo2swvTNssuHFwow1Sj35QZiGZoV9tY1BZDVd3HGR0iIiLyLA50iIiIyLM40CEiIiLP4hodogHKJtYfre0QrayNMNNnAbkuoS1OpqAfLpDPGdMg//0V0+R+rO0Q3apUwfXXybU8qJfp62K9h1LpuS+EMpW4t9OSbdbjaLS1PSeVdVuxFfKe8xvHaetxRKmCDtrMe0xb72Nb2dvcRf10mCt2kxtndIiIiMizONAhIiIiz2LoiogCRGhBqUCsfmkoxzVd6L6WuUkiAKRvleGAytkyxJW+1R1K0lLQayeYKej65p9aFWdzU1LrVPIQhiBCmdJuVYFY6bsZYlHPs3w+jRb2NNlUKQb06sJmGQJfo7zntPtEq9DtP+KeB0jaqaSN274/5u+HYapexRkdIiIi8iwOdIiIiMizONAhIiIiz+IaHSIKsEo51nY9V9ZomCm7WjrwwZvljubZMw+Itpp97jRxrWR/xobjsq+KqkvkGp3k3ca6kCy5Q7u2xcRp5foijVtbl6Ks0Qg2HVu9vrLWRlwryPU4ah8st23Q0r1NtjuO106U72Pqq5+5HtcpJQd8TaIJaZtluQJ/nXvdjrYmyJcUunVOFD6c0SEiIiLP4kCHiIiIPIuhKyIKC7FDsxKGGf2mDD/sw1h5seyun88m3RjQ09y1Ssum2rlyp2rtWkl1ia7HaigjyNRuLQVdc3JK17u7a+EU2zRxmBWzlfICWjVjrbqwGYLSwk112V1Xy9bO1SoeH58mz4s9Iq9v9tUsQQBAfd0MU0UezugQERGRZ3GgQ0RERJ7F0BUR9ZjVdL2yQahWBTetVFap9RkbcWrnmccAQPM/tIq2qL3DRJszvrnLY846JprUyssmLYwkwnoAoISlaq9zh8u0Sr9aOMjq+komkxZu0n6vZphNCzdpmXFQQon1492vSdvUdfSbMmx0aLEj2pp3uN83LXQ1YrvsVlyFvL4aqjLYhv8YzupbnNEhIiIiz+JAh4iIiDyLAx0iIiLyLK7RIaJuCXoNgpKKG6scplYINtaXaGt0tF3Ph6+Xa1qq8+UaoLT17p6YVXEBPQVd2wndRluBrLwsV8IAqSUW1Z6V32vNnbnBXUuhrVUxKxVr63GSP5bnae9by0j3uh1tPY5WUXn4etnXtnjjPGVNkNavYHE9Tv/AGR0iIiLyLA50iIiIyLMYuiKibgk6pVZJL9eo1zJCRFpIImmfDFNpIZW0zV33oTrH7vpqNWMjrHO4QIa8tEq8WpVlwF3tWdtYVPu9Jj/9vjzOCP9pISktRNSUKTdebRnpfqyl3uubp8qQ4JgX3SE1rcK1RkurN8OXWr/UUJwWglRCguI8ppf3C5zRISIiIs/iQIeIiIg8iwMdIiIi8iyu0SGibgl2DYKWNq7RUq9Naqq0lmZ9ndxxXFvbEftIteux//+OFsdo6eXl/yS3Pkjc6/73Y+wRcYi6Hkfrl7nOpTVJPl/zKPnvVV+eXFdjbh+h7Qiu0dYm1Y93P25rsltzVDtBrn0y1wBp21xoW21oa3nM7UPUkgAW5QsAiLVPtvcv1+NEHs7oEBERkWdxoENERESexdAVEXVLuKfmYyvqZKMRljJTuAEASopw3CGZzqylUB+qTnU99ikhFl+TfMoR27UQlDt9WXs+Le29dqLsvxn+OT5N9iFjY9e7vQMyDKZVINZSwrUQV+Je9+MTU2UfMjbI97F5lAypme+RVjrANuXcvHdilWz808p5ajhLCdFS/xS2GZ2HHnoIubm5iI2NRWJiotU58+bNQ1RUlOunoKAgXF0kIiIijwvbjE5bWxvmzJmDnJwc/OY3v7E+r6CgAM8991zgsd/vD0f3iIiIaAAI20Bn5cqVAIC1a9d26zy/34+0tLQw9IiI+pJtKMAmuyVaq7KsZF1pX3Cxyrn+OndopClTubwSnhm+Q4ZZzI0r/XVnxDE33f+maPvVuqtEW2qJO0yohX5alahOdY4MQZkZSRqtcrHGzGbSfg91F8r3+1Rc19fWMurUysXbPpZtZvaUdk/Y3ofMnvKMiFuMXFJSghEjRmDcuHG46667UFtb2+nxra2taGxsdP0QERERARE20CkoKMALL7yA4uJiPPzww9iyZQtmz56NM2fkv4i+UVRUhISEhMBPRkZGL/aYiIiIIlm3BjpLly4Vi4XNn3379gXdmRtvvBHf//73cf7556OwsBCvv/46PvjgA5SUlHR4zrJly9DQ0BD4OXz4cNDPT0RERN7SrTU6S5Yswbx58zo9ZsyYMT3pj7jW2Wefjf379+Oyyy5Tj/H7/VywTNQP2K550KrUinU7FjtLd9gPbb2HQUtLTy1R0t4Ve+91L5o5799lqvfGGeeJtrQp8jkPX+1ek3PyHJnOfipO/nv1VELXVYnTt8p+aZWENWal4vK75LqXcU/J99tcC6XR1mip1bK1to0fuB5yPQ4B3RzopKSkICUlJVx9EY4cOYLa2lqkp6f32nMSERGRd4RtjU5lZSV27dqFyspKnDlzBrt27cKuXbvQ3NwcOCY7Oxvr168HADQ3N+OnP/0ptm3bhs8//xzFxcW45pprcO6552LWrFnh6iYRERF5WNjSy5cvX47nn38+8Hjy5MkAgHfeeQd5eXkAgPLycjQ0NAAABg8ejE8++QTPP/886uvrMXLkSFxxxRV48MEHGZoiGkBs0su1UINtmCL6wFHZZjw+OSVLnmi5aej4R9zp0errUcJzWkVos7KvViH4y/9XZqYO+o9k0dZsFJPWKjFrG2pq6fFm+C9jY9fHAHrFZmH6RNEUW6aUONZY3AMMUw08YRvorF27tssaOo7jBP576NCheOutt8LVHSIiIhqAIiq9nIiIiCiUuKknEUU8MyylhR96Es4yNwlVQyVKRWUtO8sMcVX9JFcckrHBrvpvdJ17400tKyr5Hke0ncySoSRzI1F1w1NlQ01N7QR3X7WKylqIUPsdCrYZdRbVsW3vCYazvI0zOkRERORZHOgQERGRZ3GgQ0RERJ7FNTpEFPHMNRQ9WWdhrscBIHfC1tb2KGtHYpXrmynn6VtblKMkc9dzAGjKGup6nLRZpqprqfBaqrqv3r2uRnu+1Fc/E23a7yvOMttbUH6HQa+PCfI8rscZeDijQ0RERJ7FgQ4RERF5FkNXRNTvWIeptONCGPLQvkBTS4zzLDYRBTqojFxW3+V5Wkq4WVEZUMJgO+XrOTxXVnrOeF6Gs0zq71mp/myTOt6TsKRNGQIaeDijQ0RERJ7FgQ4RERF5Fgc6RERE5Flco0NEA4rNGhDrtR0W2xBE225pYEHrV/LTdutjtDU5JnVrCovfhfo7tdiFXju3J+tquCaHNJzRISIiIs/iQIeIiIg8i6ErIhpQQhne0MIzNqEY613VLSpCB90vpeKxtuN4SEN9CoabKNw4o0NERESexYEOEREReRZDV0REIWQTigl3pV+bcFO0kjEWyjBSTyocE4USZ3SIiIjIszjQISIiIs/iQIeIiIg8i2t0iIgiVLBrWqzWCQVZudj6+txxnCIEZ3SIiIjIszjQISIiIs9i6ArApvZ1fd0FIiIiCgPO6BAREZFncaBDREREnsWBDhEREXkWBzpERETkWRzoEBERkWdxoENERESexYEOEREReRYHOkRERORZHOgQERGRZ3GgQ0RERJ4VtoHO559/jttvvx1ZWVkYOnQoxo4dixUrVqCtra3T87766issWLAAycnJGDZsGK6//nrU1NSEq5tERETkYWEb6Ozbtw/t7e14+umnsWfPHqxevRpr1qzBfffd1+l599xzDzZs2IB169Zhy5YtOHbsGK677rpwdZOIiIg8LMpxHKe3nuzRRx/FU089hYMHD6r/v6GhASkpKXjxxRfxd3/3dwC+HjCNHz8epaWlmD59ujintbUVra2trmuMHj0ahw8fRnx8fHheCBEREYVUY2MjMjIyUF9fj4SEhJBdt1d3L29oaMDw4cM7/P9lZWU4deoU8vPzA23Z2dkYPXp0hwOdoqIirFy5UrRnZGSEptNERETUa2pra/vnQGf//v144oknsGrVqg6Pqa6uhs/nQ2Jioqs9NTUV1dXV6jnLli3D4sWLA4/r6+uRmZmJysrKkP6iIt03I+GBNpPF183XPRDwdfN1DwTfRGQ6mxAJRrcHOkuXLsXDDz/c6TF79+5FdnZ24PHRo0dRUFCAOXPmYP78+d3vZSf8fj/8fr9oT0hIGFA3yDfi4+P5ugcQvu6Bha97YBmor3vQoNAuH+72QGfJkiWYN29ep8eMGTMm8N/Hjh3DzJkzkZubi2eeeabT89LS0tDW1ob6+nrXrE5NTQ3S0tK621UiIiIa4Lo90ElJSUFKSorVsUePHsXMmTMxZcoUPPfcc12O0qZMmYKYmBgUFxfj+uuvBwCUl5ejsrISOTk53e0qERERDXBhSy8/evQo8vLyMHr0aKxatQpffPEFqqurXWttjh49iuzsbOzYsQPA1+Gm22+/HYsXL8Y777yDsrIy3HbbbcjJyVEXImv8fj9WrFihhrO8jK+br3sg4Ovm6x4I+LpD+7rDll6+du1a3Hbbber/++YpP//8c2RlZeGdd95BXl4egK8LBi5ZsgQvvfQSWltbMWvWLPzqV79i6IqIiIi6rVfr6BARERH1Ju51RURERJ7FgQ4RERF5Fgc6RERE5Fkc6BAREZFn9fuBzueff47bb78dWVlZGDp0KMaOHYsVK1agra2t0/O++uorLFiwAMnJyRg2bBiuv/561NTU9FKvQ+Ohhx5Cbm4uYmNjxbYZHZk3bx6ioqJcPwUFBeHtaIgF87odx8Hy5cuRnp6OoUOHIj8/H3/+85/D29EQO3HiBG655RbEx8cjMTERt99+O5qbmzs9Jy8vT7zfP/rRj3qpx8F58skn8a1vfQtDhgzBtGnTAuUnOrJu3TpkZ2djyJAhOP/88/HGG2/0Uk9Dqzuve+3ateJ9HTJkSC/2NjTeffddXH311Rg5ciSioqLw2muvdXlOSUkJLrzwQvj9fpx77rlYu3Zt2PsZat193SUlJeL9joqK6nBrpEhUVFSEiy++GHFxcRgxYgQKCwtRXl7e5Xmh+Hz3+4HOvn370N7ejqeffhp79uzB6tWrsWbNGtx3332dnnfPPfdgw4YNWLduHbZs2YJjx47huuuu66Veh0ZbWxvmzJmDu+66q1vnFRQUoKqqKvDz0ksvhamH4RHM637kkUfwy1/+EmvWrMH27dtx1llnYdasWfjqq6/C2NPQuuWWW7Bnzx5s2rQJr7/+Ot59913ccccdXZ43f/581/v9yCOP9EJvg/PKK69g8eLFWLFiBXbu3ImJEydi1qxZOH78uHr8+++/j5tuugm33347PvroIxQWFqKwsBC7d+/u5Z73THdfN/D19gB/+74eOnSoF3scGi0tLZg4cSKefPJJq+MrKipw1VVXYebMmdi1axcWLVqEH/7wh3jrrbfC3NPQ6u7r/kZ5ebnrPR8xYkSYehh6W7ZswYIFC7Bt2zZs2rQJp06dwhVXXIGWlpYOzwnZ59vxoEceecTJysrq8P/X19c7MTExzrp16wJte/fudQA4paWlvdHFkHruueechIQEq2Pnzp3rXHPNNWHtT2+xfd3t7e1OWlqa8+ijjwba6uvrHb/f77z00kth7GHo/OlPf3IAOB988EGg7c0333SioqKco0ePdnjejBkznLvvvrsXehgaU6dOdRYsWBB4fObMGWfkyJFOUVGRevzf//3fO1dddZWrbdq0ac6dd94Z1n6GWndfd3c+8/0FAGf9+vWdHnPvvfc63/3ud11tN9xwgzNr1qww9iy8bF73O++84wBw6urqeqVPveH48eMOAGfLli0dHhOqz3e/n9HRNDQ0dLr7aVlZGU6dOoX8/PxAW3Z2NkaPHo3S0tLe6GKfKikpwYgRIzBu3DjcddddqK2t7esuhVVFRQWqq6td73dCQgKmTZvWb97v0tJSJCYm4qKLLgq05efnY9CgQdi+fXun5/72t7/F2WefjQkTJmDZsmU4efJkuLsblLa2NpSVlbnep0GDBiE/P7/D96m0tNR1PADMmjWr37yvQHCvGwCam5uRmZmJjIwMXHPNNdizZ09vdLdPeeH97olJkyYhPT0dl19+Od57772+7k6PNDQ0AECnf6tD9X53e6+rSLd//3488cQTWLVqVYfHVFdXw+fzifUdqamp/SrmGYyCggJcd911yMrKwoEDB3Dfffdh9uzZKC0txeDBg/u6e2HxzXuamprqau9P73d1dbWYpo6Ojsbw4cM7fQ0333wzMjMzMXLkSHzyySf42c9+hvLycrz66qvh7nK3/eUvf8GZM2fU92nfvn3qOdXV1f36fQWCe93jxo3Ds88+iwsuuAANDQ1YtWoVcnNzsWfPHpxzzjm90e0+0dH73djYiC+//BJDhw7to56FV3p6OtasWYOLLroIra2t+PWvf428vDxs374dF154YV93r9va29uxaNEifO9738OECRM6PC5Un++IndFZunSpuvjqb3/ML4GjR4+ioKAAc+bMwfz58/uo5z0TzOvujhtvvBHf//73cf7556OwsBCvv/46PvjgA5SUlITuRQQh3K87UoX7dd9xxx2YNWsWzj//fNxyyy144YUXsH79ehw4cCCEr4J6W05ODm699VZMmjQJM2bMwKuvvoqUlBQ8/fTTfd01CoNx48bhzjvvxJQpU5Cbm4tnn30Wubm5WL16dV93LSgLFizA7t278fLLL/fK80XsjM6SJUswb968To8ZM2ZM4L+PHTuGmTNnIjc3F88880yn56WlpaGtrQ319fWuWZ2ampo+31Oru6+7p8aMGYOzzz4b+/fvx2WXXRay63ZXOF/3N+9pTU0N0tPTA+01NTWYNGlSUNcMFdvXnZaWJhamnj59GidOnOjWPTtt2jQAX898jh07ttv9Daezzz4bgwcPFtmPnX0u09LSunV8JArmdZtiYmIwefJk7N+/PxxdjBgdvd/x8fGenc3pyNSpU7F169a+7ka3LVy4MJBM0dXsY6g+3xE70ElJSUFKSorVsUePHsXMmTMxZcoUPPfccxg0qPOJqilTpiAmJgbFxcW4/vrrAXy9mr2yshI5OTk97ntPdOd1h8KRI0dQW1vrGgD0hXC+7qysLKSlpaG4uDgwsGlsbMT27du7nbEWaravOycnB/X19SgrK8OUKVMAAG+//Tba29sDgxcbu3btAoA+f781Pp8PU6ZMQXFxMQoLCwF8PcVdXFyMhQsXqufk5OSguLgYixYtCrRt2rSpzz/H3RHM6zadOXMGn376Ka688sow9rTv5eTkiPTi/vZ+h8quXbsi8nPcEcdx8OMf/xjr169HSUkJsrKyujwnZJ/vYFZLR5IjR4445557rnPZZZc5R44ccaqqqgI/f3vMuHHjnO3btwfafvSjHzmjR4923n77befDDz90cnJynJycnL54CUE7dOiQ89FHHzkrV650hg0b5nz00UfORx995DQ1NQWOGTdunPPqq686juM4TU1Nzk9+8hOntLTUqaiocDZv3uxceOGFznnnned89dVXffUyuq27r9txHOdf/uVfnMTEROcPf/iD88knnzjXXHONk5WV5Xz55Zd98RKCUlBQ4EyePNnZvn27s3XrVue8885zbrrppsD/N+/z/fv3Ow888IDz4YcfOhUVFc4f/vAHZ8yYMc6ll17aVy+hSy+//LLj9/udtWvXOn/605+cO+64w0lMTHSqq6sdx3GcH/zgB87SpUsDx7/33ntOdHS0s2rVKmfv3r3OihUrnJiYGOfTTz/tq5cQlO6+7pUrVzpvvfWWc+DAAaesrMy58cYbnSFDhjh79uzpq5cQlKampsDnF4Dzi1/8wvnoo4+cQ4cOOY7jOEuXLnV+8IMfBI4/ePCgExsb6/z0pz919u7d6zz55JPO4MGDnY0bN/bVSwhKd1/36tWrnddee83585//7Hz66afO3Xff7QwaNMjZvHlzX72EbrvrrruchIQEp6SkxPV3+uTJk4FjwvX57vcDneeee84BoP58o6KiwgHgvPPOO4G2L7/80vnHf/xHJykpyYmNjXWuvfZa1+CoP5g7d676uv/2dQJwnnvuOcdxHOfkyZPOFVdc4aSkpDgxMTFOZmamM3/+/MCXaX/R3dftOF+nmN9///1Oamqq4/f7ncsuu8wpLy/v/c73QG1trXPTTTc5w4YNc+Lj453bbrvNNbgz7/PKykrn0ksvdYYPH+74/X7n3HPPdX760586DQ0NffQK7DzxxBPO6NGjHZ/P50ydOtXZtm1b4P/NmDHDmTt3ruv43/3ud863v/1tx+fzOd/97nedP/7xj73c49DozutetGhR4NjU1FTnyiuvdHbu3NkHve6Zb9KmzZ9vXuvcuXOdGTNmiHMmTZrk+Hw+Z8yYMa7PeX/R3df98MMPO2PHjnWGDBniDB8+3MnLy3Pefvvtvul8kDr6O/2371+4Pt9R/9sBIiIiIs+J2KwrIiIiop7iQIeIiIg8iwMdIiIi8iwOdIiIiMizONAhIiIiz+JAh4iIiDyLAx0iIiLyLA50iIiIyLM40CEiIiLP4kCHiIiIPIsDHSIiIvKs/x+yPq5UFwV/FwAAAABJRU5ErkJggg==", 73 | "text/plain": [ 74 | "
" 75 | ] 76 | }, 77 | "metadata": {}, 78 | "output_type": "display_data" 79 | } 80 | ], 81 | "source": [ 82 | "n_samples = 10_000\n", 83 | "\n", 84 | "x, _ = datasets.make_moons(n_samples=n_samples, noise=.06)\n", 85 | "\n", 86 | "scaler = preprocessing.StandardScaler()\n", 87 | "x = scaler.fit_transform(x)\n", 88 | "\n", 89 | "plt.hist2d(x[:, 0], x[:, 1], bins=100)\n", 90 | "plt.xlim(-2 ,2)\n", 91 | "plt.ylim(-2, 2)" 92 | ] 93 | }, 94 | { 95 | "attachments": {}, 96 | "cell_type": "markdown", 97 | "metadata": {}, 98 | "source": [ 99 | "## CNFs\n", 100 | "\n", 101 | "_TODO: Harmonize notation with code and add details._\n", 102 | "\n", 103 | "The evolution of the log-density follows the instantaneous change-of-variables formula:\n", 104 | "$$\\frac{\\partial \\log p({z}(t))}{\\partial t}=-\\operatorname{Tr}\\left(\\frac{\\partial f}{\\partial {z}(t)}\\right)$$\n", 105 | "\n", 106 | "Get total change in log-density by integrating across time:\n", 107 | "$$\\log p_1\\left({z}\\left(t_1\\right)\\right)=\\log p_0\\left({z}\\left(t_0\\right)\\right)-\\int_{t_0}^{t_1} \\operatorname{Tr}\\left(\\frac{\\partial f}{\\partial {z}(t)}\\right) d t$$\n", 108 | "\n", 109 | "We can get an unbiased estimate of the trace of a matrix by taking a double product of that matrix with a noise vector.\n", 110 | "$$\\operatorname{Tr}(A)=E_{p({\\epsilon})}\\left[{\\epsilon}^T A {\\epsilon}\\right]$$\n", 111 | "\n", 112 | "Typically we'd also need to implement backprop-ing through an ODE with e.g. adjoints, but Diffrax will take care of this for us here." 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 3, 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "class MLP(nn.Module):\n", 122 | " \"\"\" A simple MLP in Flax.\n", 123 | " \"\"\"\n", 124 | " hidden_dim: int = 32\n", 125 | " out_dim: int = 2\n", 126 | " n_layers: int = 3\n", 127 | "\n", 128 | " @nn.compact\n", 129 | " def __call__(self, x):\n", 130 | " for _ in range(self.n_layers):\n", 131 | " x = nn.Dense(features=self.hidden_dim)(x)\n", 132 | " x = nn.gelu(x)\n", 133 | " x = nn.Dense(features=self.out_dim)(x)\n", 134 | " return x" 135 | ] 136 | }, 137 | { 138 | "cell_type": "markdown", 139 | "metadata": {}, 140 | "source": [ 141 | "## Implementation\n", 142 | "\n", 143 | "Adapted from [Diffrax](https://docs.kidger.site/diffrax/examples/continuous_normalising_flow/)." 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": 4, 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "def logp_exact(t, y, args):\n", 153 | " \"\"\" Compute trace directly.\n", 154 | " \"\"\"\n", 155 | " y, _ = y\n", 156 | " _, func = args\n", 157 | " t = np.atleast_1d(t)\n", 158 | "\n", 159 | " fn = lambda y: func(np.concatenate([y, t])) # Augmented function\n", 160 | " f, f_vjp = jax.vjp(fn, y) # VJPs can be computed at the ~same cost as computing f through reverse-mode AD\n", 161 | "\n", 162 | " # Compute trace\n", 163 | " (size,) = y.shape\n", 164 | " (dfdy,) = jax.vmap(f_vjp)(np.eye(size))\n", 165 | " logp = np.trace(dfdy)\n", 166 | " return f, logp\n", 167 | "\n", 168 | "def logp_approx(t, y, args):\n", 169 | " \"\"\" Approx. trace using Hutchinson's trace estimator.\n", 170 | " \"\"\"\n", 171 | " y, _ = y\n", 172 | " z, func = args\n", 173 | " t = np.atleast_1d(t)\n", 174 | " \n", 175 | " fn = lambda y: func(np.concatenate([y, t])) # Augmented function\n", 176 | " f, f_vjp = jax.vjp(fn, y) # VJPs can be computed at the ~same cost as computing f through reverse-mode AD\n", 177 | " \n", 178 | " # Trace estimator\n", 179 | " (z_dfdy,) = f_vjp(z)\n", 180 | " logp = np.sum(z_dfdy * z)\n", 181 | " return f, logp" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": 5, 187 | "metadata": {}, 188 | "outputs": [], 189 | "source": [ 190 | "key = jax.random.PRNGKey(0)\n", 191 | "t = np.ones((x.shape[0], 1))\n", 192 | "\n", 193 | "f = MLP(hidden_dim=64, out_dim=2, n_layers=3)\n", 194 | "params = f.init(key, np.concatenate([x, t], axis=1))" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": 6, 200 | "metadata": {}, 201 | "outputs": [ 202 | { 203 | "data": { 204 | "text/plain": [ 205 | "Array([3.156627 , 2.2985032, 3.1344635, 3.863559 , 2.96205 , 2.404113 ,\n", 206 | " 2.4348328, 3.7143767, 3.1841269, 3.1162772, 3.385356 , 3.1589055,\n", 207 | " 3.0098429, 2.634412 , 2.9587815, 3.777234 , 3.6070356, 3.055339 ,\n", 208 | " 3.1169815, 3.5339994, 1.9420253, 2.847947 , 1.9530888, 3.5562863,\n", 209 | " 3.1166618, 2.3222103, 2.065152 , 2.128791 , 3.3884706, 1.9760624,\n", 210 | " 3.2089186, 2.6790004], dtype=float32)" 211 | ] 212 | }, 213 | "execution_count": 6, 214 | "metadata": {}, 215 | "output_type": "execute_result" 216 | } 217 | ], 218 | "source": [ 219 | "t0 = 0.0\n", 220 | "t1 = 1.0\n", 221 | "dt0 = 1e-2\n", 222 | "logp = 'exact'\n", 223 | "\n", 224 | "# Runs backward-in-time to train the CNF\n", 225 | "def loss_fn(params, y, f):\n", 226 | " if logp == 'exact':\n", 227 | " term = dfx.ODETerm(logp_exact)\n", 228 | " elif logp == 'approx':\n", 229 | " term = dfx.ODETerm(logp_approx)\n", 230 | " else:\n", 231 | " raise NotImplementedError\n", 232 | " solver = dfx.Heun()\n", 233 | " eps = jax.random.normal(key, y.shape)\n", 234 | " delta_log_likelihood = 0.0\n", 235 | " y = (y, delta_log_likelihood)\n", 236 | " func = lambda x: f.apply(params, x)\n", 237 | " sol = dfx.diffeqsolve(term, solver, t1, t0, -dt0, y, (eps, func))\n", 238 | " (z,), (delta_log_likelihood,) = sol.ys\n", 239 | " log_prob = delta_log_likelihood + tfp.distributions.Normal(loc=0., scale=1.).log_prob(z).sum()\n", 240 | " return - log_prob\n", 241 | "\n", 242 | "jax.vmap(loss_fn, in_axes=(None, 0, None))(params, x[:32], f)" 243 | ] 244 | }, 245 | { 246 | "attachments": {}, 247 | "cell_type": "markdown", 248 | "metadata": {}, 249 | "source": [ 250 | "## Train" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": 16, 256 | "metadata": {}, 257 | "outputs": [], 258 | "source": [ 259 | "opt = optax.adamw(learning_rate=3e-4, weight_decay=1e-4)\n", 260 | "opt_state = opt.init(params)\n", 261 | "\n", 262 | "@partial(jax.jit, static_argnums=(2,))\n", 263 | "def loss_fn_vmapped(params, x_batch, f):\n", 264 | " loss = jax.vmap(loss_fn, in_axes=(None, 0, None))(params, x_batch, f)\n", 265 | " return loss.mean()\n", 266 | "\n", 267 | "@partial(jax.jit, static_argnums=(3,))\n", 268 | "def train_step(params, opt_state, x_batch, f):\n", 269 | " loss, grad = jax.value_and_grad(loss_fn_vmapped)(params, x_batch, f)\n", 270 | " updates, opt_state = opt.update(grad, opt_state, params)\n", 271 | " params = optax.apply_updates(params, updates)\n", 272 | " return loss, params, opt_state" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": 17, 278 | "metadata": {}, 279 | "outputs": [ 280 | { 281 | "name": "stderr", 282 | "output_type": "stream", 283 | "text": [ 284 | "100%|██████████| 2000/2000 [06:38<00:00, 5.02it/s, val=1.4153568]\n" 285 | ] 286 | } 287 | ], 288 | "source": [ 289 | "n_steps = 2000\n", 290 | "n_batch = 32\n", 291 | "\n", 292 | "with trange(n_steps) as steps:\n", 293 | " for step in steps:\n", 294 | "\n", 295 | " # Draw a random batches from x\n", 296 | " key, _ = jax.random.split(key)\n", 297 | " idx = jax.random.choice(key, x.shape[0], shape=(n_batch,))\n", 298 | "\n", 299 | " x_batch = x[idx]\n", 300 | "\n", 301 | " loss, params, opt_state = train_step(params, opt_state, x_batch, f)\n", 302 | " \n", 303 | " # loss, grads = jax.value_and_grad(loss_fn_vmapped)(params, x_batch, f)\n", 304 | " # updates, opt_state = opt.update(grads, opt_state, params)\n", 305 | "\n", 306 | " # params = optax.apply_updates(params, updates)\n", 307 | "\n", 308 | " steps.set_postfix(val=loss)" 309 | ] 310 | }, 311 | { 312 | "attachments": {}, 313 | "cell_type": "markdown", 314 | "metadata": {}, 315 | "source": [ 316 | "## Sampling" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": 18, 322 | "metadata": {}, 323 | "outputs": [ 324 | { 325 | "data": { 326 | "text/plain": [ 327 | "Array([ 0.60950206, -0.50103378], dtype=float64, weak_type=True)" 328 | ] 329 | }, 330 | "execution_count": 18, 331 | "metadata": {}, 332 | "output_type": "execute_result" 333 | } 334 | ], 335 | "source": [ 336 | "def single_sample_fn(params, key, n_dim=2):\n", 337 | " \"\"\" Produce single sample from the CNF by integrating forward.\n", 338 | " \"\"\"\n", 339 | " z = jax.random.normal(key, (n_dim,))\n", 340 | " def func(t, x, args):\n", 341 | " t = np.atleast_1d(t)\n", 342 | " return f.apply(params, np.concatenate([x, t]))\n", 343 | " term = dfx.ODETerm(func)\n", 344 | " solver = dfx.Heun()\n", 345 | " sol = dfx.diffeqsolve(term, solver, t0, t1, dt0, z)\n", 346 | " (y,) = sol.ys\n", 347 | " return y\n", 348 | "\n", 349 | "single_sample_fn(params, key)" 350 | ] 351 | }, 352 | { 353 | "cell_type": "code", 354 | "execution_count": 19, 355 | "metadata": {}, 356 | "outputs": [], 357 | "source": [ 358 | "sample_fn = partial(single_sample_fn, params)\n", 359 | "\n", 360 | "n_samples = 100\n", 361 | "sample_key = jax.random.split(key, n_samples ** 2)\n", 362 | "x_sample = jax.vmap(sample_fn)(sample_key)" 363 | ] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "execution_count": 20, 368 | "metadata": {}, 369 | "outputs": [ 370 | { 371 | "data": { 372 | "image/png": "", 373 | "text/plain": [ 374 | "
" 375 | ] 376 | }, 377 | "metadata": {}, 378 | "output_type": "display_data" 379 | } 380 | ], 381 | "source": [ 382 | "\n", 383 | "plt.hist2d(x_sample[:, 0], x_sample[:, 1], bins=100);" 384 | ] 385 | }, 386 | { 387 | "cell_type": "code", 388 | "execution_count": null, 389 | "metadata": {}, 390 | "outputs": [], 391 | "source": [] 392 | }, 393 | { 394 | "cell_type": "code", 395 | "execution_count": null, 396 | "metadata": {}, 397 | "outputs": [], 398 | "source": [] 399 | } 400 | ], 401 | "metadata": { 402 | "kernelspec": { 403 | "display_name": "torch-mps", 404 | "language": "python", 405 | "name": "python3" 406 | }, 407 | "language_info": { 408 | "codemirror_mode": { 409 | "name": "ipython", 410 | "version": 3 411 | }, 412 | "file_extension": ".py", 413 | "mimetype": "text/x-python", 414 | "name": "python", 415 | "nbconvert_exporter": "python", 416 | "pygments_lexer": "ipython3", 417 | "version": "3.9.13" 418 | }, 419 | "orig_nbformat": 4 420 | }, 421 | "nbformat": 4, 422 | "nbformat_minor": 2 423 | } 424 | -------------------------------------------------------------------------------- /07_diffusion_distillation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "af32d529", 6 | "metadata": {}, 7 | "source": [ 8 | "# Diffusion distillation (WiP)" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "8d2c8c5d", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "from functools import partial\n", 19 | "\n", 20 | "import jax\n", 21 | "import jax.numpy as np\n", 22 | "import flax.linen as nn\n", 23 | "import optax\n", 24 | "import diffrax as dfx\n", 25 | "\n", 26 | "from sklearn import datasets, preprocessing\n", 27 | "\n", 28 | "import matplotlib.pyplot as plt\n", 29 | "from tqdm import trange" 30 | ] 31 | }, 32 | { 33 | "attachments": {}, 34 | "cell_type": "markdown", 35 | "id": "6fcdfbd3", 36 | "metadata": {}, 37 | "source": [ 38 | "## The dataset" 39 | ] 40 | }, 41 | { 42 | "attachments": {}, 43 | "cell_type": "markdown", 44 | "id": "c6de9e86", 45 | "metadata": {}, 46 | "source": [ 47 | "We'll use two moons to keep things simple." 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 34, 53 | "id": "cdab2cc6-b2dc-4834-b343-0386a25fd9a9", 54 | "metadata": {}, 55 | "outputs": [ 56 | { 57 | "data": { 58 | "text/plain": [ 59 | "(-2.0, 2.0)" 60 | ] 61 | }, 62 | "execution_count": 34, 63 | "metadata": {}, 64 | "output_type": "execute_result" 65 | }, 66 | { 67 | "data": { 68 | "image/png": "", 69 | "text/plain": [ 70 | "
" 71 | ] 72 | }, 73 | "metadata": {}, 74 | "output_type": "display_data" 75 | } 76 | ], 77 | "source": [ 78 | "n_samples = 100_000\n", 79 | "\n", 80 | "x, _ = datasets.make_moons(n_samples=n_samples, noise=.06)\n", 81 | "\n", 82 | "scaler = preprocessing.StandardScaler()\n", 83 | "x = scaler.fit_transform(x)\n", 84 | "\n", 85 | "plt.hist2d(x[:, 0], x[:, 1], bins=100)\n", 86 | "plt.xlim(-2 ,2)\n", 87 | "plt.ylim(-2, 2)" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 35, 93 | "id": "de9f55f9", 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "sigma = 1.\n", 98 | "\n", 99 | "def weight(t):\n", 100 | " \"\"\" \\lambda(t)\n", 101 | " \"\"\"\n", 102 | " return 0.5 / np.log(sigma) * (sigma ** (2 * t) - 1)\n", 103 | "\n", 104 | "@partial(jax.jit, static_argnums=(3,))\n", 105 | "def loss_fn(params, x, t, score, key):\n", 106 | "\n", 107 | " mu = x # x(0)\n", 108 | " std = np.sqrt(0.5 / np.log(sigma) * (sigma ** (2 * t) - 1)) # std of noise at time t\n", 109 | "\n", 110 | " eps = jax.random.normal(key, shape=x.shape) # Sampled noise\n", 111 | " y = mu + std * eps # x(t) = x(0) + std * eps # Corrupted data\n", 112 | "\n", 113 | " # Predicted score\n", 114 | " pred = score.apply(params, np.concatenate([t, y], -1)) \n", 115 | "\n", 116 | " # Score matching loss\n", 117 | " loss = weight(t) * np.mean((pred + eps / std) ** 2)\n", 118 | "\n", 119 | " return loss.mean()" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 36, 125 | "id": "e544f2d0-ae93-42d5-affa-9fb85bd5bc5c", 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "class MLP(nn.Module):\n", 130 | " \"\"\" A simple MLP in Flax. This is the score function.\n", 131 | " \"\"\"\n", 132 | " hidden_dim: int = 32\n", 133 | " out_dim: int = 2\n", 134 | " n_layers: int = 2\n", 135 | "\n", 136 | " @nn.compact\n", 137 | " def __call__(self, x):\n", 138 | " for _ in range(self.n_layers):\n", 139 | " x = nn.Dense(features=self.hidden_dim)(x)\n", 140 | " x = nn.gelu(x)\n", 141 | " x = nn.Dense(features=self.out_dim)(x)\n", 142 | " return x" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": 37, 148 | "id": "337bbdfd", 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "def int_beta(t):\n", 153 | " return t\n", 154 | "\n", 155 | "def weight(t):\n", 156 | " return 1 - np.exp(-int_beta(t))\n", 157 | "\n", 158 | "@partial(jax.jit, static_argnums=(3,4,))\n", 159 | "def loss_fn(params, x, t, int_beta, score, key):\n", 160 | " mu = x * np.exp(-0.5 * int_beta(t))\n", 161 | " sigma = np.sqrt(1 - np.exp(-int_beta(t)))\n", 162 | " eps = jax.random.normal(key, shape=x.shape)\n", 163 | " y = mu + sigma * eps\n", 164 | "\n", 165 | " pred = score.apply(params, np.concatenate([y, t], -1))\n", 166 | " loss = weight(t) * np.mean((pred + eps / sigma) ** 2)\n", 167 | " return loss.mean()" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 38, 173 | "id": "7455e305", 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [ 177 | "key = jax.random.PRNGKey(0)\n", 178 | "t = np.ones((x.shape[0], 1))\n", 179 | "\n", 180 | "score = MLP(hidden_dim=128, out_dim=2, n_layers=5)\n", 181 | "params = score.init(key, np.concatenate([x, t], axis=1))" 182 | ] 183 | }, 184 | { 185 | "attachments": {}, 186 | "cell_type": "markdown", 187 | "id": "1ec7fbad", 188 | "metadata": {}, 189 | "source": [ 190 | "## Training" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": 39, 196 | "id": "c98ab327", 197 | "metadata": {}, 198 | "outputs": [], 199 | "source": [ 200 | "opt = optax.adamw(learning_rate=3e-4, weight_decay=1e-4)\n", 201 | "opt_state = opt.init(params)" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": 40, 207 | "id": "8c67c255", 208 | "metadata": {}, 209 | "outputs": [ 210 | { 211 | "name": "stderr", 212 | "output_type": "stream", 213 | "text": [ 214 | "100%|██████████| 2000/2000 [00:17<00:00, 112.91it/s, val=2.2799215] \n" 215 | ] 216 | } 217 | ], 218 | "source": [ 219 | "n_steps = 2_000\n", 220 | "n_batch = 128\n", 221 | "T = 1.\n", 222 | "\n", 223 | "with trange(n_steps) as steps:\n", 224 | " for step in steps:\n", 225 | "\n", 226 | " # Draw a random batches from x\n", 227 | " key, subkey = jax.random.split(key)\n", 228 | " idx = jax.random.choice(key, x.shape[0], shape=(n_batch,))\n", 229 | " \n", 230 | " x_batch = x[idx]\n", 231 | " t_batch = jax.random.uniform(key, shape=(x_batch.shape[0], 1), minval=0., maxval=T)\n", 232 | "\n", 233 | " loss, grads = jax.value_and_grad(loss_fn)(params, x_batch, t_batch, int_beta, score, key)\n", 234 | " updates, opt_state = opt.update(grads, opt_state, params)\n", 235 | "\n", 236 | " params = optax.apply_updates(params, updates)\n", 237 | "\n", 238 | " steps.set_postfix(val=loss)" 239 | ] 240 | }, 241 | { 242 | "cell_type": "markdown", 243 | "id": "a018c3d0", 244 | "metadata": {}, 245 | "source": [ 246 | "## Distillation" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 41, 252 | "id": "90b560b3", 253 | "metadata": {}, 254 | "outputs": [], 255 | "source": [ 256 | "key = jax.random.PRNGKey(0)\n", 257 | "t = np.ones((x.shape[0], 1))\n", 258 | "\n", 259 | "score_student = MLP(hidden_dim=128, out_dim=2, n_layers=5)\n", 260 | "params_student = score_student.init(key, np.concatenate([x, t], axis=1))" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": 45, 266 | "id": "f8ce7cb4", 267 | "metadata": {}, 268 | "outputs": [ 269 | { 270 | "name": "stderr", 271 | "output_type": "stream", 272 | "text": [ 273 | "100%|██████████| 5000/5000 [00:50<00:00, 98.17it/s, val=4.15465] \n" 274 | ] 275 | } 276 | ], 277 | "source": [ 278 | "n_steps = 5_000\n", 279 | "n_batch = 128\n", 280 | "\n", 281 | "def alpha(t):\n", 282 | " return np.exp(-0.5 * int_beta(t))\n", 283 | "\n", 284 | "def sigma(t):\n", 285 | " return np.sqrt(1 - np.exp(-int_beta(t)))\n", 286 | "\n", 287 | "@partial(jax.jit, static_argnums=(4,5,))\n", 288 | "def loss_distillation_fn(params_student, params, x, t, score_student, score, key):\n", 289 | " eps = jax.random.normal(key, shape=x.shape)\n", 290 | "\n", 291 | " t = t[:, None] \n", 292 | "\n", 293 | " z_t = alpha(t) * x + sigma(t) * eps\n", 294 | "\n", 295 | " t_p = t - 0.5 / N\n", 296 | " t_pp = t - 1 / N\n", 297 | "\n", 298 | " x_hat_t = score.apply(params, np.concatenate([z_t, t], -1))\n", 299 | " z_tp = alpha(t_p) * x_hat_t + sigma(t_p) / sigma(t) * (z_t - x_hat_t * alpha(t))\n", 300 | "\n", 301 | " x_hat_tp = score.apply(params, np.concatenate([z_tp, t], -1))\n", 302 | " z_tpp = alpha(t_pp) * x_hat_tp + sigma(t_pp) / sigma(t_p) * (z_t - x_hat_tp * alpha(t_p))\n", 303 | "\n", 304 | " x_tilde = (z_tpp - (sigma(t_pp) / sigma(t)) * z_t) / (alpha(t_pp) - (sigma(t_pp) / sigma(t)) * alpha(t))\n", 305 | "\n", 306 | " pred = score_student.apply(params_student, np.concatenate([z_t, t], -1))\n", 307 | " loss = weight(t) * np.mean((pred + (x_tilde - x * alpha(t)) / sigma(t) ** 2) ** 2)\n", 308 | " return loss.mean()\n", 309 | "\n", 310 | "N = 100\n", 311 | "\n", 312 | "with trange(n_steps) as steps:\n", 313 | " for step in steps:\n", 314 | "\n", 315 | " # Draw a random batches from x\n", 316 | " key, subkey = jax.random.split(key)\n", 317 | " idx = jax.random.choice(key, x.shape[0], shape=(n_batch,))\n", 318 | " \n", 319 | " x_batch = x[idx]\n", 320 | " t_batch = jax.random.choice(key, np.arange(1, N + 1), shape=(n_batch,)) / N\n", 321 | "\n", 322 | " loss, grads = jax.value_and_grad(loss_distillation_fn)(params_student, params, x_batch, t_batch, score_student, score, key)\n", 323 | " updates, opt_state = opt.update(grads, opt_state, params_student)\n", 324 | "\n", 325 | " params_student = optax.apply_updates(params_student, updates)\n", 326 | "\n", 327 | " steps.set_postfix(val=loss)" 328 | ] 329 | }, 330 | { 331 | "attachments": {}, 332 | "cell_type": "markdown", 333 | "id": "ca323255", 334 | "metadata": {}, 335 | "source": [ 336 | "## Sampling" 337 | ] 338 | }, 339 | { 340 | "cell_type": "code", 341 | "execution_count": null, 342 | "id": "148b0f99", 343 | "metadata": {}, 344 | "outputs": [], 345 | "source": [] 346 | } 347 | ], 348 | "metadata": { 349 | "kernelspec": { 350 | "display_name": "Python 3 (ipykernel)", 351 | "language": "python", 352 | "name": "python3" 353 | }, 354 | "language_info": { 355 | "codemirror_mode": { 356 | "name": "ipython", 357 | "version": 3 358 | }, 359 | "file_extension": ".py", 360 | "mimetype": "text/x-python", 361 | "name": "python", 362 | "nbconvert_exporter": "python", 363 | "pygments_lexer": "ipython3", 364 | "version": "3.9.13" 365 | } 366 | }, 367 | "nbformat": 4, 368 | "nbformat_minor": 5 369 | } 370 | -------------------------------------------------------------------------------- /08_discrete_walk_jump_sampling.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "# [Discrete Walk-Jump Sampling](https://arxiv.org/abs/2306.12360)\n" 9 | ] 10 | }, 11 | { 12 | "attachments": {}, 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "## Imports" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "from typing import Any, Callable, Iterable, Optional, Tuple, Union\n", 26 | "import functools\n", 27 | "from tqdm import trange\n", 28 | "\n", 29 | "import jax\n", 30 | "import jax.config\n", 31 | "jax.config.update(\"jax_enable_x64\", True)\n", 32 | "\n", 33 | "import chex\n", 34 | "import jax.numpy as jnp\n", 35 | "import flax.linen as nn\n", 36 | "from clu import parameter_overview\n", 37 | "import optax\n", 38 | "import tqdm\n", 39 | "from tensorflow_probability.substrates import jax as tfp\n", 40 | "import matplotlib.pyplot as plt\n", 41 | "import matplotlib.animation\n", 42 | "\n", 43 | "plt.rcParams[\"animation.html\"] = \"jshtml\"\n" 44 | ] 45 | }, 46 | { 47 | "attachments": {}, 48 | "cell_type": "markdown", 49 | "metadata": {}, 50 | "source": [ 51 | "## Problem Setup\n", 52 | "\n", 53 | "Fundamentally, discrete walk-jump sampling (dWJS) is a method for sampling noisy latents from an energy-based model, and denoising them with a neural network.\n", 54 | "\n", 55 | "Thus, to understand dWJS, we first need to understand energy-based models and neural empirical Bayes." 56 | ] 57 | }, 58 | { 59 | "attachments": {}, 60 | "cell_type": "markdown", 61 | "metadata": {}, 62 | "source": [ 63 | "## Energy-Based Models\n", 64 | "\n", 65 | "Energy-based models (EBMs) model the probability of a data point $y$ with an energy function $E_\\theta(y)$ parameterized by $\\theta$:\n", 66 | "$$\n", 67 | "p_\\theta(y) = \\frac{\\exp(-E_\\theta(y))}{Z(\\theta)}\n", 68 | "$$\n", 69 | "where $Z(\\theta)$ is the normalizing constant (called the partition function):\n", 70 | "$$\n", 71 | "Z(\\theta) = \\int \\exp(-E_\\theta(y)) dy\n", 72 | "$$\n", 73 | "Low energy samples have high probability, and vice versa." 74 | ] 75 | }, 76 | { 77 | "attachments": {}, 78 | "cell_type": "markdown", 79 | "metadata": {}, 80 | "source": [ 81 | "\n", 82 | "### How to train an EBM? \n", 83 | "\n", 84 | "We can train an EBM by minimizing the KL divergence between the true distribution $p$ and the model distribution $p_\\theta$. This is equivalent to maximizing the expected value of log-likelihood of the data sampled from $p(y)$:\n", 85 | "$$\n", 86 | "\\theta^* = \\argmin_\\theta \\text{KL}(p \\ || \\ p_\\theta) = \\argmax_\\theta \\mathbb{E}_{y\\sim p(y)} [\\log p_\\theta(y)] = \\argmin_\\theta \\mathbb{E}_{y\\sim p(y)} [-\\log p_\\theta(y)]\n", 87 | "$$\n", 88 | "\n", 89 | "This is all standard so far. The problem is that $Z(\\theta)$ is intractable to compute:\n", 90 | "$$\n", 91 | "\\mathbb{E}_{y\\sim p(y)} [-\\log p_\\theta(y)] = \\mathbb{E}_{y\\sim p(y)} [E_\\theta(y)] + \\log Z(\\theta) \n", 92 | "$$ \n", 93 | "If we use gradient descent to optimize $\\theta$, we need to compute:\n", 94 | "$$\n", 95 | "\\nabla_\\theta \\mathbb{E}_{y\\sim p(y)} [-\\log p_\\theta(y)] = \\mathbb{E}_{y\\sim p(y)} [\\nabla_\\theta E_\\theta(y)] + \\nabla_\\theta\\log Z(\\theta) \n", 96 | "$$ \n", 97 | "The first term is easy to compute, but the second term is (usually) intractable. The common approach is to approximate the second term via MCMC sampling:\n", 98 | "$$\n", 99 | "\\begin{aligned}\n", 100 | "\\nabla_\\theta\\log Z(\\theta) &= \\frac{\\nabla_\\theta Z(\\theta)}{Z(\\theta)} \\\\\n", 101 | "&= \\frac{\\nabla_\\theta \\int \\exp(-E_\\theta(y)) dy}{Z(\\theta)} \\\\\n", 102 | "&= \\frac{\\int \\nabla_\\theta \\exp(-E_\\theta(y)) dy}{Z(\\theta)} \\\\\n", 103 | "&= \\frac{\\int - \\nabla_\\theta E_\\theta(y) \\exp(-E_\\theta(y)) dy}{Z(\\theta)} \\\\\n", 104 | "&= \\frac{\\int - \\nabla_\\theta E_\\theta(y) Z(\\theta) p_\\theta(y) dy}{Z(\\theta)} \\\\\n", 105 | "&= \\int - \\nabla_\\theta E_\\theta(y) \\ p_\\theta(y) dy \\\\\n", 106 | "&= \\mathbb{E}_{y\\sim p_\\theta(y)} [- \\nabla_\\theta E_\\theta(y)] \\\\\n", 107 | "\\\\\n", 108 | "\\end{aligned}\n", 109 | "$$\n", 110 | "\n", 111 | "Thus, we seek to minimize:\n", 112 | "$$\n", 113 | "\\mathbb{E}_{y\\sim p(y)} [-\\log p_\\theta(y) = \n", 114 | "\\mathbb{E}_{y\\sim p(y)} [\\nabla_\\theta E_\\theta(y)] - \\mathbb{E}_{y\\sim p_\\theta(y)} [ \\nabla_\\theta E_\\theta(y)] \n", 115 | "$$\n", 116 | "We are seeking to minimize the energy of positive samples (from the data distribution) and maximize the energy of negative samples (from the model distribution). This is why this approach is also called contrastive divergence." 117 | ] 118 | }, 119 | { 120 | "attachments": {}, 121 | "cell_type": "markdown", 122 | "metadata": {}, 123 | "source": [ 124 | "# Langevin MCMC\n", 125 | "\n", 126 | "We have computed an estimator for the gradient of the negative log-likelihood (NLL) loss.\n", 127 | "\n", 128 | "Note that this estimator requires us to sample from the model distribution $p_\\theta(y)$ at each iteration. We perform this sampling from $p_\\theta(y)$ via Langevin MCMC. Langevin MCMC is similar to a noisy version of gradient ascent on the log-likelihood.\n", 129 | "\n", 130 | "Initialize a sample $y_0$ randomly.\n", 131 | "Then, for each iteration $t$, compute:\n", 132 | "$$\n", 133 | "y_{t+1} = y_t + \\delta \\nabla_{y_t} \\log p_\\theta(y_t) + \\sqrt{2\\delta} \\epsilon_t\n", 134 | "$$\n", 135 | "where $\\epsilon_t \\sim \\mathcal{N}(0, I)$.\n", 136 | "Then, as $t \\rightarrow \\infty$, $y_t$ will appear to be sampled from $p_\\theta(y)$.\n", 137 | "\n", 138 | "Note that the partition function $Z(\\theta)$ does not show up in the sampling procedure:\n", 139 | "$$\n", 140 | "\\nabla_{y_t} \\log p_\\theta(y_t) = - \\nabla_{y_t} E_\\theta(y_t)\n", 141 | "$$" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": null, 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "@functools.partial(jax.jit, static_argnames=(\"grad_log_prob_fn\", \"num_steps\"))\n", 151 | "def langevin_sample(grad_log_prob_fn: Callable[[chex.Array], float], init: chex.Array, delta: float, rng: chex.PRNGKey, num_steps: int):\n", 152 | " \"\"\"Langevin sampling from a given log probability function.\"\"\"\n", 153 | "\n", 154 | " def one_step_langevin(y_t: chex.Array, rng: chex.PRNGKey):\n", 155 | " eps = jax.random.normal(rng, y_t.shape)\n", 156 | " y_next = y_t + delta * grad_log_prob_fn(y_t) + jnp.sqrt(2 * delta) * eps\n", 157 | " return y_next, y_next\n", 158 | "\n", 159 | " sampling_rngs = jax.random.split(rng, num_steps)\n", 160 | " _, samples = jax.lax.scan(one_step_langevin, init, xs=sampling_rngs, length=len(sampling_rngs))\n", 161 | " return samples" 162 | ] 163 | }, 164 | { 165 | "attachments": {}, 166 | "cell_type": "markdown", 167 | "metadata": {}, 168 | "source": [ 169 | "## Example Time!\n", 170 | "\n", 171 | "To illustrate the math, we will use a simple 1D example. Let's assume that the data distribution is a mixture of two Gaussians centered at -1 and 1, respectively:\n", 172 | "$$\n", 173 | "p(y) = \\frac{1}{2} \\mathcal{N}(y; -1, 0.5) + \\frac{1}{2} \\mathcal{N}(y; 1, 0.5)\n", 174 | "$$" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": null, 180 | "metadata": {}, 181 | "outputs": [], 182 | "source": [ 183 | "tfd = tfp.distributions\n", 184 | "px = tfd.Categorical(probs=[0.5, 0.5])\n", 185 | "p = tfd.Mixture(\n", 186 | " cat=px,\n", 187 | " components=[\n", 188 | " tfd.Normal(loc=-1., scale=0.5),\n", 189 | " tfd.Normal(loc=+1., scale=0.5),\n", 190 | " ]\n", 191 | ")\n", 192 | "\n", 193 | "# Plot the PDF.\n", 194 | "y = jnp.linspace(-5., 5., int(1e4))\n", 195 | "plt.grid()\n", 196 | "plt.plot(y, p.prob(y))\n", 197 | "plt.xlabel('y')\n", 198 | "plt.ylabel('p(y)')\n", 199 | "plt.title('True PDF')\n", 200 | "plt.show();" 201 | ] 202 | }, 203 | { 204 | "attachments": {}, 205 | "cell_type": "markdown", 206 | "metadata": {}, 207 | "source": [ 208 | "We can visualize the Langevin MCMC sampling process below:" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": null, 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [ 217 | "rng = jax.random.PRNGKey(0)\n", 218 | "grad_log_prob_fn = jax.jit(jax.grad(lambda y: p.log_prob(y).squeeze()))" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": null, 224 | "metadata": {}, 225 | "outputs": [], 226 | "source": [ 227 | "sampling_rng, rng = jax.random.split(rng)\n", 228 | "delta = 0.1\n", 229 | "langevin_samples_from_2 = langevin_sample(grad_log_prob_fn, init=2 * jnp.ones((1,)), delta=delta, rng=sampling_rng, num_steps=500)\n", 230 | "\n", 231 | "fig, ax = plt.subplots()\n", 232 | "ax.grid()\n", 233 | "ax.set_xlim(-5., 5.)\n", 234 | "ax.set_ylim(0., 1.)\n", 235 | "ax.set_xlabel('y')\n", 236 | "ax.set_ylabel('p(y)')\n", 237 | "scatter = ax.scatter([], [], lw=2, color='C0')\n", 238 | "\n", 239 | "def animate(i: int):\n", 240 | " offsets = [langevin_samples_from_2[:i], p.prob(langevin_samples_from_2[:i])]\n", 241 | " offsets = jnp.stack(offsets, axis=-1).squeeze()\n", 242 | " scatter.set_offsets(offsets)\n", 243 | " # Adjust opacity.\n", 244 | " if i > 0:\n", 245 | " scatter.set_alpha(jnp.arange(i) ** 2 / i ** 2)\n", 246 | " ax.set_title(r'Langevin Sampling Starting from 2 with $\\delta={}$: Step {}'.format(delta, i))\n", 247 | " return (scatter,)\n", 248 | "\n", 249 | "anim = matplotlib.animation.FuncAnimation(fig, animate, frames=100, interval=100, blit=True)\n", 250 | "plt.close()\n", 251 | "anim" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": null, 257 | "metadata": {}, 258 | "outputs": [], 259 | "source": [ 260 | "sampling_rng, rng = jax.random.split(rng)\n", 261 | "langevin_samples_from_neg_2 = langevin_sample(grad_log_prob_fn, init=-2*jnp.ones((1,)), delta=delta, rng=sampling_rng, num_steps=500)\n", 262 | "\n", 263 | "fig, ax = plt.subplots()\n", 264 | "ax.grid()\n", 265 | "ax.set_xlim(-5., 5.)\n", 266 | "ax.set_ylim(0., 1.)\n", 267 | "ax.set_xlabel('y')\n", 268 | "ax.set_ylabel('p(y)')\n", 269 | "scatter = ax.scatter([], [], lw=2, c='C1')\n", 270 | "\n", 271 | "def animate(i: int):\n", 272 | " offsets = [langevin_samples_from_neg_2[:i], p.prob(langevin_samples_from_neg_2[:i])]\n", 273 | " offsets = jnp.stack(offsets, axis=-1).squeeze()\n", 274 | " scatter.set_offsets(offsets)\n", 275 | " # Adjust opacity.\n", 276 | " # scatter.set_sizes(100 * jnp.ones(i))\n", 277 | " if i > 0:\n", 278 | " scatter.set_alpha(jnp.arange(i) ** 2 / i ** 2)\n", 279 | " ax.set_title(r'Langevin Sampling Starting from -2 with $\\delta={}$: Step {}'.format(delta, i))\n", 280 | " return (scatter,)\n", 281 | "\n", 282 | "anim = matplotlib.animation.FuncAnimation(fig, animate, frames=100, interval=100, blit=True)\n", 283 | "plt.close()\n", 284 | "anim" 285 | ] 286 | }, 287 | { 288 | "attachments": {}, 289 | "cell_type": "markdown", 290 | "metadata": {}, 291 | "source": [ 292 | "We can check that the histogram of the samples from Langevin MCMC match the data distribution somewhat closely:" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": null, 298 | "metadata": {}, 299 | "outputs": [], 300 | "source": [ 301 | "plt.hist(langevin_samples_from_2.flatten(), bins=20, density=True, alpha=0.5, color='C0', label='Starting from 2')\n", 302 | "plt.hist(langevin_samples_from_neg_2.flatten(), bins=20, density=True, alpha=0.5, color='C1', label='Starting from -2')\n", 303 | "plt.grid()\n", 304 | "plt.xlabel('y')\n", 305 | "plt.ylabel('p(y)')\n", 306 | "plt.legend()\n", 307 | "plt.title('Histograms of Langevin Samples')\n", 308 | "plt.show();" 309 | ] 310 | }, 311 | { 312 | "attachments": {}, 313 | "cell_type": "markdown", 314 | "metadata": {}, 315 | "source": [ 316 | "For our EBM, we will use a simple 2-layer neural network for the energy function:\n", 317 | "$$\n", 318 | "E_\\theta(y) = W_1\\text{softplus}(W_0y + b_0) + b_1\n", 319 | "$$" 320 | ] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "execution_count": null, 325 | "metadata": {}, 326 | "outputs": [], 327 | "source": [ 328 | "class EnergyBasedModel(nn.Module):\n", 329 | " \"\"\"A simple energy-based model.\"\"\"\n", 330 | " hidden_size: int\n", 331 | "\n", 332 | " @nn.compact\n", 333 | " def __call__(self, y: chex.Array) -> chex.Array:\n", 334 | " if len(y.shape) <= 1:\n", 335 | " y = jnp.expand_dims(y, axis=0)\n", 336 | " y = nn.Dense(self.hidden_size)(y)\n", 337 | " y = jax.nn.softplus(y)\n", 338 | " y = nn.Dense(self.hidden_size)(y)\n", 339 | " y = jax.nn.softplus(y)\n", 340 | " y = nn.Dense(1)(y)\n", 341 | " y = jnp.squeeze(y, axis=-1)\n", 342 | " return y\n", 343 | "\n", 344 | "# Initialize the model.\n", 345 | "model = EnergyBasedModel(hidden_size=10)\n", 346 | "dummy_input = jnp.ones((1,))\n", 347 | "init_params = model.init(rng, dummy_input)\n", 348 | "energy_fn = jax.jit(model.apply)\n", 349 | "\n", 350 | "# Overview of the model parameters.\n", 351 | "print(parameter_overview.get_parameter_overview(init_params))" 352 | ] 353 | }, 354 | { 355 | "attachments": {}, 356 | "cell_type": "markdown", 357 | "metadata": {}, 358 | "source": [ 359 | "We can visualize the unnormlized probability distribution of the EBM below:" 360 | ] 361 | }, 362 | { 363 | "cell_type": "code", 364 | "execution_count": null, 365 | "metadata": {}, 366 | "outputs": [], 367 | "source": [ 368 | "y = jnp.linspace(-10., 10., int(1e4))\n", 369 | "y_probs = jax.vmap(lambda y: jnp.exp(-energy_fn(init_params, y)))(y)\n", 370 | "y_probs /= y_probs.sum() * (y[-1] - y[0]) / len(y)\n", 371 | "plt.grid()\n", 372 | "plt.plot(y, y_probs, color='C2')\n", 373 | "plt.xlabel('y')\n", 374 | "plt.ylabel('p_model(y)')\n", 375 | "plt.title('(Approximately Normalized) Model PDF')\n", 376 | "plt.show();" 377 | ] 378 | }, 379 | { 380 | "cell_type": "code", 381 | "execution_count": null, 382 | "metadata": {}, 383 | "outputs": [], 384 | "source": [ 385 | "def create_grad_log_prob_fn(params: optax.Params, energy_fn: Callable[[optax.Params, chex.Array], float]) -> Callable[[chex.Array], chex.Array]:\n", 386 | " \"\"\"Creates a function that computes the gradient of the log probability under the EBM.\"\"\"\n", 387 | " def grad_log_prob_fn(y: chex.Array) -> chex.Array:\n", 388 | " return -jax.grad(lambda y: energy_fn(params, y).squeeze())(y)\n", 389 | " return grad_log_prob_fn\n", 390 | "\n", 391 | "sampling_rng, rng = jax.random.split(rng)\n", 392 | "langevin_samples_from_model = langevin_sample(create_grad_log_prob_fn(init_params, energy_fn), init=jnp.zeros((1,)), delta=1, rng=sampling_rng, num_steps=10000)\n", 393 | "plt.hist(langevin_samples_from_model.flatten(), bins=20, density=True, alpha=0.5, color='C2', label='Model Samples')\n", 394 | "plt.grid()\n", 395 | "plt.xlabel('y')\n", 396 | "plt.ylabel('p(y)')\n", 397 | "plt.legend()\n", 398 | "plt.title('Histograms of Langevin Samples')\n", 399 | "plt.show();" 400 | ] 401 | }, 402 | { 403 | "attachments": {}, 404 | "cell_type": "markdown", 405 | "metadata": {}, 406 | "source": [ 407 | "We now train the EBM using the gradient estimator from above. We use the last sample from Langevin MCMC as the negative sample.\n", 408 | "We can simply use automatic differentiation to compute the gradient of the loss!" 409 | ] 410 | }, 411 | { 412 | "cell_type": "code", 413 | "execution_count": null, 414 | "metadata": {}, 415 | "outputs": [], 416 | "source": [ 417 | "@functools.partial(jax.jit, static_argnames=(\"energy_fn\", \"num_sampling_steps\", \"take_every_sample\", \"burn_in_samples\"))\n", 418 | "def ebm_loss_fn(params: optax.Params, energy_fn: Callable[[optax.Params, chex.Array], float], y_true_samples: chex.Array, rng: chex.PRNGKey,\n", 419 | " delta: float, num_sampling_steps: int, take_every_sample: int, burn_in_samples: int) -> float:\n", 420 | " \"\"\"Computes the EBM loss function.\"\"\"\n", 421 | " init_rng, rng = jax.random.split(rng)\n", 422 | " init = jax.random.normal(init_rng, y_true_samples[0].shape)\n", 423 | " sampling_rng, rng = jax.random.split(rng)\n", 424 | " y_model_samples = langevin_sample(create_grad_log_prob_fn(params, energy_fn), init=init, delta=delta, rng=sampling_rng, num_steps=num_sampling_steps)\n", 425 | " y_model_samples = y_model_samples[burn_in_samples:]\n", 426 | " y_model_samples = y_model_samples[::take_every_sample]\n", 427 | " # We don't differentiate through the sampling procedure.\n", 428 | " y_model_samples = jax.lax.stop_gradient(y_model_samples)\n", 429 | " energy_fn_vmapped = jax.vmap(lambda y: energy_fn(params, y))\n", 430 | " return energy_fn_vmapped(y_true_samples).mean() - energy_fn_vmapped(y_model_samples).mean()" 431 | ] 432 | }, 433 | { 434 | "attachments": {}, 435 | "cell_type": "markdown", 436 | "metadata": {}, 437 | "source": [ 438 | "We are ready to train our model!" 439 | ] 440 | }, 441 | { 442 | "cell_type": "code", 443 | "execution_count": null, 444 | "metadata": {}, 445 | "outputs": [], 446 | "source": [ 447 | "def train_ebm_model(init_params: optax.Params, rng: chex.PRNGKey, num_training_steps: int, **loss_kwargs) -> optax.Params:\n", 448 | " \"\"\"Train the EBM model using the Adam optimizer.\"\"\"\n", 449 | " @jax.jit\n", 450 | " def train_step(params: optax.Params, opt_state: optax.OptState, y_true_samples: chex.Array, rng: chex.PRNGKey) -> Tuple[optax.Params, optax.OptState, float]:\n", 451 | " loss, grad = jax.value_and_grad(ebm_loss_fn, has_aux=False)(params, energy_fn, y_true_samples, rng, **loss_kwargs)\n", 452 | " updates, opt_state = tx.update(grad, opt_state)\n", 453 | " params = optax.apply_updates(params, updates)\n", 454 | " return params, opt_state, loss\n", 455 | "\n", 456 | " tx = optax.adam(5e-4)\n", 457 | " opt_state = tx.init(init_params)\n", 458 | "\n", 459 | " params_at_steps = {\n", 460 | " 0: init_params,\n", 461 | " }\n", 462 | " \n", 463 | " params = init_params\n", 464 | " for step in tqdm.trange(num_training_steps):\n", 465 | " step_rng, samples_rng, rng = jax.random.split(rng, num=3)\n", 466 | " y_true_samples = p.sample(32, seed=samples_rng)\n", 467 | "\n", 468 | " params, opt_state, loss = train_step(params, opt_state, y_true_samples, step_rng)\n", 469 | "\n", 470 | " # Log the training progress. \n", 471 | " if step % 500 == 0:\n", 472 | " params_at_steps[step] = params\n", 473 | " \n", 474 | " if step % 2000 == 0:\n", 475 | " print('Step {}: Loss = {}'.format(step, loss))\n", 476 | " \n", 477 | " return params, params_at_steps\n", 478 | "\n", 479 | "train_rng = jax.random.PRNGKey(0)\n", 480 | "energy_params, energy_params_at_steps = train_ebm_model(init_params, train_rng, num_training_steps=30000, delta=0.1, num_sampling_steps=10, take_every_sample=2, burn_in_samples=1)" 481 | ] 482 | }, 483 | { 484 | "attachments": {}, 485 | "cell_type": "markdown", 486 | "metadata": {}, 487 | "source": [ 488 | "We can visualize the learnt (unnormalized) probability distribution of the EBM:" 489 | ] 490 | }, 491 | { 492 | "cell_type": "code", 493 | "execution_count": null, 494 | "metadata": {}, 495 | "outputs": [], 496 | "source": [ 497 | "fig, ax = plt.subplots()\n", 498 | "ax.grid()\n", 499 | "ax.set_xlim(-5., 5.)\n", 500 | "ax.set_ylim(0., 1.)\n", 501 | "ax.set_xlabel('y')\n", 502 | "ax.set_ylabel('p_model(y)')\n", 503 | "line, = ax.plot([], [], lw=2, color='C2')\n", 504 | "steps = sorted(energy_params_at_steps.keys())\n", 505 | "\n", 506 | "def animate(i: int):\n", 507 | " y = jnp.linspace(-5., 5., int(1e3))\n", 508 | " y_probs = jax.vmap(lambda y: jnp.exp(-energy_fn(energy_params_at_steps[steps[i]], y)))(y)\n", 509 | " y_probs /= y_probs.sum() * (y[-1] - y[0]) / len(y)\n", 510 | " line.set_data(y, y_probs)\n", 511 | " ax.set_title('(Approximately Normalized) Model PDF: Step {}'.format(steps[i]))\n", 512 | " return (line,)\n", 513 | "\n", 514 | "anim = matplotlib.animation.FuncAnimation(fig, animate, frames=len(steps), interval=100, blit=True)\n", 515 | "plt.close()\n", 516 | "anim" 517 | ] 518 | }, 519 | { 520 | "attachments": {}, 521 | "cell_type": "markdown", 522 | "metadata": {}, 523 | "source": [ 524 | "## Neural Empirical Bayes" 525 | ] 526 | }, 527 | { 528 | "attachments": {}, 529 | "cell_type": "markdown", 530 | "metadata": {}, 531 | "source": [ 532 | "This section follows the work of [Saeed Saremi and Aapo Hyvarinen](https://arxiv.org/abs/1903.02334).\n", 533 | "\n", 534 | "Consider an observation denoted by the random variable $X \\in \\mathbb{R}^d$, and a noisy observation of $X$ denoted by $Y \\in \\mathbb{R}^d$:\n", 535 | "$$\n", 536 | "Y = X + \\epsilon, \\quad \\epsilon \\sim \\mathcal{N}(0, \\sigma^2 I_d)\n", 537 | "$$\n", 538 | "Thus, we are given that $p_{Y|X}$ is a Gaussian.\n", 539 | "\n", 540 | "Given the observation $Y = y$, the Bayes' least squares estimator for $X$ is:\n", 541 | "$$\n", 542 | "\\hat{X}(y) = \\mathbb{E}[X \\ | \\ Y = y]\n", 543 | "$$\n", 544 | "Computing this estimator requires knowledge of $p_{X|Y} \\propto p_{Y|X}\\cdot p_X$ by Bayes' rule.\n", 545 | "Thus, it seems that we need to know $p_X$ to compute this estimator.\n", 546 | "\n", 547 | "The trick (figured out by [Robbins](https://link.springer.com/chapter/10.1007/978-1-4612-0919-5_26) and [Miyasawa](https://mit.primo.exlibrisgroup.com/discovery/openurl?institution=01MIT_INST&vid=01MIT_INST:MIT&rft.epage=188&rft_val_fmt=info:ofi%2Ffmt:kev:mtx:journal&rft.stitle=B%20INT%20STATIST%20INST&rft.volume=38&rfr_id=info:sid%2Fwebofscience.com:WOS:WOS&rft.jtitle=BULLETIN%20OF%20THE%20INTERNATIONAL%20STATISTICAL%20INSTITUTE&rft.aufirst=K&rft.genre=article&rft.issue=4&rft.pages=181-188&url_ctx_fmt=info:ofi%2Ffmt:kev:mtx:ctx&rft.aulast=MIYASAWA&url_ver=Z39.88-2004&rft.auinit=K&rft.date=1960&rft.spage=181&rft.atitle=AN%20EMPIRICAL%20BAYES%20ESTIMATOR%20OF%20THE%20MEAN%20OF%20A%20NORMAL%20POPULATION&rft.issn=0074-8609)) turns out that we can compute this estimator without knowing $p_X$!\n", 548 | "\n", 549 | "For all $x$ and $y$, we have:\n", 550 | "$$\n", 551 | "p_{Y|X}(y|x) = \\mathcal{N}(y; x, \\sigma^2) = \\frac{1}{(2\\pi\\sigma^2)^{\\frac{d}{2}}} \\exp\\left(-\\frac{\\|y - x\\|^2}{2\\sigma^2} \\right)\n", 552 | "$$\n", 553 | "so:\n", 554 | "$$\n", 555 | "\\begin{aligned}\n", 556 | "\\nabla_y p_{Y|X}(y|x) &= -\\frac{y - x}{\\sigma^2} p_{Y|X}(y|x) \\\\\n", 557 | "\\implies (x - y) p_{Y|X}(y|x) &= \\sigma^2 \\nabla_y p_{Y|X}(y|x) \\\\\n", 558 | "\\implies \\int (x - y) p_{Y|X}(y|x) p_X(x) dx &= \\sigma^2 \\int \\nabla_y p_{Y|X}(y|x) p_X(x) dx\n", 559 | "\\end{aligned}\n", 560 | "$$\n", 561 | "Note that, by Bayes' rule: $p_{Y|X}(y|x) p_X(x) = p_{X,Y}(x, y) = p_{X|Y}(x|y) p_Y(y)$. \n", 562 | "Also, by definition of the marginals: $\\int p_{X,Y}(x, y) dx = p_{Y}(y)$. \n", 563 | "For the left-hand side, we have:\n", 564 | "$$\n", 565 | "\\begin{aligned}\n", 566 | "\\int (x - y) p_{Y|X}(y|x) p_X(x) dx &= \\int x p_{Y|X}(y|x) p_X(x) dx - \\int y p_{Y|X}(y|x) p_X(x) dx \\\\\n", 567 | "&= \\int x p_{X,Y}(x, y) dx - \\int y p_{X,Y}(x, y) dx \\\\\n", 568 | "&= p_Y(y) \\left(\\int x p_{X|Y}(x|y) dx - y \\int p_{X|Y}(x|y) dx\\right) \\\\\n", 569 | "&= p_Y(y) \\left(\\mathbb{E}[X \\ | \\ Y = y] - y \\right) \\\\\n", 570 | "&= p_Y(y) \\left(\\hat{X}(y) - y \\right)\n", 571 | "\\end{aligned}\n", 572 | "$$\n", 573 | "For the right-hand side, we have:\n", 574 | "$$\n", 575 | "\\begin{aligned}\n", 576 | "\\sigma^2 \\int \\nabla_y p_{Y|X}(y|x) p_X(x) dx &= \\sigma^2 \\nabla_y \\int p_{Y|X}(y|x) p_X(x) dx \\\\\n", 577 | "&= \\sigma^2 \\nabla_y \\int p_{X,Y}(x, y) dx \\\\\n", 578 | "&= \\sigma^2 \\nabla_y p_{Y}(y)\n", 579 | "\\end{aligned}\n", 580 | "$$\n", 581 | "Thus,\n", 582 | "$$\n", 583 | "\\begin{aligned}\n", 584 | "p_Y(y) \\left(\\hat{X}(y) - y \\right) &= \\sigma^2 \\nabla_y p_{Y}(y)\n", 585 | "\\\\\n", 586 | "\\implies \\hat{X}(y) &= y + \\sigma^2 \\frac{\\nabla_y p_{Y}(y)}{p_Y(y)}\n", 587 | "\\\\\n", 588 | "\\implies \\hat{X}(y) &= y + \\sigma^2 \\nabla_y \\log p_{Y}(y)\n", 589 | "\\end{aligned}\n", 590 | "$$\n", 591 | "Thus, the estimator $\\hat{X}(y)$ can be computed without knowledge of $p_X$, only the knowledge of the score $\\nabla_y \\log p_{Y}(y)$ is required.\n", 592 | "\n", 593 | "Now, there are two approaches to learning the score function $\\nabla_y \\log p_{Y}(y)$.\n", 594 | "The first is to approximate $p_{Y}(y)$ by an EBM:\n", 595 | "$$\n", 596 | "p_{Y}(y) \\approx \\frac{\\exp(-E_\\theta(y))}{Z(\\theta)} \\implies \\nabla_y \\log p_{Y}(y) = - \\nabla_y E_\\theta(y)\n", 597 | "$$\n", 598 | "The EBM can be learned using contrastive divergence described before.\n", 599 | "Then, the learned EBM can be used as a denoiser, by denoising $Y$ to obtain an estimate of $X$:\n", 600 | "$$\n", 601 | "\\hat{X}(y) = y - \\sigma^2 \\nabla_y E_\\theta(y)\n", 602 | "$$\n", 603 | "\n", 604 | "The second approach, proposed in this Discrete Walk-Jump Sampling paper, is to directly parametrize the score function by a 'denoising' neural network:\n", 605 | "$$\n", 606 | "g_\\phi(y) \\approx \\nabla_y \\log p_{Y}(y)\n", 607 | "$$\n", 608 | "Denoising $Y$ as before:\n", 609 | "$$\n", 610 | "\\hat{X}(y) = y + \\sigma^2 g_\\phi(y)\n", 611 | "$$\n", 612 | "The denoising network $g_\\phi$ can be trained from observations of $X$ and adding noise to obtain examples $Y$:\n", 613 | "* Sample $X_i \\sim p_X$.\n", 614 | "* Sample $\\epsilon_j \\sim \\mathcal{N}(0, \\sigma^2 I_d)$.\n", 615 | "* Compute $Y_{ij} = X_i + \\epsilon_j$.\n", 616 | "* Optimize $\\phi$:\n", 617 | "$$\n", 618 | "\\phi^* = \\argmin_\\phi \\sum_{i,j} \\|X_i - (Y_{ij} + g_\\phi(Y_{ij})) \\|^2\n", 619 | "$$" 620 | ] 621 | }, 622 | { 623 | "cell_type": "code", 624 | "execution_count": null, 625 | "metadata": {}, 626 | "outputs": [], 627 | "source": [ 628 | "# Define the score model that predicts the score.\n", 629 | "class ScoreNetwork(nn.Module):\n", 630 | " \"\"\"A simple score neural network.\"\"\"\n", 631 | " hidden_size: int\n", 632 | "\n", 633 | " @nn.compact\n", 634 | " def __call__(self, y: chex.Array) -> chex.Array:\n", 635 | " if len(y.shape) <= 1:\n", 636 | " y = jnp.expand_dims(y, axis=0)\n", 637 | " init_dims = y.shape[-1]\n", 638 | " y = nn.Dense(self.hidden_size)(y)\n", 639 | " y = jax.nn.softplus(y)\n", 640 | " y = nn.Dense(self.hidden_size)(y)\n", 641 | " y = jax.nn.softplus(y)\n", 642 | " y = nn.Dense(init_dims)(y)\n", 643 | " return y\n", 644 | "\n", 645 | "# Initialize the score model.\n", 646 | "score_model = ScoreNetwork(hidden_size=10)\n", 647 | "dummy_input = jnp.ones((1,))\n", 648 | "init_score_params = score_model.init(rng, dummy_input)\n", 649 | "score_fn = jax.jit(score_model.apply)\n", 650 | "\n", 651 | "# Overview of the model parameters.\n", 652 | "print(parameter_overview.get_parameter_overview(init_score_params))" 653 | ] 654 | }, 655 | { 656 | "cell_type": "code", 657 | "execution_count": null, 658 | "metadata": {}, 659 | "outputs": [], 660 | "source": [ 661 | "@functools.partial(jax.jit, static_argnames=(\"score_fn\", \"num_noise_samples\"))\n", 662 | "def score_loss_fn(\n", 663 | " params: optax.Params, score_fn: Callable[[optax.Params, chex.Array], chex.Array], x_true_samples: chex.Array, rng: chex.PRNGKey, noise_std: float, num_noise_samples: int) -> float:\n", 664 | " \"\"\"Computes the denoising loss.\"\"\"\n", 665 | " assert len(x_true_samples.shape) == 2\n", 666 | " num_true_samples, num_dims = x_true_samples.shape\n", 667 | "\n", 668 | " noise_rng, rng = jax.random.split(rng)\n", 669 | " noise = noise_std * jax.random.normal(noise_rng, (num_noise_samples, num_dims))\n", 670 | " y_samples = x_true_samples[:, None, :] + noise[None, ...]\n", 671 | " assert y_samples.shape == (num_true_samples, num_noise_samples, num_dims)\n", 672 | "\n", 673 | " predictions = score_fn(params, y_samples)\n", 674 | " assert predictions.shape == (num_true_samples, num_noise_samples, num_dims)\n", 675 | "\n", 676 | " l2_loss = jax.vmap(lambda x, ys, preds: jnp.linalg.norm(x - (ys + noise_std ** 2 * preds), axis=-1).mean())(x_true_samples, y_samples, predictions)\n", 677 | " assert l2_loss.shape == (num_true_samples,)\n", 678 | "\n", 679 | " return l2_loss.mean()\n", 680 | " " 681 | ] 682 | }, 683 | { 684 | "cell_type": "code", 685 | "execution_count": null, 686 | "metadata": {}, 687 | "outputs": [], 688 | "source": [ 689 | "def train_score_model(init_params: optax.Params, rng: chex.PRNGKey, num_training_steps: int, **loss_kwargs) -> optax.Params:\n", 690 | " \"\"\"Train the score model using the Adam optimizer.\"\"\"\n", 691 | " @jax.jit\n", 692 | " def train_step(params: optax.Params, opt_state: optax.OptState, x_true_samples: chex.Array, rng: chex.PRNGKey) -> Tuple[optax.Params, optax.OptState, float]:\n", 693 | " loss, grad = jax.value_and_grad(score_loss_fn, has_aux=False)(params, score_fn, x_true_samples, rng, **loss_kwargs)\n", 694 | " updates, opt_state = tx.update(grad, opt_state)\n", 695 | " params = optax.apply_updates(params, updates)\n", 696 | " return params, opt_state, loss\n", 697 | "\n", 698 | " tx = optax.adam(5e-4)\n", 699 | " opt_state = tx.init(init_params)\n", 700 | "\n", 701 | " params_at_steps = {\n", 702 | " 0: init_params,\n", 703 | " }\n", 704 | " \n", 705 | " params = init_params\n", 706 | " for step in tqdm.trange(num_training_steps + 1):\n", 707 | " step_rng, samples_rng, rng = jax.random.split(rng, num=3)\n", 708 | " x_true_samples = px.sample(32, seed=samples_rng)\n", 709 | " x_true_samples = x_true_samples[:, None]\n", 710 | "\n", 711 | " params, opt_state, loss = train_step(params, opt_state, x_true_samples, step_rng)\n", 712 | "\n", 713 | " # Log the training progress. \n", 714 | " if step % 500 == 0:\n", 715 | " params_at_steps[step] = params\n", 716 | " \n", 717 | " if step % 2000 == 0:\n", 718 | " print('Step {}: Loss = {}'.format(step, loss))\n", 719 | " \n", 720 | " return params, params_at_steps\n", 721 | "\n", 722 | "train_rng = jax.random.PRNGKey(0)\n", 723 | "noise_std = 0.5\n", 724 | "score_params, score_params_at_steps = train_score_model(init_score_params, train_rng, num_training_steps=30000, noise_std=noise_std, num_noise_samples=10)" 725 | ] 726 | }, 727 | { 728 | "attachments": {}, 729 | "cell_type": "markdown", 730 | "metadata": {}, 731 | "source": [ 732 | "Let's check that the score function works!" 733 | ] 734 | }, 735 | { 736 | "cell_type": "code", 737 | "execution_count": null, 738 | "metadata": {}, 739 | "outputs": [], 740 | "source": [ 741 | "noise_rng, rng = jax.random.split(rng)\n", 742 | "noise = noise_std * jax.random.normal(noise_rng, (10, 1))\n", 743 | "x_true = jnp.asarray([[-1.], [1.]])\n", 744 | "y = x_true[None, :] + noise[:, None]\n", 745 | "y = y.transpose((1, 0, 2)).reshape((-1, 1))\n", 746 | "preds = score_fn(score_params, y)\n", 747 | "x = y + noise_std ** 2 * preds\n", 748 | "labels = jnp.where(x < 0.5, 0, 1)\n", 749 | "plt.grid()\n", 750 | "plt.scatter(y, x, color=['C0' if label == 0 else 'C1' for label in labels])\n", 751 | "plt.xlabel('y')\n", 752 | "plt.ylabel('x')\n", 753 | "plt.title('Denoising Model Predictions')\n", 754 | "plt.show();" 755 | ] 756 | }, 757 | { 758 | "attachments": {}, 759 | "cell_type": "markdown", 760 | "metadata": {}, 761 | "source": [ 762 | "## Walk-Jump Sampling\n", 763 | "\n", 764 | "The idea behind Walk-Jump Sampling is that it is easier to walk in the space of noisy observations $Y$ than in the space of clean observations $X$. The noise helps connect different modes of the distribution. Given any noisy observation $Y$, we can always go back to the clean observation $X$ by denoising.\n", 765 | "\n", 766 | "* Walk in noisy observation space with Langevin MCMC:\n", 767 | "$$\n", 768 | " y_t = y_{t-1} + \\delta \\nabla_{y_{t-1}} \\log p_Y(y_{t-1}) + \\sqrt{2\\delta} \\epsilon_t\n", 769 | "$$\n", 770 | "* Jump to clean observation (at any time $\\tau$):\n", 771 | "$$\n", 772 | " x_\\tau = y_\\tau + \\sigma^2 \\nabla_{y_\\tau} \\log p_Y(y_\\tau)\n", 773 | "$$\n", 774 | "\n", 775 | "Note that both the walk and jump steps need an estimate of the score function $\\nabla_{y} \\log p_Y(y)$.\n", 776 | "We have choices for how we parametrize these in each step. Here, they find that using an EBM for the walker, and a denoising network for the jumper works best:\n", 777 | "* EBM:\n", 778 | "$$\n", 779 | "p_Y(y) = \\frac{\\exp(-E_\\theta(y))}{Z(\\theta)}\n", 780 | "$$\n", 781 | "* Denoiser:\n", 782 | "$$\n", 783 | "\\nabla_y \\log p_{Y}(y) \\approx g_\\phi(y)\n", 784 | "$$\n", 785 | "\n", 786 | "Note that unlike diffusion, every single sample from walk-jump sampling is approximately from the data distribution $p_X$." 787 | ] 788 | }, 789 | { 790 | "cell_type": "code", 791 | "execution_count": null, 792 | "metadata": {}, 793 | "outputs": [], 794 | "source": [ 795 | "@functools.partial(jax.jit, static_argnames=(\"energy_fn\", \"score_fn\", \"num_steps\", \"noise_std\"))\n", 796 | "def walk_jump_sampling(energy_fn_params: optax.Params, energy_fn: Callable[[optax.Params, chex.Array], float], score_fn_params: optax.Params, score_fn: Callable[[optax.Params, chex.Array], chex.Array], rng: chex.PRNGKey, delta: float, num_steps: int, noise_std: float):\n", 797 | " \"\"\"Performs walk-jump sampling.\"\"\"\n", 798 | " grad_log_prob_fn = create_grad_log_prob_fn(energy_fn_params, energy_fn)\n", 799 | " noisy_observations = langevin_sample(grad_log_prob_fn, init=jnp.zeros((1,)), delta=delta, rng=rng, num_steps=num_steps)\n", 800 | " scores = score_fn(score_fn_params, noisy_observations)\n", 801 | " denoised_observations = noisy_observations + noise_std ** 2 * scores\n", 802 | " return noisy_observations, denoised_observations\n", 803 | "\n", 804 | "\n", 805 | "walk_jump_sampling_rng, rng = jax.random.split(rng)\n", 806 | "noisy_observations, denoised_observations = walk_jump_sampling(energy_params, energy_fn, score_params, score_fn, rng=walk_jump_sampling_rng, delta=0.1, num_steps=1000, noise_std=noise_std)" 807 | ] 808 | }, 809 | { 810 | "cell_type": "code", 811 | "execution_count": null, 812 | "metadata": {}, 813 | "outputs": [], 814 | "source": [ 815 | "plt.hist(noisy_observations.flatten(), bins=20, density=True, alpha=0.5, color='C0', label='Noisy Observations')\n", 816 | "plt.legend()\n", 817 | "plt.grid()\n", 818 | "plt.title('Histogram of Noisy Observations')\n", 819 | "plt.show();" 820 | ] 821 | }, 822 | { 823 | "cell_type": "code", 824 | "execution_count": null, 825 | "metadata": {}, 826 | "outputs": [], 827 | "source": [ 828 | "plt.hist(denoised_observations.flatten(), bins=20, density=True, alpha=0.5, color='C1', label='Denoised Observations')\n", 829 | "plt.legend()\n", 830 | "plt.grid()\n", 831 | "plt.title('Histogram of Denoised Observations')\n", 832 | "plt.show();" 833 | ] 834 | } 835 | ], 836 | "metadata": { 837 | "kernelspec": { 838 | "display_name": ".venv", 839 | "language": "python", 840 | "name": "python3" 841 | }, 842 | "language_info": { 843 | "codemirror_mode": { 844 | "name": "ipython", 845 | "version": 3 846 | }, 847 | "file_extension": ".py", 848 | "mimetype": "text/x-python", 849 | "name": "python", 850 | "nbconvert_exporter": "python", 851 | "pygments_lexer": "ipython3", 852 | "version": "3.11.4" 853 | }, 854 | "orig_nbformat": 4 855 | }, 856 | "nbformat": 4, 857 | "nbformat_minor": 2 858 | } 859 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright 2023 Siddharth Mishra-Sharma 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Minified generative models 2 | 3 | [![License: MIT](https://img.shields.io/badge/License-MIT-red.svg)](https://opensource.org/licenses/MIT) 4 | 5 | Bare-bones, minified versions of some common (and not-so-common) generative models, for pedagogical purposes. 6 | 7 | ## Installation 8 | 9 | First, install JAX following these [instructions](https://jax.readthedocs.io/en/latest/installation.html). For CPU-only, this is as simple as: 10 | ```bash 11 | pip install "jax[cpu]" 12 | ``` 13 | Additional libraries: 14 | ```bash 15 | pip install flax optax diffrax tensorflow_probability scikit-learn tqdm matplotlib 16 | ``` 17 | 18 | ## List of notebooks 19 | 20 | 1. [β-VAEs](01_beta_vae.ipynb): Variational autoencoders and basic rate-distortion theory. 21 | 2. [Diffusion models](02_diffusion.ipynb): Diffusion models, covering likelihood-based and score-matching interpretations. 22 | 3. [Normalizing flows](03_normalizing_flows.ipynb) (WiP annotations): Normalizing flows, specifically [RealNVP](https://arxiv.org/abs/1605.08803). 23 | 4. [Continuous normalizing flows](03_continuous_normalizing_flows.ipynb): Continuous-time normalizing flows from e.g., [Grathwohl et al 2018](https://arxiv.org/abs/1810.01367). 24 | 5. [Consistency models](04_consistency_models.ipynb) (WiP annotations): Consistency models from [Song et al 2023](https://arxiv.org/abs/2303.01469). 25 | 6. [Flow matching](05_flow_matching.ipynb) (WiP annotations): From [Lipman et al 2022](https://arxiv.org/abs/2210.02747); see also [Albergo et al 2023](https://arxiv.org/abs/2303.08797). 26 | 7. [Diffusion distillation](06_diffusion_distillation.ipynb) (WiP): Progressive ([Salimans et al 2022](https://arxiv.org/abs/2202.00512)) and consistency ([Song et al 2023](https://arxiv.org/abs/2303.01469)) distillation. 27 | 8. [Discrete walk-jump sampling](07_discrete_walk_jump_sampling.ipynb) (WiP): From [Frey et al 2023](https://arxiv.org/abs/2306.12360). 28 | 29 | ## Inspiration 30 | 31 | ![assets/midwit.pngs](assets/midwit.png) 32 | -------------------------------------------------------------------------------- /assets/midwit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smsharma/minified-generative-models/3354e368f83f81694bb833ba4492fa4f93ae18b5/assets/midwit.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.0.0 2 | appnope==0.1.3 3 | asttokens==2.4.0 4 | backcall==0.2.0 5 | cached-property==1.5.2 6 | chex==0.1.7 7 | cloudpickle==3.0.0 8 | clu==0.0.10 9 | comm==0.1.4 10 | contextlib2==21.6.0 11 | contourpy==1.1.1 12 | cycler==0.12.1 13 | debugpy==1.8.0 14 | decorator==5.1.1 15 | diffrax==0.4.1 16 | dm-tree==0.1.8 17 | equinox==0.11.1 18 | etils==1.5.1 19 | executing==2.0.0 20 | flax==0.7.4 21 | fonttools==4.43.1 22 | fsspec==2023.9.2 23 | gast==0.5.4 24 | importlib-resources==6.1.0 25 | ipykernel==6.25.2 26 | ipython==8.16.1 27 | jax==0.4.14 28 | jaxlib==0.4.14 29 | jaxtyping==0.2.23 30 | jedi==0.19.1 31 | joblib==1.3.2 32 | jupyter_client==8.4.0 33 | jupyter_core==5.4.0 34 | kiwisolver==1.4.5 35 | markdown-it-py==3.0.0 36 | matplotlib==3.8.0 37 | matplotlib-inline==0.1.6 38 | mdurl==0.1.2 39 | ml-collections==0.1.1 40 | ml-dtypes==0.3.1 41 | msgpack==1.0.7 42 | nest-asyncio==1.5.8 43 | numpy==1.26.1 44 | opt-einsum==3.3.0 45 | optax==0.1.7 46 | orbax-checkpoint==0.1.6 47 | packaging==23.2 48 | parso==0.8.3 49 | pexpect==4.8.0 50 | pickleshare==0.7.5 51 | Pillow==10.1.0 52 | platformdirs==3.11.0 53 | prompt-toolkit==3.0.39 54 | protobuf==4.24.4 55 | psutil==5.9.6 56 | ptyprocess==0.7.0 57 | pure-eval==0.2.2 58 | Pygments==2.16.1 59 | pyparsing==3.1.1 60 | python-dateutil==2.8.2 61 | PyYAML==6.0.1 62 | pyzmq==25.1.1 63 | rich==13.6.0 64 | scikit-learn==1.3.1 65 | scipy==1.11.3 66 | six==1.16.0 67 | stack-data==0.6.3 68 | tensorflow-probability==0.22.0 69 | tensorstore==0.1.45 70 | threadpoolctl==3.2.0 71 | toolz==0.12.0 72 | tornado==6.3.3 73 | tqdm==4.66.1 74 | traitlets==5.11.2 75 | typeguard==2.13.3 76 | typing_extensions==4.5.0 77 | wcwidth==0.2.8 78 | wrapt==1.15.0 79 | zipp==3.17.0 80 | --------------------------------------------------------------------------------