├── README.md ├── LICENSE └── vae.py /README.md: -------------------------------------------------------------------------------- 1 | # Variational Autoencoder in PyTorch. 2 | 3 | See this blog post: 4 | http://kvfrans.com/variational-autoencoders-explained/ 5 | 6 | Variational Autoencoder is introduced in this paper 7 | https://arxiv.org/abs/1312.6114 8 | 9 | Also this tutorial paper: 10 | https://arxiv.org/abs/1606.05908 11 | 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2017 Yicheng Luo 2 | 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions are 5 | met: 6 | 7 | 1. Redistributions of source code must retain the above copyright 8 | notice, this list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright 11 | notice, this list of conditions and the following disclaimer in the 12 | documentation and/or other materials provided with the distribution. 13 | 14 | 3. Neither the name of the copyright holder nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import numpy as np 4 | import torch.nn.functional as F 5 | import torchvision 6 | from torchvision import transforms 7 | import torch.optim as optim 8 | from torch import nn 9 | import matplotlib.pyplot as plt 10 | from torch import distributions 11 | 12 | class Encoder(torch.nn.Module): 13 | def __init__(self, D_in, H, latent_size): 14 | super(Encoder, self).__init__() 15 | self.linear1 = torch.nn.Linear(D_in, H) 16 | self.linear2 = torch.nn.Linear(H, H) 17 | self.enc_mu = torch.nn.Linear(H, latent_size) 18 | self.enc_log_sigma = torch.nn.Linear(H, latent_size) 19 | 20 | def forward(self, x): 21 | x = F.relu(self.linear1(x)) 22 | x = F.relu(self.linear2(x)) 23 | mu = self.enc_mu(x) 24 | log_sigma = self.enc_log_sigma(x) 25 | sigma = torch.exp(log_sigma) 26 | return torch.distributions.Normal(loc=mu, scale=sigma) 27 | 28 | 29 | class Decoder(torch.nn.Module): 30 | def __init__(self, D_in, H, D_out): 31 | super(Decoder, self).__init__() 32 | self.linear1 = torch.nn.Linear(D_in, H) 33 | self.linear2 = torch.nn.Linear(H, D_out) 34 | 35 | 36 | def forward(self, x): 37 | x = F.relu(self.linear1(x)) 38 | mu = torch.tanh(self.linear2(x)) 39 | return torch.distributions.Normal(mu, torch.ones_like(mu)) 40 | 41 | class VAE(torch.nn.Module): 42 | def __init__(self, encoder, decoder): 43 | super(VAE, self).__init__() 44 | self.encoder = encoder 45 | self.decoder = decoder 46 | 47 | def forward(self, state): 48 | q_z = self.encoder(state) 49 | z = q_z.rsample() 50 | return self.decoder(z), q_z 51 | 52 | 53 | transform = transforms.Compose( 54 | [transforms.ToTensor(), 55 | # Normalize the images to be -0.5, 0.5 56 | transforms.Normalize(0.5, 1)] 57 | ) 58 | mnist = torchvision.datasets.MNIST('./', download=True, transform=transform) 59 | 60 | input_dim = 28 * 28 61 | batch_size = 128 62 | num_epochs = 100 63 | learning_rate = 0.001 64 | hidden_size = 512 65 | latent_size = 8 66 | 67 | if torch.cuda.is_available(): 68 | device = torch.device('cuda') 69 | else: 70 | device = torch.device('cpu') 71 | 72 | dataloader = torch.utils.data.DataLoader( 73 | mnist, batch_size=batch_size, 74 | shuffle=True, 75 | pin_memory=torch.cuda.is_available()) 76 | 77 | print('Number of samples: ', len(mnist)) 78 | 79 | encoder = Encoder(input_dim, hidden_size, latent_size) 80 | decoder = Decoder(latent_size, hidden_size, input_dim) 81 | 82 | vae = VAE(encoder, decoder).to(device) 83 | 84 | optimizer = optim.Adam(vae.parameters(), lr=learning_rate) 85 | for epoch in range(num_epochs): 86 | for data in dataloader: 87 | inputs, _ = data 88 | inputs = inputs.view(-1, input_dim).to(device) 89 | optimizer.zero_grad() 90 | p_x, q_z = vae(inputs) 91 | log_likelihood = p_x.log_prob(inputs).sum(-1).mean() 92 | kl = torch.distributions.kl_divergence( 93 | q_z, 94 | torch.distributions.Normal(0, 1.) 95 | ).sum(-1).mean() 96 | loss = -(log_likelihood - kl) 97 | loss.backward() 98 | optimizer.step() 99 | l = loss.item() 100 | print(epoch, l, log_likelihood.item(), kl.item()) 101 | --------------------------------------------------------------------------------