├── .gitignore ├── README.md ├── tfp_resnet.py └── tfp_bgmm_MultivariateNormalTriL_MCMC.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | saved_models/ 2 | .ipynb_checkpoints 3 | *.pdf 4 | *.png 5 | *.h5 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tfp-tutorial 2 | My own implementations of some examples using TensorFlow Probability for tutorial purposes 3 | 4 | ## Tutorials 5 | 6 | * [Bayesian Neural Networks](tfp_bnn.ipynb) 7 | * [Bayesian Gaussian Mixture Model (MultivariateNormalDiag components) (Stochastic Variational Inference)](tfp_bgmm_MultivariateNormalDiag_SVI.ipynb) 8 | * [Bayesian Gaussian Mixture Model (MultivariateNormalTriL components) (Stochastic Variational Inference)](tfp_bgmm_MultivariateNormalTriL_SVI.ipynb) 9 | * [Bayesian Gaussian Mixture Model (MultivariateNormalTriL components) (Markov chain Monte Carlo)](tfp_bgmm_MultivariateNormalTriL_MCMC.ipynb) 10 | * [ResNet](tfp_resnet.py) 11 | 12 | ## References 13 | 14 | ### TensorFlow Probability 15 | * [Project](https://www.tensorflow.org/probability) 16 | * [GitHub](https://github.com/tensorflow/probability) 17 | 18 | ### ResNet 19 | * v1: [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385.pdf) 20 | * v2: [Identity Mappings in Deep Residual Networks](https://arxiv.org/pdf/1603.05027.pdf) -------------------------------------------------------------------------------- /tfp_resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Trains a ResNet on the CIFAR10 dataset using Keras and Tensorflow Probability. 3 | 4 | ResNet v1: 5 | [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385.pdf) 6 | 7 | ResNet v2: 8 | [Identity Mappings in Deep Residual Networks](https://arxiv.org/pdf/1603.05027.pdf) 9 | 10 | Model parameter 11 | ----------------------------------------------------------------------------------- 12 | | | 200-epoch | Orig Paper| 200-epoch | Orig Paper| sec/epoch 13 | Model | n_res_block | ResNet v1 | ResNet v1 | ResNet v2 | ResNet v2 | GTX1080Ti 14 | | v1(v2) | %Accuracy | %Accuracy | %Accuracy | %Accuracy | v1 (v2) 15 | ----------------------------------------------------------------------------------- 16 | ResNet20 | 3 (2) | 92.16 | 91.25 | ----- | ----- | 35 (---) 17 | ResNet32 | 5(NA) | 92.46 | 92.49 | NA | NA | 50 ( NA) 18 | ResNet44 | 7(NA) | 92.50 | 92.83 | NA | NA | 70 ( NA) 19 | ResNet56 | 9 (6) | 92.71 | 93.03 | 93.01 | NA | 90 (100) 20 | ResNet110 | 18(12) | 92.65 | 93.39+-.16| 93.15 | 93.63 | 165(180) 21 | ResNet164 | 27(18) | ----- | 94.07 | ----- | 94.54 | ---(---) 22 | ResNet1001| NA(111) | ----- | 92.39 | ----- | 95.08+-.14| ---(---) 23 | ----------------------------------------------------------------------------------- 24 | """ 25 | from __future__ import print_function 26 | import tensorflow as tf 27 | import tensorflow_probability as tfp 28 | import os 29 | os.environ['KERAS_BACKEND'] = 'tensorflow' # set up tensorflow backend for keras 30 | import keras 31 | import numpy as np 32 | 33 | 34 | def lr_schedule(epoch): 35 | """ 36 | Learning Rate Schedule 37 | 38 | Learning rate is scheduled to be reduced after 80, 120, 160, 180 epochs. 39 | Called automatically every epoch as part of callbacks during training. 40 | 41 | # Arguments 42 | epoch (int): The number of epochs 43 | 44 | # Returns 45 | lr (float32): learning rate 46 | """ 47 | lr = 1e-3 48 | if epoch > 180: 49 | lr *= 0.5e-3 50 | elif epoch > 160: 51 | lr *= 1e-3 52 | elif epoch > 120: 53 | lr *= 1e-2 54 | elif epoch > 80: 55 | lr *= 1e-1 56 | print('Learning rate: ', lr) 57 | return lr 58 | 59 | 60 | def get_kernel_posterior_fn(kernel_posterior_scale_mean=-9.0, 61 | kernel_posterior_scale_stddev=0.1, 62 | kernel_posterior_scale_constraint=0.2): 63 | """ 64 | Get the kernel posterior distribution 65 | 66 | # Arguments 67 | kernel_posterior_scale_mean (float): kernel posterior's scale mean. 68 | kernel_posterior_scale_stddev (float): the initial kernel posterior's scale stddev. 69 | ``` 70 | q(W|x) ~ N(mu, var), 71 | log_var ~ N(kernel_posterior_scale_mean, kernel_posterior_scale_stddev) 72 | ```` 73 | kernel_posterior_scale_constraint (float): the log value to constrain the log variance throughout training. 74 | i.e. log_var <= log(kernel_posterior_scale_constraint). 75 | 76 | # Returns 77 | kernel_posterior_fn: kernel posterior distribution 78 | """ 79 | 80 | def _untransformed_scale_constraint(t): 81 | return tf.clip_by_value(t, -1000, tf.math.log(kernel_posterior_scale_constraint)) 82 | 83 | kernel_posterior_fn = tfp.layers.default_mean_field_normal_fn( 84 | untransformed_scale_initializer=tf.random_normal_initializer( 85 | mean=kernel_posterior_scale_mean, 86 | stddev=kernel_posterior_scale_stddev), 87 | untransformed_scale_constraint=_untransformed_scale_constraint) 88 | return kernel_posterior_fn 89 | 90 | 91 | def get_kernel_divergence_fn(train_size, w=1.0): 92 | """ 93 | Get the kernel Kullback-Leibler divergence function 94 | 95 | # Arguments 96 | train_size (int): size of the training dataset for normalization 97 | w (float): weight to the function 98 | 99 | # Returns 100 | kernel_divergence_fn: kernel Kullback-Leibler divergence function 101 | """ 102 | def kernel_divergence_fn(q, p, _): # need the third ignorable argument 103 | kernel_divergence = tfp.distributions.kl_divergence(q, p) / tf.cast(train_size, tf.float32) 104 | return w * kernel_divergence 105 | return kernel_divergence_fn 106 | 107 | 108 | def get_neg_log_likelihood_fn(bayesian=False): 109 | """ 110 | Get the negative log-likelihood function 111 | # Arguments 112 | bayesian(bool): Bayesian neural network (True) or point-estimate neural network (False) 113 | 114 | # Returns 115 | a negative log-likelihood function 116 | """ 117 | if bayesian: 118 | def neg_log_likelihood_bayesian(y_true, y_pred): 119 | labels_distribution = tfp.distributions.Categorical(logits=y_pred) 120 | log_likelihood = labels_distribution.log_prob(tf.argmax(input=y_true, axis=1)) 121 | loss = -tf.reduce_mean(input_tensor=log_likelihood) 122 | return loss 123 | return neg_log_likelihood_bayesian 124 | else: 125 | def neg_log_likelihood(y_true, y_pred): 126 | y_pred_softmax = keras.layers.Activation('softmax')(y_pred) # logits to softmax 127 | loss = keras.losses.categorical_crossentropy(y_true, y_pred_softmax) 128 | return loss 129 | return neg_log_likelihood 130 | 131 | 132 | def get_categorical_accuracy_fn(y_true, y_pred): 133 | y_pred_softmax = keras.layers.Activation('softmax')(y_pred) # logits to softmax 134 | acc = keras.metrics.categorical_accuracy(y_true, y_pred_softmax) 135 | return acc 136 | 137 | 138 | class KLLossScheduler(tf.keras.callbacks.Callback): 139 | def __init__(self, update_per_batch=False, n_silent_epoch=5, n_annealing_epoch=50, verbose=0): 140 | self.update_per_batch = update_per_batch 141 | self.n_silent_epoch = n_silent_epoch 142 | self.n_annealing_epoch = n_annealing_epoch 143 | self.verbose = verbose 144 | super(KLLossScheduler, self).__init__() 145 | def on_batch_begin(self, batch, logs=None): 146 | if self.update_per_batch: 147 | n_batch_per_epoch = int(np.ceil(self.params['samples'] / self.params['batch_size'])) 148 | idx_total_batch = (self.epoch - self.n_silent_epoch) * n_batch_per_epoch + batch + 1 149 | kl_weight = (idx_total_batch / n_batch_per_epoch) / self.n_annealing_epoch 150 | kl_weight = np.maximum(0.0, np.minimum(kl_weight, 1.0)) 151 | self.kl_weight = kl_weight 152 | if self.verbose > 0: 153 | print('\nBatch: {}, KL Divergence Loss Weight = {:.6f}'.format(batch+1, kl_weight)) 154 | for l in self.model.layers: 155 | for id_w, w in enumerate(l.weights): 156 | if 'kl_loss_weight' in w.name: 157 | l_weights = l.get_weights() 158 | l.set_weights([*l_weights[:id_w], kl_weight, *l_weights[id_w+1:]]) 159 | def on_epoch_begin(self, epoch, logs=None): 160 | self.epoch = epoch 161 | if not self.update_per_batch: 162 | kl_weight = (epoch - self.n_silent_epoch + 1) / self.n_annealing_epoch 163 | kl_weight = np.maximum(0.0, np.minimum(kl_weight, 1.0)) 164 | self.kl_weight = kl_weight 165 | if self.verbose > 0: 166 | print('\nEpoch: {}, KL Divergence Loss Weight = {:.6f}'.format(epoch+1, kl_weight)) 167 | for l in self.model.layers: 168 | for id_w, w in enumerate(l.weights): 169 | if 'kl_loss_weight' in w.name: 170 | l_weights = l.get_weights() 171 | l.set_weights([*l_weights[:id_w], kl_weight, *l_weights[id_w+1:]]) 172 | def on_epoch_end(self, epoch, logs={}): 173 | print('KL Divergence Weight = {:.6f}, KL Divergence Loss = {:.4f}'.format(self.kl_weight, 174 | sum(self.model.losses).eval(session=tf.keras.backend.get_session()))) 175 | 176 | 177 | def resnet_layer(inputs, train_size, 178 | n_filter=16, 179 | kernel_size=3, 180 | strides=1, 181 | activation='relu', 182 | batch_normalization=True, 183 | conv_first=True, 184 | bayesian=False): 185 | """2D Convolution-Batch Normalization-Activation stack builder 186 | 187 | # Arguments 188 | inputs (tensor): input tensor from input image or previous layer 189 | n_filter (int): Conv2D number of filters 190 | kernel_size (int): Conv2D square kernel dimensions 191 | strides (int): Conv2D square stride dimensions 192 | activation (string): activation name 193 | batch_normalization (bool): whether to include batch normalization 194 | conv_first (bool): conv-bn-activation (True) or bn-activation-conv (False) 195 | bayesian (bool): implement Bayesian neural network (True) or point-estimate neural network (False) 196 | 197 | # Returns 198 | x (tensor): tensor as input to the next layer 199 | """ 200 | if bayesian: 201 | # scale the KL divergence function to avoid the loss function being over-regularized 202 | conv = tfp.layers.Convolution2DFlipout(n_filter, 203 | kernel_size=kernel_size, 204 | strides=strides, 205 | padding='same', 206 | kernel_posterior_fn=get_kernel_posterior_fn(), 207 | kernel_divergence_fn=None) 208 | w = conv.add_weight(name=conv.name+'/kl_loss_weight', shape=(), initializer=tf.initializers.constant(0.0), trainable=False) 209 | conv.kernel_divergence_fn = get_kernel_divergence_fn(train_size, w) 210 | else: 211 | conv = keras.layers.Conv2D(n_filter, 212 | kernel_size=kernel_size, 213 | strides=strides, 214 | padding='same', 215 | kernel_initializer='he_normal', 216 | kernel_regularizer=keras.regularizers.l2(1e-4)) 217 | x = inputs 218 | if conv_first: 219 | x = conv(x) 220 | if batch_normalization: 221 | x = keras.layers.BatchNormalization()(x) 222 | if activation is not None: 223 | x = keras.layers.Activation(activation)(x) 224 | else: 225 | if batch_normalization: 226 | x = keras.layers.BatchNormalization()(x) 227 | if activation is not None: 228 | x = keras.layers.Activation(activation)(x) 229 | x = conv(x) 230 | return x 231 | 232 | 233 | def resnet_v1(input_shape, n_res_block, train_size, n_class=10, bayesian=False): 234 | """ResNet Version 1 Model builder [a] 235 | 236 | Stacks of 2 x (3 x 3) Conv2D-BN-ReLU 237 | Last ReLU is after the shortcut connection. 238 | At the beginning of each stage, the feature map size is halved (downsampled) 239 | by a convolutional layer with strides=2, while the number of filters is 240 | doubled. Within each stage, the layers have the same number filters and the 241 | same number of filters. 242 | Features maps sizes: 243 | stage 0: 32x32, 16 244 | stage 1: 16x16, 32 245 | stage 2: 8x8, 64 246 | The Number of parameters is approx the same as Table 6 of [a]: 247 | ResNet20 0.27M 248 | ResNet32 0.46M 249 | ResNet44 0.66M 250 | ResNet56 0.85M 251 | ResNet110 1.7M 252 | 253 | # Arguments 254 | input_shape (tensor): shape of input image tensor 255 | n_res_block (int): number of residual blocks 256 | n_class (int): number of classes (CIFAR10 has 10) 257 | bayesian (bool): implement Bayesian neural network (True) or point-estimate neural network (False) 258 | 259 | # Returns 260 | model (Model): Keras model instance 261 | """ 262 | n_filter = 16 263 | 264 | inputs = keras.layers.Input(shape=input_shape) 265 | x = resnet_layer(inputs=inputs, train_size=train_size, bayesian=bayesian) 266 | # Instantiate the stack of residual units 267 | for stack in range(3): 268 | for res_block in range(n_res_block): 269 | strides = 1 270 | if stack > 0 and res_block == 0: # first layer but not first stack 271 | strides = 2 # downsample 272 | y = resnet_layer(inputs=x, train_size=train_size, 273 | n_filter=n_filter, 274 | strides=strides, 275 | bayesian=bayesian) 276 | y = resnet_layer(inputs=y, train_size=train_size, 277 | n_filter=n_filter, 278 | activation=None, 279 | bayesian=bayesian) 280 | if stack > 0 and res_block == 0: # first layer but not first stack 281 | # linear projection residual shortcut connection to match 282 | # changed dims 283 | x = resnet_layer(inputs=x, train_size=train_size, 284 | n_filter=n_filter, 285 | kernel_size=1, 286 | strides=strides, 287 | activation=None, 288 | batch_normalization=False, 289 | bayesian=bayesian) 290 | x = keras.layers.add([x, y]) 291 | x = keras.layers.Activation('relu')(x) 292 | n_filter *= 2 293 | 294 | # Add classifier on top. 295 | # v1 does not use BN after last shortcut connection-ReLU 296 | x = keras.layers.AveragePooling2D(pool_size=8)(x) 297 | y = keras.layers.Flatten()(x) 298 | if bayesian: 299 | # scale the KL divergence function to avoid the loss function being over-regularized 300 | dense = tfp.layers.DenseFlipout(n_class, 301 | activation=None, 302 | kernel_posterior_fn=get_kernel_posterior_fn(), 303 | kernel_divergence_fn=None) 304 | w = dense.add_weight(name=dense.name+'/kl_loss_weight', shape=(), initializer=tf.initializers.constant(0.0), trainable=False) 305 | dense.kernel_divergence_fn = get_kernel_divergence_fn(train_size, w) 306 | logits = dense(y) 307 | else: 308 | logits = keras.layers.Dense(n_class, 309 | activation=None, 310 | kernel_initializer='he_normal')(y) 311 | # Instantiate model. 312 | model = keras.Model(inputs=inputs, outputs=logits) 313 | 314 | return model 315 | 316 | 317 | def resnet_v2(input_shape, n_res_block, train_size, n_class=10, bayesian=False): 318 | """ResNet Version 2 Model builder [b] 319 | 320 | Stacks of (1 x 1)-(3 x 3)-(1 x 1) BN-ReLU-Conv2D or also known as 321 | bottleneck layer 322 | First shortcut connection per layer is 1 x 1 Conv2D. 323 | Second and onwards shortcut connection is identity. 324 | At the beginning of each stage, the feature map size is halved (downsampled) 325 | by a convolutional layer with strides=2, while the number of filter maps is 326 | doubled. Within each stage, the layers have the same number filters and the 327 | same filter map sizes. 328 | Features maps sizes: 329 | conv1 : 32x32, 16 330 | stage 0: 32x32, 64 331 | stage 1: 16x16, 128 332 | stage 2: 8x8, 256 333 | 334 | # Arguments 335 | input_shape (tensor): shape of input image tensor 336 | n_res_block (int): number of residual blocks 337 | n_class (int): number of classes (CIFAR10 has 10) 338 | bayesian (bool): implement Bayesian neural network (True) or point-estimate neural network (False) 339 | 340 | # Returns 341 | model (Model): Keras model instance 342 | """ 343 | n_filter_in = 16 344 | 345 | inputs = keras.layers.Input(shape=input_shape) 346 | # v2 performs Conv2D with BN-ReLU on input before splitting into 2 paths 347 | x = resnet_layer(inputs=inputs, train_size=train_size, 348 | n_filter=n_filter_in, 349 | conv_first=True, 350 | bayesian=bayesian) 351 | 352 | # Instantiate the stack of residual units 353 | for stage in range(3): 354 | for res_block in range(n_res_block): 355 | activation = 'relu' 356 | batch_normalization = True 357 | strides = 1 358 | if stage == 0: 359 | n_filter_out = n_filter_in * 4 360 | if res_block == 0: # first layer and first stage 361 | activation = None 362 | batch_normalization = False 363 | else: 364 | n_filter_out = n_filter_in * 2 365 | if res_block == 0: # first layer but not first stage 366 | strides = 2 # downsample 367 | 368 | # bottleneck residual unit 369 | y = resnet_layer(inputs=x, train_size=train_size, 370 | n_filter=n_filter_in, 371 | kernel_size=1, 372 | strides=strides, 373 | activation=activation, 374 | batch_normalization=batch_normalization, 375 | conv_first=False, 376 | bayesian=bayesian) 377 | y = resnet_layer(inputs=y, train_size=train_size, 378 | n_filter=n_filter_in, 379 | conv_first=False, 380 | bayesian=bayesian) 381 | y = resnet_layer(inputs=y, train_size=train_size, 382 | n_filter=n_filter_out, 383 | kernel_size=1, 384 | conv_first=False, 385 | bayesian=bayesian) 386 | if res_block == 0: 387 | # linear projection residual shortcut connection to match 388 | # changed dims 389 | x = resnet_layer(inputs=x, train_size=train_size, 390 | n_filter=n_filter_out, 391 | kernel_size=1, 392 | strides=strides, 393 | activation=None, 394 | batch_normalization=False, 395 | bayesian=bayesian) 396 | x = keras.layers.add([x, y]) 397 | 398 | n_filter_in = n_filter_out 399 | 400 | # Add classifier on top. 401 | # v2 has BN-ReLU before Pooling 402 | x = keras.layers.BatchNormalization()(x) 403 | x = keras.layers.Activation('relu')(x) 404 | x = keras.layers.AveragePooling2D(pool_size=8)(x) 405 | y = keras.layers.Flatten()(x) 406 | if bayesian: 407 | # scale the KL divergence function to avoid the loss function being over-regularized 408 | dense = tfp.layers.DenseFlipout(n_class, 409 | activation=None, 410 | kernel_posterior_fn=get_kernel_posterior_fn(), 411 | kernel_divergence_fn=None) 412 | w = dense.add_weight(name=dense.name+'/kl_loss_weight', shape=(), initializer=tf.initializers.constant(0.0), trainable=False) 413 | dense.kernel_divergence_fn = get_kernel_divergence_fn(train_size, w) 414 | logits = dense(y) 415 | else: 416 | logits = keras.layers.Dense(n_class, 417 | activation=None, 418 | kernel_initializer='he_normal')(y) 419 | # Instantiate model. 420 | model = keras.Model(inputs=inputs, outputs=logits) 421 | 422 | return model 423 | 424 | 425 | if __name__ == '__main__': 426 | 427 | # Bayesian mode setting 428 | bayesian = True 429 | if bayesian: 430 | keras = tf.keras 431 | 432 | n_mc_run = 20 if bayesian else 1 433 | 434 | # Training parameters 435 | batch_size = 128 # orig paper trained all networks with batch_size=128 436 | epochs = 200 437 | data_augmentation = False 438 | n_class = 10 439 | 440 | # Subtracting pixel mean improves accuracy 441 | subtract_pixel_mean = True 442 | 443 | n_res_block = 3 444 | 445 | # Model version 446 | # Orig paper: version = 1 (ResNet v1), Improved ResNet: version = 2 (ResNet v2) 447 | version = 2 448 | assert version in [1, 2], 'ResNet version must be 1 or 2.' 449 | 450 | # Computed depth from supplied model parameter n_res_block 451 | depth = n_res_block * 6 + 2 if version == 1 else n_res_block * 9 + 2 452 | 453 | # Model name, depth and version 454 | model_type = 'ResNet%dv%d' % (depth, version) 455 | if bayesian: 456 | model_type += '_Bayesian' 457 | 458 | # Load the CIFAR10 data. 459 | (x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data() 460 | 461 | # Input image dimensions. 462 | input_shape = x_train.shape[1:] 463 | 464 | # Normalize data. 465 | x_train = x_train.astype('float32') / 255 466 | x_test = x_test.astype('float32') / 255 467 | 468 | # If subtract pixel mean is enabled 469 | if subtract_pixel_mean: 470 | x_train_mean = np.mean(x_train, axis=0) 471 | x_train -= x_train_mean 472 | x_test -= x_train_mean 473 | 474 | print('x_train shape:', x_train.shape) 475 | print(x_train.shape[0], 'train samples') 476 | print(x_test.shape[0], 'test samples') 477 | print('y_train shape:', y_train.shape) 478 | 479 | # Convert class vectors to binary class matrices. 480 | y_train = keras.utils.to_categorical(y_train, n_class) 481 | y_test = keras.utils.to_categorical(y_test, n_class) 482 | 483 | if version == 1: 484 | model = resnet_v1(input_shape=input_shape, 485 | n_res_block=n_res_block, 486 | train_size=len(x_train), 487 | bayesian=bayesian) 488 | else: 489 | model = resnet_v2(input_shape=input_shape, 490 | n_res_block=n_res_block, 491 | train_size=len(x_train), 492 | bayesian=bayesian) 493 | 494 | model.compile(loss=get_neg_log_likelihood_fn(bayesian=bayesian), 495 | optimizer=keras.optimizers.Adam(lr=lr_schedule(0)), 496 | metrics=[get_categorical_accuracy_fn]) 497 | model.summary() 498 | print(model_type) 499 | 500 | # Prepare model model saving directory. 501 | save_dir = os.path.join(os.getcwd(), 'saved_models') 502 | model_vis_name = 'cifar10_%s_model.png' % model_type 503 | model_name = 'cifar10_%s_model.h5' % model_type 504 | if not os.path.isdir(save_dir): 505 | os.makedirs(save_dir) 506 | model_vis_filepath = os.path.join(save_dir, model_vis_name) 507 | model_filepath = os.path.join(save_dir, model_name) 508 | 509 | # Plot the model 510 | keras.utils.plot_model(model, to_file=model_vis_filepath, show_shapes=True) 511 | 512 | # Prepare callbacks for model saving and for learning rate adjustment. 513 | checkpoint = keras.callbacks.ModelCheckpoint(filepath=model_filepath, 514 | monitor='val_get_categorical_accuracy_fn', 515 | verbose=1, 516 | save_best_only=True, 517 | save_weights_only=True) 518 | 519 | lr_scheduler = keras.callbacks.LearningRateScheduler(lr_schedule) 520 | 521 | lr_reducer = keras.callbacks.ReduceLROnPlateau(factor=np.sqrt(0.1), 522 | cooldown=0, 523 | patience=5, 524 | min_lr=0.5e-6) 525 | 526 | callbacks = [checkpoint, lr_reducer, lr_scheduler] 527 | if bayesian: 528 | kl_loss_scheduler = KLLossScheduler(update_per_batch=True) 529 | callbacks += [kl_loss_scheduler] 530 | 531 | # Run training, with or without data augmentation. 532 | if not os.path.isfile(model_filepath): 533 | if not data_augmentation: 534 | print('Not using data augmentation.') 535 | model.fit(x_train, y_train, 536 | batch_size=batch_size, 537 | epochs=epochs, 538 | validation_data=(x_test, y_test), 539 | shuffle=True, 540 | callbacks=callbacks) 541 | else: 542 | print('Using real-time data augmentation.') 543 | # This will do preprocessing and realtime data augmentation: 544 | datagen = keras.preprocessing.image.ImageDataGenerator( 545 | # set input mean to 0 over the dataset 546 | featurewise_center=False, 547 | # set each sample mean to 0 548 | samplewise_center=False, 549 | # divide inputs by std of dataset 550 | featurewise_std_normalization=False, 551 | # divide each input by its std 552 | samplewise_std_normalization=False, 553 | # apply ZCA whitening 554 | zca_whitening=False, 555 | # epsilon for ZCA whitening 556 | zca_epsilon=1e-06, 557 | # randomly rotate images in the range (deg 0 to 180) 558 | rotation_range=0, 559 | # randomly shift images horizontally 560 | width_shift_range=0.1, 561 | # randomly shift images vertically 562 | height_shift_range=0.1, 563 | # set range for random shear 564 | shear_range=0., 565 | # set range for random zoom 566 | zoom_range=0., 567 | # set range for random channel shifts 568 | channel_shift_range=0., 569 | # set mode for filling points outside the input boundaries 570 | fill_mode='nearest', 571 | # value used for fill_mode = "constant" 572 | cval=0., 573 | # randomly flip images 574 | horizontal_flip=True, 575 | # randomly flip images 576 | vertical_flip=False, 577 | # set rescaling factor (applied before any other transformation) 578 | rescale=None, 579 | # set function that will be applied on each input 580 | preprocessing_function=None, 581 | # image data format, either "channels_first" or "channels_last" 582 | data_format=None, 583 | # fraction of images reserved for validation (strictly between 0 and 1) 584 | validation_split=0.0) 585 | 586 | # Compute quantities required for featurewise normalization 587 | # (std, mean, and principal components if ZCA whitening is applied). 588 | datagen.fit(x_train) 589 | 590 | # Fit the model on the batches generated by datagen.flow(). 591 | model.fit_generator(datagen.flow(x_train, y_train, batch_size=batch_size), 592 | steps_per_epoch=len(x_train) / batch_size, 593 | validation_data=(x_test, y_test), 594 | epochs=epochs, verbose=1, workers=4, 595 | callbacks=callbacks) 596 | model.load_weights(model_filepath) # load the optimal model with the lowest validation loss 597 | 598 | # apply the model on test data 599 | y_pred_logits = [model.predict(x_test) for _ in range(n_mc_run)] 600 | y_pred_logits = np.concatenate([y[np.newaxis, :, :] for y in y_pred_logits], axis=0) 601 | y_pred_logits_mean = np.mean(y_pred_logits, axis=0) 602 | y_pred_logits_std = np.std(y_pred_logits, axis=0) 603 | 604 | y_pred_softmax = keras.layers.Activation('softmax')(keras.backend.variable(y_pred_logits_mean)).eval(session=keras.backend.get_session()) 605 | print('Test accuracy: ', sum(np.equal(np.argmax(y_test, axis=-1), np.argmax(y_pred_softmax, axis=-1))) / len(y_test)) -------------------------------------------------------------------------------- /tfp_bgmm_MultivariateNormalTriL_MCMC.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Bayesian Gaussian Mixture Modeling using TensorFlow Probability" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "name": "stdout", 17 | "output_type": "stream", 18 | "text": [ 19 | "TensorFlow version: 2.1.0-dev20191015\n", 20 | "TensorFlow Probability version: 0.9.0-dev20191016\n" 21 | ] 22 | } 23 | ], 24 | "source": [ 25 | "import tensorflow as tf\n", 26 | "import tensorflow_probability as tfp\n", 27 | "tfb = tfp.bijectors\n", 28 | "tfd = tfp.distributions\n", 29 | "tf.compat.v1.disable_eager_execution() # Eager execution is disabled as it significantly slows down MCMC\n", 30 | "\n", 31 | "import sys\n", 32 | "import time\n", 33 | "import functools\n", 34 | "import numpy as np\n", 35 | "import pandas as pd\n", 36 | "import seaborn as sns\n", 37 | "import matplotlib.pyplot as plt\n", 38 | "from mpl_toolkits.mplot3d import Axes3D\n", 39 | "from operator import mul\n", 40 | "from sklearn.cluster import KMeans\n", 41 | "\n", 42 | "%matplotlib inline\n", 43 | "\n", 44 | "print('TensorFlow version:', tf.__version__)\n", 45 | "print('TensorFlow Probability version:', tfp.__version__)" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 2, 51 | "metadata": {}, 52 | "outputs": [ 53 | { 54 | "name": "stdout", 55 | "output_type": "stream", 56 | "text": [ 57 | "Device mapping:\n", 58 | "/job:localhost/replica:0/task:0/device:GPU:0 -> device: 0, name: Quadro P1000, pci bus id: 0000:01:00.0, compute capability: 6.1\n", 59 | "\n" 60 | ] 61 | } 62 | ], 63 | "source": [ 64 | "def session_options(enable_gpu_ram_resizing=True):\n", 65 | " \"\"\"Convenience function which sets common `tf.Session` options.\"\"\"\n", 66 | " config = tf.compat.v1.ConfigProto()\n", 67 | " config.log_device_placement = True\n", 68 | " if enable_gpu_ram_resizing:\n", 69 | " # `allow_growth=True` makes it possible to connect multiple colabs to your\n", 70 | " # GPU. Otherwise the colab malloc's all GPU ram.\n", 71 | " config.gpu_options.allow_growth = True\n", 72 | " return config\n", 73 | "\n", 74 | "def reset_sess(config=None):\n", 75 | " \"\"\"Convenience function to create the TF graph and session, or reset them.\"\"\"\n", 76 | " if config is None:\n", 77 | " config = session_options()\n", 78 | " tf.compat.v1.reset_default_graph()\n", 79 | " global sess\n", 80 | " try:\n", 81 | " sess.close()\n", 82 | " except:\n", 83 | " pass\n", 84 | " sess = tf.compat.v1.InteractiveSession(config=config)\n", 85 | "\n", 86 | "reset_sess()" 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "metadata": {}, 92 | "source": [ 93 | "## Generate some data" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 3, 99 | "metadata": {}, 100 | "outputs": [ 101 | { 102 | "data": { 103 | "image/png": "\n", 104 | "text/plain": [ 105 | "
" 106 | ] 107 | }, 108 | "metadata": { 109 | "needs_background": "light" 110 | }, 111 | "output_type": "display_data" 112 | } 113 | ], 114 | "source": [ 115 | "eps = 1e-6\n", 116 | "\n", 117 | "n_samples_per_component = 1000\n", 118 | "n_dims = 3\n", 119 | "X = []\n", 120 | "X.append(np.random.multivariate_normal([5, 5, 0],\n", 121 | " [[0.1, 0, 0], [0, 1, 0.6], [0, 0.6, 2]],\n", 122 | " n_samples_per_component).astype('float64'))\n", 123 | "X.append(np.random.multivariate_normal([-5, -5, 0],\n", 124 | " [[0.1, 0, 0], [0, 1, 0.6], [0, 0.6, 2]],\n", 125 | " n_samples_per_component).astype('float64'))\n", 126 | "X.append(np.random.multivariate_normal([0, 5, 5],\n", 127 | " [[1, 0.6, 0], [0.6, 1, 0], [0, 0, 0.2]],\n", 128 | " n_samples_per_component).astype('float64'))\n", 129 | "X.append(np.random.multivariate_normal([0, -5, -5],\n", 130 | " [[1, 0.6, 0], [0.6, 1, 0], [0, 0, 0.2]],\n", 131 | " n_samples_per_component).astype('float64'))\n", 132 | "n_components = len(X)\n", 133 | "n_samples = n_samples_per_component * n_components\n", 134 | "X = np.concatenate(X)\n", 135 | "\n", 136 | "# Plot the data\n", 137 | "fig = plt.figure(figsize=(10, 10))\n", 138 | "ax = plt.axes(projection='3d')\n", 139 | "ax.scatter(X[:, 0], X[:, 1], X[:, 2], marker='.')\n", 140 | "ax.set_xlabel('$x_0$')\n", 141 | "ax.set_ylabel('$x_1$')\n", 142 | "ax.set_zlabel('$x_2$')\n", 143 | "plt.show()" 144 | ] 145 | }, 146 | { 147 | "cell_type": "markdown", 148 | "metadata": {}, 149 | "source": [ 150 | "## Build a Gaussian Mixture Model\n", 151 | "$$\n", 152 | "\\begin{aligned}\n", 153 | "\\mathbf{x}_n &\\sim \\sum\\limits_k \\pi_k \\mathcal{N}\\left(\\pmb{\\mu_x}_k, \\pmb{\\Sigma_x}_k\\right)\\\\\n", 154 | "\\pi_k &\\sim \\mathcal{D}\\left(K, \\mathbf{c}\\right)\\\\\n", 155 | "\\forall k \\quad&\\left\\{\n", 156 | " \\begin{aligned}\n", 157 | " \\pmb{\\Sigma_x}_k^{-1} &\\sim \\mathcal{W}\\left(\\mathbf{W}, \\nu\\right)\\\\\n", 158 | " \\pmb{\\mu_x}_k &\\sim \\mathcal{N}\\left(\\pmb{\\mu_0}, \\pmb{\\Sigma_0}\\right)\n", 159 | " \\end{aligned}\n", 160 | "\\right.\n", 161 | "\\end{aligned}\n", 162 | "$$" 163 | ] 164 | }, 165 | { 166 | "cell_type": "markdown", 167 | "metadata": {}, 168 | "source": [ 169 | "### prior distributions" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": 4, 175 | "metadata": {}, 176 | "outputs": [ 177 | { 178 | "name": "stdout", 179 | "output_type": "stream", 180 | "text": [ 181 | "WARNING:tensorflow:From C:\\ProgramData\\Anaconda3\\envs\\nightly\\lib\\site-packages\\tensorflow_core\\python\\ops\\linalg\\linear_operator_lower_triangular.py:158: calling LinearOperator.__init__ (from tensorflow.python.ops.linalg.linear_operator) with graph_parents is deprecated and will be removed in a future version.\n", 182 | "Instructions for updating:\n", 183 | "Do not pass `graph_parents`. They will no longer be used.\n" 184 | ] 185 | } 186 | ], 187 | "source": [ 188 | "theta_prior = tfd.Dirichlet(\n", 189 | " concentration=2 * np.ones(n_components, dtype=np.float64), name='theta_prior',\n", 190 | " validate_args=True, allow_nan_stats=False)\n", 191 | "\n", 192 | "mu_prior = tfd.Independent(\n", 193 | " tfd.Normal(loc=np.stack([np.zeros(n_dims, dtype=np.float64)] * n_components),\n", 194 | " scale=tf.ones((n_components, n_dims), dtype=np.float64),\n", 195 | " validate_args=True, allow_nan_stats=False),\n", 196 | " reinterpreted_batch_ndims=1,\n", 197 | " name='mu_prior',\n", 198 | " validate_args=True)\n", 199 | "\n", 200 | "invcov_chol_prior = tfd.WishartTriL(df=n_dims+2,\n", 201 | " scale_tril=np.stack([np.eye(n_dims, dtype=np.float64)] * n_components),\n", 202 | " input_output_cholesky=True,\n", 203 | " name='invcov_chol_prior',\n", 204 | " validate_args=True, allow_nan_stats=False)" 205 | ] 206 | }, 207 | { 208 | "cell_type": "markdown", 209 | "metadata": {}, 210 | "source": [ 211 | "### joint log probability" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": 5, 217 | "metadata": {}, 218 | "outputs": [], 219 | "source": [ 220 | "def joint_log_prob(x, theta, mu, invcov_chol):\n", 221 | " \"\"\"BGMM with priors: theta=Dirichlet, mu=Normal, invcov_chol=Wishart\n", 222 | "\n", 223 | " Args:\n", 224 | " x: `[n, d]`-shaped `Tensor` representing Bayesian Gaussian Mixture model draws.\n", 225 | " Each sample is a length-`d` vector.\n", 226 | " theta: `[K]`-shaped `Tensor` representing random draw from `SoftmaxInverse(Dirichlet)` prior.\n", 227 | " mu: `[K, d]`-shaped `Tensor` representing the location parameter of the `K` components.\n", 228 | " invcov_chol: `[K, d, d]`-shaped `Tensor` representing `K` lower triangular `cholesky(Precision)` matrices,\n", 229 | " each being sampled from a Wishart distribution.\n", 230 | "\n", 231 | " Returns:\n", 232 | " log_prob: `Tensor` representing joint log-density over all inputs.\n", 233 | " \"\"\"\n", 234 | " cov = tf.linalg.inv(tf.matmul(invcov_chol, tf.linalg.matrix_transpose(invcov_chol)))\n", 235 | "\n", 236 | " gmm = tfd.MixtureSameFamily(\n", 237 | " mixture_distribution=tfd.Categorical(probs=theta),\n", 238 | " components_distribution=tfd.MultivariateNormalTriL(loc=mu,\n", 239 | " scale_tril=tf.linalg.cholesky(cov)))\n", 240 | " log_prob_parts = [\n", 241 | " gmm.log_prob(x), # log probabilities (summed to be log-likelihoods)\n", 242 | " theta_prior.log_prob(theta)[..., tf.newaxis], # prior probabilities of theta\n", 243 | " mu_prior.log_prob(mu), # prior probabilities of mu\n", 244 | " invcov_chol_prior.log_prob(invcov_chol) # prior probabilities of invcov_chol\n", 245 | " ]\n", 246 | " sum_log_prob = tf.reduce_sum(tf.concat(log_prob_parts, axis=-1), axis=-1) # joint log probabilities\n", 247 | " return sum_log_prob" 248 | ] 249 | }, 250 | { 251 | "cell_type": "markdown", 252 | "metadata": {}, 253 | "source": [ 254 | "## Bayesian Inference using Markov chain Monte Carlo (MCMC) algorithms" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": 6, 260 | "metadata": {}, 261 | "outputs": [], 262 | "source": [ 263 | "unnormalized_posterior_log_prob = functools.partial(joint_log_prob, X)" 264 | ] 265 | }, 266 | { 267 | "cell_type": "markdown", 268 | "metadata": {}, 269 | "source": [ 270 | "### set up initial states" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": 7, 276 | "metadata": {}, 277 | "outputs": [], 278 | "source": [ 279 | "initial_state = [\n", 280 | " (1 / n_components) * tf.ones(n_components, dtype=tf.float64, name='theta'),\n", 281 | " tf.convert_to_tensor(np.array([[5, 5, 0],\n", 282 | " [-5, -5, 0],\n", 283 | " [0, 5, 5],\n", 284 | " [0, -5, -5]], dtype=np.float64),\n", 285 | " name='mu'),\n", 286 | " tf.linalg.diag(tf.ones((n_components, n_dims), dtype=tf.float64), name='invcov_chol')\n", 287 | "]" 288 | ] 289 | }, 290 | { 291 | "cell_type": "markdown", 292 | "metadata": {}, 293 | "source": [ 294 | "### unconstrained representation\n", 295 | "MCMC requires the target log-probability function be differentiable with respect to its arguments. Furthermore, MCMC can exhibit dramatically higher statistical efficiency if the state-space is unconstrained.\n", 296 | "\n", 297 | "To address this requirement we'll need to:\n", 298 | "\n", 299 | "1. transform the constrained variables to an unconstrained space;\n", 300 | "2. run the MCMC in unconstrained space;\n", 301 | "3. transform the unconstrained variables back to the constrained space." 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": 8, 307 | "metadata": {}, 308 | "outputs": [], 309 | "source": [ 310 | "# bijectors transforms the unconstrained variables back to the constrained space\n", 311 | "unconstraining_bijectors = [\n", 312 | " tfb.SoftmaxCentered(validate_args=True), # transform unconstrained theta values to discrete probability vectors\n", 313 | " tfb.Identity(validate_args=True), # identity transformation (no transformation) for mu values\n", 314 | " tfb.Chain([\n", 315 | " tfb.TransformDiagonal(tfb.Softplus(), validate_args=True),\n", 316 | " tfb.FillTriangular(validate_args=True)\n", 317 | " ]) # transforms unconstrained invcov_chol values to lower triangular matrices with positive diagonal\n", 318 | "]" 319 | ] 320 | }, 321 | { 322 | "cell_type": "markdown", 323 | "metadata": {}, 324 | "source": [ 325 | "### graph for Random Walk Metropolis (RWM) sampling (Method 1)" 326 | ] 327 | }, 328 | { 329 | "cell_type": "code", 330 | "execution_count": 9, 331 | "metadata": {}, 332 | "outputs": [ 333 | { 334 | "name": "stdout", 335 | "output_type": "stream", 336 | "text": [ 337 | "WARNING:tensorflow:From C:\\ProgramData\\Anaconda3\\envs\\nightly\\lib\\site-packages\\tensorflow_probability\\python\\distributions\\mvn_linear_operator.py:193: AffineLinearOperator.__init__ (from tensorflow_probability.python.bijectors.affine_linear_operator) is deprecated and will be removed after 2020-01-01.\n", 338 | "Instructions for updating:\n", 339 | "`AffineLinearOperator` bijector is deprecated; please use `tfb.Shift(loc)(tfb.MatvecLinearOperator(...))`.\n" 340 | ] 341 | }, 342 | { 343 | "name": "stderr", 344 | "output_type": "stream", 345 | "text": [ 346 | "C:\\ProgramData\\Anaconda3\\envs\\nightly\\lib\\site-packages\\tensorflow_probability\\python\\mcmc\\sample.py:333: UserWarning: Tracing all kernel results by default is deprecated. Set the `trace_fn` argument to None (the future default value) or an explicit callback that traces the values you are interested in.\n", 347 | " warnings.warn(\"Tracing all kernel results by default is deprecated. Set \"\n" 348 | ] 349 | } 350 | ], 351 | "source": [ 352 | "n_results = 16000\n", 353 | "n_burnin_steps = 8000\n", 354 | "scale = 1e-3\n", 355 | "[theta, mu, invcov_chol], kernel_results = tfp.mcmc.sample_chain(num_results=n_results,\n", 356 | " num_burnin_steps=n_burnin_steps,\n", 357 | " current_state=initial_state,\n", 358 | " kernel=tfp.mcmc.TransformedTransitionKernel(\n", 359 | " inner_kernel=tfp.mcmc.RandomWalkMetropolis(\n", 360 | " target_log_prob_fn=unnormalized_posterior_log_prob,\n", 361 | " new_state_fn=tfp.mcmc.random_walk_normal_fn(scale=scale)),\n", 362 | " bijector=unconstraining_bijectors),\n", 363 | " parallel_iterations=100)" 364 | ] 365 | }, 366 | { 367 | "cell_type": "markdown", 368 | "metadata": {}, 369 | "source": [ 370 | "### launch the graph and display the inferred statistics" 371 | ] 372 | }, 373 | { 374 | "cell_type": "code", 375 | "execution_count": 10, 376 | "metadata": {}, 377 | "outputs": [ 378 | { 379 | "name": "stdout", 380 | "output_type": "stream", 381 | "text": [ 382 | "elapsed time: 9595.37s\n", 383 | "acceptance_rate: 0.862625\n", 384 | "avg mix probs: [0.24844049 0.25512503 0.24762867 0.24880581]\n", 385 | "\n", 386 | "avg loc:\n", 387 | " [[ 4.99136968 5.02519328 0.12765716]\n", 388 | " [-4.98041703 -4.97772828 -0.0572409 ]\n", 389 | " [ 0.00996063 5.02680586 5.00870718]\n", 390 | " [-0.0314585 -4.98221856 -4.99639094]]\n", 391 | "\n", 392 | "avg precision cholesky:\n", 393 | " [[[ 2.47709883 0. 0. ]\n", 394 | " [ 0.05957447 1.10980288 0. ]\n", 395 | " [-0.01921852 -0.35457543 0.70271206]]\n", 396 | "\n", 397 | " [[ 2.46778401 0. 0. ]\n", 398 | " [-0.0616935 1.0843033 0. ]\n", 399 | " [ 0.00980305 -0.31878146 0.69907693]]\n", 400 | "\n", 401 | " [[ 1.23256135 0. 0. ]\n", 402 | " [-0.67339941 0.97340591 0. ]\n", 403 | " [-0.15097277 0.21016192 2.05898233]]\n", 404 | "\n", 405 | " [[ 1.25751358 0. 0. ]\n", 406 | " [-0.76347918 0.99049323 0. ]\n", 407 | " [-0.10465556 -0.03029713 2.09267267]]]\n", 408 | "\n", 409 | "avg covariance matrix:\n", 410 | " [[[ 1.65442030e-01 -1.91825730e-02 -3.96328617e-05]\n", 411 | " [-1.91825730e-02 1.01935527e+00 6.47246881e-01]\n", 412 | " [-3.96328617e-05 6.47246881e-01 2.02602839e+00]]\n", 413 | "\n", 414 | " [[ 1.66775947e-01 2.28697436e-02 6.58647346e-03]\n", 415 | " [ 2.28697436e-02 1.02893153e+00 6.01939233e-01]\n", 416 | " [ 6.58647346e-03 6.01939233e-01 2.04737284e+00]]\n", 417 | "\n", 418 | " [[ 9.74369775e-01 5.76760515e-01 1.29443252e-03]\n", 419 | " [ 5.76760515e-01 1.06695074e+00 -5.08535367e-02]\n", 420 | " [ 1.29443252e-03 -5.08535367e-02 2.36544302e-01]]\n", 421 | "\n", 422 | " [[ 1.01238103e+00 6.21230738e-01 2.36046567e-02]\n", 423 | " [ 6.21230738e-01 1.02236233e+00 6.93620343e-03]\n", 424 | " [ 2.36046567e-02 6.93620343e-03 2.29015027e-01]]]\n" 425 | ] 426 | } 427 | ], 428 | "source": [ 429 | "acceptance_rate = tf.reduce_mean(tf.cast(kernel_results.inner_results.is_accepted, tf.float64))\n", 430 | "\n", 431 | "theta = theta[n_burnin_steps:]\n", 432 | "mu = mu[n_burnin_steps:]\n", 433 | "invcov_chol = invcov_chol[n_burnin_steps:]\n", 434 | "cov = tf.linalg.inv(tf.matmul(invcov_chol, tf.linalg.matrix_transpose(invcov_chol)))\n", 435 | "\n", 436 | "mean_theta = tf.reduce_mean(theta, axis=0)\n", 437 | "mean_mu = tf.reduce_mean(mu, axis=0)\n", 438 | "mean_invcov_chol = tf.reduce_mean(invcov_chol, axis=0)\n", 439 | "mean_cov = tf.reduce_mean(cov, axis=0)\n", 440 | "\n", 441 | "start_time = time.time()\n", 442 | "[val_acceptance_rate, val_mean_theta, val_mean_mu, val_mean_invcov_chol, val_mean_cov] = sess.run(\n", 443 | " [acceptance_rate, mean_theta, mean_mu, mean_invcov_chol, mean_cov])\n", 444 | "elapsed_time = time.time() - start_time\n", 445 | "print('elapsed time: {:.2f}s'.format(elapsed_time))\n", 446 | "\n", 447 | "print('acceptance_rate:', val_acceptance_rate)\n", 448 | "print('avg mix probs:', val_mean_theta)\n", 449 | "print('\\navg loc:\\n', val_mean_mu)\n", 450 | "print('\\navg precision cholesky:\\n', val_mean_invcov_chol)\n", 451 | "print('\\navg covariance matrix:\\n', val_mean_cov)" 452 | ] 453 | }, 454 | { 455 | "cell_type": "markdown", 456 | "metadata": {}, 457 | "source": [ 458 | "### graph for Hamiltonian Monte Carlo (HMC) sampling (Method 2)" 459 | ] 460 | }, 461 | { 462 | "cell_type": "code", 463 | "execution_count": 11, 464 | "metadata": {}, 465 | "outputs": [ 466 | { 467 | "name": "stderr", 468 | "output_type": "stream", 469 | "text": [ 470 | "C:\\ProgramData\\Anaconda3\\envs\\nightly\\lib\\site-packages\\tensorflow_probability\\python\\mcmc\\sample.py:333: UserWarning: Tracing all kernel results by default is deprecated. Set the `trace_fn` argument to None (the future default value) or an explicit callback that traces the values you are interested in.\n", 471 | " warnings.warn(\"Tracing all kernel results by default is deprecated. Set \"\n" 472 | ] 473 | } 474 | ], 475 | "source": [ 476 | "n_results = 4000\n", 477 | "n_burnin_steps = 2000\n", 478 | "step_size = 1e-2\n", 479 | "num_leapfrog_steps = 10\n", 480 | "[theta, mu, invcov_chol], kernel_results = tfp.mcmc.sample_chain(num_results=n_results,\n", 481 | " num_burnin_steps=n_burnin_steps,\n", 482 | " current_state=initial_state,\n", 483 | " kernel=tfp.mcmc.TransformedTransitionKernel(\n", 484 | " inner_kernel=tfp.mcmc.HamiltonianMonteCarlo(\n", 485 | " target_log_prob_fn=unnormalized_posterior_log_prob,\n", 486 | " step_size=step_size,\n", 487 | " num_leapfrog_steps=num_leapfrog_steps),\n", 488 | " bijector=unconstraining_bijectors),\n", 489 | " parallel_iterations=100)" 490 | ] 491 | }, 492 | { 493 | "cell_type": "markdown", 494 | "metadata": {}, 495 | "source": [ 496 | "### launch the graph and display the inferred statistics" 497 | ] 498 | }, 499 | { 500 | "cell_type": "code", 501 | "execution_count": 12, 502 | "metadata": {}, 503 | "outputs": [ 504 | { 505 | "name": "stdout", 506 | "output_type": "stream", 507 | "text": [ 508 | "elapsed time: 1529.73s\n", 509 | "acceptance_rate: 0.87025\n", 510 | "avg mix probs: [0.2498434 0.25009479 0.25014161 0.2499202 ]\n", 511 | "\n", 512 | "avg loc:\n", 513 | " [[ 4.99109595 5.01778177 0.10850991]\n", 514 | " [-4.983237 -4.97558152 -0.03743145]\n", 515 | " [-0.0129128 5.01482 5.00213167]\n", 516 | " [ 0.02351644 -4.97109919 -5.005098 ]]\n", 517 | "\n", 518 | "avg precision cholesky:\n", 519 | " [[[ 3.06681467 0. 0. ]\n", 520 | " [ 0.07086182 1.1076119 0. ]\n", 521 | " [-0.0261507 -0.34831598 0.70353157]]\n", 522 | "\n", 523 | " [[ 3.05631866 0. 0. ]\n", 524 | " [-0.05208389 1.0710462 0. ]\n", 525 | " [ 0.00821412 -0.31225647 0.70647403]]\n", 526 | "\n", 527 | " [[ 1.24444161 0. 0. ]\n", 528 | " [-0.74107702 1.01483969 0. ]\n", 529 | " [-0.08706177 0.13554366 2.22257539]]\n", 530 | "\n", 531 | " [[ 1.24503975 0. 0. ]\n", 532 | " [-0.7727238 0.9953581 0. ]\n", 533 | " [ 0.01202825 -0.06615687 2.25012359]]]\n", 534 | "\n", 535 | "avg covariance matrix:\n", 536 | " [[[ 1.07148881e-01 -1.80524840e-02 2.55286089e-03]\n", 537 | " [-1.80524840e-02 1.01724448e+00 6.36348312e-01]\n", 538 | " [ 2.55286089e-03 6.36348312e-01 2.02299863e+00]]\n", 539 | "\n", 540 | " [[ 1.07701479e-01 1.62243932e-02 4.59128723e-03]\n", 541 | " [ 1.62243932e-02 1.04463059e+00 5.85522613e-01]\n", 542 | " [ 4.59128723e-03 5.85522613e-01 2.00849777e+00]]\n", 543 | "\n", 544 | " [[ 9.93131442e-01 5.79904979e-01 -1.95259448e-03]\n", 545 | " [ 5.79904979e-01 9.77100814e-01 -2.70934357e-02]\n", 546 | " [-1.95259448e-03 -2.70934357e-02 2.02742248e-01]]\n", 547 | "\n", 548 | " [[ 1.03734163e+00 6.28533826e-01 6.26399375e-03]\n", 549 | " [ 6.28533826e-01 1.01304587e+00 1.31660492e-02]\n", 550 | " [ 6.26399375e-03 1.31660492e-02 1.97802389e-01]]]\n" 551 | ] 552 | } 553 | ], 554 | "source": [ 555 | "acceptance_rate = tf.reduce_mean(tf.cast(kernel_results.inner_results.is_accepted, tf.float64))\n", 556 | "\n", 557 | "theta = theta[n_burnin_steps:]\n", 558 | "mu = mu[n_burnin_steps:]\n", 559 | "invcov_chol = invcov_chol[n_burnin_steps:]\n", 560 | "cov = tf.linalg.inv(tf.matmul(invcov_chol, tf.linalg.matrix_transpose(invcov_chol)))\n", 561 | "\n", 562 | "mean_theta = tf.reduce_mean(theta, axis=0)\n", 563 | "mean_mu = tf.reduce_mean(mu, axis=0)\n", 564 | "mean_invcov_chol = tf.reduce_mean(invcov_chol, axis=0)\n", 565 | "mean_cov = tf.reduce_mean(cov, axis=0)\n", 566 | "\n", 567 | "start_time = time.time()\n", 568 | "[val_acceptance_rate, val_mean_theta, val_mean_mu, val_mean_invcov_chol, val_mean_cov] = sess.run(\n", 569 | " [acceptance_rate, mean_theta, mean_mu, mean_invcov_chol, mean_cov])\n", 570 | "elapsed_time = time.time() - start_time\n", 571 | "print('elapsed time: {:.2f}s'.format(elapsed_time))\n", 572 | "\n", 573 | "print('acceptance_rate:', val_acceptance_rate)\n", 574 | "print('avg mix probs:', val_mean_theta)\n", 575 | "print('\\navg loc:\\n', val_mean_mu)\n", 576 | "print('\\navg precision cholesky:\\n', val_mean_invcov_chol)\n", 577 | "print('\\navg covariance matrix:\\n', val_mean_cov)" 578 | ] 579 | }, 580 | { 581 | "cell_type": "markdown", 582 | "metadata": {}, 583 | "source": [ 584 | "### graph for Metropolis-adjusted Langevin algorithm (MALA) sampling (Method 3)" 585 | ] 586 | }, 587 | { 588 | "cell_type": "code", 589 | "execution_count": 13, 590 | "metadata": {}, 591 | "outputs": [ 592 | { 593 | "name": "stderr", 594 | "output_type": "stream", 595 | "text": [ 596 | "C:\\ProgramData\\Anaconda3\\envs\\nightly\\lib\\site-packages\\tensorflow_probability\\python\\mcmc\\sample.py:333: UserWarning: Tracing all kernel results by default is deprecated. Set the `trace_fn` argument to None (the future default value) or an explicit callback that traces the values you are interested in.\n", 597 | " warnings.warn(\"Tracing all kernel results by default is deprecated. Set \"\n" 598 | ] 599 | } 600 | ], 601 | "source": [ 602 | "n_results = 4000\n", 603 | "n_burnin_steps = 2000\n", 604 | "step_size = 2e-4\n", 605 | "[theta, mu, invcov_chol], kernel_results = tfp.mcmc.sample_chain(num_results=n_results,\n", 606 | " num_burnin_steps=n_burnin_steps,\n", 607 | " current_state=initial_state,\n", 608 | " kernel=tfp.mcmc.TransformedTransitionKernel(\n", 609 | " inner_kernel=tfp.mcmc.MetropolisAdjustedLangevinAlgorithm(\n", 610 | " target_log_prob_fn=unnormalized_posterior_log_prob,\n", 611 | " step_size=step_size),\n", 612 | " bijector=unconstraining_bijectors),\n", 613 | " parallel_iterations=100)" 614 | ] 615 | }, 616 | { 617 | "cell_type": "markdown", 618 | "metadata": {}, 619 | "source": [ 620 | "### launch the graph and display the inferred statistics" 621 | ] 622 | }, 623 | { 624 | "cell_type": "code", 625 | "execution_count": 14, 626 | "metadata": {}, 627 | "outputs": [ 628 | { 629 | "name": "stdout", 630 | "output_type": "stream", 631 | "text": [ 632 | "elapsed time: 759.22s\n", 633 | "acceptance_rate: 0.6435\n", 634 | "avg mix probs: [0.24883278 0.24951683 0.25063191 0.25101848]\n", 635 | "\n", 636 | "avg loc:\n", 637 | " [[ 4.99119668 5.01439659 0.10412371]\n", 638 | " [-4.98318399 -4.97404597 -0.03899816]\n", 639 | " [-0.01031765 5.01830767 5.00120453]\n", 640 | " [ 0.02402447 -4.969517 -5.00546928]]\n", 641 | "\n", 642 | "avg precision cholesky:\n", 643 | " [[[ 3.05036607 0. 0. ]\n", 644 | " [ 0.07515016 1.10424112 0. ]\n", 645 | " [-0.0280067 -0.34487696 0.70550916]]\n", 646 | "\n", 647 | " [[ 3.03637912 0. 0. ]\n", 648 | " [-0.06450821 1.07119057 0. ]\n", 649 | " [ 0.01557511 -0.31028529 0.70706238]]\n", 650 | "\n", 651 | " [[ 1.25064862 0. 0. ]\n", 652 | " [-0.7510701 1.0170769 0. ]\n", 653 | " [-0.07914385 0.1028017 2.22949531]]\n", 654 | "\n", 655 | " [[ 1.24028647 0. 0. ]\n", 656 | " [-0.76607108 0.99803328 0. ]\n", 657 | " [ 0.0452747 -0.02509068 2.24655694]]]\n", 658 | "\n", 659 | "avg covariance matrix:\n", 660 | " [[[ 1.08363041e-01 -1.93511679e-02 2.95535839e-03]\n", 661 | " [-1.93511679e-02 1.01813455e+00 6.28355810e-01]\n", 662 | " [ 2.95535839e-03 6.28355810e-01 2.01194212e+00]]\n", 663 | "\n", 664 | " [[ 1.09282689e-01 1.91413964e-02 2.11489311e-03]\n", 665 | " [ 1.91413964e-02 1.04140895e+00 5.80251421e-01]\n", 666 | " [ 2.11489311e-03 5.80251421e-01 2.00310135e+00]]\n", 667 | "\n", 668 | " [[ 9.90906538e-01 5.81840954e-01 5.75234315e-04]\n", 669 | " [ 5.81840954e-01 9.70897765e-01 -2.04146856e-02]\n", 670 | " [ 5.75234315e-04 -2.04146856e-02 2.01515665e-01]]\n", 671 | "\n", 672 | " [[ 1.03684055e+00 6.21611624e-01 -4.42933867e-03]\n", 673 | " [ 6.21611624e-01 1.00620622e+00 4.85717982e-03]\n", 674 | " [-4.42933867e-03 4.85717982e-03 1.98521125e-01]]]\n" 675 | ] 676 | } 677 | ], 678 | "source": [ 679 | "acceptance_rate = tf.reduce_mean(tf.cast(kernel_results.inner_results.is_accepted, tf.float64))\n", 680 | "\n", 681 | "theta = theta[n_burnin_steps:]\n", 682 | "mu = mu[n_burnin_steps:]\n", 683 | "invcov_chol = invcov_chol[n_burnin_steps:]\n", 684 | "cov = tf.linalg.inv(tf.matmul(invcov_chol, tf.linalg.matrix_transpose(invcov_chol)))\n", 685 | "\n", 686 | "mean_theta = tf.reduce_mean(theta, axis=0)\n", 687 | "mean_mu = tf.reduce_mean(mu, axis=0)\n", 688 | "mean_invcov_chol = tf.reduce_mean(invcov_chol, axis=0)\n", 689 | "mean_cov = tf.reduce_mean(cov, axis=0)\n", 690 | "\n", 691 | "start_time = time.time()\n", 692 | "[val_acceptance_rate, val_mean_theta, val_mean_mu, val_mean_invcov_chol, val_mean_cov] = sess.run(\n", 693 | " [acceptance_rate, mean_theta, mean_mu, mean_invcov_chol, mean_cov])\n", 694 | "elapsed_time = time.time() - start_time\n", 695 | "print('elapsed time: {:.2f}s'.format(elapsed_time))\n", 696 | "\n", 697 | "print('acceptance_rate:', val_acceptance_rate)\n", 698 | "print('avg mix probs:', val_mean_theta)\n", 699 | "print('\\navg loc:\\n', val_mean_mu)\n", 700 | "print('\\navg precision cholesky:\\n', val_mean_invcov_chol)\n", 701 | "print('\\navg covariance matrix:\\n', val_mean_cov)" 702 | ] 703 | } 704 | ], 705 | "metadata": { 706 | "kernelspec": { 707 | "display_name": "Python 3", 708 | "language": "python", 709 | "name": "python3" 710 | }, 711 | "language_info": { 712 | "codemirror_mode": { 713 | "name": "ipython", 714 | "version": 3 715 | }, 716 | "file_extension": ".py", 717 | "mimetype": "text/x-python", 718 | "name": "python", 719 | "nbconvert_exporter": "python", 720 | "pygments_lexer": "ipython3", 721 | "version": "3.7.3" 722 | } 723 | }, 724 | "nbformat": 4, 725 | "nbformat_minor": 2 726 | } 727 | --------------------------------------------------------------------------------