├── .gitignore ├── LICENSE ├── README.md ├── conv_layers.py ├── dense_layers.py ├── layers.py ├── mnf_lenet_mnist.py ├── mnist.pkl.gz ├── mnist.py ├── norm_flows.py ├── utils.py └── wrappers.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 AMLAB 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Example implementation of the Multiplicative Normalizing Flow (MNF) parameter posteriors found in: 2 | 3 | *Multiplicative Normalizing Flows for Variational Bayesian Neural Networks* 4 | Christos Louizos & Max Welling 5 | https://arxiv.org/abs/1703.01961 6 | 7 | This code is provided as is and is not maintained / will not be updated. 8 | -------------------------------------------------------------------------------- /conv_layers.py: -------------------------------------------------------------------------------- 1 | from layers import Layer 2 | import numpy as np 3 | import tensorflow as tf 4 | from norm_flows import MaskedNVPFlow 5 | from utils import randmat, zeros_d, ones_d, outer 6 | 7 | 8 | class Conv2DMNF(Layer): 9 | '''2D convolutional layer with a multiplicative normalizing flow (MNF) aproximate posterior over the weights. 10 | Prior is a standard normal. 11 | ''' 12 | 13 | def __init__(self, nb_filter, nb_row, nb_col, input_shape=(), activation=tf.identity, N=1, name=None, 14 | border_mode='SAME', subsample=(1, 1, 1, 1), flows_q=2, flows_r=2, learn_p=False, use_z=True, 15 | prior_var=1., prior_var_b=1., flow_dim_h=50, logging=False, thres_var=1., **kwargs): 16 | 17 | if border_mode not in {'VALID', 'SAME'}: 18 | raise Exception('Invalid border mode for Convolution2D:', border_mode) 19 | 20 | self.nb_filter = nb_filter 21 | self.nb_row = nb_row 22 | self.nb_col = nb_col 23 | self.border_mode = border_mode 24 | self.subsample = subsample 25 | self.thres_var = thres_var 26 | 27 | self.N = N 28 | self.flow_dim_h = flow_dim_h 29 | self.learn_p = learn_p 30 | self.input_shape = input_shape 31 | 32 | self.prior_var = prior_var 33 | self.prior_var_b = prior_var_b 34 | self.n_flows_q = flows_q 35 | self.n_flows_r = flows_r 36 | self.use_z = use_z 37 | super(Conv2DMNF, self).__init__(N=N, nonlin=activation, name=name, logging=logging) 38 | 39 | def build(self): 40 | stack_size = self.input_shape[-1] 41 | self.W_shape = (self.nb_row, self.nb_col, stack_size, self.nb_filter) 42 | self.input_dim = self.nb_col * stack_size * self.nb_row 43 | self.stack_size = stack_size 44 | 45 | with tf.variable_scope(self.name): 46 | self.mu_W = randmat(self.W_shape, name='mean_W') 47 | self.logvar_W = randmat(self.W_shape, mu=-9., name='logvar_W', extra_scale=1e-6) 48 | self.mu_bias = tf.Variable(tf.zeros((self.nb_filter,)), name='mean_bias') 49 | self.logvar_bias = randmat((self.nb_filter,), mu=-9., name='logvar_bias', extra_scale=1e-6) 50 | 51 | if self.use_z: 52 | self.qzero_mean = randmat((self.nb_filter,), name='dropout_rates_mean', mu=1. if self.n_flows_q == 0 else 0.) 53 | self.qzero = randmat((self.nb_filter,), name='dropout_rates', mu=np.log(0.1), extra_scale=1e-6) 54 | self.rsr_M = randmat((self.nb_filter,), name='var_r_aux') 55 | self.apvar_M = randmat((self.nb_filter,), name='apvar_r_aux') 56 | self.rsri_M = randmat((self.nb_filter,), name='var_r_auxi') 57 | 58 | self.pvar = randmat((self.input_dim,), mu=np.log(self.prior_var), name='prior_var_r_p', extra_scale=1e-6, trainable=self.learn_p) 59 | self.pvar_bias = randmat((1,), mu=np.log(self.prior_var_b), name='prior_var_r_p_bias', extra_scale=1e-6, trainable=self.learn_p) 60 | 61 | if self.n_flows_r > 0: 62 | self.flow_r = MaskedNVPFlow(self.nb_filter, n_flows=self.n_flows_r, name=self.name + '_fr', n_hidden=0, 63 | dim_h=2 * self.flow_dim_h, scope=self.name) 64 | 65 | if self.n_flows_q > 0: 66 | self.flow_q = MaskedNVPFlow(self.nb_filter, n_flows=self.n_flows_q, name=self.name + '_fq', n_hidden=0, 67 | dim_h=self.flow_dim_h, scope=self.name) 68 | 69 | print 'Built layer {}, output_dim: {}, input_shape: {}, flows_r: {}, flows_q: {}, use_z: {}, learn_p: {}, ' \ 70 | 'pvar: {}, thres_var: {}'.format(self.name, self.nb_filter, self.input_shape, self.n_flows_r, 71 | self.n_flows_q, self.use_z, self.learn_p, self.prior_var, self.thres_var) 72 | 73 | def sample_z(self, size_M=1, sample=True): 74 | if not self.use_z: 75 | return ones_d((size_M, self.nb_filter)), zeros_d((size_M,)) 76 | qm0 = self.get_params_m() 77 | isample_M = tf.tile(tf.expand_dims(self.qzero_mean, 0), [size_M, 1]) 78 | eps = tf.random_normal(tf.stack((size_M, self.nb_filter))) 79 | sample_M = isample_M + tf.sqrt(qm0) * eps if sample else isample_M 80 | 81 | logdets = zeros_d((size_M,)) 82 | if self.n_flows_q > 0: 83 | sample_M, logdets = self.flow_q.get_output_for(sample_M, sample=sample) 84 | 85 | return sample_M, logdets 86 | 87 | def get_params_m(self): 88 | if not self.use_z: 89 | return None 90 | 91 | return tf.exp(self.qzero) 92 | 93 | def get_params_W(self): 94 | return tf.exp(self.logvar_W) 95 | 96 | def get_mean_var(self, x): 97 | var_w = tf.clip_by_value(self.get_params_W(), 0., self.thres_var) 98 | var_w = tf.square(var_w) 99 | var_b = tf.clip_by_value(tf.exp(self.logvar_bias), 0., self.thres_var**2) 100 | 101 | # formally we do cross-correlation here 102 | muout = tf.nn.conv2d(x, self.mu_W, self.subsample, self.border_mode, use_cudnn_on_gpu=True) + self.mu_bias 103 | varout = tf.nn.conv2d(tf.square(x), var_w, self.subsample, self.border_mode, use_cudnn_on_gpu=True) + var_b 104 | 105 | return muout, varout 106 | 107 | def kldiv(self): 108 | M, logdets = self.sample_z() 109 | logdets = logdets[0] 110 | M = tf.squeeze(M) 111 | 112 | std_w = self.get_params_W() 113 | mu = tf.reshape(self.mu_W, [-1, self.nb_filter]) 114 | std_w = tf.reshape(std_w, [-1, self.nb_filter]) 115 | Mtilde = mu * tf.expand_dims(M, 0) 116 | mbias = self.mu_bias * M 117 | Vtilde = tf.square(std_w) 118 | 119 | iUp = outer(tf.exp(self.pvar), ones_d((self.nb_filter,))) 120 | 121 | qm0 = self.get_params_m() 122 | logqm = 0. 123 | if self.use_z > 0.: 124 | logqm = - tf.reduce_sum(.5 * (tf.log(2 * np.pi) + tf.log(qm0) + 1)) 125 | logqm -= logdets 126 | 127 | kldiv_w = tf.reduce_sum(.5 * tf.log(iUp) - .5 * tf.log(Vtilde) + ((Vtilde + tf.square(Mtilde)) / (2 * iUp)) - .5) 128 | kldiv_bias = tf.reduce_sum(.5 * self.pvar_bias - .5 * self.logvar_bias + ((tf.exp(self.logvar_bias) + 129 | tf.square(mbias)) / (2 * tf.exp(self.pvar_bias))) - .5) 130 | 131 | logrm = 0. 132 | if self.use_z: 133 | apvar_M = self.apvar_M 134 | mw = tf.matmul(Mtilde, tf.expand_dims(apvar_M, 1)) 135 | vw = tf.matmul(Vtilde, tf.expand_dims(tf.square(apvar_M), 1)) 136 | eps = tf.expand_dims(tf.random_normal((self.input_dim,)), 1) 137 | a = mw + tf.sqrt(vw) * eps 138 | mb = tf.reduce_sum(mbias * apvar_M) 139 | vb = tf.reduce_sum(tf.exp(self.logvar_bias) * tf.square(apvar_M)) 140 | a += mb + tf.sqrt(vb) * tf.random_normal(()) 141 | 142 | w__ = tf.reduce_mean(outer(tf.squeeze(a), self.rsr_M), axis=0) 143 | wv__ = tf.reduce_mean(outer(tf.squeeze(a), self.rsri_M), axis=0) 144 | 145 | if self.flow_r is not None: 146 | M, logrm = self.flow_r.get_output_for(tf.expand_dims(M, 0)) 147 | M = tf.squeeze(M) 148 | logrm = logrm[0] 149 | 150 | logrm += tf.reduce_sum(-.5 * tf.exp(wv__) * tf.square(M - w__) - .5 * tf.log(2 * np.pi) + .5 * wv__) 151 | 152 | return - kldiv_w + logrm - logqm - kldiv_bias 153 | 154 | def call(self, x, sample=True, **kwargs): 155 | sample_M, _ = self.sample_z(size_M=tf.shape(x)[0], sample=sample) 156 | sample_M = tf.expand_dims(tf.expand_dims(sample_M, 1), 2) 157 | mean_out, var_out = self.get_mean_var(x) 158 | mean_gout = mean_out * sample_M 159 | var_gout = tf.sqrt(var_out) * tf.random_normal(tf.shape(mean_gout)) 160 | out = mean_gout + var_gout 161 | 162 | output = out if sample else mean_gout 163 | return output 164 | -------------------------------------------------------------------------------- /dense_layers.py: -------------------------------------------------------------------------------- 1 | from layers import Layer 2 | import numpy as np 3 | import tensorflow as tf 4 | from norm_flows import MaskedNVPFlow, PlanarFlow 5 | from utils import randmat, zeros_d, ones_d, outer 6 | 7 | 8 | class DenseMNF(Layer): 9 | '''Fully connected layer with a multiplicative normalizing flow (MNF) aproximate posterior over the weights. 10 | Prior is a standard normal. 11 | ''' 12 | def __init__(self, output_dim, activation=tf.identity, N=1, input_dim=None, flows_q=2, flows_r=2, learn_p=False, 13 | use_z=True, prior_var=1., name=None, logging=False, flow_dim_h=50, prior_var_b=1., thres_var=1., 14 | **kwargs): 15 | 16 | self.output_dim = output_dim 17 | self.learn_p = learn_p 18 | self.prior_var = prior_var 19 | self.prior_var_b = prior_var_b 20 | self.thres_var = thres_var 21 | 22 | self.n_flows_q = flows_q 23 | self.n_flows_r = flows_r 24 | self.use_z = use_z 25 | self.flow_dim_h = flow_dim_h 26 | 27 | self.input_dim = input_dim 28 | super(DenseMNF, self).__init__(N=N, nonlin=activation, name=name, logging=logging) 29 | 30 | def build(self): 31 | dim_in, dim_out = self.input_dim, self.output_dim 32 | 33 | with tf.variable_scope(self.name): 34 | self.mu_W = randmat((dim_in, dim_out), name='mean_W', extra_scale=1.) 35 | self.logvar_W = randmat((dim_in, dim_out), mu=-9., name='var_W', extra_scale=1e-6) 36 | self.mu_bias = tf.Variable(tf.zeros((dim_out,)), name='mean_bias') 37 | self.logvar_bias = randmat((dim_out,), mu=-9., name='var_bias', extra_scale=1e-6) 38 | 39 | if self.use_z: 40 | self.qzero_mean = randmat((dim_in,), name='dropout_rates_mean', mu=1. if self.n_flows_q == 0 else 0.) 41 | self.qzero = randmat((dim_in,), mu=np.log(0.1), name='dropout_rates', extra_scale=1e-6) 42 | self.rsr_M = randmat((dim_in,), name='var_r_aux') 43 | self.apvar_M = randmat((dim_in,), name='apvar_r_aux') 44 | self.rsri_M = randmat((dim_in,), name='var_r_auxi') 45 | 46 | self.pvar = randmat((dim_in,), mu=np.log(self.prior_var), name='prior_var_r_p', trainable=self.learn_p, extra_scale=1e-6) 47 | self.pvar_bias = randmat((1,), mu=np.log(self.prior_var_b), name='prior_var_r_p_bias', trainable=self.learn_p, extra_scale=1e-6) 48 | 49 | if self.n_flows_r > 0: 50 | if dim_in == 1: 51 | self.flow_r = PlanarFlow(dim_in, n_flows=self.n_flows_r, name=self.name + '_fr', scope=self.name) 52 | else: 53 | self.flow_r = MaskedNVPFlow(dim_in, n_flows=self.n_flows_r, name=self.name + '_fr', n_hidden=0, 54 | dim_h=2 * self.flow_dim_h, scope=self.name) 55 | 56 | if self.n_flows_q > 0: 57 | if dim_in == 1: 58 | self.flow_q = PlanarFlow(dim_in, n_flows=self.n_flows_q, name=self.name + '_fq', scope=self.name) 59 | else: 60 | self.flow_q = MaskedNVPFlow(dim_in, n_flows=self.n_flows_q, name=self.name + '_fq', n_hidden=0, 61 | dim_h=self.flow_dim_h, scope=self.name) 62 | 63 | print 'Built layer', self.name, 'prior_var: {}'.format(self.prior_var), \ 64 | 'flows_q: {}, flows_r: {}, use_z: {}'.format(self.n_flows_q, self.n_flows_r, self.use_z), \ 65 | 'learn_p: {}, thres_var: {}'.format(self.learn_p, self.thres_var) 66 | 67 | def sample_z(self, size_M=1, sample=True): 68 | if not self.use_z: 69 | return ones_d((size_M, self.input_dim)), zeros_d((size_M,)) 70 | 71 | qm0 = self.get_params_m() 72 | isample_M = tf.tile(tf.expand_dims(self.qzero_mean, 0), [size_M, 1]) 73 | eps = tf.random_normal(tf.stack((size_M, self.input_dim))) 74 | sample_M = isample_M + tf.sqrt(qm0) * eps if sample else isample_M 75 | 76 | logdets = zeros_d((size_M,)) 77 | if self.n_flows_q > 0: 78 | sample_M, logdets = self.flow_q.get_output_for(sample_M, sample=sample) 79 | 80 | return sample_M, logdets 81 | 82 | def get_params_m(self): 83 | if not self.use_z: 84 | return None 85 | 86 | return tf.exp(self.qzero) 87 | 88 | def get_params_W(self): 89 | return tf.exp(self.logvar_W) 90 | 91 | def kldiv(self): 92 | M, logdets = self.sample_z() 93 | logdets = logdets[0] 94 | M = tf.squeeze(M) 95 | 96 | std_mg = self.get_params_W() 97 | qm0 = self.get_params_m() 98 | if len(M.get_shape()) == 0: 99 | Mexp = M 100 | else: 101 | Mexp = tf.expand_dims(M, 1) 102 | 103 | Mtilde = Mexp * self.mu_W 104 | Vtilde = tf.square(std_mg) 105 | 106 | iUp = outer(tf.exp(self.pvar), ones_d((self.output_dim,))) 107 | 108 | logqm = 0. 109 | if self.use_z: 110 | logqm = - tf.reduce_sum(.5 * (tf.log(2 * np.pi) + tf.log(qm0) + 1)) 111 | logqm -= logdets 112 | 113 | kldiv_w = tf.reduce_sum(.5 * tf.log(iUp) - tf.log(std_mg) + ((Vtilde + tf.square(Mtilde)) / (2 * iUp)) - .5) 114 | kldiv_bias = tf.reduce_sum(.5 * self.pvar_bias - .5 * self.logvar_bias + ((tf.exp(self.logvar_bias) + 115 | tf.square(self.mu_bias)) / (2 * tf.exp(self.pvar_bias))) - .5) 116 | 117 | if self.use_z: 118 | apvar_M = self.apvar_M 119 | # shared network for hidden layer 120 | mw = tf.matmul(tf.expand_dims(apvar_M, 0), Mtilde) 121 | eps = tf.expand_dims(tf.random_normal((self.output_dim,)), 0) 122 | varw = tf.matmul(tf.square(tf.expand_dims(apvar_M, 0)), Vtilde) 123 | a = tf.nn.tanh(mw + tf.sqrt(varw) * eps) 124 | # split at output layer 125 | if len(tf.squeeze(a).get_shape()) != 0: 126 | w__ = tf.reduce_mean(outer(self.rsr_M, tf.squeeze(a)), axis=1) 127 | wv__ = tf.reduce_mean(outer(self.rsri_M, tf.squeeze(a)), axis=1) 128 | else: 129 | w__ = self.rsr_M * tf.squeeze(a) 130 | wv__ = self.rsri_M * tf.squeeze(a) 131 | 132 | logrm = 0. 133 | if self.flow_r is not None: 134 | M, logrm = self.flow_r.get_output_for(tf.expand_dims(M, 0)) 135 | M = tf.squeeze(M) 136 | logrm = logrm[0] 137 | 138 | logrm += tf.reduce_sum(-.5 * tf.exp(wv__) * tf.square(M - w__) - .5 * tf.log(2 * np.pi) + .5 * wv__) 139 | else: 140 | logrm = 0. 141 | 142 | return - kldiv_w + logrm - logqm - kldiv_bias 143 | 144 | def call(self, x, sample=True, **kwargs): 145 | std_mg = tf.clip_by_value(self.get_params_W(), 0., self.thres_var) 146 | var_mg = tf.square(std_mg) 147 | sample_M, _ = self.sample_z(size_M=tf.shape(x)[0], sample=sample) 148 | xt = x * sample_M 149 | 150 | mu_out = tf.matmul(xt, self.mu_W) + self.mu_bias 151 | varin = tf.matmul(tf.square(x), var_mg) + tf.clip_by_value(tf.exp(self.logvar_bias), 0., self.thres_var**2) 152 | xin = tf.sqrt(varin) 153 | sigma_out = xin * tf.random_normal(tf.shape(mu_out)) 154 | 155 | output = mu_out + sigma_out if sample else mu_out 156 | return output 157 | 158 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from utils import get_layer_uid 3 | 4 | 5 | class Layer(object): 6 | def __init__(self, nonlin=tf.identity, N=1, name=None, logging=False): 7 | self.N = N 8 | if name is None: 9 | layer = self.__class__.__name__.lower() 10 | name = layer + '_' + str(get_layer_uid(layer)) 11 | self.name = name 12 | self.logging = logging 13 | self.nonlinearity = nonlin 14 | self.build() 15 | print 'Logging: {}'.format(self.logging) 16 | 17 | def __call__(self, x, sample=True, **kwargs): 18 | with tf.name_scope(self.name): 19 | if self.logging: 20 | tf.summary.histogram(self.name + '/inputs', x) 21 | output = self.call(x, sample=sample, **kwargs) 22 | if self.logging: 23 | tf.summary.histogram(self.name + '/outputs', output) 24 | outputs = self.nonlinearity(output) 25 | return outputs 26 | 27 | def call(self, x, sample=True, **kwargs): 28 | raise NotImplementedError() 29 | 30 | def build(self): 31 | raise NotImplementedError() 32 | 33 | def f(self, x, sampling=True, **kwargs): 34 | raise NotImplementedError() 35 | 36 | def get_reg(self): 37 | return - (1. / self.N) * self.kldiv() 38 | 39 | def kldiv(self): 40 | raise NotImplementedError 41 | 42 | from dense_layers import * 43 | from conv_layers import * 44 | -------------------------------------------------------------------------------- /mnf_lenet_mnist.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from progressbar import ETA, Bar, Percentage, ProgressBar 4 | from keras.utils.np_utils import to_categorical 5 | from mnist import MNIST 6 | import time, os 7 | from wrappers import MNFLeNet 8 | 9 | 10 | def train(): 11 | mnist = MNIST() 12 | (xtrain, ytrain), (xvalid, yvalid), (xtest, ytest) = mnist.images() 13 | xtrain, xvalid, xtest = np.transpose(xtrain, [0, 2, 3, 1]), np.transpose(xvalid, [0, 2, 3, 1]), np.transpose(xtest, [0, 2, 3, 1]) 14 | ytrain, yvalid, ytest = to_categorical(ytrain, 10), to_categorical(yvalid, 10), to_categorical(ytest, 10) 15 | 16 | N, height, width, n_channels = xtrain.shape 17 | iter_per_epoch = N / 100 18 | 19 | sess = tf.InteractiveSession() 20 | 21 | input_shape = [None, height, width, n_channels] 22 | x = tf.placeholder(tf.float32, input_shape, name='x') 23 | y_ = tf.placeholder(tf.float32, [None, 10], name='y_') 24 | 25 | model = MNFLeNet(N, input_shape=input_shape, flows_q=FLAGS.fq, flows_r=FLAGS.fr, use_z=not FLAGS.no_z, 26 | learn_p=FLAGS.learn_p, thres_var=FLAGS.thres_var, flow_dim_h=FLAGS.flow_h) 27 | 28 | tf.set_random_seed(FLAGS.seed) 29 | np.random.seed(FLAGS.seed) 30 | y = model.predict(x) 31 | yd = model.predict(x, sample=False) 32 | pyx = tf.nn.softmax(y) 33 | 34 | with tf.name_scope('KL_prior'): 35 | regs = model.get_reg() 36 | tf.summary.scalar('KL prior', regs) 37 | 38 | with tf.name_scope('cross_entropy'): 39 | cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=y_)) 40 | tf.summary.scalar('Loglike', cross_entropy) 41 | 42 | global_step = tf.Variable(0, trainable=False) 43 | if FLAGS.anneal: 44 | number_zero, original_zero = FLAGS.epzero, FLAGS.epochs / 2 45 | with tf.name_scope('annealing_beta'): 46 | max_zero_step = number_zero * iter_per_epoch 47 | original_anneal = original_zero * iter_per_epoch 48 | beta_t_val = tf.cast((tf.cast(global_step, tf.float32) - max_zero_step) / original_anneal, tf.float32) 49 | beta_t = tf.maximum(beta_t_val, 0.) 50 | annealing = tf.minimum(1., tf.cond(global_step < max_zero_step, lambda: tf.zeros((1,))[0], lambda: beta_t)) 51 | tf.summary.scalar('annealing beta', annealing) 52 | else: 53 | annealing = 1. 54 | 55 | with tf.name_scope('lower_bound'): 56 | lowerbound = cross_entropy + annealing * regs 57 | tf.summary.scalar('Lower bound', lowerbound) 58 | 59 | train_step = tf.train.AdamOptimizer(learning_rate=FLAGS.lr).minimize(lowerbound, global_step=global_step) 60 | 61 | with tf.name_scope('accuracy'): 62 | correct_prediction = tf.equal(tf.argmax(yd, 1), tf.argmax(y_, 1)) 63 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 64 | tf.summary.scalar('Accuracy', accuracy) 65 | 66 | merged = tf.summary.merge_all() 67 | train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train', sess.graph) 68 | 69 | tf.add_to_collection('logits', y) 70 | tf.add_to_collection('logits_map', yd) 71 | tf.add_to_collection('accuracy', accuracy) 72 | tf.add_to_collection('x', x) 73 | tf.add_to_collection('y', y_) 74 | saver = tf.train.Saver(tf.global_variables()) 75 | 76 | tf.global_variables_initializer().run() 77 | 78 | idx = np.arange(N) 79 | steps = 0 80 | model_dir = './models/mnf_lenet_mnist_fq{}_fr{}_usez{}_thres{}/model/'.format(FLAGS.fq, FLAGS.fr, not FLAGS.no_z, 81 | FLAGS.thres_var) 82 | if not os.path.exists(model_dir): 83 | os.makedirs(model_dir) 84 | print 'Will save model as: {}'.format(model_dir + 'model') 85 | # Train 86 | for epoch in xrange(FLAGS.epochs): 87 | widgets = ["epoch {}/{}|".format(epoch + 1, FLAGS.epochs), Percentage(), Bar(), ETA()] 88 | pbar = ProgressBar(iter_per_epoch, widgets=widgets) 89 | pbar.start() 90 | np.random.shuffle(idx) 91 | t0 = time.time() 92 | for j in xrange(iter_per_epoch): 93 | steps += 1 94 | pbar.update(j) 95 | batch = np.random.choice(idx, 100) 96 | if j == (iter_per_epoch - 1): 97 | summary, _ = sess.run([merged, train_step], feed_dict={x: xtrain[batch], y_: ytrain[batch]}) 98 | train_writer.add_summary(summary, steps) 99 | train_writer.flush() 100 | else: 101 | sess.run(train_step, feed_dict={x: xtrain[batch], y_: ytrain[batch]}) 102 | 103 | # the accuracy here is calculated by a crude MAP so as to have fast evaluation 104 | # it is much better if we properly integrate over the parameters by averaging across multiple samples 105 | tacc = sess.run(accuracy, feed_dict={x: xvalid, y_: yvalid}) 106 | string = 'Epoch {}/{}, valid_acc: {:0.3f}'.format(epoch + 1, FLAGS.epochs, tacc) 107 | 108 | if (epoch + 1) % 10 == 0: 109 | string += ', model_save: True' 110 | saver.save(sess, model_dir + 'model') 111 | 112 | string += ', dt: {:0.3f}'.format(time.time() - t0) 113 | print string 114 | 115 | saver.save(sess, model_dir + 'model') 116 | train_writer.close() 117 | 118 | preds = np.zeros_like(ytest) 119 | widgets = ["Sampling |", Percentage(), Bar(), ETA()] 120 | pbar = ProgressBar(FLAGS.L, widgets=widgets) 121 | pbar.start() 122 | for i in xrange(FLAGS.L): 123 | pbar.update(i) 124 | for j in xrange(xtest.shape[0] / 100): 125 | pyxi = sess.run(pyx, feed_dict={x: xtest[j * 100:(j + 1) * 100]}) 126 | preds[j * 100:(j + 1) * 100] += pyxi / FLAGS.L 127 | print 128 | sample_accuracy = np.mean(np.equal(np.argmax(preds, 1), np.argmax(ytest, 1))) 129 | print 'Sample test accuracy: {}'.format(sample_accuracy) 130 | 131 | 132 | def main(): 133 | if tf.gfile.Exists(FLAGS.summaries_dir): 134 | tf.gfile.DeleteRecursively(FLAGS.summaries_dir) 135 | tf.gfile.MakeDirs(FLAGS.summaries_dir) 136 | train() 137 | 138 | if __name__ == '__main__': 139 | import argparse 140 | parser = argparse.ArgumentParser() 141 | parser.add_argument('--summaries_dir', type=str, default='logs/mnf_lenet', 142 | help='Summaries directory') 143 | parser.add_argument('-epochs', type=int, default=100) 144 | parser.add_argument('-epzero', type=int, default=1) 145 | parser.add_argument('-fq', default=2, type=int) 146 | parser.add_argument('-fr', default=2, type=int) 147 | parser.add_argument('-no_z', action='store_true') 148 | parser.add_argument('-seed', type=int, default=1) 149 | parser.add_argument('-lr', type=float, default=0.001) 150 | parser.add_argument('-thres_var', type=float, default=0.5) 151 | parser.add_argument('-flow_h', type=int, default=50) 152 | parser.add_argument('-L', type=int, default=100) 153 | parser.add_argument('-anneal', action='store_true') 154 | parser.add_argument('-learn_p', action='store_true') 155 | FLAGS = parser.parse_args() 156 | main() 157 | -------------------------------------------------------------------------------- /mnist.pkl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AMLab-Amsterdam/MNF_VBNN/900cdf1e28a1f172a4abcfc6a0d3518bcb4f0f05/mnist.pkl.gz -------------------------------------------------------------------------------- /mnist.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import cPickle as pkl 3 | import numpy as np 4 | 5 | 6 | class MNIST(object): 7 | def __int__(self): 8 | self.nb_classes = 10 9 | self.name = self.__class__.__name__.lower() 10 | 11 | def load_data(self): 12 | with gzip.open('mnist.pkl.gz', 'rb') as f: 13 | try: 14 | train_set, valid_set, test_set = pkl.load(f, encoding='latin1') 15 | except: 16 | train_set, valid_set, test_set = pkl.load(f) 17 | return [train_set[0], train_set[1]], [valid_set[0], valid_set[1]], [test_set[0], test_set[1]] 18 | 19 | def permutation_invariant(self, n=None): 20 | train, valid, test = self.load_data() 21 | return train, valid, test 22 | 23 | def images(self, n=None): 24 | train, valid, test = self.load_data() 25 | train[0] = np.reshape(train[0], (train[0].shape[0], 1, 28, 28)) 26 | valid[0] = np.reshape(valid[0], (valid[0].shape[0], 1, 28, 28)) 27 | test[0] = np.reshape(test[0], (test[0].shape[0], 1, 28, 28)) 28 | return train, valid, test 29 | -------------------------------------------------------------------------------- /norm_flows.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from utils import randmat, get_layer_uid, zeros_d, random_bernoulli 3 | 4 | 5 | class MaskedNVPFlow(object): 6 | """ 7 | """ 8 | def __init__(self, incoming, n_flows=2, n_hidden=0, dim_h=10, name=None, scope=None, nonlin=tf.nn.tanh, **kwargs): 9 | self.incoming = incoming 10 | self.n_flows = n_flows 11 | self.n_hidden = n_hidden 12 | if name is None: 13 | layer = self.__class__.__name__.lower() 14 | self.name = layer + '_' + str(get_layer_uid(layer)) 15 | else: 16 | self.name = name 17 | self.dim_h = dim_h 18 | self.params = [] 19 | self.nonlin = nonlin 20 | self.scope = scope 21 | self.build() 22 | print 'MaskedNVP flow {} with length: {}, n_hidden: {}, dim_h: {}, name: {}, ' \ 23 | 'scope: {}'.format(self.name, n_flows, n_hidden, dim_h, name, scope) 24 | 25 | def build_mnn(self, fid, param_list): 26 | dimin = self.incoming 27 | with tf.variable_scope(self.scope): 28 | w = randmat((dimin, self.dim_h), name='w{}_{}_{}'.format(0, self.name, fid)) 29 | b = tf.Variable(tf.zeros((self.dim_h,)), name='b{}_{}_{}'.format(0, self.name, fid)) 30 | param_list.append([(w, b)]) 31 | for l in xrange(self.n_hidden): 32 | wh = randmat((self.dim_h, self.dim_h), name='w{}_{}_{}'.format(l + 1, self.name, fid)) 33 | bh = tf.Variable(tf.zeros((self.dim_h,)), name='b{}_{}_{}'.format(l + 1, self.name, fid)) 34 | param_list[-1].append((wh, bh)) 35 | wout = randmat((self.dim_h, dimin), name='w{}_{}_{}'.format(self.n_hidden, self.name, fid)) 36 | bout = tf.Variable(tf.zeros((dimin,)), name='b{}_{}_{}'.format(self.n_hidden, self.name, fid)) 37 | wout2 = randmat((self.dim_h, dimin), name='w{}_{}_{}_sigma'.format(self.n_hidden, self.name, fid)) 38 | bout2 = tf.Variable(tf.ones((dimin,)) * 2, name='b{}_{}_{}_sigma'.format(self.n_hidden, self.name, fid)) 39 | param_list[-1].append((wout, bout, wout2, bout2)) 40 | 41 | def build(self): 42 | for flow in xrange(self.n_flows): 43 | self.build_mnn('muf_{}'.format(flow), self.params) 44 | 45 | def ff(self, x, weights): 46 | inputs = [x] 47 | for j in xrange(len(weights[:-1])): 48 | h = tf.matmul(inputs[-1], weights[j][0]) + weights[j][1] 49 | inputs.append(self.nonlin(h)) 50 | wmu, bmu, wsigma, bsigma = weights[-1] 51 | mean = tf.matmul(inputs[-1], wmu) + bmu 52 | sigma = tf.matmul(inputs[-1], wsigma) + bsigma 53 | return mean, sigma 54 | 55 | def get_output_for(self, z, sample=True): 56 | logdets = zeros_d((tf.shape(z)[0],)) 57 | for flow in xrange(self.n_flows): 58 | mask = random_bernoulli(tf.shape(z), p=0.5) if sample else 0.5 59 | ggmu, ggsigma = self.ff(mask * z, self.params[flow]) 60 | gate = tf.nn.sigmoid(ggsigma) 61 | logdets += tf.reduce_sum((1 - mask) * tf.log(gate), axis=1) 62 | z = (1 - mask) * (z * gate + (1 - gate) * ggmu) + mask * z 63 | 64 | return z, logdets 65 | 66 | 67 | class PlanarFlow(object): 68 | """ 69 | """ 70 | def __init__(self, incoming, n_flows=2, name=None, scope=None, **kwargs): 71 | self.incoming = incoming 72 | self.n_flows = n_flows 73 | self.sigma = 0.01 74 | self.params = [] 75 | self.name = name 76 | self.scope = scope 77 | self.build() 78 | print 'Planar flow layer with nf: {}, name: {}, scope: {}'.format(n_flows, name, scope) 79 | 80 | def build(self): 81 | with tf.variable_scope(self.scope): 82 | for flow in xrange(self.n_flows): 83 | w = randmat((self.incoming, 1), name='w_{}_{}'.format(flow, self.name)) 84 | u = randmat((self.incoming, 1), name='u_{}_{}'.format(flow, self.name)) 85 | b = tf.Variable(tf.zeros((1,)), name='b_{}_{}'.format(flow, self.name)) 86 | self.params.append([w, u, b]) 87 | 88 | def get_output_for(self, z, **kwargs): 89 | logdets = zeros_d((tf.shape(z)[0],)) 90 | for flow in xrange(self.n_flows): 91 | w, u, b = self.params[flow] 92 | uw = tf.reduce_sum(u * w) 93 | muw = -1 + tf.nn.softplus(uw) # = -1 + T.log(1 + T.exp(uw)) 94 | u_hat = u + (muw - uw) * w / tf.reduce_sum(w ** 2) 95 | if len(z.get_shape()) == 1: 96 | zwb = z * w + b 97 | else: 98 | zwb = tf.matmul(z, w) + b 99 | psi = tf.matmul(1 - tf.nn.tanh(zwb) ** 2, tf.transpose(w)) # tanh(x)dx = 1 - tanh(x)**2 100 | psi_u = tf.matmul(psi, u_hat) 101 | logdets += tf.squeeze(tf.log(tf.abs(1 + psi_u))) 102 | zadd = tf.matmul(tf.nn.tanh(zwb), tf.transpose(u_hat)) 103 | z += zadd 104 | return z, logdets -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import os, sys 4 | 5 | _LAYER_UIDS = {} 6 | 7 | tf.set_random_seed(1) 8 | prng = np.random.RandomState(1) 9 | 10 | flags = tf.app.flags 11 | FLAGS = flags.FLAGS 12 | sigma_init = 0.01 13 | 14 | DATA_DIR = os.environ['DATA_DIR'] 15 | dtype = tf.float32 16 | 17 | 18 | def get_layer_uid(layer_name=''): 19 | """Helper function, assigns unique layer IDs 20 | """ 21 | if layer_name not in _LAYER_UIDS: 22 | _LAYER_UIDS[layer_name] = 1 23 | return 1 24 | else: 25 | _LAYER_UIDS[layer_name] += 1 26 | return _LAYER_UIDS[layer_name] 27 | 28 | 29 | def change_random_seed(seed): 30 | global prng 31 | prng = np.random.RandomState(seed) 32 | tf.set_random_seed(seed) 33 | 34 | 35 | def randmat(shape, name, mu=0., type_init='he2', type_dist='normal', trainable=True, extra_scale=1.): 36 | if len(shape) == 1: 37 | dim_in, dim_out = shape[0], 0 38 | elif len(shape) == 2: 39 | dim_in, dim_out = shape 40 | else: 41 | dim_in, dim_out = np.prod(shape[1:]), shape[0] 42 | if type_init == 'xavier': 43 | bound = np.sqrt(1. / dim_in) 44 | elif type_init == 'xavier2': 45 | bound = np.sqrt(2. / (dim_in + dim_out)) 46 | elif type_init == 'he': 47 | bound = np.sqrt(2. / dim_in) 48 | elif type_init == 'he2': 49 | bound = np.sqrt(4. / (dim_in + dim_out)) 50 | elif type_init == 'regular': 51 | bound = sigma_init 52 | else: 53 | raise Exception() 54 | if type_dist == 'normal': 55 | val = tf.random_normal(shape, mean=mu, stddev=extra_scale * bound, dtype=dtype) # actual weight initialization 56 | else: 57 | val = tf.random_uniform(shape, minval=mu - extra_scale * bound, maxval=mu + extra_scale * bound, dtype=dtype) 58 | 59 | return tf.Variable(initial_value=val, name=name, trainable=trainable) 60 | 61 | 62 | def ones_d(shape): 63 | if isinstance(shape, (list, tuple)): 64 | shape = tf.stack(shape) 65 | return tf.ones(shape) 66 | 67 | 68 | def zeros_d(shape): 69 | if isinstance(shape, (list, tuple)): 70 | shape = tf.stack(shape) 71 | return tf.zeros(shape) 72 | 73 | 74 | def random_bernoulli(shape, p=0.5): 75 | if isinstance(shape, (list, tuple)): 76 | shape = tf.stack(shape) 77 | return tf.where(tf.random_uniform(shape) < p, tf.ones(shape), tf.zeros(shape)) 78 | 79 | 80 | def outer(x, y): 81 | return tf.matmul(tf.expand_dims(x, 1), tf.transpose(tf.expand_dims(y, 1))) 82 | 83 | -------------------------------------------------------------------------------- /wrappers.py: -------------------------------------------------------------------------------- 1 | from layers import * 2 | from tensorflow.contrib import slim 3 | 4 | 5 | class MNFLeNet(object): 6 | def __init__(self, N, input_shape, flows_q=2, flows_r=2, use_z=True, activation=tf.nn.relu, logging=False, 7 | nb_classes=10, learn_p=False, layer_dims=(20, 50, 500), flow_dim_h=50, thres_var=1, prior_var_w=1., 8 | prior_var_b=1.): 9 | self.layer_dims = layer_dims 10 | self.activation = activation 11 | self.N = N 12 | self.input_shape = input_shape 13 | self.flows_q = flows_q 14 | self.flows_r = flows_r 15 | self.use_z = use_z 16 | 17 | self.logging = logging 18 | self.nb_classes = nb_classes 19 | self.flow_dim_h = flow_dim_h 20 | self.thres_var = thres_var 21 | self.learn_p = learn_p 22 | self.prior_var_w = prior_var_w 23 | self.prior_var_b = prior_var_b 24 | 25 | self.opts = 'fq{}_fr{}_usez{}'.format(self.flows_q, self.flows_r, self.use_z) 26 | self.built = False 27 | 28 | def build_mnf_lenet(self, x, sample=True): 29 | if not self.built: 30 | self.layers = [] 31 | with tf.variable_scope(self.opts): 32 | if not self.built: 33 | layer1 = Conv2DMNF(self.layer_dims[0], 5, 5, N=self.N, input_shape=self.input_shape, border_mode='VALID', 34 | flows_q=self.flows_q, flows_r=self.flows_r, logging=self.logging, use_z=self.use_z, 35 | learn_p=self.learn_p, prior_var=self.prior_var_w, prior_var_b=self.prior_var_b, 36 | thres_var=self.thres_var, flow_dim_h=self.flow_dim_h) 37 | self.layers.append(layer1) 38 | else: 39 | layer1 = self.layers[0] 40 | h1 = self.activation(tf.nn.max_pool(layer1(x, sample=sample), [1, 2, 2, 1], [1, 2, 2, 1], 'SAME')) 41 | 42 | if not self.built: 43 | shape = [None] + [s.value for s in h1.get_shape()[1:]] 44 | layer2 = Conv2DMNF(self.layer_dims[1], 5, 5, N=self.N, input_shape=shape, border_mode='VALID', 45 | flows_q=self.flows_q, flows_r=self.flows_r, use_z=self.use_z, logging=self.logging, 46 | learn_p=self.learn_p, flow_dim_h=self.flow_dim_h, thres_var=self.thres_var, 47 | prior_var=self.prior_var_w, prior_var_b=self.prior_var_b) 48 | self.layers.append(layer2) 49 | else: 50 | layer2 = self.layers[1] 51 | h2 = slim.flatten(self.activation(tf.nn.max_pool(layer2(h1, sample=sample), [1, 2, 2, 1], [1, 2, 2, 1], 'SAME'))) 52 | 53 | if not self.built: 54 | fcinp_dim = h2.get_shape()[1].value 55 | layer3 = DenseMNF(self.layer_dims[2], N=self.N, input_dim=fcinp_dim, flows_q=self.flows_q, 56 | flows_r=self.flows_r, use_z=self.use_z, logging=self.logging, learn_p=self.learn_p, 57 | prior_var=self.prior_var_w, prior_var_b=self.prior_var_b, flow_dim_h=self.flow_dim_h, 58 | thres_var=self.thres_var) 59 | self.layers.append(layer3) 60 | else: 61 | layer3 = self.layers[2] 62 | h3 = self.activation(layer3(h2, sample=sample)) 63 | 64 | if not self.built: 65 | fcinp_dim = h3.get_shape()[1].value 66 | layerout = DenseMNF(self.nb_classes, N=self.N, input_dim=fcinp_dim, flows_q=self.flows_q, 67 | flows_r=self.flows_r, use_z=self.use_z, logging=self.logging, learn_p=self.learn_p, 68 | prior_var=self.prior_var_w, prior_var_b=self.prior_var_b, flow_dim_h=self.flow_dim_h, 69 | thres_var=self.thres_var) 70 | self.layers.append(layerout) 71 | else: 72 | layerout = self.layers[3] 73 | 74 | if not self.built: 75 | self.built = True 76 | return layerout(h3, sample=sample) 77 | 78 | def predict(self, x, sample=True): 79 | return self.build_mnf_lenet(x, sample=sample) 80 | 81 | def get_reg(self): 82 | reg = 0. 83 | for j, layer in enumerate(self.layers): 84 | with tf.name_scope('kl_layer{}'.format(j + 1)): 85 | regi = layer.get_reg() 86 | tf.summary.scalar('kl_layer{}'.format(j + 1), regi) 87 | reg += regi 88 | 89 | return reg 90 | 91 | 92 | --------------------------------------------------------------------------------