├── requirements.txt ├── README.md ├── train_vae.py └── model.py /requirements.txt: -------------------------------------------------------------------------------- 1 | torch=1.4.0 2 | numpy 3 | argparse 4 | tqdm 5 | torchvision=0.5.0 6 | tensorboard 7 | scipy 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Simple and Effective VAE training with σ-VAE in PyTorch 2 | 3 | [[Project page]](https://orybkin.github.io/sigma-vae/) [[Colab]](https://colab.research.google.com/drive/1mQr1SkiSQLhCSsVaj4R7XLcinknwHiV8?usp=sharing) [[TensorFlow implementation]](https://github.com/orybkin/sigma-vae-tensorflow) 4 | 5 | This is the PyTorch implementation of the σ-VAE paper. See the σ-VAE project page for more info, results, and alternative 6 | implementations. Also see the Colab version of this repo to train a sigma-VAE with zero setup needed! 7 | 8 | This implementation is based on the VAE from PyTorch [examples](https://github.com/pytorch/examples/blob/master/vae/main.py). In contrast to the original implementation, the σ-VAE 9 | achieves good results without tuning the heuristic weight beta since the decoder variance balances the objective. 10 | It is also very easy to implement, check out individual commits to see the few lines of code you need to add this to your VAE.! 11 | 12 | ## Installation 13 | ``` 14 | git clone https://github.com/orybkin/sigma-vae-pytorch.git 15 | cd sigma-vae-pytorch/ 16 | pip3 install torch numpy argparse tqdm torchvision scipy tensorboard 17 | ``` 18 | 19 | ## How to run it 20 | 21 | This repo implements several VAE versions. 22 | 23 | First, a VAE from the original PyTorch example repo that uses MSE loss. This implementation works very poorly because 24 | the MSE loss averages the pixels instead of summing them. Don't do this! You have to sum the loss across pixels and 25 | latent dimensions according to the definition of multivariate Gaussian (and other) distributions. 26 | ``` 27 | python train_vae.py --log_dir mse_vae --model mse_vae 28 | ``` 29 | 30 | Summing the loss works a bit better and is equivalent to the Gaussian negative log likelihood (NLL) with a certain, constant 31 | variance. This second model uses the Gaussian NLL as the reconstruction term. However, since the variance is constant 32 | it is still unable to balance the reconstruction and KL divergence term. 33 | ``` 34 | python train_vae.py --log_dir gaussian_vae --model gaussian_vae 35 | ``` 36 | 37 | The third model is the σ-VAE. It learns the variance of the decoding distribution, which works significantly better and produces 38 | high-quality samples. This is because learning the variance automatically balances the VAE objective. One could balance 39 | the objective manually by using beta-VAE, however, this is not required when learning the variance! 40 | ``` 41 | python train_vae.py --log_dir sigma_vae --model sigma_vae 42 | ``` 43 | 44 | Finally, optimal sigma-VAE uses a batch-wise analytic estimate of the variance, which speeds up learning and improves results. 45 | It is also extremely easy to implement! 46 | ``` 47 | python train_vae.py --log_dir optimal_sigma_vae --model optimal_sigma_vae 48 | ``` 49 | 50 | -------------------------------------------------------------------------------- /train_vae.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | from torch import optim 6 | from torch.utils.tensorboard import SummaryWriter 7 | from torchvision import datasets, transforms 8 | from torchvision.utils import save_image 9 | from tqdm import tqdm 10 | 11 | from model import ConvVAE 12 | 13 | """ This script is an example of Sigma VAE training in PyTorch. The code was adapted from: 14 | https://github.com/pytorch/examples/blob/master/vae/main.py """ 15 | 16 | ## Arguments 17 | parser = argparse.ArgumentParser(description='VAE MNIST Example') 18 | parser.add_argument('--batch-size', type=int, default=128, metavar='N', 19 | help='input batch size for training (default: 128)') 20 | parser.add_argument('--epochs', type=int, default=10, metavar='N', 21 | help='number of epochs to train (default: 10)') 22 | parser.add_argument('--no-cuda', action='store_true', default=False, 23 | help='enables CUDA training') 24 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 25 | help='how many batches to wait before logging training status') 26 | parser.add_argument('--model', type=str, default='mse', metavar='N', 27 | help='which model to use: mse_vae, gaussian_vae, or sigma_vae or optimal_sigma_vae') 28 | parser.add_argument('--log_dir', type=str, default='test', metavar='N', required=True) 29 | args = parser.parse_args() 30 | 31 | ## Cuda 32 | args.cuda = not args.no_cuda and torch.cuda.is_available() 33 | device = torch.device("cuda" if args.cuda else "cpu") 34 | 35 | ## Dataset 36 | transform = transforms.Compose([transforms.Resize((28, 28)), transforms.ToTensor()]) 37 | train_dataset = datasets.SVHN('../../data', split='train', download=True, transform=transform) 38 | test_dataset = datasets.SVHN('../../data', split='train', transform=transform) 39 | kwargs = {'num_workers': 10, 'pin_memory': True} if args.cuda else {} 40 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs) 41 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=True, **kwargs) 42 | 43 | ## Logging 44 | os.makedirs('vae_logs/{}'.format(args.log_dir), exist_ok=True) 45 | summary_writer = SummaryWriter(log_dir='vae_logs/' + args.log_dir, purge_step=0) 46 | 47 | ## Build Model 48 | model = ConvVAE(device, 3, args).to(device) 49 | optimizer = optim.Adam(model.parameters(), lr=1e-3) 50 | 51 | 52 | def train(epoch): 53 | model.train() 54 | train_loss = 0 55 | for batch_idx, (data, _) in enumerate(train_loader): 56 | data = data.to(device) 57 | optimizer.zero_grad() 58 | 59 | # Run VAE 60 | recon_batch, mu, logvar = model(data) 61 | # Compute loss 62 | rec, kl = model.loss_function(recon_batch, data, mu, logvar) 63 | 64 | total_loss = rec + kl 65 | total_loss.backward() 66 | train_loss += total_loss.item() 67 | optimizer.step() 68 | 69 | if batch_idx % args.log_interval == 0: 70 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tMSE: {:.6f}\tKL: {:.6f}\tlog_sigma: {:f}'.format( 71 | epoch, batch_idx * len(data), len(train_loader.dataset), 72 | 100. * batch_idx / len(train_loader), 73 | rec.item() / len(data), 74 | kl.item() / len(data), 75 | model.log_sigma)) 76 | 77 | train_loss /= len(train_loader.dataset) 78 | print('====> Epoch: {} Average loss: {:.4f}'.format( 79 | epoch, train_loss)) 80 | summary_writer.add_scalar('train/elbo', train_loss, epoch) 81 | summary_writer.add_scalar('train/rec', rec.item() / len(data), epoch) 82 | summary_writer.add_scalar('train/kld', kl.item() / len(data), epoch) 83 | summary_writer.add_scalar('train/log_sigma', model.log_sigma, epoch) 84 | 85 | 86 | def test(epoch): 87 | model.eval() 88 | test_loss = 0 89 | with torch.no_grad(): 90 | for i, (data, _) in enumerate(tqdm(test_loader)): 91 | data = data.to(device) 92 | recon_batch, mu, logvar = model(data) 93 | # Pass the second value from posthoc VAE 94 | rec, kl = model.loss_function(recon_batch, data, mu, logvar) 95 | test_loss += rec + kl 96 | if i == 0: 97 | n = min(data.size(0), 8) 98 | comparison = torch.cat([data[:n], recon_batch.view(args.batch_size, -1, 28, 28)[:n]]) 99 | save_image(comparison.cpu(), 'vae_logs/{}/reconstruction_{}.png'.format(args.log_dir, str(epoch)), nrow=n) 100 | 101 | test_loss /= len(test_loader.dataset) 102 | print('====> Test set loss: {:.4f}'.format(test_loss)) 103 | summary_writer.add_scalar('test/elbo', test_loss, epoch) 104 | 105 | 106 | if __name__ == "__main__": 107 | for epoch in range(1, args.epochs + 1): 108 | train(epoch) 109 | test(epoch) 110 | with torch.no_grad(): 111 | sample = model.sample(64).cpu() 112 | save_image(sample.view(64, -1, 28, 28), 113 | 'vae_logs/{}/sample_{}.png'.format(args.log_dir, str(epoch))) 114 | summary_writer.file_writer.flush() 115 | 116 | torch.save(model.state_dict(), 'vae_logs/{}/checkpoint_{}.pt'.format(args.log_dir, str(epoch))) 117 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | from torch import nn 4 | import numpy as np 5 | import torch.nn.functional as F 6 | 7 | 8 | def softclip(tensor, min): 9 | """ Clips the tensor values at the minimum value min in a softway. Taken from Handful of Trials """ 10 | result_tensor = min + F.softplus(tensor - min) 11 | 12 | return result_tensor 13 | 14 | 15 | class Flatten(nn.Module): 16 | def forward(self, input): 17 | return input.view(input.size(0), -1) 18 | 19 | 20 | class UnFlatten(nn.Module): 21 | def __init__(self, n_channels): 22 | super(UnFlatten, self).__init__() 23 | self.n_channels = n_channels 24 | 25 | def forward(self, input): 26 | size = int((input.size(1) // self.n_channels) ** 0.5) 27 | return input.view(input.size(0), self.n_channels, size, size) 28 | 29 | 30 | class ConvVAE(nn.Module): 31 | def __init__(self, device='cuda', img_channels=3, args=None): 32 | super().__init__() 33 | self.batch_size = args.batch_size 34 | self.device = device 35 | self.z_dim = 20 36 | self.img_channels = img_channels 37 | self.model = args.model 38 | img_size = 28 39 | filters_m = 32 40 | 41 | ## Build network 42 | self.encoder = self.get_encoder(self.img_channels, filters_m) 43 | 44 | # output size depends on input image size, compute the output size 45 | demo_input = torch.ones([1, self.img_channels, img_size, img_size]) 46 | h_dim = self.encoder(demo_input).shape[1] 47 | print('h_dim', h_dim) 48 | 49 | # map to latent z 50 | self.fc11 = nn.Linear(h_dim, self.z_dim) 51 | self.fc12 = nn.Linear(h_dim, self.z_dim) 52 | 53 | # decoder 54 | self.fc2 = nn.Linear(self.z_dim, h_dim) 55 | self.decoder = self.get_decoder(filters_m, self.img_channels) 56 | 57 | self.log_sigma = 0 58 | if self.model == 'sigma_vae': 59 | ## Sigma VAE 60 | self.log_sigma = torch.nn.Parameter(torch.full((1,), 0)[0], requires_grad=args.model == 'sigma_vae') 61 | 62 | @staticmethod 63 | def get_encoder(img_channels, filters_m): 64 | return nn.Sequential( 65 | nn.Conv2d(img_channels, filters_m, (3, 3), stride=1, padding=1), 66 | nn.ReLU(), 67 | nn.Conv2d(filters_m, 2 * filters_m, (4, 4), stride=2, padding=1), 68 | nn.ReLU(), 69 | nn.Conv2d(2 * filters_m, 4 * filters_m, (5, 5), stride=2, padding=2), 70 | nn.ReLU(), 71 | Flatten() 72 | ) 73 | 74 | @staticmethod 75 | def get_decoder(filters_m, out_channels): 76 | return nn.Sequential( 77 | UnFlatten(4 * filters_m), 78 | nn.ConvTranspose2d(4 * filters_m, 2 * filters_m, (6, 6), stride=2, padding=2), 79 | nn.ReLU(), 80 | nn.ConvTranspose2d(2 * filters_m, filters_m, (6, 6), stride=2, padding=2), 81 | nn.ReLU(), 82 | nn.ConvTranspose2d(filters_m, out_channels, (5, 5), stride=1, padding=2), 83 | nn.Sigmoid(), 84 | ) 85 | 86 | def encode(self, x): 87 | h = self.encoder(x) 88 | return self.fc11(h), self.fc12(h) 89 | 90 | def reparameterize(self, mu, logvar): 91 | std = torch.exp(0.5 * logvar) 92 | eps = torch.randn_like(std) 93 | return mu + eps * std 94 | 95 | def decode(self, z): 96 | return self.decoder(self.fc2(z)) 97 | 98 | def forward(self, x): 99 | mu, logvar = self.encode(x) 100 | z = self.reparameterize(mu, logvar) 101 | return self.decode(z), mu, logvar 102 | 103 | def sample(self, n): 104 | sample = torch.randn(n, self.z_dim).to(self.device) 105 | return self.decode(sample) 106 | 107 | def reconstruction_loss(self, x_hat, x): 108 | """ Computes the likelihood of the data given the latent variable, 109 | in this case using a Gaussian distribution with mean predicted by the neural network and variance = 1 """ 110 | 111 | if self.model == 'gaussian_vae': 112 | # Naive gaussian VAE uses a constant variance 113 | log_sigma = torch.zeros([], device=x_hat.device) 114 | elif self.model == 'sigma_vae': 115 | # Sigma VAE learns the variance of the decoder as another parameter 116 | log_sigma = self.log_sigma 117 | elif self.model == 'optimal_sigma_vae': 118 | log_sigma = ((x - x_hat) ** 2).mean([0,1,2,3], keepdim=True).sqrt().log() 119 | self.log_sigma = log_sigma.item() 120 | else: 121 | raise NotImplementedError 122 | 123 | # Learning the variance can become unstable in some cases. Softly limiting log_sigma to a minimum of -6 124 | # ensures stable training. 125 | log_sigma = softclip(log_sigma, -6) 126 | 127 | rec = gaussian_nll(x_hat, log_sigma, x).sum() 128 | 129 | return rec 130 | 131 | def loss_function(self, recon_x, x, mu, logvar): 132 | # Important: both reconstruction and KL divergence loss have to be summed over all element! 133 | # Here we also sum the over batch and divide by the number of elements in the data later 134 | if self.model == 'mse_vae': 135 | rec = torch.nn.MSELoss()(recon_x, x) 136 | else: 137 | rec = self.reconstruction_loss(recon_x, x) 138 | 139 | # see Appendix B from VAE paper: 140 | # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 141 | # https://arxiv.org/abs/1312.6114 142 | # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) 143 | kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 144 | 145 | return rec, kl 146 | 147 | 148 | def gaussian_nll(mu, log_sigma, x): 149 | return 0.5 * torch.pow((x - mu) / log_sigma.exp(), 2) + log_sigma + 0.5 * np.log(2 * np.pi) 150 | 151 | --------------------------------------------------------------------------------