├── Figure ├── 3raxbh.gif ├── 3risln.gif ├── Figure_1.png ├── Figure_2.png ├── Figure_3.png └── Figure_4.png ├── README.md ├── Utils.py ├── main.py └── model ├── ConvIWAE.py ├── ExplicitIWAE.py └── PytorchIWAE.py /Figure/3raxbh.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JohanYe/IWAE-pytorch/059527de3da8920b0b7688dea516ce2adf76202b/Figure/3raxbh.gif -------------------------------------------------------------------------------- /Figure/3risln.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JohanYe/IWAE-pytorch/059527de3da8920b0b7688dea516ce2adf76202b/Figure/3risln.gif -------------------------------------------------------------------------------- /Figure/Figure_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JohanYe/IWAE-pytorch/059527de3da8920b0b7688dea516ce2adf76202b/Figure/Figure_1.png -------------------------------------------------------------------------------- /Figure/Figure_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JohanYe/IWAE-pytorch/059527de3da8920b0b7688dea516ce2adf76202b/Figure/Figure_2.png -------------------------------------------------------------------------------- /Figure/Figure_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JohanYe/IWAE-pytorch/059527de3da8920b0b7688dea516ce2adf76202b/Figure/Figure_3.png -------------------------------------------------------------------------------- /Figure/Figure_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JohanYe/IWAE-pytorch/059527de3da8920b0b7688dea516ce2adf76202b/Figure/Figure_4.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Importance Weighted Autoencoders (IWAE) 2 | Link to paper: https://arxiv.org/abs/1509.00519 3 | 4 | AnalyticalIWAE: IWAE calculating loss manually 5 | 6 | PytorchIWAE: IWAE using built-in torch functions to evaluate and calculation loss.
7 | Includes example of algorithm very easy to apply to existing VAE (although a bit slower) 8 | 9 | ConvIWAE: An example of convolutional IWAE, not integrated with main script, only as example 10 | 11 | 12 | # Importance Weighted Autoencoders - Gaussian encoder and decoder 13 | #### Pytorch IWAE Loss Curve: 14 | ![MNIST](Figure/Figure_3.png) 15 | 16 | #### Pytorch IWAE 60 epoch results: 17 | ![MNIST sampled sampels](Figure/Figure_4.png) 18 | 19 | #### Training gif 20 | ![Giffygifgif1](Figure/3risln.gif) 21 | 22 | ## Importance Weighted Autoencoders - Gaussian encoder, Bernoulli decoder 23 | #### Analytical IWAE Loss Curve: 24 | ![MNIST sampled sampels](Figure/Figure_1.png) 25 | 26 | #### Analytical IWAE 60 epoch results: 27 | ![MNIST sampled sampels](Figure/Figure_2.png) 28 | 29 | #### Training gif 30 | ![Giffygifgif2](Figure/3raxbh.gif) 31 | 32 | -------------------------------------------------------------------------------- /Utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import torch.nn as nn 4 | 5 | 6 | 7 | def Plot_loss_curve(train_list, test_dict): 8 | x_tst = list(test_dict.keys()) 9 | y_tst = list(test_dict.values()) 10 | train_x_vals = np.arange(len(train_list)) 11 | plt.figure(2) 12 | plt.xlabel('Num Steps') 13 | plt.ylabel('ELBO') 14 | plt.title('ELBO Loss Curve') 15 | plt.plot(train_x_vals, train_list, label='train') 16 | plt.plot(x_tst, y_tst, label='tst') 17 | plt.legend(loc='best') 18 | plt.locator_params(axis='x', nbins=10) 19 | 20 | plt.show() 21 | return 22 | 23 | def create_canvas(x): 24 | rows = 10 25 | columns = 10 26 | 27 | plt.figure(1) 28 | canvas = np.zeros((28 * rows, columns * 28)) 29 | for i in range(rows): 30 | for j in range(columns): 31 | idx = i % columns + rows * j 32 | canvas[i * 28:(i + 1) * 28, j * 28:(j + 1) * 28] = x[idx].reshape((28, 28)) 33 | return canvas 34 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import torch.optim as optim 3 | import seaborn as sns 4 | from model.ExplicitIWAE import * 5 | from model.PytorchIWAE import * 6 | from Utils import * 7 | 8 | sns.set_style("darkgrid") 9 | 10 | # Hyperparameters 11 | gif_pics = True 12 | batch_size = 250 13 | lr = 1e-4 14 | num_epochs = 65 15 | train_log = [] 16 | test_log = {} 17 | k = 0 18 | num_samples = 5 19 | beta = 0 20 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 21 | 22 | # Model 23 | Explicit = True 24 | Implicit = False 25 | 26 | if (Explicit + Implicit) > 1: 27 | print('More than one model enabled') 28 | sys.exit() 29 | 30 | if Explicit: 31 | net = AnalyticalIWAE(1024, 512, 32).to(device) 32 | if Implicit: 33 | net = PytorchIWAE(1024, 512, 32).to(device) 34 | optimizer = optim.Adam(net.parameters(), lr=lr) 35 | 36 | # Data loading 37 | t = torchvision.transforms.transforms.ToTensor() 38 | train_data = torchvision.datasets.MNIST('./', train=True, transform=t, target_transform=None, download=True) 39 | test_data = torchvision.datasets.MNIST('./', train=False, transform=t, target_transform=None, download=True) 40 | train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True) 41 | test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False) 42 | 43 | for epoch in range(num_epochs): 44 | for idx, train_iter in enumerate(train_loader): 45 | batch, label = train_iter[0], train_iter[1] 46 | batch = batch.view(batch.size(0), -1) # flatten 47 | batch = batch.expand(num_samples, batch.shape[0], -1).to(device) # make num_samples copies 48 | 49 | batch_loss = net.calc_loss(batch, beta) 50 | optimizer.zero_grad() 51 | batch_loss.backward() 52 | optimizer.step() 53 | 54 | train_log.append(batch_loss.item()) 55 | if beta < 2: 56 | beta += 0.001 # Warm-up 57 | k += 1 58 | 59 | loss_batch_mean = [] 60 | for idx, test_iter in enumerate(test_loader): 61 | batch, label = train_iter[0], train_iter[1] 62 | batch = batch.view(batch.size(0), -1) # flatten 63 | batch = batch.expand(num_samples, batch.shape[0], batch.shape[1]).to(device) # make num_samples copies 64 | 65 | test_loss = net.calc_loss(batch, beta) 66 | 67 | loss_batch_mean.append(test_loss.detach().item()) 68 | 69 | test_log[k] = np.mean(loss_batch_mean) 70 | if gif_pics and epoch % 2 == 0: 71 | if Explicit: 72 | batch = batch[0, :100, :].squeeze() 73 | recon_x = net(batch) 74 | else: 75 | batch = batch[0, :100, :].unsqueeze(0) 76 | recon_x = net(batch)[0].squeeze() #get mu only 77 | 78 | samples = net.sample(100).detach().cpu() 79 | fig, axs = plt.subplots(1, 2, figsize=(5, 10)) 80 | 81 | # Reconstructions 82 | recon_x = create_canvas(recon_x.detach().cpu()) 83 | axs[0].set_title('Epoch {} Reconstructions'.format(epoch + 1)) 84 | axs[0].axis('off') 85 | axs[0].imshow(recon_x, cmap='gray') 86 | 87 | # Samples 88 | samples = create_canvas(samples) 89 | axs[1].set_title('Epoch {} Sampled Samples'.format(epoch + 1)) 90 | axs[1].axis('off') 91 | axs[1].imshow(samples, cmap='gray') 92 | save_path = './Figure/GIF/gif_pic' + str(epoch + 1) + '.jpg' 93 | plt.savefig(save_path, bbox_inches='tight') 94 | plt.close() 95 | 96 | print('[Epoch: {}/{}][Step: {}]\tTrain Loss: {},\tTest Loss: {}'.format( 97 | epoch + 1, num_epochs, k, round(train_log[k - 1], 2), round(test_log[k], 2))) 98 | 99 | ###### Loss Curve Plotting ###### 100 | Plot_loss_curve(train_log, test_log) 101 | plt.savefig('./Figure/Figure_1.png', bbox_inches='tight') 102 | plt.close() 103 | 104 | ###### Sampling ######### 105 | x = next(iter(train_loader))[0].to(device) 106 | x = x.view(x.size(0), -1)[:100] # flatten and limit to 100 107 | if Explicit: 108 | recon_x = net(x) 109 | else: 110 | recon_x = net(x)[0] 111 | 112 | fig, axs = plt.subplots(1, 3, figsize=(15, 5)) 113 | x_true = create_canvas(x.detach().cpu()) 114 | axs[0].set_title('Ground Truth MNIST Digits') 115 | axs[0].axis('off') 116 | axs[0].imshow(x_true, cmap='gray') 117 | 118 | recon_x = create_canvas(recon_x.detach().cpu()) 119 | axs[1].set_title('Reconstructed MNIST Digits') 120 | axs[1].axis('off') 121 | axs[1].imshow(recon_x, cmap='gray') 122 | 123 | samples = net.sample(100).detach().cpu() 124 | samples = create_canvas(samples) 125 | axs[2].set_title('Sampled MNIST Digits') 126 | axs[2].axis('off') 127 | axs[2].imshow(samples, cmap='gray') 128 | plt.savefig('./Figure/Figure_2.png', bbox_inches='tight') 129 | plt.close() 130 | -------------------------------------------------------------------------------- /model/ConvIWAE.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | import torch.distributions as td 5 | 6 | # NOT INTEGRATED WITH MAIN.PY, ONLY FOR SHOW 7 | 8 | 9 | class Flatten(nn.Module): 10 | def forward(self, input): 11 | return input.view(input.size(0), 50 * 7 * 7) 12 | 13 | 14 | class UnFlatten(nn.Module): 15 | def forward(self, input): 16 | return input.view(input.size(0), 50, 7, 7) 17 | 18 | 19 | class ConvIWAE(nn.Module): 20 | def __init__(self, z_dim=20, bs): 21 | super(ConvIWAE, self).__init__() 22 | self.z_dim = z_dim 23 | self.analytic_kl = True 24 | self.batch_size = bs 25 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 26 | 27 | self.block1 = nn.Sequential( 28 | nn.Conv2d(1, 240, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(0.2), 29 | nn.Conv2d(240, 160, kernel_size=3, stride=2, padding=1), nn.LeakyReLU(0.2), 30 | nn.BatchNorm2d(160), 31 | nn.Conv2d(160, 80, kernel_size=3, stride=1, padding=1), nn.ReLU(), 32 | nn.BatchNorm2d(80), 33 | nn.Conv2d(80, 50, kernel_size=3, stride=2, padding=1), nn.ReLU(), 34 | Flatten(), 35 | nn.Linear(50 * 7 * 7, 700), 36 | nn.ReLU()) 37 | 38 | self.mu = nn.Linear(700, z_dim) 39 | self.std = nn.Linear(700, z_dim) 40 | 41 | self.dec = nn.Sequential( 42 | nn.Linear(z_dim, 700), nn.ReLU(), 43 | nn.Linear(700, 50 * 7 * 7), nn.ReLU(), 44 | UnFlatten(), 45 | nn.ConvTranspose2d(50, 220, kernel_size=3, stride=2, padding=1, output_padding=1), nn.ReLU(), 46 | nn.BatchNorm2d(220), 47 | nn.ConvTranspose2d(220, 160, kernel_size=3, stride=1, padding=1), nn.ReLU(), 48 | nn.BatchNorm2d(160), 49 | nn.ConvTranspose2d(160, 60, kernel_size=3, stride=2, padding=1, output_padding=1), nn.ReLU(), 50 | nn.BatchNorm2d(60), 51 | nn.ConvTranspose2d(60, 2, kernel_size=3, stride=1, padding=1), 52 | nn.Sigmoid()) # not sure why we sigmoid variance 53 | # 54 | 55 | def encoder(self, x): 56 | x = self.block1(x) 57 | mu = self.mu(x) 58 | std = nn.functional.softplus(self.std(x)) 59 | return mu, std 60 | 61 | def reparameterize(self, mu, std, K, training): 62 | if training == True: 63 | qz_Gx_obs = td.Normal(loc=mu, scale=std) 64 | z = qz_Gx_obs.rsample(torch.Size([K])) # if we do [m,k] we get (m,k,batch,zdim) 65 | else: 66 | z = mu.view(self.batch_size, -1) 67 | return z 68 | 69 | def decoder(self, z): 70 | K, bs = z.size(0), z.size(1) 71 | z = z.view([K * bs, -1]) 72 | x = self.dec(z) 73 | x = x.view([K, bs, 2, 28, 28]) 74 | x_mean = x[:, :, :1, :, :] # (K, bs, param, dim1, dim2) 75 | x_std = nn.functional.softplus(x[:, :, 1:, :, :]) # (K, bs, param, dim1, dim2) 76 | return x_mean, x_std 77 | 78 | def elbo(self, mu_dec, std_dec, mu_enc, std_enc, x, z, K, beta): 79 | qz_Gx_obs = td.Normal(loc=mu_enc, scale=std_enc) # z_dist 80 | p_z = td.Normal(torch.zeros([self.z_dim]).to(self.device), torch.ones([self.z_dim]).to(self.device)) 81 | 82 | if self.analytic_kl: 83 | kl = td.kl_divergence(qz_Gx_obs, p_z).sum(-1) 84 | else: 85 | lpz = p_z.log_prob(z).sum(-1) 86 | lqzx = qz_Gx_obs.log_prob(z).sum(-1) 87 | kl = lqzx - lpz 88 | 89 | xgz = td.Normal(loc=mu_dec, scale=std_dec) # x_dist 90 | lpxgz = xgz.log_prob(x).sum([-1, -2, -3]) # (K,bs) 91 | elbo = lpxgz - beta * kl 92 | 93 | loss = -torch.mean(torch.logsumexp(elbo, 0)) 94 | return loss 95 | 96 | def forward(self, x, K, beta, training): 97 | mu_enc, std_enc = self.encoder(x) 98 | z = self.reparameterize(mu_enc, std_enc, K, training) 99 | mu_dec, std_dec = self.decoder(z) # (K, bs, dim1, dim2) 100 | loss = self.elbo(mu_dec, std_dec, mu_enc, std_enc, x, z, K, beta) 101 | 102 | return loss, mu_enc, std_enc, mu_dec, std_dec -------------------------------------------------------------------------------- /model/ExplicitIWAE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class AnalyticalIWAE(nn.Module): 7 | # Calculates loss explicitly. 8 | def __init__(self, num_hidden1, num_hidden2, latent_space): 9 | super(AnalyticalIWAE, self).__init__() 10 | self.fc1 = nn.Sequential( 11 | nn.Linear(in_features=784, out_features=num_hidden1), 12 | nn.ReLU(), 13 | nn.Linear(num_hidden1, num_hidden2), 14 | nn.ReLU(), 15 | ) 16 | 17 | self.fc21 = nn.Linear(in_features=num_hidden2, out_features=latent_space) 18 | self.fc22 = nn.Linear(in_features=num_hidden2, out_features=latent_space) 19 | 20 | self.fc3 = nn.Sequential( 21 | nn.Linear(in_features=latent_space, out_features=num_hidden2), 22 | nn.ReLU(), 23 | nn.Linear(num_hidden2, num_hidden1), 24 | nn.ReLU(), 25 | ) 26 | self.decode = nn.Linear(in_features=num_hidden1, out_features=784) 27 | self.latent = latent_space 28 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 29 | 30 | for m in self.modules(): 31 | if isinstance(m, nn.Linear): 32 | nn.init.kaiming_normal_(m.weight) 33 | 34 | def encode(self, x): 35 | x = self.fc1(x) 36 | mu = self.fc21(x) 37 | log_var = self.fc22(x) 38 | 39 | # Reparameterize 40 | eps = torch.randn_like(log_var) 41 | h = mu + torch.exp(log_var * 0.5) * eps 42 | return mu, log_var, h, eps 43 | 44 | def forward(self, x): 45 | """ 46 | Purely to see reconstruction, not for calculating loss. 47 | """ 48 | 49 | # Encode 50 | mu, log_var, h, eps = self.encode(x) 51 | 52 | # decode 53 | recon_X = self.fc3(h) 54 | recon_X = torch.sigmoid(self.decode(recon_X)) 55 | return recon_X 56 | 57 | def calc_loss(self, x, beta): 58 | 59 | # Encode 60 | mu, log_var, h, eps = self.encode(x) 61 | 62 | # Calculating P(x,h) 63 | log_Ph = torch.sum(-0.5 * h ** 2 - 0.5 * torch.log(2 * h.new_tensor(np.pi)), 64 | -1) # equivalent to lognormal if mu=0,std=1 (i think) 65 | recon_X = torch.sigmoid(self.decode(self.fc3(h))) # Creating reconstructions 66 | log_PxGh = torch.sum(x * torch.log(recon_X) + (1 - x) * torch.log(1 - recon_X), 67 | -1) # Bernoulli decoder: Appendix c.1 Kingma p(x|h) 68 | log_Pxh = log_Ph + log_PxGh # log(p(x,h)) 69 | log_QhGx = torch.sum(-0.5 * (eps) ** 2 - 0.5 * torch.log(2 * h.new_tensor(np.pi)) - 0.5 * log_var, 70 | -1) # Evaluation in lognormal 71 | 72 | # Weighting according to equation 13 from IWAE paper 73 | log_weight = (log_Pxh - log_QhGx).detach().data 74 | log_weight = log_weight - torch.max(log_weight, 0)[0] 75 | weight = torch.exp(log_weight) 76 | weight = weight / torch.sum(weight, 0) 77 | 78 | # scaling 79 | loss = torch.mean(-torch.sum(weight * (log_PxGh + (log_Ph - log_QhGx)*beta), 0)) 80 | 81 | return loss 82 | 83 | def sample(self, n_samples): 84 | eps = torch.randn((n_samples, self.latent)).to(self.device) 85 | sample = self.fc3(eps) 86 | sample = torch.sigmoid(self.decode(sample)) 87 | return sample 88 | -------------------------------------------------------------------------------- /model/PytorchIWAE.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.distributions as td 4 | import torch.nn.functional as F 5 | 6 | 7 | # define network 8 | class PytorchIWAE(nn.Module): 9 | # Network uses in-built pytorch function for variational inference, instead of having to explicitly 10 | # calculate it 11 | def __init__(self, num_hidden1, num_hidden2, latent, in_dim=784): 12 | super(PytorchIWAE, self).__init__() 13 | self.latent = latent 14 | self.out_dec = in_dim 15 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | 17 | self.block1 = nn.Sequential( 18 | nn.Linear(in_features=in_dim, out_features=num_hidden1), 19 | nn.ReLU(), 20 | nn.Linear(num_hidden1, num_hidden2), 21 | nn.ReLU(), 22 | ) 23 | self.mu_enc = nn.Linear(in_features=num_hidden2, out_features=self.latent) 24 | self.lvar_enc = nn.Linear(in_features=num_hidden2, out_features=self.latent) 25 | 26 | self.block2 = nn.Sequential( 27 | nn.Linear(in_features=self.latent, out_features=num_hidden2), 28 | nn.ReLU(), 29 | nn.Linear(num_hidden2, num_hidden1), 30 | nn.ReLU(), 31 | ) 32 | 33 | self.mu_dec = nn.Linear(in_features=num_hidden1, out_features=self.out_dec) 34 | self.lvar_dec = nn.Linear(in_features=num_hidden1, out_features=self.out_dec) 35 | 36 | for m in self.modules(): 37 | if isinstance(m, nn.Linear): 38 | nn.init.kaiming_normal_(m.weight) 39 | 40 | def encoder(self, x): 41 | x = self.block1(x) 42 | mu = self.mu_enc(x) 43 | log_var = self.lvar_enc(x) 44 | return mu, log_var 45 | 46 | def decoder(self, z): 47 | h = self.block2(z) 48 | 49 | mu_x = torch.sigmoid(self.mu_dec(h)) 50 | var_x = torch.sigmoid(self.lvar_dec(h)) # Stability reasons 51 | 52 | return (mu_x, var_x) 53 | 54 | 55 | def reparameterize(self, mu, std): 56 | qz_Gx_obs = td.Normal(loc=mu, scale=std) 57 | # find z|x 58 | z_Gx = qz_Gx_obs.rsample() 59 | return z_Gx, qz_Gx_obs 60 | 61 | def forward(self, x, train=True): 62 | mu, log_var = self.encoder(x) 63 | if train: 64 | std = log_var.mul(0.5).exp_() 65 | z, _ = self.reparameterize(mu, std) 66 | else: 67 | z = mu 68 | x_recon = self.decoder(z) 69 | # can also show mu 70 | return x_recon 71 | 72 | def calc_loss(self, x, beta): 73 | # Encode 74 | mu_enc, log_var_enc = self.encoder(x) 75 | std_enc = torch.exp(0.5 * log_var_enc) 76 | 77 | # Reparameterize: 78 | z_Gx, qz_Gx_obs = self.reparameterize(mu_enc, std_enc) 79 | mu_dec, log_var_dec = self.decoder(z_Gx) 80 | 81 | # Find q(z|x) 82 | log_QhGx = qz_Gx_obs.log_prob(z_Gx) 83 | log_QhGx = torch.sum(log_QhGx, -1) 84 | 85 | # Find p(z) 86 | mu_prior = torch.zeros(self.latent).to(self.device) 87 | std_prior = torch.ones(self.latent).to(self.device) 88 | p_z = td.Normal(loc=mu_prior, scale=std_prior) 89 | log_Ph = torch.sum(p_z.log_prob(z_Gx), -1) 90 | 91 | # Find p(x|z) 92 | std_dec = log_var_dec.mul(0.5).exp_() 93 | px_Gz = td.Normal(loc=mu_dec, scale=std_dec).log_prob(x) 94 | log_PxGh = torch.sum(px_Gz, -1) 95 | # print(log_PxGh, log_Ph, log_QhGx) 96 | # Calculate loss 97 | 98 | w = log_PxGh + (log_Ph - log_QhGx)*beta 99 | loss = -torch.mean(torch.logsumexp(w, 0)) 100 | return loss 101 | 102 | def calc_loss_simple(self, x, beta, k_samples): 103 | """ 104 | Simple to understand and lazy algorithm, but slow 105 | Not made compatible with remainin script 106 | Useful as it is likely easy to implement into existing models 107 | """ 108 | from torch.distributions.kl import kl_divergence as KL 109 | 110 | for _ in range(k_samples): 111 | # Encode 112 | mu_enc, log_var_enc = self.encoder(x) 113 | std_enc = torch.exp(0.5 * log_var_enc) 114 | 115 | # Reparameterize: 116 | z_Gx, qz_Gx_obs = self.reparameterize(mu_enc, std_enc) 117 | mu_prior = torch.zeros(self.latent).to(self.device) 118 | std_prior = torch.ones(self.latent).to(self.device) 119 | p_z = td.Normal(loc=mu_prior, scale=std_prior) 120 | 121 | #decode 122 | mu_dec, log_var_dec = self.decoder(z_Gx) 123 | std_dec = log_var_dec.mul(0.5).exp_() 124 | px_Gz = td.Normal(loc=mu_dec, scale=std_dec).log_prob(x) 125 | 126 | if log_px_z is None: 127 | log_px_z = px_Gz.log_prob(x).mean(-1).unsqueeze(1) 128 | kld = KL(qz_Gx_obs, p_z).unsqueeze(1) 129 | else: 130 | log_px_z = torch.cat([log_px_z, px_Gz.log_prob(x).mean(-1).unsqueeze(1)], 1) 131 | kld = torch.cat([kld, KL(qz_Gx_obs, p_z).unsqueeze(1)], 1) 132 | 133 | # loss calculation 134 | log_wk = log_px_z - kld 135 | L_k = log_wk.logsumexp(dim=-1) - k_samples.log() # division by k in logspace 136 | return -torch.mean(L_k) 137 | 138 | def sample(self, num_samples): 139 | z_dist = td.Normal(loc=torch.zeros([num_samples, self.latent]), scale=1) 140 | z_sample = z_dist.sample().unsqueeze(0).to(self.device) 141 | samples = self.decoder(z_sample)[0].view(num_samples, -1) 142 | return samples 143 | --------------------------------------------------------------------------------