├── .gitignore ├── IAF ├── IAF.py ├── __init__.py ├── autoregressive_linear.py └── highway.py ├── LICENSE ├── README.md ├── blocks ├── __init__.py ├── generative_block.py └── inference_block.py ├── data ├── processed │ ├── test.pt │ └── training.pt └── raw │ ├── t10k-images-idx3-ubyte │ ├── t10k-labels-idx1-ubyte │ ├── train-images-idx3-ubyte │ └── train-labels-idx1-ubyte ├── experimentation.py ├── mnist ├── __init__.py ├── generative_out.py ├── parameters_inference.py └── vae.py └── train_vae.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | .idea/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | /prior_sampling 21 | trained_model 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # dotenv 86 | .env 87 | 88 | # virtualenv 89 | .venv 90 | venv/ 91 | ENV/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # project 107 | experimentation.py 108 | -------------------------------------------------------------------------------- /IAF/IAF.py: -------------------------------------------------------------------------------- 1 | import torch as t 2 | import torch.nn as nn 3 | 4 | from IAF.autoregressive_linear import AutoregressiveLinear 5 | from IAF.highway import Highway 6 | 7 | 8 | class IAF(nn.Module): 9 | def __init__(self, latent_size, h_size): 10 | super(IAF, self).__init__() 11 | 12 | self.z_size = latent_size 13 | self.h_size = h_size 14 | 15 | self.h = Highway(self.h_size, 3, nn.ELU()) 16 | 17 | self.m = nn.Sequential( 18 | AutoregressiveLinear(self.z_size + self.h_size, self.z_size), 19 | nn.ELU(), 20 | AutoregressiveLinear(self.z_size, self.z_size), 21 | nn.ELU(), 22 | AutoregressiveLinear(self.z_size, self.z_size) 23 | ) 24 | 25 | self.s = nn.Sequential( 26 | AutoregressiveLinear(self.z_size + self.h_size, self.z_size), 27 | nn.ELU(), 28 | AutoregressiveLinear(self.z_size, self.z_size), 29 | nn.ELU(), 30 | AutoregressiveLinear(self.z_size, self.z_size) 31 | ) 32 | 33 | def forward(self, z, h): 34 | """ 35 | :param z: An float tensor with shape of [batch_size, z_size] 36 | :param h: An float tensor with shape of [batch_size, h_size] 37 | :return: An float tensor with shape of [batch_size, z_size] and log det value of the IAF mapping Jacobian 38 | """ 39 | 40 | h = self.h(h) 41 | 42 | input = t.cat([z, h], 1) 43 | 44 | m = self.m(input) 45 | s = self.s(input) 46 | 47 | z = s.exp() * z + m 48 | 49 | log_det = s.sum(1) 50 | 51 | return z, log_det 52 | -------------------------------------------------------------------------------- /IAF/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kefirski/bdir_vae/63c5fd2e4d96479d77db8240246428fff4a04ab9/IAF/__init__.py -------------------------------------------------------------------------------- /IAF/autoregressive_linear.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch as t 4 | import torch.nn as nn 5 | from torch.nn.init import xavier_normal 6 | from torch.nn.parameter import Parameter 7 | 8 | 9 | class AutoregressiveLinear(nn.Module): 10 | def __init__(self, in_size, out_size, bias=True, ): 11 | super(AutoregressiveLinear, self).__init__() 12 | 13 | self.in_size = in_size 14 | self.out_size = out_size 15 | 16 | self.weight = Parameter(t.Tensor(self.in_size, self.out_size)) 17 | 18 | if bias: 19 | self.bias = Parameter(t.Tensor(self.out_size)) 20 | else: 21 | self.register_parameter('bias', None) 22 | 23 | self.reset_parameters() 24 | 25 | def reset_parameters(self, ): 26 | stdv = 1. / math.sqrt(self.out_size) 27 | 28 | self.weight = xavier_normal(self.weight) 29 | 30 | if self.bias is not None: 31 | self.bias.data.uniform_(-stdv, stdv) 32 | 33 | def forward(self, input): 34 | if input.dim() == 2 and self.bias is not None: 35 | return t.addmm(self.bias, input, self.weight.tril(-1)) 36 | 37 | output = input @ self.weight.tril(-1) 38 | if self.bias is not None: 39 | output += self.bias 40 | return output -------------------------------------------------------------------------------- /IAF/highway.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class Highway(nn.Module): 6 | def __init__(self, size, num_layers, f): 7 | super(Highway, self).__init__() 8 | 9 | self.num_layers = num_layers 10 | 11 | self.nonlinear = nn.ModuleList([nn.utils.weight_norm(nn.Linear(size, size)) for _ in range(num_layers)]) 12 | self.linear = nn.ModuleList([nn.utils.weight_norm(nn.Linear(size, size)) for _ in range(num_layers)]) 13 | self.gate = nn.ModuleList([nn.utils.weight_norm(nn.Linear(size, size)) for _ in range(num_layers)]) 14 | 15 | self.f = f 16 | 17 | def forward(self, x): 18 | """ 19 | :param x: tensor with shape of [batch_size, size] 20 | :return: tensor with shape of [batch_size, size] 21 | applies σ(x) ⨀ f(G(x)) + (1 - σ(x)) ⨀ Q(x) transformation | G and Q is affine transformation, 22 | f is non-linear transformation, σ(x) is affine transformation with sigmoid non-linearition 23 | and ⨀ is element-wise multiplication 24 | """ 25 | 26 | for layer in range(self.num_layers): 27 | gate = F.sigmoid(self.gate[layer](x)) 28 | 29 | nonlinear = self.f(self.nonlinear[layer](x)) 30 | linear = self.linear[layer](x) 31 | 32 | x = gate * nonlinear + (1 - gate) * linear 33 | 34 | return x 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Daniil 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kefirski/bdir_vae/63c5fd2e4d96479d77db8240246428fff4a04ab9/README.md -------------------------------------------------------------------------------- /blocks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kefirski/bdir_vae/63c5fd2e4d96479d77db8240246428fff4a04ab9/blocks/__init__.py -------------------------------------------------------------------------------- /blocks/generative_block.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class GenerativeBlock(nn.Module): 5 | def __init__(self, **kwargs): 6 | super(GenerativeBlock, self).__init__() 7 | 8 | ''' 9 | Computation graph of generative block is identical to described in 10 | "Improved Variational Inference with Inverse Autoregressive Flow" paper. 11 | 12 | It is quite hard to describe it's flow, 13 | as hard to make input go through one function. 14 | ''' 15 | 16 | for name, value in kwargs.items(): 17 | setattr(self, name, value) 18 | 19 | self.top_most = kwargs.get('input') is None 20 | 21 | def inference(self, inference_input, type): 22 | """ 23 | :param inference_input: An float tensor 24 | :return: distribution parameters 25 | """ 26 | 27 | assert not self.top_most, 'Generative error. Top most block can not perform inference of posterior' 28 | assert type in ['posterior', 'prior'] 29 | 30 | return self.posterior(inference_input) if type == 'posterior' else self.prior(inference_input) 31 | 32 | def forward(self, inference_input): 33 | """ 34 | :param inference_input: An float tensor with top-down input 35 | :return: An float tensor with out of generative function from top-down inference 36 | """ 37 | 38 | assert self.top_most 39 | return self.out(inference_input) 40 | -------------------------------------------------------------------------------- /blocks/inference_block.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class InferenceBlock(nn.Module): 5 | def __init__(self, **kwargs): 6 | super(InferenceBlock, self).__init__() 7 | 8 | ''' 9 | Computation graph of inference block is identical to described in 10 | "Improved Variational Inference with Inverse Autoregressive Flow" paper. 11 | 12 | Firstly input goes through input operation in order to get some hidden state, 13 | that is used in order to get features of posterior distribution. 14 | 15 | Then this hidden state are passed to out operation. 16 | ''' 17 | 18 | for name, value in kwargs.items(): 19 | setattr(self, name, value) 20 | 21 | self.top_most = kwargs.get('out') is None 22 | 23 | def forward(self, input): 24 | """ 25 | :param input: An float tensor with shape appropriate to first_op 26 | :return: result of out operation (None if top_most) and posterior features 27 | """ 28 | 29 | hidden_state = self.input(input) 30 | posterior_parameters = self.posterior(hidden_state) 31 | 32 | if self.top_most: 33 | return posterior_parameters 34 | 35 | result = self.out(hidden_state) if not self.top_most else None 36 | return result, posterior_parameters 37 | -------------------------------------------------------------------------------- /data/processed/test.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kefirski/bdir_vae/63c5fd2e4d96479d77db8240246428fff4a04ab9/data/processed/test.pt -------------------------------------------------------------------------------- /data/processed/training.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kefirski/bdir_vae/63c5fd2e4d96479d77db8240246428fff4a04ab9/data/processed/training.pt -------------------------------------------------------------------------------- /data/raw/t10k-images-idx3-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kefirski/bdir_vae/63c5fd2e4d96479d77db8240246428fff4a04ab9/data/raw/t10k-images-idx3-ubyte -------------------------------------------------------------------------------- /data/raw/t10k-labels-idx1-ubyte: -------------------------------------------------------------------------------- 1 | '                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             -------------------------------------------------------------------------------- /data/raw/train-images-idx3-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kefirski/bdir_vae/63c5fd2e4d96479d77db8240246428fff4a04ab9/data/raw/train-images-idx3-ubyte -------------------------------------------------------------------------------- /data/raw/train-labels-idx1-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kefirski/bdir_vae/63c5fd2e4d96479d77db8240246428fff4a04ab9/data/raw/train-labels-idx1-ubyte -------------------------------------------------------------------------------- /experimentation.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | from operator import mul 3 | from mnist.vae import VAE 4 | 5 | import torch as t 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | import torch.nn.functional as F 9 | 10 | 11 | 12 | 13 | if __name__ == '__main__': 14 | pass -------------------------------------------------------------------------------- /mnist/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kefirski/bdir_vae/63c5fd2e4d96479d77db8240246428fff4a04ab9/mnist/__init__.py -------------------------------------------------------------------------------- /mnist/generative_out.py: -------------------------------------------------------------------------------- 1 | import torch as t 2 | import torch.nn as nn 3 | 4 | 5 | class GenerativeOut(nn.Module): 6 | def __init__(self, fc): 7 | super(GenerativeOut, self).__init__() 8 | 9 | self.fc = fc 10 | 11 | def forward(self, latent_input, determenistic_input): 12 | input = t.cat([latent_input, determenistic_input], 1) 13 | 14 | return self.fc(input) 15 | -------------------------------------------------------------------------------- /mnist/parameters_inference.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class ParametersInference(nn.Module): 5 | def __init__(self, input_size, latent_size, h_size=None): 6 | super(ParametersInference, self).__init__() 7 | 8 | self.mu = nn.utils.weight_norm(nn.Linear(input_size, latent_size)) 9 | self.std = nn.utils.weight_norm(nn.Linear(input_size, latent_size)) 10 | 11 | self.h = nn.Sequential( 12 | nn.utils.weight_norm(nn.Linear(input_size, h_size)), 13 | nn.SELU() 14 | ) if h_size is not None else None 15 | 16 | def forward(self, input): 17 | mu = self.mu(input) 18 | std = (0.5 * self.std(input)).exp() 19 | h = self.h(input) if self.h is not None else None 20 | 21 | return mu, std, h 22 | -------------------------------------------------------------------------------- /mnist/vae.py: -------------------------------------------------------------------------------- 1 | from math import log, pi 2 | 3 | import torch as t 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | 7 | from IAF.IAF import IAF 8 | from blocks.generative_block import GenerativeBlock 9 | from blocks.inference_block import InferenceBlock 10 | from mnist.generative_out import GenerativeOut 11 | from mnist.parameters_inference import ParametersInference 12 | 13 | 14 | class VAE(nn.Module): 15 | def __init__(self): 16 | super(VAE, self).__init__() 17 | 18 | self.inference = nn.ModuleList( 19 | [ 20 | InferenceBlock( 21 | input=nn.Sequential( 22 | nn.utils.weight_norm(nn.Linear(784, 1500)), 23 | nn.ELU() 24 | ), 25 | posterior=ParametersInference(1500, latent_size=100, h_size=100), 26 | out=nn.Sequential( 27 | nn.utils.weight_norm(nn.Linear(1500, 900)), 28 | nn.SELU() 29 | ) 30 | ), 31 | InferenceBlock( 32 | input=nn.Sequential( 33 | nn.utils.weight_norm(nn.Linear(900, 400)), 34 | nn.ELU(), 35 | nn.utils.weight_norm(nn.Linear(400, 200)), 36 | nn.ELU(), 37 | ), 38 | posterior=ParametersInference(200, latent_size=30, h_size=50) 39 | ) 40 | ] 41 | ) 42 | 43 | self.iaf = nn.ModuleList( 44 | [ 45 | IAF(latent_size=100, h_size=100), 46 | IAF(latent_size=30, h_size=50) 47 | ] 48 | ) 49 | 50 | self.generation = nn.ModuleList( 51 | [ 52 | GenerativeBlock( 53 | posterior=ParametersInference(100, latent_size=100), 54 | input=nn.Sequential( 55 | nn.utils.weight_norm(nn.Linear(100, 100)), 56 | nn.SELU() 57 | ), 58 | prior=ParametersInference(100, latent_size=100), 59 | out=GenerativeOut(nn.Sequential( 60 | nn.utils.weight_norm(nn.Linear(100 + 100, 300)), 61 | nn.SELU(), 62 | nn.utils.weight_norm(nn.Linear(300, 400)), 63 | nn.SELU(), 64 | nn.utils.weight_norm(nn.Linear(400, 600)), 65 | nn.SELU(), 66 | nn.utils.weight_norm(nn.Linear(600, 784)), 67 | )), 68 | 69 | ), 70 | GenerativeBlock( 71 | out=nn.Sequential( 72 | nn.utils.weight_norm(nn.Linear(30, 80)), 73 | nn.SELU(), 74 | nn.utils.weight_norm(nn.Linear(80, 100)), 75 | nn.SELU() 76 | ) 77 | ) 78 | ] 79 | ) 80 | 81 | self.latent_size = [100, 30] 82 | 83 | assert len(self.inference) == len(self.generation) == len(self.iaf) 84 | self.vae_length = len(self.inference) 85 | 86 | def forward(self, input): 87 | """ 88 | :param input: An float tensor with shape of [batch_size, 784] 89 | :return: An float tensor with shape of [batch_size, 784] 90 | with logits of margin likelihood expectation 91 | """ 92 | 93 | [batch_size, _] = input.size() 94 | 95 | cuda = input.is_cuda 96 | 97 | posterior_parameters = [] 98 | 99 | ''' 100 | Here we perform top-down inference. 101 | parameters array is filled with posterior parameters [mu, std, h] 102 | ''' 103 | for i in range(self.vae_length): 104 | 105 | if i < self.vae_length - 1: 106 | input, parameters = self.inference[i](input) 107 | else: 108 | parameters = self.inference[i](input) 109 | 110 | posterior_parameters.append(parameters) 111 | 112 | ''' 113 | Here we perform generation in top-most layer. 114 | We will use posterior and prior in layers bellow this. 115 | ''' 116 | [mu, std, h] = posterior_parameters[-1] 117 | 118 | prior = Variable(t.randn(*mu.size())) 119 | eps = Variable(t.randn(*mu.size())) 120 | 121 | if cuda: 122 | prior, eps = prior.cuda(), eps.cuda() 123 | 124 | posterior_gauss = eps * std + mu 125 | posterior, log_det = self.iaf[-1](posterior_gauss, h) 126 | 127 | kld = VAE.monte_carlo_divergence(z=posterior, 128 | z_gauss=posterior_gauss, 129 | log_det=log_det, 130 | posterior=[mu, std]) 131 | 132 | posterior = self.generation[-1](posterior) 133 | prior = self.generation[-1](prior) 134 | 135 | for i in range(self.vae_length - 2, -1, -1): 136 | 137 | ''' 138 | Iteration over not top-most generative layers. 139 | Firstly we pass input through inputs operation in order to get determenistic features 140 | ''' 141 | posterior_determenistic = self.generation[i].input(posterior) 142 | prior_determenistic = self.generation[i].input(prior) 143 | 144 | ''' 145 | Then posterior input goes through inference function in order to get top-down features. 146 | Parameters of posterior are combined together and new latent variable is sampled 147 | ''' 148 | [top_down_mu, top_down_std, _] = self.generation[i].inference(posterior, 'posterior') 149 | [bottom_up_mu, bottom_up_std, h] = posterior_parameters[i] 150 | 151 | posterior_mu = top_down_mu + bottom_up_mu 152 | posterior_std = top_down_std + bottom_up_std 153 | 154 | eps = Variable(t.randn(*posterior_mu.size())) 155 | if cuda: 156 | eps.cuda() 157 | 158 | posterior_gauss = eps * posterior_std + posterior_mu 159 | posterior, log_det = self.iaf[i](posterior_gauss, h) 160 | 161 | ''' 162 | Prior parameters are obtained from prior operation, 163 | then new prior variable is sampled 164 | ''' 165 | prior_mu, prior_std, _ = self.generation[i].inference(prior_determenistic, 'prior') 166 | 167 | kld += VAE.monte_carlo_divergence(z=posterior, 168 | z_gauss=posterior_gauss, 169 | log_det=log_det, 170 | posterior=[posterior_mu, posterior_std], 171 | prior=[prior_mu, prior_std]) 172 | 173 | posterior = self.generation[i].out(posterior, posterior_determenistic) 174 | 175 | if i != 0: 176 | ''' 177 | Since there no level below bottom-most, 178 | there no reason to pass prior through out operation 179 | ''' 180 | 181 | eps = Variable(t.randn(*prior_mu.size())) 182 | if cuda: 183 | eps.cuda() 184 | 185 | prior = eps * prior_std + prior_mu 186 | 187 | prior = self.generation[i].out(prior, prior_determenistic) 188 | 189 | return posterior, kld 190 | 191 | def sample(self, z): 192 | """ 193 | :param z: An array of variables from normal distribution each with shape of [batch_size, latent_size[i]] 194 | :return: Sample from generative model with shape of [batch_size, 784] 195 | """ 196 | 197 | top_variable = z[-1] 198 | 199 | out = self.generation[-1].out(top_variable) 200 | 201 | for i in range(self.vae_length - 2, -1, -1): 202 | determenistic = self.generation[i].input(out) 203 | 204 | [mu, std, _] = self.generation[i].prior(determenistic) 205 | prior = z[i] * std + mu 206 | out = self.generation[i].out(prior, determenistic) 207 | 208 | return out 209 | 210 | @staticmethod 211 | def monte_carlo_divergence(**kwargs): 212 | 213 | log_p_z_x = VAE.log_gauss(kwargs['z_gauss'], kwargs['posterior']) - kwargs['log_det'] 214 | 215 | if kwargs.get('prior') is None: 216 | kwargs['prior'] = [Variable(t.zeros(*kwargs['z'].size())), 217 | Variable(t.ones(*kwargs['z'].size()))] 218 | 219 | one = Variable(t.FloatTensor([1])) 220 | 221 | if kwargs['z'].is_cuda: 222 | one = one.cuda() 223 | for var in kwargs['prior']: 224 | var.cuda() 225 | log_p_z = VAE.log_gauss(kwargs['z'], kwargs['prior']) 226 | 227 | result = log_p_z_x - log_p_z 228 | return t.max(t.stack([result.mean(), one]), 0)[0] 229 | 230 | @staticmethod 231 | def log_gauss(z, params): 232 | [mu, std] = params 233 | return - 0.5 * (t.pow(z - mu, 2) * t.pow(std + 1e-8, -2) + 2 * t.log(std + 1e-8) + log(2 * pi)).sum(1) 234 | -------------------------------------------------------------------------------- /train_vae.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from os import listdir 4 | 5 | import imageio 6 | import torch as t 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torchvision.transforms as transforms 10 | import torchvision.utils as vutils 11 | from torch.autograd import Variable 12 | from torch.optim import Adam 13 | from torchvision import datasets 14 | 15 | from mnist.vae import VAE 16 | 17 | 18 | def make_grid(tensor, number, size): 19 | tensor = t.transpose(tensor, 0, 1).contiguous().view(1, number, number * size, size) 20 | tensor = t.transpose(tensor, 1, 2).contiguous().view(1, number * size, number * size) 21 | 22 | return tensor 23 | 24 | 25 | if __name__ == "__main__": 26 | 27 | if not os.path.exists('prior_sampling'): 28 | os.mkdir('prior_sampling') 29 | 30 | parser = argparse.ArgumentParser(description='CDVAE') 31 | parser.add_argument('--num-epochs', type=int, default=4, metavar='NI', 32 | help='num epochs (default: 4)') 33 | parser.add_argument('--batch-size', type=int, default=40, metavar='BS', 34 | help='batch size (default: 40)') 35 | parser.add_argument('--use-cuda', type=bool, default=False, metavar='CUDA', 36 | help='use cuda (default: False)') 37 | parser.add_argument('--learning-rate', type=float, default=0.001, metavar='LR', 38 | help='learning rate (default: 0.001)') 39 | parser.add_argument('--save', type=str, default='trained_model', metavar='TS', 40 | help='path where save trained model to (default: "trained_model")') 41 | args = parser.parse_args() 42 | 43 | dataset = datasets.MNIST(root='data/', 44 | transform=transforms.Compose([ 45 | transforms.ToTensor()]), 46 | download=True, 47 | train=True) 48 | dataloader = t.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True) 49 | 50 | vae = VAE() 51 | if args.use_cuda: 52 | vae = vae.cuda() 53 | 54 | optimizer = Adam(vae.parameters(), args.learning_rate, eps=1e-6) 55 | 56 | likelihood_function = nn.BCEWithLogitsLoss(size_average=False) 57 | 58 | z = [Variable(t.randn(256, size)) for size in vae.latent_size] 59 | if args.use_cuda: 60 | z = [var.cuda() for var in z] 61 | 62 | for epoch in range(args.num_epochs): 63 | for iteration, (input, _) in enumerate(dataloader): 64 | 65 | input = Variable(input).view(-1, 784) 66 | if args.use_cuda: 67 | input = input.cuda() 68 | 69 | optimizer.zero_grad() 70 | 71 | out, kld = vae(input) 72 | 73 | input = input.view(-1, 1, 28, 28) 74 | out = out.contiguous().view(-1, 1, 28, 28) 75 | 76 | likelihood = likelihood_function(out, input) / args.batch_size 77 | print(likelihood, kld) 78 | loss = likelihood + kld 79 | 80 | loss.backward() 81 | optimizer.step() 82 | 83 | if iteration % 10 == 0: 84 | print('epoch {}, iteration {}, loss {}'.format(epoch, iteration, loss.cpu().data.numpy()[0])) 85 | 86 | sampling = vae.sample(z).view(-1, 1, 28, 28) 87 | 88 | grid = make_grid(F.sigmoid(sampling).cpu().data, 16, 28) 89 | vutils.save_image(grid, 'prior_sampling/vae_{}.png'.format(epoch * len(dataloader) + iteration)) 90 | 91 | samplings = [f for f in listdir('prior_sampling')] 92 | samplings = [imageio.imread('prior_sampling/' + path) for path in samplings for _ in range(2)] 93 | imageio.mimsave('prior_sampling/movie.gif', samplings) 94 | 95 | t.save(vae.cpu().state_dict(), args.save) 96 | --------------------------------------------------------------------------------