├── Figure
├── 3raxbh.gif
├── 3risln.gif
├── Figure_1.png
├── Figure_2.png
├── Figure_3.png
└── Figure_4.png
├── README.md
├── Utils.py
├── main.py
└── model
├── ConvIWAE.py
├── ExplicitIWAE.py
└── PytorchIWAE.py
/Figure/3raxbh.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JohanYe/IWAE-pytorch/059527de3da8920b0b7688dea516ce2adf76202b/Figure/3raxbh.gif
--------------------------------------------------------------------------------
/Figure/3risln.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JohanYe/IWAE-pytorch/059527de3da8920b0b7688dea516ce2adf76202b/Figure/3risln.gif
--------------------------------------------------------------------------------
/Figure/Figure_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JohanYe/IWAE-pytorch/059527de3da8920b0b7688dea516ce2adf76202b/Figure/Figure_1.png
--------------------------------------------------------------------------------
/Figure/Figure_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JohanYe/IWAE-pytorch/059527de3da8920b0b7688dea516ce2adf76202b/Figure/Figure_2.png
--------------------------------------------------------------------------------
/Figure/Figure_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JohanYe/IWAE-pytorch/059527de3da8920b0b7688dea516ce2adf76202b/Figure/Figure_3.png
--------------------------------------------------------------------------------
/Figure/Figure_4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JohanYe/IWAE-pytorch/059527de3da8920b0b7688dea516ce2adf76202b/Figure/Figure_4.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Importance Weighted Autoencoders (IWAE)
2 | Link to paper: https://arxiv.org/abs/1509.00519
3 |
4 | AnalyticalIWAE: IWAE calculating loss manually
5 |
6 | PytorchIWAE: IWAE using built-in torch functions to evaluate and calculation loss.
7 | Includes example of algorithm very easy to apply to existing VAE (although a bit slower)
8 |
9 | ConvIWAE: An example of convolutional IWAE, not integrated with main script, only as example
10 |
11 |
12 | # Importance Weighted Autoencoders - Gaussian encoder and decoder
13 | #### Pytorch IWAE Loss Curve:
14 | 
15 |
16 | #### Pytorch IWAE 60 epoch results:
17 | 
18 |
19 | #### Training gif
20 | 
21 |
22 | ## Importance Weighted Autoencoders - Gaussian encoder, Bernoulli decoder
23 | #### Analytical IWAE Loss Curve:
24 | 
25 |
26 | #### Analytical IWAE 60 epoch results:
27 | 
28 |
29 | #### Training gif
30 | 
31 |
32 |
--------------------------------------------------------------------------------
/Utils.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | import torch.nn as nn
4 |
5 |
6 |
7 | def Plot_loss_curve(train_list, test_dict):
8 | x_tst = list(test_dict.keys())
9 | y_tst = list(test_dict.values())
10 | train_x_vals = np.arange(len(train_list))
11 | plt.figure(2)
12 | plt.xlabel('Num Steps')
13 | plt.ylabel('ELBO')
14 | plt.title('ELBO Loss Curve')
15 | plt.plot(train_x_vals, train_list, label='train')
16 | plt.plot(x_tst, y_tst, label='tst')
17 | plt.legend(loc='best')
18 | plt.locator_params(axis='x', nbins=10)
19 |
20 | plt.show()
21 | return
22 |
23 | def create_canvas(x):
24 | rows = 10
25 | columns = 10
26 |
27 | plt.figure(1)
28 | canvas = np.zeros((28 * rows, columns * 28))
29 | for i in range(rows):
30 | for j in range(columns):
31 | idx = i % columns + rows * j
32 | canvas[i * 28:(i + 1) * 28, j * 28:(j + 1) * 28] = x[idx].reshape((28, 28))
33 | return canvas
34 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 | import torch.optim as optim
3 | import seaborn as sns
4 | from model.ExplicitIWAE import *
5 | from model.PytorchIWAE import *
6 | from Utils import *
7 |
8 | sns.set_style("darkgrid")
9 |
10 | # Hyperparameters
11 | gif_pics = True
12 | batch_size = 250
13 | lr = 1e-4
14 | num_epochs = 65
15 | train_log = []
16 | test_log = {}
17 | k = 0
18 | num_samples = 5
19 | beta = 0
20 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21 |
22 | # Model
23 | Explicit = True
24 | Implicit = False
25 |
26 | if (Explicit + Implicit) > 1:
27 | print('More than one model enabled')
28 | sys.exit()
29 |
30 | if Explicit:
31 | net = AnalyticalIWAE(1024, 512, 32).to(device)
32 | if Implicit:
33 | net = PytorchIWAE(1024, 512, 32).to(device)
34 | optimizer = optim.Adam(net.parameters(), lr=lr)
35 |
36 | # Data loading
37 | t = torchvision.transforms.transforms.ToTensor()
38 | train_data = torchvision.datasets.MNIST('./', train=True, transform=t, target_transform=None, download=True)
39 | test_data = torchvision.datasets.MNIST('./', train=False, transform=t, target_transform=None, download=True)
40 | train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
41 | test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False)
42 |
43 | for epoch in range(num_epochs):
44 | for idx, train_iter in enumerate(train_loader):
45 | batch, label = train_iter[0], train_iter[1]
46 | batch = batch.view(batch.size(0), -1) # flatten
47 | batch = batch.expand(num_samples, batch.shape[0], -1).to(device) # make num_samples copies
48 |
49 | batch_loss = net.calc_loss(batch, beta)
50 | optimizer.zero_grad()
51 | batch_loss.backward()
52 | optimizer.step()
53 |
54 | train_log.append(batch_loss.item())
55 | if beta < 2:
56 | beta += 0.001 # Warm-up
57 | k += 1
58 |
59 | loss_batch_mean = []
60 | for idx, test_iter in enumerate(test_loader):
61 | batch, label = train_iter[0], train_iter[1]
62 | batch = batch.view(batch.size(0), -1) # flatten
63 | batch = batch.expand(num_samples, batch.shape[0], batch.shape[1]).to(device) # make num_samples copies
64 |
65 | test_loss = net.calc_loss(batch, beta)
66 |
67 | loss_batch_mean.append(test_loss.detach().item())
68 |
69 | test_log[k] = np.mean(loss_batch_mean)
70 | if gif_pics and epoch % 2 == 0:
71 | if Explicit:
72 | batch = batch[0, :100, :].squeeze()
73 | recon_x = net(batch)
74 | else:
75 | batch = batch[0, :100, :].unsqueeze(0)
76 | recon_x = net(batch)[0].squeeze() #get mu only
77 |
78 | samples = net.sample(100).detach().cpu()
79 | fig, axs = plt.subplots(1, 2, figsize=(5, 10))
80 |
81 | # Reconstructions
82 | recon_x = create_canvas(recon_x.detach().cpu())
83 | axs[0].set_title('Epoch {} Reconstructions'.format(epoch + 1))
84 | axs[0].axis('off')
85 | axs[0].imshow(recon_x, cmap='gray')
86 |
87 | # Samples
88 | samples = create_canvas(samples)
89 | axs[1].set_title('Epoch {} Sampled Samples'.format(epoch + 1))
90 | axs[1].axis('off')
91 | axs[1].imshow(samples, cmap='gray')
92 | save_path = './Figure/GIF/gif_pic' + str(epoch + 1) + '.jpg'
93 | plt.savefig(save_path, bbox_inches='tight')
94 | plt.close()
95 |
96 | print('[Epoch: {}/{}][Step: {}]\tTrain Loss: {},\tTest Loss: {}'.format(
97 | epoch + 1, num_epochs, k, round(train_log[k - 1], 2), round(test_log[k], 2)))
98 |
99 | ###### Loss Curve Plotting ######
100 | Plot_loss_curve(train_log, test_log)
101 | plt.savefig('./Figure/Figure_1.png', bbox_inches='tight')
102 | plt.close()
103 |
104 | ###### Sampling #########
105 | x = next(iter(train_loader))[0].to(device)
106 | x = x.view(x.size(0), -1)[:100] # flatten and limit to 100
107 | if Explicit:
108 | recon_x = net(x)
109 | else:
110 | recon_x = net(x)[0]
111 |
112 | fig, axs = plt.subplots(1, 3, figsize=(15, 5))
113 | x_true = create_canvas(x.detach().cpu())
114 | axs[0].set_title('Ground Truth MNIST Digits')
115 | axs[0].axis('off')
116 | axs[0].imshow(x_true, cmap='gray')
117 |
118 | recon_x = create_canvas(recon_x.detach().cpu())
119 | axs[1].set_title('Reconstructed MNIST Digits')
120 | axs[1].axis('off')
121 | axs[1].imshow(recon_x, cmap='gray')
122 |
123 | samples = net.sample(100).detach().cpu()
124 | samples = create_canvas(samples)
125 | axs[2].set_title('Sampled MNIST Digits')
126 | axs[2].axis('off')
127 | axs[2].imshow(samples, cmap='gray')
128 | plt.savefig('./Figure/Figure_2.png', bbox_inches='tight')
129 | plt.close()
130 |
--------------------------------------------------------------------------------
/model/ConvIWAE.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.nn.functional as F
4 | import torch.distributions as td
5 |
6 | # NOT INTEGRATED WITH MAIN.PY, ONLY FOR SHOW
7 |
8 |
9 | class Flatten(nn.Module):
10 | def forward(self, input):
11 | return input.view(input.size(0), 50 * 7 * 7)
12 |
13 |
14 | class UnFlatten(nn.Module):
15 | def forward(self, input):
16 | return input.view(input.size(0), 50, 7, 7)
17 |
18 |
19 | class ConvIWAE(nn.Module):
20 | def __init__(self, z_dim=20, bs):
21 | super(ConvIWAE, self).__init__()
22 | self.z_dim = z_dim
23 | self.analytic_kl = True
24 | self.batch_size = bs
25 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26 |
27 | self.block1 = nn.Sequential(
28 | nn.Conv2d(1, 240, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(0.2),
29 | nn.Conv2d(240, 160, kernel_size=3, stride=2, padding=1), nn.LeakyReLU(0.2),
30 | nn.BatchNorm2d(160),
31 | nn.Conv2d(160, 80, kernel_size=3, stride=1, padding=1), nn.ReLU(),
32 | nn.BatchNorm2d(80),
33 | nn.Conv2d(80, 50, kernel_size=3, stride=2, padding=1), nn.ReLU(),
34 | Flatten(),
35 | nn.Linear(50 * 7 * 7, 700),
36 | nn.ReLU())
37 |
38 | self.mu = nn.Linear(700, z_dim)
39 | self.std = nn.Linear(700, z_dim)
40 |
41 | self.dec = nn.Sequential(
42 | nn.Linear(z_dim, 700), nn.ReLU(),
43 | nn.Linear(700, 50 * 7 * 7), nn.ReLU(),
44 | UnFlatten(),
45 | nn.ConvTranspose2d(50, 220, kernel_size=3, stride=2, padding=1, output_padding=1), nn.ReLU(),
46 | nn.BatchNorm2d(220),
47 | nn.ConvTranspose2d(220, 160, kernel_size=3, stride=1, padding=1), nn.ReLU(),
48 | nn.BatchNorm2d(160),
49 | nn.ConvTranspose2d(160, 60, kernel_size=3, stride=2, padding=1, output_padding=1), nn.ReLU(),
50 | nn.BatchNorm2d(60),
51 | nn.ConvTranspose2d(60, 2, kernel_size=3, stride=1, padding=1),
52 | nn.Sigmoid()) # not sure why we sigmoid variance
53 | #
54 |
55 | def encoder(self, x):
56 | x = self.block1(x)
57 | mu = self.mu(x)
58 | std = nn.functional.softplus(self.std(x))
59 | return mu, std
60 |
61 | def reparameterize(self, mu, std, K, training):
62 | if training == True:
63 | qz_Gx_obs = td.Normal(loc=mu, scale=std)
64 | z = qz_Gx_obs.rsample(torch.Size([K])) # if we do [m,k] we get (m,k,batch,zdim)
65 | else:
66 | z = mu.view(self.batch_size, -1)
67 | return z
68 |
69 | def decoder(self, z):
70 | K, bs = z.size(0), z.size(1)
71 | z = z.view([K * bs, -1])
72 | x = self.dec(z)
73 | x = x.view([K, bs, 2, 28, 28])
74 | x_mean = x[:, :, :1, :, :] # (K, bs, param, dim1, dim2)
75 | x_std = nn.functional.softplus(x[:, :, 1:, :, :]) # (K, bs, param, dim1, dim2)
76 | return x_mean, x_std
77 |
78 | def elbo(self, mu_dec, std_dec, mu_enc, std_enc, x, z, K, beta):
79 | qz_Gx_obs = td.Normal(loc=mu_enc, scale=std_enc) # z_dist
80 | p_z = td.Normal(torch.zeros([self.z_dim]).to(self.device), torch.ones([self.z_dim]).to(self.device))
81 |
82 | if self.analytic_kl:
83 | kl = td.kl_divergence(qz_Gx_obs, p_z).sum(-1)
84 | else:
85 | lpz = p_z.log_prob(z).sum(-1)
86 | lqzx = qz_Gx_obs.log_prob(z).sum(-1)
87 | kl = lqzx - lpz
88 |
89 | xgz = td.Normal(loc=mu_dec, scale=std_dec) # x_dist
90 | lpxgz = xgz.log_prob(x).sum([-1, -2, -3]) # (K,bs)
91 | elbo = lpxgz - beta * kl
92 |
93 | loss = -torch.mean(torch.logsumexp(elbo, 0))
94 | return loss
95 |
96 | def forward(self, x, K, beta, training):
97 | mu_enc, std_enc = self.encoder(x)
98 | z = self.reparameterize(mu_enc, std_enc, K, training)
99 | mu_dec, std_dec = self.decoder(z) # (K, bs, dim1, dim2)
100 | loss = self.elbo(mu_dec, std_dec, mu_enc, std_enc, x, z, K, beta)
101 |
102 | return loss, mu_enc, std_enc, mu_dec, std_dec
--------------------------------------------------------------------------------
/model/ExplicitIWAE.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 |
6 | class AnalyticalIWAE(nn.Module):
7 | # Calculates loss explicitly.
8 | def __init__(self, num_hidden1, num_hidden2, latent_space):
9 | super(AnalyticalIWAE, self).__init__()
10 | self.fc1 = nn.Sequential(
11 | nn.Linear(in_features=784, out_features=num_hidden1),
12 | nn.ReLU(),
13 | nn.Linear(num_hidden1, num_hidden2),
14 | nn.ReLU(),
15 | )
16 |
17 | self.fc21 = nn.Linear(in_features=num_hidden2, out_features=latent_space)
18 | self.fc22 = nn.Linear(in_features=num_hidden2, out_features=latent_space)
19 |
20 | self.fc3 = nn.Sequential(
21 | nn.Linear(in_features=latent_space, out_features=num_hidden2),
22 | nn.ReLU(),
23 | nn.Linear(num_hidden2, num_hidden1),
24 | nn.ReLU(),
25 | )
26 | self.decode = nn.Linear(in_features=num_hidden1, out_features=784)
27 | self.latent = latent_space
28 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29 |
30 | for m in self.modules():
31 | if isinstance(m, nn.Linear):
32 | nn.init.kaiming_normal_(m.weight)
33 |
34 | def encode(self, x):
35 | x = self.fc1(x)
36 | mu = self.fc21(x)
37 | log_var = self.fc22(x)
38 |
39 | # Reparameterize
40 | eps = torch.randn_like(log_var)
41 | h = mu + torch.exp(log_var * 0.5) * eps
42 | return mu, log_var, h, eps
43 |
44 | def forward(self, x):
45 | """
46 | Purely to see reconstruction, not for calculating loss.
47 | """
48 |
49 | # Encode
50 | mu, log_var, h, eps = self.encode(x)
51 |
52 | # decode
53 | recon_X = self.fc3(h)
54 | recon_X = torch.sigmoid(self.decode(recon_X))
55 | return recon_X
56 |
57 | def calc_loss(self, x, beta):
58 |
59 | # Encode
60 | mu, log_var, h, eps = self.encode(x)
61 |
62 | # Calculating P(x,h)
63 | log_Ph = torch.sum(-0.5 * h ** 2 - 0.5 * torch.log(2 * h.new_tensor(np.pi)),
64 | -1) # equivalent to lognormal if mu=0,std=1 (i think)
65 | recon_X = torch.sigmoid(self.decode(self.fc3(h))) # Creating reconstructions
66 | log_PxGh = torch.sum(x * torch.log(recon_X) + (1 - x) * torch.log(1 - recon_X),
67 | -1) # Bernoulli decoder: Appendix c.1 Kingma p(x|h)
68 | log_Pxh = log_Ph + log_PxGh # log(p(x,h))
69 | log_QhGx = torch.sum(-0.5 * (eps) ** 2 - 0.5 * torch.log(2 * h.new_tensor(np.pi)) - 0.5 * log_var,
70 | -1) # Evaluation in lognormal
71 |
72 | # Weighting according to equation 13 from IWAE paper
73 | log_weight = (log_Pxh - log_QhGx).detach().data
74 | log_weight = log_weight - torch.max(log_weight, 0)[0]
75 | weight = torch.exp(log_weight)
76 | weight = weight / torch.sum(weight, 0)
77 |
78 | # scaling
79 | loss = torch.mean(-torch.sum(weight * (log_PxGh + (log_Ph - log_QhGx)*beta), 0))
80 |
81 | return loss
82 |
83 | def sample(self, n_samples):
84 | eps = torch.randn((n_samples, self.latent)).to(self.device)
85 | sample = self.fc3(eps)
86 | sample = torch.sigmoid(self.decode(sample))
87 | return sample
88 |
--------------------------------------------------------------------------------
/model/PytorchIWAE.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.distributions as td
4 | import torch.nn.functional as F
5 |
6 |
7 | # define network
8 | class PytorchIWAE(nn.Module):
9 | # Network uses in-built pytorch function for variational inference, instead of having to explicitly
10 | # calculate it
11 | def __init__(self, num_hidden1, num_hidden2, latent, in_dim=784):
12 | super(PytorchIWAE, self).__init__()
13 | self.latent = latent
14 | self.out_dec = in_dim
15 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16 |
17 | self.block1 = nn.Sequential(
18 | nn.Linear(in_features=in_dim, out_features=num_hidden1),
19 | nn.ReLU(),
20 | nn.Linear(num_hidden1, num_hidden2),
21 | nn.ReLU(),
22 | )
23 | self.mu_enc = nn.Linear(in_features=num_hidden2, out_features=self.latent)
24 | self.lvar_enc = nn.Linear(in_features=num_hidden2, out_features=self.latent)
25 |
26 | self.block2 = nn.Sequential(
27 | nn.Linear(in_features=self.latent, out_features=num_hidden2),
28 | nn.ReLU(),
29 | nn.Linear(num_hidden2, num_hidden1),
30 | nn.ReLU(),
31 | )
32 |
33 | self.mu_dec = nn.Linear(in_features=num_hidden1, out_features=self.out_dec)
34 | self.lvar_dec = nn.Linear(in_features=num_hidden1, out_features=self.out_dec)
35 |
36 | for m in self.modules():
37 | if isinstance(m, nn.Linear):
38 | nn.init.kaiming_normal_(m.weight)
39 |
40 | def encoder(self, x):
41 | x = self.block1(x)
42 | mu = self.mu_enc(x)
43 | log_var = self.lvar_enc(x)
44 | return mu, log_var
45 |
46 | def decoder(self, z):
47 | h = self.block2(z)
48 |
49 | mu_x = torch.sigmoid(self.mu_dec(h))
50 | var_x = torch.sigmoid(self.lvar_dec(h)) # Stability reasons
51 |
52 | return (mu_x, var_x)
53 |
54 |
55 | def reparameterize(self, mu, std):
56 | qz_Gx_obs = td.Normal(loc=mu, scale=std)
57 | # find z|x
58 | z_Gx = qz_Gx_obs.rsample()
59 | return z_Gx, qz_Gx_obs
60 |
61 | def forward(self, x, train=True):
62 | mu, log_var = self.encoder(x)
63 | if train:
64 | std = log_var.mul(0.5).exp_()
65 | z, _ = self.reparameterize(mu, std)
66 | else:
67 | z = mu
68 | x_recon = self.decoder(z)
69 | # can also show mu
70 | return x_recon
71 |
72 | def calc_loss(self, x, beta):
73 | # Encode
74 | mu_enc, log_var_enc = self.encoder(x)
75 | std_enc = torch.exp(0.5 * log_var_enc)
76 |
77 | # Reparameterize:
78 | z_Gx, qz_Gx_obs = self.reparameterize(mu_enc, std_enc)
79 | mu_dec, log_var_dec = self.decoder(z_Gx)
80 |
81 | # Find q(z|x)
82 | log_QhGx = qz_Gx_obs.log_prob(z_Gx)
83 | log_QhGx = torch.sum(log_QhGx, -1)
84 |
85 | # Find p(z)
86 | mu_prior = torch.zeros(self.latent).to(self.device)
87 | std_prior = torch.ones(self.latent).to(self.device)
88 | p_z = td.Normal(loc=mu_prior, scale=std_prior)
89 | log_Ph = torch.sum(p_z.log_prob(z_Gx), -1)
90 |
91 | # Find p(x|z)
92 | std_dec = log_var_dec.mul(0.5).exp_()
93 | px_Gz = td.Normal(loc=mu_dec, scale=std_dec).log_prob(x)
94 | log_PxGh = torch.sum(px_Gz, -1)
95 | # print(log_PxGh, log_Ph, log_QhGx)
96 | # Calculate loss
97 |
98 | w = log_PxGh + (log_Ph - log_QhGx)*beta
99 | loss = -torch.mean(torch.logsumexp(w, 0))
100 | return loss
101 |
102 | def calc_loss_simple(self, x, beta, k_samples):
103 | """
104 | Simple to understand and lazy algorithm, but slow
105 | Not made compatible with remainin script
106 | Useful as it is likely easy to implement into existing models
107 | """
108 | from torch.distributions.kl import kl_divergence as KL
109 |
110 | for _ in range(k_samples):
111 | # Encode
112 | mu_enc, log_var_enc = self.encoder(x)
113 | std_enc = torch.exp(0.5 * log_var_enc)
114 |
115 | # Reparameterize:
116 | z_Gx, qz_Gx_obs = self.reparameterize(mu_enc, std_enc)
117 | mu_prior = torch.zeros(self.latent).to(self.device)
118 | std_prior = torch.ones(self.latent).to(self.device)
119 | p_z = td.Normal(loc=mu_prior, scale=std_prior)
120 |
121 | #decode
122 | mu_dec, log_var_dec = self.decoder(z_Gx)
123 | std_dec = log_var_dec.mul(0.5).exp_()
124 | px_Gz = td.Normal(loc=mu_dec, scale=std_dec).log_prob(x)
125 |
126 | if log_px_z is None:
127 | log_px_z = px_Gz.log_prob(x).mean(-1).unsqueeze(1)
128 | kld = KL(qz_Gx_obs, p_z).unsqueeze(1)
129 | else:
130 | log_px_z = torch.cat([log_px_z, px_Gz.log_prob(x).mean(-1).unsqueeze(1)], 1)
131 | kld = torch.cat([kld, KL(qz_Gx_obs, p_z).unsqueeze(1)], 1)
132 |
133 | # loss calculation
134 | log_wk = log_px_z - kld
135 | L_k = log_wk.logsumexp(dim=-1) - k_samples.log() # division by k in logspace
136 | return -torch.mean(L_k)
137 |
138 | def sample(self, num_samples):
139 | z_dist = td.Normal(loc=torch.zeros([num_samples, self.latent]), scale=1)
140 | z_sample = z_dist.sample().unsqueeze(0).to(self.device)
141 | samples = self.decoder(z_sample)[0].view(num_samples, -1)
142 | return samples
143 |
--------------------------------------------------------------------------------