├── readme_imgs ├── originals_epoch_200_example_0.png └── reconstruction_epoch_200_example_0.png ├── .gitignore ├── maxout.py ├── LICENSE ├── README.md ├── evaluation_trained_model.py ├── utils └── code_to_load_the_dataset.py ├── flows.py ├── flows_with_amortized_weights.py ├── MNIST_experiment.py └── VAE_with_flows.py /readme_imgs/originals_epoch_200_example_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/federicobergamin/Variational-Inference-with-Normalizing-Flows/HEAD/readme_imgs/originals_epoch_200_example_0.png -------------------------------------------------------------------------------- /readme_imgs/reconstruction_epoch_200_example_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/federicobergamin/Variational-Inference-with-Normalizing-Flows/HEAD/readme_imgs/reconstruction_epoch_200_example_0.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | .DS_Store 3 | __pycache__/ 4 | __init__.py 5 | Original_MNIST_binarized/ 6 | random_samples/ 7 | reconstruction_during_training/ 8 | runs/ 9 | samples_during_training/ 10 | saved_models/ 11 | flows.py -------------------------------------------------------------------------------- /maxout.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.autograd import Variable 4 | import torch.nn as nn 5 | from torch.autograd import Function 6 | 7 | 8 | # class Maxout(nn.Module): 9 | # def __init__(self, pool_size): 10 | # super().__init__() 11 | # self._pool_size = pool_size 12 | # 13 | # def forward(self, x): 14 | # assert x.shape[-1] % self._pool_size == 0, \ 15 | # 'Wrong input last dim size ({}) for Maxout({})'.format(x.shape[-1], self._pool_size) 16 | # m, i = x.view(*x.shape[:-1], x.shape[-1] // self._pool_size, self._pool_size).max(-1) 17 | # return m 18 | 19 | class Maxout(nn.Module): 20 | def __init__(self, pool_size): 21 | super().__init__() 22 | self._pool_size = pool_size 23 | 24 | def forward(self, x): 25 | assert x.shape[1] % self._pool_size == 0, \ 26 | 'Wrong input last dim size ({}) for Maxout({})'.format(x.shape[1], self._pool_size) 27 | m, i = x.view(*x.shape[:1], x.shape[1] // self._pool_size, self._pool_size, *x.shape[2:]).max(2) 28 | return m -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Federico Bergamin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Variational Autoencoders with Normalizing Flows (only Planar Flows) 2 | 3 | 4 | Personal implementation of the paper *Rezende, Danilo Jimenez, and Shakir Mohamed. "Variational inference with normalizing flows." by Rezende, D. J., & Mohamed, S. (2015).*. 5 | The main purpose of this repository is to make the paper implementation accessible and clear to people 6 | that just started getting into Variational Autoencoders without having to look into highly optimized and difficult to 7 | search libraries. 8 | 9 | I use the hyperparameters of the paper: every MLP has only one hidden layer with 400 hidden units, the dimension 10 | of the latent space is 40 and I use 10 Planar Flows. Also in this case I trained only for 200 epochs. I consider amortized weights 11 | for the flows. The log-likelihood on the test set is -92.0902. 12 | 13 | ### Reconstruction examples after 200 epochs, z_dim = 40, MLP with one layer (400 units), 10 Planar Flows 14 | 15 | ![alt text](readme_imgs/originals_epoch_200_example_0.png "Original MNIST example") 16 | ![alt text](readme_imgs/reconstruction_epoch_200_example_0.png "Reconstruction MNIST example") 17 | 18 | TODO: this is far from being a complete repo. There are some changes I am still want 19 | make during my free time: 20 | 1. learn how to sample when using amortized weights and then train until convergence using a GPU 21 | 2. create more functions to avoid repeated code 22 | 3. print more infos during training 23 | 4. using another dataset other than MNIST 24 | 5. add argparse -------------------------------------------------------------------------------- /evaluation_trained_model.py: -------------------------------------------------------------------------------- 1 | ''' Now that we have trained a VAE model wioth normalizing flows, we want to evalueate it in some way. 2 | We estimate the probability of data under the model using an importance sampling technique. 3 | We can write the marginal likelihood of a datapoint as: 4 | log p_theta(x) = log E_q [p_theta(x,z) / q_phi(z|x)] 5 | ~ log 1/L sum( (p_theta(x|z) * p(z)) / q_phi(z|x) ) 6 | ''' 7 | 8 | import math 9 | import numpy as np 10 | import torch 11 | import torch.utils 12 | import torch.utils.data 13 | from torch.utils.tensorboard import SummaryWriter 14 | import torchvision 15 | from torch import nn 16 | import torch.nn.functional as F 17 | from torchvision import datasets, transforms, utils 18 | from torch.autograd import Variable 19 | from VAE_with_normalizing_flows.VAE_with_flows import VariationalAutoencoderWithFlows 20 | from sklearn.decomposition import PCA 21 | from VAE_with_normalizing_flows.utils.code_to_load_the_dataset import load_MNIST_dataset 22 | 23 | import matplotlib.pyplot as plt 24 | 25 | def show_images(images, title=None, path=None): 26 | images = utils.make_grid(images) 27 | show_image(images[0], title, path) 28 | 29 | def show_image(img, title = "", path = None): 30 | plt.imshow(img, cmap='gray') 31 | plt.title(title) 32 | if path is not None: 33 | plt.savefig(path) 34 | plt.show() 35 | 36 | use_cuda = torch.cuda.is_available() 37 | print('Do we get access to a CUDA? - ', use_cuda) 38 | device = torch.device("cuda" if use_cuda else "cpu") 39 | 40 | ORIGINAL_BINARIZED_MNIST = True 41 | 42 | BATCH_SIZE = 100 43 | HIDDEN_LAYERS = [400] 44 | Z_DIM = 40 45 | N_FLOWS = 10 46 | 47 | N_EPOCHS = 200 48 | LEARNING_RATE = 1e-5 49 | MOMENTUM = 0.9 50 | WEIGHT_DECAY = -1 51 | 52 | AMORTIZED_WEIGHTS = True 53 | 54 | ESTIMATION_SAMPLES = 100 55 | 56 | PATH = 'saved_models/VAE_flows_zdim_40_epoch_200_elbo_-94.84511192321777_learnrate_1e-05' 57 | 58 | ## we have the binarized MNIST 59 | ## in this case we look at the test set, since we are interested in these examples that 60 | ## were not used to train the model 61 | if ORIGINAL_BINARIZED_MNIST: 62 | ## we load the original dataset by Larochelle 63 | train_loader, val_loader, test_loader = load_MNIST_dataset('Original_MNIST_binarized/', BATCH_SIZE, True, True, 64 | True) 65 | else: 66 | # we have the binarized MNIST 67 | ## TRAIN SET 68 | 69 | flatten_bernoulli = lambda x: transforms.ToTensor()(x).view(-1).bernoulli() 70 | 71 | ## TEST SET 72 | test_loader = torch.utils.data.DataLoader( 73 | datasets.MNIST('../MNIST_dataset', train=False, transform=flatten_bernoulli), 74 | batch_size=BATCH_SIZE, shuffle=True) 75 | 76 | ## we can create our model and try to train it 77 | model = VariationalAutoencoderWithFlows(28*28, HIDDEN_LAYERS, Z_DIM, N_FLOWS, amortized_params_flow = AMORTIZED_WEIGHTS) 78 | print('Model overview and recap\n') 79 | print(model) 80 | print('\n') 81 | 82 | # now we have to load the trained model dict 83 | model.load_state_dict(torch.load(PATH)) 84 | 85 | 86 | 87 | ## now for each datapoint of the test set we want to compute the marginal likelihood 88 | 89 | marginal_log_likelihood = 0 90 | model.eval() 91 | 92 | with torch.no_grad(): 93 | for i, data in enumerate(test_loader, 0): 94 | if ORIGINAL_BINARIZED_MNIST: 95 | images = data 96 | else: 97 | images, labels = data 98 | images = images.to(device) 99 | 100 | batch_log_likelihood = torch.zeros((len(images), ESTIMATION_SAMPLES)) 101 | 102 | for j in range(ESTIMATION_SAMPLES): 103 | # I have to forward the images through the model, this way we get the reconstruction 104 | reconstruction = model(images) 105 | # we should get the kl 106 | # kl = torch.sum(model.kl_divergence) 107 | kl = model.qz - model.pz 108 | # print(kl.shape) 109 | 110 | likelihood = - torch.sum(F.binary_cross_entropy(reconstruction, images, reduction = 'none'), 1) ## BATCH_SIZE element 111 | bound = likelihood - kl 112 | 113 | batch_log_likelihood[:,j] = bound 114 | 115 | ## at the end we have this matrix of size BATCH_SIZE x ESTIMATION_SAMPLES 116 | # print(batch_log_likelihood) 117 | log_likel = math.log(1/ESTIMATION_SAMPLES) + torch.logsumexp(batch_log_likelihood, dim = 1) 118 | marginal_log_likelihood += torch.sum(log_likel) 119 | 120 | print('The marginal log likelihood we get on average on a test example is:', marginal_log_likelihood / len(test_loader.dataset)) 121 | 122 | 123 | 124 | -------------------------------------------------------------------------------- /utils/code_to_load_the_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data_utils 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import torch.nn.functional as F 6 | from torchvision import datasets, transforms, utils 7 | 8 | 9 | def show_images(images, title=None, path=None): 10 | images = utils.make_grid(images) 11 | show_image(images[0], title, path) 12 | 13 | def show_image(img, title = "", path = None): 14 | plt.imshow(img, cmap='gray') 15 | plt.title(title) 16 | if path is not None: 17 | plt.savefig(path) 18 | plt.show() 19 | 20 | 21 | # print('loading the datasets...') 22 | # train = np.loadtxt('binarized_mnist_train.amat') 23 | # valid = np.loadtxt('binarized_mnist_valid.amat') 24 | # test = np.loadtxt('binarized_mnist_test.amat') 25 | # print('Datatsets loaded') 26 | # print(train.shape) 27 | # 28 | # plt.imshow(train[0,:].reshape(28,28), cmap='gray') 29 | # plt.title('Digit example from the training set') 30 | # plt.show() 31 | # 32 | # # batch size 33 | # BATCH_SIZE = 64 34 | # 35 | # train = train.reshape(-1,28,28) 36 | # print(train.shape) 37 | # valid = valid.reshape(-1,28,28) 38 | # test = test.reshape(-1,28,28) 39 | # 40 | # train = torch.from_numpy(train).float() 41 | # validation = torch.from_numpy(valid).float() 42 | # test = torch.from_numpy(test).float() 43 | # print(train.shape) 44 | # 45 | # # plt.imshow(train[0,:,:], cmap='gray') 46 | # # plt.title('Digit example from the training set') 47 | # # plt.show() 48 | # 49 | # # pytorch data loader 50 | # # train = data_utils.TensorDataset(torch.from_numpy(train).float()) 51 | # train_loader = data_utils.DataLoader(train, batch_size=BATCH_SIZE, shuffle=True) 52 | # 53 | # # validation = data_utils.TensorDataset(torch.from_numpy(valid).float()) 54 | # val_loader = data_utils.DataLoader(validation, batch_size=BATCH_SIZE, shuffle=False) 55 | # 56 | # # test = data_utils.TensorDataset(torch.from_numpy(test).float()) 57 | # test_loader = data_utils.DataLoader(test, batch_size=BATCH_SIZE, shuffle=True) 58 | # 59 | # 60 | # dataiter = iter(train_loader) 61 | # images = dataiter.next() ## next return a complete batch --> BATCH_SIZE images 62 | # print(images.shape) 63 | # print(images.unsqueeze(1).shape) 64 | # show_images(images.unsqueeze(1)) 65 | # 66 | # i = 0 67 | # for data in train_loader: 68 | # i+=1 69 | # 70 | # print(i) 71 | # 72 | 73 | def load_MNIST_dataset(dir, batch_size, flatten = False, shuffled = False, verbose = False, show_examples = False): 74 | 75 | if isinstance(batch_size, list): 76 | assert len(batch_size) == 3, "In case you are using a list for the batch size, it should contain the batch size for " \ 77 | "the training, validation and test set" 78 | train_batch_size = batch_size[0] 79 | valid_batch_size = batch_size[1] 80 | test_batch_size = batch_size[2] 81 | else: 82 | # we assume the same batch size for all the three 83 | train_batch_size = batch_size 84 | valid_batch_size = batch_size 85 | test_batch_size = batch_size 86 | 87 | 88 | print('Loading the datasets...') 89 | train = np.loadtxt(dir + 'binarized_mnist_train.amat') 90 | valid = np.loadtxt(dir + 'binarized_mnist_valid.amat') 91 | test = np.loadtxt(dir + 'binarized_mnist_test.amat') 92 | print('Datatsets loaded\n') 93 | 94 | if not flatten: 95 | train = train.reshape(-1, 28, 28) 96 | valid = valid.reshape(-1, 28, 28) 97 | test = test.reshape(-1, 28, 28) 98 | 99 | train = torch.from_numpy(train).float() 100 | validation = torch.from_numpy(valid).float() 101 | test = torch.from_numpy(test).float() 102 | 103 | if verbose: 104 | print('Training set shape:', train.shape) 105 | print('Validation et shape:', validation.shape) 106 | print('Test set shape', test.shape) 107 | 108 | train_loader = data_utils.DataLoader(train, batch_size=train_batch_size, shuffle=shuffled) 109 | val_loader = data_utils.DataLoader(validation, batch_size=valid_batch_size, shuffle=shuffled) 110 | test_loader = data_utils.DataLoader(test, batch_size=test_batch_size, shuffle=shuffled) 111 | 112 | if show_examples: 113 | dataiter = iter(train_loader) 114 | images = dataiter.next() ## next return a complete batch --> BATCH_SIZE images 115 | if flatten: 116 | show_images(images.view(train_batch_size, 1, 28,28)) 117 | else: 118 | show_images(images.unsqueeze(1)) 119 | 120 | return train_loader, val_loader, test_loader 121 | 122 | 123 | 124 | 125 | #t, v, te = load_MNIST_dataset('../Original_MNIST_binarized/', 64, True, True, True) 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | -------------------------------------------------------------------------------- /flows.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Implementation of 'Variational Inference with Normalizing Flows' by [Rezende, D. & Mohamed, S. 2015] 3 | ''' 4 | 5 | import math 6 | import torch 7 | import torch.utils 8 | import torch.utils.data 9 | from torch import nn 10 | import torch.nn.functional as F 11 | from torch.autograd import Variable 12 | from torch.distributions import LogNormal 13 | from torch.nn import init 14 | # 15 | class PlanarFlow(nn.Module): 16 | 17 | def __init__(self, z_dim): 18 | ''' 19 | 20 | :param z_dim: since we are using it with VAEs this would be the size of the latent. 21 | However, flows can be used also in other scenario where the input they 22 | get is not the latent variables 23 | ''' 24 | super(PlanarFlow, self).__init__() 25 | self.init_sigma = 0.01 26 | # self.u3 = nn.Parameter(torch.randn(1, 2).normal_(0, init_sigma)) 27 | self.n_features = z_dim 28 | self.weights = nn.Parameter(torch.randn(1, z_dim).normal_(0, self.init_sigma)) 29 | self.bias = nn.Parameter(torch.zeros(1).normal_(0, self.init_sigma)) 30 | self.u = nn.Parameter(torch.randn(1, z_dim).normal_(0, self.init_sigma)) 31 | 32 | 33 | def forward(self, z): 34 | 35 | ## we follow the instruction on the Appendix A of the paper: 36 | ## for the planar flow to be invertible whe using tanh a sufficient 37 | ## condition is to w^Tu >= -1 38 | ## --> we have to compute uhat parallel to w 39 | u_temp = (self.weights @ self.u.t()).squeeze() 40 | # print('u_temp', u_temp.shape) 41 | m_u_temp = -1 + F.softplus(u_temp) 42 | # print('m_u_temp', m_u_temp.shape) 43 | 44 | uhat = self.u + (m_u_temp - u_temp) * (self.weights / (self.weights @ self.weights.t())) 45 | # print('uhat', uhat.shape) 46 | 47 | z_temp = z @ self.weights.t() + self.bias #F.linear(z, self.weights, self.bias) 48 | 49 | new_z = z + uhat * torch.tanh(z_temp) 50 | 51 | ## now we have to compute psi 52 | 53 | psi = (1 - torch.tanh(z_temp)**2) @ self.weights 54 | # print('psi', psi.shape) 55 | 56 | det_jac = 1 + psi @ uhat.t() #uhat * psi 57 | 58 | logdet_jacobian = torch.log(torch.abs(det_jac) + 1e-8).squeeze() 59 | # print(torch.sum(logdet_jacobian, 1).shape) 60 | # print('log_det_jacobian', logdet_jacobian.shape) 61 | 62 | return new_z, logdet_jacobian # torch.sum(logdet_jacobian, 1) 63 | 64 | 65 | # class PlanarFlow(nn.Module): 66 | # 67 | # def __init__(self, z_dim): 68 | # ''' 69 | # 70 | # :param z_dim: since we are using it with VAEs this would be the size of the latent. 71 | # However, flows can be used also in other scenario where the input they 72 | # get is not the latent variables 73 | # ''' 74 | # super(PlanarFlow, self).__init__() 75 | # self.init_sigma = 0.01 76 | # # self.u3 = nn.Parameter(torch.randn(1, 2).normal_(0, init_sigma)) 77 | # self.n_features = z_dim 78 | # # self.weights = nn.Parameter(torch.randn(1, z_dim).normal_(0, self.init_sigma)) 79 | # # self.bias = nn.Parameter(torch.zeros(1).normal_(0, self.init_sigma)) 80 | # # self.u = nn.Parameter(torch.randn(1, z_dim).normal_(0, self.init_sigma)) 81 | # 82 | # 83 | # def forward(self, z, weights, bias, u): 84 | # 85 | # ## we follow the instruction on the Appendix A of the paper: 86 | # ## for the planar flow to be invertible whe using tanh a sufficient 87 | # ## condition is to w^Tu >= -1 88 | # ## --> we have to compute uhat parallel to w 89 | # u_temp = (weights @ u.t()).squeeze() 90 | # # print('u_temp', u_temp.shape) 91 | # m_u_temp = -1 + F.softplus(u_temp) 92 | # # print('m_u_temp', m_u_temp.shape) 93 | # 94 | # uhat = u + (m_u_temp - u_temp) * (self.weights / (self.weights @ self.weights.t())) 95 | # # print('uhat', uhat.shape) 96 | # 97 | # z_temp = z @ self.weights.t() + self.bias #F.linear(z, self.weights, self.bias) 98 | # 99 | # new_z = z + uhat * torch.tanh(z_temp) 100 | # 101 | # ## now we have to compute psi 102 | # 103 | # psi = (1 - torch.tanh(z_temp)**2) @ self.weights 104 | # # print('psi', psi.shape) 105 | # 106 | # det_jac = 1 + psi @ uhat.t() #uhat * psi 107 | # 108 | # logdet_jacobian = torch.log(torch.abs(det_jac) + 1e-8).squeeze() 109 | # # print(torch.sum(logdet_jacobian, 1).shape) 110 | # # print('log_det_jacobian', logdet_jacobian.shape) 111 | # 112 | # return new_z, logdet_jacobian # torch.sum(logdet_jacobian, 1) 113 | 114 | 115 | class NormalizingFlows(nn.Module): 116 | 117 | def __init__(self, z_dims, n_flows = 1, flow_type = PlanarFlow): 118 | ''' 119 | 120 | :param z_dims: dimension of the latent variables 121 | :param n_flows: how many flows we should use in term of sequence of function f_k (f_k-1(f_k-2(.. 122 | :param flow_type: we have implemented only the Planar Flow, but in case we implement also the radial flow, one can 123 | select what type of flows to use 124 | ''' 125 | 126 | super(NormalizingFlows, self).__init__() 127 | self.z_dims = z_dims 128 | self.n_flows = n_flows 129 | self.flow_type = flow_type 130 | 131 | flows_sequence = [self.flow_type(self.z_dims) for _ in range(self.n_flows)] 132 | 133 | self.flows = nn.ModuleList(flows_sequence) 134 | 135 | def forward(self, z): 136 | 137 | # we have to collect all the logdet_jacobian to sum them up in the end 138 | logdet_jacobians = [] 139 | # i = 0 140 | for flow in self.flows: 141 | # i += 1 142 | # print(i) 143 | z, logdet_j = flow(z) 144 | # print(logdet_j.shape) 145 | logdet_jacobians.append(logdet_j) 146 | 147 | z_k = z 148 | # print('final_logdet_jacobian', logdet_jacobians) 149 | logdet_jacobians = torch.stack(logdet_jacobians, dim=1) 150 | # print(logdet_jacobians.shape) 151 | # print(torch.sum(logdet_jacobians, 1).shape) 152 | # print('new_z', z_k.shape) 153 | # print('we sum them') 154 | # print(torch.sum(logdet_jacobians, 1)) 155 | return z_k, torch.sum(logdet_jacobians, 1) 156 | 157 | 158 | 159 | 160 | 161 | 162 | -------------------------------------------------------------------------------- /flows_with_amortized_weights.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Implementation of 'Variational Inference with Normalizing Flows' by [Rezende, D. & Mohamed, S. 2015] 3 | ''' 4 | 5 | import math 6 | import torch 7 | import torch.utils 8 | import torch.utils.data 9 | from torch import nn 10 | import torch.nn.functional as F 11 | from torch.autograd import Variable 12 | from torch.distributions import LogNormal 13 | from torch.nn import init 14 | # 15 | class PlanarFlow(nn.Module): 16 | 17 | def __init__(self, amortized_params_flow = False, z_dim = None): 18 | ''' 19 | We are considering in this case amortized inference of the weights/params. This means that 20 | instead of learning some weights that we use for all the examples, we learn specific weights for each example. 21 | In other words, the u, w, b are computed form the input, so we are learning a function that maps each 22 | imput to 23 | ''' 24 | super(PlanarFlow, self).__init__() 25 | self.amortized_params_flow = amortized_params_flow 26 | 27 | if not self.amortized_params_flow: 28 | self.init_sigma = 0.01 29 | self.n_features = z_dim 30 | self.weights = nn.Parameter(torch.randn(1, z_dim).normal_(0, self.init_sigma)) 31 | self.bias = nn.Parameter(torch.zeros(1).normal_(0, self.init_sigma)) 32 | self.u = nn.Parameter(torch.randn(1, z_dim).normal_(0, self.init_sigma)) 33 | 34 | 35 | def forward(self, zk, u_k, weights_k, bias_k): 36 | 37 | if self.amortized_params_flow: 38 | ''' 39 | Assumes the following input shapes: 40 | shape u = (batch_size, z_size, 1) 41 | shape w = (batch_size, 1, z_size) 42 | shape b = (batch_size, 1, 1) 43 | shape z = (batch_size, z_size) 44 | 45 | this part is taken from https://github.com/riannevdberg/sylvester-flows/blob/master/models/flows.py 46 | ''' 47 | # print('z',zk.shape) # 48 | # print('u', u_k.shape) 49 | # print('w', weights_k.shape) 50 | # print('b', bias_k.shape) 51 | u = u_k 52 | weights = weights_k 53 | bias = bias_k 54 | 55 | zk = zk.unsqueeze(2) 56 | # reparameterize u such that the flow becomes invertible (see appendix paper) 57 | uw = torch.bmm(weights, u) 58 | m_uw = -1. + F.softplus(uw) 59 | w_norm_sq = torch.sum(weights ** 2, dim=2, keepdim=True) 60 | u_hat = u + ((m_uw - uw) * weights.transpose(2, 1) / w_norm_sq) 61 | # print('uhat', u_hat.shape) 62 | 63 | # compute flow with u_hat 64 | wzb = torch.bmm(weights, zk) + bias 65 | new_z = zk + u_hat * torch.tanh(wzb) 66 | new_z = new_z.squeeze(2) 67 | 68 | # compute logdetJ 69 | psi = weights * (1-torch.tanh(wzb)**2) 70 | log_det_jacobian = torch.log(torch.abs(1 + torch.bmm(psi, u_hat))+ 1e-8) 71 | logdet_jacobian = log_det_jacobian.squeeze(2).squeeze(1) 72 | # print(logdet_jacobian.shape) 73 | 74 | 75 | else: 76 | 77 | u = self.u 78 | weights = self.weights 79 | bias = self.bias 80 | 81 | ## we follow the instruction on the Appendix A of the paper: 82 | ## for the planar flow to be invertible whe using tanh a sufficient 83 | ## condition is to w^Tu >= -1 84 | ## --> we have to compute uhat parallel to w 85 | u_temp = (weights @ u.t()).squeeze() 86 | # print('u_temp', u_temp.shape) 87 | m_u_temp = -1 + F.softplus(u_temp) 88 | # print('m_u_temp', m_u_temp.shape) 89 | 90 | uhat = u + (m_u_temp - u_temp) * (weights / (weights @ weights.t())) 91 | print('uhat', uhat.shape) 92 | 93 | z_temp = zk @ weights.t() + bias #F.linear(z, self.weights, self.bias) 94 | 95 | new_z = zk + uhat * torch.tanh(z_temp) 96 | 97 | ## now we have to compute psi 98 | 99 | psi = (1 - torch.tanh(z_temp)**2) @ weights 100 | print('psi', psi.shape) 101 | 102 | det_jac = 1 + psi @ uhat.t() #uhat * psi 103 | 104 | logdet_jacobian = torch.log(torch.abs(det_jac) + 1e-8).squeeze() 105 | # print(torch.sum(logdet_jacobian, 1).shape) 106 | print('log_det_jacobian', logdet_jacobian.shape) 107 | 108 | return new_z, logdet_jacobian # torch.sum(logdet_jacobian, 1) 109 | 110 | 111 | 112 | 113 | class NormalizingFlows(nn.Module): 114 | 115 | def __init__(self, n_flows = 1, amortized_params_flow = False, z_dim = None, flow_type = PlanarFlow): 116 | ''' 117 | 118 | :param z_dims: dimension of the latent variables 119 | :param n_flows: how many flows we should use in term of sequence of function f_k (f_k-1(f_k-2(.. 120 | :param flow_type: we have implemented only the Planar Flow, but in case we implement also the radial flow, one can 121 | select what type of flows to use 122 | ''' 123 | 124 | super(NormalizingFlows, self).__init__() 125 | self.n_flows = n_flows 126 | self.flow_type = flow_type 127 | self.amortized_params_flow = amortized_params_flow 128 | self.z_dim = z_dim 129 | 130 | flows_sequence = [self.flow_type(self.amortized_params_flow, self.z_dim) for _ in range(self.n_flows)] 131 | 132 | self.flows = nn.ModuleList(flows_sequence) 133 | 134 | def forward(self, z, u, w, b): 135 | 136 | # we have to collect all the logdet_jacobian to sum them up in the end 137 | logdet_jacobians = [] 138 | # i = 0 139 | # print('u') 140 | # print(u) 141 | # print(u.shape) 142 | for k, flow in enumerate(self.flows): 143 | # i += 1 144 | # print(i) 145 | # print(self.amortized_params_flow) 146 | if self.amortized_params_flow: 147 | z, logdet_j = flow(z, u[:, k, :, :], w[:, k, :, :], b[:, k, :, :]) 148 | else: 149 | z, logdet_j = flow(z, None, None, None) 150 | # print(logdet_j.shape) 151 | logdet_jacobians.append(logdet_j) 152 | 153 | z_k = z 154 | # print('final_logdet_jacobian', logdet_jacobians) 155 | logdet_jacobians = torch.stack(logdet_jacobians, dim=1) 156 | # print(logdet_jacobians.shape) 157 | # print(torch.sum(logdet_jacobians, 1).shape) 158 | # print('new_z', z_k.shape) 159 | # print('we sum them') 160 | # print(torch.sum(logdet_jacobians, 1)) 161 | return z_k, torch.sum(logdet_jacobians, 1) 162 | 163 | 164 | 165 | 166 | 167 | 168 | -------------------------------------------------------------------------------- /MNIST_experiment.py: -------------------------------------------------------------------------------- 1 | ''' 2 | We are going to learn a latent space and a generative model for the MNIST dataset. 3 | 4 | ''' 5 | 6 | import math 7 | import numpy as np 8 | import torch 9 | import torch.utils 10 | import torch.utils.data 11 | from torch.utils.tensorboard import SummaryWriter 12 | import torchvision 13 | from torch import nn 14 | import torch.nn.functional as F 15 | from torchvision import datasets, transforms, utils 16 | from torch.autograd import Variable 17 | from VAE_with_normalizing_flows.VAE_with_flows import VariationalAutoencoderWithFlows 18 | from VAE_with_normalizing_flows.utils.code_to_load_the_dataset import load_MNIST_dataset 19 | from sklearn.decomposition import PCA 20 | 21 | import matplotlib.pyplot as plt 22 | 23 | def show_images(images, title=None, path=None): 24 | images = utils.make_grid(images) 25 | show_image(images[0], title, path) 26 | 27 | def show_image(img, title = "", path = None): 28 | plt.imshow(img, cmap='gray') 29 | plt.title(title) 30 | if path is not None: 31 | plt.savefig(path) 32 | plt.show() 33 | 34 | # We use this custom binary cross entropy 35 | # def binary_cross_entropy(r, x): 36 | # return -torch.sum(x * torch.log(r + 1e-8) + (1 - x) * torch.log(1 - r + 1e-8), dim=-1) 37 | 38 | # Writer will output to ./runs/ directory by default 39 | writer = SummaryWriter() 40 | ORIGINAL_BINARIZED_MNIST = True 41 | use_cuda = torch.cuda.is_available() 42 | print('Do we get access to a CUDA? - ', use_cuda) 43 | device = torch.device("cuda" if use_cuda else "cpu") 44 | BATCH_SIZE = 100 45 | HIDDEN_LAYERS = [400] 46 | Z_DIM = 40 47 | N_FLOWS = 10 48 | 49 | N_EPOCHS = 200 50 | LEARNING_RATE = 1e-5 51 | MOMENTUM = 0.9 52 | WEIGHT_DECAY = -1 53 | 54 | AMORTIZED_WEIGHTS = True 55 | 56 | N_SAMPLE = 64 57 | 58 | SAVE_MODEL_EPOCH = N_EPOCHS - 5 59 | PATH = 'saved_models/' 60 | 61 | ## we have the binarized MNIST 62 | ## TRAIN SET 63 | if ORIGINAL_BINARIZED_MNIST: 64 | train_loader, val_loader, test_loader = load_MNIST_dataset('Original_MNIST_binarized/', BATCH_SIZE, True, True, True) 65 | else: 66 | training_set = datasets.MNIST('../MNIST_dataset', train=True, download=True, 67 | transform=transforms.ToTensor()) 68 | print('Number of examples in the training set:', len(training_set)) 69 | print('Size of the image:', training_set[0][0].shape) 70 | ## we plot an example only to check it 71 | idx_ex = 1000 72 | x, y = training_set[idx_ex] # x is now a torch.Tensor 73 | plt.imshow(x.numpy()[0], cmap='gray') 74 | plt.title('Example n {}, label: {}'.format(idx_ex, y)) 75 | plt.show() 76 | 77 | ### we only check if it is binarized 78 | input_dim = x.numpy().size 79 | print('Size of the image:', input_dim) 80 | 81 | flatten_bernoulli = lambda x: transforms.ToTensor()(x).view(-1).bernoulli() 82 | 83 | train_loader = torch.utils.data.DataLoader( 84 | datasets.MNIST('../MNIST_dataset', train=True, transform=flatten_bernoulli), 85 | batch_size=BATCH_SIZE, shuffle=True) 86 | 87 | ## TEST SET 88 | test_loader = torch.utils.data.DataLoader( 89 | datasets.MNIST('../MNIST_dataset', train=False, transform=flatten_bernoulli), 90 | batch_size=BATCH_SIZE, shuffle=True) 91 | 92 | ## another way to plot some images from the dataset 93 | dataiter = iter(train_loader) 94 | images, labels = dataiter.next() ## next return a complete batch --> BATCH_SIZE images 95 | show_images(images.view(BATCH_SIZE,1,28,28)) 96 | 97 | 98 | ## now we have our train and test set 99 | ## we can create our model and try to train it 100 | model = VariationalAutoencoderWithFlows(28*28, HIDDEN_LAYERS, Z_DIM, N_FLOWS, amortized_params_flow = AMORTIZED_WEIGHTS) 101 | print('Model overview and recap\n') 102 | print(model) 103 | print('\n') 104 | 105 | optimizer = torch.optim.RMSprop(model.parameters(), lr=LEARNING_RATE, momentum = MOMENTUM) 106 | 107 | ## training loop 108 | training_loss = [] 109 | approx_kl = [] 110 | anal_kl = [] 111 | print('.....Starting trianing') 112 | t = 0 113 | for epoch in range(N_EPOCHS): 114 | tmp_elbo = 0 115 | tmp_kl = 0 116 | tmp_recon = 0 117 | n_batch = 0 118 | for i, data in enumerate(train_loader, 0): 119 | beta = min(1, 0.01 + t / 700) 120 | n_batch += 1 121 | if ORIGINAL_BINARIZED_MNIST: 122 | images = data 123 | else: 124 | images, labels = data 125 | images = images.to(device) 126 | 127 | reconstruction = model(images) 128 | # print('images shape', images.shape) 129 | # print('recon shape', test_set_reconstruction.shape) 130 | 131 | # likelihood = -binary_cross_entropy(reconstruction, images) 132 | likelihood = - F.binary_cross_entropy(reconstruction, images, reduction='sum') 133 | kl = torch.sum(model.qz - beta * model.pz) 134 | bound = beta * torch.sum(likelihood) - kl 135 | 136 | L = - bound / len(images) #BATCH_SIZE 137 | 138 | L.backward() 139 | optimizer.step() 140 | optimizer.zero_grad() 141 | # if L.item()/len(images) > 4: 142 | # print('Epoch: {}, Batch: {}, images in the batch: {}, L.item: {}'.format(epoch, i, len(images), L.item())) 143 | training_loss.append(-bound/ len(images)) 144 | tmp_elbo += - L.item() * BATCH_SIZE 145 | tmp_recon += torch.sum(likelihood) 146 | tmp_kl += kl 147 | 148 | ## we should update our t 149 | t += 1 150 | 151 | 152 | ## at the end of each epoch we can try to store some images 153 | ## 154 | with torch.no_grad(): 155 | for r, data in enumerate(test_loader, 0): 156 | if ORIGINAL_BINARIZED_MNIST: 157 | images = data 158 | else: 159 | images, labels = data 160 | images = images.to(device) 161 | reconstruction = model(images) 162 | # print(test_set_reconstruction.shape) 163 | recon_image_ = reconstruction.view(reconstruction.shape[0], 1, 28, 28) 164 | images = images.view(images.shape[0], 1, 28, 28) 165 | if r % 100 == 0: 166 | # show_images(images, 'original') 167 | # show_images(recon_image_, 'test_set_reconstruction') 168 | grid1 = torchvision.utils.make_grid(images) 169 | writer.add_image('orig images', grid1, 0) 170 | grid2 = torchvision.utils.make_grid(recon_image_) 171 | writer.add_image('recon images', grid2) 172 | writer.close() 173 | ## maybe we just store the test_set_reconstruction 174 | ## maybe we just store the test_set_reconstruction 175 | images = utils.make_grid(images) 176 | recon_image_ = utils.make_grid(recon_image_) 177 | plt.imshow(images[0], cmap='gray') 178 | plt.title('Original from epoch {}'.format(epoch + 1)) 179 | plt.savefig('reconstruction_during_training/originals_epoch_{}_example_{}'.format(epoch + 1, r)) 180 | plt.imshow(recon_image_[0], cmap='gray') 181 | plt.title('Reconstruction from epoch {}'.format(epoch + 1)) 182 | plt.savefig('reconstruction_during_training/reconstruction_epoch_{}_example_{}'.format(epoch + 1, r)) 183 | 184 | model.eval() 185 | ## we want also to sample something from the model during training 186 | rendom_samples = model.sample(N_SAMPLE) 187 | samples = rendom_samples.view(rendom_samples.shape[0], 1, 28, 28) 188 | samples = utils.make_grid(samples) 189 | plt.imshow(samples[0], cmap='gray') 190 | plt.title('Samples from epoch {}'.format(epoch + 1)) 191 | plt.savefig('samples_during_training/samples_epoch_{}'.format(epoch + 1)) 192 | 193 | 194 | 195 | print('Epoch: {}, Elbo: {}, recon_error: {}, kl: {}'.format(epoch+1, tmp_elbo/ len(train_loader.dataset), -tmp_recon/ len(train_loader.dataset), tmp_kl/ len(train_loader.dataset))) 196 | 197 | if epoch + 1 > SAVE_MODEL_EPOCH: 198 | ## we have to store the model 199 | torch.save(model.state_dict(), PATH + 'VAE_flows_zdim_{}_epoch_{}_elbo_{}_learnrate_{}'.format(Z_DIM, epoch+1, tmp_elbo/ len(train_loader.dataset), LEARNING_RATE)) 200 | 201 | 202 | 203 | 204 | print('....Training ended') 205 | fig = plt.figure() 206 | plt.plot(training_loss, label='Bound mean per batch') 207 | plt.legend() 208 | plt.show() 209 | 210 | # plt.plot(approx_kl, label='Approximated KL (mean)') 211 | # plt.legend() 212 | # plt.show() 213 | 214 | 215 | model.eval() 216 | with torch.no_grad(): 217 | for i, data in enumerate(test_loader, 0): 218 | if ORIGINAL_BINARIZED_MNIST: 219 | images = data 220 | else: 221 | images, labels = data 222 | images = images.to(device) 223 | reconstruction = model(images) 224 | # print(test_set_reconstruction.shape) 225 | recon_image_ = reconstruction.view(reconstruction.shape[0], 1, 28, 28) 226 | images = images.view(images.shape[0], 1, 28, 28) 227 | if i % 100 == 0: 228 | show_images(images, 'original') 229 | show_images(recon_image_, 'test_set_reconstruction') 230 | images = utils.make_grid(images) 231 | recon_image_ = utils.make_grid(recon_image_) 232 | plt.imshow(images[0], cmap='gray') 233 | plt.title('Original') 234 | plt.savefig('reconstruction_during_training/originals_example_{}'.format(i)) 235 | plt.imshow(recon_image_[0], cmap='gray') 236 | plt.title('Reconstruction') 237 | plt.savefig('reconstruction_during_training/reconstruction_example_{}'.format(i)) 238 | 239 | 240 | # samples form the prios 241 | for i in range(5): 242 | # random_latent = torch.randn((N_SAMPLE, Z_DIM), dtype = torch.float).to(device) 243 | images_from_random = model.sample(N_SAMPLE) 244 | sampled_ima = images_from_random.view(images_from_random.shape[0], 1, 28, 28) 245 | show_images(sampled_ima, 'Random sampled imagess', 'random_samples/Random_samples_ex_{}'.format(i+1)) 246 | -------------------------------------------------------------------------------- /VAE_with_flows.py: -------------------------------------------------------------------------------- 1 | ''' 2 | In this file we are going to create our first implementation of a VAE, following 3 | the Kingma and Welling [2014] paper. I cannot be sure it will be an optimized version, 4 | but I will try to do my best. 5 | 6 | 7 | ''' 8 | import math 9 | import torch 10 | import torch.utils 11 | import torch.utils.data 12 | from torch import nn 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | from torch.nn import init 16 | from VAE_with_normalizing_flows.flows_with_amortized_weights import PlanarFlow, NormalizingFlows 17 | from VAE_with_normalizing_flows.maxout import Maxout 18 | import numpy as np 19 | 20 | ## function to compute the standard gaussian N(x;0,I) and a gaussian parametrized by 21 | ## mean mu and variance sigma log N(x|µ,σ) 22 | def log_standard_gaussian(x): 23 | """ 24 | Evaluates the log pdf of a standard normal distribution at x. (Univariate distribution) 25 | :param x: point to evaluate 26 | :return: log N(x|0,I) 27 | """ 28 | return torch.sum(-0.5 * math.log(2 * math.pi) - (x ** 2 + 1e-8) / 2, dim=-1) 29 | 30 | 31 | def log_gaussian(x, mu, log_var): 32 | """ 33 | Returns the log pdf of a normal distribution parametrised 34 | by mu and log_var evaluated at x. (Univariate distribution) 35 | :param x: point to evaluate 36 | :param mu: mean of distribution 37 | :param log_var: log variance of distribution 38 | :return: log N(x|µ,σ) 39 | """ 40 | log_pdf = - 0.5 * math.log(2 * math.pi) - (log_var + 1e-8) / 2 - ((x - mu)**2 + 1e-8) / (2 * torch.exp(log_var)) 41 | # print('Size log_pdf:', log_pdf.shape) 42 | return torch.sum(log_pdf, dim=-1) 43 | 44 | ## in a simple explanation a VAE is made up of three different parts: 45 | ## - Inference model (or encoder) q_phi(z|x) 46 | ## - A stochastic layer that sample (Reparametrization trick) 47 | ## - a generative model (or decoder) p_theta(z|x) 48 | ## given this, we want to minimize the ELBO 49 | 50 | def reparametrization_trick(mu, log_var): 51 | ''' 52 | Function that given the mean (mu) and the logarithmic variance (log_var) compute 53 | the latent variables using the reparametrization trick. 54 | z = mu + sigma * noise, where the noise is sample 55 | 56 | :param mu: mean of the z_variables 57 | :param log_var: variance of the latent variables 58 | :return: z = mu + sigma * noise 59 | ''' 60 | # we should get the std from the log_var 61 | # log_std = 0.5 * log_var (use the logarithm properties) 62 | # std = exp(log_std) 63 | std = torch.exp(log_var * 0.5) 64 | 65 | # we have to sample the noise (we do not have to keep the gradient wrt the noise) 66 | eps = Variable(torch.randn_like(std), requires_grad=False) 67 | z = mu.addcmul(std, eps) 68 | 69 | return z 70 | 71 | 72 | class Encoder(nn.Module): 73 | def __init__(self, input_dim, hidden_dims, latent_dim, amortized_params_flow = False, n_flows = None): 74 | ''' 75 | Probabilistic inference network given by a MLP. In case of a Gaussian MLP, we will 76 | have to output: log(sigma^2) and mu. 77 | 78 | :param input_dim: dimension of the input (scalar) 79 | :param hidden_dims: dimensions of the hidden layers (vector) 80 | :param latent_dim: dimension of the latent space 81 | :param amortized_params_flow: bool, if true we get the u, w ,b from the input 82 | :param n_flows: number of flows 83 | 84 | In addition to return z, _mu, _log_var, if amortized_params_flow = True it returns also 85 | the weights needed for the transformation 86 | ''' 87 | 88 | super(Encoder, self).__init__() 89 | 90 | self.z_dims = latent_dim 91 | ## now we have to create the architecture 92 | neurons = [input_dim, *hidden_dims] 93 | ## common part of the architecture 94 | self.hidden_layers = nn.ModuleList([nn.Linear(neurons[i-1], neurons[i]) for i in range(1,len(neurons))]) 95 | # self.maxout = Maxout(4) 96 | # dim_after_maxout = int(hidden_dims[-1] / 4) 97 | ## we have two output: mu and log(sigma^2) 98 | self.mu = nn.Linear(hidden_dims[-1], latent_dim) 99 | self.log_var = nn.Linear(hidden_dims[-1], latent_dim) 100 | 101 | ## flows part 102 | self.amortized_params_flow = amortized_params_flow 103 | self.n_flows = n_flows 104 | 105 | if self.amortized_params_flow: 106 | self.u = nn.Linear(hidden_dims[-1], latent_dim * self.n_flows) 107 | self.weights = nn.Linear(hidden_dims[-1], latent_dim * self.n_flows) 108 | self.bias = nn.Linear(hidden_dims[-1], self.n_flows) 109 | 110 | 111 | def forward(self, input): 112 | x = input 113 | batch_size = x.size(0) 114 | for layer in self.hidden_layers: 115 | x = F.relu(layer(x)) 116 | 117 | ## now we should compute the mu and log var 118 | _mu = self.mu(x) 119 | # _log_var = F.softplus(self.log_var(x)) 120 | _log_var = self.log_var(x) 121 | 122 | ## now we have also to return our z as the reparametrization trick told us 123 | ## z = mu + sigma * noise, where the noise is sample 124 | 125 | z = reparametrization_trick(_mu, _log_var) 126 | 127 | if self.amortized_params_flow: 128 | u = self.u(x) 129 | w = self.weights(x) 130 | b = self.bias(x) 131 | 132 | # print( u.view(batch_size, self.n_flows, self.z_dims).shape) 133 | return z, _mu, _log_var, u.view(batch_size, self.n_flows, self.z_dims, 1), w.view(batch_size, self.n_flows, 1, self.z_dims), b.view(batch_size, self.n_flows, 1, 1) 134 | 135 | else: 136 | return z, _mu, _log_var 137 | 138 | 139 | ## now we have to create the Decoder class 140 | class Decoder(nn.Module): 141 | def __init__(self, latent_dim, hidden_dims, input_dim): 142 | ''' 143 | 144 | :param latent_dim: dimension of the latent space (scalar) 145 | :param hidden_dims: dimensions of the hidden layers (vector) 146 | :param input_dim: dimension of the input (scalar) 147 | ''' 148 | 149 | super(Decoder, self).__init__() 150 | 151 | # this is kind of symmetric to the encoder, it starts from the latent variables z and it 152 | # tries to get the original x back 153 | 154 | neurons = [latent_dim, *hidden_dims] 155 | self.hidden_layers = nn.ModuleList([nn.Linear(neurons[i - 1], neurons[i]) for i in range(1, len(neurons))]) 156 | # self.maxout = Maxout(4) 157 | # dim_after_maxout = int(hidden_dims[-1] / 4) 158 | self.reconstruction = nn.Linear(hidden_dims[-1], input_dim) 159 | self.output_activation = nn.Sigmoid() 160 | 161 | def forward(self, input): 162 | x = input 163 | for layer in self.hidden_layers: 164 | x = F.relu(layer(x)) 165 | # print(self.test_set_reconstruction(x).shape) 166 | return self.output_activation(self.reconstruction(x)) 167 | 168 | 169 | ## at this point we have both the encoder and decoder, so we can create the VAE 170 | 171 | class VariationalAutoencoderWithFlows(nn.Module): 172 | def __init__(self, input_dim, hidden_dims, latent_dim, n_flows, flow_type = PlanarFlow, amortized_params_flow = False): 173 | ''' 174 | Variational AutoEncoder as described in Kingma and Welling 2014. We have an encoder - decoder 175 | and we want to learn a meaningful latent representation to being able to reconstruct the input 176 | 177 | :param input_dim: dimension of the input 178 | :param hidden_dims: dimension of hidden layers #todo: maybe we can differentiate between the encoder and decoder? 179 | :param latent_dim: dimension of the latent variables 180 | ''' 181 | 182 | super(VariationalAutoencoderWithFlows, self).__init__() 183 | 184 | self.input_dim = input_dim 185 | self.hidden_dims = hidden_dims 186 | self.z_dims = latent_dim 187 | 188 | ## infos about using flows 189 | self.n_flows = n_flows 190 | self.flow_type = flow_type 191 | self.amortized_params_flow = amortized_params_flow 192 | 193 | ## we should create the encoder and the decoder 194 | 195 | if self.amortized_params_flow: 196 | self.encoder = Encoder(input_dim, hidden_dims, latent_dim, self.amortized_params_flow, self.n_flows) 197 | self.flows = NormalizingFlows(self.n_flows, self.amortized_params_flow, None) 198 | else: 199 | self.encoder = Encoder(input_dim, hidden_dims, latent_dim) 200 | self.flows = NormalizingFlows(self.n_flows, z_dim = self.z_dims) 201 | 202 | self.decoder = Decoder(latent_dim, list(reversed(hidden_dims)), input_dim) 203 | 204 | 205 | self.kl_divergence = 0 206 | 207 | ## we should initialize the weights # 208 | for m in self.modules(): 209 | if isinstance(m, nn.Linear): 210 | init.xavier_normal_(m.weight.data) 211 | if m.bias is not None: 212 | m.bias.data.zero_() 213 | # 214 | # def _kl_divergence(self, z, q_params, p_params = None): 215 | # ''' 216 | # The function compute the KL divergence between the distribution q_phi(z|x) and the prior p_theta(z) 217 | # of a sample z. 218 | # 219 | # KL(q_phi(z|x) || p_theta(z)) = -∫ q_phi(z|x) log [ p_theta(z) / q_phi(z|x) ] 220 | # = -E[log p_theta(z) - log q_phi(z|x)] 221 | # 222 | # :param z: sample from the distribution q_phi(z|x) 223 | # :param q_params: (mu, log_var) of the q_phi(z|x) 224 | # :param p_params: (mu, log_var) of the p_theta(z) 225 | # :return: the kl divergence KL(q_phi(z|x) || p_theta(z)) computed in z 226 | # ''' 227 | # 228 | # ## we have to compute the pdf of z wrt q_phi(z|x) 229 | # (mu, log_var) = q_params 230 | # qz = log_gaussian(z, mu, log_var) 231 | # # print('size qz:', qz.shape) 232 | # ## we should do the same with p 233 | # if p_params is None: 234 | # pz = log_standard_gaussian(z) 235 | # else: 236 | # (mu, log_var) = p_params 237 | # pz = log_gaussian(z, mu, log_var) 238 | # # print('size pz:', pz.shape) 239 | # 240 | # kl = qz - pz 241 | # 242 | # return kl 243 | 244 | def _kl_divergence_flows(self, z, q_init_params, new_z, logdet_jacobians, p_params = None): 245 | ''' 246 | The function compute the KL divergence between the distribution q_phi(z|x) and the prior p_theta(z) 247 | of a sample z. 248 | 249 | KL(q_phi(z|x) || p_theta(z)) = -∫ q_phi(z|x) log [ p_theta(z) / q_phi(z|x) ] 250 | = -E[log p_theta(z) - log q_phi(z|x)] 251 | 252 | :param z: sample from the distribution q_phi(z|x) 253 | :param q_params: (mu, log_var) of the q_phi(z|x) 254 | :param p_params: (mu, log_var) of the p_theta(z) 255 | :return: the kl divergence KL(q_phi(z|x) || p_theta(z)) computed in z 256 | ''' 257 | 258 | ## we have to compute the pdf of z wrt q_phi(z|x) 259 | 260 | (mu, log_var) = q_init_params 261 | q0 = log_gaussian(z, mu, log_var) 262 | # print(q0) 263 | # print('q0 shape', q0.shape) 264 | # print('log_det_shape', logdet_jacobians.shape) 265 | 266 | # now we have to compute the qz 267 | qz = q0 - logdet_jacobians 268 | # print('size qz:', qz.shape) 269 | ## we should do the same with p 270 | if p_params is None: 271 | pz = log_standard_gaussian(new_z) 272 | else: 273 | (mu, log_var) = p_params 274 | pz = log_gaussian(new_z, mu, log_var) 275 | # print('size pz:', pz.shape) 276 | 277 | return qz, pz 278 | 279 | # ## in case we are using a gaussian prior and a gaussian approximation family 280 | # def _analytical_kl_gaussian(self, q_params): 281 | # ''' 282 | # Way for computing the kl in an analytical way. This works for gaussian prior 283 | # and gaussian density family for the approximated posterior. 284 | # 285 | # :param q_params: (mu, log_var) of the q_phi(z|x) 286 | # :return: the kl value computed analytically 287 | # ''' 288 | # 289 | # (mu, log_var) = q_params 290 | # # print(mu.shape) 291 | # # print(log_var.shape) 292 | # # prova = (log_var + 1 - mu**2 - log_var.exp()) 293 | # # print(prova.shape) 294 | # # print(torch.sum(prova, 1).shape) 295 | # # kl = 0.5 * torch.sum(log_var + 1 - mu**2 - log_var.exp(), 1) 296 | # kl = 0.5 * torch.sum(log_var + 1 - mu.pow(2) - log_var.exp(), 1) 297 | # 298 | # return kl 299 | 300 | 301 | 302 | def forward(self, input): 303 | ''' 304 | Given an input, we want to run the encoder, compute the kl, and reconstruct it 305 | 306 | :param input: an input example 307 | :return: the reconstructed input 308 | ''' 309 | 310 | # we pass the input through the encoder 311 | if self.amortized_params_flow: 312 | z, z_mu, z_log_var, u, w, b = self.encoder(input) 313 | # print(u.shape) 314 | # print(w.shape) 315 | # print(b.shape) 316 | # we have to process the z through the flows 317 | new_z, logdet_jacobians = self.flows(z, u, w, b) 318 | else: 319 | z, z_mu, z_log_var = self.encoder(input) 320 | # we have to process the z through the flows 321 | new_z, logdet_jacobians = self.flows(z, None, None, None) 322 | # print('original z ', z) 323 | 324 | # we compute the kl 325 | self.qz, self.pz = self._kl_divergence_flows(z, (z_mu, z_log_var), new_z, logdet_jacobians) 326 | # self.kl_analytical = self._analytical_kl_gaussian((z_mu, z_log_var)) 327 | 328 | 329 | # we reconstruct it 330 | x_mu = self.decoder(new_z) 331 | 332 | return x_mu 333 | 334 | 335 | def sample(self, n_images): 336 | ''' 337 | Method to sample from our generative model 338 | 339 | :return: a sample starting from z ~ N(0,1) 340 | ''' 341 | 342 | # in a VAE + normalizing flows, we should start by a random sample from N(0,1) 343 | # and then we should propagate the z into the flows 344 | z = torch.randn((n_images, self.z_dims), dtype = torch.float) 345 | 346 | # since it is amortized flow, maybe we should sample also 347 | # the u, the weights and the bias ? #todo: check this 348 | u = torch.randn((n_images, self.n_flows, self.z_dims, 1)) 349 | w = torch.randn((n_images, self.n_flows, 1, self.z_dims)) 350 | b = torch.randn((n_images, self.n_flows, 1, 1)) 351 | # print('when we are sampling') 352 | # print(u.shape) 353 | # print(w.shape) 354 | # print(b.shape) 355 | if self.amortized_params_flow: 356 | # print( self.encoder.u) 357 | new_z, _ = self.flows(z, u, w, b) 358 | else: 359 | new_z, _ = self.flows(z, None, None, None) 360 | # when we get to the final flow we get the new_z and we propagate it 361 | # into the decoder 362 | samples = self.decoder(new_z) 363 | 364 | return samples 365 | 366 | 367 | 368 | 369 | 370 | 371 | 372 | 373 | 374 | 375 | 376 | 377 | --------------------------------------------------------------------------------