├── Bayesian CNN - Experiments ├── Effect of KL-Weight.ipynb └── bayesian_cnn_real_data.ipynb ├── Brief Introduction to Uncertainty └── Brief Introduction to Uncertainty - Medium.ipynb ├── README.md ├── Simple Fully Probabilistic Bayesian CNN └── HelloWorld_FullyProbabilistic_BayesianCNN.ipynb ├── Uncertainty in DL - Aleatoric Uncertainty └── Aleatoric Uncertainty.ipynb └── Uncertainty in DL - Epistemic Uncertainty └── Modeling Epistemic Uncertainty .ipynb /README.md: -------------------------------------------------------------------------------- 1 | # Medium_Notebooks_English 2 | 3 | This repo only contains the codes that used in the articles. 4 | 5 | ## Brief Introduction to Uncertainty 6 | https://towardsdatascience.com/uncertainty-in-deep-learning-brief-introduction-1f9a5de3ae04 7 | 8 | ## Uncertainty in DL - Aleatoric Uncertainty 9 | https://towardsdatascience.com/uncertainty-in-deep-learning-aleatoric-uncertainty-and-maximum-likelihood-estimation-c7449ee13712 10 | 11 | ## Uncertainty in DL - Epistemic Uncertainty 12 | https://towardsdatascience.com/uncertainty-in-deep-learning-epistemic-uncertainty-and-bayes-by-backprop-e6353eeadebb 13 | 14 | ## Uncertainty In Deep Learning — Bayesian CNN | TensorFlow Probability 15 | https://towardsdatascience.com/uncertainty-in-deep-learning-bayesian-cnn-tensorflow-probability-758d7482bef6 16 | -------------------------------------------------------------------------------- /Simple Fully Probabilistic Bayesian CNN/HelloWorld_FullyProbabilistic_BayesianCNN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "f3b2807e", 6 | "metadata": { 7 | "id": "f3b2807e" 8 | }, 9 | "source": [ 10 | "# Imports" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "id": "mHTvZop4bz-Q", 17 | "metadata": { 18 | "colab": { 19 | "base_uri": "https://localhost:8080/" 20 | }, 21 | "id": "mHTvZop4bz-Q", 22 | "outputId": "881ee4f8-a939-49dc-a520-8d140c7eb365" 23 | }, 24 | "outputs": [ 25 | { 26 | "name": "stdout", 27 | "output_type": "stream", 28 | "text": [ 29 | "Sat Mar 12 19:50:45 2022 \n", 30 | "+-----------------------------------------------------------------------------+\n", 31 | "| NVIDIA-SMI 460.32.03 Driver Version: 460.32.03 CUDA Version: 11.2 |\n", 32 | "|-------------------------------+----------------------+----------------------+\n", 33 | "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", 34 | "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", 35 | "| | | MIG M. |\n", 36 | "|===============================+======================+======================|\n", 37 | "| 0 Tesla P100-PCIE... Off | 00000000:00:04.0 Off | 0 |\n", 38 | "| N/A 57C P0 32W / 250W | 0MiB / 16280MiB | 0% Default |\n", 39 | "| | | N/A |\n", 40 | "+-------------------------------+----------------------+----------------------+\n", 41 | " \n", 42 | "+-----------------------------------------------------------------------------+\n", 43 | "| Processes: |\n", 44 | "| GPU GI CI PID Type Process name GPU Memory |\n", 45 | "| ID ID Usage |\n", 46 | "|=============================================================================|\n", 47 | "| No running processes found |\n", 48 | "+-----------------------------------------------------------------------------+\n" 49 | ] 50 | } 51 | ], 52 | "source": [ 53 | "!nvidia-smi" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 2, 59 | "id": "caed0563", 60 | "metadata": { 61 | "colab": { 62 | "base_uri": "https://localhost:8080/" 63 | }, 64 | "id": "caed0563", 65 | "outputId": "879d0bed-5b89-444c-98e6-b5fc4d821e0e" 66 | }, 67 | "outputs": [ 68 | { 69 | "data": { 70 | "text/plain": [ 71 | "('2.8.0', '0.16.0')" 72 | ] 73 | }, 74 | "execution_count": 2, 75 | "metadata": {}, 76 | "output_type": "execute_result" 77 | } 78 | ], 79 | "source": [ 80 | "import tensorflow as tf\n", 81 | "import tensorflow_datasets as tfds\n", 82 | "\n", 83 | "import tensorflow_probability as tfp\n", 84 | "\n", 85 | "tfd = tfp.distributions\n", 86 | "tfpl = tfp.layers\n", 87 | "\n", 88 | "import numpy as np\n", 89 | "import matplotlib.pyplot as plt\n", 90 | "\n", 91 | "tf.__version__, tfp.__version__" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "id": "c581280d", 97 | "metadata": { 98 | "id": "c581280d" 99 | }, 100 | "source": [ 101 | "# Reparameterization Layers" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 3, 107 | "id": "fef0862d", 108 | "metadata": { 109 | "id": "fef0862d" 110 | }, 111 | "outputs": [], 112 | "source": [ 113 | "# divergence_fn = lambda q, p, _: tfd.kl_divergence(q, p) / total_samples\n", 114 | "\n", 115 | "# tfpl.Convolution2DReparameterization(\n", 116 | "# input_shape = (128,6), \n", 117 | "# filters = 8, \n", 118 | "# kernel_size = 16,\n", 119 | "# activation = 'relu',\n", 120 | "# kernel_prior_fn = tfpl.default_multivariate_normal_fn,\n", 121 | "# kernel_posterior_fn = tfpl.default_mean_field_normal_fn(is_singular=False),\n", 122 | "# kernel_divergence_fn = divergence_fn,\n", 123 | "# bias_prior_fn = None,\n", 124 | "# bias_posterior_fn = tfpl.default_mean_field_normal_fn(is_singular=True),\n", 125 | "# bias_divergence_fn = divergence_fn)" 126 | ] 127 | }, 128 | { 129 | "cell_type": "markdown", 130 | "id": "p3-wVAnR7jDT", 131 | "metadata": { 132 | "id": "p3-wVAnR7jDT" 133 | }, 134 | "source": [ 135 | "## Load MNIST Data" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 4, 141 | "id": "ae0bde05", 142 | "metadata": { 143 | "colab": { 144 | "base_uri": "https://localhost:8080/" 145 | }, 146 | "id": "ae0bde05", 147 | "outputId": "bc83c949-f1bc-4785-a450-e3b9441a36a3" 148 | }, 149 | "outputs": [ 150 | { 151 | "data": { 152 | "text/plain": [ 153 | "(TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name=None),\n", 154 | " TensorSpec(shape=(None, 10), dtype=tf.float32, name=None))" 155 | ] 156 | }, 157 | "execution_count": 4, 158 | "metadata": {}, 159 | "output_type": "execute_result" 160 | } 161 | ], 162 | "source": [ 163 | "train_ds, test_ds = tfds.load('mnist',\n", 164 | " split = ['train', 'test'],\n", 165 | " as_supervised = True)\n", 166 | "\n", 167 | "# Normalize and one-hot\n", 168 | "def ohe_normalize(images, labels):\n", 169 | " images = tf.cast(images, tf.float32)\n", 170 | " images = tf.divide(images, 255.0)\n", 171 | "\n", 172 | " labels = tf.one_hot(labels, 10)\n", 173 | "\n", 174 | " return images, labels\n", 175 | "\n", 176 | "train_ds = train_ds.batch(128).map(ohe_normalize).shuffle(128) \\\n", 177 | " .prefetch(tf.data.AUTOTUNE)\n", 178 | "test_ds = test_ds.batch(128).map(ohe_normalize) \\\n", 179 | " .prefetch(tf.data.AUTOTUNE)\n", 180 | "\n", 181 | "# Outputs a tuple of two elements, first --> Images, second --> Labels\n", 182 | "train_ds.element_spec " 183 | ] 184 | }, 185 | { 186 | "cell_type": "markdown", 187 | "id": "y70CWoDN9D9L", 188 | "metadata": { 189 | "id": "y70CWoDN9D9L" 190 | }, 191 | "source": [ 192 | "## First Normal CNN" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": null, 198 | "id": "89255ad4", 199 | "metadata": { 200 | "colab": { 201 | "base_uri": "https://localhost:8080/" 202 | }, 203 | "id": "89255ad4", 204 | "outputId": "b85b32d5-fcbb-4848-8227-cf7e83b9dbe5" 205 | }, 206 | "outputs": [], 207 | "source": [ 208 | "normal_cnn = tf.keras.Sequential([\n", 209 | " tf.keras.layers.Conv2D(16, 3, activation = 'swish', \n", 210 | " input_shape = (28, 28, 1),\n", 211 | " padding = 'same'),\n", 212 | " tf.keras.layers.MaxPooling2D(2),\n", 213 | " \n", 214 | " tf.keras.layers.Conv2D(32, 3, activation = 'swish',\n", 215 | " padding = 'same'),\n", 216 | " tf.keras.layers.MaxPooling2D(2),\n", 217 | "\n", 218 | " tf.keras.layers.Conv2D(64, 3, activation = 'swish',\n", 219 | " padding = 'same'),\n", 220 | " tf.keras.layers.MaxPooling2D(2),\n", 221 | "\n", 222 | " tf.keras.layers.Conv2D(128, 3, activation = 'swish',\n", 223 | " padding = 'same'),\n", 224 | " tf.keras.layers.GlobalMaxPooling2D(),\n", 225 | " \n", 226 | " tf.keras.layers.Dense(10, activation = 'softmax')\n", 227 | "])\n", 228 | "\n", 229 | "normal_cnn.compile(optimizer = 'adam', loss = 'categorical_crossentropy',\n", 230 | " metrics = ['acc'])\n", 231 | "\n", 232 | "normal_cnn.summary() # Total params: 98,442" 233 | ] 234 | }, 235 | { 236 | "cell_type": "markdown", 237 | "id": "PVg7Wctu-DhJ", 238 | "metadata": { 239 | "id": "PVg7Wctu-DhJ" 240 | }, 241 | "source": [ 242 | "## Bayesian CNN" 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": 6, 248 | "id": "2AyekIbf4D-c", 249 | "metadata": { 250 | "colab": { 251 | "base_uri": "https://localhost:8080/" 252 | }, 253 | "id": "2AyekIbf4D-c", 254 | "outputId": "44be2a02-f616-4b28-bb16-822b24dc9493" 255 | }, 256 | "outputs": [ 257 | { 258 | "name": "stdout", 259 | "output_type": "stream", 260 | "text": [ 261 | "reinterpreted_batch_ndims: 0:\n", 262 | "batch_shape: (4,) event_shape: () Sample shape: (4,)\n", 263 | "Samples: [-1.4488162 0.28503537 -2.6792827 1.6637591 ] \n", 264 | "\n", 265 | "reinterpreted_batch_ndims: 1:\n", 266 | "batch_shape: () event_shape: (4,) Sample shape: (4,)\n", 267 | "Samples: [-0.59052044 0.00105313 -0.59601015 0.643041 ] \n", 268 | "\n" 269 | ] 270 | } 271 | ], 272 | "source": [ 273 | "# Revise reinterpreted_batch_ndims arg. Needed for custom prior & posterior.\n", 274 | "shape = (4, )\n", 275 | "dtype = tf.float32\n", 276 | "distribution = tfd.Normal(loc = tf.zeros(shape, dtype),\n", 277 | " scale = tf.ones(shape, dtype))\n", 278 | "# batch_ndims = tf.size(distribution.batch_shape_tensor())\n", 279 | "\n", 280 | "for i in range(len(shape) + 1):\n", 281 | " print('reinterpreted_batch_ndims: %d:' %(i))\n", 282 | " independent_dist = tfd.Independent(distribution,\n", 283 | " reinterpreted_batch_ndims = i)\n", 284 | " samples = independent_dist.sample()\n", 285 | " print('batch_shape: {}' \n", 286 | " ' event_shape: {}' \n", 287 | " ' Sample shape: {}'.format(independent_dist._batch_shape(),\n", 288 | " independent_dist._event_shape(),\n", 289 | " samples.shape))\n", 290 | " print('Samples:', samples.numpy(), '\\n')" 291 | ] 292 | }, 293 | { 294 | "cell_type": "markdown", 295 | "id": "zMF0aS_rsVfW", 296 | "metadata": { 297 | "id": "zMF0aS_rsVfW" 298 | }, 299 | "source": [ 300 | "### How to Provide Custom Prior?" 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": 7, 306 | "id": "e3631cc2", 307 | "metadata": { 308 | "id": "e3631cc2" 309 | }, 310 | "outputs": [], 311 | "source": [ 312 | "# For Reparameterization Layers\n", 313 | "def custom_mvn_prior(dtype, shape, name, trainable, add_variable_fn):\n", 314 | " distribution = tfd.Normal(loc = 0.1 * tf.ones(shape, dtype),\n", 315 | " scale = 0.003 * tf.ones(shape, dtype))\n", 316 | " batch_ndims = tf.size(distribution.batch_shape_tensor())\n", 317 | " \n", 318 | " independent_distribution = tfd.Independent(distribution,\n", 319 | " reinterpreted_batch_ndims = batch_ndims)\n", 320 | " return independent_distribution" 321 | ] 322 | }, 323 | { 324 | "cell_type": "markdown", 325 | "id": "wxFvgCGhsaXn", 326 | "metadata": { 327 | "id": "wxFvgCGhsaXn" 328 | }, 329 | "source": [ 330 | "### What if KL cannot be Computed Analytically?" 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": 8, 336 | "id": "PG-aAU7WFdmT", 337 | "metadata": { 338 | "id": "PG-aAU7WFdmT" 339 | }, 340 | "outputs": [], 341 | "source": [ 342 | "# The default posterior is Normal and if we use a laplace prior, we need to\n", 343 | "# approximate the KL. If we try to compute KL with that 2 distributions, we will get:\n", 344 | "# Error:\n", 345 | "# No KL(distribution_a || distribution_b) registered for distribution_a type Normal and distribution_b type Laplace\n", 346 | "\n", 347 | "# Call arguments received:\n", 348 | "# • inputs=tf.Tensor(shape=(None, 28, 28, 1), dtype=float32)\n", 349 | "\n", 350 | "def approximate_kl(q, p, q_tensor):\n", 351 | " return tf.reduce_mean(q.log_prob(q_tensor) - p.log_prob(q_tensor))\n", 352 | "\n", 353 | "total_samples = 60000\n", 354 | "divergence_fn = lambda q, p, q_tensor : approximate_kl(q, p, q_tensor) / total_samples" 355 | ] 356 | }, 357 | { 358 | "cell_type": "code", 359 | "execution_count": 9, 360 | "id": "DoR1GIWOZ46u", 361 | "metadata": { 362 | "id": "DoR1GIWOZ46u" 363 | }, 364 | "outputs": [], 365 | "source": [ 366 | "def conv_reparameterization_layer(filters, kernel_size, activation):\n", 367 | " # For simplicity, we use default prior and posterior.\n", 368 | " # In the next parts, we will use custom mixture prior and posteriors.\n", 369 | " return tfpl.Convolution2DReparameterization(\n", 370 | " filters = filters,\n", 371 | " kernel_size = kernel_size,\n", 372 | " activation = activation, \n", 373 | " padding = 'same',\n", 374 | " kernel_posterior_fn = tfpl.default_mean_field_normal_fn(is_singular=False),\n", 375 | " kernel_prior_fn = tfpl.default_multivariate_normal_fn,\n", 376 | " \n", 377 | " bias_prior_fn = tfpl.default_multivariate_normal_fn,\n", 378 | " bias_posterior_fn = tfpl.default_mean_field_normal_fn(is_singular=False),\n", 379 | " \n", 380 | " kernel_divergence_fn = divergence_fn,\n", 381 | " bias_divergence_fn = divergence_fn)" 382 | ] 383 | }, 384 | { 385 | "cell_type": "markdown", 386 | "id": "117ed43a", 387 | "metadata": {}, 388 | "source": [ 389 | "### Create BCNN" 390 | ] 391 | }, 392 | { 393 | "cell_type": "code", 394 | "execution_count": 10, 395 | "id": "4WF_lDEcBOml", 396 | "metadata": { 397 | "colab": { 398 | "base_uri": "https://localhost:8080/" 399 | }, 400 | "id": "4WF_lDEcBOml", 401 | "outputId": "bee399a7-a51a-4efa-9cc5-866226583ab0" 402 | }, 403 | "outputs": [ 404 | { 405 | "name": "stderr", 406 | "output_type": "stream", 407 | "text": [ 408 | "/usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/layers/util.py:102: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use `layer.add_weight` method instead.\n", 409 | " trainable=trainable)\n", 410 | "/usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/layers/util.py:112: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use `layer.add_weight` method instead.\n", 411 | " trainable=trainable)\n" 412 | ] 413 | }, 414 | { 415 | "name": "stdout", 416 | "output_type": "stream", 417 | "text": [ 418 | "Model: \"sequential_1\"\n", 419 | "_________________________________________________________________\n", 420 | " Layer (type) Output Shape Param # \n", 421 | "=================================================================\n", 422 | " conv2d_reparameterization ( (None, 28, 28, 16) 320 \n", 423 | " Conv2DReparameterization) \n", 424 | " \n", 425 | " max_pooling2d_3 (MaxPooling (None, 14, 14, 16) 0 \n", 426 | " 2D) \n", 427 | " \n", 428 | " conv2d_reparameterization_1 (None, 14, 14, 32) 9280 \n", 429 | " (Conv2DReparameterization) \n", 430 | " \n", 431 | " max_pooling2d_4 (MaxPooling (None, 7, 7, 32) 0 \n", 432 | " 2D) \n", 433 | " \n", 434 | " conv2d_reparameterization_2 (None, 7, 7, 64) 36992 \n", 435 | " (Conv2DReparameterization) \n", 436 | " \n", 437 | " max_pooling2d_5 (MaxPooling (None, 3, 3, 64) 0 \n", 438 | " 2D) \n", 439 | " \n", 440 | " conv2d_reparameterization_3 (None, 3, 3, 128) 147712 \n", 441 | " (Conv2DReparameterization) \n", 442 | " \n", 443 | " global_max_pooling2d_1 (Glo (None, 128) 0 \n", 444 | " balMaxPooling2D) \n", 445 | " \n", 446 | " dense_reparameterization (D (None, 10) 2580 \n", 447 | " enseReparameterization) \n", 448 | " \n", 449 | " one_hot_categorical (OneHot ((None, 10), 0 \n", 450 | " Categorical) (None, 10)) \n", 451 | " \n", 452 | "=================================================================\n", 453 | "Total params: 196,884\n", 454 | "Trainable params: 196,884\n", 455 | "Non-trainable params: 0\n", 456 | "_________________________________________________________________\n" 457 | ] 458 | } 459 | ], 460 | "source": [ 461 | "bayesian_cnn = tf.keras.Sequential([\n", 462 | " tf.keras.layers.InputLayer((28, 28, 1)),\n", 463 | " \n", 464 | " conv_reparameterization_layer(16, 3, 'swish'),\n", 465 | " tf.keras.layers.MaxPooling2D(2),\n", 466 | " \n", 467 | " conv_reparameterization_layer(32, 3, 'swish'),\n", 468 | " tf.keras.layers.MaxPooling2D(2),\n", 469 | "\n", 470 | " conv_reparameterization_layer(64, 3, 'swish'),\n", 471 | " tf.keras.layers.MaxPooling2D(2),\n", 472 | "\n", 473 | " conv_reparameterization_layer(128, 3, 'swish'),\n", 474 | " tf.keras.layers.GlobalMaxPooling2D(),\n", 475 | " \n", 476 | " tfpl.DenseReparameterization(\n", 477 | " units = tfpl.OneHotCategorical.params_size(10), activation = None,\n", 478 | " kernel_posterior_fn = tfpl.default_mean_field_normal_fn(is_singular=False),\n", 479 | " kernel_prior_fn = tfpl.default_multivariate_normal_fn,\n", 480 | " \n", 481 | " bias_prior_fn = tfpl.default_multivariate_normal_fn,\n", 482 | " bias_posterior_fn = tfpl.default_mean_field_normal_fn(is_singular=False),\n", 483 | " \n", 484 | " kernel_divergence_fn = divergence_fn,\n", 485 | " bias_divergence_fn = divergence_fn),\n", 486 | " tfpl.OneHotCategorical(10)\n", 487 | "])\n", 488 | "\n", 489 | "def nll(y_true, y_pred):\n", 490 | " return -y_pred.log_prob(y_true)\n", 491 | "\n", 492 | "bayesian_cnn.compile(loss=nll,\n", 493 | " optimizer=tf.keras.optimizers.Adam(0.001),\n", 494 | " metrics=['accuracy'])\n", 495 | "\n", 496 | "bayesian_cnn.summary()" 497 | ] 498 | }, 499 | { 500 | "cell_type": "code", 501 | "execution_count": 11, 502 | "id": "wagyyldNFJD_", 503 | "metadata": { 504 | "colab": { 505 | "base_uri": "https://localhost:8080/" 506 | }, 507 | "id": "wagyyldNFJD_", 508 | "outputId": "2b342338-4686-4af6-9976-0213e87ceac1" 509 | }, 510 | "outputs": [ 511 | { 512 | "name": "stdout", 513 | "output_type": "stream", 514 | "text": [ 515 | "Epoch 1/12\n", 516 | "469/469 [==============================] - 17s 18ms/step - loss: 4.7674 - accuracy: 0.6792 - val_loss: 4.1597 - val_accuracy: 0.9095\n", 517 | "Epoch 2/12\n", 518 | "469/469 [==============================] - 5s 11ms/step - loss: 4.0503 - accuracy: 0.9244 - val_loss: 3.9239 - val_accuracy: 0.9428\n", 519 | "Epoch 3/12\n", 520 | "469/469 [==============================] - 5s 11ms/step - loss: 3.8208 - accuracy: 0.9484 - val_loss: 3.7356 - val_accuracy: 0.9456\n", 521 | "Epoch 4/12\n", 522 | "469/469 [==============================] - 5s 11ms/step - loss: 3.6138 - accuracy: 0.9596 - val_loss: 3.5012 - val_accuracy: 0.9679\n", 523 | "Epoch 5/12\n", 524 | "469/469 [==============================] - 5s 11ms/step - loss: 3.4163 - accuracy: 0.9673 - val_loss: 3.3117 - val_accuracy: 0.9726\n", 525 | "Epoch 6/12\n", 526 | "469/469 [==============================] - 5s 11ms/step - loss: 3.2263 - accuracy: 0.9712 - val_loss: 3.1341 - val_accuracy: 0.9744\n", 527 | "Epoch 7/12\n", 528 | "469/469 [==============================] - 5s 11ms/step - loss: 3.0467 - accuracy: 0.9740 - val_loss: 2.9602 - val_accuracy: 0.9758\n", 529 | "Epoch 8/12\n", 530 | "469/469 [==============================] - 5s 11ms/step - loss: 2.8846 - accuracy: 0.9754 - val_loss: 2.8000 - val_accuracy: 0.9793\n", 531 | "Epoch 9/12\n", 532 | "469/469 [==============================] - 6s 12ms/step - loss: 2.7258 - accuracy: 0.9773 - val_loss: 2.6527 - val_accuracy: 0.9770\n", 533 | "Epoch 10/12\n", 534 | "469/469 [==============================] - 5s 11ms/step - loss: 2.5832 - accuracy: 0.9777 - val_loss: 2.5219 - val_accuracy: 0.9793\n", 535 | "Epoch 11/12\n", 536 | "469/469 [==============================] - 5s 11ms/step - loss: 2.4495 - accuracy: 0.9786 - val_loss: 2.3883 - val_accuracy: 0.9798\n", 537 | "Epoch 12/12\n", 538 | "469/469 [==============================] - 5s 11ms/step - loss: 2.3281 - accuracy: 0.9791 - val_loss: 2.2741 - val_accuracy: 0.9795\n" 539 | ] 540 | }, 541 | { 542 | "data": { 543 | "text/plain": [ 544 | "" 545 | ] 546 | }, 547 | "execution_count": 11, 548 | "metadata": {}, 549 | "output_type": "execute_result" 550 | } 551 | ], 552 | "source": [ 553 | "bayesian_cnn.fit(train_ds, epochs = 12, validation_data = test_ds)" 554 | ] 555 | }, 556 | { 557 | "cell_type": "code", 558 | "execution_count": 12, 559 | "id": "-B_y_YNFcLMQ", 560 | "metadata": { 561 | "id": "-B_y_YNFcLMQ" 562 | }, 563 | "outputs": [], 564 | "source": [ 565 | "example_images = []\n", 566 | "example_labels = []\n", 567 | "\n", 568 | "for x, y in test_ds.take(10):\n", 569 | " example_images.append(x.numpy())\n", 570 | " example_labels.append(y.numpy())\n", 571 | "\n", 572 | "example_images = np.concatenate(example_images, axis = 0) \n", 573 | "example_labels = np.concatenate(example_labels, axis = 0) " 574 | ] 575 | }, 576 | { 577 | "cell_type": "code", 578 | "execution_count": 13, 579 | "id": "rgVSJidtGdVu", 580 | "metadata": { 581 | "id": "rgVSJidtGdVu" 582 | }, 583 | "outputs": [], 584 | "source": [ 585 | "def analyse_model_prediction(image, label = None, forward_passes = 10):\n", 586 | " if label is not None:\n", 587 | " label = np.argmax(label, axis = -1)\n", 588 | " \n", 589 | " extracted_probabilities = np.empty(shape=(forward_passes, 10))\n", 590 | " extracted_std = np.empty(shape=(forward_passes, 10))\n", 591 | " for i in range(forward_passes):\n", 592 | " model_output_distribution = bayesian_cnn(tf.expand_dims(image, \n", 593 | " axis = 0))\n", 594 | " extracted_probabilities[i] = model_output_distribution.mean().numpy().flatten()\n", 595 | " extracted_std[i] = model_output_distribution.stddev().numpy().flatten()\n", 596 | "\n", 597 | " fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(16, 6),\n", 598 | " gridspec_kw={'width_ratios': [2, 4]})\n", 599 | " plt.xticks(fontsize = 16, rotation = 45)\n", 600 | " plt.yticks(fontsize = 16)\n", 601 | "\n", 602 | " # Show the image and the true label if provided.\n", 603 | " ax1.imshow(image.squeeze(), cmap='gray')\n", 604 | " ax1.axis('off')\n", 605 | " if label is not None:\n", 606 | " ax1.set_title('True Label: {}'.format(str(label)), fontsize = 20)\n", 607 | " else:\n", 608 | " ax1.set_title('True Label Not Given', fontsize = 20)\n", 609 | " \n", 610 | " # Obtain the 95% prediction interval.\n", 611 | " # extracted_probabilities.shape = (forward_passes, 10)\n", 612 | " # So if we sample from the model 100 times, there will be 100 different\n", 613 | " # values for each of the 10 classes. \n", 614 | " # We get the interval for each of the classes independently.\n", 615 | " pct_2p5 = np.array([np.percentile(extracted_probabilities[:, i], \n", 616 | " 2.5) for i in range(10)])\n", 617 | " pct_97p5 = np.array([np.percentile(extracted_probabilities[:, i], \n", 618 | " 97.5) for i in range(10)]) \n", 619 | "\n", 620 | " # Std also contains 100 different values. We take median across the column\n", 621 | " # to obtain a single value for each of the class label.\n", 622 | " extracted_std = np.median(extracted_std, axis = 0)\n", 623 | " highest_var_label = np.argmax(extracted_std, axis = -1)\n", 624 | " if label is not None:\n", 625 | " print('Label %d has the highest std in this'\n", 626 | " ' prediction with the value %.3f' %(highest_var_label,\n", 627 | " extracted_std[highest_var_label]))\n", 628 | " else:\n", 629 | " print('Std Array:', extracted_std) \n", 630 | " \n", 631 | " bar = ax2.bar(np.arange(10), pct_97p5, color='red')\n", 632 | " if label is not None:\n", 633 | " bar[int(label)].set_color('green')\n", 634 | " \n", 635 | " ax2.bar(np.arange(10), pct_2p5-0.02, color='white', \n", 636 | " linewidth=4, edgecolor='white')\n", 637 | " ax2.set_xticks(np.arange(10))\n", 638 | " \n", 639 | " ax2.set_ylim([0, 1])\n", 640 | " ax2.set_ylabel('Probability', fontsize = 18)\n", 641 | " ax2.set_title(\"Model's Probabilities\", fontsize = 20)\n", 642 | " plt.show()" 643 | ] 644 | }, 645 | { 646 | "cell_type": "code", 647 | "execution_count": 18, 648 | "id": "ADKmTDUQsL55", 649 | "metadata": { 650 | "colab": { 651 | "base_uri": "https://localhost:8080/", 652 | "height": 420 653 | }, 654 | "id": "ADKmTDUQsL55", 655 | "outputId": "2cdf0bb3-4384-4c46-a10b-0cb21900b935" 656 | }, 657 | "outputs": [ 658 | { 659 | "name": "stdout", 660 | "output_type": "stream", 661 | "text": [ 662 | "Label 8 has the highest std in this prediction with the value 0.157\n" 663 | ] 664 | }, 665 | { 666 | "data": { 667 | "image/png": "\n", 668 | "text/plain": [ 669 | "
" 670 | ] 671 | }, 672 | "metadata": { 673 | "needs_background": "light" 674 | }, 675 | "output_type": "display_data" 676 | } 677 | ], 678 | "source": [ 679 | "analyse_model_prediction(example_images[284], example_labels[284])" 680 | ] 681 | }, 682 | { 683 | "cell_type": "code", 684 | "execution_count": 14, 685 | "id": "lDnT6Koxc2vE", 686 | "metadata": { 687 | "colab": { 688 | "base_uri": "https://localhost:8080/", 689 | "height": 420 690 | }, 691 | "id": "lDnT6Koxc2vE", 692 | "outputId": "f5b61032-98e2-4733-996f-6944c2e721fa" 693 | }, 694 | "outputs": [ 695 | { 696 | "name": "stdout", 697 | "output_type": "stream", 698 | "text": [ 699 | "Label 0 has the highest std in this prediction with the value 0.001\n" 700 | ] 701 | }, 702 | { 703 | "data": { 704 | "image/png": "\n", 705 | "text/plain": [ 706 | "
" 707 | ] 708 | }, 709 | "metadata": { 710 | "needs_background": "light" 711 | }, 712 | "output_type": "display_data" 713 | } 714 | ], 715 | "source": [ 716 | "analyse_model_prediction(example_images[50], example_labels[50])" 717 | ] 718 | }, 719 | { 720 | "cell_type": "code", 721 | "execution_count": 15, 722 | "id": "o2J1vKp2c-io", 723 | "metadata": { 724 | "colab": { 725 | "base_uri": "https://localhost:8080/", 726 | "height": 420 727 | }, 728 | "id": "o2J1vKp2c-io", 729 | "outputId": "d5fe9a66-3880-468b-fbd2-82bddbea559c" 730 | }, 731 | "outputs": [ 732 | { 733 | "name": "stdout", 734 | "output_type": "stream", 735 | "text": [ 736 | "Label 0 has the highest std in this prediction with the value 0.278\n" 737 | ] 738 | }, 739 | { 740 | "data": { 741 | "image/png": "\n", 742 | "text/plain": [ 743 | "
" 744 | ] 745 | }, 746 | "metadata": { 747 | "needs_background": "light" 748 | }, 749 | "output_type": "display_data" 750 | } 751 | ], 752 | "source": [ 753 | "noise_vector = np.random.uniform(size = (28, 28, 1), low = 0, high = 0.5)\n", 754 | "noisy_image = np.clip(example_images[50] + noise_vector, 0, 1)\n", 755 | "analyse_model_prediction(noisy_image, example_labels[50])" 756 | ] 757 | }, 758 | { 759 | "cell_type": "code", 760 | "execution_count": 16, 761 | "id": "4wtcYDJmoXYN", 762 | "metadata": { 763 | "colab": { 764 | "base_uri": "https://localhost:8080/", 765 | "height": 420 766 | }, 767 | "id": "4wtcYDJmoXYN", 768 | "outputId": "88ec4ab7-1b15-4e6f-fbb5-9f560e401e64" 769 | }, 770 | "outputs": [ 771 | { 772 | "name": "stdout", 773 | "output_type": "stream", 774 | "text": [ 775 | "Label 0 has the highest std in this prediction with the value 0.434\n" 776 | ] 777 | }, 778 | { 779 | "data": { 780 | "image/png": "\n", 781 | "text/plain": [ 782 | "
" 783 | ] 784 | }, 785 | "metadata": { 786 | "needs_background": "light" 787 | }, 788 | "output_type": "display_data" 789 | } 790 | ], 791 | "source": [ 792 | "noise_vector = np.random.uniform(size = (28, 28, 1), low = 0, high = 0.5)\n", 793 | "noisy_image = np.clip(example_images[50] + noise_vector*2, 0, 1)\n", 794 | "analyse_model_prediction(noisy_image, example_labels[50])" 795 | ] 796 | }, 797 | { 798 | "cell_type": "code", 799 | "execution_count": 17, 800 | "id": "cKDgiPHZotP1", 801 | "metadata": { 802 | "colab": { 803 | "base_uri": "https://localhost:8080/", 804 | "height": 438 805 | }, 806 | "id": "cKDgiPHZotP1", 807 | "outputId": "42e4fad2-9e9b-418e-d5a7-1227f391bc80" 808 | }, 809 | "outputs": [ 810 | { 811 | "name": "stdout", 812 | "output_type": "stream", 813 | "text": [ 814 | "Std Array: [0.27027504 0.22355586 0.19433676 0.08276099 0.1712302 0.14369398\n", 815 | " 0.31018993 0.13080781 0.47434729 0.18379491]\n" 816 | ] 817 | }, 818 | { 819 | "data": { 820 | "image/png": "\n", 821 | "text/plain": [ 822 | "
" 823 | ] 824 | }, 825 | "metadata": { 826 | "needs_background": "light" 827 | }, 828 | "output_type": "display_data" 829 | } 830 | ], 831 | "source": [ 832 | "analyse_model_prediction(np.random.uniform(size = (28, 28, 1), low = 0, \n", 833 | " high = 1))" 834 | ] 835 | }, 836 | { 837 | "cell_type": "code", 838 | "execution_count": 17, 839 | "id": "PJHtwIOAoyAW", 840 | "metadata": { 841 | "id": "PJHtwIOAoyAW" 842 | }, 843 | "outputs": [], 844 | "source": [] 845 | } 846 | ], 847 | "metadata": { 848 | "accelerator": "GPU", 849 | "colab": { 850 | "collapsed_sections": [], 851 | "machine_shape": "hm", 852 | "name": "Bayesian CNN.ipynb", 853 | "provenance": [] 854 | }, 855 | "kernelspec": { 856 | "display_name": "Python 3", 857 | "language": "python", 858 | "name": "python3" 859 | }, 860 | "language_info": { 861 | "codemirror_mode": { 862 | "name": "ipython", 863 | "version": 3 864 | }, 865 | "file_extension": ".py", 866 | "mimetype": "text/x-python", 867 | "name": "python", 868 | "nbconvert_exporter": "python", 869 | "pygments_lexer": "ipython3", 870 | "version": "3.8.8" 871 | } 872 | }, 873 | "nbformat": 4, 874 | "nbformat_minor": 5 875 | } 876 | --------------------------------------------------------------------------------