├── .gitignore ├── README.md ├── gaps_over_training_exp ├── script.sh ├── plot_over_training.py └── train_mnist.py ├── models ├── generator.py ├── inference_net.py ├── optimize_local_q.py ├── utils │ ├── utils.py │ ├── ais.py │ ├── ais2.py │ ├── ais3.py │ └── ais4.py ├── pytorch_vae_v6.py ├── vae_1.py └── vae_2.py ├── flow_effect_on_amort_exp ├── train_encoder_only.py ├── train_encoder_only2.py └── compute_gaps.py ├── decoder_sizes_exp ├── train_mnist.py └── compute_gaps.py ├── test_different_dists ├── eval.py ├── plot_gaps_over_epochs.py ├── plot_8plots.py └── train_mnist.py └── test_set_inference_exp └── compute_gaps.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.DS_Store 3 | 4 | *.aux 5 | *.log 6 | *.dvi 7 | *.out 8 | *.blg 9 | *.bbl 10 | 11 | *.ipynb_checkpoints 12 | 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Inference Suboptimality in Variational Autoencoders 2 | 3 | This repo contains code for the paper *Inference Suboptimality in Variational Autoencoders* (2018) 4 | [[arxiv](https://arxiv.org/abs/1801.03558)] 5 | 6 | More code for this paper can be found in co-author Xuechen Li's repo: [[github](https://github.com/lxuechen/inference-suboptimality)] 7 | 8 | This paper's experiments have also been reproduced and extended by a group of MSc students at the University of Oxford: 9 | [[github](https://github.com/ATML-2020-Group9/Inference-Suboptimality)] 10 | 11 | 12 | ## Citation 13 | 14 | ``` 15 | @article{cremer2018inference, 16 | title={Inference Suboptimality in Variational Autoencoders}, 17 | author={Cremer, Chris and Li, Xuechen and Duvenaud, David}, 18 | journal={ICML}, 19 | year={2018} 20 | } 21 | ``` 22 | -------------------------------------------------------------------------------- /gaps_over_training_exp/script.sh: -------------------------------------------------------------------------------- 1 | 2 | # tmux ls ; tmux ls 3 | 4 | # tmux 5 | 6 | python compute_gaps2.py 0 100 amort ; python compute_gaps2.py 1 100 opt_train ; python compute_gaps2.py 0 100 opt_valid 7 | 8 | 9 | python compute_gaps2.py 1 1000 amort ; python compute_gaps2.py 0 1000 opt_train ; python compute_gaps2.py 1 1000 opt_valid 10 | 11 | 12 | python compute_gaps2.py 0 2200 amort ; python compute_gaps2.py 1 2200 opt_train ; python compute_gaps2.py 0 2200 opt_valid 13 | 14 | 15 | python compute_gaps2.py 1 3280 amort ; python compute_gaps2.py 0 3280 opt_train ; python compute_gaps2.py 1 3280 opt_valid 16 | 17 | 18 | 19 | 20 | 21 | python compute_gaps2.py 0 100 amort ; python compute_gaps2.py 0 700 amort ; python compute_gaps2.py 0 1300 amort 22 | 23 | 24 | 25 | 26 | 27 | 28 | python compute_gaps2.py 0 100 opt_train ; python compute_gaps2.py 0 300 opt_train ; python compute_gaps2.py 0 500 opt_train ; python compute_gaps2.py 0 700 opt_train ; python compute_gaps2.py 0 1000 opt_train 29 | python compute_gaps2.py 1 100 opt_valid ; python compute_gaps2.py 1 300 opt_valid ; python compute_gaps2.py 1 500 opt_valid ; python compute_gaps2.py 1 700 opt_valid ; python compute_gaps2.py 1 1000 opt_valid 30 | 31 | 32 | python compute_gaps2.py 0 100 amort ; python compute_gaps2.py 0 300 amort ; python compute_gaps2.py 0 500 amort ; python compute_gaps2.py 0 700 amort ; python compute_gaps2.py 0 1000 amort 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /models/generator.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | import torch 6 | from torch.autograd import Variable 7 | import torch.utils.data 8 | import torch.optim as optim 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | class Generator(nn.Module): 21 | 22 | def __init__(self, hyper_config): 23 | super(Generator, self).__init__() 24 | 25 | if hyper_config['cuda']: 26 | self.dtype = torch.cuda.FloatTensor 27 | else: 28 | self.dtype = torch.FloatTensor 29 | 30 | self.z_size = hyper_config['z_size'] 31 | self.x_size = hyper_config['x_size'] 32 | self.act_func = hyper_config['act_func'] 33 | 34 | #Decoder 35 | self.decoder_weights = [] 36 | self.layer_norms = [] 37 | for i in range(len(hyper_config['decoder_arch'])): 38 | self.decoder_weights.append(nn.Linear(hyper_config['decoder_arch'][i][0], hyper_config['decoder_arch'][i][1])) 39 | 40 | count =1 41 | for i in range(len(self.decoder_weights)): 42 | self.add_module(str(count), self.decoder_weights[i]) 43 | count+=1 44 | 45 | 46 | def decode(self, z): 47 | k = z.size()[0] 48 | B = z.size()[1] 49 | z = z.view(-1, self.z_size) 50 | 51 | out = z 52 | for i in range(len(self.decoder_weights)-1): 53 | out = self.act_func(self.decoder_weights[i](out)) 54 | # out = self.act_func(self.layer_norms[i].forward(self.decoder_weights[i](out))) 55 | out = self.decoder_weights[-1](out) 56 | 57 | x = out.view(k, B, self.x_size) 58 | return x 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /models/inference_net.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | import torch 7 | from torch.autograd import Variable 8 | import torch.utils.data 9 | import torch.optim as optim 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | 15 | import sys 16 | sys.path.insert(0, 'utils') 17 | from utils import lognormal2 as lognormal 18 | from utils import lognormal333 19 | 20 | 21 | 22 | from distributions import Gaussian 23 | from distributions import Flow 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | class standard(nn.Module): 32 | 33 | def __init__(self, hyper_config): 34 | super(standard, self).__init__() 35 | 36 | if torch.cuda.is_available(): 37 | self.dtype = torch.cuda.FloatTensor 38 | else: 39 | self.dtype = torch.FloatTensor 40 | 41 | self.hyper_config = hyper_config 42 | 43 | self.z_size = hyper_config['z_size'] 44 | self.x_size = hyper_config['x_size'] 45 | self.act_func = hyper_config['act_func'] 46 | 47 | 48 | #Encoder 49 | self.encoder_weights = [] 50 | self.layer_norms = [] 51 | for i in range(len(hyper_config['encoder_arch'])): 52 | self.encoder_weights.append(nn.Linear(hyper_config['encoder_arch'][i][0], hyper_config['encoder_arch'][i][1])) 53 | 54 | # if i != len(hyper_config['encoder_arch'])-1: 55 | # self.layer_norms.append(LayerNorm(hyper_config['encoder_arch'][i][1])) 56 | 57 | count =1 58 | for i in range(len(self.encoder_weights)): 59 | self.add_module(str(count), self.encoder_weights[i]) 60 | count+=1 61 | 62 | # if i != len(hyper_config['encoder_arch'])-1: 63 | # self.add_module(str(count), self.layer_norms[i]) 64 | # count+=1 65 | 66 | 67 | 68 | # self.q = Gaussian(self.hyper_config) #, mean, logvar) 69 | # self.q = Flow(self.hyper_config)#, mean, logvar) 70 | self.q = hyper_config['q'] 71 | 72 | 73 | def forward(self, k, x, logposterior): 74 | ''' 75 | k: number of samples 76 | x: [B,X] 77 | logposterior(z) -> [P,B] 78 | ''' 79 | 80 | self.B = x.size()[0] 81 | 82 | #Encode 83 | out = x 84 | for i in range(len(self.encoder_weights)-1): 85 | out = self.act_func(self.encoder_weights[i](out)) 86 | # out = self.act_func(self.layer_norms[i].forward(self.encoder_weights[i](out))) 87 | 88 | out = self.encoder_weights[-1](out) 89 | mean = out[:,:self.z_size] #[B,Z] 90 | logvar = out[:,self.z_size:] 91 | 92 | # #Sample 93 | # eps = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_().type(self.dtype)) #[P,B,Z] 94 | # z = eps.mul(torch.exp(.5*logvar)) + mean #[P,B,Z] 95 | # logqz = lognormal(z, mean, logvar) #[P,B] 96 | 97 | 98 | if self.hyper_config['hnf']: 99 | z, logqz = self.q.sample(mean, logvar, k, logposterior) 100 | else: 101 | z, logqz = self.q.sample(mean, logvar, k) 102 | 103 | return z, logqz 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /models/optimize_local_q.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | import numpy as np 9 | 10 | import torch 11 | from torch.autograd import Variable 12 | import torch.utils.data 13 | import torch.optim as optim 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | 17 | from utils import lognormal2 as lognormal 18 | from utils import lognormal333 19 | 20 | from utils import log_bernoulli 21 | 22 | import time 23 | 24 | import pickle 25 | 26 | quick = 0 27 | 28 | 29 | 30 | 31 | 32 | 33 | def optimize_local_q_dist(logposterior, hyper_config, x, q): 34 | 35 | B = x.size()[0] #batch size 36 | P = 50 37 | 38 | z_size = hyper_config['z_size'] 39 | x_size = hyper_config['x_size'] 40 | if torch.cuda.is_available(): 41 | dtype = torch.cuda.FloatTensor 42 | else: 43 | dtype = torch.FloatTensor 44 | 45 | mean = Variable(torch.zeros(B, z_size).type(dtype), requires_grad=True) 46 | logvar = Variable(torch.zeros(B, z_size).type(dtype), requires_grad=True) 47 | 48 | params = [mean, logvar] 49 | for aaa in q.parameters(): 50 | params.append(aaa) 51 | 52 | 53 | optimizer = optim.Adam(params, lr=.001) 54 | 55 | last_100 = [] 56 | best_last_100_avg = -1 57 | consecutive_worse = 0 58 | for epoch in range(1, 999999): 59 | 60 | # #Sample 61 | # eps = Variable(torch.FloatTensor(P, B, model.z_size).normal_().type(model.dtype)) #[P,B,Z] 62 | # z = eps.mul(torch.exp(.5*logvar)) + mean #[P,B,Z] 63 | # logqz = lognormal(z, mean, logvar) #[P,B] 64 | 65 | # fsadfad 66 | # z, logqz = q.sample(...) 67 | z, logqz = q.sample(mean, logvar, P) 68 | 69 | logpx = logposterior(z) 70 | 71 | optimizer.zero_grad() 72 | 73 | 74 | loss = -(torch.mean(logpx-logqz)) 75 | loss_np = loss.data.cpu().numpy() 76 | # print (epoch, loss_np) 77 | # fasfaf 78 | 79 | loss.backward() 80 | optimizer.step() 81 | 82 | last_100.append(loss_np) 83 | if epoch % 100 ==0: 84 | 85 | last_100_avg = np.mean(last_100) 86 | if last_100_avg< best_last_100_avg or best_last_100_avg == -1: 87 | consecutive_worse=0 88 | best_last_100_avg = last_100_avg 89 | else: 90 | consecutive_worse +=1 91 | # print(consecutive_worse) 92 | if consecutive_worse> 10: 93 | # print ('done') 94 | break 95 | 96 | if epoch % 2000 ==0: 97 | print (epoch, last_100_avg, consecutive_worse)#,mean) 98 | # print (torch.mean(logpx)) 99 | 100 | last_100 = [] 101 | 102 | 103 | 104 | # Compute VAE and IWAE bounds 105 | 106 | # #Sample 107 | # eps = Variable(torch.FloatTensor(1000, B, model.z_size).normal_().type(model.dtype)) #[P,B,Z] 108 | # z = eps.mul(torch.exp(.5*logvar)) + mean #[P,B,Z] 109 | # logqz = lognormal(z, mean, logvar) #[P,B] 110 | z, logqz = q.sample(mean, logvar, 5000) 111 | 112 | # print (logqz) 113 | # fad 114 | logpx = logposterior(z) 115 | 116 | elbo = logpx-logqz #[P,B] 117 | vae = torch.mean(elbo) 118 | 119 | max_ = torch.max(elbo, 0)[0] #[B] 120 | elbo_ = torch.log(torch.mean(torch.exp(elbo - max_), 0)) + max_ #[B] 121 | iwae = torch.mean(elbo_) 122 | 123 | return vae, iwae 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | -------------------------------------------------------------------------------- /models/utils/utils.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import math 5 | import torch 6 | from torch.autograd import Variable 7 | import numpy as np 8 | 9 | 10 | import torch 11 | from torch.autograd import Variable 12 | import torch.utils.data 13 | import torch.optim as optim 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | 17 | 18 | 19 | # def lognormal(x, mean, logvar): 20 | # ''' 21 | # x: [B,Z] 22 | # mean,logvar: [B,Z] 23 | # output: [B] 24 | # ''' 25 | 26 | # # # D = x.size()[1] 27 | # # # term1 = D * torch.log(torch.FloatTensor([2.*math.pi])) #[1] 28 | # # term2 = logvar.sum(1) #sum over D, [B] 29 | # # dif_cov = (x - mean).pow(2) 30 | # # # dif_cov.div(torch.exp(logvar)) #exp_()) #[P,B,D] 31 | # # term3 = (dif_cov/torch.exp(logvar)).sum(1) #sum over D, [P,B] 32 | # # # all_ = Variable(term1) + term2 + term3 #[P,B] 33 | # # all_ = term2 + term3 #[P,B] 34 | # # log_N = -.5 * all_ 35 | # # return log_N 36 | 37 | # # term2 = logvar.sum(1) #sum over D, [B] 38 | # # dif_cov = (x - mean).pow(2) 39 | # # term3 = (dif_cov/torch.exp(logvar)).sum(1) #sum over D, [P,B] 40 | # # all_ = term2 + term3 #[P,B] 41 | # # log_N = -.5 * all_ 42 | # # return log_N 43 | 44 | # # one line 45 | # return -.5 * (logvar.sum(1) + ((x - mean).pow(2)/torch.exp(logvar)).sum(1)) 46 | 47 | 48 | class LayerNorm(nn.Module): 49 | 50 | def __init__(self, features, eps=1e-6): 51 | super().__init__() 52 | self.gamma = nn.Parameter(torch.ones(features)) 53 | self.beta = nn.Parameter(torch.zeros(features)) 54 | self.eps = eps 55 | 56 | def forward(self, x): 57 | mean = x.mean(-1, keepdim=True) 58 | std = x.std(-1, keepdim=True) 59 | return self.gamma * (x - mean) / (std + self.eps) + self.beta 60 | 61 | 62 | 63 | 64 | 65 | def lognormal2(x, mean, logvar): 66 | ''' 67 | x: [P,B,Z] 68 | mean,logvar: [B,Z] 69 | output: [P,B] 70 | ''' 71 | 72 | assert len(x.size()) == 3 73 | assert len(mean.size()) == 2 74 | assert len(logvar.size()) == 2 75 | assert x.size()[1] == mean.size()[0] 76 | 77 | D = x.size()[2] 78 | 79 | if torch.cuda.is_available(): 80 | term1 = D * torch.log(torch.cuda.FloatTensor([2.*math.pi])) #[1] 81 | else: 82 | term1 = D * torch.log(torch.FloatTensor([2.*math.pi])) #[1] 83 | 84 | 85 | return -.5 * (Variable(term1) + logvar.sum(1) + ((x - mean).pow(2)/torch.exp(logvar)).sum(2)) 86 | 87 | 88 | def lognormal333(x, mean, logvar): 89 | ''' 90 | x: [P,B,Z] 91 | mean,logvar: [P,B,Z] 92 | output: [P,B] 93 | ''' 94 | 95 | assert len(x.size()) == 3 96 | assert len(mean.size()) == 3 97 | assert len(logvar.size()) == 3 98 | assert x.size()[0] == mean.size()[0] 99 | assert x.size()[1] == mean.size()[1] 100 | 101 | D = x.size()[2] 102 | 103 | if torch.cuda.is_available(): 104 | term1 = D * torch.log(torch.cuda.FloatTensor([2.*math.pi])) #[1] 105 | else: 106 | term1 = D * torch.log(torch.FloatTensor([2.*math.pi])) #[1] 107 | 108 | 109 | return -.5 * (Variable(term1) + logvar.sum(2) + ((x - mean).pow(2)/torch.exp(logvar)).sum(2)) 110 | 111 | 112 | 113 | 114 | def log_bernoulli(pred_no_sig, target): 115 | ''' 116 | pred_no_sig is [P, B, X] 117 | t is [B, X] 118 | output is [P, B] 119 | ''' 120 | 121 | assert len(pred_no_sig.size()) == 3 122 | assert len(target.size()) == 2 123 | assert pred_no_sig.size()[1] == target.size()[0] 124 | 125 | return -(torch.clamp(pred_no_sig, min=0) 126 | - pred_no_sig * target 127 | + torch.log(1. + torch.exp(-torch.abs(pred_no_sig)))).sum(2) #sum over dimensions 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | def lognormal3(x, mean, logvar): 137 | ''' 138 | x: [P] 139 | mean,logvar: [P] 140 | output: [1] 141 | ''' 142 | 143 | return -.5 * (logvar.sum(0) + ((x - mean).pow(2)/torch.exp(logvar)).sum(0)) 144 | 145 | 146 | 147 | 148 | def lognormal4(x, mean, logvar): 149 | ''' 150 | x: [B,X] 151 | mean,logvar: [X] 152 | output: [B] 153 | ''' 154 | # print x.size() 155 | # print mean.size() 156 | # print logvar.size() 157 | # print mean 158 | # print logvar 159 | D = x.size()[1] 160 | # print D 161 | term1 = D * torch.log(torch.FloatTensor([2.*math.pi])) #[1] 162 | # print term1 163 | # print logvar.sum(0) 164 | 165 | aaa = -.5 * (term1 + logvar.sum(0) + ((x - mean).pow(2)/torch.exp(logvar)).sum(1)) 166 | # print aaa.size() 167 | 168 | return aaa 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | -------------------------------------------------------------------------------- /models/utils/ais.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | import math 7 | import torch 8 | from torch.autograd import Variable 9 | import numpy as np 10 | 11 | from utils import lognormal2 as lognormal 12 | from utils import log_bernoulli 13 | 14 | import time 15 | 16 | 17 | def test_ais(model, data_x, batch_size, display, k, n_intermediate_dists): 18 | 19 | 20 | def intermediate_dist(t, z, mean, logvar, zeros, batch): 21 | logp1 = lognormal(z, mean, logvar) #[P,B] 22 | log_prior = lognormal(z, zeros, zeros) #[P,B] 23 | log_likelihood = log_bernoulli(model.decode(z), batch) 24 | logpT = log_prior + log_likelihood 25 | log_intermediate_2 = (1-float(t))*logp1 + float(t)*logpT 26 | return log_intermediate_2 27 | 28 | 29 | def hmc(z, intermediate_dist_func): 30 | 31 | if torch.cuda.is_available(): 32 | v = Variable(torch.FloatTensor(z.size()).normal_(), volatile=volatile_, requires_grad=requires_grad).cuda() 33 | else: 34 | v = Variable(torch.FloatTensor(z.size()).normal_()) 35 | 36 | v0 = v 37 | z0 = z 38 | 39 | gradients = torch.autograd.grad(outputs=intermediate_dist_func(z), inputs=z, 40 | grad_outputs=grad_outputs, 41 | create_graph=True, retain_graph=retain_graph, only_inputs=True)[0] 42 | 43 | gradients = gradients.detach() 44 | 45 | v = v + .5 *step_size*gradients 46 | z = z + step_size*v 47 | 48 | for LF_step in range(n_HMC_steps): 49 | 50 | # log_intermediate_2 = intermediate_dist(t1, z, mean, logvar, zeros, batch) 51 | gradients = torch.autograd.grad(outputs=intermediate_dist_func(z), inputs=z, 52 | grad_outputs=grad_outputs, 53 | create_graph=True, retain_graph=retain_graph, only_inputs=True)[0] 54 | gradients = gradients.detach() 55 | v = v + step_size*gradients 56 | z = z + step_size*v 57 | 58 | # log_intermediate_2 = intermediate_dist(t1, z, mean, logvar, zeros, batch) 59 | gradients = torch.autograd.grad(outputs=intermediate_dist_func(z), inputs=z, 60 | grad_outputs=grad_outputs, 61 | create_graph=True, retain_graph=retain_graph, only_inputs=True)[0] 62 | gradients = gradients.detach() 63 | v = v + .5 *step_size*gradients 64 | 65 | return z0, v0, z, v 66 | 67 | 68 | def mh_step(z0, v0, z, v, step_size, intermediate_dist_func): 69 | 70 | logpv0 = lognormal(v0, zeros, zeros) #[P,B] 71 | hamil_0 = intermediate_dist_func(z0) + logpv0 72 | 73 | logpvT = lognormal(v, zeros, zeros) #[P,B] 74 | hamil_T = intermediate_dist_func(z) + logpvT 75 | 76 | accept_prob = torch.exp(hamil_T - hamil_0) 77 | 78 | if torch.cuda.is_available(): 79 | rand_uni = Variable(torch.FloatTensor(accept_prob.size()).uniform_(), volatile=volatile_, requires_grad=requires_grad).cuda() 80 | else: 81 | rand_uni = Variable(torch.FloatTensor(accept_prob.size()).uniform_()) 82 | 83 | accept = accept_prob > rand_uni 84 | 85 | if torch.cuda.is_available(): 86 | accept = accept.type(torch.FloatTensor).cuda() 87 | else: 88 | accept = accept.type(torch.FloatTensor) 89 | 90 | accept = accept.view(k, model.B, 1) 91 | 92 | z = (accept * z) + ((1-accept) * z0) 93 | 94 | #Adapt step size 95 | avg_acceptance_rate = torch.mean(accept) 96 | 97 | if avg_acceptance_rate.cpu().data.numpy() > .7: 98 | step_size = 1.02 * step_size 99 | else: 100 | step_size = .98 * step_size 101 | 102 | if step_size < 0.0001: 103 | step_size = 0.0001 104 | if step_size > 0.5: 105 | step_size = 0.5 106 | 107 | return z, step_size 108 | 109 | 110 | 111 | 112 | # n_intermediate_dists = 10 113 | n_HMC_steps = 5 114 | step_size = .1 115 | 116 | retain_graph = False 117 | volatile_ = False 118 | requires_grad = False 119 | 120 | time_ = time.time() 121 | 122 | logws = [] 123 | data_index= 0 124 | for i in range(int(len(data_x)/ batch_size)): 125 | 126 | #AIS 127 | 128 | schedule = np.linspace(0.,1.,n_intermediate_dists) 129 | model.B = batch_size 130 | 131 | batch = data_x[data_index:data_index+batch_size] 132 | data_index += batch_size 133 | 134 | 135 | 136 | if torch.cuda.is_available(): 137 | batch = Variable(torch.from_numpy(batch), volatile=volatile_, requires_grad=requires_grad).cuda() 138 | zeros = Variable(torch.zeros(model.B, model.z_size), volatile=volatile_, requires_grad=requires_grad).cuda() # [B,Z] 139 | logw = Variable(torch.zeros(k, model.B), volatile=True, requires_grad=requires_grad).cuda() 140 | grad_outputs = torch.ones(k, model.B).cuda() 141 | else: 142 | batch = Variable(torch.from_numpy(batch)) 143 | zeros = Variable(torch.zeros(model.B, model.z_size)) # [B,Z] 144 | logw = Variable(torch.zeros(k, model.B)) 145 | grad_outputs = torch.ones(k, model.B) 146 | 147 | 148 | #Encode x 149 | mean, logvar = model.encode(batch) #[B,Z] 150 | #Init z 151 | z, logpz, logqz = model.sample(mean, logvar, k=k) #[P,B,Z], [P,B], [P,B] 152 | 153 | for (t0, t1) in zip(schedule[:-1], schedule[1:]): 154 | 155 | 156 | #logw = logw + logpt-1(zt-1) - logpt(zt-1) 157 | log_intermediate_1 = intermediate_dist(t0, z, mean, logvar, zeros, batch) 158 | log_intermediate_2 = intermediate_dist(t1, z, mean, logvar, zeros, batch) 159 | logw += log_intermediate_2 - log_intermediate_1 160 | 161 | 162 | 163 | #HMC dynamics 164 | intermediate_dist_func = lambda aaa: intermediate_dist(t1, aaa, mean, logvar, zeros, batch) 165 | z0, v0, z, v = hmc(z, intermediate_dist_func) 166 | 167 | #MH step 168 | z, step_size = mh_step(z0, v0, z, v, step_size, intermediate_dist_func) 169 | 170 | #log sum exp 171 | max_ = torch.max(logw,0)[0] #[B] 172 | logw = torch.log(torch.mean(torch.exp(logw - max_), 0)) + max_ #[B] 173 | 174 | logws.append(torch.mean(logw.cpu()).data.numpy()) 175 | 176 | 177 | if i%display==0: 178 | print (i,len(data_x)/ batch_size, np.mean(logws)) 179 | 180 | mean_ = np.mean(logws) 181 | print(mean_, 'T:', time.time()-time_) 182 | return mean_ 183 | 184 | 185 | 186 | -------------------------------------------------------------------------------- /models/utils/ais2.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | # This one samples the prior distribution 4 | 5 | 6 | import math 7 | import torch 8 | from torch.autograd import Variable 9 | import numpy as np 10 | 11 | from utils import lognormal2 as lognormal 12 | from utils import log_bernoulli 13 | 14 | import time 15 | 16 | 17 | def test_ais(model, data_x, batch_size, display, k, n_intermediate_dists): 18 | 19 | 20 | def intermediate_dist(t, z, mean, logvar, zeros, batch): 21 | # logp1 = lognormal(z, mean, logvar) #[P,B] 22 | log_prior = lognormal(z, zeros, zeros) #[P,B] 23 | log_likelihood = log_bernoulli(model.decode(z), batch) 24 | # logpT = log_prior + log_likelihood 25 | # log_intermediate_2 = (1-float(t))*logp1 + float(t)*logpT 26 | 27 | log_intermediate_2 = log_prior + float(t)*log_likelihood 28 | 29 | return log_intermediate_2 30 | 31 | 32 | def hmc(z, intermediate_dist_func): 33 | 34 | if torch.cuda.is_available(): 35 | v = Variable(torch.FloatTensor(z.size()).normal_(), volatile=volatile_, requires_grad=requires_grad).cuda() 36 | else: 37 | v = Variable(torch.FloatTensor(z.size()).normal_()) 38 | 39 | v0 = v 40 | z0 = z 41 | 42 | # print (intermediate_dist_func(z)) 43 | # fasdf 44 | gradients = torch.autograd.grad(outputs=intermediate_dist_func(z), inputs=z, 45 | grad_outputs=grad_outputs, 46 | create_graph=True, retain_graph=retain_graph, only_inputs=True)[0] 47 | 48 | gradients = gradients.detach() 49 | 50 | v = v + .5 *step_size*gradients 51 | z = z + step_size*v 52 | 53 | for LF_step in range(n_HMC_steps): 54 | 55 | # log_intermediate_2 = intermediate_dist(t1, z, mean, logvar, zeros, batch) 56 | gradients = torch.autograd.grad(outputs=intermediate_dist_func(z), inputs=z, 57 | grad_outputs=grad_outputs, 58 | create_graph=True, retain_graph=retain_graph, only_inputs=True)[0] 59 | gradients = gradients.detach() 60 | v = v + step_size*gradients 61 | z = z + step_size*v 62 | 63 | # log_intermediate_2 = intermediate_dist(t1, z, mean, logvar, zeros, batch) 64 | gradients = torch.autograd.grad(outputs=intermediate_dist_func(z), inputs=z, 65 | grad_outputs=grad_outputs, 66 | create_graph=True, retain_graph=retain_graph, only_inputs=True)[0] 67 | gradients = gradients.detach() 68 | v = v + .5 *step_size*gradients 69 | 70 | return z0, v0, z, v 71 | 72 | 73 | def mh_step(z0, v0, z, v, step_size, intermediate_dist_func): 74 | 75 | logpv0 = lognormal(v0, zeros, zeros) #[P,B] 76 | hamil_0 = intermediate_dist_func(z0) + logpv0 77 | 78 | logpvT = lognormal(v, zeros, zeros) #[P,B] 79 | hamil_T = intermediate_dist_func(z) + logpvT 80 | 81 | accept_prob = torch.exp(hamil_T - hamil_0) 82 | 83 | if torch.cuda.is_available(): 84 | rand_uni = Variable(torch.FloatTensor(accept_prob.size()).uniform_(), volatile=volatile_, requires_grad=requires_grad).cuda() 85 | else: 86 | rand_uni = Variable(torch.FloatTensor(accept_prob.size()).uniform_()) 87 | 88 | accept = accept_prob > rand_uni 89 | 90 | if torch.cuda.is_available(): 91 | accept = accept.type(torch.FloatTensor).cuda() 92 | else: 93 | accept = accept.type(torch.FloatTensor) 94 | 95 | accept = accept.view(k, model.B, 1) 96 | 97 | z = (accept * z) + ((1-accept) * z0) 98 | 99 | #Adapt step size 100 | avg_acceptance_rate = torch.mean(accept) 101 | 102 | if avg_acceptance_rate.cpu().data.numpy() > .65: 103 | step_size = 1.02 * step_size 104 | else: 105 | step_size = .98 * step_size 106 | 107 | if step_size < 0.0001: 108 | step_size = 0.0001 109 | if step_size > 0.5: 110 | step_size = 0.5 111 | 112 | return z, step_size 113 | 114 | 115 | 116 | 117 | # n_intermediate_dists = 10 118 | n_HMC_steps = 10 119 | step_size = .1 120 | 121 | retain_graph = False 122 | volatile_ = False 123 | requires_grad = False 124 | 125 | time_ = time.time() 126 | 127 | logws = [] 128 | data_index= 0 129 | for i in range(int(len(data_x)/ batch_size)): 130 | 131 | #AIS 132 | 133 | schedule = np.linspace(0.,1.,n_intermediate_dists) 134 | model.B = batch_size 135 | 136 | batch = data_x[data_index:data_index+batch_size] 137 | data_index += batch_size 138 | 139 | 140 | 141 | if torch.cuda.is_available(): 142 | batch = Variable(torch.from_numpy(batch), volatile=volatile_, requires_grad=requires_grad).cuda() 143 | zeros = Variable(torch.zeros(model.B, model.z_size), volatile=volatile_, requires_grad=requires_grad).cuda() # [B,Z] 144 | logw = Variable(torch.zeros(k, model.B), volatile=True, requires_grad=requires_grad).cuda() 145 | grad_outputs = torch.ones(k, model.B).cuda() 146 | else: 147 | batch = Variable(torch.from_numpy(batch)) 148 | zeros = Variable(torch.zeros(model.B, model.z_size)) # [B,Z] 149 | logw = Variable(torch.zeros(k, model.B)) 150 | grad_outputs = torch.ones(k, model.B) 151 | 152 | 153 | # #Encode x 154 | # mean, logvar = model.encode(batch) #[B,Z] 155 | # #Init z 156 | # z, logpz, logqz = model.sample(mean, logvar, k=k) #[P,B,Z], [P,B], [P,B] 157 | 158 | 159 | # z, logqz = model.q_dist.forward(k=k, x=batch, logposterior=model.logposterior) 160 | z = Variable(torch.FloatTensor(k, model.B, model.z_size).normal_().type(model.dtype),requires_grad=True) 161 | 162 | 163 | for (t0, t1) in zip(schedule[:-1], schedule[1:]): 164 | 165 | 166 | #logw = logw + logpt-1(zt-1) - logpt(zt-1) t, z, mean, logvar, zeros, batch 167 | # log_intermediate_1 = intermediate_dist(t0, z, mean, logvar, zeros, batch) 168 | # log_intermediate_2 = intermediate_dist(t1, z, mean, logvar, zeros, batch) 169 | log_intermediate_1 = intermediate_dist(t=t0, z=z, mean=zeros, logvar=zeros, zeros=zeros, batch=batch) 170 | log_intermediate_2 = intermediate_dist(t=t1, z=z, mean=zeros, logvar=zeros, zeros=zeros, batch=batch) 171 | 172 | 173 | logw += log_intermediate_2 - log_intermediate_1 174 | 175 | 176 | 177 | #HMC dynamics 178 | # intermediate_dist_func = lambda aaa: intermediate_dist(t1, aaa, mean, logvar, zeros, batch) 179 | intermediate_dist_func = lambda aaa: intermediate_dist(t1, aaa, zeros, zeros, zeros, batch) 180 | 181 | # print (t1) 182 | 183 | z0, v0, z, v = hmc(z, intermediate_dist_func) 184 | 185 | #MH step 186 | z, step_size = mh_step(z0, v0, z, v, step_size, intermediate_dist_func) 187 | 188 | #log sum exp 189 | max_ = torch.max(logw,0)[0] #[B] 190 | logw = torch.log(torch.mean(torch.exp(logw - max_), 0)) + max_ #[B] 191 | 192 | logws.append(torch.mean(logw.cpu()).data.numpy()) 193 | 194 | 195 | if i%display==0: 196 | print (i,len(data_x)/ batch_size, np.mean(logws),step_size) 197 | 198 | mean_ = np.mean(logws) 199 | print(mean_, 'T:', time.time()-time_) 200 | return mean_ 201 | 202 | 203 | 204 | -------------------------------------------------------------------------------- /flow_effect_on_amort_exp/train_encoder_only.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import numpy as np 5 | import gzip 6 | import time 7 | import pickle 8 | 9 | from os.path import expanduser 10 | home = expanduser("~") 11 | 12 | import sys, os 13 | sys.path.insert(0, '../models') 14 | sys.path.insert(0, '../models/utils') 15 | 16 | 17 | import matplotlib 18 | matplotlib.use('Agg') 19 | import matplotlib.pyplot as plt 20 | 21 | 22 | import torch 23 | from torch.autograd import Variable 24 | import torch.utils.data 25 | import torch.optim as optim 26 | import torch.nn as nn 27 | import torch.nn.functional as F 28 | 29 | from vae_2 import VAE 30 | 31 | from inference_net import standard 32 | 33 | 34 | 35 | 36 | #Load data 37 | print ('Loading data' ) 38 | data_location = home + '/Documents/MNIST_data/' 39 | # with open(data_location + 'binarized_mnist.pkl', 'rb') as f: 40 | # train_x, valid_x, test_x = pickle.load(f) 41 | with open(data_location + 'binarized_mnist.pkl', 'rb') as f: 42 | train_x, valid_x, test_x = pickle.load(f, encoding='latin1') 43 | print ('Train', train_x.shape) 44 | print ('Valid', valid_x.shape) 45 | print ('Test', test_x.shape) 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | def train_encdoer_and_decoder(model, train_x, test_x, k, batch_size, 57 | start_at, save_freq, display_epoch, 58 | path_to_save_variables): 59 | 60 | train_y = torch.from_numpy(np.zeros(len(train_x))) 61 | train_x = torch.from_numpy(train_x).float().type(model.dtype) 62 | 63 | train_ = torch.utils.data.TensorDataset(train_x, train_y) 64 | train_loader = torch.utils.data.DataLoader(train_, batch_size=batch_size, shuffle=True) 65 | 66 | #IWAE paper training strategy 67 | time_ = time.time() 68 | total_epochs = 0 69 | 70 | i_max = 7 71 | 72 | warmup_over_epochs = 100. 73 | 74 | 75 | all_params = [] 76 | for aaa in model.q_dist.parameters(): 77 | all_params.append(aaa) 78 | # for aaa in model.generator.parameters(): 79 | # all_params.append(aaa) 80 | # print (len(all_params), 'number of params') 81 | 82 | print (model.q_dist) 83 | # print (model.q_dist.q) 84 | print (model.generator) 85 | 86 | for i in range(0,i_max+1): 87 | 88 | lr = .001 * 10**(-i/float(i_max)) 89 | print (i, 'LR:', lr) 90 | 91 | optimizer = optim.Adam(all_params, lr=lr) 92 | 93 | epochs = 3**(i) 94 | 95 | for epoch in range(1, epochs + 1): 96 | 97 | for batch_idx, (data, target) in enumerate(train_loader): 98 | 99 | batch = Variable(data)#.type(model.dtype) 100 | 101 | optimizer.zero_grad() 102 | 103 | warmup = total_epochs/warmup_over_epochs 104 | if warmup > 1.: 105 | warmup = 1. 106 | 107 | elbo, logpxz, logqz = model.forward(batch, k=k, warmup=warmup) 108 | 109 | loss = -(elbo) 110 | loss.backward() 111 | optimizer.step() 112 | 113 | total_epochs += 1 114 | 115 | 116 | if total_epochs%display_epoch==0: 117 | print ('Train Epoch: {}/{}'.format(epoch, epochs), 118 | 'total_epochs {}'.format(total_epochs), 119 | 'LL:{:.3f}'.format(-loss.data[0]), 120 | 'logpxz:{:.3f}'.format(logpxz.data[0]), 121 | 'logqz:{:.3f}'.format(logqz.data[0]), 122 | 'warmup:{:.3f}'.format(warmup), 123 | 'T:{:.2f}'.format(time.time()-time_), 124 | ) 125 | time_ = time.time() 126 | 127 | 128 | if total_epochs >= start_at and (total_epochs-start_at)%save_freq==0: 129 | 130 | # save params 131 | save_file = path_to_save_variables+'_encoder_'+str(total_epochs)+'.pt' 132 | torch.save(model.q_dist.state_dict(), save_file) 133 | print ('saved variables ' + save_file) 134 | # save_file = path_to_save_variables+'_generator_'+str(total_epochs)+'.pt' 135 | # torch.save(model.generator.state_dict(), save_file) 136 | # print ('saved variables ' + save_file) 137 | 138 | 139 | 140 | # save params 141 | save_file = path_to_save_variables+'_encoder_'+str(total_epochs)+'.pt' 142 | torch.save(model.q_dist.state_dict(), save_file) 143 | print ('saved variables ' + save_file) 144 | # save_file = path_to_save_variables+'_generator_'+str(total_epochs)+'.pt' 145 | # torch.save(model.generator.state_dict(), save_file) 146 | # print ('saved variables ' + save_file) 147 | 148 | 149 | print ('done training') 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | x_size = 784 162 | z_size = 50 163 | batch_size = 20 164 | k = 1 165 | #save params 166 | start_at = 100 167 | save_freq = 300 168 | display_epoch = 3 169 | 170 | # hyper_config = { 171 | # 'x_size': x_size, 172 | # 'z_size': z_size, 173 | # 'act_func': F.tanh,# F.relu, 174 | # 'encoder_arch': [[x_size,200],[200,200],[200,z_size*2]], 175 | # 'decoder_arch': [[z_size,200],[200,200],[200,x_size]], 176 | # 'q_dist': standard, #FFG_LN#,#hnf,#aux_nf,#flow1,#, 177 | # 'cuda': 1 178 | # } 179 | 180 | 181 | 182 | 183 | # Which gpu 184 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 185 | 186 | 187 | 188 | 189 | 190 | hyper_config = { 191 | 'x_size': x_size, 192 | 'z_size': z_size, 193 | 'act_func': F.tanh,# F.relu, 194 | 'encoder_arch': [[x_size,z_size*2]], 195 | 'decoder_arch': [[z_size,200],[200,200],[200,x_size]], 196 | 'q_dist': standard, #FFG_LN#,#hnf,#aux_nf,#flow1,#, 197 | 'cuda': 1 198 | } 199 | 200 | print ('Init model') 201 | model = VAE(hyper_config) 202 | if torch.cuda.is_available(): 203 | model.cuda() 204 | 205 | print('\nModel:', hyper_config,'\n') 206 | 207 | 208 | 209 | 210 | 211 | # path_to_load_variables='' 212 | path_to_save_variables=home+'/Documents/tmp/inference_suboptimality/vae_smallencoder_withflow' #.pt' 213 | # path_to_save_variables=home+'/Documents/tmp/inference_suboptimality/vae_regencoder' #.pt' 214 | # path_to_save_variables=home+'/Documents/tmp/pytorch_vae'+str(epochs)+'.pt' 215 | # path_to_save_variables=this_dir+'/params_'+model_name+'_' 216 | # path_to_save_variables='' 217 | 218 | 219 | # load generator 220 | print ('Load params for decoder') 221 | path_to_load_variables=home+'/Documents/tmp/inference_suboptimality/vae_generator_3280.pt' 222 | model.generator.load_state_dict(torch.load(path_to_load_variables, map_location=lambda storage, loc: storage)) 223 | print ('loaded variables ' + path_to_load_variables) 224 | print () 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | print('\nTraining') 234 | # train_lr_schedule(model=model, train_x=train_x, test_x=test_x, k=k, batch_size=batch_size, 235 | # start_at=start_at, save_freq=save_freq, display_epoch=display_epoch, 236 | # path_to_save_variables=path_to_save_variables) 237 | 238 | 239 | train_encdoer_and_decoder(model=model, train_x=train_x, test_x=test_x, k=k, batch_size=batch_size, 240 | start_at=start_at, save_freq=save_freq, display_epoch=display_epoch, 241 | path_to_save_variables=path_to_save_variables) 242 | 243 | print ('Done.') 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | -------------------------------------------------------------------------------- /models/utils/ais3.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | # This one samples the prior distribution 4 | 5 | # Use this to time everytthing 6 | 7 | 8 | import math 9 | import torch 10 | from torch.autograd import Variable 11 | import numpy as np 12 | 13 | from utils import lognormal2 as lognormal 14 | from utils import log_bernoulli 15 | 16 | import time 17 | 18 | 19 | def test_ais(model, data_x, batch_size, display, k, n_intermediate_dists): 20 | 21 | 22 | def intermediate_dist(t, z, mean, logvar, zeros, batch): 23 | # logp1 = lognormal(z, mean, logvar) #[P,B] 24 | log_prior = lognormal(z, zeros, zeros) #[P,B] 25 | log_likelihood = log_bernoulli(model.decode(z), batch) 26 | # logpT = log_prior + log_likelihood 27 | # log_intermediate_2 = (1-float(t))*logp1 + float(t)*logpT 28 | 29 | log_intermediate_2 = log_prior + float(t)*log_likelihood 30 | 31 | return log_intermediate_2 32 | 33 | 34 | def hmc(z, intermediate_dist_func): 35 | 36 | if torch.cuda.is_available(): 37 | v = Variable(torch.FloatTensor(z.size()).normal_(), volatile=volatile_, requires_grad=requires_grad).cuda() 38 | else: 39 | v = Variable(torch.FloatTensor(z.size()).normal_()) 40 | 41 | v0 = v 42 | z0 = z 43 | 44 | # print (intermediate_dist_func(z)) 45 | # fasdf 46 | gradients = torch.autograd.grad(outputs=intermediate_dist_func(z), inputs=z, 47 | grad_outputs=grad_outputs, 48 | create_graph=True, retain_graph=retain_graph, only_inputs=True)[0] 49 | 50 | gradients = gradients.detach() 51 | 52 | v = v + .5 *step_size*gradients 53 | z = z + step_size*v 54 | 55 | for LF_step in range(n_HMC_steps): 56 | 57 | # log_intermediate_2 = intermediate_dist(t1, z, mean, logvar, zeros, batch) 58 | gradients = torch.autograd.grad(outputs=intermediate_dist_func(z), inputs=z, 59 | grad_outputs=grad_outputs, 60 | create_graph=True, retain_graph=retain_graph, only_inputs=True)[0] 61 | gradients = gradients.detach() 62 | v = v + step_size*gradients 63 | z = z + step_size*v 64 | 65 | # log_intermediate_2 = intermediate_dist(t1, z, mean, logvar, zeros, batch) 66 | gradients = torch.autograd.grad(outputs=intermediate_dist_func(z), inputs=z, 67 | grad_outputs=grad_outputs, 68 | create_graph=True, retain_graph=retain_graph, only_inputs=True)[0] 69 | gradients = gradients.detach() 70 | v = v + .5 *step_size*gradients 71 | 72 | return z0, v0, z, v 73 | 74 | 75 | def mh_step(z0, v0, z, v, step_size, intermediate_dist_func): 76 | 77 | logpv0 = lognormal(v0, zeros, zeros) #[P,B] 78 | hamil_0 = intermediate_dist_func(z0) + logpv0 79 | 80 | logpvT = lognormal(v, zeros, zeros) #[P,B] 81 | hamil_T = intermediate_dist_func(z) + logpvT 82 | 83 | accept_prob = torch.exp(hamil_T - hamil_0) 84 | 85 | if torch.cuda.is_available(): 86 | rand_uni = Variable(torch.FloatTensor(accept_prob.size()).uniform_(), volatile=volatile_, requires_grad=requires_grad).cuda() 87 | else: 88 | rand_uni = Variable(torch.FloatTensor(accept_prob.size()).uniform_()) 89 | 90 | accept = accept_prob > rand_uni 91 | 92 | if torch.cuda.is_available(): 93 | accept = accept.type(torch.FloatTensor).cuda() 94 | else: 95 | accept = accept.type(torch.FloatTensor) 96 | 97 | accept = accept.view(k, model.B, 1) 98 | 99 | z = (accept * z) + ((1-accept) * z0) 100 | 101 | #Adapt step size 102 | avg_acceptance_rate = torch.mean(accept) 103 | 104 | if avg_acceptance_rate.cpu().data.numpy() > .65: 105 | step_size = 1.02 * step_size 106 | else: 107 | step_size = .98 * step_size 108 | 109 | if step_size < 0.0001: 110 | step_size = 0.0001 111 | if step_size > 0.5: 112 | step_size = 0.5 113 | 114 | return z, step_size 115 | 116 | 117 | 118 | 119 | # n_intermediate_dists = 10 120 | n_HMC_steps = 10 121 | step_size = .1 122 | 123 | retain_graph = False 124 | volatile_ = False 125 | requires_grad = False 126 | 127 | time_ = time.time() 128 | 129 | logws = [] 130 | data_index= 0 131 | for i in range(int(len(data_x)/ batch_size)): 132 | 133 | #AIS 134 | 135 | schedule = np.linspace(0.,1.,n_intermediate_dists) 136 | model.B = batch_size 137 | 138 | batch = data_x[data_index:data_index+batch_size] 139 | data_index += batch_size 140 | 141 | B = int(model.B) 142 | 143 | if torch.cuda.is_available(): 144 | batch = Variable(torch.from_numpy(batch).type(model.dtype), volatile=volatile_, requires_grad=requires_grad).cuda() 145 | zeros = Variable(torch.zeros(B, int(model.z_size)).type(model.dtype), volatile=volatile_, requires_grad=requires_grad).cuda() # [B,Z] 146 | logw = Variable(torch.zeros(k, B).type(model.dtype), volatile=True, requires_grad=requires_grad).cuda() 147 | grad_outputs = torch.ones(k, B).cuda() 148 | else: 149 | batch = Variable(torch.from_numpy(batch)) 150 | zeros = Variable(torch.zeros(model.B, model.z_size)) # [B,Z] 151 | logw = Variable(torch.zeros(k, model.B)) 152 | grad_outputs = torch.ones(k, model.B) 153 | 154 | 155 | # #Encode x 156 | # mean, logvar = model.encode(batch) #[B,Z] 157 | # #Init z 158 | # z, logpz, logqz = model.sample(mean, logvar, k=k) #[P,B,Z], [P,B], [P,B] 159 | 160 | 161 | # z, logqz = model.q_dist.forward(k=k, x=batch, logposterior=model.logposterior) 162 | 163 | 164 | # z = Variable(torch.FloatTensor(k, model.B, model.z_size).normal_().type(model.dtype),requires_grad=True) 165 | 166 | z = Variable(torch.FloatTensor(k, B, model.z_size).normal_().type(model.dtype)) 167 | 168 | time_2 = time.time() 169 | for (t0, t1) in zip(schedule[:-1], schedule[1:]): 170 | 171 | 172 | #logw = logw + logpt-1(zt-1) - logpt(zt-1) t, z, mean, logvar, zeros, batch 173 | # log_intermediate_1 = intermediate_dist(t0, z, mean, logvar, zeros, batch) 174 | # log_intermediate_2 = intermediate_dist(t1, z, mean, logvar, zeros, batch) 175 | log_intermediate_1 = intermediate_dist(t=t0, z=z, mean=zeros, logvar=zeros, zeros=zeros, batch=batch) 176 | log_intermediate_2 = intermediate_dist(t=t1, z=z, mean=zeros, logvar=zeros, zeros=zeros, batch=batch) 177 | 178 | 179 | logw += log_intermediate_2 - log_intermediate_1 180 | 181 | # print('ere') 182 | 183 | z = z.data 184 | 185 | z = Variable(z, requires_grad=True) 186 | 187 | 188 | #HMC dynamics 189 | # intermediate_dist_func = lambda aaa: intermediate_dist(t1, aaa, mean, logvar, zeros, batch) 190 | intermediate_dist_func = lambda aaa: intermediate_dist(t1, aaa, zeros, zeros, zeros, batch) 191 | 192 | # print (t1) 193 | time_1 = time.time() 194 | z0, v0, z, v = hmc(z, intermediate_dist_func) 195 | # print (t0, 'time to do hmc', time.time()-time_1) 196 | 197 | 198 | #MH step 199 | z, step_size = mh_step(z0, v0, z, v, step_size, intermediate_dist_func) 200 | 201 | z = z.detach() 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | # print ('time to do whole schedule', time.time()-time_2) 211 | # fasd 212 | #log sum exp 213 | max_ = torch.max(logw,0)[0] #[B] 214 | logw = torch.log(torch.mean(torch.exp(logw - max_), 0)) + max_ #[B] 215 | 216 | logws.append(torch.mean(logw.cpu()).data.numpy()) 217 | 218 | 219 | if i%display==0: 220 | print (i,len(data_x)/ batch_size, np.mean(logws),step_size, time.time()-time_) 221 | 222 | mean_ = np.mean(logws) 223 | print(mean_, 'T:', time.time()-time_) 224 | return mean_ 225 | 226 | 227 | 228 | -------------------------------------------------------------------------------- /models/utils/ais4.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | # This one samples the prior distribution 4 | 5 | # Use this to time everytthing 6 | 7 | 8 | import math 9 | import torch 10 | from torch.autograd import Variable 11 | import numpy as np 12 | 13 | from utils import lognormal2 as lognormal 14 | from utils import log_bernoulli 15 | 16 | import time 17 | 18 | 19 | def test_ais(model, data_x, batch_size, display, k, n_intermediate_dists): 20 | 21 | 22 | def intermediate_dist(t, z, mean, logvar, zeros, batch): 23 | # logp1 = lognormal(z, mean, logvar) #[P,B] 24 | log_prior = lognormal(z, zeros, zeros) #[P,B] 25 | log_likelihood = log_bernoulli(model.generator.decode(z), batch) 26 | # logpT = log_prior + log_likelihood 27 | # log_intermediate_2 = (1-float(t))*logp1 + float(t)*logpT 28 | 29 | log_intermediate_2 = log_prior + float(t)*log_likelihood 30 | 31 | return log_intermediate_2 32 | 33 | 34 | def hmc(z, intermediate_dist_func): 35 | 36 | if torch.cuda.is_available(): 37 | v = Variable(torch.FloatTensor(z.size()).normal_(), volatile=volatile_, requires_grad=requires_grad).cuda() 38 | else: 39 | v = Variable(torch.FloatTensor(z.size()).normal_()) 40 | 41 | v0 = v 42 | z0 = z 43 | 44 | # print (intermediate_dist_func(z)) 45 | # fasdf 46 | gradients = torch.autograd.grad(outputs=intermediate_dist_func(z), inputs=z, 47 | grad_outputs=grad_outputs, 48 | create_graph=True, retain_graph=retain_graph, only_inputs=True)[0] 49 | 50 | gradients = gradients.detach() 51 | 52 | v = v + .5 *step_size*gradients 53 | z = z + step_size*v 54 | 55 | for LF_step in range(n_HMC_steps): 56 | 57 | # log_intermediate_2 = intermediate_dist(t1, z, mean, logvar, zeros, batch) 58 | gradients = torch.autograd.grad(outputs=intermediate_dist_func(z), inputs=z, 59 | grad_outputs=grad_outputs, 60 | create_graph=True, retain_graph=retain_graph, only_inputs=True)[0] 61 | gradients = gradients.detach() 62 | v = v + step_size*gradients 63 | z = z + step_size*v 64 | 65 | # log_intermediate_2 = intermediate_dist(t1, z, mean, logvar, zeros, batch) 66 | gradients = torch.autograd.grad(outputs=intermediate_dist_func(z), inputs=z, 67 | grad_outputs=grad_outputs, 68 | create_graph=True, retain_graph=retain_graph, only_inputs=True)[0] 69 | gradients = gradients.detach() 70 | v = v + .5 *step_size*gradients 71 | 72 | return z0, v0, z, v 73 | 74 | 75 | def mh_step(z0, v0, z, v, step_size, intermediate_dist_func): 76 | 77 | logpv0 = lognormal(v0, zeros, zeros) #[P,B] 78 | hamil_0 = intermediate_dist_func(z0) + logpv0 79 | 80 | logpvT = lognormal(v, zeros, zeros) #[P,B] 81 | hamil_T = intermediate_dist_func(z) + logpvT 82 | 83 | accept_prob = torch.exp(hamil_T - hamil_0) 84 | 85 | if torch.cuda.is_available(): 86 | rand_uni = Variable(torch.FloatTensor(accept_prob.size()).uniform_(), volatile=volatile_, requires_grad=requires_grad).cuda() 87 | else: 88 | rand_uni = Variable(torch.FloatTensor(accept_prob.size()).uniform_()) 89 | 90 | accept = accept_prob > rand_uni 91 | 92 | if torch.cuda.is_available(): 93 | accept = accept.type(torch.FloatTensor).cuda() 94 | else: 95 | accept = accept.type(torch.FloatTensor) 96 | 97 | accept = accept.view(k, int(model.B), 1) 98 | 99 | z = (accept * z) + ((1-accept) * z0) 100 | 101 | #Adapt step size 102 | avg_acceptance_rate = torch.mean(accept) 103 | 104 | if avg_acceptance_rate.cpu().data.numpy() > .65: 105 | step_size = 1.02 * step_size 106 | else: 107 | step_size = .98 * step_size 108 | 109 | if step_size < 0.0001: 110 | step_size = 0.0001 111 | if step_size > 0.5: 112 | step_size = 0.5 113 | 114 | return z, step_size 115 | 116 | 117 | 118 | 119 | # n_intermediate_dists = 10 120 | n_HMC_steps = 10 121 | step_size = .1 122 | 123 | retain_graph = False 124 | volatile_ = False 125 | requires_grad = False 126 | 127 | time_ = time.time() 128 | 129 | logws = [] 130 | data_index= 0 131 | for i in range(int(len(data_x)/ batch_size)): 132 | 133 | #AIS 134 | 135 | schedule = np.linspace(0.,1.,n_intermediate_dists) 136 | model.B = batch_size 137 | 138 | batch = data_x[data_index:data_index+batch_size] 139 | data_index += batch_size 140 | 141 | B = int(model.B) 142 | 143 | if torch.cuda.is_available(): 144 | batch = Variable(torch.from_numpy(batch).type(model.dtype), volatile=volatile_, requires_grad=requires_grad).cuda() 145 | zeros = Variable(torch.zeros(B, int(model.z_size)).type(model.dtype), volatile=volatile_, requires_grad=requires_grad).cuda() # [B,Z] 146 | logw = Variable(torch.zeros(k, B).type(model.dtype), volatile=True, requires_grad=requires_grad).cuda() 147 | grad_outputs = torch.ones(k, B).cuda() 148 | else: 149 | batch = Variable(torch.from_numpy(batch)) 150 | zeros = Variable(torch.zeros(model.B, model.z_size)) # [B,Z] 151 | logw = Variable(torch.zeros(k, model.B)) 152 | grad_outputs = torch.ones(k, model.B) 153 | 154 | 155 | # #Encode x 156 | # mean, logvar = model.encode(batch) #[B,Z] 157 | # #Init z 158 | # z, logpz, logqz = model.sample(mean, logvar, k=k) #[P,B,Z], [P,B], [P,B] 159 | 160 | 161 | # z, logqz = model.q_dist.forward(k=k, x=batch, logposterior=model.logposterior) 162 | 163 | 164 | # z = Variable(torch.FloatTensor(k, model.B, model.z_size).normal_().type(model.dtype),requires_grad=True) 165 | 166 | z = Variable(torch.FloatTensor(k, B, model.z_size).normal_().type(model.dtype)) 167 | 168 | time_2 = time.time() 169 | for (t0, t1) in zip(schedule[:-1], schedule[1:]): 170 | 171 | 172 | #logw = logw + logpt-1(zt-1) - logpt(zt-1) t, z, mean, logvar, zeros, batch 173 | # log_intermediate_1 = intermediate_dist(t0, z, mean, logvar, zeros, batch) 174 | # log_intermediate_2 = intermediate_dist(t1, z, mean, logvar, zeros, batch) 175 | log_intermediate_1 = intermediate_dist(t=t0, z=z, mean=zeros, logvar=zeros, zeros=zeros, batch=batch) 176 | log_intermediate_2 = intermediate_dist(t=t1, z=z, mean=zeros, logvar=zeros, zeros=zeros, batch=batch) 177 | 178 | 179 | logw += log_intermediate_2 - log_intermediate_1 180 | 181 | # print('ere') 182 | 183 | z = z.data 184 | 185 | z = Variable(z, requires_grad=True) 186 | 187 | 188 | #HMC dynamics 189 | # intermediate_dist_func = lambda aaa: intermediate_dist(t1, aaa, mean, logvar, zeros, batch) 190 | intermediate_dist_func = lambda aaa: intermediate_dist(t1, aaa, zeros, zeros, zeros, batch) 191 | 192 | # print (t1) 193 | time_1 = time.time() 194 | z0, v0, z, v = hmc(z, intermediate_dist_func) 195 | # print (t0, 'time to do hmc', time.time()-time_1) 196 | 197 | 198 | #MH step 199 | z, step_size = mh_step(z0, v0, z, v, step_size, intermediate_dist_func) 200 | 201 | z = z.detach() 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | # print ('time to do whole schedule', time.time()-time_2) 211 | # fasd 212 | #log sum exp 213 | max_ = torch.max(logw,0)[0] #[B] 214 | logw = torch.log(torch.mean(torch.exp(logw - max_), 0)) + max_ #[B] 215 | 216 | logws.append(torch.mean(logw.cpu()).data.numpy()) 217 | 218 | 219 | if i%display==0: 220 | print (i,len(data_x)/ batch_size, np.mean(logws),step_size, time.time()-time_) 221 | 222 | mean_ = np.mean(logws) 223 | print(mean_, 'T:', time.time()-time_) 224 | return mean_ 225 | 226 | 227 | 228 | -------------------------------------------------------------------------------- /flow_effect_on_amort_exp/train_encoder_only2.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import numpy as np 5 | import gzip 6 | import time 7 | import pickle 8 | 9 | from os.path import expanduser 10 | home = expanduser("~") 11 | 12 | import sys, os 13 | sys.path.insert(0, '../models') 14 | sys.path.insert(0, '../models/utils') 15 | 16 | 17 | import matplotlib 18 | matplotlib.use('Agg') 19 | import matplotlib.pyplot as plt 20 | 21 | 22 | import torch 23 | from torch.autograd import Variable 24 | import torch.utils.data 25 | import torch.optim as optim 26 | import torch.nn as nn 27 | import torch.nn.functional as F 28 | 29 | from vae_2 import VAE 30 | 31 | from inference_net import standard 32 | 33 | from distributions import Gaussian 34 | from distributions import Flow1 35 | 36 | 37 | 38 | #Load data 39 | print ('Loading data' ) 40 | data_location = home + '/Documents/MNIST_data/' 41 | # with open(data_location + 'binarized_mnist.pkl', 'rb') as f: 42 | # train_x, valid_x, test_x = pickle.load(f) 43 | with open(data_location + 'binarized_mnist.pkl', 'rb') as f: 44 | train_x, valid_x, test_x = pickle.load(f, encoding='latin1') 45 | print ('Train', train_x.shape) 46 | print ('Valid', valid_x.shape) 47 | print ('Test', test_x.shape) 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | def train_encdoer_and_decoder(model, train_x, test_x, k, batch_size, 59 | start_at, save_freq, display_epoch, 60 | path_to_save_variables): 61 | 62 | train_y = torch.from_numpy(np.zeros(len(train_x))) 63 | train_x = torch.from_numpy(train_x).float().type(model.dtype) 64 | 65 | train_ = torch.utils.data.TensorDataset(train_x, train_y) 66 | train_loader = torch.utils.data.DataLoader(train_, batch_size=batch_size, shuffle=True) 67 | 68 | #IWAE paper training strategy 69 | time_ = time.time() 70 | total_epochs = 0 71 | 72 | i_max = 7 73 | 74 | warmup_over_epochs = 100. 75 | 76 | 77 | all_params = [] 78 | for aaa in model.q_dist.parameters(): 79 | all_params.append(aaa) 80 | # for aaa in model.generator.parameters(): 81 | # all_params.append(aaa) 82 | # print (len(all_params), 'number of params') 83 | 84 | print (model.q_dist) 85 | # print (model.q_dist.q) 86 | print (model.generator) 87 | 88 | for i in range(0,i_max+1): 89 | 90 | lr = .001 * 10**(-i/float(i_max)) 91 | print (i, 'LR:', lr) 92 | 93 | optimizer = optim.Adam(all_params, lr=lr) 94 | 95 | epochs = 3**(i) 96 | 97 | for epoch in range(1, epochs + 1): 98 | 99 | for batch_idx, (data, target) in enumerate(train_loader): 100 | 101 | batch = Variable(data)#.type(model.dtype) 102 | 103 | optimizer.zero_grad() 104 | 105 | warmup = total_epochs/warmup_over_epochs 106 | if warmup > 1.: 107 | warmup = 1. 108 | 109 | elbo, logpxz, logqz = model.forward(batch, k=k, warmup=warmup) 110 | 111 | loss = -(elbo) 112 | loss.backward() 113 | optimizer.step() 114 | 115 | total_epochs += 1 116 | 117 | 118 | if total_epochs%display_epoch==0: 119 | print ('Train Epoch: {}/{}'.format(epoch, epochs), 120 | 'total_epochs {}'.format(total_epochs), 121 | 'LL:{:.3f}'.format(-loss.data[0]), 122 | 'logpxz:{:.3f}'.format(logpxz.data[0]), 123 | 'logqz:{:.3f}'.format(logqz.data[0]), 124 | 'warmup:{:.3f}'.format(warmup), 125 | 'T:{:.2f}'.format(time.time()-time_), 126 | ) 127 | time_ = time.time() 128 | 129 | 130 | if total_epochs >= start_at and (total_epochs-start_at)%save_freq==0: 131 | 132 | # save params 133 | save_file = path_to_save_variables+'_encoder_'+str(total_epochs)+'.pt' 134 | torch.save(model.q_dist.state_dict(), save_file) 135 | print ('saved variables ' + save_file) 136 | # save_file = path_to_save_variables+'_generator_'+str(total_epochs)+'.pt' 137 | # torch.save(model.generator.state_dict(), save_file) 138 | # print ('saved variables ' + save_file) 139 | 140 | 141 | 142 | # save params 143 | save_file = path_to_save_variables+'_encoder_'+str(total_epochs)+'.pt' 144 | torch.save(model.q_dist.state_dict(), save_file) 145 | print ('saved variables ' + save_file) 146 | # save_file = path_to_save_variables+'_generator_'+str(total_epochs)+'.pt' 147 | # torch.save(model.generator.state_dict(), save_file) 148 | # print ('saved variables ' + save_file) 149 | 150 | 151 | print ('done training') 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | x_size = 784 164 | z_size = 50 165 | batch_size = 20 166 | k = 1 167 | #save params 168 | start_at = 100 169 | save_freq = 300 170 | display_epoch = 3 171 | 172 | # hyper_config = { 173 | # 'x_size': x_size, 174 | # 'z_size': z_size, 175 | # 'act_func': F.tanh,# F.relu, 176 | # 'encoder_arch': [[x_size,200],[200,200],[200,z_size*2]], 177 | # 'decoder_arch': [[z_size,200],[200,200],[200,x_size]], 178 | # 'q_dist': standard, #FFG_LN#,#hnf,#aux_nf,#flow1,#, 179 | # 'cuda': 1 180 | # } 181 | 182 | 183 | 184 | 185 | # Which gpu 186 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 187 | 188 | 189 | 190 | 191 | 192 | # hyper_config = { 193 | # 'x_size': x_size, 194 | # 'z_size': z_size, 195 | # 'act_func': F.tanh,# F.relu, 196 | # 'encoder_arch': [[x_size,z_size*2]], 197 | # 'decoder_arch': [[z_size,200],[200,200],[200,x_size]], 198 | # 'q_dist': standard, #FFG_LN#,#hnf,#aux_nf,#flow1,#, 199 | # 'cuda': 1 200 | # } 201 | 202 | 203 | #flow1 204 | hyper_config = { 205 | 'x_size': x_size, 206 | 'z_size': z_size, 207 | 'act_func': F.tanh, #F.elu, #,# F.relu, 208 | 'encoder_arch': [[x_size,z_size*2]], 209 | 'decoder_arch': [[z_size,200],[200,200],[200,x_size]], 210 | 'q_dist': standard, #FFG_LN#,#hnf,#aux_nf,#flow1,#, 211 | 'cuda': 1, 212 | 'hnf': 0 213 | } 214 | 215 | hyper_config['q'] = Flow1(hyper_config) 216 | # hyper_config['q'] = Gaussian(hyper_config) 217 | 218 | 219 | print ('Init model') 220 | model = VAE(hyper_config) 221 | if torch.cuda.is_available(): 222 | model.cuda() 223 | 224 | print('\nModel:', hyper_config,'\n') 225 | 226 | 227 | 228 | 229 | 230 | # path_to_load_variables='' 231 | path_to_save_variables=home+'/Documents/tmp/inference_suboptimality/vae_smallencoder_withflow1' #.pt' 232 | # path_to_save_variables=home+'/Documents/tmp/inference_suboptimality/vae_regencoder' #.pt' 233 | # path_to_save_variables=home+'/Documents/tmp/pytorch_vae'+str(epochs)+'.pt' 234 | # path_to_save_variables=this_dir+'/params_'+model_name+'_' 235 | # path_to_save_variables='' 236 | 237 | 238 | # load generator 239 | print ('Load params for decoder') 240 | path_to_load_variables=home+'/Documents/tmp/inference_suboptimality/vae_generator_3280.pt' 241 | model.generator.load_state_dict(torch.load(path_to_load_variables, map_location=lambda storage, loc: storage)) 242 | print ('loaded variables ' + path_to_load_variables) 243 | print () 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | print('\nTraining') 253 | # train_lr_schedule(model=model, train_x=train_x, test_x=test_x, k=k, batch_size=batch_size, 254 | # start_at=start_at, save_freq=save_freq, display_epoch=display_epoch, 255 | # path_to_save_variables=path_to_save_variables) 256 | 257 | 258 | train_encdoer_and_decoder(model=model, train_x=train_x, test_x=test_x, k=k, batch_size=batch_size, 259 | start_at=start_at, save_freq=save_freq, display_epoch=display_epoch, 260 | path_to_save_variables=path_to_save_variables) 261 | 262 | print ('Done.') 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | -------------------------------------------------------------------------------- /gaps_over_training_exp/plot_over_training.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | import time 6 | import numpy as np 7 | import pickle 8 | from os.path import expanduser 9 | home = expanduser("~") 10 | 11 | import matplotlib 12 | matplotlib.use('agg') 13 | import matplotlib.pyplot as plt 14 | 15 | import csv 16 | 17 | # epochs=['100','1000','1900','2800'] 18 | # epochs=['100','1000','2200'] 19 | # epochs=['100','2800'] 20 | epochs = [] 21 | bounds = ['logpx', 'L_q_star', 'L_q'] 22 | 23 | 24 | values = {} 25 | values['training'] = {} 26 | values['validation'] = {} 27 | # for epoch in epochs: 28 | # values['training'][epoch] = {} 29 | # values['validation'][epoch] = {} 30 | # for bound in bounds: 31 | # for epoch in epochs: 32 | # values['training'][epoch][bound] = {} 33 | # values['validation'][epoch][bound] = {} 34 | 35 | 36 | 37 | #read values 38 | # results_file = 'results_50' 39 | # results_file = 'results_2_fashion' 40 | # results_file = 'results_10_fashion' 41 | # results_file = 'results_100_fashion' 42 | results_file = 'results_20_fashion_binarized_2' 43 | 44 | file_ = home+'/Documents/tmp/inference_suboptimality/over_training_exps/'+results_file+'.txt' 45 | 46 | max_value = None 47 | min_value = None 48 | 49 | with open(file_, 'r') as f: 50 | reader = csv.reader(f, delimiter=' ') 51 | for row in reader: 52 | if len(row) and row[0] in ['training','validation']: 53 | # print (row) 54 | dataset = row[0] 55 | epoch = row[1] 56 | bound = row[2] 57 | value = row[3] 58 | 59 | if epoch not in values[dataset]: 60 | values[dataset][epoch] = {} 61 | if epoch not in epochs: 62 | epochs.append(epoch) 63 | print (epoch) 64 | 65 | values[dataset][epoch][bound] = value 66 | 67 | if max_value == None or float(value) > max_value: 68 | max_value = float(value) 69 | if min_value == None or float(value) < min_value: 70 | min_value = float(value) 71 | 72 | # print (values) 73 | 74 | # #sort epochs 75 | # epochs.sort() 76 | 77 | # print (epochs) 78 | # fads 79 | 80 | #convert to list 81 | training_plot = {} 82 | for bound in bounds: 83 | values_to_plot = [] 84 | for epoch in epochs: 85 | values_to_plot.append(float(values['training'][epoch][bound])) 86 | training_plot[bound] = values_to_plot 87 | print (training_plot) 88 | 89 | 90 | validation_plot = {} 91 | for bound in bounds: 92 | values_to_plot = [] 93 | for epoch in epochs: 94 | values_to_plot.append(float(values['validation'][epoch][bound])) 95 | validation_plot[bound] = values_to_plot 96 | print (validation_plot) 97 | 98 | 99 | epochs_float = [float(x) for x in epochs] 100 | 101 | 102 | rows = 1 103 | cols = 2 104 | 105 | legend=False 106 | 107 | fig = plt.figure(figsize=(8+cols,2+rows), facecolor='white') 108 | 109 | # ylimits = [-110, -84] 110 | ylimits = [min_value, max_value] 111 | 112 | 113 | 114 | 115 | # Training set 116 | ax = plt.subplot2grid((rows,cols), (0,0), frameon=False) 117 | 118 | ax.set_title('Training Set',family='serif') 119 | 120 | # for bound in bounds: 121 | # ax.plot(epochs_float,training_plot[bound]) #, label=legends[i], c=colors[i], ls=line_styles[i]) 122 | 123 | ax.fill_between(epochs_float, training_plot['logpx'], training_plot['L_q_star']) 124 | ax.fill_between(epochs_float, training_plot['L_q_star'], training_plot['L_q']) 125 | 126 | ax.set_ylim(ylimits) 127 | ax.grid(True, alpha=.1) 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | # Validation set 137 | ax = plt.subplot2grid((rows,cols), (0,1), frameon=False) 138 | 139 | ax.set_title('Validation Set',family='serif') 140 | 141 | # for bound in bounds: 142 | # ax.plot(epochs_float,validation_plot[bound]) #, label=legends[i], c=colors[i], ls=line_styles[i]) 143 | 144 | ax.grid(True, alpha=.1) 145 | 146 | ax.fill_between(epochs_float, validation_plot['logpx'], validation_plot['L_q_star']) 147 | ax.fill_between(epochs_float, validation_plot['L_q_star'], validation_plot['L_q']) 148 | 149 | ax.set_ylim(ylimits) 150 | 151 | 152 | 153 | # ax.set_yticks() 154 | 155 | # family='serif' 156 | # fontProperties = {'family':'serif'} 157 | # ax.set_xticklabels(ax.get_xticks(), fontProperties) 158 | 159 | 160 | 161 | 162 | 163 | 164 | name_file = home+'/Documents/tmp/inference_suboptimality/over_training_exps/'+results_file+'.png' 165 | name_file2 = home+'/Documents/tmp/inference_suboptimality/over_training_exps/'+results_file+'.pdf' 166 | # name_file = home+'/Documents/tmp/plot.png' 167 | plt.savefig(name_file) 168 | plt.savefig(name_file2) 169 | print ('Saved fig', name_file) 170 | print ('Saved fig', name_file2) 171 | 172 | 173 | 174 | print ('DONE') 175 | fdsa 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | # # models = [standard,flow1,aux_nf]#,hnf] 210 | # # models = [standard,standard_large_encoder]#, aux_nf aux_large_encoder]#,hnf] 211 | # models = [standard,aux_nf]#, aux_nf aux_large_encoder]#,hnf] 212 | 213 | 214 | # # model_names = ['standard','flow1','aux_nf','hnf'] 215 | # # model_names = ['VAE','NF','Aux+NF']#,'HNF'] 216 | # # model_names = ['FFG','Flow']#,'HNF'] 217 | # # model_names = ['FFG','Flow']#,'HNF'] 218 | # model_names = ['FFG','Flow']# 'aux_nf','aux_large_encoder']#,'HNF'] 219 | 220 | 221 | 222 | 223 | 224 | # # legends = ['IW train', 'IW test', 'AIS train', 'AIS test'] 225 | # # legends = ['VAE train', 'VAE test', 'IW train', 'IW test', 'AIS train', 'AIS test'] 226 | 227 | # legends = ['VAE train', 'VAE test', 'IW train', 'IW test', 'AIS train', 'AIS test'] 228 | 229 | 230 | # colors = ['blue', 'blue', 'green', 'green', 'red', 'red'] 231 | 232 | # line_styles = [':', '-', ':', '-', ':', '-'] 233 | 234 | 235 | 236 | 237 | rows = 1 238 | cols = 1 239 | 240 | legend=False 241 | 242 | fig = plt.figure(figsize=(2+cols,2+rows), facecolor='white') 243 | 244 | # Get y-axis limits 245 | min_ = None 246 | max_ = None 247 | for m in range(len(models)): 248 | for i in range(len(legends)): 249 | if i == 1: 250 | continue 251 | this_min = np.min(models[m][i]) 252 | this_max = np.max(models[m][i]) 253 | if min_ ==None or this_min < min_: 254 | min_ = this_min 255 | if max_ ==None or this_max > max_: 256 | max_ = this_max 257 | 258 | min_ -= .1 259 | max_ += .1 260 | # print (min_) 261 | # print (max_) 262 | ylimits = [min_, max_] 263 | xlimits = [x[0], x[-1]] 264 | 265 | # fasd 266 | 267 | # ax.plot(x,hnf_ais, label='hnf_ais') 268 | # ax.set_yticks([]) 269 | # ax.set_xticks([]) 270 | # if samp_i==0: ax.annotate('Sample', xytext=(.3, 1.1), xy=(0, 1), textcoords='axes fraction') 271 | 272 | for m in range(len(models)): 273 | ax = plt.subplot2grid((rows,cols), (0,m), frameon=False) 274 | for i in range(len(legends)): 275 | if i == 1: 276 | continue 277 | ax.set_title(model_names[m],family='serif') 278 | ax.plot(x,models[m][i], label=legends[i], c=colors[i], ls=line_styles[i]) 279 | plt.legend(fontsize=6) 280 | # ax.set(adjustable='box-forced', aspect='equal') 281 | plt.yticks(size=6) 282 | # plt.xticks(x,size=6) 283 | plt.xticks([400,1300,2200,3100],size=6) 284 | 285 | # ax.set_xlim(xlimits) 286 | ax.set_ylim(ylimits) 287 | ax.set_xlim(xlimits) 288 | 289 | ax.set_xlabel('Epochs',size=6) 290 | if m==0: 291 | ax.set_ylabel('Log-Likelihood',size=6) 292 | 293 | 294 | ax.grid(True, alpha=.1) 295 | 296 | 297 | # m+=1 298 | # ax = plt.subplot2grid((rows,cols), (0,m), frameon=False) 299 | # ax.set_title('AIS_test') 300 | # for m in range(len(models)): 301 | # ax.plot(x,models[m][3], label=model_names[m]) 302 | # plt.legend(fontsize=4) 303 | # plt.yticks(size=6) 304 | 305 | 306 | 307 | 308 | 309 | # plt.gca().set_aspect('equal', adjustable='box') 310 | name_file = home+'/Documents/tmp/plot.png' 311 | plt.savefig(name_file) 312 | print ('Saved fig', name_file) 313 | 314 | name_file = home+'/Documents/tmp/plot.eps' 315 | plt.savefig(name_file) 316 | print ('Saved fig', name_file) 317 | 318 | 319 | name_file = home+'/Documents/tmp/plot.pdf' 320 | plt.savefig(name_file) 321 | print ('Saved fig', name_file) 322 | 323 | 324 | 325 | 326 | 327 | 328 | 329 | 330 | 331 | 332 | -------------------------------------------------------------------------------- /flow_effect_on_amort_exp/compute_gaps.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | import numpy as np 7 | import gzip 8 | import time 9 | import pickle 10 | 11 | from os.path import expanduser 12 | home = expanduser("~") 13 | 14 | import sys, os 15 | sys.path.insert(0, '../models') 16 | sys.path.insert(0, '../models/utils') 17 | 18 | 19 | import matplotlib 20 | matplotlib.use('Agg') 21 | import matplotlib.pyplot as plt 22 | 23 | 24 | import torch 25 | from torch.autograd import Variable 26 | import torch.utils.data 27 | import torch.optim as optim 28 | import torch.nn as nn 29 | import torch.nn.functional as F 30 | 31 | 32 | from vae_2 import VAE 33 | 34 | 35 | # from approx_posteriors_v6 import standard 36 | from inference_net import standard 37 | 38 | # from ais3 import test_ais 39 | 40 | 41 | # from optimize_local import optimize_local_gaussian 42 | 43 | 44 | from optimize_local_q import optimize_local_q_dist 45 | 46 | 47 | 48 | 49 | from distributions import Gaussian 50 | from distributions import Flow 51 | from distributions import Flow1 52 | 53 | 54 | 55 | 56 | 57 | def test_vae(model, data_x, batch_size, display, k): 58 | 59 | time_ = time.time() 60 | elbos = [] 61 | data_index= 0 62 | for i in range(int(len(data_x)/ batch_size)): 63 | 64 | batch = data_x[data_index:data_index+batch_size] 65 | data_index += batch_size 66 | 67 | batch = Variable(torch.from_numpy(batch)).type(model.dtype) 68 | 69 | elbo, logpxz, logqz = model.forward2(batch, k=k) 70 | 71 | elbos.append(elbo.data[0]) 72 | 73 | # if i%display==0: 74 | # print (i,len(data_x)/ batch_size, np.mean(elbos)) 75 | 76 | mean_ = np.mean(elbos) 77 | # print(mean_, 'T:', time.time()-time_) 78 | 79 | return mean_#, time.time()-time_ 80 | 81 | 82 | 83 | 84 | 85 | def test(model, data_x, batch_size, display, k): 86 | 87 | time_ = time.time() 88 | elbos = [] 89 | data_index= 0 90 | for i in range(int(len(data_x)/ batch_size)): 91 | 92 | batch = data_x[data_index:data_index+batch_size] 93 | data_index += batch_size 94 | 95 | batch = Variable(torch.from_numpy(batch)).type(model.dtype) 96 | 97 | elbo, logpxz, logqz = model(batch, k=k) 98 | 99 | elbos.append(elbo.data[0]) 100 | 101 | # if i%display==0: 102 | # print (i,len(data_x)/ batch_size, np.mean(elbos)) 103 | 104 | mean_ = np.mean(elbos) 105 | # print(mean_, 'T:', time.time()-time_) 106 | 107 | return mean_#, time.time()-time_ 108 | 109 | 110 | 111 | 112 | 113 | ########################### 114 | # Load data 115 | 116 | # print ('Loading data') 117 | # with open(home+'/Documents/MNIST_data/mnist.pkl','rb') as f: 118 | # mnist_data = pickle.load(f, encoding='latin1') 119 | # train_x = mnist_data[0][0] 120 | # valid_x = mnist_data[1][0] 121 | # test_x = mnist_data[2][0] 122 | # train_x = np.concatenate([train_x, valid_x], axis=0) 123 | # print (train_x.shape) 124 | 125 | #Load data 126 | print ('Loading data' ) 127 | data_location = home + '/Documents/MNIST_data/' 128 | # with open(data_location + 'binarized_mnist.pkl', 'rb') as f: 129 | # train_x, valid_x, test_x = pickle.load(f) 130 | with open(data_location + 'binarized_mnist.pkl', 'rb') as f: 131 | train_x, valid_x, test_x = pickle.load(f, encoding='latin1') 132 | print ('Train', train_x.shape) 133 | print ('Valid', valid_x.shape) 134 | print ('Test', test_x.shape) 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | ########################### 147 | # Load model 148 | 149 | 150 | 151 | # this_ckt_file = path_to_save_variables + str(ckt) + '.pt' 152 | # model.load_params(path_to_load_variables=this_ckt_file) 153 | # print ('Init model') 154 | # model = VAE(hyper_config) 155 | # if torch.cuda.is_available(): 156 | # model.cuda() 157 | 158 | # print('\nModel:', hyper_config,'\n') 159 | 160 | 161 | x_size = 784 162 | z_size = 50 163 | # batch_size = 20 164 | # k = 1 165 | #save params 166 | # start_at = 100 167 | # save_freq = 300 168 | # display_epoch = 3 169 | 170 | # hyper_config = { 171 | # 'x_size': x_size, 172 | # 'z_size': z_size, 173 | # 'act_func': F.tanh,# F.relu, 174 | # 'encoder_arch': [[x_size,200],[200,200],[200,z_size*2]], 175 | # 'decoder_arch': [[z_size,200],[200,200],[200,x_size]], 176 | # 'q_dist': standard, #FFG_LN#,#hnf,#aux_nf,#flow1,#, 177 | # 'cuda': 1 178 | # } 179 | 180 | 181 | hyper_config = { 182 | 'x_size': x_size, 183 | 'z_size': z_size, 184 | 'act_func': F.tanh,# F.relu, 185 | 'encoder_arch': [[x_size,z_size*2]], 186 | 'decoder_arch': [[z_size,200],[200,200],[200,x_size]], 187 | 'q_dist': standard, #FFG_LN#,#hnf,#aux_nf,#flow1,#, 188 | 'cuda': 1, 189 | 'hnf':0 190 | } 191 | 192 | 193 | # q = Gaussian(hyper_config) 194 | # q = Flow(hyper_config) 195 | q = Flow1(hyper_config) 196 | hyper_config['q'] = q 197 | 198 | 199 | 200 | 201 | # Which gpu 202 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 203 | 204 | print ('Init model') 205 | model = VAE(hyper_config) 206 | if torch.cuda.is_available(): 207 | model.cuda() 208 | print('\nModel:', hyper_config,'\n') 209 | 210 | 211 | 212 | print ('Load params for decoder') 213 | path_to_load_variables=home+'/Documents/tmp/inference_suboptimality/vae_generator_3280.pt' 214 | model.generator.load_state_dict(torch.load(path_to_load_variables, map_location=lambda storage, loc: storage)) 215 | print ('loaded variables ' + path_to_load_variables) 216 | print () 217 | 218 | 219 | 220 | compute_local_opt = 0 221 | compute_amort = 1 222 | 223 | 224 | if compute_amort: 225 | 226 | print ('Load params for encoder') 227 | # path_to_load_variables=home+'/Documents/tmp/inference_suboptimality/vae_encoder_100.pt' 228 | # path_to_load_variables=home+'/Documents/tmp/inference_suboptimality/vae_encoder_3280.pt' 229 | # path_to_load_variables=home+'/Documents/tmp/inference_suboptimality/vae_smallencoder_encoder_3280.pt' 230 | # path_to_load_variables=home+'/Documents/tmp/inference_suboptimality/vae_regencoder_encoder_3280.pt' 231 | # path_to_load_variables=home+'/Documents/tmp/inference_suboptimality/vae_smallencoder_withflow_encoder_3280.pt' 232 | # path_to_load_variables=home+'/Documents/tmp/inference_suboptimality/vae_smallencoder_withflow1_encoder_1000.pt' 233 | path_to_load_variables=home+'/Documents/tmp/inference_suboptimality/vae_smallencoder_withflow1_encoder_3280.pt' 234 | 235 | 236 | model.q_dist.load_state_dict(torch.load(path_to_load_variables, map_location=lambda storage, loc: storage)) 237 | print ('loaded variables ' + path_to_load_variables) 238 | 239 | 240 | 241 | 242 | 243 | 244 | ########################### 245 | # For each datapoint, compute L[q], L[q*], log p(x) 246 | 247 | # # log it 248 | # with open(experiment_log, "a") as myfile: 249 | # myfile.write('Checkpoint' +str(ckt)+'\n') 250 | 251 | # start_time = time.time() 252 | 253 | # n_data = 1000 #2 #100 254 | n_data = 500 #2 #100 255 | 256 | 257 | vaes = [] 258 | iwaes = [] 259 | vaes_flex = [] 260 | iwaes_flex = [] 261 | 262 | 263 | 264 | if compute_local_opt: 265 | print ('optmizing local') 266 | # for i in range(len(train_x[:n_data])): 267 | for i in range(n_data): 268 | 269 | print (i+500) 270 | 271 | x = train_x[i+500] 272 | x = Variable(torch.from_numpy(x)).type(model.dtype) 273 | x = x.view(1,784) 274 | 275 | logposterior = lambda aa: model.logposterior_func2(x=x,z=aa) 276 | 277 | 278 | # # flex_model = aux_nf__(model, hyper_config) 279 | # # if torch.cuda.is_available(): 280 | # # flex_model.cuda() 281 | # # vae, iwae = flex_model.train_and_eval(logposterior=logposterior, model=model, x=x) 282 | 283 | 284 | # vae, iwae = optimize_local_expressive(logposterior, model, x) 285 | # print (vae.data.cpu().numpy(),iwae.data.cpu().numpy(),'flex') 286 | # vaes_flex.append(vae.data.cpu().numpy()) 287 | # iwaes_flex.append(iwae.data.cpu().numpy()) 288 | 289 | # q_local = Gaussian(hyper_config) #, mean, logvar) 290 | # q_local = Flow(hyper_config).cuda()#, mean, logvar) 291 | q_local = Flow1(hyper_config).cuda() #, mean, logvar) 292 | 293 | 294 | # print (q_local) 295 | 296 | # vae, iwae = optimize_local_gaussian(logposterior, model, x) 297 | vae, iwae = optimize_local_q_dist(logposterior, hyper_config, x, q_local) 298 | print (vae.data.cpu().numpy(),iwae.data.cpu().numpy(),'reg') 299 | vaes.append(vae.data.cpu().numpy()) 300 | iwaes.append(iwae.data.cpu().numpy()) 301 | 302 | print() 303 | print ('opt vae',np.mean(vaes)) 304 | print ('opt iwae',np.mean(iwaes)) 305 | print() 306 | 307 | # print ('opt vae flex',np.mean(vaes_flex)) 308 | # print ('opt iwae flex',np.mean(iwaes_flex)) 309 | # print() 310 | 311 | if compute_amort: 312 | VAE_train = test_vae(model=model, data_x=train_x[:n_data], batch_size=np.minimum(n_data, 50), display=10, k=5000) 313 | IW_train = test(model=model, data_x=train_x[:n_data], batch_size=np.minimum(n_data, 50), display=10, k=5000) 314 | print ('amortized VAE',VAE_train) 315 | print ('amortized IW',IW_train) 316 | 317 | 318 | # print() 319 | # AIS_train = test_ais(model=model, data_x=train_x[:n_data], batch_size=n_data, display=2, k=50, n_intermediate_dists=500) 320 | # print ('AIS_train',AIS_train) 321 | 322 | 323 | 324 | # print() 325 | # print() 326 | # print ('AIS_train',AIS_train) 327 | # print() 328 | # print ('opt vae flex',np.mean(vaes_flex)) 329 | # # print() 330 | # print ('opt vae',np.mean(vaes)) 331 | # # print() 332 | # print ('amortized VAE',VAE_train) 333 | # print() 334 | 335 | 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | 344 | 345 | -------------------------------------------------------------------------------- /decoder_sizes_exp/train_mnist.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import numpy as np 5 | import gzip 6 | import time 7 | import pickle 8 | 9 | from os.path import expanduser 10 | home = expanduser("~") 11 | 12 | import sys, os 13 | sys.path.insert(0, '../models') 14 | sys.path.insert(0, '../models/utils') 15 | 16 | 17 | import matplotlib 18 | matplotlib.use('Agg') 19 | import matplotlib.pyplot as plt 20 | 21 | 22 | import torch 23 | from torch.autograd import Variable 24 | import torch.utils.data 25 | import torch.optim as optim 26 | import torch.nn as nn 27 | import torch.nn.functional as F 28 | 29 | # from ais2 import test_ais 30 | 31 | # from pytorch_vae_v6 import VAE 32 | 33 | # from vae_1 import VAE 34 | from vae_2 import VAE 35 | 36 | 37 | # from approx_posteriors_v6 import FFG_LN 38 | # from approx_posteriors_v6 import ANF_LN 39 | # import argparse 40 | # from approx_posteriors_v6 import standard 41 | from inference_net import standard 42 | 43 | from distributions import Gaussian 44 | from distributions import Flow 45 | 46 | 47 | 48 | 49 | 50 | 51 | #FASHION 52 | # def load_mnist(path, kind='train'): 53 | 54 | # images_path = os.path.join(path, 55 | # '%s-images-idx3-ubyte.gz' 56 | # % kind) 57 | 58 | # with gzip.open(images_path, 'rb') as imgpath: 59 | # images = np.frombuffer(imgpath.read(), dtype=np.uint8, 60 | # offset=16).reshape(-1, 784) 61 | 62 | # return images#, labels 63 | 64 | 65 | # path = home+'/Documents/fashion_MNIST' 66 | 67 | # train_x = load_mnist(path=path) 68 | # test_x = load_mnist(path=path, kind='t10k') 69 | 70 | # train_x = train_x / 255. 71 | # test_x = test_x / 255. 72 | 73 | # print (train_x.shape) 74 | # print (test_x.shape) 75 | 76 | # print (np.max(train_x)) 77 | # print (test_x[3]) 78 | # fsda 79 | 80 | 81 | # print ('Loading data') 82 | # with open(home+'/Documents/MNIST_data/mnist.pkl','rb') as f: 83 | # mnist_data = pickle.load(f, encoding='latin1') 84 | # train_x = mnist_data[0][0] 85 | # valid_x = mnist_data[1][0] 86 | # test_x = mnist_data[2][0] 87 | # train_x = np.concatenate([train_x, valid_x], axis=0) 88 | # print (train_x.shape) 89 | 90 | #Load data 91 | print ('Loading data' ) 92 | data_location = home + '/Documents/MNIST_data/' 93 | # with open(data_location + 'binarized_mnist.pkl', 'rb') as f: 94 | # train_x, valid_x, test_x = pickle.load(f) 95 | with open(data_location + 'binarized_mnist.pkl', 'rb') as f: 96 | train_x, valid_x, test_x = pickle.load(f, encoding='latin1') 97 | print ('Train', train_x.shape) 98 | print ('Valid', valid_x.shape) 99 | print ('Test', test_x.shape) 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | def train_encoder_and_decoder(model, train_x, test_x, k, batch_size, 108 | start_at, save_freq, display_epoch, 109 | path_to_save_variables): 110 | 111 | train_y = torch.from_numpy(np.zeros(len(train_x))) 112 | train_x = torch.from_numpy(train_x).float().type(model.dtype) 113 | 114 | train_ = torch.utils.data.TensorDataset(train_x, train_y) 115 | train_loader = torch.utils.data.DataLoader(train_, batch_size=batch_size, shuffle=True) 116 | 117 | #IWAE paper training strategy 118 | time_ = time.time() 119 | total_epochs = 0 120 | 121 | i_max = 7 122 | 123 | warmup_over_epochs = 100. 124 | 125 | 126 | all_params = [] 127 | for aaa in model.q_dist.parameters(): 128 | all_params.append(aaa) 129 | for aaa in model.generator.parameters(): 130 | all_params.append(aaa) 131 | # print (len(all_params), 'number of params') 132 | 133 | print (model.q_dist) 134 | # print (model.q_dist.q) 135 | print (model.generator) 136 | 137 | # fads 138 | 139 | 140 | for i in range(0,i_max+1): 141 | 142 | lr = .001 * 10**(-i/float(i_max)) 143 | print (i, 'LR:', lr) 144 | 145 | # # optimizer = optim.Adam(model.parameters(), lr=lr) 146 | # print (model.q_dist) 147 | # print (model.generator) 148 | # print (model.q_dist.parameters()) 149 | # print (model.generator.parameters()) 150 | 151 | # print ('Encoder') 152 | # for aaa in model.q_dist.parameters(): 153 | # # print (aaa) 154 | # print (aaa.size()) 155 | # print ('Decoder') 156 | # for aaa in model.generator.parameters(): 157 | # # print (aaa) 158 | # print (aaa.size()) 159 | # # fasdfs 160 | # fads 161 | 162 | 163 | optimizer = optim.Adam(all_params, lr=lr) 164 | 165 | epochs = 3**(i) 166 | 167 | for epoch in range(1, epochs + 1): 168 | 169 | for batch_idx, (data, target) in enumerate(train_loader): 170 | 171 | batch = Variable(data)#.type(model.dtype) 172 | 173 | optimizer.zero_grad() 174 | 175 | warmup = total_epochs/warmup_over_epochs 176 | if warmup > 1.: 177 | warmup = 1. 178 | 179 | elbo, logpxz, logqz = model.forward(batch, k=k, warmup=warmup) 180 | 181 | loss = -(elbo) 182 | loss.backward() 183 | optimizer.step() 184 | 185 | total_epochs += 1 186 | 187 | 188 | if total_epochs%display_epoch==0: 189 | print ('Train Epoch: {}/{}'.format(epoch, epochs), 190 | 'total_epochs {}'.format(total_epochs), 191 | 'LL:{:.3f}'.format(-loss.data[0]), 192 | 'logpxz:{:.3f}'.format(logpxz.data[0]), 193 | 'logqz:{:.3f}'.format(logqz.data[0]), 194 | 'warmup:{:.3f}'.format(warmup), 195 | 'T:{:.2f}'.format(time.time()-time_), 196 | ) 197 | time_ = time.time() 198 | 199 | 200 | if total_epochs >= start_at and (total_epochs-start_at)%save_freq==0: 201 | 202 | # save params 203 | save_file = path_to_save_variables+'_encoder_'+str(total_epochs)+'.pt' 204 | torch.save(model.q_dist.state_dict(), save_file) 205 | print ('saved variables ' + save_file) 206 | save_file = path_to_save_variables+'_generator_'+str(total_epochs)+'.pt' 207 | torch.save(model.generator.state_dict(), save_file) 208 | print ('saved variables ' + save_file) 209 | 210 | 211 | 212 | # save params 213 | save_file = path_to_save_variables+'_encoder_'+str(total_epochs)+'.pt' 214 | torch.save(model.q_dist.state_dict(), save_file) 215 | print ('saved variables ' + save_file) 216 | save_file = path_to_save_variables+'_generator_'+str(total_epochs)+'.pt' 217 | torch.save(model.generator.state_dict(), save_file) 218 | print ('saved variables ' + save_file) 219 | 220 | 221 | print ('done training') 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | # Which gpu 235 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' 236 | 237 | 238 | x_size = 784 239 | z_size = 50 240 | batch_size = 20 241 | k = 1 242 | #save params 243 | start_at = 100 244 | save_freq = 500 245 | display_epoch = 3 246 | 247 | # hyper_config = { 248 | # 'x_size': x_size, 249 | # 'z_size': z_size, 250 | # 'act_func': F.tanh,# F.relu, 251 | # 'encoder_arch': [[x_size,200],[200,200],[200,z_size*2]], 252 | # 'decoder_arch': [[z_size,200],[200,200],[200,x_size]], 253 | # 'q_dist': standard, #FFG_LN#,#hnf,#aux_nf,#flow1,#, 254 | # 'cuda': 1 255 | # } 256 | 257 | # hyper_config = { 258 | # 'x_size': x_size, 259 | # 'z_size': z_size, 260 | # 'act_func': F.tanh,# F.relu, 261 | # 'encoder_arch': [[x_size,200],[200,200],[200,z_size*2]], 262 | # 'decoder_arch': [[z_size,x_size]], 263 | # 'q_dist': standard, #FFG_LN#,#hnf,#aux_nf,#flow1,#, 264 | # 'cuda': 1 265 | # } 266 | 267 | 268 | hyper_config = { 269 | 'x_size': x_size, 270 | 'z_size': z_size, 271 | 'act_func': F.tanh,# F.relu, 272 | 'encoder_arch': [[x_size,200],[200,200],[200,z_size*2]], 273 | 'decoder_arch': [[z_size,200],[200,200],[200,200],[200,200],[200,x_size]], 274 | 'q_dist': standard, #FFG_LN#,#hnf,#aux_nf,#flow1,#, 275 | 'cuda': 1 276 | } 277 | 278 | 279 | q = Gaussian(hyper_config) 280 | # q = Flow(hyper_config) 281 | hyper_config['q'] = q 282 | 283 | 284 | print ('Init model') 285 | model = VAE(hyper_config) 286 | if torch.cuda.is_available(): 287 | model.cuda() 288 | 289 | print('\nModel:', hyper_config,'\n') 290 | 291 | 292 | 293 | 294 | # path_to_load_variables='' 295 | path_to_save_variables=home+'/Documents/tmp/inference_suboptimality/decoder_exps/hidden_layers_4' #.pt' 296 | # path_to_save_variables=home+'/Documents/tmp/pytorch_vae'+str(epochs)+'.pt' 297 | # path_to_save_variables=this_dir+'/params_'+model_name+'_' 298 | # path_to_save_variables='' 299 | 300 | 301 | 302 | print('\nTraining') 303 | # train_lr_schedule(model=model, train_x=train_x, test_x=test_x, k=k, batch_size=batch_size, 304 | # start_at=start_at, save_freq=save_freq, display_epoch=display_epoch, 305 | # path_to_save_variables=path_to_save_variables) 306 | 307 | 308 | train_encoder_and_decoder(model=model, train_x=train_x, test_x=test_x, k=k, batch_size=batch_size, 309 | start_at=start_at, save_freq=save_freq, display_epoch=display_epoch, 310 | path_to_save_variables=path_to_save_variables) 311 | 312 | print ('Done.') 313 | 314 | 315 | 316 | 317 | 318 | 319 | 320 | 321 | 322 | 323 | 324 | 325 | 326 | 327 | 328 | 329 | 330 | 331 | 332 | 333 | 334 | 335 | 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | 344 | 345 | 346 | 347 | 348 | 349 | 350 | -------------------------------------------------------------------------------- /test_different_dists/eval.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | import numpy as np 10 | import gzip 11 | import time 12 | import pickle 13 | 14 | from os.path import expanduser 15 | home = expanduser("~") 16 | 17 | import sys, os 18 | sys.path.insert(0, '../models') 19 | sys.path.insert(0, '../models/utils') 20 | 21 | 22 | import matplotlib 23 | matplotlib.use('Agg') 24 | import matplotlib.pyplot as plt 25 | 26 | 27 | import torch 28 | from torch.autograd import Variable 29 | import torch.utils.data 30 | import torch.optim as optim 31 | import torch.nn as nn 32 | import torch.nn.functional as F 33 | 34 | 35 | from vae_2 import VAE 36 | 37 | 38 | # from approx_posteriors_v6 import standard 39 | from inference_net import standard 40 | 41 | # from ais3 import test_ais 42 | 43 | 44 | # from optimize_local import optimize_local_gaussian 45 | 46 | 47 | import csv 48 | 49 | from optimize_local_q import optimize_local_q_dist 50 | 51 | 52 | 53 | 54 | from distributions import Gaussian 55 | from distributions import Flow 56 | from distributions import HNF 57 | 58 | 59 | 60 | 61 | 62 | gpu_to_use = sys.argv[1] 63 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_to_use #'1' 64 | 65 | q_name = sys.argv[2] 66 | 67 | hnf = 0 68 | if q_name == 'Gaus': 69 | q = Gaussian 70 | elif q_name == 'Flow': 71 | q = Flow 72 | elif q_name == 'HNF': 73 | q = HNF 74 | hnf = 1 75 | else: 76 | dfadfas 77 | 78 | 79 | # path_to_save_variables=home+'/Documents/tmp/inference_suboptimality/fashion_params/100warm_10k_binarized_fashion_'+q_name #.pt' 80 | path_to_save_variables=home+'/Documents/tmp/inference_suboptimality/fashion_params/10k_binarized_fashion2_SSE_'+q_name #.pt' 81 | 82 | 83 | # epochs = [100,1000,2200,3280] 84 | # epochs = [400,700] 85 | # epochs = [20,100,200,300,360] 86 | # epochs = [100,300,500,700,1000] 87 | epochs = [100,300,600,700] 88 | # epochs = [360] 89 | 90 | 91 | 92 | n_data =1006 93 | 94 | 95 | # write to 96 | file_ = home+'/Documents/tmp/inference_suboptimality/over_training_exps/results_'+str(n_data)+'_10k_fashion_binarized2_SSE.txt' 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | def test_vae(model, data_x, batch_size, display, k): 106 | 107 | time_ = time.time() 108 | elbos = [] 109 | data_index= 0 110 | for i in range(int(len(data_x)/ batch_size)): 111 | 112 | batch = data_x[data_index:data_index+batch_size] 113 | data_index += batch_size 114 | 115 | batch = Variable(torch.from_numpy(batch)).type(model.dtype) 116 | 117 | elbo, logpxz, logqz = model.forward2(batch, k=k) 118 | 119 | elbos.append(elbo.data[0]) 120 | 121 | # if i%display==0: 122 | # print (i,len(data_x)/ batch_size, np.mean(elbos)) 123 | 124 | mean_ = np.mean(elbos) 125 | # print(mean_, 'T:', time.time()-time_) 126 | 127 | return mean_#, time.time()-time_ 128 | 129 | 130 | 131 | 132 | 133 | def test(model, data_x, batch_size, display, k): 134 | 135 | time_ = time.time() 136 | elbos = [] 137 | data_index= 0 138 | for i in range(int(len(data_x)/ batch_size)): 139 | 140 | batch = data_x[data_index:data_index+batch_size] 141 | data_index += batch_size 142 | 143 | batch = Variable(torch.from_numpy(batch)).type(model.dtype) 144 | 145 | elbo, logpxz, logqz = model(batch, k=k) 146 | 147 | elbos.append(elbo.data[0]) 148 | 149 | # if i%display==0: 150 | # print (i,len(data_x)/ batch_size, np.mean(elbos)) 151 | 152 | mean_ = np.mean(elbos) 153 | # print(mean_, 'T:', time.time()-time_) 154 | 155 | return mean_#, time.time()-time_ 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | #FASHION 168 | def load_mnist(path, kind='train'): 169 | 170 | images_path = os.path.join(path, 171 | '%s-images-idx3-ubyte.gz' 172 | % kind) 173 | 174 | with gzip.open(images_path, 'rb') as imgpath: 175 | images = np.frombuffer(imgpath.read(), dtype=np.uint8, 176 | offset=16).reshape(-1, 784) 177 | 178 | return images#, labels 179 | 180 | 181 | path = home+'/Documents/fashion_MNIST' 182 | 183 | train_x = load_mnist(path=path) 184 | test_x = load_mnist(path=path, kind='t10k') 185 | 186 | train_x = train_x / 255. 187 | test_x = test_x / 255. 188 | 189 | #binarize 190 | train_x = (train_x > .5).astype(float) 191 | test_x = (test_x > .5).astype(float) 192 | 193 | 194 | print (train_x.shape) 195 | print (test_x.shape) 196 | print () 197 | 198 | valid_x = train_x[50000:] 199 | # train_x = train_x[:50000] 200 | train_x = train_x[:10000] #small dataset 201 | 202 | 203 | print (train_x.shape) 204 | print (valid_x.shape) 205 | print (test_x.shape) 206 | print () 207 | 208 | 209 | 210 | 211 | 212 | x_size = 784 213 | z_size = 20 214 | batch_size = 50 215 | k = 1 216 | #save params 217 | start_at = 100 218 | save_freq = 300 219 | display_epoch = 3 220 | 221 | # hyper_config = { 222 | # 'x_size': x_size, 223 | # 'z_size': z_size, 224 | # 'act_func': F.elu, #F.tanh,# F.relu, 225 | # 'encoder_arch': [[x_size,200],[200,200],[200,z_size*2]], 226 | # 'decoder_arch': [[z_size,200],[200,200],[200,x_size]], 227 | # 'q_dist': standard, #FFG_LN#,#hnf,#aux_nf,#flow1,#, 228 | # 'cuda': 1, 229 | # 'hnf': hnf 230 | # } 231 | 232 | 233 | 234 | # #SE 235 | # hyper_config = { 236 | # 'x_size': x_size, 237 | # 'z_size': z_size, 238 | # 'act_func': F.elu, #F.tanh,# F.relu, 239 | # 'encoder_arch': [[x_size,100],[100,z_size*2]], 240 | # 'decoder_arch': [[z_size,200],[200,200],[200,x_size]], 241 | # 'q_dist': standard, #FFG_LN#,#hnf,#aux_nf,#flow1,#, 242 | # 'cuda': 1, 243 | # 'hnf': hnf 244 | # } 245 | 246 | 247 | #SSE 248 | hyper_config = { 249 | 'x_size': x_size, 250 | 'z_size': z_size, 251 | 'act_func': F.elu, #F.tanh,# F.relu, 252 | 'encoder_arch': [[x_size,z_size*2]], 253 | 'decoder_arch': [[z_size,200],[200,200],[200,x_size]], 254 | 'q_dist': standard, #FFG_LN#,#hnf,#aux_nf,#flow1,#, 255 | 'cuda': 1, 256 | 'hnf': hnf 257 | } 258 | 259 | 260 | 261 | hyper_config['q'] = q(hyper_config) 262 | 263 | 264 | print ('Init model') 265 | model = VAE(hyper_config) 266 | if torch.cuda.is_available(): 267 | model.cuda() 268 | 269 | print('\nModel:', hyper_config,'\n') 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | batch_size = 4 280 | 281 | k = 5000 282 | 283 | for epoch in epochs: 284 | 285 | print (epoch, epochs) 286 | 287 | 288 | 289 | print ('Load params for decoder') 290 | path_to_load_variables=path_to_save_variables+'_generator_'+str(epoch)+'.pt' 291 | model.generator.load_state_dict(torch.load(path_to_load_variables, map_location=lambda storage, loc: storage)) 292 | print ('loaded variables ' + path_to_load_variables) 293 | # print () 294 | 295 | 296 | print ('Load params for encoder') 297 | path_to_load_variables=path_to_save_variables+'_encoder_'+str(epoch)+'.pt' 298 | model.q_dist.load_state_dict(torch.load(path_to_load_variables, map_location=lambda storage, loc: storage)) 299 | print ('loaded variables ' + path_to_load_variables) 300 | 301 | 302 | 303 | 304 | #TRAINING SET AMORT 305 | 306 | VAE_train = test_vae(model=model, data_x=train_x[:n_data], batch_size=np.minimum(n_data, batch_size), display=10, k=k) 307 | IW_train = test(model=model, data_x=train_x[:n_data], batch_size=np.minimum(n_data, batch_size), display=10, k=k) 308 | print ('train amortized VAE',VAE_train) 309 | print ('train amortized IW',IW_train) 310 | 311 | 312 | with open(file_, 'a') as f: 313 | writer = csv.writer(f, delimiter=' ') 314 | 315 | writer.writerow([q_name,'training', epoch, 'L_q', VAE_train]) 316 | writer.writerow([q_name,'training', epoch, 'L_q_IWAE', IW_train]) 317 | 318 | 319 | 320 | #TEST SET AMORT 321 | 322 | VAE_test = test_vae(model=model, data_x=test_x[:n_data], batch_size=np.minimum(n_data, batch_size), display=10, k=k) 323 | IW_test = test(model=model, data_x=test_x[:n_data], batch_size=np.minimum(n_data, batch_size), display=10, k=k) 324 | print ('test amortized VAE',VAE_test) 325 | print ('test amortized IW',IW_test) 326 | print() 327 | 328 | 329 | with open(file_, 'a') as f: 330 | writer = csv.writer(f, delimiter=' ') 331 | 332 | writer.writerow([q_name,'validation', epoch, 'L_q', VAE_test]) 333 | writer.writerow([q_name,'validation', epoch, 'L_q_IWAE', IW_test]) 334 | 335 | 336 | 337 | 338 | # values = {} 339 | # values['training'] = {} 340 | # values['validation'] = {} 341 | 342 | # max_value = None 343 | # min_value = None 344 | 345 | 346 | # #Get numbder of distributions and epochs and datasets 347 | # datasets = [] 348 | # distributions = [] 349 | # epochs = [] 350 | # bounds = [] 351 | 352 | # with open(file_, 'r') as f: 353 | # reader = csv.reader(f, delimiter=' ') 354 | # for row in reader: 355 | # if len(row): 356 | 357 | # distribution = row[0] 358 | # dataset = row[1] 359 | # epoch = row[2] 360 | # bound = row[2] 361 | 362 | # if distribution not in distributions: 363 | # distributions.append(distribution) 364 | 365 | # and row[1] in ['training','validation']: 366 | # # print (row) 367 | 368 | # value = row[3] 369 | 370 | # if epoch not in values[dataset]: 371 | # values[dataset][epoch] = {} 372 | # if epoch not in epochs: 373 | # epochs.append(epoch) 374 | # print (epoch) 375 | 376 | # values[dataset][epoch][bound] = value 377 | 378 | # if max_value == None or float(value) > max_value: 379 | # max_value = float(value) 380 | # if min_value == None or float(value) < min_value: 381 | # min_value = float(value) 382 | 383 | 384 | 385 | 386 | 387 | 388 | 389 | 390 | 391 | 392 | 393 | 394 | 395 | 396 | 397 | 398 | -------------------------------------------------------------------------------- /gaps_over_training_exp/train_mnist.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import numpy as np 5 | import gzip 6 | import time 7 | import pickle 8 | 9 | from os.path import expanduser 10 | home = expanduser("~") 11 | 12 | import sys, os 13 | sys.path.insert(0, '../models') 14 | sys.path.insert(0, '../models/utils') 15 | 16 | 17 | import matplotlib 18 | matplotlib.use('Agg') 19 | import matplotlib.pyplot as plt 20 | 21 | 22 | import torch 23 | from torch.autograd import Variable 24 | import torch.utils.data 25 | import torch.optim as optim 26 | import torch.nn as nn 27 | import torch.nn.functional as F 28 | 29 | # from ais2 import test_ais 30 | 31 | # from pytorch_vae_v6 import VAE 32 | 33 | # from vae_1 import VAE 34 | from vae_2 import VAE 35 | 36 | 37 | # from approx_posteriors_v6 import FFG_LN 38 | # from approx_posteriors_v6 import ANF_LN 39 | # import argparse 40 | # from approx_posteriors_v6 import standard 41 | from inference_net import standard 42 | 43 | from distributions import Gaussian 44 | from distributions import Flow 45 | 46 | 47 | 48 | 49 | 50 | 51 | #FASHION 52 | def load_mnist(path, kind='train'): 53 | 54 | images_path = os.path.join(path, 55 | '%s-images-idx3-ubyte.gz' 56 | % kind) 57 | 58 | with gzip.open(images_path, 'rb') as imgpath: 59 | images = np.frombuffer(imgpath.read(), dtype=np.uint8, 60 | offset=16).reshape(-1, 784) 61 | 62 | return images#, labels 63 | 64 | 65 | path = home+'/Documents/fashion_MNIST' 66 | 67 | train_x = load_mnist(path=path) 68 | test_x = load_mnist(path=path, kind='t10k') 69 | 70 | train_x = train_x / 255. 71 | test_x = test_x / 255. 72 | 73 | print (train_x.shape) 74 | print (test_x.shape) 75 | 76 | 77 | 78 | #binarize 79 | train_x = (train_x > .5).astype(float) 80 | test_x = (test_x > .5).astype(float) 81 | 82 | # print (train_x) 83 | # fads 84 | 85 | # print (np.max(train_x)) 86 | # print (test_x[3]) 87 | # fsda 88 | 89 | 90 | # print ('Loading data') 91 | # with open(home+'/Documents/MNIST_data/mnist.pkl','rb') as f: 92 | # mnist_data = pickle.load(f, encoding='latin1') 93 | # train_x = mnist_data[0][0] 94 | # valid_x = mnist_data[1][0] 95 | # test_x = mnist_data[2][0] 96 | # train_x = np.concatenate([train_x, valid_x], axis=0) 97 | # print (train_x.shape) 98 | 99 | 100 | 101 | # #Load data mnist 102 | # print ('Loading data' ) 103 | # data_location = home + '/Documents/MNIST_data/' 104 | # # with open(data_location + 'binarized_mnist.pkl', 'rb') as f: 105 | # # train_x, valid_x, test_x = pickle.load(f) 106 | # with open(data_location + 'binarized_mnist.pkl', 'rb') as f: 107 | # train_x, valid_x, test_x = pickle.load(f, encoding='latin1') 108 | # print ('Train', train_x.shape) 109 | # print ('Valid', valid_x.shape) 110 | # print ('Test', test_x.shape) 111 | 112 | 113 | # print (np.max(train_x)) 114 | 115 | # fadad 116 | 117 | 118 | 119 | 120 | def train_encoder_and_decoder(model, train_x, test_x, k, batch_size, 121 | start_at, save_freq, display_epoch, 122 | path_to_save_variables): 123 | 124 | train_y = torch.from_numpy(np.zeros(len(train_x))) 125 | train_x = torch.from_numpy(train_x).float().type(model.dtype) 126 | 127 | train_ = torch.utils.data.TensorDataset(train_x, train_y) 128 | train_loader = torch.utils.data.DataLoader(train_, batch_size=batch_size, shuffle=True) 129 | 130 | #IWAE paper training strategy 131 | time_ = time.time() 132 | total_epochs = 0 133 | 134 | i_max = 7 135 | 136 | warmup_over_epochs = 100. 137 | 138 | 139 | all_params = [] 140 | for aaa in model.q_dist.parameters(): 141 | all_params.append(aaa) 142 | for aaa in model.generator.parameters(): 143 | all_params.append(aaa) 144 | # print (len(all_params), 'number of params') 145 | 146 | print (model.q_dist) 147 | # print (model.q_dist.q) 148 | print (model.generator) 149 | 150 | # fads 151 | 152 | 153 | for i in range(0,i_max+1): 154 | 155 | lr = .001 * 10**(-i/float(i_max)) 156 | print (i, 'LR:', lr) 157 | 158 | # # optimizer = optim.Adam(model.parameters(), lr=lr) 159 | # print (model.q_dist) 160 | # print (model.generator) 161 | # print (model.q_dist.parameters()) 162 | # print (model.generator.parameters()) 163 | 164 | # print ('Encoder') 165 | # for aaa in model.q_dist.parameters(): 166 | # # print (aaa) 167 | # print (aaa.size()) 168 | # print ('Decoder') 169 | # for aaa in model.generator.parameters(): 170 | # # print (aaa) 171 | # print (aaa.size()) 172 | # # fasdfs 173 | # fads 174 | 175 | 176 | optimizer = optim.Adam(all_params, lr=lr) 177 | 178 | epochs = 3**(i) 179 | 180 | for epoch in range(1, epochs + 1): 181 | 182 | for batch_idx, (data, target) in enumerate(train_loader): 183 | 184 | batch = Variable(data)#.type(model.dtype) 185 | 186 | optimizer.zero_grad() 187 | 188 | warmup = total_epochs/warmup_over_epochs 189 | if warmup > 1.: 190 | warmup = 1. 191 | 192 | elbo, logpxz, logqz = model.forward(batch, k=k, warmup=warmup) 193 | 194 | loss = -(elbo) 195 | loss.backward() 196 | optimizer.step() 197 | 198 | total_epochs += 1 199 | 200 | 201 | if total_epochs%display_epoch==0: 202 | print ('Train Epoch: {}/{}'.format(epoch, epochs), 203 | 'total_epochs {}'.format(total_epochs), 204 | 'LL:{:.3f}'.format(-loss.data[0]), 205 | 'logpxz:{:.3f}'.format(logpxz.data[0]), 206 | 'logqz:{:.3f}'.format(logqz.data[0]), 207 | 'warmup:{:.3f}'.format(warmup), 208 | 'T:{:.2f}'.format(time.time()-time_), 209 | ) 210 | time_ = time.time() 211 | 212 | 213 | if total_epochs >= start_at and (total_epochs-start_at)%save_freq==0: 214 | 215 | # save params 216 | save_file = path_to_save_variables+'_encoder_'+str(total_epochs)+'.pt' 217 | torch.save(model.q_dist.state_dict(), save_file) 218 | print ('saved variables ' + save_file) 219 | save_file = path_to_save_variables+'_generator_'+str(total_epochs)+'.pt' 220 | torch.save(model.generator.state_dict(), save_file) 221 | print ('saved variables ' + save_file) 222 | 223 | 224 | 225 | # save params 226 | save_file = path_to_save_variables+'_encoder_'+str(total_epochs)+'.pt' 227 | torch.save(model.q_dist.state_dict(), save_file) 228 | print ('saved variables ' + save_file) 229 | save_file = path_to_save_variables+'_generator_'+str(total_epochs)+'.pt' 230 | torch.save(model.generator.state_dict(), save_file) 231 | print ('saved variables ' + save_file) 232 | 233 | 234 | print ('done training') 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | # Which gpu 249 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' 250 | 251 | 252 | x_size = 784 253 | z_size = 50 254 | batch_size = 20 255 | k = 1 256 | #save params 257 | start_at = 100 258 | save_freq = 300 259 | display_epoch = 3 260 | 261 | # hyper_config = { 262 | # 'x_size': x_size, 263 | # 'z_size': z_size, 264 | # 'act_func': F.elu, #F.tanh,# F.relu, 265 | # 'encoder_arch': [[x_size,200],[200,200],[200,z_size*2]], 266 | # 'decoder_arch': [[z_size,200],[200,200],[200,x_size]], 267 | # 'q_dist': standard, #FFG_LN#,#hnf,#aux_nf,#flow1,#, 268 | # 'cuda': 1 269 | # } 270 | 271 | 272 | #LE 273 | hyper_config = { 274 | 'x_size': x_size, 275 | 'z_size': z_size, 276 | 'act_func': F.elu, #F.tanh,# F.relu, 277 | 'encoder_arch': [[x_size,500],[500,500],[500,z_size*2]], 278 | 'decoder_arch': [[z_size,200],[200,200],[200,x_size]], 279 | 'q_dist': standard, #FFG_LN#,#hnf,#aux_nf,#flow1,#, 280 | 'cuda': 1 281 | } 282 | 283 | # hyper_config = { 284 | # 'x_size': x_size, 285 | # 'z_size': z_size, 286 | # 'act_func': F.tanh,# F.relu, 287 | # 'encoder_arch': [[x_size,200],[200,200],[200,z_size*2]], 288 | # 'decoder_arch': [[z_size,x_size]], 289 | # 'q_dist': standard, #FFG_LN#,#hnf,#aux_nf,#flow1,#, 290 | # 'cuda': 1 291 | # } 292 | 293 | 294 | # hyper_config = { 295 | # 'x_size': x_size, 296 | # 'z_size': z_size, 297 | # 'act_func': F.tanh,# F.relu, 298 | # 'encoder_arch': [[x_size,200],[200,200],[200,z_size*2]], 299 | # 'decoder_arch': [[z_size,200],[200,200],[200,200],[200,200],[200,x_size]], 300 | # 'q_dist': standard, #FFG_LN#,#hnf,#aux_nf,#flow1,#, 301 | # 'cuda': 1 302 | # } 303 | 304 | 305 | q = Gaussian(hyper_config) 306 | # q = Flow(hyper_config) 307 | hyper_config['q'] = q 308 | 309 | 310 | print ('Init model') 311 | model = VAE(hyper_config) 312 | if torch.cuda.is_available(): 313 | model.cuda() 314 | 315 | print('\nModel:', hyper_config,'\n') 316 | 317 | 318 | 319 | 320 | # path_to_load_variables='' 321 | path_to_save_variables=home+'/Documents/tmp/inference_suboptimality/fashion_params/LE_binarized_fashion' #.pt' 322 | # path_to_save_variables=home+'/Documents/tmp/inference_suboptimality/fashion_params/binarized_fashion_' #.pt' 323 | 324 | # path_to_save_variables=home+'/Documents/tmp/pytorch_vae'+str(epochs)+'.pt' 325 | # path_to_save_variables=this_dir+'/params_'+model_name+'_' 326 | # path_to_save_variables='' 327 | 328 | 329 | 330 | print('\nTraining') 331 | # train_lr_schedule(model=model, train_x=train_x, test_x=test_x, k=k, batch_size=batch_size, 332 | # start_at=start_at, save_freq=save_freq, display_epoch=display_epoch, 333 | # path_to_save_variables=path_to_save_variables) 334 | 335 | 336 | train_encoder_and_decoder(model=model, train_x=train_x, test_x=test_x, k=k, batch_size=batch_size, 337 | start_at=start_at, save_freq=save_freq, display_epoch=display_epoch, 338 | path_to_save_variables=path_to_save_variables) 339 | 340 | print ('Done.') 341 | 342 | 343 | 344 | 345 | 346 | 347 | 348 | 349 | 350 | 351 | 352 | 353 | 354 | 355 | 356 | 357 | 358 | 359 | 360 | 361 | 362 | 363 | 364 | 365 | 366 | 367 | 368 | 369 | 370 | 371 | 372 | 373 | 374 | 375 | 376 | 377 | 378 | 379 | -------------------------------------------------------------------------------- /test_different_dists/plot_gaps_over_epochs.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | import time 7 | import numpy as np 8 | import pickle 9 | from os.path import expanduser 10 | home = expanduser("~") 11 | 12 | import matplotlib 13 | matplotlib.use('agg') 14 | import matplotlib.pyplot as plt 15 | 16 | import csv 17 | 18 | import sys 19 | 20 | 21 | just_amort = 0 22 | 23 | 24 | 25 | 26 | 27 | # epochs=['100','1000','1900','2800'] 28 | # epochs=['100','1000','2200'] 29 | # epochs=['100','2800'] 30 | 31 | 32 | if just_amort: 33 | bounds = ['L_q', 'L_q_IWAE'] 34 | else: 35 | bounds = ['logpx', 'L_q_star', 'L_q'] 36 | 37 | 38 | 39 | epochs = [] 40 | 41 | values = {} 42 | values['training'] = {} 43 | values['validation'] = {} 44 | # for epoch in epochs: 45 | # values['training'][epoch] = {} 46 | # values['validation'][epoch] = {} 47 | # for bound in bounds: 48 | # for epoch in epochs: 49 | # values['training'][epoch][bound] = {} 50 | # values['validation'][epoch][bound] = {} 51 | 52 | 53 | 54 | #read values 55 | # results_file = 'results_50' 56 | # results_file = 'results_2_fashion' 57 | # results_file = 'results_10_fashion' 58 | # results_file = 'results_100_fashion' 59 | 60 | results_file = sys.argv[1] 61 | 62 | 63 | 64 | # ndata_101_binarized_fashion3_Flow 65 | 66 | 67 | file_ = home+'/Documents/tmp/inference_suboptimality/over_training_exps/results_'+results_file+'.txt' 68 | 69 | max_value = None 70 | min_value = None 71 | 72 | with open(file_, 'r') as f: 73 | reader = csv.reader(f, delimiter=' ') 74 | for row in reader: 75 | if len(row) and row[0] in ['training','validation']: 76 | # print (row) 77 | dataset = row[0] 78 | epoch = row[1] 79 | bound = row[2] 80 | value = row[3] 81 | 82 | if epoch not in values[dataset]: 83 | values[dataset][epoch] = {} 84 | if epoch not in epochs: 85 | epochs.append(epoch) 86 | print (epoch) 87 | 88 | values[dataset][epoch][bound] = value 89 | 90 | if max_value == None or float(value) > max_value: 91 | max_value = float(value) 92 | if min_value == None or float(value) < min_value: 93 | min_value = float(value) 94 | 95 | max_value += .2 96 | 97 | # max_value = -81 98 | # min_value = -110 99 | 100 | 101 | # print (values) 102 | 103 | #sort epochs 104 | # epochs.sort() 105 | 106 | # print (epochs) 107 | # fads 108 | 109 | #convert to list 110 | # training_plot = {} 111 | # for bound in bounds: 112 | # values_to_plot = [] 113 | # for epoch in epochs: 114 | # values_to_plot.append(float(values['training'][epoch][bound])) 115 | # training_plot[bound] = values_to_plot 116 | # print (training_plot) 117 | 118 | 119 | # validation_plot = {} 120 | # for bound in bounds: 121 | # values_to_plot = [] 122 | # for epoch in epochs: 123 | # values_to_plot.append(float(values['validation'][epoch][bound])) 124 | # validation_plot[bound] = values_to_plot 125 | # print (validation_plot) 126 | 127 | 128 | 129 | training_plot = {} 130 | for bound in bounds: 131 | values_to_plot = [] 132 | for epoch in epochs: 133 | if bound == 'logpx' and 'AIS' in values['training'][epoch]: 134 | # print (values['training'][epoch]['AIS'], values['training'][epoch]['logpx']) 135 | # fadsfa 136 | # value = max() 137 | value = (max(float(values['training'][epoch]['AIS']), float(values['training'][epoch]['logpx']))) 138 | else: 139 | value = float(values['training'][epoch][bound]) 140 | values_to_plot.append(value) 141 | 142 | training_plot[bound] = values_to_plot 143 | print (training_plot) 144 | # fadsa 145 | 146 | 147 | validation_plot = {} 148 | for bound in bounds: 149 | values_to_plot = [] 150 | for epoch in epochs: 151 | if bound == 'logpx' and 'AIS' in values['validation'][epoch]: 152 | value = (max(float(values['validation'][epoch]['AIS']), float(values['validation'][epoch]['logpx']))) 153 | else: 154 | value = float(values['validation'][epoch][bound]) 155 | # values_to_plot.append(float(values['validation'][epoch][bound])) 156 | values_to_plot.append(value) 157 | 158 | validation_plot[bound] = values_to_plot 159 | print (validation_plot) 160 | 161 | epochs_float = [float(x) for x in epochs] 162 | 163 | 164 | rows = 1 165 | cols = 2 166 | 167 | legend=False 168 | 169 | fig = plt.figure(figsize=(8+cols,2+rows), facecolor='white') 170 | 171 | # ylimits = [-110, -84] 172 | ylimits = [min_value, max_value] 173 | 174 | 175 | 176 | 177 | # Training set 178 | ax = plt.subplot2grid((rows,cols), (0,0), frameon=False) 179 | 180 | 181 | 182 | 183 | # ax.set_title(results_file,family='serif') 184 | ax.set_title('Training Set',family='serif') 185 | 186 | # for bound in bounds: 187 | # ax.plot(epochs_float,training_plot[bound]) #, label=legends[i], c=colors[i], ls=line_styles[i]) 188 | 189 | 190 | if not just_amort: 191 | ax.fill_between(epochs_float, training_plot['logpx'], training_plot['L_q_star']) 192 | ax.fill_between(epochs_float, training_plot['L_q_star'], training_plot['L_q']) 193 | else: 194 | ax.plot(epochs_float, training_plot['L_q']) 195 | ax.plot(epochs_float, training_plot['L_q_IWAE']) 196 | 197 | 198 | ax.set_ylim(ylimits) 199 | ax.grid(True, alpha=.5) 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | # Validation set 208 | ax = plt.subplot2grid((rows,cols), (0,1), frameon=False) 209 | 210 | ax.set_title('Validation Set',family='serif') 211 | 212 | # for bound in bounds: 213 | # ax.plot(epochs_float,validation_plot[bound]) #, label=legends[i], c=colors[i], ls=line_styles[i]) 214 | 215 | ax.grid(True, alpha=.5) 216 | 217 | if not just_amort: 218 | ax.fill_between(epochs_float, validation_plot['logpx'], validation_plot['L_q_star']) 219 | ax.fill_between(epochs_float, validation_plot['L_q_star'], validation_plot['L_q']) 220 | else: 221 | ax.plot(epochs_float, validation_plot['L_q']) 222 | ax.plot(epochs_float, validation_plot['L_q_IWAE']) 223 | 224 | 225 | ax.set_ylim(ylimits) 226 | 227 | 228 | 229 | # ax.set_yticks() 230 | 231 | # family='serif' 232 | # fontProperties = {'family':'serif'} 233 | # ax.set_xticklabels(ax.get_xticks(), fontProperties) 234 | 235 | 236 | 237 | # ax.annotate('fdfafadf', xytext=(.5, .5), xy=(.5, .5), textcoords='figure fraction') 238 | 239 | # ax.annotate('fdfafadf', xy=(0, 0), xytext=(.5, .5), textcoords='figure fraction') 240 | # ax.annotate('local max', xy=(3, 1), xycoords='data', 241 | # xytext=(0.8, 0.95), textcoords='axes fraction', 242 | # arrowprops=dict(facecolor='black', shrink=0.05), 243 | # horizontalalignment='right', verticalalignment='top', 244 | # ) 245 | 246 | 247 | name_file = home+'/Documents/tmp/inference_suboptimality/over_training_exps/'+results_file+'.png' 248 | name_file2 = home+'/Documents/tmp/inference_suboptimality/over_training_exps/'+results_file+'.pdf' 249 | # name_file = home+'/Documents/tmp/plot.png' 250 | plt.savefig(name_file) 251 | plt.savefig(name_file2) 252 | print ('Saved fig', name_file) 253 | print ('Saved fig', name_file2) 254 | 255 | 256 | 257 | print ('DONE') 258 | # fdsa 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | # # # models = [standard,flow1,aux_nf]#,hnf] 293 | # # # models = [standard,standard_large_encoder]#, aux_nf aux_large_encoder]#,hnf] 294 | # # models = [standard,aux_nf]#, aux_nf aux_large_encoder]#,hnf] 295 | 296 | 297 | # # # model_names = ['standard','flow1','aux_nf','hnf'] 298 | # # # model_names = ['VAE','NF','Aux+NF']#,'HNF'] 299 | # # # model_names = ['FFG','Flow']#,'HNF'] 300 | # # # model_names = ['FFG','Flow']#,'HNF'] 301 | # # model_names = ['FFG','Flow']# 'aux_nf','aux_large_encoder']#,'HNF'] 302 | 303 | 304 | 305 | 306 | 307 | # # # legends = ['IW train', 'IW test', 'AIS train', 'AIS test'] 308 | # # # legends = ['VAE train', 'VAE test', 'IW train', 'IW test', 'AIS train', 'AIS test'] 309 | 310 | # # legends = ['VAE train', 'VAE test', 'IW train', 'IW test', 'AIS train', 'AIS test'] 311 | 312 | 313 | # # colors = ['blue', 'blue', 'green', 'green', 'red', 'red'] 314 | 315 | # # line_styles = [':', '-', ':', '-', ':', '-'] 316 | 317 | 318 | 319 | 320 | # rows = 1 321 | # cols = 1 322 | 323 | # legend=False 324 | 325 | # fig = plt.figure(figsize=(2+cols,2+rows), facecolor='white') 326 | 327 | # # Get y-axis limits 328 | # min_ = None 329 | # max_ = None 330 | # for m in range(len(models)): 331 | # for i in range(len(legends)): 332 | # if i == 1: 333 | # continue 334 | # this_min = np.min(models[m][i]) 335 | # this_max = np.max(models[m][i]) 336 | # if min_ ==None or this_min < min_: 337 | # min_ = this_min 338 | # if max_ ==None or this_max > max_: 339 | # max_ = this_max 340 | 341 | # min_ -= .1 342 | # max_ += .1 343 | # # print (min_) 344 | # # print (max_) 345 | # ylimits = [min_, max_] 346 | # xlimits = [x[0], x[-1]] 347 | 348 | # # fasd 349 | 350 | # # ax.plot(x,hnf_ais, label='hnf_ais') 351 | # # ax.set_yticks([]) 352 | # # ax.set_xticks([]) 353 | # # if samp_i==0: ax.annotate('Sample', xytext=(.3, 1.1), xy=(0, 1), textcoords='axes fraction') 354 | 355 | # for m in range(len(models)): 356 | # ax = plt.subplot2grid((rows,cols), (0,m), frameon=False) 357 | # for i in range(len(legends)): 358 | # if i == 1: 359 | # continue 360 | # ax.set_title(model_names[m],family='serif') 361 | # ax.plot(x,models[m][i], label=legends[i], c=colors[i], ls=line_styles[i]) 362 | # plt.legend(fontsize=6) 363 | # # ax.set(adjustable='box-forced', aspect='equal') 364 | # plt.yticks(size=6) 365 | # # plt.xticks(x,size=6) 366 | # plt.xticks([400,1300,2200,3100],size=6) 367 | 368 | # # ax.set_xlim(xlimits) 369 | # ax.set_ylim(ylimits) 370 | # ax.set_xlim(xlimits) 371 | 372 | # ax.set_xlabel('Epochs',size=6) 373 | # if m==0: 374 | # ax.set_ylabel('Log-Likelihood',size=6) 375 | 376 | 377 | # ax.grid(True, alpha=.1) 378 | 379 | 380 | # # m+=1 381 | # # ax = plt.subplot2grid((rows,cols), (0,m), frameon=False) 382 | # # ax.set_title('AIS_test') 383 | # # for m in range(len(models)): 384 | # # ax.plot(x,models[m][3], label=model_names[m]) 385 | # # plt.legend(fontsize=4) 386 | # # plt.yticks(size=6) 387 | 388 | 389 | 390 | 391 | 392 | # # plt.gca().set_aspect('equal', adjustable='box') 393 | # name_file = home+'/Documents/tmp/plot.png' 394 | # plt.savefig(name_file) 395 | # print ('Saved fig', name_file) 396 | 397 | # name_file = home+'/Documents/tmp/plot.eps' 398 | # plt.savefig(name_file) 399 | # print ('Saved fig', name_file) 400 | 401 | 402 | # name_file = home+'/Documents/tmp/plot.pdf' 403 | # plt.savefig(name_file) 404 | # print ('Saved fig', name_file) 405 | 406 | 407 | 408 | 409 | 410 | 411 | 412 | 413 | 414 | 415 | -------------------------------------------------------------------------------- /decoder_sizes_exp/compute_gaps.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | import numpy as np 7 | import gzip 8 | import time 9 | import pickle 10 | 11 | from os.path import expanduser 12 | home = expanduser("~") 13 | 14 | import sys, os 15 | sys.path.insert(0, '../models') 16 | sys.path.insert(0, '../models/utils') 17 | 18 | 19 | import matplotlib 20 | matplotlib.use('Agg') 21 | import matplotlib.pyplot as plt 22 | 23 | 24 | import torch 25 | from torch.autograd import Variable 26 | import torch.utils.data 27 | import torch.optim as optim 28 | import torch.nn as nn 29 | import torch.nn.functional as F 30 | 31 | 32 | from vae_2 import VAE 33 | 34 | 35 | # from approx_posteriors_v6 import standard 36 | from inference_net import standard 37 | 38 | # from ais3 import test_ais 39 | 40 | 41 | # from optimize_local import optimize_local_gaussian 42 | 43 | 44 | from optimize_local_q import optimize_local_q_dist 45 | 46 | 47 | 48 | 49 | from distributions import Gaussian 50 | from distributions import Flow 51 | 52 | 53 | 54 | 55 | 56 | 57 | def test_vae(model, data_x, batch_size, display, k): 58 | 59 | time_ = time.time() 60 | elbos = [] 61 | data_index= 0 62 | for i in range(int(len(data_x)/ batch_size)): 63 | 64 | batch = data_x[data_index:data_index+batch_size] 65 | data_index += batch_size 66 | 67 | batch = Variable(torch.from_numpy(batch)).type(model.dtype) 68 | 69 | elbo, logpxz, logqz = model.forward2(batch, k=k) 70 | 71 | elbos.append(elbo.data[0]) 72 | 73 | # if i%display==0: 74 | # print (i,len(data_x)/ batch_size, np.mean(elbos)) 75 | 76 | mean_ = np.mean(elbos) 77 | # print(mean_, 'T:', time.time()-time_) 78 | 79 | return mean_#, time.time()-time_ 80 | 81 | 82 | 83 | 84 | 85 | def test(model, data_x, batch_size, display, k): 86 | 87 | time_ = time.time() 88 | elbos = [] 89 | data_index= 0 90 | for i in range(int(len(data_x)/ batch_size)): 91 | 92 | batch = data_x[data_index:data_index+batch_size] 93 | data_index += batch_size 94 | 95 | batch = Variable(torch.from_numpy(batch)).type(model.dtype) 96 | 97 | elbo, logpxz, logqz = model(batch, k=k) 98 | 99 | elbos.append(elbo.data[0]) 100 | 101 | # if i%display==0: 102 | # print (i,len(data_x)/ batch_size, np.mean(elbos)) 103 | 104 | mean_ = np.mean(elbos) 105 | # print(mean_, 'T:', time.time()-time_) 106 | 107 | return mean_#, time.time()-time_ 108 | 109 | 110 | 111 | 112 | 113 | ########################### 114 | # Load data 115 | 116 | # print ('Loading data') 117 | # with open(home+'/Documents/MNIST_data/mnist.pkl','rb') as f: 118 | # mnist_data = pickle.load(f, encoding='latin1') 119 | # train_x = mnist_data[0][0] 120 | # valid_x = mnist_data[1][0] 121 | # test_x = mnist_data[2][0] 122 | # train_x = np.concatenate([train_x, valid_x], axis=0) 123 | # print (train_x.shape) 124 | 125 | #Load data 126 | print ('Loading data' ) 127 | data_location = home + '/Documents/MNIST_data/' 128 | # with open(data_location + 'binarized_mnist.pkl', 'rb') as f: 129 | # train_x, valid_x, test_x = pickle.load(f) 130 | with open(data_location + 'binarized_mnist.pkl', 'rb') as f: 131 | train_x, valid_x, test_x = pickle.load(f, encoding='latin1') 132 | print ('Train', train_x.shape) 133 | print ('Valid', valid_x.shape) 134 | print ('Test', test_x.shape) 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | ########################### 147 | # Load model 148 | 149 | 150 | 151 | # this_ckt_file = path_to_save_variables + str(ckt) + '.pt' 152 | # model.load_params(path_to_load_variables=this_ckt_file) 153 | # print ('Init model') 154 | # model = VAE(hyper_config) 155 | # if torch.cuda.is_available(): 156 | # model.cuda() 157 | 158 | # print('\nModel:', hyper_config,'\n') 159 | 160 | 161 | x_size = 784 162 | z_size = 50 163 | # batch_size = 20 164 | # k = 1 165 | #save params 166 | # start_at = 100 167 | # save_freq = 300 168 | # display_epoch = 3 169 | 170 | 171 | 172 | #small encoder 173 | # hyper_config = { 174 | # 'x_size': x_size, 175 | # 'z_size': z_size, 176 | # 'act_func': F.tanh,# F.relu, 177 | # 'encoder_arch': [[x_size,z_size*2]], 178 | # 'decoder_arch': [[z_size,200],[200,200],[200,x_size]], 179 | # 'q_dist': standard, #FFG_LN#,#hnf,#aux_nf,#flow1,#, 180 | # 'cuda': 1 181 | # } 182 | 183 | #no hidden decoder 184 | # hyper_config = { 185 | # 'x_size': x_size, 186 | # 'z_size': z_size, 187 | # 'act_func': F.tanh,# F.relu, 188 | # 'encoder_arch': [[x_size,200],[200,200],[200,z_size*2]], 189 | # 'decoder_arch': [[z_size,x_size]], 190 | # 'q_dist': standard, #FFG_LN#,#hnf,#aux_nf,#flow1,#, 191 | # 'cuda': 1 192 | # } 193 | 194 | # 2 hidden decoder 195 | # hyper_config = { 196 | # 'x_size': x_size, 197 | # 'z_size': z_size, 198 | # 'act_func': F.tanh,# F.relu, 199 | # 'encoder_arch': [[x_size,200],[200,200],[200,z_size*2]], 200 | # 'decoder_arch': [[z_size,200],[200,200],[200,x_size]], 201 | # 'q_dist': standard, #FFG_LN#,#hnf,#aux_nf,#flow1,#, 202 | # 'cuda': 1 203 | # } 204 | 205 | 206 | # 4 hidden decoder 207 | hyper_config = { 208 | 'x_size': x_size, 209 | 'z_size': z_size, 210 | 'act_func': F.tanh,# F.relu, 211 | 'encoder_arch': [[x_size,200],[200,200],[200,z_size*2]], 212 | 'decoder_arch': [[z_size,200],[200,200],[200,200],[200,200],[200,x_size]], 213 | 'q_dist': standard, #FFG_LN#,#hnf,#aux_nf,#flow1,#, 214 | 'cuda': 1 215 | } 216 | 217 | 218 | 219 | q = Gaussian(hyper_config) 220 | # q = Flow(hyper_config) 221 | hyper_config['q'] = q 222 | 223 | 224 | 225 | 226 | # Which gpu 227 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' 228 | 229 | print ('Init model') 230 | model = VAE(hyper_config) 231 | if torch.cuda.is_available(): 232 | model.cuda() 233 | print('\nModel:', hyper_config,'\n') 234 | 235 | print (model.q_dist) 236 | # print (model.q_dist.q) 237 | print (model.generator) 238 | 239 | 240 | print ('Load params for decoder') 241 | # path_to_load_variables=home+'/Documents/tmp/inference_suboptimality/vae_generator_3280.pt' 242 | path_to_load_variables=home+'/Documents/tmp/inference_suboptimality/decoder_exps/hidden_layers_4_generator_3280.pt' 243 | # path_to_load_variables=home+'/Documents/tmp/inference_suboptimality/decoder_exps/hidden_layers_2_generator_3280.pt' 244 | # path_to_load_variables=home+'/Documents/tmp/inference_suboptimality/decoder_exps/hidden_layers_0_generator_3280.pt' 245 | 246 | model.generator.load_state_dict(torch.load(path_to_load_variables, map_location=lambda storage, loc: storage)) 247 | print ('loaded variables ' + path_to_load_variables) 248 | print () 249 | 250 | 251 | 252 | compute_local_opt = 1 253 | compute_amort = 0 254 | 255 | 256 | if compute_amort: 257 | 258 | print ('Load params for encoder') 259 | # path_to_load_variables=home+'/Documents/tmp/inference_suboptimality/vae_encoder_100.pt' 260 | # path_to_load_variables=home+'/Documents/tmp/inference_suboptimality/vae_encoder_3280.pt' 261 | # path_to_load_variables=home+'/Documents/tmp/inference_suboptimality/vae_smallencoder_encoder_3280.pt' 262 | # path_to_load_variables=home+'/Documents/tmp/inference_suboptimality/vae_regencoder_encoder_3280.pt' 263 | # path_to_load_variables=home+'/Documents/tmp/inference_suboptimality/vae_smallencoder_withflow_encoder_3280.pt' 264 | path_to_load_variables=home+'/Documents/tmp/inference_suboptimality/decoder_exps/hidden_layers_4_encoder_3280.pt' 265 | # path_to_load_variables=home+'/Documents/tmp/inference_suboptimality/decoder_exps/hidden_layers_2_encoder_3280.pt' 266 | # path_to_load_variables=home+'/Documents/tmp/inference_suboptimality/decoder_exps/hidden_layers_0_encoder_3280.pt' 267 | 268 | 269 | 270 | 271 | model.q_dist.load_state_dict(torch.load(path_to_load_variables, map_location=lambda storage, loc: storage)) 272 | print ('loaded variables ' + path_to_load_variables) 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | ########################### 282 | # For each datapoint, compute L[q], L[q*], log p(x) 283 | 284 | # # log it 285 | # with open(experiment_log, "a") as myfile: 286 | # myfile.write('Checkpoint' +str(ckt)+'\n') 287 | 288 | # start_time = time.time() 289 | 290 | n_data = 2 #1000 #100 291 | 292 | vaes = [] 293 | iwaes = [] 294 | vaes_flex = [] 295 | iwaes_flex = [] 296 | 297 | 298 | 299 | if compute_local_opt: 300 | print ('optmizing local') 301 | for i in range(len(train_x[:n_data])): 302 | 303 | print (i) 304 | 305 | x = train_x[i] 306 | x = Variable(torch.from_numpy(x)).type(model.dtype) 307 | x = x.view(1,784) 308 | 309 | logposterior = lambda aa: model.logposterior_func2(x=x,z=aa) 310 | 311 | 312 | # # flex_model = aux_nf__(model, hyper_config) 313 | # # if torch.cuda.is_available(): 314 | # # flex_model.cuda() 315 | # # vae, iwae = flex_model.train_and_eval(logposterior=logposterior, model=model, x=x) 316 | 317 | 318 | # vae, iwae = optimize_local_expressive(logposterior, model, x) 319 | # print (vae.data.cpu().numpy(),iwae.data.cpu().numpy(),'flex') 320 | # vaes_flex.append(vae.data.cpu().numpy()) 321 | # iwaes_flex.append(iwae.data.cpu().numpy()) 322 | 323 | q_local = Gaussian(hyper_config) #, mean, logvar) 324 | # q_local = Flow(hyper_config).cuda()#, mean, logvar) 325 | 326 | # print (q_local) 327 | 328 | # vae, iwae = optimize_local_gaussian(logposterior, model, x) 329 | vae, iwae = optimize_local_q_dist(logposterior, hyper_config, x, q_local) 330 | print (vae.data.cpu().numpy(),iwae.data.cpu().numpy(),'reg') 331 | vaes.append(vae.data.cpu().numpy()) 332 | iwaes.append(iwae.data.cpu().numpy()) 333 | 334 | print() 335 | print ('opt vae',np.mean(vaes)) 336 | print ('opt iwae',np.mean(iwaes)) 337 | print() 338 | 339 | # print ('opt vae flex',np.mean(vaes_flex)) 340 | # print ('opt iwae flex',np.mean(iwaes_flex)) 341 | # print() 342 | 343 | if compute_amort: 344 | VAE_train = test_vae(model=model, data_x=train_x[:n_data], batch_size=np.minimum(n_data, 50), display=10, k=5000) 345 | IW_train = test(model=model, data_x=train_x[:n_data], batch_size=np.minimum(n_data, 50), display=10, k=5000) 346 | print ('amortized VAE',VAE_train) 347 | print ('amortized IW',IW_train) 348 | 349 | 350 | # print() 351 | # AIS_train = test_ais(model=model, data_x=train_x[:n_data], batch_size=n_data, display=2, k=50, n_intermediate_dists=500) 352 | # print ('AIS_train',AIS_train) 353 | 354 | 355 | 356 | # print() 357 | # print() 358 | # print ('AIS_train',AIS_train) 359 | # print() 360 | # print ('opt vae flex',np.mean(vaes_flex)) 361 | # # print() 362 | # print ('opt vae',np.mean(vaes)) 363 | # # print() 364 | # print ('amortized VAE',VAE_train) 365 | # print() 366 | 367 | 368 | 369 | 370 | 371 | 372 | 373 | 374 | 375 | 376 | 377 | -------------------------------------------------------------------------------- /test_different_dists/plot_8plots.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | import time 7 | import numpy as np 8 | import pickle 9 | from os.path import expanduser 10 | home = expanduser("~") 11 | 12 | import matplotlib 13 | matplotlib.use('agg') 14 | import matplotlib.pyplot as plt 15 | 16 | import csv 17 | 18 | import sys 19 | 20 | 21 | just_amort = 0 22 | 23 | 24 | 25 | 26 | 27 | # epochs=['100','1000','1900','2800'] 28 | # epochs=['100','1000','2200'] 29 | # epochs=['100','2800'] 30 | 31 | 32 | if just_amort: 33 | bounds = ['L_q', 'L_q_IWAE'] 34 | else: 35 | bounds = ['logpx', 'L_q_star', 'L_q'] 36 | 37 | 38 | 39 | 40 | # for epoch in epochs: 41 | # values['training'][epoch] = {} 42 | # values['validation'][epoch] = {} 43 | # for bound in bounds: 44 | # for epoch in epochs: 45 | # values['training'][epoch][bound] = {} 46 | # values['validation'][epoch][bound] = {} 47 | 48 | 49 | 50 | #read values 51 | # results_file = 'results_50' 52 | # results_file = 'results_2_fashion' 53 | # results_file = 'results_10_fashion' 54 | # results_file = 'results_100_fashion' 55 | 56 | # results_file = sys.argv[1] 57 | 58 | 59 | 60 | 61 | # file_1 = home+'/Documents/tmp/inference_suboptimality/over_training_exps/results_binarized_fashion3_Gaus.txt' 62 | # file_2 = home+'/Documents/tmp/inference_suboptimality/over_training_exps/results_binarized_fashion3_Gaus.txt' 63 | # file_3 = home+'/Documents/tmp/inference_suboptimality/over_training_exps/results_binarized_fashion3_Gaus.txt' 64 | # file_4 = home+'/Documents/tmp/inference_suboptimality/over_training_exps/results_binarized_fashion3_Gaus.txt' 65 | 66 | file_1 = 'ndata_101_binarized_fashion3_Gaus' 67 | file_2 = 'ndata_101_binarized_fashion3_Flow1' 68 | file_3 = 'ndata_101_binarized_fashion3_LD_Gaus' 69 | file_4 = 'ndata_101_binarized_fashion3_LE_Gaus' 70 | 71 | files = [file_1,file_2,file_3,file_4] 72 | 73 | 74 | 75 | rows = 2 76 | cols = 4 77 | 78 | legend=False 79 | 80 | fig = plt.figure(figsize=(12+cols,4+rows), facecolor='white') 81 | 82 | 83 | 84 | for file__i in range(len(files)): 85 | 86 | print (file__i, files[file__i]) 87 | 88 | 89 | epochs = [] 90 | 91 | values = {} 92 | values['training'] = {} 93 | values['validation'] = {} 94 | 95 | 96 | file_ = home+'/Documents/tmp/inference_suboptimality/over_training_exps/results_'+files[file__i]+'.txt' 97 | 98 | 99 | max_value = None 100 | min_value = None 101 | 102 | with open(file_, 'r') as f: 103 | reader = csv.reader(f, delimiter=' ') 104 | for row in reader: 105 | if len(row) and row[0] in ['training','validation']: 106 | # print (row) 107 | dataset = row[0] 108 | epoch = row[1] 109 | bound = row[2] 110 | value = row[3] 111 | 112 | if epoch not in values[dataset]: 113 | values[dataset][epoch] = {} 114 | if epoch not in epochs: 115 | epochs.append(epoch) 116 | print (epoch) 117 | 118 | values[dataset][epoch][bound] = value 119 | 120 | if max_value == None or float(value) > max_value: 121 | max_value = float(value) 122 | if min_value == None or float(value) < min_value: 123 | min_value = float(value) 124 | 125 | max_value += .2 126 | 127 | # max_value = -81 128 | # min_value = -110 129 | 130 | 131 | # print (values) 132 | 133 | #sort epochs 134 | # epochs.sort() 135 | 136 | # print (epochs) 137 | # fads 138 | 139 | #convert to list 140 | training_plot = {} 141 | for bound in bounds: 142 | values_to_plot = [] 143 | for epoch in epochs: 144 | if bound == 'logpx' and 'AIS' in values['training'][epoch]: 145 | # print (values['training'][epoch]['AIS'], values['training'][epoch]['logpx']) 146 | # fadsfa 147 | # value = max() 148 | value = (max(float(values['training'][epoch]['AIS']), float(values['training'][epoch]['logpx']))) 149 | else: 150 | value = float(values['training'][epoch][bound]) 151 | values_to_plot.append(value) 152 | 153 | training_plot[bound] = values_to_plot 154 | print (training_plot) 155 | # fadsa 156 | 157 | 158 | validation_plot = {} 159 | for bound in bounds: 160 | values_to_plot = [] 161 | for epoch in epochs: 162 | values_to_plot.append(float(values['validation'][epoch][bound])) 163 | validation_plot[bound] = values_to_plot 164 | print (validation_plot) 165 | 166 | 167 | epochs_float = [float(x) for x in epochs] 168 | 169 | 170 | 171 | 172 | # ylimits = [-110, -84] 173 | ylimits = [min_value, max_value] 174 | 175 | 176 | 177 | if file__i == 0: 178 | # pos = (0,0) 179 | pos = [0,0] 180 | elif file__i == 1: 181 | pos = [0,2] 182 | elif file__i == 2: 183 | pos = [1,0] 184 | elif file__i == 3: 185 | pos = [1,2] 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | # Training set 195 | ax = plt.subplot2grid((rows,cols), pos, frameon=False) 196 | 197 | 198 | 199 | 200 | # ax.set_title(results_file,family='serif') 201 | if pos[0]==0 or 1: 202 | ax.set_title('Training Set',family='serif') 203 | 204 | if pos==[0,0]: 205 | ax.text(.98, 1.2, 'Standard', fontsize=13,transform=ax.transAxes,family='serif') 206 | if pos==[0,2]: 207 | ax.text(.98, 1.2, 'Flow', fontsize=13,transform=ax.transAxes,family='serif') 208 | if pos==[1,0]: 209 | ax.text(.9, 1.2, 'Larger Decoder', fontsize=13,transform=ax.transAxes,family='serif') 210 | if pos==[1,2]: 211 | ax.text(.9, 1.2, 'Larger Encoder', fontsize=13,transform=ax.transAxes,family='serif') 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | # for bound in bounds: 221 | # ax.plot(epochs_float,training_plot[bound]) #, label=legends[i], c=colors[i], ls=line_styles[i]) 222 | 223 | 224 | if not just_amort: 225 | ax.fill_between(epochs_float, training_plot['logpx'], training_plot['L_q_star']) 226 | ax.fill_between(epochs_float, training_plot['L_q_star'], training_plot['L_q']) 227 | else: 228 | ax.plot(epochs_float, training_plot['L_q']) 229 | ax.plot(epochs_float, training_plot['L_q_IWAE']) 230 | 231 | 232 | ax.xaxis.set_ticks([0,1000,2000,3000]) 233 | 234 | ax.set_ylim(ylimits) 235 | ax.grid(True, alpha=.5) 236 | 237 | 238 | 239 | 240 | pos[1] = pos[1]+1 241 | 242 | 243 | # Validation set 244 | ax = plt.subplot2grid((rows,cols), pos, frameon=False) 245 | 246 | if pos[0]==0 or 1: 247 | ax.set_title('Validation Set',family='serif') 248 | 249 | # for bound in bounds: 250 | # ax.plot(epochs_float,validation_plot[bound]) #, label=legends[i], c=colors[i], ls=line_styles[i]) 251 | 252 | ax.grid(True, alpha=.5) 253 | 254 | if not just_amort: 255 | ax.fill_between(epochs_float, validation_plot['logpx'], validation_plot['L_q_star']) 256 | ax.fill_between(epochs_float, validation_plot['L_q_star'], validation_plot['L_q']) 257 | else: 258 | ax.plot(epochs_float, validation_plot['L_q']) 259 | ax.plot(epochs_float, validation_plot['L_q_IWAE']) 260 | 261 | # plt.xticks([1000,2000,3000],size=6) 262 | ax.xaxis.set_ticks([0,1000,2000,3000]) 263 | 264 | ax.set_ylim(ylimits) 265 | 266 | 267 | 268 | # ax.set_yticks() 269 | 270 | # family='serif' 271 | # fontProperties = {'family':'serif'} 272 | # ax.set_xticklabels(ax.get_xticks(), fontProperties) 273 | 274 | 275 | 276 | # 277 | 278 | # ax.annotate('fdfafadf', xy=(0, 0), xytext=(.5, .5), textcoords='figure fraction') 279 | # ax.annotate('local max', xy=(3, 1), xycoords='data', 280 | # xytext=(0.8, 0.95), textcoords='axes fraction', 281 | # arrowprops=dict(facecolor='black', shrink=0.05), 282 | # horizontalalignment='right', verticalalignment='top', 283 | # ) 284 | 285 | 286 | # ax.annotate('fdfafadf', xytext=(.5, .5), xy=(.5, .5), textcoords='a fraction') 287 | 288 | 289 | 290 | 291 | 292 | # ax = plt.subplot2grid((rows,cols), [1,0], frameon=False, colspan=2, rowspan=1) 293 | # ax.set_title('fsdfa',family='serif') 294 | 295 | # plt.tight_layout(pad=0.4, w_pad=0.5, h_pad=1.0) 296 | plt.tight_layout(pad=3.0,h_pad=4.0) 297 | 298 | # fig.suptitle('Standard', fontsize=12, x=.3, y=.1, family='serif') 299 | # fig.suptitle('Flow', fontsize=12, x=.7, y=.1, family='serif') 300 | # fig.suptitle('Larger Decoder', fontsize=12, x=.3, y=.5, family='serif') 301 | # fig.suptitle('Larger Encoder', fontsize=12, x=.7, y=.5, family='serif') 302 | 303 | 304 | name_file = home+'/Documents/tmp/inference_suboptimality/over_training_exps/8plots_withAIS.png' 305 | name_file2 = home+'/Documents/tmp/inference_suboptimality/over_training_exps/8plots_withAIS.pdf' 306 | # name_file = home+'/Documents/tmp/plot.png' 307 | plt.savefig(name_file) 308 | plt.savefig(name_file2) 309 | print ('Saved fig', name_file) 310 | print ('Saved fig', name_file2) 311 | 312 | 313 | 314 | print ('DONE') 315 | # fdsa 316 | 317 | 318 | 319 | 320 | 321 | 322 | 323 | 324 | 325 | 326 | 327 | 328 | 329 | 330 | 331 | 332 | 333 | 334 | 335 | 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | 344 | 345 | 346 | 347 | 348 | 349 | # # # models = [standard,flow1,aux_nf]#,hnf] 350 | # # # models = [standard,standard_large_encoder]#, aux_nf aux_large_encoder]#,hnf] 351 | # # models = [standard,aux_nf]#, aux_nf aux_large_encoder]#,hnf] 352 | 353 | 354 | # # # model_names = ['standard','flow1','aux_nf','hnf'] 355 | # # # model_names = ['VAE','NF','Aux+NF']#,'HNF'] 356 | # # # model_names = ['FFG','Flow']#,'HNF'] 357 | # # # model_names = ['FFG','Flow']#,'HNF'] 358 | # # model_names = ['FFG','Flow']# 'aux_nf','aux_large_encoder']#,'HNF'] 359 | 360 | 361 | 362 | 363 | 364 | # # # legends = ['IW train', 'IW test', 'AIS train', 'AIS test'] 365 | # # # legends = ['VAE train', 'VAE test', 'IW train', 'IW test', 'AIS train', 'AIS test'] 366 | 367 | # # legends = ['VAE train', 'VAE test', 'IW train', 'IW test', 'AIS train', 'AIS test'] 368 | 369 | 370 | # # colors = ['blue', 'blue', 'green', 'green', 'red', 'red'] 371 | 372 | # # line_styles = [':', '-', ':', '-', ':', '-'] 373 | 374 | 375 | 376 | 377 | # rows = 1 378 | # cols = 1 379 | 380 | # legend=False 381 | 382 | # fig = plt.figure(figsize=(2+cols,2+rows), facecolor='white') 383 | 384 | # # Get y-axis limits 385 | # min_ = None 386 | # max_ = None 387 | # for m in range(len(models)): 388 | # for i in range(len(legends)): 389 | # if i == 1: 390 | # continue 391 | # this_min = np.min(models[m][i]) 392 | # this_max = np.max(models[m][i]) 393 | # if min_ ==None or this_min < min_: 394 | # min_ = this_min 395 | # if max_ ==None or this_max > max_: 396 | # max_ = this_max 397 | 398 | # min_ -= .1 399 | # max_ += .1 400 | # # print (min_) 401 | # # print (max_) 402 | # ylimits = [min_, max_] 403 | # xlimits = [x[0], x[-1]] 404 | 405 | # # fasd 406 | 407 | # # ax.plot(x,hnf_ais, label='hnf_ais') 408 | # # ax.set_yticks([]) 409 | # # ax.set_xticks([]) 410 | # # if samp_i==0: ax.annotate('Sample', xytext=(.3, 1.1), xy=(0, 1), textcoords='axes fraction') 411 | 412 | # for m in range(len(models)): 413 | # ax = plt.subplot2grid((rows,cols), (0,m), frameon=False) 414 | # for i in range(len(legends)): 415 | # if i == 1: 416 | # continue 417 | # ax.set_title(model_names[m],family='serif') 418 | # ax.plot(x,models[m][i], label=legends[i], c=colors[i], ls=line_styles[i]) 419 | # plt.legend(fontsize=6) 420 | # # ax.set(adjustable='box-forced', aspect='equal') 421 | # plt.yticks(size=6) 422 | # # plt.xticks(x,size=6) 423 | # plt.xticks([400,1300,2200,3100],size=6) 424 | 425 | # # ax.set_xlim(xlimits) 426 | # ax.set_ylim(ylimits) 427 | # ax.set_xlim(xlimits) 428 | 429 | # ax.set_xlabel('Epochs',size=6) 430 | # if m==0: 431 | # ax.set_ylabel('Log-Likelihood',size=6) 432 | 433 | 434 | # ax.grid(True, alpha=.1) 435 | 436 | 437 | # # m+=1 438 | # # ax = plt.subplot2grid((rows,cols), (0,m), frameon=False) 439 | # # ax.set_title('AIS_test') 440 | # # for m in range(len(models)): 441 | # # ax.plot(x,models[m][3], label=model_names[m]) 442 | # # plt.legend(fontsize=4) 443 | # # plt.yticks(size=6) 444 | 445 | 446 | 447 | 448 | 449 | # # plt.gca().set_aspect('equal', adjustable='box') 450 | # name_file = home+'/Documents/tmp/plot.png' 451 | # plt.savefig(name_file) 452 | # print ('Saved fig', name_file) 453 | 454 | # name_file = home+'/Documents/tmp/plot.eps' 455 | # plt.savefig(name_file) 456 | # print ('Saved fig', name_file) 457 | 458 | 459 | # name_file = home+'/Documents/tmp/plot.pdf' 460 | # plt.savefig(name_file) 461 | # print ('Saved fig', name_file) 462 | 463 | 464 | 465 | 466 | 467 | 468 | 469 | 470 | 471 | 472 | -------------------------------------------------------------------------------- /test_set_inference_exp/compute_gaps.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | import numpy as np 7 | import gzip 8 | import time 9 | import pickle 10 | 11 | from os.path import expanduser 12 | home = expanduser("~") 13 | 14 | import sys, os 15 | sys.path.insert(0, '../models') 16 | sys.path.insert(0, '../models/utils') 17 | 18 | 19 | import matplotlib 20 | matplotlib.use('Agg') 21 | import matplotlib.pyplot as plt 22 | 23 | 24 | import torch 25 | from torch.autograd import Variable 26 | import torch.utils.data 27 | import torch.optim as optim 28 | import torch.nn as nn 29 | import torch.nn.functional as F 30 | 31 | 32 | from vae_2 import VAE 33 | 34 | 35 | # from approx_posteriors_v6 import standard 36 | from inference_net import standard 37 | 38 | # from ais3 import test_ais 39 | 40 | 41 | # from optimize_local import optimize_local_gaussian 42 | 43 | 44 | from optimize_local_q import optimize_local_q_dist 45 | 46 | 47 | 48 | 49 | from distributions import Gaussian 50 | from distributions import Flow 51 | 52 | 53 | 54 | 55 | 56 | 57 | def test_vae(model, data_x, batch_size, display, k): 58 | 59 | time_ = time.time() 60 | elbos = [] 61 | data_index= 0 62 | for i in range(int(len(data_x)/ batch_size)): 63 | 64 | batch = data_x[data_index:data_index+batch_size] 65 | data_index += batch_size 66 | 67 | batch = Variable(torch.from_numpy(batch)).type(model.dtype) 68 | 69 | elbo, logpxz, logqz = model.forward2(batch, k=k) 70 | 71 | elbos.append(elbo.data[0]) 72 | 73 | # if i%display==0: 74 | # print (i,len(data_x)/ batch_size, np.mean(elbos)) 75 | 76 | mean_ = np.mean(elbos) 77 | # print(mean_, 'T:', time.time()-time_) 78 | 79 | return mean_#, time.time()-time_ 80 | 81 | 82 | 83 | 84 | 85 | def test(model, data_x, batch_size, display, k): 86 | 87 | time_ = time.time() 88 | elbos = [] 89 | data_index= 0 90 | for i in range(int(len(data_x)/ batch_size)): 91 | 92 | batch = data_x[data_index:data_index+batch_size] 93 | data_index += batch_size 94 | 95 | batch = Variable(torch.from_numpy(batch)).type(model.dtype) 96 | 97 | elbo, logpxz, logqz = model(batch, k=k) 98 | 99 | elbos.append(elbo.data[0]) 100 | 101 | # if i%display==0: 102 | # print (i,len(data_x)/ batch_size, np.mean(elbos)) 103 | 104 | mean_ = np.mean(elbos) 105 | # print(mean_, 'T:', time.time()-time_) 106 | 107 | return mean_#, time.time()-time_ 108 | 109 | 110 | 111 | 112 | 113 | ########################### 114 | # Load data 115 | 116 | # print ('Loading data') 117 | # with open(home+'/Documents/MNIST_data/mnist.pkl','rb') as f: 118 | # mnist_data = pickle.load(f, encoding='latin1') 119 | # train_x = mnist_data[0][0] 120 | # valid_x = mnist_data[1][0] 121 | # test_x = mnist_data[2][0] 122 | # train_x = np.concatenate([train_x, valid_x], axis=0) 123 | # print (train_x.shape) 124 | 125 | #Load data 126 | print ('Loading data' ) 127 | data_location = home + '/Documents/MNIST_data/' 128 | # with open(data_location + 'binarized_mnist.pkl', 'rb') as f: 129 | # train_x, valid_x, test_x = pickle.load(f) 130 | with open(data_location + 'binarized_mnist.pkl', 'rb') as f: 131 | train_x, valid_x, test_x = pickle.load(f, encoding='latin1') 132 | print ('Train', train_x.shape) 133 | print ('Valid', valid_x.shape) 134 | print ('Test', test_x.shape) 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | ########################### 147 | # Load model 148 | 149 | 150 | 151 | # this_ckt_file = path_to_save_variables + str(ckt) + '.pt' 152 | # model.load_params(path_to_load_variables=this_ckt_file) 153 | # print ('Init model') 154 | # model = VAE(hyper_config) 155 | # if torch.cuda.is_available(): 156 | # model.cuda() 157 | 158 | # print('\nModel:', hyper_config,'\n') 159 | 160 | 161 | x_size = 784 162 | z_size = 50 163 | # batch_size = 20 164 | # k = 1 165 | #save params 166 | # start_at = 100 167 | # save_freq = 300 168 | # display_epoch = 3 169 | 170 | 171 | 172 | #small encoder 173 | # hyper_config = { 174 | # 'x_size': x_size, 175 | # 'z_size': z_size, 176 | # 'act_func': F.tanh,# F.relu, 177 | # 'encoder_arch': [[x_size,z_size*2]], 178 | # 'decoder_arch': [[z_size,200],[200,200],[200,x_size]], 179 | # 'q_dist': standard, #FFG_LN#,#hnf,#aux_nf,#flow1,#, 180 | # 'cuda': 1 181 | # } 182 | 183 | #no hidden decoder 184 | # hyper_config = { 185 | # 'x_size': x_size, 186 | # 'z_size': z_size, 187 | # 'act_func': F.tanh,# F.relu, 188 | # 'encoder_arch': [[x_size,200],[200,200],[200,z_size*2]], 189 | # 'decoder_arch': [[z_size,x_size]], 190 | # 'q_dist': standard, #FFG_LN#,#hnf,#aux_nf,#flow1,#, 191 | # 'cuda': 1 192 | # } 193 | 194 | # 2 hidden decoder 195 | hyper_config = { 196 | 'x_size': x_size, 197 | 'z_size': z_size, 198 | 'act_func': F.tanh,# F.relu, 199 | 'encoder_arch': [[x_size,200],[200,200],[200,z_size*2]], 200 | 'decoder_arch': [[z_size,200],[200,200],[200,x_size]], 201 | 'q_dist': standard, #FFG_LN#,#hnf,#aux_nf,#flow1,#, 202 | 'cuda': 1 203 | } 204 | 205 | 206 | # # 4 hidden decoder 207 | # hyper_config = { 208 | # 'x_size': x_size, 209 | # 'z_size': z_size, 210 | # 'act_func': F.tanh,# F.relu, 211 | # 'encoder_arch': [[x_size,200],[200,200],[200,z_size*2]], 212 | # 'decoder_arch': [[z_size,200],[200,200],[200,200],[200,200],[200,x_size]], 213 | # 'q_dist': standard, #FFG_LN#,#hnf,#aux_nf,#flow1,#, 214 | # 'cuda': 1 215 | # } 216 | 217 | 218 | 219 | q = Gaussian(hyper_config) 220 | # q = Flow(hyper_config) 221 | hyper_config['q'] = q 222 | 223 | 224 | 225 | 226 | # Which gpu 227 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' 228 | 229 | print ('Init model') 230 | model = VAE(hyper_config) 231 | if torch.cuda.is_available(): 232 | model.cuda() 233 | print('\nModel:', hyper_config,'\n') 234 | 235 | print (model.q_dist) 236 | # print (model.q_dist.q) 237 | print (model.generator) 238 | 239 | 240 | print ('Load params for decoder') 241 | path_to_load_variables=home+'/Documents/tmp/inference_suboptimality/vae_generator_3280.pt' 242 | # path_to_load_variables=home+'/Documents/tmp/inference_suboptimality/decoder_exps/hidden_layers_4_generator_3280.pt' 243 | # path_to_load_variables=home+'/Documents/tmp/inference_suboptimality/decoder_exps/hidden_layers_2_generator_3280.pt' 244 | # path_to_load_variables=home+'/Documents/tmp/inference_suboptimality/decoder_exps/hidden_layers_0_generator_3280.pt' 245 | 246 | model.generator.load_state_dict(torch.load(path_to_load_variables, map_location=lambda storage, loc: storage)) 247 | print ('loaded variables ' + path_to_load_variables) 248 | print () 249 | 250 | 251 | 252 | compute_local_opt = 1 253 | compute_amort = 1 254 | 255 | compute_local_opt_test = 1 256 | compute_amort_test = 1 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | if compute_amort: 265 | 266 | print ('Load params for encoder') 267 | # path_to_load_variables=home+'/Documents/tmp/inference_suboptimality/vae_encoder_100.pt' 268 | path_to_load_variables=home+'/Documents/tmp/inference_suboptimality/vae_encoder_3280.pt' 269 | # path_to_load_variables=home+'/Documents/tmp/inference_suboptimality/vae_smallencoder_encoder_3280.pt' 270 | # path_to_load_variables=home+'/Documents/tmp/inference_suboptimality/vae_regencoder_encoder_3280.pt' 271 | # path_to_load_variables=home+'/Documents/tmp/inference_suboptimality/vae_smallencoder_withflow_encoder_3280.pt' 272 | # path_to_load_variables=home+'/Documents/tmp/inference_suboptimality/decoder_exps/hidden_layers_4_encoder_3280.pt' 273 | # path_to_load_variables=home+'/Documents/tmp/inference_suboptimality/decoder_exps/hidden_layers_2_encoder_3280.pt' 274 | # path_to_load_variables=home+'/Documents/tmp/inference_suboptimality/decoder_exps/hidden_layers_0_encoder_3280.pt' 275 | 276 | 277 | 278 | 279 | model.q_dist.load_state_dict(torch.load(path_to_load_variables, map_location=lambda storage, loc: storage)) 280 | print ('loaded variables ' + path_to_load_variables) 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | ########################### 290 | # For each datapoint, compute L[q], L[q*], log p(x) 291 | 292 | # # log it 293 | # with open(experiment_log, "a") as myfile: 294 | # myfile.write('Checkpoint' +str(ckt)+'\n') 295 | 296 | # start_time = time.time() 297 | 298 | n_data = 100 #1000 #100 299 | 300 | vaes = [] 301 | iwaes = [] 302 | vaes_flex = [] 303 | iwaes_flex = [] 304 | 305 | 306 | 307 | if compute_local_opt: 308 | print ('optmizing local') 309 | for i in range(len(train_x[:n_data])): 310 | 311 | print (i) 312 | 313 | x = train_x[i] 314 | x = Variable(torch.from_numpy(x)).type(model.dtype) 315 | x = x.view(1,784) 316 | 317 | logposterior = lambda aa: model.logposterior_func2(x=x,z=aa) 318 | 319 | 320 | # # flex_model = aux_nf__(model, hyper_config) 321 | # # if torch.cuda.is_available(): 322 | # # flex_model.cuda() 323 | # # vae, iwae = flex_model.train_and_eval(logposterior=logposterior, model=model, x=x) 324 | 325 | 326 | # vae, iwae = optimize_local_expressive(logposterior, model, x) 327 | # print (vae.data.cpu().numpy(),iwae.data.cpu().numpy(),'flex') 328 | # vaes_flex.append(vae.data.cpu().numpy()) 329 | # iwaes_flex.append(iwae.data.cpu().numpy()) 330 | 331 | q_local = Gaussian(hyper_config) #, mean, logvar) 332 | # q_local = Flow(hyper_config).cuda()#, mean, logvar) 333 | 334 | # print (q_local) 335 | 336 | # vae, iwae = optimize_local_gaussian(logposterior, model, x) 337 | vae, iwae = optimize_local_q_dist(logposterior, hyper_config, x, q_local) 338 | print (vae.data.cpu().numpy(),iwae.data.cpu().numpy(),'reg') 339 | vaes.append(vae.data.cpu().numpy()) 340 | iwaes.append(iwae.data.cpu().numpy()) 341 | 342 | print() 343 | print ('opt vae',np.mean(vaes)) 344 | print ('opt iwae',np.mean(iwaes)) 345 | print() 346 | 347 | # print ('opt vae flex',np.mean(vaes_flex)) 348 | # print ('opt iwae flex',np.mean(iwaes_flex)) 349 | # print() 350 | 351 | if compute_amort: 352 | VAE_train = test_vae(model=model, data_x=train_x[:n_data], batch_size=np.minimum(n_data, 50), display=10, k=5000) 353 | IW_train = test(model=model, data_x=train_x[:n_data], batch_size=np.minimum(n_data, 50), display=10, k=5000) 354 | print ('amortized VAE',VAE_train) 355 | print ('amortized IW',IW_train) 356 | 357 | 358 | # print() 359 | # AIS_train = test_ais(model=model, data_x=train_x[:n_data], batch_size=n_data, display=2, k=50, n_intermediate_dists=500) 360 | # print ('AIS_train',AIS_train) 361 | 362 | 363 | 364 | # print() 365 | # print() 366 | # print ('AIS_train',AIS_train) 367 | # print() 368 | # print ('opt vae flex',np.mean(vaes_flex)) 369 | # # print() 370 | # print ('opt vae',np.mean(vaes)) 371 | # # print() 372 | # print ('amortized VAE',VAE_train) 373 | # print() 374 | 375 | 376 | 377 | 378 | 379 | 380 | 381 | # TEST SET 382 | print ('TEST SET') 383 | 384 | vaes_test = [] 385 | iwaes_test = [] 386 | # vaes_flex = [] 387 | # iwaes_flex = [] 388 | 389 | 390 | 391 | if compute_local_opt_test: 392 | print ('optmizing local') 393 | for i in range(len(test_x[:n_data])): 394 | 395 | print (i) 396 | 397 | x = test_x[i] 398 | x = Variable(torch.from_numpy(x)).type(model.dtype) 399 | x = x.view(1,784) 400 | 401 | logposterior = lambda aa: model.logposterior_func2(x=x,z=aa) 402 | 403 | 404 | # # flex_model = aux_nf__(model, hyper_config) 405 | # # if torch.cuda.is_available(): 406 | # # flex_model.cuda() 407 | # # vae, iwae = flex_model.train_and_eval(logposterior=logposterior, model=model, x=x) 408 | 409 | 410 | # vae, iwae = optimize_local_expressive(logposterior, model, x) 411 | # print (vae.data.cpu().numpy(),iwae.data.cpu().numpy(),'flex') 412 | # vaes_flex.append(vae.data.cpu().numpy()) 413 | # iwaes_flex.append(iwae.data.cpu().numpy()) 414 | 415 | q_local = Gaussian(hyper_config) #, mean, logvar) 416 | # q_local = Flow(hyper_config).cuda()#, mean, logvar) 417 | 418 | # print (q_local) 419 | 420 | # vae, iwae = optimize_local_gaussian(logposterior, model, x) 421 | vae, iwae = optimize_local_q_dist(logposterior, hyper_config, x, q_local) 422 | print (vae.data.cpu().numpy(),iwae.data.cpu().numpy(),'reg') 423 | vaes_test.append(vae.data.cpu().numpy()) 424 | iwaes_test.append(iwae.data.cpu().numpy()) 425 | 426 | print() 427 | print ('opt vae',np.mean(vaes_test)) 428 | print ('opt iwae',np.mean(iwaes_test)) 429 | print() 430 | 431 | # print ('opt vae flex',np.mean(vaes_flex)) 432 | # print ('opt iwae flex',np.mean(iwaes_flex)) 433 | # print() 434 | 435 | if compute_amort_test: 436 | VAE_test = test_vae(model=model, data_x=test_x[:n_data], batch_size=np.minimum(n_data, 50), display=10, k=5000) 437 | IW_test = test(model=model, data_x=test_x[:n_data], batch_size=np.minimum(n_data, 50), display=10, k=5000) 438 | print ('amortized VAE',VAE_test) 439 | print ('amortized IW',IW_test) 440 | 441 | 442 | 443 | 444 | print('TRAIN') 445 | print ('opt vae',np.mean(vaes)) 446 | print ('opt iwae',np.mean(iwaes)) 447 | print ('amortized VAE',VAE_train) 448 | print ('amortized IW',IW_train) 449 | 450 | 451 | print('TEST') 452 | print ('opt vae',np.mean(vaes_test)) 453 | print ('opt iwae',np.mean(iwaes_test)) 454 | print ('amortized VAE',VAE_test) 455 | print ('amortized IW',IW_test) 456 | 457 | 458 | 459 | 460 | 461 | 462 | -------------------------------------------------------------------------------- /models/pytorch_vae_v6.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | # adding layer norm 6 | 7 | import numpy as np 8 | import pickle 9 | # import cPickle as pickle 10 | from os.path import expanduser 11 | home = expanduser("~") 12 | import time 13 | import sys 14 | sys.path.insert(0, 'utils') 15 | 16 | import torch 17 | from torch.autograd import Variable 18 | import torch.utils.data 19 | import torch.optim as optim 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | 23 | from utils import lognormal2 as lognormal 24 | from utils import log_bernoulli 25 | 26 | from utils import LayerNorm 27 | 28 | 29 | from ais import test_ais 30 | 31 | from approx_posteriors_v5 import standard 32 | from approx_posteriors_v5 import flow1 33 | from approx_posteriors_v5 import aux_nf 34 | from approx_posteriors_v5 import hnf 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | class VAE(nn.Module): 44 | def __init__(self, hyper_config, seed=1): 45 | super(VAE, self).__init__() 46 | 47 | torch.manual_seed(seed) 48 | 49 | 50 | self.z_size = hyper_config['z_size'] 51 | self.x_size = hyper_config['x_size'] 52 | self.act_func = hyper_config['act_func'] 53 | 54 | self.q_dist = hyper_config['q_dist'](self, hyper_config=hyper_config) 55 | 56 | # for aaa in self.q_dist.parameters(): 57 | # # print (aaa) 58 | # print (aaa.size()) 59 | 60 | # # fasdfs 61 | 62 | 63 | if torch.cuda.is_available(): 64 | self.dtype = torch.cuda.FloatTensor 65 | self.q_dist.cuda() 66 | else: 67 | self.dtype = torch.FloatTensor 68 | 69 | 70 | #Decoder 71 | self.decoder_weights = [] 72 | self.layer_norms = [] 73 | for i in range(len(hyper_config['decoder_arch'])): 74 | self.decoder_weights.append(nn.Linear(hyper_config['decoder_arch'][i][0], hyper_config['decoder_arch'][i][1])) 75 | 76 | if i != len(hyper_config['decoder_arch'])-1: 77 | self.layer_norms.append(LayerNorm(hyper_config['decoder_arch'][i][1])) 78 | 79 | count =1 80 | for i in range(len(self.decoder_weights)): 81 | self.add_module(str(count), self.decoder_weights[i]) 82 | count+=1 83 | 84 | if i != len(hyper_config['decoder_arch'])-1: 85 | self.add_module(str(count), self.layer_norms[i]) 86 | count+=1 87 | 88 | # self.hyper_config = hyper_config 89 | 90 | # # See params 91 | # print('all') 92 | # for aaa in self.parameters(): 93 | # # print (aaa) 94 | # print (aaa.size()) 95 | # fsadfsa 96 | 97 | 98 | def decode(self, z): 99 | k = z.size()[0] 100 | B = z.size()[1] 101 | z = z.view(-1, self.z_size) 102 | 103 | out = z 104 | for i in range(len(self.decoder_weights)-1): 105 | # out = self.act_func(self.decoder_weights[i](out)) 106 | out = self.act_func(self.layer_norms[i].forward(self.decoder_weights[i](out))) 107 | out = self.decoder_weights[-1](out) 108 | 109 | x = out.view(k, B, self.x_size) 110 | return x 111 | 112 | 113 | def forward(self, x, k, warmup=1.): 114 | 115 | self.B = x.size()[0] #batch size 116 | self.zeros = Variable(torch.zeros(self.B, self.z_size).type(self.dtype)) 117 | 118 | self.logposterior = lambda aa: lognormal(aa, self.zeros, self.zeros) + log_bernoulli(self.decode(aa), x) 119 | 120 | z, logqz = self.q_dist.forward(k, x, self.logposterior) 121 | 122 | logpxz = self.logposterior(z) 123 | 124 | #Compute elbo 125 | elbo = logpxz - (warmup*logqz) #[P,B] 126 | if k>1: 127 | max_ = torch.max(elbo, 0)[0] #[B] 128 | elbo = torch.log(torch.mean(torch.exp(elbo - max_), 0)) + max_ #[B] 129 | 130 | elbo = torch.mean(elbo) #[1] 131 | logpxz = torch.mean(logpxz) #[1] 132 | logqz = torch.mean(logqz) 133 | 134 | return elbo, logpxz, logqz 135 | 136 | 137 | def sample_q(self, x, k): 138 | 139 | self.B = x.size()[0] #batch size 140 | self.zeros = Variable(torch.zeros(self.B, self.z_size).type(self.dtype)) 141 | 142 | self.logposterior = lambda aa: lognormal(aa, self.zeros, self.zeros) + log_bernoulli(self.decode(aa), x) 143 | 144 | z, logqz = self.q_dist.forward(k=k, x=x, logposterior=self.logposterior) 145 | 146 | return z 147 | 148 | 149 | def logposterior_func(self, x, z): 150 | self.B = x.size()[0] #batch size 151 | self.zeros = Variable(torch.zeros(self.B, self.z_size).type(self.dtype)) 152 | 153 | # print (x) #[B,X] 154 | # print(z) #[P,Z] 155 | z = Variable(z).type(self.dtype) 156 | z = z.view(-1,self.B,self.z_size) 157 | return lognormal(z, self.zeros, self.zeros) + log_bernoulli(self.decode(z), x) 158 | 159 | 160 | 161 | def logposterior_func2(self, x, z): 162 | self.B = x.size()[0] #batch size 163 | self.zeros = Variable(torch.zeros(self.B, self.z_size).type(self.dtype)) 164 | 165 | # print (x) #[B,X] 166 | # print(z) #[P,Z] 167 | # z = Variable(z).type(self.dtype) 168 | z = z.view(-1,self.B,self.z_size) 169 | 170 | # print (z) 171 | return lognormal(z, self.zeros, self.zeros) + log_bernoulli(self.decode(z), x) 172 | 173 | 174 | 175 | def forward2(self, x, k): 176 | 177 | self.B = x.size()[0] #batch size 178 | self.zeros = Variable(torch.zeros(self.B, self.z_size).type(self.dtype)) 179 | 180 | self.logposterior = lambda aa: lognormal(aa, self.zeros, self.zeros) + log_bernoulli(self.decode(aa), x) 181 | 182 | z, logqz = self.q_dist.forward(k, x, self.logposterior) 183 | 184 | logpxz = self.logposterior(z) 185 | 186 | #Compute elbo 187 | elbo = logpxz - logqz #[P,B] 188 | # if k>1: 189 | # max_ = torch.max(elbo, 0)[0] #[B] 190 | # elbo = torch.log(torch.mean(torch.exp(elbo - max_), 0)) + max_ #[B] 191 | 192 | elbo = torch.mean(elbo) #[1] 193 | logpxz = torch.mean(logpxz) #[1] 194 | logqz = torch.mean(logqz) 195 | 196 | return elbo, logpxz, logqz 197 | 198 | 199 | 200 | 201 | def forward3_prior(self, x, k): 202 | 203 | self.B = x.size()[0] #batch size 204 | self.zeros = Variable(torch.zeros(self.B, self.z_size).type(self.dtype)) 205 | 206 | self.logposterior = lambda aa: log_bernoulli(self.decode(aa), x) #+ lognormal(aa, self.zeros, self.zeros) 207 | 208 | # z, logqz = self.q_dist.forward(k, x, self.logposterior) 209 | 210 | z = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_().type(self.dtype)) #[P,B,Z] 211 | 212 | logpxz = self.logposterior(z) 213 | 214 | #Compute elbo 215 | elbo = logpxz #- logqz #[P,B] 216 | if k>1: 217 | max_ = torch.max(elbo, 0)[0] #[B] 218 | elbo = torch.log(torch.mean(torch.exp(elbo - max_), 0)) + max_ #[B] 219 | 220 | elbo = torch.mean(elbo) #[1] 221 | # logpxz = torch.mean(logpxz) #[1] 222 | # logqz = torch.mean(logqz) 223 | 224 | return elbo#, logpxz, logqz 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | def train(self, train_x, k, epochs, batch_size, display_epoch, learning_rate): 237 | 238 | optimizer = optim.Adam(self.parameters(), lr=learning_rate) 239 | time_ = time.time() 240 | n_data = len(train_x) 241 | arr = np.array(range(n_data)) 242 | 243 | for epoch in range(1, epochs + 1): 244 | 245 | #shuffle 246 | np.random.shuffle(arr) 247 | train_x = train_x[arr] 248 | 249 | data_index= 0 250 | for i in range(int(n_data/batch_size)): 251 | batch = train_x[data_index:data_index+batch_size] 252 | data_index += batch_size 253 | 254 | batch = Variable(torch.from_numpy(batch)).type(self.dtype) 255 | optimizer.zero_grad() 256 | 257 | elbo, logpxz, logqz = self.forward(batch, k=k) 258 | 259 | loss = -(elbo) 260 | loss.backward() 261 | optimizer.step() 262 | 263 | 264 | if epoch%display_epoch==0: 265 | print ('Train Epoch: {}/{}'.format(epoch, epochs), 266 | 'LL:{:.3f}'.format(-loss.data[0]), 267 | 'logpxz:{:.3f}'.format(logpxz.data[0]), 268 | # 'logpz:{:.3f}'.format(logpz.data[0]), 269 | 'logqz:{:.3f}'.format(logqz.data[0]), 270 | 'T:{:.2f}'.format(time.time()-time_), 271 | ) 272 | 273 | time_ = time.time() 274 | 275 | 276 | 277 | 278 | 279 | def test(self, data_x, batch_size, display, k): 280 | 281 | time_ = time.time() 282 | elbos = [] 283 | data_index= 0 284 | for i in range(int(len(data_x)/ batch_size)): 285 | 286 | batch = data_x[data_index:data_index+batch_size] 287 | data_index += batch_size 288 | 289 | batch = Variable(torch.from_numpy(batch)).type(self.dtype) 290 | 291 | elbo, logpxz, logqz = self(batch, k=k) 292 | 293 | elbos.append(elbo.data[0]) 294 | 295 | if i%display==0: 296 | print (i,len(data_x)/ batch_size, np.mean(elbos)) 297 | 298 | mean_ = np.mean(elbos) 299 | print(mean_, 'T:', time.time()-time_) 300 | 301 | 302 | 303 | 304 | 305 | def load_params(self, path_to_load_variables=''): 306 | # model.load_state_dict(torch.load(path_to_load_variables)) 307 | self.load_state_dict(torch.load(path_to_load_variables, map_location=lambda storage, loc: storage)) 308 | print ('loaded variables ' + path_to_load_variables) 309 | 310 | 311 | def save_params(self, path_to_save_variables=''): 312 | torch.save(self.state_dict(), path_to_save_variables) 313 | print ('saved variables ' + path_to_save_variables) 314 | 315 | 316 | 317 | 318 | 319 | 320 | 321 | 322 | 323 | 324 | 325 | 326 | 327 | 328 | # if __name__ == "__main__": 329 | 330 | # load_params = 0 331 | # train_ = 1 332 | # eval_IW = 1 333 | # eval_AIS = 0 334 | 335 | # print ('Loading data') 336 | # with open(home+'/Documents/MNIST_data/mnist.pkl','rb') as f: 337 | # mnist_data = pickle.load(f, encoding='latin1') 338 | 339 | # train_x = mnist_data[0][0] 340 | # valid_x = mnist_data[1][0] 341 | # test_x = mnist_data[2][0] 342 | 343 | # train_x = np.concatenate([train_x, valid_x], axis=0) 344 | 345 | # print (train_x.shape) 346 | 347 | # x_size = 784 348 | # z_size = 50 349 | 350 | # hyper_config = { 351 | # 'x_size': x_size, 352 | # 'z_size': z_size, 353 | # 'act_func': F.tanh,# F.relu, 354 | # 'encoder_arch': [[x_size,200],[200,200],[200,z_size*2]], 355 | # 'decoder_arch': [[z_size,200],[200,200],[200,x_size]], 356 | # 'q_dist': hnf,#aux_nf,#flow1,#standard,#, #, #, #,#, #,# , 357 | # 'n_flows': 2, 358 | # 'qv_arch': [[x_size,200],[200,200],[200,z_size*2]], 359 | # 'qz_arch': [[x_size+z_size,200],[200,200],[200,z_size*2]], 360 | # 'rv_arch': [[x_size+z_size,200],[200,200],[200,z_size*2]], 361 | # 'flow_hidden_size': 100 362 | # } 363 | 364 | 365 | # model = VAE(hyper_config) 366 | 367 | # if torch.cuda.is_available(): 368 | # model.cuda() 369 | 370 | 371 | 372 | # #Train params 373 | # learning_rate = .0001 374 | # batch_size = 100 375 | # epochs = 3000 376 | # display_epoch = 2 377 | # k = 1 378 | 379 | # path_to_load_variables='' 380 | # # path_to_load_variables=home+'/Documents/tmp/pytorch_bvae.pt' 381 | # path_to_save_variables=home+'/Documents/tmp/pytorch_vae'+str(epochs)+'.pt' 382 | # # path_to_save_variables='' 383 | 384 | 385 | 386 | # if load_params: 387 | # print ('\nLoading parameters') 388 | # model.load_params(path_to_save_variables) 389 | 390 | # if train_: 391 | 392 | # print('\nTraining') 393 | # print('k='+str(k), 'lr='+str(learning_rate), 'batch_size='+str(batch_size)) 394 | # print('\nModel:', hyper_config,'\n') 395 | # model.train(train_x=train_x, k=k, epochs=epochs, batch_size=batch_size, 396 | # display_epoch=display_epoch, learning_rate=learning_rate) 397 | # model.save_params(path_to_save_variables) 398 | 399 | 400 | # if eval_IW: 401 | # k_IW = 2000 402 | # batch_size = 20 403 | # print('\nTesting with IW, Train set[:10000], B'+str(batch_size)+' k'+str(k_IW)) 404 | # model.test(data_x=train_x[:10000], batch_size=batch_size, display=100, k=k_IW) 405 | 406 | # print('\nTesting with IW, Test set, B'+str(batch_size)+' k'+str(k_IW)) 407 | # model.test(data_x=test_x, batch_size=batch_size, display=100, k=k_IW) 408 | 409 | # if eval_AIS: 410 | # k_AIS = 10 411 | # batch_size = 100 412 | # n_intermediate_dists = 100 413 | # print('\nTesting with AIS, Train set[:10000], B'+str(batch_size)+' k'+str(k_AIS)+' intermediates'+str(n_intermediate_dists)) 414 | # test_ais(model, data_x=train_x[:10000], batch_size=batch_size, display=10, k=k_AIS, n_intermediate_dists=n_intermediate_dists) 415 | 416 | # print('\nTesting with AIS, Test set, B'+str(batch_size)+' k'+str(k_AIS)+' intermediates'+str(n_intermediate_dists)) 417 | # test_ais(model, data_x=test_x, batch_size=batch_size, display=10, k=k_AIS, n_intermediate_dists=n_intermediate_dists) 418 | 419 | 420 | 421 | 422 | 423 | 424 | 425 | 426 | 427 | 428 | 429 | 430 | 431 | 432 | 433 | 434 | 435 | 436 | 437 | 438 | -------------------------------------------------------------------------------- /models/vae_1.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | # adding layer norm 6 | 7 | import numpy as np 8 | import pickle 9 | # import cPickle as pickle 10 | from os.path import expanduser 11 | home = expanduser("~") 12 | import time 13 | import sys 14 | sys.path.insert(0, 'utils') 15 | 16 | import torch 17 | from torch.autograd import Variable 18 | import torch.utils.data 19 | import torch.optim as optim 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | 23 | from utils import lognormal2 as lognormal 24 | from utils import log_bernoulli 25 | 26 | # from utils import LayerNorm 27 | 28 | 29 | # from ais import test_ais 30 | 31 | # from approx_posteriors_v5 import standard 32 | # from approx_posteriors_v5 import flow1 33 | # from approx_posteriors_v5 import aux_nf 34 | # from approx_posteriors_v5 import hnf 35 | 36 | # from approx_posteriors_v6 import standard 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | class VAE(nn.Module): 45 | def __init__(self, hyper_config, seed=1): 46 | super(VAE, self).__init__() 47 | 48 | torch.manual_seed(seed) 49 | 50 | 51 | self.z_size = hyper_config['z_size'] 52 | self.x_size = hyper_config['x_size'] 53 | self.act_func = hyper_config['act_func'] 54 | 55 | self.q_dist = hyper_config['q_dist'](self, hyper_config=hyper_config) 56 | 57 | # for aaa in self.q_dist.parameters(): 58 | # # print (aaa) 59 | # print (aaa.size()) 60 | 61 | # # fasdfs 62 | 63 | 64 | if torch.cuda.is_available(): 65 | self.dtype = torch.cuda.FloatTensor 66 | self.q_dist.cuda() 67 | else: 68 | self.dtype = torch.FloatTensor 69 | 70 | 71 | #Decoder 72 | self.decoder_weights = [] 73 | self.layer_norms = [] 74 | for i in range(len(hyper_config['decoder_arch'])): 75 | self.decoder_weights.append(nn.Linear(hyper_config['decoder_arch'][i][0], hyper_config['decoder_arch'][i][1])) 76 | 77 | # if i != len(hyper_config['decoder_arch'])-1: 78 | # self.layer_norms.append(LayerNorm(hyper_config['decoder_arch'][i][1])) 79 | 80 | count =1 81 | for i in range(len(self.decoder_weights)): 82 | self.add_module(str(count), self.decoder_weights[i]) 83 | count+=1 84 | 85 | # if i != len(hyper_config['decoder_arch'])-1: 86 | # self.add_module(str(count), self.layer_norms[i]) 87 | # count+=1 88 | 89 | # self.hyper_config = hyper_config 90 | 91 | # # See params 92 | # print('all') 93 | # for aaa in self.parameters(): 94 | # # print (aaa) 95 | # print (aaa.size()) 96 | # fsadfsa 97 | 98 | 99 | def decode(self, z): 100 | k = z.size()[0] 101 | B = z.size()[1] 102 | z = z.view(-1, self.z_size) 103 | 104 | out = z 105 | for i in range(len(self.decoder_weights)-1): 106 | out = self.act_func(self.decoder_weights[i](out)) 107 | # out = self.act_func(self.layer_norms[i].forward(self.decoder_weights[i](out))) 108 | out = self.decoder_weights[-1](out) 109 | 110 | x = out.view(k, B, self.x_size) 111 | return x 112 | 113 | 114 | def forward(self, x, k, warmup=1.): 115 | 116 | self.B = x.size()[0] #batch size 117 | self.zeros = Variable(torch.zeros(self.B, self.z_size).type(self.dtype)) 118 | 119 | self.logposterior = lambda aa: lognormal(aa, self.zeros, self.zeros) + log_bernoulli(self.decode(aa), x) 120 | 121 | z, logqz = self.q_dist.forward(k, x, self.logposterior) 122 | 123 | logpxz = self.logposterior(z) 124 | 125 | #Compute elbo 126 | elbo = logpxz - (warmup*logqz) #[P,B] 127 | if k>1: 128 | max_ = torch.max(elbo, 0)[0] #[B] 129 | elbo = torch.log(torch.mean(torch.exp(elbo - max_), 0)) + max_ #[B] 130 | 131 | elbo = torch.mean(elbo) #[1] 132 | logpxz = torch.mean(logpxz) #[1] 133 | logqz = torch.mean(logqz) 134 | 135 | return elbo, logpxz, logqz 136 | 137 | 138 | def sample_q(self, x, k): 139 | 140 | self.B = x.size()[0] #batch size 141 | self.zeros = Variable(torch.zeros(self.B, self.z_size).type(self.dtype)) 142 | 143 | self.logposterior = lambda aa: lognormal(aa, self.zeros, self.zeros) + log_bernoulli(self.decode(aa), x) 144 | 145 | z, logqz = self.q_dist.forward(k=k, x=x, logposterior=self.logposterior) 146 | 147 | return z 148 | 149 | 150 | def logposterior_func(self, x, z): 151 | self.B = x.size()[0] #batch size 152 | self.zeros = Variable(torch.zeros(self.B, self.z_size).type(self.dtype)) 153 | 154 | # print (x) #[B,X] 155 | # print(z) #[P,Z] 156 | z = Variable(z).type(self.dtype) 157 | z = z.view(-1,self.B,self.z_size) 158 | return lognormal(z, self.zeros, self.zeros) + log_bernoulli(self.decode(z), x) 159 | 160 | 161 | 162 | def logposterior_func2(self, x, z): 163 | self.B = x.size()[0] #batch size 164 | self.zeros = Variable(torch.zeros(self.B, self.z_size).type(self.dtype)) 165 | 166 | # print (x) #[B,X] 167 | # print(z) #[P,Z] 168 | # z = Variable(z).type(self.dtype) 169 | z = z.view(-1,self.B,self.z_size) 170 | 171 | # print (z) 172 | return lognormal(z, self.zeros, self.zeros) + log_bernoulli(self.decode(z), x) 173 | 174 | 175 | 176 | def forward2(self, x, k): 177 | 178 | self.B = x.size()[0] #batch size 179 | self.zeros = Variable(torch.zeros(self.B, self.z_size).type(self.dtype)) 180 | 181 | self.logposterior = lambda aa: lognormal(aa, self.zeros, self.zeros) + log_bernoulli(self.decode(aa), x) 182 | 183 | z, logqz = self.q_dist.forward(k, x, self.logposterior) 184 | 185 | logpxz = self.logposterior(z) 186 | 187 | #Compute elbo 188 | elbo = logpxz - logqz #[P,B] 189 | # if k>1: 190 | # max_ = torch.max(elbo, 0)[0] #[B] 191 | # elbo = torch.log(torch.mean(torch.exp(elbo - max_), 0)) + max_ #[B] 192 | 193 | elbo = torch.mean(elbo) #[1] 194 | logpxz = torch.mean(logpxz) #[1] 195 | logqz = torch.mean(logqz) 196 | 197 | return elbo, logpxz, logqz 198 | 199 | 200 | 201 | 202 | def forward3_prior(self, x, k): 203 | 204 | self.B = x.size()[0] #batch size 205 | self.zeros = Variable(torch.zeros(self.B, self.z_size).type(self.dtype)) 206 | 207 | self.logposterior = lambda aa: log_bernoulli(self.decode(aa), x) #+ lognormal(aa, self.zeros, self.zeros) 208 | 209 | # z, logqz = self.q_dist.forward(k, x, self.logposterior) 210 | 211 | z = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_().type(self.dtype)) #[P,B,Z] 212 | 213 | logpxz = self.logposterior(z) 214 | 215 | #Compute elbo 216 | elbo = logpxz #- logqz #[P,B] 217 | if k>1: 218 | max_ = torch.max(elbo, 0)[0] #[B] 219 | elbo = torch.log(torch.mean(torch.exp(elbo - max_), 0)) + max_ #[B] 220 | 221 | elbo = torch.mean(elbo) #[1] 222 | # logpxz = torch.mean(logpxz) #[1] 223 | # logqz = torch.mean(logqz) 224 | 225 | return elbo#, logpxz, logqz 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | def train(self, train_x, k, epochs, batch_size, display_epoch, learning_rate): 238 | 239 | optimizer = optim.Adam(self.parameters(), lr=learning_rate) 240 | time_ = time.time() 241 | n_data = len(train_x) 242 | arr = np.array(range(n_data)) 243 | 244 | for epoch in range(1, epochs + 1): 245 | 246 | #shuffle 247 | np.random.shuffle(arr) 248 | train_x = train_x[arr] 249 | 250 | data_index= 0 251 | for i in range(int(n_data/batch_size)): 252 | batch = train_x[data_index:data_index+batch_size] 253 | data_index += batch_size 254 | 255 | batch = Variable(torch.from_numpy(batch)).type(self.dtype) 256 | optimizer.zero_grad() 257 | 258 | elbo, logpxz, logqz = self.forward(batch, k=k) 259 | 260 | loss = -(elbo) 261 | loss.backward() 262 | optimizer.step() 263 | 264 | 265 | if epoch%display_epoch==0: 266 | print ('Train Epoch: {}/{}'.format(epoch, epochs), 267 | 'LL:{:.3f}'.format(-loss.data[0]), 268 | 'logpxz:{:.3f}'.format(logpxz.data[0]), 269 | # 'logpz:{:.3f}'.format(logpz.data[0]), 270 | 'logqz:{:.3f}'.format(logqz.data[0]), 271 | 'T:{:.2f}'.format(time.time()-time_), 272 | ) 273 | 274 | time_ = time.time() 275 | 276 | 277 | 278 | 279 | 280 | def test(self, data_x, batch_size, display, k): 281 | 282 | time_ = time.time() 283 | elbos = [] 284 | data_index= 0 285 | for i in range(int(len(data_x)/ batch_size)): 286 | 287 | batch = data_x[data_index:data_index+batch_size] 288 | data_index += batch_size 289 | 290 | batch = Variable(torch.from_numpy(batch)).type(self.dtype) 291 | 292 | elbo, logpxz, logqz = self(batch, k=k) 293 | 294 | elbos.append(elbo.data[0]) 295 | 296 | if i%display==0: 297 | print (i,len(data_x)/ batch_size, np.mean(elbos)) 298 | 299 | mean_ = np.mean(elbos) 300 | print(mean_, 'T:', time.time()-time_) 301 | 302 | 303 | 304 | 305 | 306 | def load_params(self, path_to_load_variables=''): 307 | # model.load_state_dict(torch.load(path_to_load_variables)) 308 | self.load_state_dict(torch.load(path_to_load_variables, map_location=lambda storage, loc: storage)) 309 | print ('loaded variables ' + path_to_load_variables) 310 | 311 | 312 | def save_params(self, path_to_save_variables=''): 313 | torch.save(self.state_dict(), path_to_save_variables) 314 | print ('saved variables ' + path_to_save_variables) 315 | 316 | 317 | 318 | 319 | 320 | 321 | 322 | 323 | 324 | 325 | 326 | 327 | 328 | 329 | # if __name__ == "__main__": 330 | 331 | # load_params = 0 332 | # train_ = 1 333 | # eval_IW = 1 334 | # eval_AIS = 0 335 | 336 | # print ('Loading data') 337 | # with open(home+'/Documents/MNIST_data/mnist.pkl','rb') as f: 338 | # mnist_data = pickle.load(f, encoding='latin1') 339 | 340 | # train_x = mnist_data[0][0] 341 | # valid_x = mnist_data[1][0] 342 | # test_x = mnist_data[2][0] 343 | 344 | # train_x = np.concatenate([train_x, valid_x], axis=0) 345 | 346 | # print (train_x.shape) 347 | 348 | # x_size = 784 349 | # z_size = 50 350 | 351 | # hyper_config = { 352 | # 'x_size': x_size, 353 | # 'z_size': z_size, 354 | # 'act_func': F.tanh,# F.relu, 355 | # 'encoder_arch': [[x_size,200],[200,200],[200,z_size*2]], 356 | # 'decoder_arch': [[z_size,200],[200,200],[200,x_size]], 357 | # 'q_dist': hnf,#aux_nf,#flow1,#standard,#, #, #, #,#, #,# , 358 | # 'n_flows': 2, 359 | # 'qv_arch': [[x_size,200],[200,200],[200,z_size*2]], 360 | # 'qz_arch': [[x_size+z_size,200],[200,200],[200,z_size*2]], 361 | # 'rv_arch': [[x_size+z_size,200],[200,200],[200,z_size*2]], 362 | # 'flow_hidden_size': 100 363 | # } 364 | 365 | 366 | # model = VAE(hyper_config) 367 | 368 | # if torch.cuda.is_available(): 369 | # model.cuda() 370 | 371 | 372 | 373 | # #Train params 374 | # learning_rate = .0001 375 | # batch_size = 100 376 | # epochs = 3000 377 | # display_epoch = 2 378 | # k = 1 379 | 380 | # path_to_load_variables='' 381 | # # path_to_load_variables=home+'/Documents/tmp/pytorch_bvae.pt' 382 | # path_to_save_variables=home+'/Documents/tmp/pytorch_vae'+str(epochs)+'.pt' 383 | # # path_to_save_variables='' 384 | 385 | 386 | 387 | # if load_params: 388 | # print ('\nLoading parameters') 389 | # model.load_params(path_to_save_variables) 390 | 391 | # if train_: 392 | 393 | # print('\nTraining') 394 | # print('k='+str(k), 'lr='+str(learning_rate), 'batch_size='+str(batch_size)) 395 | # print('\nModel:', hyper_config,'\n') 396 | # model.train(train_x=train_x, k=k, epochs=epochs, batch_size=batch_size, 397 | # display_epoch=display_epoch, learning_rate=learning_rate) 398 | # model.save_params(path_to_save_variables) 399 | 400 | 401 | # if eval_IW: 402 | # k_IW = 2000 403 | # batch_size = 20 404 | # print('\nTesting with IW, Train set[:10000], B'+str(batch_size)+' k'+str(k_IW)) 405 | # model.test(data_x=train_x[:10000], batch_size=batch_size, display=100, k=k_IW) 406 | 407 | # print('\nTesting with IW, Test set, B'+str(batch_size)+' k'+str(k_IW)) 408 | # model.test(data_x=test_x, batch_size=batch_size, display=100, k=k_IW) 409 | 410 | # if eval_AIS: 411 | # k_AIS = 10 412 | # batch_size = 100 413 | # n_intermediate_dists = 100 414 | # print('\nTesting with AIS, Train set[:10000], B'+str(batch_size)+' k'+str(k_AIS)+' intermediates'+str(n_intermediate_dists)) 415 | # test_ais(model, data_x=train_x[:10000], batch_size=batch_size, display=10, k=k_AIS, n_intermediate_dists=n_intermediate_dists) 416 | 417 | # print('\nTesting with AIS, Test set, B'+str(batch_size)+' k'+str(k_AIS)+' intermediates'+str(n_intermediate_dists)) 418 | # test_ais(model, data_x=test_x, batch_size=batch_size, display=10, k=k_AIS, n_intermediate_dists=n_intermediate_dists) 419 | 420 | 421 | 422 | 423 | 424 | 425 | 426 | 427 | 428 | 429 | 430 | 431 | 432 | 433 | 434 | 435 | 436 | 437 | 438 | 439 | -------------------------------------------------------------------------------- /test_different_dists/train_mnist.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import numpy as np 5 | import gzip 6 | import time 7 | import pickle 8 | 9 | from os.path import expanduser 10 | home = expanduser("~") 11 | 12 | import sys, os 13 | sys.path.insert(0, '../models') 14 | sys.path.insert(0, '../models/utils') 15 | 16 | 17 | import matplotlib 18 | matplotlib.use('Agg') 19 | import matplotlib.pyplot as plt 20 | 21 | 22 | import torch 23 | from torch.autograd import Variable 24 | import torch.utils.data 25 | import torch.optim as optim 26 | import torch.nn as nn 27 | import torch.nn.functional as F 28 | 29 | # from ais2 import test_ais 30 | 31 | # from pytorch_vae_v6 import VAE 32 | 33 | # from vae_1 import VAE 34 | from vae_2 import VAE 35 | 36 | 37 | # from approx_posteriors_v6 import FFG_LN 38 | # from approx_posteriors_v6 import ANF_LN 39 | # import argparse 40 | # from approx_posteriors_v6 import standard 41 | from inference_net import standard 42 | 43 | from distributions import Gaussian 44 | from distributions import Flow 45 | from distributions import HNF 46 | from distributions import Flow1 47 | 48 | 49 | 50 | 51 | 52 | gpu_to_use = sys.argv[1] 53 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_to_use #'1' 54 | 55 | q_name = sys.argv[2] 56 | 57 | hnf = 0 58 | if q_name == 'Gaus': 59 | q = Gaussian 60 | elif q_name == 'Flow': 61 | q = Flow 62 | elif q_name == 'Flow1': 63 | q = Flow1 64 | elif q_name == 'HNF': 65 | q = HNF 66 | hnf = 1 67 | else: 68 | dfadfas 69 | 70 | 71 | 72 | 73 | # path_to_save_variables=home+'/Documents/tmp/inference_suboptimality/fashion_params/10k_binarized_fashion2_SSE_'+q_name #.pt' 74 | 75 | 76 | 77 | # path_to_save_variables=home+'/Documents/tmp/inference_suboptimality/fashion_params/binarized_fashion3_LE_'+q_name #.pt' 78 | 79 | path_to_save_variables=home+'/Documents/tmp/inference_suboptimality/fashion_params/binarized_fashion3_'+q_name #.pt' 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | #FASHION 89 | def load_mnist(path, kind='train'): 90 | 91 | images_path = os.path.join(path, 92 | '%s-images-idx3-ubyte.gz' 93 | % kind) 94 | 95 | with gzip.open(images_path, 'rb') as imgpath: 96 | images = np.frombuffer(imgpath.read(), dtype=np.uint8, 97 | offset=16).reshape(-1, 784) 98 | 99 | return images#, labels 100 | 101 | 102 | path = home+'/Documents/fashion_MNIST' 103 | 104 | train_x = load_mnist(path=path) 105 | test_x = load_mnist(path=path, kind='t10k') 106 | 107 | train_x = train_x / 255. 108 | test_x = test_x / 255. 109 | 110 | #binarize 111 | train_x = (train_x > .5).astype(float) 112 | test_x = (test_x > .5).astype(float) 113 | 114 | 115 | print (train_x.shape) 116 | print (test_x.shape) 117 | print () 118 | 119 | valid_x = train_x[50000:] 120 | train_x = train_x[:50000] 121 | # train_x = train_x[:10000] #small dataset 122 | 123 | 124 | print (train_x.shape) 125 | print (valid_x.shape) 126 | print (test_x.shape) 127 | print () 128 | 129 | 130 | # fdsa 131 | 132 | 133 | 134 | 135 | # print (train_x) 136 | # fads 137 | 138 | # print (np.max(train_x)) 139 | # print (test_x[3]) 140 | # fsda 141 | 142 | 143 | # print ('Loading data') 144 | # with open(home+'/Documents/MNIST_data/mnist.pkl','rb') as f: 145 | # mnist_data = pickle.load(f, encoding='latin1') 146 | # train_x = mnist_data[0][0] 147 | # valid_x = mnist_data[1][0] 148 | # test_x = mnist_data[2][0] 149 | # train_x = np.concatenate([train_x, valid_x], axis=0) 150 | # print (train_x.shape) 151 | 152 | 153 | 154 | # #Load data mnist 155 | # print ('Loading data' ) 156 | # data_location = home + '/Documents/MNIST_data/' 157 | # # with open(data_location + 'binarized_mnist.pkl', 'rb') as f: 158 | # # train_x, valid_x, test_x = pickle.load(f) 159 | # with open(data_location + 'binarized_mnist.pkl', 'rb') as f: 160 | # train_x, valid_x, test_x = pickle.load(f, encoding='latin1') 161 | # print ('Train', train_x.shape) 162 | # print ('Valid', valid_x.shape) 163 | # print ('Test', test_x.shape) 164 | 165 | 166 | # print (np.max(train_x)) 167 | 168 | # fadad 169 | 170 | 171 | 172 | 173 | def train_encoder_and_decoder(model, train_x, test_x, k, batch_size, 174 | start_at, save_freq, display_epoch, 175 | path_to_save_variables): 176 | 177 | train_y = torch.from_numpy(np.zeros(len(train_x))) 178 | train_x = torch.from_numpy(train_x).float().type(model.dtype) 179 | 180 | train_ = torch.utils.data.TensorDataset(train_x, train_y) 181 | train_loader = torch.utils.data.DataLoader(train_, batch_size=batch_size, shuffle=True) 182 | 183 | #IWAE paper training strategy 184 | time_ = time.time() 185 | total_epochs = 0 186 | 187 | i_max = 7 188 | # i_max = 6 189 | 190 | warmup_over_epochs = 100. 191 | # warmup_over_epochs = 20. 192 | 193 | 194 | all_params = [] 195 | for aaa in model.q_dist.parameters(): 196 | all_params.append(aaa) 197 | for aaa in model.generator.parameters(): 198 | all_params.append(aaa) 199 | # print (len(all_params), 'number of params') 200 | 201 | print (model.q_dist) 202 | # print (model.q_dist.q) 203 | print (model.generator) 204 | 205 | # fads 206 | 207 | 208 | for i in range(0,i_max+1): 209 | 210 | lr = .001 * 10**(-i/float(i_max)) 211 | print (i, 'LR:', lr) 212 | 213 | # # optimizer = optim.Adam(model.parameters(), lr=lr) 214 | # print (model.q_dist) 215 | # print (model.generator) 216 | # print (model.q_dist.parameters()) 217 | # print (model.generator.parameters()) 218 | 219 | # print ('Encoder') 220 | # for aaa in model.q_dist.parameters(): 221 | # # print (aaa) 222 | # print (aaa.size()) 223 | # print ('Decoder') 224 | # for aaa in model.generator.parameters(): 225 | # # print (aaa) 226 | # print (aaa.size()) 227 | # # fasdfs 228 | # fads 229 | 230 | 231 | optimizer = optim.Adam(all_params, lr=lr) 232 | 233 | epochs = 3**(i) 234 | 235 | for epoch in range(1, epochs + 1): 236 | 237 | for batch_idx, (data, target) in enumerate(train_loader): 238 | 239 | batch = Variable(data)#.type(model.dtype) 240 | 241 | optimizer.zero_grad() 242 | 243 | warmup = total_epochs/warmup_over_epochs 244 | if warmup > 1.: 245 | warmup = 1. 246 | 247 | elbo, logpxz, logqz = model.forward(batch, k=k, warmup=warmup) 248 | 249 | loss = -(elbo) 250 | loss.backward() 251 | optimizer.step() 252 | 253 | total_epochs += 1 254 | 255 | 256 | if total_epochs%display_epoch==0: 257 | print ('Train Epoch: {}/{}'.format(epoch, epochs), 258 | 'total_epochs {}'.format(total_epochs), 259 | 'LL:{:.3f}'.format(-loss.data[0]), 260 | 'logpxz:{:.3f}'.format(logpxz.data[0]), 261 | 'logqz:{:.3f}'.format(logqz.data[0]), 262 | 'warmup:{:.3f}'.format(warmup), 263 | 'T:{:.2f}'.format(time.time()-time_), 264 | ) 265 | time_ = time.time() 266 | 267 | 268 | if total_epochs >= start_at and (total_epochs-start_at)%save_freq==0: 269 | 270 | # save params 271 | save_file = path_to_save_variables+'_encoder_'+str(total_epochs)+'.pt' 272 | torch.save(model.q_dist.state_dict(), save_file) 273 | print ('saved variables ' + save_file) 274 | save_file = path_to_save_variables+'_generator_'+str(total_epochs)+'.pt' 275 | torch.save(model.generator.state_dict(), save_file) 276 | print ('saved variables ' + save_file) 277 | 278 | 279 | 280 | # save params 281 | save_file = path_to_save_variables+'_encoder_'+str(total_epochs)+'.pt' 282 | torch.save(model.q_dist.state_dict(), save_file) 283 | print ('saved variables ' + save_file) 284 | save_file = path_to_save_variables+'_generator_'+str(total_epochs)+'.pt' 285 | torch.save(model.generator.state_dict(), save_file) 286 | print ('saved variables ' + save_file) 287 | 288 | 289 | print ('done training') 290 | 291 | 292 | 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | # Which gpu 304 | # os.environ['CUDA_VISIBLE_DEVICES'] = '1' 305 | 306 | 307 | x_size = 784 308 | z_size = 20 309 | batch_size = 50 310 | k = 1 311 | #save params 312 | # start_at = 50 313 | # save_freq = 250 314 | start_at = 100 315 | save_freq = 300 316 | 317 | display_epoch = 3 318 | 319 | hyper_config = { 320 | 'x_size': x_size, 321 | 'z_size': z_size, 322 | 'act_func': F.elu, #F.tanh,# F.relu, 323 | 'encoder_arch': [[x_size,200],[200,200],[200,z_size*2]], 324 | 'decoder_arch': [[z_size,200],[200,200],[200,x_size]], 325 | 'q_dist': standard, #FFG_LN#,#hnf,#aux_nf,#flow1,#, 326 | 'cuda': 1, 327 | 'hnf': hnf 328 | } 329 | 330 | 331 | # #LB 332 | # hyper_config = { 333 | # 'x_size': x_size, 334 | # 'z_size': z_size, 335 | # 'act_func': F.elu, #F.tanh,# F.relu, 336 | # 'encoder_arch': [[x_size,500],[500,500],[500,z_size*2]], 337 | # 'decoder_arch': [[z_size,500],[500,500],[500,x_size]], 338 | # 'q_dist': standard, #FFG_LN#,#hnf,#aux_nf,#flow1,#, 339 | # 'cuda': 1, 340 | # 'hnf': hnf 341 | # } 342 | 343 | 344 | # #LD 345 | # hyper_config = { 346 | # 'x_size': x_size, 347 | # 'z_size': z_size, 348 | # 'act_func': F.elu, #F.tanh,# F.relu, 349 | # 'encoder_arch': [[x_size,200],[200,200],[200,z_size*2]], 350 | # 'decoder_arch': [[z_size,500],[500,500],[500,500],[500,x_size]], 351 | # 'q_dist': standard, #FFG_LN#,#hnf,#aux_nf,#flow1,#, 352 | # 'cuda': 1, 353 | # 'hnf': hnf 354 | # } 355 | 356 | 357 | 358 | 359 | # #LE 360 | # hyper_config = { 361 | # 'x_size': x_size, 362 | # 'z_size': z_size, 363 | # 'act_func': F.elu, #F.tanh,# F.relu, 364 | # 'encoder_arch': [[x_size,500],[500,500],[500,500],[500,z_size*2]], 365 | # 'decoder_arch': [[z_size,200],[200,200],[200,x_size]], 366 | # 'q_dist': standard, #FFG_LN#,#hnf,#aux_nf,#flow1,#, 367 | # 'cuda': 1, 368 | # 'hnf': hnf 369 | # } 370 | 371 | 372 | 373 | 374 | # #SE 375 | # hyper_config = { 376 | # 'x_size': x_size, 377 | # 'z_size': z_size, 378 | # 'act_func': F.elu, #F.tanh,# F.relu, 379 | # 'encoder_arch': [[x_size,100],[100,z_size*2]], 380 | # 'decoder_arch': [[z_size,200],[200,200],[200,x_size]], 381 | # 'q_dist': standard, #FFG_LN#,#hnf,#aux_nf,#flow1,#, 382 | # 'cuda': 1, 383 | # 'hnf': hnf 384 | # } 385 | 386 | 387 | 388 | 389 | # #SSE 390 | # hyper_config = { 391 | # 'x_size': x_size, 392 | # 'z_size': z_size, 393 | # 'act_func': F.elu, #F.tanh,# F.relu, 394 | # 'encoder_arch': [[x_size,z_size*2]], 395 | # 'decoder_arch': [[z_size,200],[200,200],[200,x_size]], 396 | # 'q_dist': standard, #FFG_LN#,#hnf,#aux_nf,#flow1,#, 397 | # 'cuda': 1, 398 | # 'hnf': hnf 399 | # } 400 | 401 | 402 | 403 | # #LE 404 | # hyper_config = { 405 | # 'x_size': x_size, 406 | # 'z_size': z_size, 407 | # 'act_func': F.elu, #F.tanh,# F.relu, 408 | # 'encoder_arch': [[x_size,500],[500,500],[500,z_size*2]], 409 | # 'decoder_arch': [[z_size,200],[200,200],[200,x_size]], 410 | # 'q_dist': standard, #FFG_LN#,#hnf,#aux_nf,#flow1,#, 411 | # 'cuda': 1 412 | # } 413 | 414 | # hyper_config = { 415 | # 'x_size': x_size, 416 | # 'z_size': z_size, 417 | # 'act_func': F.tanh,# F.relu, 418 | # 'encoder_arch': [[x_size,200],[200,200],[200,z_size*2]], 419 | # 'decoder_arch': [[z_size,x_size]], 420 | # 'q_dist': standard, #FFG_LN#,#hnf,#aux_nf,#flow1,#, 421 | # 'cuda': 1 422 | # } 423 | 424 | 425 | # hyper_config = { 426 | # 'x_size': x_size, 427 | # 'z_size': z_size, 428 | # 'act_func': F.tanh,# F.relu, 429 | # 'encoder_arch': [[x_size,200],[200,200],[200,z_size*2]], 430 | # 'decoder_arch': [[z_size,200],[200,200],[200,200],[200,200],[200,x_size]], 431 | # 'q_dist': standard, #FFG_LN#,#hnf,#aux_nf,#flow1,#, 432 | # 'cuda': 1 433 | # } 434 | 435 | 436 | # q = Gaussian(hyper_config) 437 | # # q = Flow(hyper_config) 438 | hyper_config['q'] = q(hyper_config) 439 | 440 | 441 | print ('Init model') 442 | model = VAE(hyper_config) 443 | if torch.cuda.is_available(): 444 | model.cuda() 445 | 446 | print('\nModel:', hyper_config,'\n') 447 | 448 | 449 | 450 | 451 | # path_to_load_variables='' 452 | # path_to_save_variables=home+'/Documents/tmp/inference_suboptimality/fashion_params/LE_binarized_fashion' #.pt' 453 | # path_to_save_variables=home+'/Documents/tmp/inference_suboptimality/fashion_params/binarized_fashion_' #.pt' 454 | 455 | # path_to_save_variables=home+'/Documents/tmp/pytorch_vae'+str(epochs)+'.pt' 456 | # path_to_save_variables=this_dir+'/params_'+model_name+'_' 457 | # path_to_save_variables='' 458 | 459 | 460 | 461 | print('\nTraining') 462 | # train_lr_schedule(model=model, train_x=train_x, test_x=test_x, k=k, batch_size=batch_size, 463 | # start_at=start_at, save_freq=save_freq, display_epoch=display_epoch, 464 | # path_to_save_variables=path_to_save_variables) 465 | 466 | 467 | train_encoder_and_decoder(model=model, train_x=train_x, test_x=test_x, k=k, batch_size=batch_size, 468 | start_at=start_at, save_freq=save_freq, display_epoch=display_epoch, 469 | path_to_save_variables=path_to_save_variables) 470 | 471 | print ('Done.') 472 | 473 | 474 | 475 | 476 | 477 | 478 | 479 | 480 | 481 | 482 | 483 | 484 | 485 | 486 | 487 | 488 | 489 | 490 | 491 | 492 | 493 | 494 | 495 | 496 | 497 | 498 | 499 | 500 | 501 | 502 | 503 | 504 | 505 | 506 | 507 | 508 | 509 | 510 | -------------------------------------------------------------------------------- /models/vae_2.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | #separate class for decoder 4 | 5 | 6 | 7 | 8 | 9 | 10 | # adding layer norm 11 | 12 | import numpy as np 13 | import pickle 14 | # import cPickle as pickle 15 | from os.path import expanduser 16 | home = expanduser("~") 17 | import time 18 | import sys 19 | sys.path.insert(0, 'utils') 20 | 21 | import torch 22 | from torch.autograd import Variable 23 | import torch.utils.data 24 | import torch.optim as optim 25 | import torch.nn as nn 26 | import torch.nn.functional as F 27 | 28 | from utils import lognormal2 as lognormal 29 | from utils import log_bernoulli 30 | 31 | import os 32 | 33 | 34 | # from utils import LayerNorm 35 | 36 | 37 | # from ais import test_ais 38 | 39 | # from approx_posteriors_v5 import standard 40 | # from approx_posteriors_v5 import flow1 41 | # from approx_posteriors_v5 import aux_nf 42 | # from approx_posteriors_v5 import hnf 43 | 44 | # from approx_posteriors_v6 import standard 45 | 46 | from generator import Generator 47 | 48 | 49 | 50 | 51 | 52 | 53 | class VAE(nn.Module): 54 | def __init__(self, hyper_config, seed=1): 55 | super(VAE, self).__init__() 56 | 57 | torch.manual_seed(seed) 58 | 59 | 60 | self.z_size = hyper_config['z_size'] 61 | self.x_size = hyper_config['x_size'] 62 | self.act_func = hyper_config['act_func'] 63 | 64 | 65 | self.q_dist = hyper_config['q_dist'](hyper_config=hyper_config) 66 | # self.q_dist = hyper_config['q_dist'](self, hyper_config=hyper_config) 67 | # print (self.q_dist.parameters()) 68 | 69 | 70 | self.generator = Generator(hyper_config=hyper_config) 71 | # print (self.generator.parameters()) 72 | # fasd 73 | 74 | 75 | # print ('Encoder') 76 | # for aaa in self.q_dist.parameters(): 77 | # # print (aaa) 78 | # print (aaa.size()) 79 | # print ('Decoder') 80 | # for aaa in self.generator.parameters(): 81 | # # print (aaa) 82 | # print (aaa.size()) 83 | # # fasdfs 84 | 85 | # if hyper_config[''] 86 | # os.environ['CUDA_VISIBLE_DEVICES'] = hyper_config['cuda'] 87 | 88 | 89 | if torch.cuda.is_available(): 90 | self.dtype = torch.cuda.FloatTensor 91 | self.q_dist.cuda() 92 | else: 93 | self.dtype = torch.FloatTensor 94 | 95 | 96 | # #Decoder 97 | # self.decoder_weights = [] 98 | # self.layer_norms = [] 99 | # for i in range(len(hyper_config['decoder_arch'])): 100 | # self.decoder_weights.append(nn.Linear(hyper_config['decoder_arch'][i][0], hyper_config['decoder_arch'][i][1])) 101 | 102 | # # if i != len(hyper_config['decoder_arch'])-1: 103 | # # self.layer_norms.append(LayerNorm(hyper_config['decoder_arch'][i][1])) 104 | 105 | # count =1 106 | # for i in range(len(self.decoder_weights)): 107 | # self.add_module(str(count), self.decoder_weights[i]) 108 | # count+=1 109 | 110 | # if i != len(hyper_config['decoder_arch'])-1: 111 | # self.add_module(str(count), self.layer_norms[i]) 112 | # count+=1 113 | 114 | # self.hyper_config = hyper_config 115 | 116 | # # See params 117 | # print('all') 118 | # for aaa in self.parameters(): 119 | # # print (aaa) 120 | # print (aaa.size()) 121 | # fsadfsa 122 | 123 | 124 | # def decode(self, z): 125 | # k = z.size()[0] 126 | # B = z.size()[1] 127 | # z = z.view(-1, self.z_size) 128 | 129 | # out = z 130 | # for i in range(len(self.decoder_weights)-1): 131 | # out = self.act_func(self.decoder_weights[i](out)) 132 | # # out = self.act_func(self.layer_norms[i].forward(self.decoder_weights[i](out))) 133 | # out = self.decoder_weights[-1](out) 134 | 135 | # x = out.view(k, B, self.x_size) 136 | # return x 137 | 138 | 139 | def forward(self, x, k, warmup=1.): 140 | 141 | self.B = x.size()[0] #batch size 142 | self.zeros = Variable(torch.zeros(self.B, self.z_size).type(self.dtype)) 143 | 144 | self.logposterior = lambda aa: lognormal(aa, self.zeros, self.zeros) + log_bernoulli(self.generator.decode(aa), x) 145 | 146 | z, logqz = self.q_dist.forward(k, x, self.logposterior) 147 | 148 | logpxz = self.logposterior(z) 149 | 150 | #Compute elbo 151 | elbo = logpxz - (warmup*logqz) #[P,B] 152 | if k>1: 153 | max_ = torch.max(elbo, 0)[0] #[B] 154 | elbo = torch.log(torch.mean(torch.exp(elbo - max_), 0)) + max_ #[B] 155 | 156 | elbo = torch.mean(elbo) #[1] 157 | logpxz = torch.mean(logpxz) #[1] 158 | logqz = torch.mean(logqz) 159 | 160 | return elbo, logpxz, logqz 161 | 162 | 163 | def sample_q(self, x, k): 164 | 165 | self.B = x.size()[0] #batch size 166 | self.zeros = Variable(torch.zeros(self.B, self.z_size).type(self.dtype)) 167 | 168 | self.logposterior = lambda aa: lognormal(aa, self.zeros, self.zeros) + log_bernoulli(self.generator.decode(aa), x) 169 | 170 | z, logqz = self.q_dist.forward(k=k, x=x, logposterior=self.logposterior) 171 | 172 | return z 173 | 174 | 175 | def logposterior_func(self, x, z): 176 | self.B = x.size()[0] #batch size 177 | self.zeros = Variable(torch.zeros(self.B, self.z_size).type(self.dtype)) 178 | 179 | # print (x) #[B,X] 180 | # print(z) #[P,Z] 181 | z = Variable(z).type(self.dtype) 182 | z = z.view(-1,self.B,self.z_size) 183 | return lognormal(z, self.zeros, self.zeros) + log_bernoulli(self.generator.decode(z), x) 184 | 185 | 186 | 187 | def logposterior_func2(self, x, z): 188 | self.B = x.size()[0] #batch size 189 | self.zeros = Variable(torch.zeros(self.B, self.z_size).type(self.dtype)) 190 | 191 | # print (x) #[B,X] 192 | # print(z) #[P,Z] 193 | # z = Variable(z).type(self.dtype) 194 | z = z.view(-1,self.B,self.z_size) 195 | 196 | # print (z) 197 | return lognormal(z, self.zeros, self.zeros) + log_bernoulli(self.generator.decode(z), x) 198 | 199 | 200 | 201 | def forward2(self, x, k): 202 | 203 | self.B = x.size()[0] #batch size 204 | self.zeros = Variable(torch.zeros(self.B, self.z_size).type(self.dtype)) 205 | 206 | self.logposterior = lambda aa: lognormal(aa, self.zeros, self.zeros) + log_bernoulli(self.generator.decode(aa), x) 207 | 208 | z, logqz = self.q_dist.forward(k, x, self.logposterior) 209 | 210 | logpxz = self.logposterior(z) 211 | 212 | #Compute elbo 213 | elbo = logpxz - logqz #[P,B] 214 | # if k>1: 215 | # max_ = torch.max(elbo, 0)[0] #[B] 216 | # elbo = torch.log(torch.mean(torch.exp(elbo - max_), 0)) + max_ #[B] 217 | 218 | elbo = torch.mean(elbo) #[1] 219 | logpxz = torch.mean(logpxz) #[1] 220 | logqz = torch.mean(logqz) 221 | 222 | return elbo, logpxz, logqz 223 | 224 | 225 | 226 | 227 | def forward3_prior(self, x, k): 228 | 229 | self.B = x.size()[0] #batch size 230 | self.zeros = Variable(torch.zeros(self.B, self.z_size).type(self.dtype)) 231 | 232 | self.logposterior = lambda aa: log_bernoulli(self.generator.decode(aa), x) #+ lognormal(aa, self.zeros, self.zeros) 233 | 234 | # z, logqz = self.q_dist.forward(k, x, self.logposterior) 235 | 236 | z = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_().type(self.dtype)) #[P,B,Z] 237 | 238 | logpxz = self.logposterior(z) 239 | 240 | #Compute elbo 241 | elbo = logpxz #- logqz #[P,B] 242 | if k>1: 243 | max_ = torch.max(elbo, 0)[0] #[B] 244 | elbo = torch.log(torch.mean(torch.exp(elbo - max_), 0)) + max_ #[B] 245 | 246 | elbo = torch.mean(elbo) #[1] 247 | # logpxz = torch.mean(logpxz) #[1] 248 | # logqz = torch.mean(logqz) 249 | 250 | return elbo#, logpxz, logqz 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | # def train(self, train_x, k, epochs, batch_size, display_epoch, learning_rate): 263 | 264 | # optimizer = optim.Adam(self.parameters(), lr=learning_rate) 265 | # time_ = time.time() 266 | # n_data = len(train_x) 267 | # arr = np.array(range(n_data)) 268 | 269 | # for epoch in range(1, epochs + 1): 270 | 271 | # #shuffle 272 | # np.random.shuffle(arr) 273 | # train_x = train_x[arr] 274 | 275 | # data_index= 0 276 | # for i in range(int(n_data/batch_size)): 277 | # batch = train_x[data_index:data_index+batch_size] 278 | # data_index += batch_size 279 | 280 | # batch = Variable(torch.from_numpy(batch)).type(self.dtype) 281 | # optimizer.zero_grad() 282 | 283 | # elbo, logpxz, logqz = self.forward(batch, k=k) 284 | 285 | # loss = -(elbo) 286 | # loss.backward() 287 | # optimizer.step() 288 | 289 | 290 | # if epoch%display_epoch==0: 291 | # print ('Train Epoch: {}/{}'.format(epoch, epochs), 292 | # 'LL:{:.3f}'.format(-loss.data[0]), 293 | # 'logpxz:{:.3f}'.format(logpxz.data[0]), 294 | # # 'logpz:{:.3f}'.format(logpz.data[0]), 295 | # 'logqz:{:.3f}'.format(logqz.data[0]), 296 | # 'T:{:.2f}'.format(time.time()-time_), 297 | # ) 298 | 299 | # time_ = time.time() 300 | 301 | 302 | 303 | 304 | 305 | # def test(self, data_x, batch_size, display, k): 306 | 307 | # time_ = time.time() 308 | # elbos = [] 309 | # data_index= 0 310 | # for i in range(int(len(data_x)/ batch_size)): 311 | 312 | # batch = data_x[data_index:data_index+batch_size] 313 | # data_index += batch_size 314 | 315 | # batch = Variable(torch.from_numpy(batch)).type(self.dtype) 316 | 317 | # elbo, logpxz, logqz = self(batch, k=k) 318 | 319 | # elbos.append(elbo.data[0]) 320 | 321 | # if i%display==0: 322 | # print (i,len(data_x)/ batch_size, np.mean(elbos)) 323 | 324 | # mean_ = np.mean(elbos) 325 | # print(mean_, 'T:', time.time()-time_) 326 | 327 | 328 | 329 | 330 | 331 | # def load_params(self, path_to_load_variables=''): 332 | # # model.load_state_dict(torch.load(path_to_load_variables)) 333 | # self.load_state_dict(torch.load(path_to_load_variables, map_location=lambda storage, loc: storage)) 334 | # print ('loaded variables ' + path_to_load_variables) 335 | 336 | 337 | # def save_params(self, path_to_save_variables=''): 338 | # torch.save(self.state_dict(), path_to_save_variables) 339 | # print ('saved variables ' + path_to_save_variables) 340 | 341 | 342 | 343 | 344 | 345 | 346 | 347 | 348 | 349 | 350 | 351 | 352 | 353 | 354 | # if __name__ == "__main__": 355 | 356 | # load_params = 0 357 | # train_ = 1 358 | # eval_IW = 1 359 | # eval_AIS = 0 360 | 361 | # print ('Loading data') 362 | # with open(home+'/Documents/MNIST_data/mnist.pkl','rb') as f: 363 | # mnist_data = pickle.load(f, encoding='latin1') 364 | 365 | # train_x = mnist_data[0][0] 366 | # valid_x = mnist_data[1][0] 367 | # test_x = mnist_data[2][0] 368 | 369 | # train_x = np.concatenate([train_x, valid_x], axis=0) 370 | 371 | # print (train_x.shape) 372 | 373 | # x_size = 784 374 | # z_size = 50 375 | 376 | # hyper_config = { 377 | # 'x_size': x_size, 378 | # 'z_size': z_size, 379 | # 'act_func': F.tanh,# F.relu, 380 | # 'encoder_arch': [[x_size,200],[200,200],[200,z_size*2]], 381 | # 'decoder_arch': [[z_size,200],[200,200],[200,x_size]], 382 | # 'q_dist': hnf,#aux_nf,#flow1,#standard,#, #, #, #,#, #,# , 383 | # 'n_flows': 2, 384 | # 'qv_arch': [[x_size,200],[200,200],[200,z_size*2]], 385 | # 'qz_arch': [[x_size+z_size,200],[200,200],[200,z_size*2]], 386 | # 'rv_arch': [[x_size+z_size,200],[200,200],[200,z_size*2]], 387 | # 'flow_hidden_size': 100 388 | # } 389 | 390 | 391 | # model = VAE(hyper_config) 392 | 393 | # if torch.cuda.is_available(): 394 | # model.cuda() 395 | 396 | 397 | 398 | # #Train params 399 | # learning_rate = .0001 400 | # batch_size = 100 401 | # epochs = 3000 402 | # display_epoch = 2 403 | # k = 1 404 | 405 | # path_to_load_variables='' 406 | # # path_to_load_variables=home+'/Documents/tmp/pytorch_bvae.pt' 407 | # path_to_save_variables=home+'/Documents/tmp/pytorch_vae'+str(epochs)+'.pt' 408 | # # path_to_save_variables='' 409 | 410 | 411 | 412 | # if load_params: 413 | # print ('\nLoading parameters') 414 | # model.load_params(path_to_save_variables) 415 | 416 | # if train_: 417 | 418 | # print('\nTraining') 419 | # print('k='+str(k), 'lr='+str(learning_rate), 'batch_size='+str(batch_size)) 420 | # print('\nModel:', hyper_config,'\n') 421 | # model.train(train_x=train_x, k=k, epochs=epochs, batch_size=batch_size, 422 | # display_epoch=display_epoch, learning_rate=learning_rate) 423 | # model.save_params(path_to_save_variables) 424 | 425 | 426 | # if eval_IW: 427 | # k_IW = 2000 428 | # batch_size = 20 429 | # print('\nTesting with IW, Train set[:10000], B'+str(batch_size)+' k'+str(k_IW)) 430 | # model.test(data_x=train_x[:10000], batch_size=batch_size, display=100, k=k_IW) 431 | 432 | # print('\nTesting with IW, Test set, B'+str(batch_size)+' k'+str(k_IW)) 433 | # model.test(data_x=test_x, batch_size=batch_size, display=100, k=k_IW) 434 | 435 | # if eval_AIS: 436 | # k_AIS = 10 437 | # batch_size = 100 438 | # n_intermediate_dists = 100 439 | # print('\nTesting with AIS, Train set[:10000], B'+str(batch_size)+' k'+str(k_AIS)+' intermediates'+str(n_intermediate_dists)) 440 | # test_ais(model, data_x=train_x[:10000], batch_size=batch_size, display=10, k=k_AIS, n_intermediate_dists=n_intermediate_dists) 441 | 442 | # print('\nTesting with AIS, Test set, B'+str(batch_size)+' k'+str(k_AIS)+' intermediates'+str(n_intermediate_dists)) 443 | # test_ais(model, data_x=test_x, batch_size=batch_size, display=10, k=k_AIS, n_intermediate_dists=n_intermediate_dists) 444 | 445 | 446 | 447 | 448 | 449 | 450 | 451 | 452 | 453 | 454 | 455 | 456 | 457 | 458 | 459 | 460 | 461 | 462 | 463 | 464 | --------------------------------------------------------------------------------