├── .gitignore ├── README.md ├── analyze_aleatoric_uncertainty.ipynb ├── batch_data └── .keep_me ├── bin ├── create_batch_data.py ├── download_model_info.py ├── predict.py └── train.py ├── blog_images ├── aleatoric_variance_loss_function_analysis.png ├── aleatoric_variance_loss_values.png ├── alex_kendall_uncertainty_types.jpg ├── augmented_vs_original_uncertainty.png ├── bayesian-deep-learning.jpg ├── blank-wall.jpg ├── catdog.png ├── catdog_just_cat.png ├── catdog_just_dog.png ├── change_logit_loss_analysis.png ├── elu.jpg ├── example_images.png ├── gammas.png ├── max_aleatoric_uncertainty_test.png ├── max_epistemic_uncertainty_test.png ├── semi-truck-glare.jpg ├── softmax_categorical_crossentropy_v_logit_difference.png ├── stanford_occlusions.png ├── test_first_second_rest_stats.png ├── test_stats.png └── thanks-for-all-the-fish.jpg ├── bnn ├── __init__.py ├── data.py ├── loss_equations.py ├── model.py ├── predict.py └── util.py ├── data └── .keep_me ├── hold_images ├── gamma_aleatoric_uncertainty.png └── gamma_prediction_score.png ├── medium-header.md ├── model_training_logs_resnet50_cifar10_256_201_100.csv ├── predictions └── .keep_me └── scratchwork.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | data/* 2 | batch_data/* 3 | predictions/* 4 | .DS_Store 5 | model.ckpt 6 | bnn/__pycache__ 7 | *.zip 8 | *.ckpt 9 | .ipynb_checkpoints/* 10 | 11 | !batch_data/.keep_me 12 | !data/.keep_me 13 | !predictions/.keep_me 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Building a Bayesian deep learning classifier 2 | 3 | [//]: # (Image References) 4 | [remoteimage1]: https://www.new-york-city-travel-tips.com/wp-content/uploads/2014/01/manhattanhenge-2-590x394.jpg "Under/over exposed example" 5 | [remoteimage2]: https://neil.fraser.name/writing/tank/tank-yes.jpg "Tank" 6 | [remoteimage3]: https://neil.fraser.name/writing/tank/tank-no.jpg "No tank" 7 | [remoteimage4]: https://cdn-images-1.medium.com/max/2000/1*m0T_vjg4mOJNIvel1JXGqQ.png "Kalman filter" 8 | [remoteimage5]: https://motivationdedication.files.wordpress.com/2013/03/workoverload.jpg?w=300&h=225 "Work overload" 9 | 10 | [image1]: https://github.com/kyle-dorman/bayesian-neural-network-blogpost/blob/master/blog_images/aleatoric_variance_loss_function_analysis.png "Aleatoric variance vs loss for different 'wrong' logit values" 11 | [image2]: https://github.com/kyle-dorman/bayesian-neural-network-blogpost/blob/master/blog_images/catdog.png "Ambiguity example" 12 | [image3]: https://github.com/kyle-dorman/bayesian-neural-network-blogpost/blob/master/blog_images/example_images.png "Example Cifar10 images" 13 | [image4]: https://github.com/kyle-dorman/bayesian-neural-network-blogpost/blob/master/blog_images/gammas.png "Example image with different gamma values" 14 | [image5]: https://github.com/kyle-dorman/bayesian-neural-network-blogpost/blob/master/blog_images/max_aleatoric_uncertainty_test.png "Max Aleatoric Uncertainty" 15 | [image6]: https://github.com/kyle-dorman/bayesian-neural-network-blogpost/blob/master/blog_images/max_epistemic_uncertainty_test.png "Max Epistemic Uncertainty" 16 | [image7]: https://github.com/kyle-dorman/bayesian-neural-network-blogpost/blob/master/blog_images/softmax_categorical_crossentropy_v_logit_difference.png "Softmax categorical cross entropy vs. logit difference" 17 | [image8]: https://github.com/kyle-dorman/bayesian-neural-network-blogpost/blob/master/blog_images/test_stats.png "Stats" 18 | [image9]: https://github.com/kyle-dorman/bayesian-neural-network-blogpost/blob/master/blog_images/test_first_second_rest_stats.png "Stats by correct label logit position" 19 | [image10]: https://github.com/kyle-dorman/bayesian-neural-network-blogpost/blob/master/blog_images/change_logit_loss_analysis.png "Change in logit loss" 20 | [image11]: https://github.com/kyle-dorman/bayesian-neural-network-blogpost/blob/master/blog_images/catdog_just_dog.png "Just dog" 21 | [image12]: https://github.com/kyle-dorman/bayesian-neural-network-blogpost/blob/master/blog_images/catdog_just_cat.png "Just cat" 22 | [image13]: https://github.com/kyle-dorman/bayesian-neural-network-blogpost/blob/master/blog_images/augmented_vs_original_uncertainty.png "Uncertainty: augmented vs original images" 23 | [image14]: https://github.com/kyle-dorman/bayesian-neural-network-blogpost/blob/master/blog_images/aleatoric_variance_loss_values.png "Minimum aleatoric variance and minimum loss for different incorrect logit values" 24 | [image15]: https://github.com/kyle-dorman/bayesian-neural-network-blogpost/blob/master/blog_images/elu.jpg "ELU activation function" 25 | [image16]: https://github.com/kyle-dorman/bayesian-neural-network-blogpost/blob/master/blog_images/blank-wall.jpg "Lack of visual features example" 26 | [image17]: https://github.com/kyle-dorman/bayesian-neural-network-blogpost/blob/master/blog_images/semi-truck-glare.jpg "Truck with glare" 27 | [image18]: https://github.com/kyle-dorman/bayesian-neural-network-blogpost/blob/master/blog_images/alex_kendall_uncertainty_types.jpg "Segmentation uncertainty" 28 | [image19]: https://github.com/kyle-dorman/bayesian-neural-network-blogpost/blob/master/blog_images/stanford_occlusions.png "Occlusion example" 29 | [image20]: https://github.com/kyle-dorman/bayesian-neural-network-blogpost/blob/master/blog_images/thanks-for-all-the-fish.jpg "Thanks for all the fish" 30 | [image21]: https://github.com/kyle-dorman/bayesian-neural-network-blogpost/blob/master/blog_images/bayesian-deep-learning.jpg "Bayesian deep learning" 31 | 32 | ### Intro 33 | In this blog post, I am going to teach you how to train a Bayesian deep learning classifier using [Keras](https://keras.io/) and [tensorflow](https://www.tensorflow.org/). Before diving into the specific training example, I will cover a few important high level concepts: 34 | 1. What is Bayesian deep learning? 35 | 2. What is uncertainty? 36 | 3. Why is uncertainty important? 37 | 38 | I will then cover two techniques for including uncertainty in a deep learning model and will go over a specific example using Keras to train fully connected layers over a frozen [ResNet50](https://arxiv.org/abs/1512.03385) encoder on the [cifar10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset. With this example, I will also discuss methods of exploring the uncertainty predictions of a Bayesian deep learning classifier and provide suggestions for improving the model in the future. 39 | 40 | ### Acknowledgments 41 | This post is based on material from two blog posts ([here](http://alexgkendall.com/computer_vision/bayesian_deep_learning_for_safe_ai/) and [here](http://mlg.eng.cam.ac.uk/yarin/blog_3d801aa532c1ce.html)) and a [white paper](https://arxiv.org/pdf/1703.04977.pdf) on Bayesian deep learning from the University of Cambridge machine learning group. If you want to learn more about Bayesian deep learning after reading this post, I encourage you to check out all three of these resources. Thank you to the University of Cambridge machine learning group for your amazing blog posts and papers. 42 | 43 | ### What is Bayesian deep learning? 44 | Bayesian statistics is a theory in the field of statistics in which the evidence about the true state of the world is expressed in terms of degrees of belief. The combination of Bayesian statistics and deep learning in practice means including uncertainty in your deep learning model predictions. The idea of including uncertainty in neural networks was proposed as early as [1991](http://papers.nips.cc/paper/419-transforming-neural-net-output-levels-to-probability-distributions.pdf). Put simply, Bayesian deep learning adds a prior distribution over each weight and bias parameter found in a typical neural network model. In the past, Bayesian deep learning models were not used very often because they require more parameters to optimize, which can make the models difficult to work with. However, more recently, Bayesian deep learning has become more popular and new techniques are being developed to include uncertainty in a model while using the same number of parameters as a traditional model. 45 | 46 | ![alt image][image21] 47 | > Visualizing a Bayesian deep learning model. 48 | 49 | ### What is [uncertainty](https://en.wikipedia.org/wiki/Uncertainty)? 50 | Uncertainty is the state of having limited knowledge where it is impossible to exactly describe the existing state, a future outcome, or more than one possible outcome. As it pertains to deep learning and classification, uncertainty also includes ambiguity; uncertainty about human definitions and concepts, not an objective fact of nature. 51 | 52 | ![alt image][image2] 53 | > An example of ambiguity. What should the model predict? 54 | 55 | ### Types of uncertainty 56 | There are several different types of uncertainty and I will only cover two important types in this post. 57 | 58 | #### Aleatoric uncertainty 59 | Aleatoric uncertainty measures what you can't understand from the data. It can be explained away with the ability to observe all explanatory variables with increased precision. Think of aleatoric uncertainty as sensing uncertainty. There are actually two types of aleatoric uncertainty, heteroscedastic and homoscedastic, but I am only covering heteroscedastic uncertainty in this post. Homoscedastic is covered more in depth in [this](http://alexgkendall.com/computer_vision/bayesian_deep_learning_for_safe_ai/) blog post. 60 | 61 | Concrete examples of aleatoric uncertainty in stereo imagery are occlusions (parts of the scene a camera can't see), lack of visual features (i.e a blank wall), or over/under exposed areas (glare & shading). 62 | 63 | ![alt image][image19] 64 | > Occlusions example 65 | 66 | ![alt image][image16] 67 | > Lack of visual features example 68 | 69 | ![alt image][remoteimage1] 70 | > Under/over exposed example 71 | 72 | #### Epistemic uncertainty 73 | Epistemic uncertainty measures what your model doesn't know due to lack of training data. It can be explained away with infinite training data. Think of epistemic uncertainty as model uncertainty. 74 | 75 | An easy way to observe epistemic uncertainty in action is to train one model on 25% of your dataset and to train a second model on the entire dataset. The model trained on only 25% of the dataset will have higher average epistemic uncertainty than the model trained on the entire dataset because it has seen fewer examples. 76 | 77 | A fun example of epistemic uncertainty was uncovered in the now famous Not Hotdog app. From my own experiences with the app, the model performs very well. But upon closer inspection, it seems like the network was never trained on "not hotdog" images that included ketchup on the item in the image. So if the model is shown a picture of your leg with ketchup on it, the model is fooled into thinking it is a hotdog. A Bayesian deep learning model would predict high epistemic uncertainty in these situations. 78 |
79 | 80 | ### Why is uncertainty important? 81 | In machine learning, we are trying to create approximate representations of the real world. Popular deep learning models created today produce a point estimate but not an uncertainty value. Understanding if your model is under-confident or falsely over-confident can help you reason about your model and your dataset. The two types of uncertainty explained above are import for different reasons. 82 | 83 | Note: In a classification problem, the softmax output gives you a probability value for each class, but this is not the same as uncertainty. The softmax probability is the probability that an input is a given class relative to the other classes. Because the probability is relative to the other classes, it does not help explain the model’s overall confidence. 84 | 85 | #### Why is Aleatoric uncertainty important? 86 | Aleatoric uncertainty is important in cases where parts of the observation space have higher noise levels than others. For example, aleatoric uncertainty played a role in the first fatality involving a self driving car. Tesla has said that during this incident, the car's autopilot failed to recognize the white truck against a bright sky. An image segmentation classifier that is able to predict aleatoric uncertainty would recognize that this particular area of the image was difficult to interpret and predicted a high uncertainty. In the case of the Tesla incident, although the car's radar could "see" the truck, the radar data was inconsistent with the image classifier data and the car's path planner ultimately ignored the radar data (radar data is known to be noisy). If the image classifier had included a high uncertainty with its prediction, the path planner would have known to ignore the image classifier prediction and use the radar data instead (this is oversimplified but is effectively what would happen. See Kalman filters below). 87 | 88 | ![alt image][image17] 89 | > Even for a human, driving when roads have lots of glare is difficult 90 | 91 | #### Why is Epistemic uncertainty important? 92 | Epistemic uncertainty is important because it identifies situations the model was never trained to understand because the situations were not in the training data. Machine learning engineers hope our models generalize well to situations that are different from the training data; however, in safety critical applications of deep learning hope is not enough. High epistemic uncertainty is a red flag that a model is much more likely to make inaccurate predictions and when this occurs in safety critical applications, the model should not be trusted. 93 | 94 | Epistemic uncertainty is also helpful for exploring your dataset. For example, epistemic uncertainty would have been helpful with [this](https://neil.fraser.name/writing/tank/) particular neural network mishap from the 1980s. In this case, researchers trained a neural network to recognize tanks hidden in trees versus trees without tanks. After training, the network performed incredibly well on the training set and the test set. The only problem was that all of the images of the tanks were taken on cloudy days and all of the images without tanks were taken on a sunny day. The classifier had actually learned to identify sunny versus cloudy days. Whoops. 95 | 96 | ![alt image][remoteimage2] ![alt image][remoteimage3] 97 | > Tank & cloudy vs no tank & sunny 98 | 99 | Uncertainty predictions in deep learning models are also important in robotics. I am currently enrolled in the Udacity self driving car nanodegree and have been learning about techniques cars/robots use to recognize and track objects around then. Self driving cars use a powerful technique called [Kalman filters](https://en.wikipedia.org/wiki/Kalman_filter) to track objects. Kalman filters combine a series of measurement data containing statistical noise and produce estimates that tend to be more accurate than any single measurement. Traditional deep learning models are not able to contribute to Kalman filters because they only predict an outcome and do not include an uncertainty term. In theory, Bayesian deep learning models could contribute to Kalman filter tracking. 100 | 101 | ![alt image][remoteimage4] 102 | > Radar and lidar data merged into the Kalman filter. Image data could be incorporated as well. 103 | 104 | ### Calculating uncertainty in deep learning classification models 105 | Aleatoric and epistemic uncertainty are different and, as such, they are calculated differently. 106 | 107 | #### Calculating aleatoric uncertainty 108 | Aleatoric uncertainty is a function of the input data. Therefore, a deep learning model can learn to predict aleatoric uncertainty by using a modified loss function. For a classification task, instead of only predicting the softmax values, the Bayesian deep learning model will have two outputs, the softmax values and the input variance. Teaching the model to predict aleatoric variance is an example of unsupervised learning because the model doesn't have variance labels to learn from. Below is the standard categorical cross entropy loss function and a function to calculate the Bayesian categorical cross entropy loss. 109 | 110 | ```python 111 | import numpy as np 112 | from keras import backend as K 113 | from tensorflow.contrib import distributions 114 | 115 | # standard categorical cross entropy 116 | # N data points, C classes 117 | # true - true values. Shape: (N, C) 118 | # pred - predicted values. Shape: (N, C) 119 | # returns - loss (N) 120 | def categorical_cross_entropy(true, pred): 121 | return np.sum(true * np.log(pred), axis=1) 122 | 123 | # Bayesian categorical cross entropy. 124 | # N data points, C classes, T monte carlo simulations 125 | # true - true values. Shape: (N, C) 126 | # pred_var - predicted logit values and variance. Shape: (N, C + 1) 127 | # returns - loss (N,) 128 | def bayesian_categorical_crossentropy(T, num_classes): 129 | def bayesian_categorical_crossentropy_internal(true, pred_var): 130 | # shape: (N,) 131 | std = K.sqrt(pred_var[:, num_classes:]) 132 | # shape: (N,) 133 | variance = pred_var[:, num_classes] 134 | variance_depressor = K.exp(variance) - K.ones_like(variance) 135 | # shape: (N, C) 136 | pred = pred_var[:, 0:num_classes] 137 | # shape: (N,) 138 | undistorted_loss = K.categorical_crossentropy(pred, true, from_logits=True) 139 | # shape: (T,) 140 | iterable = K.variable(np.ones(T)) 141 | dist = distributions.Normal(loc=K.zeros_like(std), scale=std) 142 | monte_carlo_results = K.map_fn(gaussian_categorical_crossentropy(true, pred, dist, undistorted_loss, num_classes), iterable, name='monte_carlo_results') 143 | 144 | variance_loss = K.mean(monte_carlo_results, axis=0) * undistorted_loss 145 | 146 | return variance_loss + undistorted_loss + variance_depressor 147 | 148 | return bayesian_categorical_crossentropy_internal 149 | 150 | # for a single monte carlo simulation, 151 | # calculate categorical_crossentropy of 152 | # predicted logit values plus gaussian 153 | # noise vs true values. 154 | # true - true values. Shape: (N, C) 155 | # pred - predicted logit values. Shape: (N, C) 156 | # dist - normal distribution to sample from. Shape: (N, C) 157 | # undistorted_loss - the crossentropy loss without variance distortion. Shape: (N,) 158 | # num_classes - the number of classes. C 159 | # returns - total differences for all classes (N,) 160 | def gaussian_categorical_crossentropy(true, pred, dist, undistorted_loss, num_classes): 161 | def map_fn(i): 162 | std_samples = K.transpose(dist.sample(num_classes)) 163 | distorted_loss = K.categorical_crossentropy(pred + std_samples, true, from_logits=True) 164 | diff = undistorted_loss - distorted_loss 165 | return -K.elu(diff) 166 | return map_fn 167 | ``` 168 | The loss function I created is based on the loss function in [this](https://arxiv.org/pdf/1703.04977.pdf) paper. In the paper, the loss function creates a normal distribution with a mean of zero and the predicted variance. It distorts the predicted logit values by sampling from the distribution and computes the softmax categorical cross entropy using the distorted predictions. The loss function runs T Monte Carlo samples and then takes the average of the T samples as the loss. 169 | 170 | ![alt image][image7] 171 | > Figure 1: Softmax categorical cross entropy vs. logit difference for binary classification 172 | 173 | In Figure 1, the y axis is the softmax categorical cross entropy. The x axis is the difference between the 'right' logit value and the 'wrong' logit value. 'right' means the correct class for this prediction. 'wrong' means the incorrect class for this prediction. I will use the term 'logit difference' to mean the x axis of Figure 1. When the 'logit difference' is positive in Figure 1, the softmax prediction will be correct. When 'logit difference' is negative, the prediction will be incorrect. I will continue to use the terms 'logit difference', 'right' logit, and 'wrong' logit this way as I explain the aleatoric loss function. 174 | 175 | Figure 1 is helpful for understanding the results of the normal distribution distortion. When the logit values (in a binary classification) are distorted using a normal distribution, the distortion is effectively creating a normal distribution with a mean of the original predicted 'logit difference' and the predicted variance as the distribution variance. Applying softmax cross entropy to the distorted logit values is the same as sampling along the line in Figure 1 for a 'logit difference' value. 176 | 177 | Taking the categorical cross entropy of the distorted logits should ideally result in a few interesting properties. 178 | 1. When the predicted logit value is much larger than any other logit value (the right half of Figure 1), increasing the variance should only increase the loss. This is true because the derivative is negative on the right half of the graph. i.e. increasing the 'logit difference' results in only a slightly smaller decrease in softmax categorical cross entropy compared to an equal decrease in 'logit difference'. The minimum loss should be close to 0 in this case. 179 | 2. When the 'wrong' logit is much larger than the 'right' logit (the left half of graph) and the variance is ~0, the loss should be ~`wrong_logit-right_logit`. You can see this is on the right half of Figure 1. When the 'logit difference' is -4, the softmax cross entropy is 4. The slope on this part of the graph is ~ -1 so this should be true as the 'logit difference' continues to decrease. 180 | 3. To enable the model to learn aleatoric uncertainty, when the 'wrong' logit value is greater than the 'right' logit value (the left half of graph), the loss function should be minimized for a variance value greater than 0. For an image that has high aleatoric uncertainty (i.e. it is difficult for the model to make an accurate prediction on this image), this feature encourages the model to find a local loss minimum during training by increasing its predicted variance. 181 | 182 | I was able to use the loss function suggested in the paper to decrease the loss when the 'wrong' logit value is greater than the 'right' logit value by increasing the variance, but the decrease in loss due to increasing the variance was extremely small (<0.1). During training, my model had a hard time picking up on this slight local minimum and the aleatoric variance predictions from my model did not make sense. I believe this happens because the slope of Figure 1 on the left half of the graph is ~ -1. Sampling a normal distribution along a line with a slope of -1 will result in another normal distribution and the mean will be about the same as it was before but what we want is for the mean of the T samples to decrease as the variance increases. 183 | 184 | To make the model easier to train, I wanted to create a more significant loss change as the variance increases. Just like in the paper, my loss function above distorts the logits for T Monte Carlo samples using a normal distribution with a mean of 0 and the predicted variance and then computes the categorical cross entropy for each sample. To get a more significant loss change as the variance increases, the loss function needed to weight the Monte Carlo samples where the loss decreased more than the samples where the loss increased. My solution is to use the [elu](http://image-net.org/challenges/posters/JKU_EN_RGB_Schwarz_poster.pdf) activation function, which is a non-linear function centered around 0. 185 | 186 | ![alt image][image15] 187 | > ELU activation function 188 | 189 | I applied the elu function to the change in categorical cross entropy, i.e. the original undistorted loss compared to the distorted loss, `undistorted_loss - distorted_loss`. The elu shifts the mean of the normal distribution away from zero for the left half of Figure 1. The elu is also ~linear for very small values near 0 so the mean for the right half of Figure 1 stays the same. 190 | 191 | ![alt image][image10] 192 | > Figure 2: Average change in loss & distorted average change in loss. 193 | 194 | In Figure 2 `right < wrong` corresponds to a point on the left half of Figure 1 and `wrong < right` corresponds to a point on the right half of Figure 2. You can see that the distribution of outcomes from the 'wrong' logit case, looks similar to the normal distribution and the 'right' case is mostly small values near zero. After applying `-elu` to the change in loss, the mean of the `right < wrong` becomes much larger. In this example, it changes from -0.16 to 0.25. The mean of the `wrong < right` stays about the same. I call the mean of the lower graphs in Figure 2 the 'distorted average change in loss'. The 'distorted average change in loss' should should stay near 0 as the variance increases on the right half of Figure 1 and should always increase when the variance increases on the right half of Figure 1. 195 | 196 | I then scaled the 'distorted average change in loss' by the original undistorted categorical cross entropy. This is done because the distorted average change in loss for the wrong logit case is about the same for all logit differences greater than three (because the derivative of the line is 0). To ensure the loss is greater than zero, I add the undistorted categorical cross entropy. The ‘distorted average change in loss’ always decreases as the variance increases but the loss function should be minimized for a variance value less than infinity. To ensure the variance that minimizes the loss is less than infinity, I add the exponential of the variance term. As Figure 3 shows, the exponential of the variance is the dominant characteristic after the variance passes 2. 197 | 198 | ![alt image][image1] 199 | > Figure 3: Aleatoric variance vs loss for different 'wrong' logit values 200 | 201 | ![alt image][image14] 202 | > Figure 4: Minimum aleatoric variance and minimum loss for different 'wrong' logit values 203 | 204 | These are the results of calculating the above loss function for binary classification example where the 'right' logit value is held constant at 1.0 and the 'wrong' logit value changes for each line. When the 'wrong' logit value is less than 1.0 (and thus less than the 'right' logit value), the minimum variance is 0.0. As the wrong 'logit' value increases, the variance that minimizes the loss increases. 205 | 206 | Note: When generating this graph, I ran 10,000 Monte Carlo simulations to create smooth lines. When training the model, I only ran 100 Monte Carlo simulations as this should be sufficient to get a reasonable mean. 207 | 208 | ![alt image][remoteimage5] 209 | > Brain overload? Grab a time appropriate beverage before continuing. 210 | 211 | #### Calculating epistemic uncertainty 212 | One way of modeling epistemic uncertainty is using Monte Carlo dropout sampling (a type of variational inference) at test time. For a full explanation of why dropout can model uncertainty check out [this](http://mlg.eng.cam.ac.uk/yarin/blog_3d801aa532c1ce.html) blog and [this](https://arxiv.org/pdf/1703.04977.pdf) white paper white paper. In practice, Monte Carlo dropout sampling means including dropout in your model and running your model multiple times with dropout turned on at test time to create a distribution of outcomes. You can then calculate the predictive entropy (the average amount of information contained in the predictive distribution). 213 | 214 | To understand using dropout to calculate epistemic uncertainty, think about splitting the cat-dog image above in half vertically. 215 | 216 | ![alt image][image11] ![alt image][image12] 217 | 218 | If you saw the left half, you would predict dog. If you saw the right half you would predict cat. A perfect 50-50 split. This image would high epistemic uncertainty because the image exhibits features that you associate with both a cat class and a dog class. 219 | 220 | Below are two ways of calculating epistemic uncertainty. They do the exact same thing, but the first is simpler and only uses numpy. The second uses additional Keras layers (and gets GPU acceleration) to make the predictions. 221 | 222 | ```python 223 | # model - the trained classifier(C classes) 224 | # where the last layer applies softmax 225 | # X_data - a list of input data(size N) 226 | # T - the number of monte carlo simulations to run 227 | def montecarlo_prediction(model, X_data, T): 228 | # shape: (T, N, C) 229 | predictions = np.array([model.predict(X_data) for _ in range(T)]) 230 | 231 | # shape: (N, C) 232 | prediction_probabilities = np.mean(predictions, axis=0) 233 | 234 | # shape: (N) 235 | prediction_variances = np.apply_along_axis(predictive_entropy, axis=1, arr=prediction_probabilities) 236 | return (prediction_probabilities, prediction_variances) 237 | 238 | # prob - prediction probability for each class(C). Shape: (N, C) 239 | # returns - Shape: (N) 240 | def predictive_entropy(prob): 241 | return -1 * np.sum(np.log(prob) * prob, axis=1) 242 | ``` 243 | 244 | ```python 245 | from keras.models import Model 246 | from keras.layers import Input, RepeatVector 247 | from keras.engine.topology import Layer 248 | from keras.layers.wrappers import TimeDistributed 249 | 250 | # Take a mean of the results of a TimeDistributed layer. 251 | # Applying TimeDistributedMean()(TimeDistributed(T)(x)) to an 252 | # input of shape (None, ...) returns output of same size. 253 | class TimeDistributedMean(Layer): 254 | def build(self, input_shape): 255 | super(TimeDistributedMean, self).build(input_shape) 256 | 257 | # input shape (None, T, ...) 258 | # output shape (None, ...) 259 | def compute_output_shape(self, input_shape): 260 | return (input_shape[0],) + input_shape[2:] 261 | 262 | def call(self, x): 263 | return K.mean(x, axis=1) 264 | 265 | 266 | # Apply the predictive entropy function for input with C classes. 267 | # Input of shape (None, C, ...) returns output with shape (None, ...) 268 | # Input should be predictive means for the C classes. 269 | # In the case of a single classification, output will be (None,). 270 | class PredictiveEntropy(Layer): 271 | def build(self, input_shape): 272 | super(PredictiveEntropy, self).build(input_shape) 273 | 274 | # input shape (None, C, ...) 275 | # output shape (None, ...) 276 | def compute_output_shape(self, input_shape): 277 | return (input_shape[0],) 278 | 279 | # x - prediction probability for each class(C) 280 | def call(self, x): 281 | return -1 * K.sum(K.log(x) * x, axis=1) 282 | 283 | 284 | def create_epistemic_uncertainty_model(checkpoint, epistemic_monte_carlo_simulations): 285 | model = load_saved_model(checkpoint) 286 | inpt = Input(shape=(model.input_shape[1:])) 287 | x = RepeatVector(epistemic_monte_carlo_simulations)(inpt) 288 | # Keras TimeDistributed can only handle a single output from a model :( 289 | # and we technically only need the softmax outputs. 290 | hacked_model = Model(inputs=model.inputs, outputs=model.outputs[1]) 291 | x = TimeDistributed(hacked_model, name='epistemic_monte_carlo')(x) 292 | # predictive probabilities for each class 293 | softmax_mean = TimeDistributedMean(name='epistemic_softmax_mean')(x) 294 | variance = PredictiveEntropy(name='epistemic_variance')(softmax_mean) 295 | epistemic_model = Model(inputs=inpt, outputs=[variance, softmax_mean]) 296 | 297 | return epistemic_model 298 | 299 | # 1. Load the model 300 | # 2. compile the model 301 | # 3. Set learning phase to train 302 | # 4. predict 303 | def predict(): 304 | model = create_epistemic_uncertainty_model('model.ckpt', 100) 305 | model.compile(...) 306 | 307 | # set learning phase to 1 so that Dropout is on. In keras master you can set this 308 | # on the TimeDistributed layer 309 | K.set_learning_phase(1) 310 | 311 | epistemic_predictions = model.predict(data) 312 | 313 | ``` 314 | Note: Epistemic uncertainty is not used to train the model. It is only calculated at test time (but during a training phase) when evaluating test/real world examples. This is different than aleatoric uncertainty, which is predicted as part of the training process. Also, in my experience, it is easier to produce reasonable epistemic uncertainty predictions than aleatoric uncertainty predictions. 315 | 316 | ### Training a Bayesian deep learning classifier 317 | Besides the code above, training a Bayesian deep learning classifier to predict uncertainty doesn't require much additional code beyond what is typically used to train a classifier. 318 | 319 | ```python 320 | def resnet50(input_shape): 321 | input_tensor = Input(shape=input_shape) 322 | base_model = ResNet50(include_top=False, input_tensor=input_tensor) 323 | # freeze encoder layers to prevent over fitting 324 | for layer in base_model.layers: 325 | layer.trainable = False 326 | 327 | output_tensor = Flatten()(base_model.output) 328 | return Model(inputs=input_tensor, outputs=output_tensor) 329 | ``` 330 | For this experiment, I used the frozen convolutional layers from Resnet50 with the weights for [ImageNet](http://www.image-net.org/) to encode the images. I initially attempted to train the model without freezing the convolutional layers but found the model quickly became over fit. 331 | 332 | ```python 333 | def create_bayesian_model(encoder, input_shape, output_classes): 334 | encoder_model = resnet50(input_shape) 335 | input_tensor = Input(shape=encoder_model.output_shape[1:]) 336 | x = BatchNormalization(name='post_encoder')(input_tensor) 337 | x = Dropout(0.5)(x) 338 | x = Dense(500, activation='relu')(x) 339 | x = BatchNormalization()(x) 340 | x = Dropout(0.5)(x) 341 | x = Dense(100, activation='relu')(x) 342 | x = BatchNormalization()(x) 343 | x = Dropout(0.5)(x) 344 | 345 | logits = Dense(output_classes)(x) 346 | variance_pre = Dense(1)(x) 347 | variance = Activation('softplus', name='variance')(variance_pre) 348 | logits_variance = concatenate([logits, variance], name='logits_variance') 349 | softmax_output = Activation('softmax', name='softmax_output')(logits) 350 | 351 | model = Model(inputs=input_tensor, outputs=[logits_variance,softmax_output]) 352 | 353 | return model 354 | ``` 355 | The trainable part of my model is two sets of `BatchNormalization`, `Dropout`, `Dense`, and `relu` layers on top of the ResNet50 output. The logits and variance are calculated using separate `Dense` layers. Note that the variance layer applies a `softplus` activation function to ensure the model always predicts variance values greater than zero. The logit and variance layers are then recombined for the aleatoric loss function and the softmax is calculated using just the logit layer. 356 | 357 | ```python 358 | model.compile( 359 | optimizer=Adam(lr=1e-3, decay=0.001), 360 | loss={ 361 | 'logits_variance': bayesian_categorical_crossentropy(100, 10), 362 | 'softmax_output': 'categorical_crossentropy' 363 | }, 364 | metrics={'softmax_output': metrics.categorical_accuracy}, 365 | loss_weights={'logits_variance': .2, 'softmax_output': 1.}) 366 | ``` 367 | I trained the model using two losses, one is the aleatoric uncertainty loss function and the other is the standard categorical cross entropy function. This allows the last `Dense` layer, which creates the logits, to learn only how to produce better logit values while the `Dense` layer that creates the variance learns only about predicting variance. The two prior `Dense` layers will train on both of these losses. The aleatoric uncertainty loss function is weighted less than the categorical cross entropy loss because the aleatoric uncertainty loss includes the categorical cross entropy loss as one of its terms. 368 | 369 | I used 100 Monte Carlo simulations for calculating the Bayesian loss function. It took about 70 seconds per epoch. I found increasing the number of Monte Carlo simulations from 100 to 1,000 added about four minutes to each training epoch. 370 | 371 | I added augmented data to the training set by randomly applying a gamma value of 0.5 or 2.0 to decrease or increase the brightness of each image. In practice I found the cifar10 dataset did not have many images that would in theory exhibit high aleatoric uncertainty. This is probably by design. By adding images with adjusted gamma values to images in the training set, I am attempting to give the model more images that should have high aleatoric uncertainty. 372 | 373 | 374 | ![alt image][image4] 375 | > Example image with gamma value distortion. 1.0 is no distortion 376 | 377 | Unfortunately, predicting epistemic uncertainty takes a considerable amount of time. It takes about 2-3 seconds on my Mac CPU for the fully connected layers to predict all 50,000 classes for the training set but over five minutes for the epistemic uncertainty predictions. This isn't that surprising because epistemic uncertainty requires running Monte Carlo simulations on each image. I ran 100 Monte Carlo simulations so it is reasonable to expect the prediction process to take about 100 times longer to predict epistemic uncertainty than aleatoric uncertainty. 378 | 379 | Lastly, my [project](https://github.com/kyle-dorman/bayesian-neural-network-blogpost) is setup to easily switch out the underlying encoder network and train models for other datasets in the future. Feel free to play with it if you want a deeper dive into training your own Bayesian deep learning classifier. 380 | 381 | ### Results 382 | ![alt image][image3] 383 | > Example of each class in cifar10 384 | 385 | My model's categorical accuracy on the test dataset is 86.4%. This is not an amazing score by any means. I was able to produce scores higher than 93%, but only by sacrificing the accuracy of the aleatoric uncertainty. There are a few different hyperparameters I could play with to increase my score. I spent very little time tuning the weights of the two loss functions and I suspect that changing these hyperparameters could greatly increase my model accuracy. I could also unfreeze the Resnet50 layers and train those as well. While getting better accuracy scores on this dataset is interesting, Bayesian deep learning is about both the predictions and the uncertainty estimates and so I will spend the rest of the post evaluating the validity of the uncertainty predictions of my model. 386 | 387 | ![alt image][image8] 388 | > Figure 5: uncertainty mean and standard deviation for test set 389 | 390 | The aleatoric uncertainty values tend to be much smaller than the epistemic uncertainty. These two values can't be compared directly on the same image. They can however be compared against the uncertainty values the model predicts for other images in this dataset. 391 | 392 | ![alt image][image9] 393 | > Figure 6: Uncertainty to relative rank of 'right' logit value. 394 | 395 | To further explore the uncertainty, I broke the test data into three groups based on the relative value of the correct logit. In Figure 5, 'first' includes all of the correct predictions (i.e logit value for the 'right' label was the largest value). 'second', includes all of the cases where the 'right' label is the second largest logit value. 'rest' includes all of the other cases. 86.4% of the samples are in the 'first' group, 8.7% are in the 'second' group, and 4.9% are in the 'rest' group. Figure 5 shows the mean and standard deviation of the aleatoric and epistemic uncertainty for the test set broken out by these three groups. As I was hoping, the epistemic and aleatoric uncertainties are correlated with the relative rank of the 'right' logit. This indicates the model is more likely to identify incorrect labels as situations it is unsure about. Additionally, the model is predicting greater than zero uncertainty when the model's prediction is correct. I expected the model to exhibit this characteristic because the model can be uncertain even if it's prediction is correct. 396 | 397 | ![alt image][image5] 398 | > Images with highest aleatoric uncertainty 399 | 400 | ![alt image][image6] 401 | > Images with the highest epistemic uncertainty 402 | 403 | Above are the images with the highest aleatoric and epistemic uncertainty. While it is interesting to look at the images, it is not exactly clear to me why these images images have high aleatoric or epistemic uncertainty. This is one downside to training an image classifier to produce uncertainty. The uncertainty for the entire image is reduced to a single value. It is often times much easier to understand uncertainty in an image segmentation model because it is easier to compare the results for each pixel in an image. 404 | 405 | ![alt image][image18] 406 | > "Illustrating the difference between aleatoric and epistemic uncertainty for semantic segmentation. You can notice that aleatoric uncertainty captures object boundaries where labels are noisy. The bottom row shows a failure case of the segmentation model, when the model is unfamiliar with the footpath, and the corresponding increased epistemic uncertainty." [link](http://alexgkendall.com/computer_vision/bayesian_deep_learning_for_safe_ai/) 407 | 408 | If my model understands aleatoric uncertainty well, my model should predict larger aleatoric uncertainty values for images with low contrast, high brightness/darkness, or high occlusions To test this theory, I applied a range of gamma values to my test images to increase/decrease the pixel intensity and predicted outcomes for the augmented images. 409 | 410 | ![alt image][image13] 411 | > Figure 7: 412 | > Left side: Images & uncertainties with gamma values applied. 413 | > Right side: Images & uncertainties of original image. 414 | 415 | The model's accuracy on the augmented images is 5.5%. This means the gamma images completely tricked my model. The model wasn't trained to score well on these gamma distortions, so that is to be expected. Figure 6 shows the predicted uncertainty for eight of the augmented images on the left and eight original uncertainties and images on the right. The first four images have the highest predicted aleatoric uncertainty of the augmented images and the last four had the lowest aleatoric uncertainty of the augmented images. I am excited to see that the model predicts higher aleatoric and epistemic uncertainties for each augmented image compared with the original image! The aleatoric uncertainty should be larger because the mock adverse lighting conditions make the images harder to understand and the epistemic uncertainty should be larger because the model has not been trained on images with larger gamma distortions. 416 | 417 | ### Next Steps 418 | The model detailed in this post explores only the tip of the Bayesian deep learning iceberg and going forward there are several ways in which I believe I could improve the model's predictions. For example, I could continue to play with the loss weights and unfreeze the Resnet50 convolutional layers to see if I can get a better accuracy score without losing the uncertainty characteristics detailed above. I could also try training a model on a dataset that has more images that exhibit high aleatoric uncertainty. One candidate is the [German Traffic Sign Recognition Benchmark](http://benchmark.ini.rub.de/?section=gtsrb&subsection=dataset) dataset which I've worked with in one of my Udacity projects. This dataset is specifically meant to make the classifier "cope with large variations in visual appearances due to illumination changes, partial occlusions, rotations, weather conditions". Sounds like aleatoric uncertainty to me! 419 | 420 | In addition to trying to improve my model, I could also explore my trained model further. One approach would be to see how my model handles adversarial examples. To do this, I could use a library like [CleverHans](https://github.com/tensorflow/cleverhans) created by Ian Goodfellow. This library uses an adversarial neural network to help explore model vulnerabilities. It would be interesting to see if adversarial examples produced by CleverHans also result in high uncertainties. 421 | 422 | Another library I am excited to explore is Edward, a Python library for probabilistic modeling, inference, and criticism. [Edward](http://edwardlib.org/) supports the creation of network layers with probability distributions and makes it easy to perform variational inference. [This](https://alpha-i.co/blog/MNIST-for-ML-beginners-The-Bayesian-Way.html) blog post uses Edward to train a Bayesian deep learning classifier on the MNIST dataset. 423 | 424 | If you've made it this far, I am very impressed and appreciative. Hopefully this post has inspired you to include uncertainty in your next deep learning project. 425 | 426 | ![alt image][image20] 427 | -------------------------------------------------------------------------------- /batch_data/.keep_me: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyle-dorman/bayesian-neural-network-blogpost/5fc764b18e4826b90b1082e5142fdbbb23b0395d/batch_data/.keep_me -------------------------------------------------------------------------------- /bin/create_batch_data.py: -------------------------------------------------------------------------------- 1 | #!/bin/python 2 | 3 | import os 4 | import sys 5 | 6 | project_path, x = os.path.split(os.path.dirname(os.path.realpath(__file__))) 7 | sys.path.append(project_path) 8 | 9 | import tensorflow as tf 10 | 11 | from bnn.model import create_encoder_model, encoder_min_input_size 12 | from bnn.util import isAWS, upload_s3, stop_instance, save_pickle_file, BatchConfig, full_path 13 | from bnn.data import test_train_data 14 | from math import ceil 15 | 16 | flags = tf.app.flags 17 | FLAGS = flags.FLAGS 18 | 19 | flags.DEFINE_string('dataset', 'cifar10', 'The dataset to train the model on.') 20 | flags.DEFINE_string('encoder', 'resnet50', 'The encoder model to train from.') 21 | flags.DEFINE_integer('batch_size', 32, 'The batch size for the generator') 22 | flags.DEFINE_boolean('debug', False, 'If this is for debugging the model/training process or not.') 23 | flags.DEFINE_integer('verbose', 0, 'Whether to use verbose logging when constructing the data object.') 24 | flags.DEFINE_boolean('augment', False, 'Whether to add augmented data to the initial data.') 25 | flags.DEFINE_boolean('stop', True, 'Stop aws instance after finished running.') 26 | 27 | def main(_): 28 | config = BatchConfig(FLAGS.encoder, FLAGS.dataset) 29 | config.info() 30 | 31 | if os.path.exists(full_path(config.batch_folder())) == False: 32 | os.makedirs(full_path(config.batch_folder)) 33 | 34 | min_image_size = encoder_min_input_size(FLAGS.encoder) 35 | 36 | ((x_train, y_train), (x_test, y_test)) = test_train_data(FLAGS.dataset, min_image_size, FLAGS.debug, 37 | augment_data=FLAGS.augment, batch_size=FLAGS.batch_size) 38 | 39 | input_shape = list(min_image_size) 40 | input_shape.append(3) 41 | 42 | encoder = create_encoder_model(FLAGS.encoder, input_shape) 43 | 44 | print("Compiling model.") 45 | encoder.compile(optimizer='sgd', loss='mean_squared_error') 46 | 47 | print("Encoding training data.") 48 | x_train_encoded = encoder.predict_generator(x_train, 49 | int(ceil(len(y_train)/FLAGS.batch_size)), 50 | verbose=FLAGS.verbose) 51 | 52 | print("Encoding test data.") 53 | x_test_encoded = encoder.predict_generator(x_test, 54 | int(ceil(len(y_test)/FLAGS.batch_size)), 55 | verbose=FLAGS.verbose) 56 | 57 | print("Finished encoding data.") 58 | 59 | if FLAGS.augment: 60 | train_file_name = "/augment-train.p" 61 | test_file_name = "/augment-test.p" 62 | else: 63 | train_file_name = "/train.p" 64 | test_file_name = "/test.p" 65 | 66 | train_file = config.batch_folder() + train_file_name 67 | test_file = config.batch_folder() + test_file_name 68 | save_pickle_file(train_file, (x_train_encoded, y_train)) 69 | save_pickle_file(test_file, (x_test_encoded, y_test)) 70 | 71 | if isAWS() and FLAGS.debug == False: 72 | upload_s3(train_file) 73 | upload_s3(test_file) 74 | 75 | if isAWS() and FLAGS.stop: 76 | stop_instance() 77 | 78 | 79 | if __name__ == '__main__': 80 | tf.app.run() 81 | -------------------------------------------------------------------------------- /bin/download_model_info.py: -------------------------------------------------------------------------------- 1 | #!/bin/python 2 | 3 | import os 4 | import sys 5 | 6 | project_path, x = os.path.split(os.path.dirname(os.path.realpath(__file__))) 7 | sys.path.append(project_path) 8 | 9 | import tensorflow as tf 10 | 11 | from bnn.util import BayesianConfig, BatchConfig, download_s3 12 | 13 | flags = tf.app.flags 14 | FLAGS = flags.FLAGS 15 | 16 | flags.DEFINE_string('dataset', 'cifar10', 'The dataset to train the model on.') 17 | flags.DEFINE_string('encoder', 'resnet50', 'The encoder model to train from.') 18 | flags.DEFINE_integer('epochs', 1, 'Number of training examples.') 19 | flags.DEFINE_integer('monte_carlo_simulations', 100, 'The number of monte carlo simulations to run for the aleatoric categorical crossentroy loss function.') 20 | flags.DEFINE_integer('batch_size', 32, 'The batch size for the generator') 21 | 22 | def main(_): 23 | bayesian_config = BayesianConfig(FLAGS.encoder, FLAGS.dataset, FLAGS.batch_size, FLAGS.epochs, FLAGS.monte_carlo_simulations) 24 | bayesian_config.info() 25 | 26 | batch_config = BatchConfig(FLAGS.encoder, FLAGS.dataset) 27 | batch_config.info() 28 | 29 | print("Downloading model info") 30 | 31 | download_s3(batch_config.batch_folder()+"/train.p") 32 | download_s3(batch_config.batch_folder()+"/test.p") 33 | download_s3(batch_config.batch_folder()+"/augment-train.p") 34 | download_s3(batch_config.batch_folder()+"/augment-test.p") 35 | download_s3(batch_config.predictions_folder()+"/results.p") 36 | download_s3(bayesian_config.model_file()) 37 | download_s3(bayesian_config.csv_log_file()) 38 | 39 | print("Done downloading model info") 40 | 41 | if __name__ == '__main__': 42 | tf.app.run() 43 | -------------------------------------------------------------------------------- /bin/predict.py: -------------------------------------------------------------------------------- 1 | #!/bin/python 2 | 3 | import os 4 | import sys 5 | 6 | project_path, x = os.path.split(os.path.dirname(os.path.realpath(__file__))) 7 | sys.path.append(project_path) 8 | 9 | import tensorflow as tf 10 | from bnn.predict import predict 11 | from bnn.util import save_pickle_file, full_path, isAWS, upload_s3 12 | 13 | flags = tf.app.flags 14 | FLAGS = flags.FLAGS 15 | 16 | flags.DEFINE_string('dataset', 'cifar10', 'The dataset to train the model on.') 17 | flags.DEFINE_string('encoder', 'resnet50', 'The encoder model to train from.') 18 | flags.DEFINE_integer('model_epochs', 1, 'Number of training examples for the saved model.') 19 | flags.DEFINE_integer('train_monte_carlo_simulations', 100, 'The number of monte carlo simulations to run for the aleatoric categorical crossentroy loss function.') 20 | flags.DEFINE_integer('epistemic_monte_carlo_simulations', 100, 'The number of monte carlo simulations to run for the epistemic uncertainty calculation.') 21 | flags.DEFINE_integer('model_batch_size', 32, 'The batch size for the saved model.') 22 | flags.DEFINE_integer('batch_size', 32, 'The batch size for evaluating model.') 23 | flags.DEFINE_integer('verbose', 0, 'Whether to use verbose logging when constructing the data object.') 24 | flags.DEFINE_boolean('debug', False, 'If this is for debugging the model/training process or not.') 25 | flags.DEFINE_boolean('full_model', False, 'Whether to load the end to end model or just the dense layers.') 26 | 27 | def main(_): 28 | (train_results, test_results) = predict(FLAGS.batch_size, 29 | FLAGS.verbose, FLAGS.epistemic_monte_carlo_simulations, 30 | FLAGS.debug, FLAGS.full_model, 31 | FLAGS.encoder, FLAGS.dataset, FLAGS.model_batch_size, 32 | FLAGS.model_epochs, FLAGS.train_monte_carlo_simulations) 33 | 34 | print("Done predicting test & train results.") 35 | 36 | if FLAGS.debug == False: 37 | folder = "predictions/{}_{}".format(FLAGS.encoder, FLAGS.dataset) 38 | if os.path.isdir(full_path(folder)) == False: 39 | os.mkdir(full_path(folder)) 40 | 41 | save_pickle_file(folder + "/results.p", (train_results, test_results)) 42 | 43 | if isAWS() and FLAGS.debug == False: 44 | upload_s3(folder + "/results.p") 45 | 46 | if __name__ == '__main__': 47 | tf.app.run() 48 | -------------------------------------------------------------------------------- /bin/train.py: -------------------------------------------------------------------------------- 1 | #!/bin/python 2 | 3 | import os 4 | import sys 5 | 6 | project_path, x = os.path.split(os.path.dirname(os.path.realpath(__file__))) 7 | sys.path.append(project_path) 8 | 9 | import tensorflow as tf 10 | from keras.optimizers import Adam 11 | from keras.callbacks import ModelCheckpoint, EarlyStopping, CSVLogger 12 | from keras import metrics 13 | import numpy as np 14 | 15 | from bnn.model import create_bayesian_model, encoder_min_input_size 16 | from bnn.loss_equations import bayesian_categorical_crossentropy 17 | from bnn.util import isAWS, upload_s3, stop_instance, BayesianConfig 18 | from bnn.data import test_train_batch_data 19 | 20 | flags = tf.app.flags 21 | FLAGS = flags.FLAGS 22 | 23 | flags.DEFINE_string('dataset', 'cifar10', 'The dataset to train the model on.') 24 | flags.DEFINE_string('encoder', 'resnet50', 'The encoder model to train from.') 25 | flags.DEFINE_integer('epochs', 1, 'Number of training examples.') 26 | flags.DEFINE_integer('monte_carlo_simulations', 100, 'The number of monte carlo simulations to run for the aleatoric categorical crossentroy loss function.') 27 | flags.DEFINE_integer('batch_size', 32, 'The batch size for the generator') 28 | flags.DEFINE_boolean('debug', False, 'If this is for debugging the model/training process or not.') 29 | flags.DEFINE_integer('verbose', 0, 'Whether to use verbose logging when constructing the data object.') 30 | flags.DEFINE_boolean('stop', True, 'Stop aws instance after finished running.') 31 | flags.DEFINE_float('min_delta', 0.005, 'Early stopping minimum change value.') 32 | flags.DEFINE_integer('patience', 20, 'Early stopping epochs patience to wait before stopping.') 33 | 34 | def main(_): 35 | config = BayesianConfig(FLAGS.encoder, FLAGS.dataset, FLAGS.batch_size, FLAGS.epochs, FLAGS.monte_carlo_simulations) 36 | config.info() 37 | 38 | min_image_size = encoder_min_input_size(FLAGS.encoder) 39 | 40 | ((x_train, y_train), (x_test, y_test)) = test_train_batch_data(FLAGS.dataset, FLAGS.encoder, FLAGS.debug, augment_data=True) 41 | 42 | min_image_size = list(min_image_size) 43 | min_image_size.append(3) 44 | num_classes = y_train.shape[-1] 45 | 46 | model = create_bayesian_model(FLAGS.encoder, min_image_size, num_classes) 47 | 48 | if FLAGS.debug: 49 | print(model.summary()) 50 | callbacks = None 51 | else: 52 | callbacks = [ 53 | ModelCheckpoint(config.model_file(), verbose=FLAGS.verbose, save_best_only=True), 54 | CSVLogger(config.csv_log_file()), 55 | EarlyStopping(monitor='val_logits_variance_loss', min_delta=FLAGS.min_delta, patience=FLAGS.patience, verbose=1) 56 | ] 57 | 58 | print("Compiling model.") 59 | model.compile( 60 | optimizer=Adam(lr=1e-3, decay=0.001), 61 | loss={ 62 | 'logits_variance': bayesian_categorical_crossentropy(FLAGS.monte_carlo_simulations, num_classes), 63 | 'softmax_output': 'categorical_crossentropy' 64 | }, 65 | metrics={'softmax_output': metrics.categorical_accuracy}, 66 | loss_weights={'logits_variance': .2, 'softmax_output': 1.}) 67 | 68 | print("Starting model train process.") 69 | model.fit(x_train, 70 | {'logits_variance':y_train, 'softmax_output':y_train}, 71 | callbacks=callbacks, 72 | verbose=FLAGS.verbose, 73 | epochs=FLAGS.epochs, 74 | batch_size=FLAGS.batch_size, 75 | validation_data=(x_test, {'logits_variance':y_test, 'softmax_output':y_test})) 76 | 77 | print("Finished training model.") 78 | 79 | if isAWS() and FLAGS.debug == False: 80 | upload_s3(config.model_file()) 81 | upload_s3(config.csv_log_file()) 82 | 83 | if isAWS() and FLAGS.stop: 84 | stop_instance() 85 | 86 | 87 | if __name__ == '__main__': 88 | tf.app.run() 89 | -------------------------------------------------------------------------------- /blog_images/aleatoric_variance_loss_function_analysis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyle-dorman/bayesian-neural-network-blogpost/5fc764b18e4826b90b1082e5142fdbbb23b0395d/blog_images/aleatoric_variance_loss_function_analysis.png -------------------------------------------------------------------------------- /blog_images/aleatoric_variance_loss_values.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyle-dorman/bayesian-neural-network-blogpost/5fc764b18e4826b90b1082e5142fdbbb23b0395d/blog_images/aleatoric_variance_loss_values.png -------------------------------------------------------------------------------- /blog_images/alex_kendall_uncertainty_types.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyle-dorman/bayesian-neural-network-blogpost/5fc764b18e4826b90b1082e5142fdbbb23b0395d/blog_images/alex_kendall_uncertainty_types.jpg -------------------------------------------------------------------------------- /blog_images/augmented_vs_original_uncertainty.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyle-dorman/bayesian-neural-network-blogpost/5fc764b18e4826b90b1082e5142fdbbb23b0395d/blog_images/augmented_vs_original_uncertainty.png -------------------------------------------------------------------------------- /blog_images/bayesian-deep-learning.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyle-dorman/bayesian-neural-network-blogpost/5fc764b18e4826b90b1082e5142fdbbb23b0395d/blog_images/bayesian-deep-learning.jpg -------------------------------------------------------------------------------- /blog_images/blank-wall.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyle-dorman/bayesian-neural-network-blogpost/5fc764b18e4826b90b1082e5142fdbbb23b0395d/blog_images/blank-wall.jpg -------------------------------------------------------------------------------- /blog_images/catdog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyle-dorman/bayesian-neural-network-blogpost/5fc764b18e4826b90b1082e5142fdbbb23b0395d/blog_images/catdog.png -------------------------------------------------------------------------------- /blog_images/catdog_just_cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyle-dorman/bayesian-neural-network-blogpost/5fc764b18e4826b90b1082e5142fdbbb23b0395d/blog_images/catdog_just_cat.png -------------------------------------------------------------------------------- /blog_images/catdog_just_dog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyle-dorman/bayesian-neural-network-blogpost/5fc764b18e4826b90b1082e5142fdbbb23b0395d/blog_images/catdog_just_dog.png -------------------------------------------------------------------------------- /blog_images/change_logit_loss_analysis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyle-dorman/bayesian-neural-network-blogpost/5fc764b18e4826b90b1082e5142fdbbb23b0395d/blog_images/change_logit_loss_analysis.png -------------------------------------------------------------------------------- /blog_images/elu.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyle-dorman/bayesian-neural-network-blogpost/5fc764b18e4826b90b1082e5142fdbbb23b0395d/blog_images/elu.jpg -------------------------------------------------------------------------------- /blog_images/example_images.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyle-dorman/bayesian-neural-network-blogpost/5fc764b18e4826b90b1082e5142fdbbb23b0395d/blog_images/example_images.png -------------------------------------------------------------------------------- /blog_images/gammas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyle-dorman/bayesian-neural-network-blogpost/5fc764b18e4826b90b1082e5142fdbbb23b0395d/blog_images/gammas.png -------------------------------------------------------------------------------- /blog_images/max_aleatoric_uncertainty_test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyle-dorman/bayesian-neural-network-blogpost/5fc764b18e4826b90b1082e5142fdbbb23b0395d/blog_images/max_aleatoric_uncertainty_test.png -------------------------------------------------------------------------------- /blog_images/max_epistemic_uncertainty_test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyle-dorman/bayesian-neural-network-blogpost/5fc764b18e4826b90b1082e5142fdbbb23b0395d/blog_images/max_epistemic_uncertainty_test.png -------------------------------------------------------------------------------- /blog_images/semi-truck-glare.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyle-dorman/bayesian-neural-network-blogpost/5fc764b18e4826b90b1082e5142fdbbb23b0395d/blog_images/semi-truck-glare.jpg -------------------------------------------------------------------------------- /blog_images/softmax_categorical_crossentropy_v_logit_difference.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyle-dorman/bayesian-neural-network-blogpost/5fc764b18e4826b90b1082e5142fdbbb23b0395d/blog_images/softmax_categorical_crossentropy_v_logit_difference.png -------------------------------------------------------------------------------- /blog_images/stanford_occlusions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyle-dorman/bayesian-neural-network-blogpost/5fc764b18e4826b90b1082e5142fdbbb23b0395d/blog_images/stanford_occlusions.png -------------------------------------------------------------------------------- /blog_images/test_first_second_rest_stats.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyle-dorman/bayesian-neural-network-blogpost/5fc764b18e4826b90b1082e5142fdbbb23b0395d/blog_images/test_first_second_rest_stats.png -------------------------------------------------------------------------------- /blog_images/test_stats.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyle-dorman/bayesian-neural-network-blogpost/5fc764b18e4826b90b1082e5142fdbbb23b0395d/blog_images/test_stats.png -------------------------------------------------------------------------------- /blog_images/thanks-for-all-the-fish.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyle-dorman/bayesian-neural-network-blogpost/5fc764b18e4826b90b1082e5142fdbbb23b0395d/blog_images/thanks-for-all-the-fish.jpg -------------------------------------------------------------------------------- /bnn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyle-dorman/bayesian-neural-network-blogpost/5fc764b18e4826b90b1082e5142fdbbb23b0395d/bnn/__init__.py -------------------------------------------------------------------------------- /bnn/data.py: -------------------------------------------------------------------------------- 1 | #!/bin/python 2 | 3 | from bnn.util import open_pickle_file, download_file, unzip_data, BatchConfig 4 | from keras.datasets import cifar10 5 | from keras.applications.resnet50 import preprocess_input 6 | import numpy as np 7 | import cv2 8 | import random 9 | 10 | def get_traffic_sign_data(): 11 | url = "https://d17h27t6h515a5.cloudfront.net/topher/2017/February/5898cd6f_traffic-signs-data/traffic-signs-data.zip" 12 | zip_file = "traffic-sign-data.zip" 13 | 14 | download_file(url, zip_file) 15 | unzip_data(zip_file, "data/traffic-sign") 16 | 17 | train = open_pickle_file("data/traffic-sign/train.p") 18 | test = open_pickle_file("data/traffic-sign/test.p") 19 | valid = open_pickle_file("data/traffic-sign/valid.p") 20 | 21 | return ((train['features'], train['labels']), (test['features'], test['labels']), (valid['features'], valid['labels'])) 22 | 23 | def test_train_data(dataset, min_image_size, is_debug, augment_data=True, batch_size=32): 24 | if dataset == 'cifar10': 25 | (x_train, y_train), (x_test, y_test) = cifar10.load_data() 26 | 27 | if is_debug: 28 | x_train = x_train[0:128] 29 | x_test = x_test[0:128] 30 | y_train = y_train[0:128] 31 | y_test = y_test[0:128] 32 | 33 | if augment_data: 34 | augment_images_train, augment_labels_train = augment_images(x_train, y_train) 35 | 36 | x_train = np.concatenate([x_train, augment_images_train]) 37 | y_train = np.concatenate([y_train, augment_labels_train]) 38 | 39 | x_train = ResizeGenerator(x_train, batch_size, min_image_size) 40 | x_test = ResizeGenerator(x_test, batch_size, min_image_size) 41 | 42 | (y_train, y_test) = clean_label_dataset(y_train, y_test, False) 43 | return ((x_train, y_train), (x_test, y_test)) 44 | # todo: add more datasets 45 | else: 46 | raise ValueError("Unexpected dataset " + dataset + ".") 47 | 48 | def test_train_batch_data(dataset, encoder, is_debug, augment_data=False): 49 | if dataset == 'cifar10': 50 | if augment_data: 51 | train_file_name = "/augment-train.p" 52 | test_file_name = "/augment-test.p" 53 | else: 54 | train_file_name = "/train.p" 55 | test_file_name = "/test.p" 56 | config = BatchConfig(encoder, dataset) 57 | x_train, y_train = open_pickle_file(config.batch_folder() + train_file_name) 58 | x_test, y_test = open_pickle_file(config.batch_folder() + test_file_name) 59 | if is_debug: 60 | x_train = x_train[0:256] 61 | x_test = x_test[0:256] 62 | y_train = y_train[0:256] 63 | y_test = y_test[0:256] 64 | return ((x_train, y_train), (x_test, y_test)) 65 | # todo: add more datasets 66 | else: 67 | raise ValueError("Unexpected dataset " + dataset + ".") 68 | 69 | 70 | class ResizeGenerator(): 71 | def __init__(self, data, batch_size, image_size): 72 | self.data = data 73 | self.batch_size = batch_size 74 | self.image_size = image_size 75 | self.index = 0 76 | 77 | def __next__(self): 78 | return self.next() 79 | 80 | def next(self): 81 | start = self.index 82 | end = min(self.index+self.batch_size, len(self.data)) 83 | result = preprocess_input(np.array([cv2.resize(i, self.image_size) for i in self.data[start:end]], dtype=np.float64)) 84 | 85 | if end == len(self.data): 86 | self.index = 0 87 | else: 88 | self.index += self.batch_size 89 | 90 | return result 91 | 92 | 93 | def clean_feature_dataset(x_train, x_test, min_image_size, is_debug): 94 | print("Resizing images from", x_train.shape[1:-1], "to", min_image_size) 95 | x_train = np.array([cv2.resize(i, min_image_size) for i in x_train], dtype=np.float64) 96 | print("Done resizing train images.") 97 | x_test = np.array([cv2.resize(i, min_image_size) for i in x_test], dtype=np.float64) 98 | print("Done resizing test images.") 99 | return (preprocess_input(x_train), preprocess_input(x_test)) 100 | 101 | 102 | # Randomly add gamma darkness/brightness to images to create bad examples 103 | # done at fixed gammas to speed up augmentation 104 | def augment_images(images, labels): 105 | # gammas that increase & decrease brightness 106 | gammas = [0.7, 2.] 107 | gamma_images = [[] for _ in range(len(gammas))] 108 | 109 | for i in range(len(images)): 110 | image = images[i] 111 | label = labels[i] 112 | gamma_images[random.randint(0, len(gammas)-1)].append([image, label]) 113 | 114 | for i in range(len(gamma_images)): 115 | g_images = [image for image, _ in gamma_images[i]] 116 | g_images = augment_gamma(g_images, gammas[i]) 117 | for j in range(len(g_images)): 118 | gamma_images[i][j][0] = g_images[j] 119 | 120 | result_images = [] 121 | result_labels = [] 122 | 123 | for row in gamma_images: 124 | for image, label in row: 125 | result_images.append(image) 126 | result_labels.append(label) 127 | 128 | return (result_images, result_labels) 129 | 130 | 131 | def augment_gamma(images, gamma=1.0): 132 | # build a lookup table mapping the pixel values [0, 255] to 133 | # their adjusted gamma values 134 | invGamma = 1.0 / gamma 135 | table = np.array([((i / 255.0) ** invGamma) * 255 for i in np.arange(0, 256)]).astype("uint8") 136 | 137 | # apply gamma correction using the lookup table 138 | return [cv2.LUT(image, table) for image in images] 139 | 140 | 141 | def clean_label_dataset(y_train, y_test, is_debug): 142 | y_train = one_hot(y_train) 143 | y_test = one_hot(y_test) 144 | 145 | if is_debug: 146 | y_train = y_train[0:128] 147 | y_test = y_test[0:128] 148 | 149 | return (y_train, y_test) 150 | 151 | def one_hot(labels): 152 | if labels.shape[-1] == 1: 153 | labels = np.reshape(labels, (-1)) 154 | max_label = np.max(labels) + 1 155 | return np.eye(max_label)[labels] 156 | 157 | def add_zeros(labels): 158 | shape = list(labels.shape) 159 | shape[-1] = 1 160 | return np.hstack((labels, np.zeros(shape))) 161 | 162 | def category_examples(dataset): 163 | if dataset == 'cifar10': 164 | (_, _), (x_test, y_test) = cifar10.load_data() 165 | else: 166 | raise ValueError("Unexpected dataset " + dataset + ".") 167 | 168 | categories = category_names(dataset) 169 | results = [] 170 | for i in range(len(categories)): 171 | idx = find_index(y_test, lambda x: x == i) 172 | results.append({'label': i, 'label_name': categories[i], 'example': x_test[idx]}) 173 | 174 | return results 175 | 176 | def find_index(arr, predicate): 177 | i = 0 178 | result = None 179 | while(i < len(arr) and result is None): 180 | if (predicate(arr[i])): 181 | result = i 182 | i+=1 183 | if result == None: 184 | raise ValueError("could not satisfy predicate.") 185 | 186 | return result 187 | 188 | def category_names(dataset): 189 | if dataset == 'cifar10': 190 | return [ 191 | 'airplane', 192 | 'automobile', 193 | 'bird', 194 | 'cat', 195 | 'deer', 196 | 'dog', 197 | 'frog', 198 | 'horse', 199 | 'ship', 200 | 'truck' 201 | ] 202 | else: 203 | raise ValueError("Unexpected dataset " + dataset + ".") 204 | 205 | -------------------------------------------------------------------------------- /bnn/loss_equations.py: -------------------------------------------------------------------------------- 1 | #!/bin/python 2 | 3 | import numpy as np 4 | from keras import backend as K 5 | from tensorflow.contrib import distributions 6 | 7 | # model - the trained classifier(C classes) 8 | # where the last layer applies softmax 9 | # X_data - a list of input data(size N) 10 | # T - the number of monte carlo simulations to run 11 | def montecarlo_prediction(model, X_data, T): 12 | # shape: (T, N, C) 13 | predictions = np.array([model.predict(X_data) for _ in range(T)]) 14 | 15 | # shape: (N, C) 16 | prediction_means = np.mean(predictions, axis=0) 17 | 18 | # shape: (N) 19 | prediction_variances = np.apply_along_axis(predictive_entropy, axis=1, arr=prediction_means) 20 | return (prediction_means, prediction_variances) 21 | 22 | # prob - mean probability for each class(C) 23 | def predictive_entropy(prob): 24 | return -np.sum(np.log(prob) * prob) 25 | 26 | 27 | # standard regression RMSE loss function 28 | # N data points 29 | # true - true values. Shape: (N) 30 | # pred - predicted values. Shape: (N) 31 | # returns - losses. Shape: (N) 32 | def loss(true, pred): 33 | return np.mean(np.square(pred - true)) 34 | 35 | # Bayesian regression loss function 36 | # N data points 37 | # true - true values. Shape: (N) 38 | # pred - predicted values (mean, log(variance)). Shape: (N, 2) 39 | # returns - losses. Shape: (N) 40 | def loss_with_uncertainty(true, pred): 41 | return np.mean((pred[:, :, 0] - true)**2. * np.exp(-pred[:, :, 1]) + pred[:, :, 1]) 42 | 43 | 44 | 45 | # standard categorical cross entropy 46 | # N data points, C classes 47 | # true - true values. Shape: (N, C) 48 | # pred - predicted values. Shape: (N, C) 49 | # returns - loss (N) 50 | def categorical_cross_entropy(true, pred): 51 | return np.sum(true * np.log(pred), axis=1) 52 | 53 | # Bayesian categorical cross entropy. 54 | # N data points, C classes, T monte carlo simulations 55 | # true - true values. Shape: (N, C) 56 | # pred_var - predicted logit values and variance. Shape: (N, C + 1) 57 | # returns - loss (N,) 58 | def bayesian_categorical_crossentropy(T, num_classes): 59 | def bayesian_categorical_crossentropy_internal(true, pred_var): 60 | # shape: (N,) 61 | std = K.sqrt(pred_var[:, num_classes:]) 62 | # shape: (N,) 63 | variance = pred_var[:, num_classes] 64 | variance_depressor = K.exp(variance) - K.ones_like(variance) 65 | # shape: (N, C) 66 | pred = pred_var[:, 0:num_classes] 67 | # shape: (N,) 68 | undistorted_loss = K.categorical_crossentropy(pred, true, from_logits=True) 69 | # shape: (T,) 70 | iterable = K.variable(np.ones(T)) 71 | dist = distributions.Normal(loc=K.zeros_like(std), scale=std) 72 | monte_carlo_results = K.map_fn(gaussian_categorical_crossentropy(true, pred, dist, undistorted_loss, num_classes), iterable, name='monte_carlo_results') 73 | 74 | variance_loss = K.mean(monte_carlo_results, axis=0) * undistorted_loss 75 | 76 | return variance_loss + undistorted_loss + variance_depressor 77 | 78 | return bayesian_categorical_crossentropy_internal 79 | 80 | # for a single monte carlo simulation, 81 | # calculate categorical_crossentropy of 82 | # predicted logit values plus gaussian 83 | # noise vs true values. 84 | # true - true values. Shape: (N, C) 85 | # pred - predicted logit values. Shape: (N, C) 86 | # dist - normal distribution to sample from. Shape: (N, C) 87 | # undistorted_loss - the crossentropy loss without variance distortion. Shape: (N,) 88 | # num_classes - the number of classes. C 89 | # returns - total differences for all classes (N,) 90 | def gaussian_categorical_crossentropy(true, pred, dist, undistorted_loss, num_classes): 91 | def map_fn(i): 92 | std_samples = K.transpose(dist.sample(num_classes)) 93 | distorted_loss = K.categorical_crossentropy(pred + std_samples, true, from_logits=True) 94 | diff = undistorted_loss - distorted_loss 95 | return -K.elu(diff) 96 | return map_fn 97 | 98 | 99 | class MonteCarloTestModel: 100 | def __init__(self, C): 101 | self.C = C 102 | 103 | def predict(self, X_data): 104 | return np.array([self._predict(data) for data in X_data]) 105 | 106 | def _predict(self, data): 107 | return self.softmax([i for i in range(self.C)]) 108 | 109 | def softmax(self, predictions): 110 | vals = np.exp(predictions) 111 | return vals / np.sum(vals) 112 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /bnn/model.py: -------------------------------------------------------------------------------- 1 | #!/bin/python 2 | 3 | from keras.applications.resnet50 import ResNet50 4 | import numpy as np 5 | # from keras.applications.resnet50 import preprocess_input, decode_predictions 6 | from keras.models import Model, load_model 7 | from keras.layers import Dense, Input, Flatten, Dropout, Activation, Lambda, RepeatVector 8 | from keras.layers.normalization import BatchNormalization 9 | from keras.layers.merge import concatenate 10 | from keras.engine.topology import Layer 11 | from keras.layers.wrappers import TimeDistributed 12 | from keras import backend as K 13 | from keras.utils.generic_utils import get_custom_objects 14 | from bnn.loss_equations import bayesian_categorical_crossentropy 15 | 16 | # Take a mean of the results of a TimeDistributed layer. 17 | # Applying TimeDistributedMean()(TimeDistributed(T)(x)) to an 18 | # input of shape (None, ...) returns putpur of same size. 19 | class TimeDistributedMean(Layer): 20 | def build(self, input_shape): 21 | super(TimeDistributedMean, self).build(input_shape) 22 | 23 | # input shape (None, T, ...) 24 | # output shape (None, ...) 25 | def compute_output_shape(self, input_shape): 26 | return (input_shape[0],) + input_shape[2:] 27 | 28 | def call(self, x): 29 | return K.mean(x, axis=1) 30 | 31 | 32 | # Apply the predictive entropy function for input with C classes. 33 | # Input of shape (None, C, ...) returns output with shape (None, ...) 34 | # Input should be predictive means for the C classes. 35 | # In the case of a single classification, output will be (None,). 36 | class PredictiveEntropy(Layer): 37 | def build(self, input_shape): 38 | super(PredictiveEntropy, self).build(input_shape) 39 | 40 | # input shape (None, C, ...) 41 | # output shape (None, ...) 42 | def compute_output_shape(self, input_shape): 43 | return (input_shape[0],) 44 | 45 | # x - prediction probability for each class(C) 46 | def call(self, x): 47 | return -1 * K.sum(K.log(x) * x, axis=1) 48 | 49 | 50 | def load_full_model(encoder, checkpoint, input_shape): 51 | encoder_model = create_encoder_model(encoder, input_shape) 52 | bayesian_model = load_bayesian_model(checkpoint) 53 | outputs = bayesian_model(encoder_model.outputs) 54 | # hack to rename outputs 55 | logits_variance = Lambda(lambda x: x, name='logits_variance')(outputs[0]) 56 | softmax_output = Lambda(lambda x: x, name='softmax_output')(outputs[1]) 57 | 58 | return Model(inputs=encoder_model.inputs, outputs=[logits_variance, softmax_output]) 59 | 60 | 61 | def load_bayesian_model(checkpoint, monte_carlo_simulations=100, classes=10): 62 | get_custom_objects().update({"bayesian_categorical_crossentropy_internal": bayesian_categorical_crossentropy(monte_carlo_simulations, classes)}) 63 | return load_model(checkpoint) 64 | 65 | 66 | def load_epistemic_uncertainty_model(checkpoint, epistemic_monte_carlo_simulations): 67 | model = load_bayesian_model(checkpoint) 68 | inpt = Input(shape=(model.input_shape[1:])) 69 | x = RepeatVector(epistemic_monte_carlo_simulations)(inpt) 70 | # Keras TimeDistributed can only handle a single output from a model :( 71 | # and we technically only need the softmax outputs. 72 | hacked_model = Model(inputs=model.inputs, outputs=model.outputs[1]) 73 | x = TimeDistributed(hacked_model, name='epistemic_monte_carlo')(x) 74 | # predictive probabilties for each class 75 | softmax_mean = TimeDistributedMean(name='epistemic_softmax_mean')(x) 76 | variance = PredictiveEntropy(name='epistemic_variance')(softmax_mean) 77 | epistemic_model = Model(inputs=inpt, outputs=[variance, softmax_mean]) 78 | 79 | return epistemic_model 80 | 81 | 82 | def load_full_epistemic_uncertainty_model(encoder, input_shape, checkpoint, epistemic_monte_carlo_simulations): 83 | encoder_model = create_encoder_model(encoder, input_shape) 84 | bayesian_model = load_epistemic_uncertainty_model(checkpoint, epistemic_monte_carlo_simulations) 85 | outputs = bayesian_model(encoder_model.outputs) 86 | 87 | return Model(inputs=encoder_model.inputs, outputs=outputs) 88 | 89 | 90 | def create_bayesian_model(encoder, input_shape, output_classes): 91 | encoder_model = resnet50(encoder, input_shape) 92 | input_tensor = Input(shape=encoder_model.output_shape[1:]) 93 | x = BatchNormalization(name='post_encoder')(input_tensor) 94 | x = Dropout(0.5)(x) 95 | x = Dense(500, activation='relu')(x) 96 | x = BatchNormalization()(x) 97 | x = Dropout(0.5)(x) 98 | x = Dense(100, activation='relu')(x) 99 | x = BatchNormalization()(x) 100 | x = Dropout(0.5)(x) 101 | 102 | logits = Dense(output_classes)(x) 103 | variance_pre = Dense(1)(x) 104 | variance = Activation('softplus', name='variance')(variance_pre) 105 | logits_variance = concatenate([logits, variance], name='logits_variance') 106 | softmax_output = Activation('softmax', name='softmax_output')(logits) 107 | 108 | model = Model(inputs=input_tensor, outputs=[logits_variance,softmax_output]) 109 | 110 | return model 111 | 112 | 113 | def create_encoder_model(encoder, input_shape): 114 | input_tensor = Input(shape=input_shape) 115 | 116 | if encoder == 'resnet50': 117 | base_model = ResNet50(include_top=False, input_tensor=input_tensor) 118 | else: 119 | raise ValueError('Unexpected encoder model ' + encoder + ".") 120 | 121 | # freeze encoder layers to prevent over fitting 122 | for layer in base_model.layers: 123 | layer.trainable = False 124 | 125 | output_tensor = Flatten()(base_model.output) 126 | 127 | model = Model(inputs=input_tensor, outputs=output_tensor) 128 | return model 129 | 130 | def encoder_min_input_size(encoder): 131 | if encoder == 'resnet50': 132 | return (197, 197) 133 | else: 134 | raise ValueError('Unexpected encoder model ' + encoder + ".") 135 | 136 | 137 | 138 | def extract_last_row(shape, dtype=None): 139 | extractor = np.zeros(shape) 140 | extractor[-1][-1] = 0 141 | return K.constant(extractor) 142 | 143 | def drop_last_row(shape, dtype=None): 144 | extractor = np.zeros(shape) 145 | for i in range(np.min(shape)): 146 | extractor[i][i] = 1 147 | return K.constant(extractor) 148 | 149 | def extract_variance(prev_layer): 150 | layer = Dense(1, name='extract_variance', kernel_initializer=extract_last_row) 151 | layer.trainable = False 152 | return layer(prev_layer) 153 | 154 | def extract_logits(prev_layer, output_classes): 155 | layer = Dense(output_classes, name='extract_logits', kernel_initializer=drop_last_row) 156 | layer.trainable = False 157 | return layer(prev_layer) 158 | -------------------------------------------------------------------------------- /bnn/predict.py: -------------------------------------------------------------------------------- 1 | #!/bin/python 2 | 3 | from keras.utils.generic_utils import get_custom_objects 4 | from keras import backend as K 5 | import numpy as np 6 | import math 7 | 8 | from bnn.model import load_bayesian_model, load_full_model, encoder_min_input_size, load_epistemic_uncertainty_model, load_full_epistemic_uncertainty_model 9 | from bnn.data import test_train_batch_data, test_train_data 10 | from bnn.util import BayesianConfig 11 | from bnn.loss_equations import bayesian_categorical_crossentropy 12 | 13 | 14 | def load_testable_model(encoder, config, monte_carlo_simulations, num_classes, min_image_size, full_model): 15 | if full_model: 16 | model = load_full_model(encoder, config.model_file(), min_image_size) 17 | print("Compiling full testable model.") 18 | model.compile( 19 | optimizer='adam', 20 | loss={'logits_variance': bayesian_categorical_crossentropy(monte_carlo_simulations, num_classes)}, 21 | metrics={'softmax_output': ['categorical_accuracy', 'top_k_categorical_accuracy']}) 22 | else: 23 | model = load_bayesian_model(config.model_file()) 24 | 25 | return model 26 | 27 | 28 | def load_testable_epistemic_uncertainty_model(full_model, min_image_size, config, epistemic_monte_carlo_simulations): 29 | if full_model: 30 | model = load_full_epistemic_uncertainty_model(config.encoder, min_image_size, config.model_file(), epistemic_monte_carlo_simulations) 31 | else: 32 | model = load_epistemic_uncertainty_model(config.model_file(), epistemic_monte_carlo_simulations) 33 | 34 | # the model won't be used for training 35 | model.compile('adam', 'categorical_crossentropy') 36 | return model 37 | 38 | 39 | def predict_epistemic_uncertainties(batch_size, verbose, epistemic_monte_carlo_simulations, debug, full_model, 40 | x_train, y_train, x_test, y_test, 41 | encoder, dataset, model_batch_size, model_epochs, model_monte_carlo_simulations): 42 | # set learning phase to 1 so that Dropout is on. In keras master you can set this 43 | # on the TimeDistributed layer 44 | K.set_learning_phase(1) 45 | min_image_size = list(encoder_min_input_size(encoder)) 46 | min_image_size.append(3) 47 | 48 | config = BayesianConfig(encoder, dataset, model_batch_size, model_epochs, model_monte_carlo_simulations) 49 | epistemic_model = load_testable_epistemic_uncertainty_model(full_model, min_image_size, config, epistemic_monte_carlo_simulations) 50 | 51 | # Shape (N) 52 | print("Predicting epistemic_uncertainties.") 53 | if hasattr(x_train, 'shape'): 54 | epistemic_uncertainties_train = epistemic_model.predict(x_train, batch_size=batch_size, verbose=verbose)[0] 55 | epistemic_uncertainties_test = epistemic_model.predict(x_test, batch_size=batch_size, verbose=verbose)[0] 56 | else: 57 | # generator 58 | epistemic_uncertainties_train = epistemic_model.predict_generator(x_train, int(math.ceil(len(y_train/batch_size))), verbose=verbose)[0] 59 | epistemic_uncertainties_test = epistemic_model.predict_generator(x_test, int(math.ceil(len(y_test/batch_size))), verbose=verbose)[0] 60 | 61 | return (epistemic_uncertainties_train, epistemic_uncertainties_test) 62 | 63 | 64 | def predict_softmax_aleatoric_uncertainties(batch_size, verbose, debug, full_model, 65 | x_train, y_train, x_test, y_test, 66 | encoder, dataset, model_batch_size, model_epochs, model_monte_carlo_simulations): 67 | 68 | num_classes = len(y_train[0]) 69 | min_image_size = encoder_min_input_size(encoder) 70 | min_image_size = list(min_image_size) 71 | min_image_size.append(3) 72 | config = BayesianConfig(encoder, dataset, model_batch_size, model_epochs, model_monte_carlo_simulations) 73 | model = load_testable_model(encoder, config, model_monte_carlo_simulations, num_classes, min_image_size, full_model) 74 | 75 | print("Predicting softmax and aleatoric_uncertainties.") 76 | if hasattr(x_train, 'shape'): 77 | predictions_train = model.predict(x_train, batch_size=batch_size, verbose=verbose) 78 | predictions_test = model.predict(x_test, batch_size=batch_size, verbose=verbose) 79 | else: 80 | # generator 81 | predictions_train = model.predict_generator(x_train, int(math.ceil(len(y_train/batch_size))), verbose=verbose) 82 | predictions_test = model.predict_generator(x_test, int(math.ceil(len(y_test/batch_size))), verbose=verbose) 83 | 84 | # Shape (N) 85 | aleatoric_uncertainties_train = np.reshape(predictions_train[0][:,num_classes:], (-1)) 86 | aleatoric_uncertainties_test = np.reshape(predictions_test[0][:,num_classes:], (-1)) 87 | 88 | logits_train = predictions_train[0][:,0:num_classes] 89 | logits_test = predictions_test[0][:,0:num_classes] 90 | 91 | # Shape (N, C) 92 | softmax_train = predictions_train[1] 93 | softmax_test = predictions_test[1] 94 | 95 | p_train = np.argmax(softmax_train, axis=1) 96 | p_test = np.argmax(softmax_test, axis=1) 97 | l_train = np.argmax(y_train, axis=1) 98 | l_test = np.argmax(y_test, axis=1) 99 | # Shape (N) 100 | prediction_comparision_train = np.equal(p_train,l_train).astype(int) 101 | prediction_comparision_test = np.equal(p_test,l_test).astype(int) 102 | 103 | train_results = [{ 104 | 'softmax_raw':softmax_train[i], 105 | 'softmax':p_train[i], 106 | 'logits_raw': logits_train[i], 107 | 'label': np.argmax(y_train[i]), 108 | 'label_expanded':y_train[i], 109 | 'aleatoric_uncertainty':aleatoric_uncertainties_train[i], 110 | 'is_correct':prediction_comparision_train[i] 111 | } for i in range(len(prediction_comparision_train))] 112 | 113 | test_results = [{ 114 | 'softmax_raw':softmax_test[i], 115 | 'softmax':p_test[i], 116 | 'logits_raw': logits_test[i], 117 | 'label': np.argmax(y_test[i]), 118 | 'label_expanded':y_test[i], 119 | 'aleatoric_uncertainty':aleatoric_uncertainties_test[i], 120 | 'is_correct':prediction_comparision_test[i] 121 | } for i in range(len(prediction_comparision_test))] 122 | 123 | return (train_results, test_results) 124 | 125 | def predict_on_data(batch_size, verbose, epistemic_monte_carlo_simulations, debug, full_model, 126 | x_train, y_train, x_test, y_test, 127 | encoder, dataset, model_batch_size, model_epochs, model_monte_carlo_simulations, include_epistemic_uncertainty=True): 128 | 129 | # epistemic_uncertainty takes a long time to predict 130 | if include_epistemic_uncertainty: 131 | (epistemic_uncertainties_train, epistemic_uncertainties_test) = predict_epistemic_uncertainties( 132 | batch_size, verbose, epistemic_monte_carlo_simulations, debug, full_model, 133 | x_train, y_train, x_test, y_test, 134 | encoder, dataset, model_batch_size, model_epochs, model_monte_carlo_simulations) 135 | 136 | (train_results, test_results) = predict_softmax_aleatoric_uncertainties(batch_size, verbose, debug, full_model, 137 | x_train, y_train, x_test, y_test, 138 | encoder, dataset, model_batch_size, model_epochs, model_monte_carlo_simulations) 139 | 140 | if include_epistemic_uncertainty: 141 | for i in range(len(epistemic_uncertainties_train)): 142 | train_results[i]['epistemic_uncertainty'] = epistemic_uncertainties_train[i] 143 | 144 | for i in range(len(epistemic_uncertainties_test)): 145 | test_results[i]['epistemic_uncertainty'] = epistemic_uncertainties_test[i] 146 | 147 | return (train_results, test_results) 148 | 149 | 150 | def predict(batch_size, verbose, epistemic_monte_carlo_simulations, debug, full_model, 151 | encoder, dataset, model_batch_size, model_epochs, model_monte_carlo_simulations): 152 | 153 | min_image_size = encoder_min_input_size(encoder) 154 | if full_model: 155 | ((x_train, y_train), (x_test, y_test)) = test_train_data(dataset, min_image_size[0:2], 156 | debug, augment_data=False, batch_size=batch_size) 157 | else: 158 | ((x_train, y_train), (x_test, y_test)) = test_train_batch_data(dataset, encoder, debug, augment_data=False) 159 | 160 | return predict_on_data(batch_size, verbose, epistemic_monte_carlo_simulations, debug, full_model, 161 | x_train, y_train, x_test, y_test, 162 | encoder, dataset, model_batch_size, model_epochs, model_monte_carlo_simulations) 163 | 164 | 165 | -------------------------------------------------------------------------------- /bnn/util.py: -------------------------------------------------------------------------------- 1 | #!/bin/python 2 | 3 | import boto3 4 | import os 5 | import os.path 6 | import zipfile 7 | from urllib.request import urlretrieve 8 | import pickle 9 | import threading 10 | import sys 11 | 12 | # Get full path to a resource underneath this project (bayesian-neural-network-blogpost) 13 | def full_path(name): 14 | base_dir_name = "bayesian-neural-network-blogpost" 15 | base_dir_list = os.getcwd().split("/") 16 | i = base_dir_list.index(base_dir_name) 17 | return "/".join(base_dir_list[0:i+1]) + "/" + name 18 | 19 | # Save and fetch data and saved models from S3. Useful for working between AWS and local machine. 20 | 21 | # TODO: change these to command line inputs. 22 | bucket_name = 'kd-carnd' 23 | key_name = 'bayesian-neural-network-blogpost/' 24 | region_name = 'us-east-2' 25 | 26 | def upload_s3(rel_path): 27 | bucket = boto3.resource('s3', region_name=region_name).Bucket(bucket_name) 28 | print("Uploading file", rel_path) 29 | bucket.upload_file(full_path(rel_path), key_name + rel_path, Callback=UploadProgressPercentage(rel_path)) 30 | print("Finished uploading file", rel_path) 31 | 32 | def download_s3(rel_path): 33 | bucket = boto3.resource('s3', region_name=region_name).Bucket(bucket_name) 34 | 35 | print("Downloading file", rel_path) 36 | try: 37 | bucket.download_file(key_name + rel_path, full_path(rel_path), Callback=DownloadProgressPercentage(rel_path)) 38 | print("Finished downloading file", rel_path) 39 | except ClientError: 40 | print("Unable to find file", rel_path, ".") 41 | 42 | class UploadProgressPercentage(object): 43 | def __init__(self, filename): 44 | self._filename = filename 45 | self._size = float(os.path.getsize(filename)) 46 | self._seen_so_far = 0 47 | self._lock = threading.Lock() 48 | def __call__(self, bytes_amount): 49 | # To simplify we'll assume this is hooked up 50 | # to a single filename. 51 | with self._lock: 52 | self._seen_so_far += bytes_amount 53 | percentage = (self._seen_so_far / self._size) * 100 54 | sys.stdout.write( 55 | "\r%s %s / %s (%.2f%%)" % ( 56 | self._filename, self._seen_so_far, self._size, 57 | percentage)) 58 | sys.stdout.flush() 59 | 60 | class DownloadProgressPercentage(object): 61 | def __init__(self, filename): 62 | self._filename = filename 63 | self._seen_so_far = 0 64 | self._lock = threading.Lock() 65 | def __call__(self, bytes_amount): 66 | # To simplify we'll assume this is hooked up 67 | # to a single filename. 68 | with self._lock: 69 | self._seen_so_far += bytes_amount 70 | sys.stdout.write( 71 | "\r%s --> %s bytes transferred" % ( 72 | self._filename, self._seen_so_far)) 73 | sys.stdout.flush() 74 | 75 | def download_file(url, file): 76 | """ 77 | Download file fromIf there's ketchup, it's a hotdog @FunnyAsianDude #nothotdog #NotHotdogchallenge pic.twitter.com/ZOQPqChADU
— David (@david__kha) May 18, 2017