├── README.md ├── bayes_by_backprop.py └── net.py /README.md: -------------------------------------------------------------------------------- 1 | # Bayes-by-Backprop 2 | 3 | A simple Python implementation of Bayes by Backprop. 4 | The original paper is "Weight Uncertainty in Neural Networks", Charles Blundell, Julien Cornebise, Koray Kavukcuoglu, and Daan Wierstra: https://arxiv.org/abs/1505.05424 5 | 6 | ## Requirements 7 | 8 | - Python (2.7.6) 9 | - Chainer (1.16.0) 10 | - six (1.10.0) 11 | - numpy (1.11.1) 12 | - scikit-learn (0.17.1) 13 | 14 | at least worked well on Ubuntu 14.04 15 | 16 | ## Usage 17 | 18 | ``` 19 | python bayes_by_backprop.py 20 | ``` 21 | 22 | Still incomplete (Nov, 2016) 23 | If you find bugs, I hope you could repote them. 24 | -------------------------------------------------------------------------------- /bayes_by_backprop.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import chainer.functions as F 5 | from chainer import Variable, optimizers 6 | import numpy as np 7 | from sklearn.datasets import fetch_mldata 8 | from sklearn.cross_validation import train_test_split 9 | from sklearn import preprocessing 10 | import net 11 | import six 12 | import os 13 | import time 14 | 15 | def mnist_preprocessing(sample_N = 600000, test_ratio = 0.25): 16 | if os.path.exists("mnist_preprocessed_data.npy"): 17 | x = np.load("mnist_preprocessed_data.npy") 18 | y = np.load("mnist_preprocessed_target.npy") 19 | #y = np.int32(preprocessing.OneHotEncoder(sparse=False).fit_transform(y.reshape(y.shape[0],1))) 20 | idx = np.random.choice(x.shape[0], sample_N) 21 | x = x[idx] 22 | y = y[idx] 23 | else: 24 | mnist = fetch_mldata('MNIST original') 25 | x = np.float32(mnist.data[:]) / 126. 26 | np.save("mnist_preprocessed_data",x) 27 | y = np.int32(mnist.target) 28 | #y = np.int32(preprocessing.OneHotEncoder(sparse=False).fit_transform(y.reshape(y.shape[0],1))) 29 | np.save("mnist_preprocessed_target",y) 30 | idx = np.random.choice(x.shape[0], sample_N) 31 | x = x[idx] 32 | y = y[idx] 33 | 34 | tr_idx, te_idx = train_test_split(np.arange(sample_N), test_size = test_ratio) 35 | tr_x, te_x = x[tr_idx], x[te_idx] 36 | tr_y, te_y = y[tr_idx], y[te_idx] 37 | 38 | return tr_x,te_x,tr_y,te_y 39 | 40 | def get_gaussianloglikelihood_pw(x,mu,sigma): 41 | return -0.5 * np.log(2*np.pi) - np.log(sigma) - (x - mu)**2 / (2 * sigma**2) 42 | 43 | def get_gaussianloglikelihood_qw(x,mu,sigma): 44 | return -0.5 * np.log(2*np.pi) - F.log(sigma) - (x - mu)**2 / (2 * sigma**2) 45 | 46 | """ 47 | sample_N = 6000 48 | test_ratio = 0.25 49 | tr_x, te_x, tr_y, te_y = mnist_preprocessing(sample_N, test_ratio) 50 | """ 51 | 52 | class BBP_agent(object): 53 | """docstring for BPP_agent""" 54 | def __init__(self, model_num = 3, sample_N = 60000, test_ratio = .25, batch_size = 32, max_iter = 100): 55 | super(BBP_agent, self).__init__() 56 | self.sample_N = sample_N 57 | self.test_ratio = test_ratio 58 | self.batch_size = batch_size 59 | self.max_iter = max_iter 60 | self.model_num = model_num 61 | 62 | def prepare_data(self): 63 | self.tr_x, self.te_x, self.tr_y, self.te_y = mnist_preprocessing(self.sample_N, self.test_ratio) 64 | self.tr_N = self.tr_x.shape[0] 65 | self.M = self.tr_N / float(self.batch_size) 66 | 67 | def set_model_parameter(self): 68 | self.prior_ratio = np.float32(0.5) 69 | self.prior_sigma_1 = np.float32(np.exp(-1)) 70 | self.prior_sigma_2 = np.float32(np.exp(-7)) 71 | self.n_in = self.tr_x.shape[1] 72 | self.n_hidden1 = 200 73 | self.n_hidden2 = 200 74 | self.n_out = 10 75 | self.prior_pho_var = np.float32(.05) 76 | self.model = net.MLP_MNIST_bbp(self.n_in, self.n_hidden1, self.n_hidden2, self.n_out, self.prior_ratio, 77 | self.prior_sigma_1, self.prior_sigma_2, self.prior_pho_var) 78 | 79 | 80 | def pD_w(self,Data_indices): 81 | in_data = Variable(self.tr_x[Data_indices]) 82 | t = Variable(self.tr_y[Data_indices]) 83 | cross_entropy = 0. 84 | for i in range(self.model_num): 85 | w1,w2,w3 = self.models[i] 86 | w1 = F.reshape(w1,(self.n_in,self.n_hidden1)) 87 | w2 = F.reshape(w2,(self.n_hidden1,self.n_hidden2)) 88 | w3 = F.reshape(w3,(self.n_hidden2,self.n_out)) 89 | h1 = F.relu(F.matmul(in_data,w1)) 90 | h2 = F.relu(F.matmul(h1,w2)) 91 | pred = F.softmax(F.matmul(h2,w3)) 92 | cross_entropy += F.softmax_cross_entropy(pred,t) 93 | return -1 * cross_entropy 94 | 95 | def KL_minibatch(self): 96 | log_qw_theta_sum = 0. 97 | log_pw_sum = 0. 98 | for i in range(self.model_num): 99 | w1,w2,w3 = self.models[i] 100 | w = F.hstack([w1,w2,w3]) 101 | log_qw_theta = get_gaussianloglikelihood_qw(w,self.model.mu_hstack(),self.model.sigma_hstack()) 102 | #log_qw_theta_sum += F.sum(log_qw_theta, axis=1) 103 | log_qw_theta_sum += F.sum(log_qw_theta) 104 | log_pw = self.prior_ratio * get_gaussianloglikelihood_pw(w,0,self.prior_sigma_1) + (1 - self.prior_ratio) * get_gaussianloglikelihood_pw(w,0,self.prior_sigma_2) 105 | #log_pw_sum += F.sum(log_pw, axis=1) 106 | log_pw_sum += F.sum(log_pw) 107 | return (log_qw_theta_sum - log_pw_sum) / self.M 108 | 109 | 110 | def fit(self): 111 | now = time.time() 112 | for iter_ in range(self.max_iter): 113 | perm_tr = np.random.permutation(self.tr_N) 114 | for batch_idx in six.moves.range(0,self.tr_N,self.batch_size): 115 | print("start_batch:{}".format(batch_idx)) 116 | Data_indices = perm_tr[batch_idx:batch_idx + self.batch_size] 117 | #self.model.zerograds() 118 | self.models = [] 119 | for i in range(self.model_num): 120 | self.models.append(self.model()) 121 | print("finish_models_append") 122 | start_f_calc = time.time() 123 | t1 = self.KL_minibatch() 124 | t2 = self.pD_w(Data_indices) 125 | f_batch = t1 - t2 126 | end_f_calc = time.time() 127 | print("finish_f_calculation:{}".format(end_f_calc -start_f_calc)) 128 | print("f_batch_grad:{}".format(f_batch.grad)) 129 | #print("mu1_grad:{}".format(self.model.mu1.grad.shape)) 130 | f_batch.backward() 131 | print("f_batch_grad:{}".format(f_batch.grad)) 132 | print("finish_f_backward:{}".format(time.time() - end_f_calc)) 133 | print("mu1_grad:{}".format(self.model.mu1.grad.shape)) 134 | self.model.update() 135 | print(iter_,f_batch.data) 136 | 137 | class BBP_agent2(object): 138 | """docstring for BPP_agent""" 139 | def __init__(self, model_num = 3, sample_N = 60000, test_ratio = .25, batch_size = 32, max_iter = 100, lr = 1e-4): 140 | super(BBP_agent2, self).__init__() 141 | self.sample_N = sample_N 142 | self.test_ratio = test_ratio 143 | self.batch_size = batch_size 144 | self.max_iter = max_iter 145 | self.model_num = model_num 146 | self.lr = lr 147 | 148 | def prepare_data(self): 149 | self.tr_x, self.te_x, self.tr_y, self.te_y = mnist_preprocessing(self.sample_N, self.test_ratio) 150 | self.tr_N = self.tr_x.shape[0] 151 | self.M = self.tr_N / float(self.batch_size) 152 | 153 | def set_model_parameter(self): 154 | self.prior_ratio = np.float32(0.5) 155 | self.prior_sigma_1 = np.float32(np.exp(-1)) 156 | self.prior_sigma_2 = np.float32(np.exp(-7)) 157 | self.n_in = self.tr_x.shape[1] 158 | self.n_hidden1 = 500 159 | self.n_hidden2 = 500 160 | self.n_out = 10 161 | self.prior_pho_var = np.float32(.05) 162 | self.model = net.MLP_MNIST_bbp(n_in = self.n_in, n_hidden1 = self.n_hidden1, n_hidden2 = self.n_hidden2, 163 | n_out = self.n_out, lr = self.lr, prior_ratio = self.prior_ratio, 164 | prior_sigma_1 = self.prior_sigma_1, prior_sigma_2 = self.prior_sigma_2, prior_pho_var = self.prior_pho_var) 165 | #self.optimizer = optimizers.Adam() 166 | self.optimizer = optimizers.SGD(lr = self.lr) 167 | self.optimizer.setup(self.model) 168 | 169 | 170 | def pD_w(self,Data_indices): 171 | in_data = Variable(self.tr_x[Data_indices]) 172 | t = Variable(self.tr_y[Data_indices]) 173 | cross_entropy = 0. 174 | for i in range(self.model_num): 175 | w1,w2,w3 = self.models[i] 176 | w1 = F.reshape(w1,(self.n_in,self.n_hidden1)) 177 | w2 = F.reshape(w2,(self.n_hidden1,self.n_hidden2)) 178 | w3 = F.reshape(w3,(self.n_hidden2,self.n_out)) 179 | h1 = F.relu(F.matmul(in_data,w1)) 180 | h2 = F.relu(F.matmul(h1,w2)) 181 | pred = F.softmax(F.matmul(h2,w3)) 182 | cross_entropy += F.softmax_cross_entropy(pred,t) 183 | return -1 * cross_entropy 184 | 185 | def KL_minibatch(self): 186 | log_qw_theta_sum = 0. 187 | log_pw_sum = 0. 188 | w = self.model.w_hstack() 189 | log_qw_theta = get_gaussianloglikelihood_qw(w,self.model.mu_hstack(),self.model.sigma_hstack()) 190 | #log_qw_theta_sum += F.sum(log_qw_theta, axis=1) 191 | log_qw_theta_sum += F.sum(log_qw_theta) 192 | log_pw = self.prior_ratio * get_gaussianloglikelihood_pw(w,0,self.prior_sigma_1) + (1 - self.prior_ratio) * get_gaussianloglikelihood_pw(w,0,self.prior_sigma_2) 193 | #log_pw_sum += F.sum(log_pw, axis=1) 194 | log_pw_sum += F.sum(log_pw) 195 | #return (log_qw_theta_sum - log_pw_sum) / (self.M * self.model_num) 196 | return (log_qw_theta_sum - log_pw_sum) / self.M 197 | 198 | def fit(self): 199 | now = time.time() 200 | for iter_ in range(self.max_iter): 201 | perm_tr = np.random.permutation(self.tr_N) 202 | for batch_idx in six.moves.range(0,self.tr_N,self.batch_size): 203 | #print("start_batch:{}".format(batch_idx)) 204 | Data_indices = perm_tr[batch_idx:batch_idx + self.batch_size] 205 | #self.model.zerograds() 206 | self.models = [] 207 | f_batch_mean = 0. 208 | for i in range(self.model_num): 209 | self.model.zerograds() 210 | #self.models.append(self.model()) 211 | #print("finish_models_append") 212 | start_f_calc = time.time() 213 | #t2 = self.pD_w(Data_indices) 214 | in_data = Variable(self.tr_x[Data_indices]) 215 | t = Variable(self.tr_y[Data_indices]) 216 | #print(t.data) 217 | t2 = self.model(in_data,t) 218 | t1 = self.KL_minibatch() 219 | 220 | f_batch = t1 - t2 221 | print("KL_t1:{}".format(t1.data)) 222 | print("Lh_t2:{}".format(t2.data)) 223 | end_f_calc = time.time() 224 | #print("finish_f_calculation:{}".format(end_f_calc -start_f_calc)) 225 | #print("f_batch_grad:{}".format(f_batch.grad)) 226 | #print("mu1_grad:{}".format(self.model.mu1.grad.shape)) 227 | print("weight:{}".format(self.model.mu1.W.data[0,:5])) 228 | f_batch.backward(retain_grad = True) 229 | print("weight_grad:{}".format(self.model.mu1.W.grad[0,:5])) 230 | #print("f_batch_grad:{}".format(f_batch.grad)) 231 | #print("finish_f_backward:{}".format(time.time() - end_f_calc)) 232 | #print("mu1_grad:{}".format(self.model.mu1.W.grad.shape)) 233 | self.model.update(self.model_num) 234 | #self.optimizer.update() 235 | print("weight:{}".format(self.model.mu1.W.data[0,:5])) 236 | f_batch_mean += f_batch.data 237 | print("f_batch_mean:{}".format(f_batch_mean/float(self.model_num))) 238 | print(iter_,f_batch.data) 239 | 240 | agent = BBP_agent2(sample_N = 6000) 241 | agent.prepare_data() 242 | print("finish data preparation!!") 243 | agent.set_model_parameter() 244 | agent.fit() 245 | 246 | 247 | 248 | 249 | """ 250 | u1 = prior_ratio * np.random.normal(0,prior_sigma_1**2,n_in * n_hidden1) + (1 - prior_ratio) * np.random.normal(0,prior_sigma_1**2,n_in * n_hidden1) 251 | u1 = u1.reshape((n_in, n_hidden1)).astype(np.float32) 252 | u2 = prior_ratio * np.random.normal(0,prior_sigma_1**2,n_hidden1 * n_hidden2) + (1 - prior_ratio) * np.random.normal(0,prior_sigma_1**2,n_hidden1 * n_hidden2) 253 | u2 = u1.reshape((n_hidden1, n_hidden2)).astype(np.float32) 254 | u3 = prior_ratio * np.random.normal(0,prior_sigma_1**2,n_hidden2 * 10) + (1 - prior_ratio) * np.random.normal(0,prior_sigma_1**2,n_hidden2 * 10) 255 | u3 = u1.reshape((n_hidden2, n_out)).astype(np.float32) 256 | """ 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | -------------------------------------------------------------------------------- /net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | from chainer import Chain 5 | import chainer.functions as F 6 | import chainer.links as L 7 | from chainer import Variable 8 | import numpy as np 9 | 10 | class MLP_MNIST_bbp(Chain): 11 | def __init__(self, n_in = 784, n_hidden1 = 1200, n_hidden2 = 1200, n_out = 10, lr = 1e-4, prior_ratio = 0.5, prior_sigma_1 = np.exp(-1), prior_sigma_2 = np.exp(-7), prior_pho_var = .05): 12 | super(MLP_MNIST_bbp, self).__init__( 13 | mu1 = L.Linear(n_in, n_hidden1), 14 | mu2 = L.Linear(n_hidden1, n_hidden2), 15 | mu3 = L.Linear(n_hidden2, n_out), 16 | pho1 = L.Linear(n_in, n_hidden1), 17 | pho2 = L.Linear(n_hidden1, n_hidden2), 18 | pho3 = L.Linear(n_hidden2, n_out), 19 | w1 = L.Linear(n_in, n_hidden1), 20 | w2 = L.Linear(n_hidden1, n_hidden2), 21 | w3 = L.Linear(n_hidden2, n_out), 22 | ) 23 | self.n_in = n_in 24 | self.n_hidden1 = n_hidden1 25 | self.n_hidden2 = n_hidden2 26 | self.n_out = n_out 27 | self.lr = lr 28 | 29 | #mu1 = prior_ratio * np.random.normal(0,prior_sigma_1,n_in * n_hidden1) + (1 - prior_ratio) * np.random.normal(0,prior_sigma_1,n_in * n_hidden1) 30 | #mu1 = mu1.reshape((n_in, n_hidden1)).astype(np.float32) 31 | #mu2 = prior_ratio * np.random.normal(0,prior_sigma_1,n_hidden1 * n_hidden2) + (1 - prior_ratio) * np.random.normal(0,prior_sigma_1,n_hidden1 * n_hidden2) 32 | #mu2 = mu2.reshape((n_hidden1, n_hidden2)).astype(np.float32), 33 | #mu3 = prior_ratio * np.random.normal(0,prior_sigma_1,n_hidden2 * n_out) + (1 - prior_ratio) * np.random.normal(0,prior_sigma_1,n_hidden2 * n_out) 34 | #mu3 = mu3.reshape((n_hidden2, 10)).astype(np.float32) 35 | 36 | for i,m in enumerate([self.mu1,self.mu2,self.mu3]): 37 | tmp_w = prior_ratio * np.random.normal(0,prior_sigma_1,m.W.shape[1] * m.W.shape[0]) + (1 - prior_ratio) * np.random.normal(0,prior_sigma_2,m.W.shape[1] * m.W.shape[0]) 38 | m.W = Variable(tmp_w.reshape(m.W.shape).astype(np.float32)) 39 | tmp_b = prior_ratio * np.random.normal(0,prior_sigma_1,m.b.shape[0]) + (1 - prior_ratio) * np.random.normal(0,prior_sigma_2,m.b.shape[0]) 40 | m.b = Variable(tmp_b.reshape(m.b.shape).astype(np.float32)) 41 | 42 | for i,m in enumerate([self.pho1,self.pho2,self.pho3]): 43 | tmp_w = np.random.normal(0,prior_pho_var,m.W.shape[1] * m.W.shape[0]) 44 | m.W = Variable(tmp_w.reshape(m.W.shape).astype(np.float32)) 45 | tmp_b = np.random.normal(0,prior_pho_var,m.b.shape[0]) 46 | m.b = Variable(tmp_b.reshape(m.b.shape).astype(np.float32)) 47 | 48 | #for i,w in enumerate([self.w1,self.w2,self.w3]): 49 | 50 | """ 51 | self.mu1 = Variable(mu1.astype(np.float32)) 52 | self.mu2 = Variable(mu2.astype(np.float32)) 53 | self.mu3 = Variable(mu3.astype(np.float32)) 54 | 55 | pho1 = np.random.normal(0,prior_pho_var,n_in * n_hidden1) 56 | #pho1 = Variable(pho1.reshape((n_in, n_hidden1)).astype(np.float32)) 57 | self.pho1 = Variable(pho1.astype(np.float32)) 58 | pho2 = np.random.normal(0,prior_pho_var,n_hidden1 * n_hidden2) 59 | #pho2 = Variable(pho2.reshape((n_hidden1, n_hidden2)).astype(np.float32)) 60 | self.pho2 = Variable(pho2.astype(np.float32)) 61 | pho3 = np.random.normal(0,prior_pho_var,n_hidden2 * n_out) 62 | #pho3 = Variable(pho3.reshape((n_hidden2, 10)).astype(np.float32)) 63 | self.pho3 = Variable(pho3.astype(np.float32)) 64 | #L.add_parames("pho1",) 65 | """ 66 | 67 | def __call__(self,x,t): 68 | 69 | self.eps1_w = np.random.normal(0,1,self.mu1.W.shape).astype(np.float32) 70 | self.eps2_w = np.random.normal(0,1,self.mu2.W.shape).astype(np.float32) 71 | self.eps3_w = np.random.normal(0,1,self.mu3.W.shape).astype(np.float32) 72 | #self.eps1_w = eps1_w.reshape(self.m1.W.shape) 73 | #self.eps2_w = eps2_w.reshape(self.m2.W.shape) 74 | #self.eps3_w = eps3_w.reshape(self.m3.W.shape) 75 | 76 | self.w1.W = self.mu1.W + F.log(1 + F.exp(self.pho1.W))*Variable(self.eps1_w) 77 | self.w2.W = self.mu2.W + F.log(1 + F.exp(self.pho2.W))*Variable(self.eps2_w) 78 | self.w3.W = self.mu3.W + F.log(1 + F.exp(self.pho3.W))*Variable(self.eps3_w) 79 | 80 | self.eps1_b = np.random.normal(0,1,self.mu1.b.shape).astype(np.float32) 81 | self.eps2_b = np.random.normal(0,1,self.mu2.b.shape).astype(np.float32) 82 | self.eps3_b = np.random.normal(0,1,self.mu3.b.shape).astype(np.float32) 83 | 84 | self.w1.b = self.mu1.b + F.log(1 + F.exp(self.pho1.b))*Variable(self.eps1_b) 85 | self.w2.b = self.mu2.b + F.log(1 + F.exp(self.pho2.b))*Variable(self.eps2_b) 86 | self.w3.b = self.mu3.b + F.log(1 + F.exp(self.pho3.b))*Variable(self.eps3_b) 87 | 88 | h1 = F.relu(self.w1(x)) 89 | h2 = F.relu(self.w2(h1)) 90 | h3 = self.w3(h2) 91 | 92 | self.eps1 = [self.eps1_w,self.eps1_b] 93 | self.eps2 = [self.eps2_w,self.eps2_b] 94 | self.eps3 = [self.eps3_w,self.eps3_b] 95 | 96 | #print("w1_shape:{}".format(w1.shape)) 97 | """ 98 | w1 = F.reshape(w1,(self.n_in,self.n_hidden1)) 99 | w2 = F.reshape(w2,(self.n_hidden1,self.n_hidden2)) 100 | w3 = F.reshape(w3,(self.n_hidden2,self.n_out)) 101 | """ 102 | #return w1,w2,w3 103 | #print h3.shape,t.shape 104 | #h3 = F.reshape(h3,(h3.shape[0],)) 105 | #print h3.shape,t.shape 106 | return F.softmax_cross_entropy(h3,t) 107 | 108 | def mu_hstack(self): 109 | return F.hstack([F.flatten(self.mu1.W),F.flatten(self.mu1.b),F.flatten(self.mu2.W),F.flatten(self.mu2.b),F.flatten(self.mu3.W),F.flatten(self.mu3.b)]) 110 | 111 | def w_hstack(self): 112 | return F.hstack([F.flatten(self.w1.W),F.flatten(self.w1.b),F.flatten(self.w2.W),F.flatten(self.w2.b),F.flatten(self.w3.W),F.flatten(self.w3.b)]) 113 | 114 | def sigma_hstack(self): 115 | return F.log(1 + F.exp(F.hstack([F.flatten(self.pho1.W),F.flatten(self.pho1.b),F.flatten(self.pho2.W),F.flatten(self.pho2.b),F.flatten(self.pho3.W),F.flatten(self.pho3.b)]))) 116 | 117 | def update(self,model_num): 118 | """ 119 | print("update:{}".format(self.mu1.W.grad.shape)) 120 | print("update:{}".format(self.mu2.W.grad.shape)) 121 | print("update:{}".format(self.mu3.W.grad.shape)) 122 | print("update:{}".format(self.pho1.W.grad.shape)) 123 | """ 124 | for m,w in zip([self.mu1,self.mu2,self.mu3],[self.w1,self.w2,self.w3]): 125 | delta_w = m.W.grad + w.W.grad 126 | #delta_w = m.W.grad# + w.W.grad 127 | m.W = m.W - self.lr * delta_w 128 | delta_b = m.b.grad + w.b.grad 129 | #delta_b = m.b.grad# + w.b.grad 130 | m.b = m.b - self.lr * delta_b 131 | 132 | for pho,w,eps in zip([self.pho1,self.pho2,self.pho3],[self.w1,self.w2,self.w3],[self.eps1,self.eps2,self.eps3]): 133 | delta_w = pho.W.grad + w.W.grad * eps[0] / (1 + F.exp(-1*pho.W)) 134 | #delta_w = pho.W.grad# + w.W.grad * eps[0] / (1 + F.exp(-1*pho.W)) 135 | #pho.W = pho.W - self.lr * delta_w / np.float32(model_num) 136 | pho.W = pho.W - self.lr * delta_w 137 | delta_b = pho.b.grad + w.b.grad * eps[1] / (1 + F.exp(-1*pho.b)) 138 | #delta_b = pho.b.grad# + w.b.grad * eps[1] / (1 + F.exp(-1*pho.b)) 139 | #pho.b = pho.b - self.lr * delta_b / np.float32(model_num) 140 | pho.b = pho.b - self.lr * delta_b 141 | #print("update:{}".format(self.mu1.grad.shape)) 142 | #print("mu1_shape:{}".format(self.mu1.shape)) 143 | 144 | class MLP_MNIST_bbp_(object): 145 | def __init__(self, n_in = 784, n_hidden1 = 1200, n_hidden2 = 1200, n_out = 10, lr = 1e-4, prior_ratio = 0.5, prior_sigma_1 = np.exp(-1), prior_sigma_2 = np.exp(-7), prior_pho_var = .05): 146 | super(MLP_MNIST_bbp, self).__init__() 147 | self.n_in = n_in 148 | self.n_hidden1 = n_hidden1 149 | self.n_hidden2 = n_hidden2 150 | self.n_out = n_out 151 | self.lr = lr 152 | 153 | mu1 = prior_ratio * np.random.normal(0,prior_sigma_1,n_in * n_hidden1) + (1 - prior_ratio) * np.random.normal(0,prior_sigma_1,n_in * n_hidden1) 154 | #mu1 = mu1.reshape((n_in, n_hidden1)).astype(np.float32) 155 | mu2 = prior_ratio * np.random.normal(0,prior_sigma_1,n_hidden1 * n_hidden2) + (1 - prior_ratio) * np.random.normal(0,prior_sigma_1,n_hidden1 * n_hidden2) 156 | #mu2 = mu2.reshape((n_hidden1, n_hidden2)).astype(np.float32), 157 | mu3 = prior_ratio * np.random.normal(0,prior_sigma_1,n_hidden2 * n_out) + (1 - prior_ratio) * np.random.normal(0,prior_sigma_1,n_hidden2 * n_out) 158 | #mu3 = mu3.reshape((n_hidden2, 10)).astype(np.float32) 159 | self.mu1 = Variable(mu1.astype(np.float32)) 160 | self.mu2 = Variable(mu2.astype(np.float32)) 161 | self.mu3 = Variable(mu3.astype(np.float32)) 162 | pho1 = np.random.normal(0,prior_pho_var,n_in * n_hidden1) 163 | #pho1 = Variable(pho1.reshape((n_in, n_hidden1)).astype(np.float32)) 164 | self.pho1 = Variable(pho1.astype(np.float32)) 165 | pho2 = np.random.normal(0,prior_pho_var,n_hidden1 * n_hidden2) 166 | #pho2 = Variable(pho2.reshape((n_hidden1, n_hidden2)).astype(np.float32)) 167 | self.pho2 = Variable(pho2.astype(np.float32)) 168 | pho3 = np.random.normal(0,prior_pho_var,n_hidden2 * n_out) 169 | #pho3 = Variable(pho3.reshape((n_hidden2, 10)).astype(np.float32)) 170 | self.pho3 = Variable(pho3.astype(np.float32)) 171 | """ 172 | mu1 = L.Linear(n_in, n_hidden1), 173 | mu2 = L.Linear(n_hidden1, n_hidden2), 174 | mu3 = L.Linear(n_hidden2, 10), 175 | """ 176 | #bnorm1 = L.BatchNormalization(n_hidden1), 177 | #bnorm2 = L.BatchNormalization(n_hidden2) 178 | 179 | def __call__(self): 180 | eps1 = np.random.normal(0,1,self.n_in*self.n_hidden1).astype(np.float32) 181 | eps2 = np.random.normal(0,1,self.n_hidden1*self.n_hidden2).astype(np.float32) 182 | eps3 = np.random.normal(0,1,self.n_hidden2*self.n_out).astype(np.float32) 183 | 184 | w1 = self.mu1 + F.log(1 + F.exp(self.pho1))*Variable(eps1) 185 | w2 = self.mu2 + F.log(1 + F.exp(self.pho2))*Variable(eps2) 186 | w3 = self.mu3 + F.log(1 + F.exp(self.pho3))*Variable(eps3) 187 | #print("w1_shape:{}".format(w1.shape)) 188 | """ 189 | w1 = F.reshape(w1,(self.n_in,self.n_hidden1)) 190 | w2 = F.reshape(w2,(self.n_hidden1,self.n_hidden2)) 191 | w3 = F.reshape(w3,(self.n_hidden2,self.n_out)) 192 | """ 193 | return w1,w2,w3 194 | 195 | def mu_hstack(self): 196 | return F.hstack([self.mu1,self.mu2,self.mu3]) 197 | 198 | def sigma_hstack(self): 199 | return F.log(1 + F.exp(F.hstack([self.pho1,self.pho2,self.pho3]))) 200 | 201 | def update(self): 202 | print("update:{}".format(self.mu1.grad.shape)) 203 | print("update:{}".format(self.mu2.grad.shape)) 204 | print("update:{}".format(self.mu3.grad.shape)) 205 | print("update:{}".format(self.pho1.grad.shape)) 206 | self.mu1 = self.mu1 - self.lr * self.mu1.grad 207 | self.mu2 = self.mu2 - self.lr * self.mu2.grad 208 | self.mu3 = self.mu3 - self.lr * self.mu3.grad 209 | self.pho1 = self.pho1 - self.lr * self.pho1.grad 210 | self.pho2 = self.pho2 - self.lr * self.pho2.grad 211 | self.pho3 = self.pho3 - self.lr * self.pho3.grad 212 | #print("update:{}".format(self.mu1.grad.shape)) 213 | #print("mu1_shape:{}".format(self.mu1.shape)) 214 | 215 | class MLP_MNIST_dropput(Chain): 216 | def __init__(self, n_in = 784, n_hidden1 = 1200, n_hidden2 = 1200): 217 | super(MLPListNet, self).__init__( 218 | l1 = L.Linear(n_in, n_hidden1), 219 | l2 = L.Linear(n_hidden1, n_hidden2), 220 | l3 = L.Linear(n_hidden2, 10), 221 | #bnorm1 = L.BatchNormalization(n_hidden1), 222 | #bnorm2 = L.BatchNormalization(n_hidden2) 223 | ) 224 | 225 | def __call__(self, x): 226 | h1 = F.relu(self.l1(x)) 227 | h1 = F.dropout(h1) 228 | h2 = F.relu(self.l2(h1)) 229 | h2 = F.dropout(h2) 230 | #h1 = F.relu(self.bnorm1(self.l1(x))) 231 | #h1 = F.dropout(h1) 232 | #h2 = F.relu(self.bnorm2(self.l2(h1))) 233 | #h2 = F.dropout(h2) 234 | return self.l3(h2) 235 | 236 | --------------------------------------------------------------------------------