├── LICENSE ├── README.md ├── results ├── reconstruction_10.png └── sample_10.png └── vae_gumbel_softmax.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Devinder Kumar 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VAE with Gumbel-Softmax in Pytorch 2 | 3 | Pytorch implementation of a Variational Autoencoder with Gumbel-Softmax Distribution. Refer to the following paper: 4 | 5 | * [Categorical Reparametrization with Gumbel-Softmax](https://arxiv.org/pdf/1611.01144.pdf) by Jang, Gu and Poole 6 | 7 | 8 | ## Table of Contents 9 | * [Installation](#installation) 10 | * [Ananconda](#anaconda) 11 | * [Results](#results) 12 | 13 | ## Installation 14 | 15 | The program requires the following dependencies (easy to install using pip or Ananconda): 16 | 17 | * python 2.7/3.5 18 | * pytorch (version 0.3.1) 19 | * numpy 20 | 21 | 22 | 23 | ### Anaconda: Train 24 | 25 | Train VAE-Gumbel-Softmax model on the local machine using MNIST dataset: 26 | 27 | ```python 28 | python vae_gumbel_softmax.py 29 | ``` 30 | 31 | ## Results 32 | 33 | ### Hyperparameters 34 | ```python 35 | Batch Size: 128 36 | Learning Rate: 0.0001 37 | Initial Temperature: 1.0 38 | Minimum Temperature: 0.5 39 | Anneal Rate: 0.00003 40 | Learnable Temperature: False 41 | ``` 42 | 43 | ### MNIST 44 | | Ground Truth/Reconstructions | Generated Samples | 45 | |:--------------------------------: |:-------------------------:| 46 | |![](results/reconstruction_10.png) | ![](results/sample_10.png)| 47 | -------------------------------------------------------------------------------- /results/reconstruction_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dev4488/VAE_gumble_softmax/14f2e936c45805cdc929c923beab1993c2f9fd34/results/reconstruction_10.png -------------------------------------------------------------------------------- /results/sample_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dev4488/VAE_gumble_softmax/14f2e936c45805cdc929c923beab1993c2f9fd34/results/sample_10.png -------------------------------------------------------------------------------- /vae_gumbel_softmax.py: -------------------------------------------------------------------------------- 1 | # Code to implement VAE-gumple_softmax in pytorch 2 | # author: Devinder Kumar (devinder.kumar@uwaterloo.ca) 3 | # The code has been modified from pytorch example vae code and inspired by the origianl tensorflow implementation of gumble-softmax by Eric Jang. 4 | 5 | from __future__ import print_function 6 | import argparse 7 | import numpy as np 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.autograd import Variable 12 | from torch import nn, optim 13 | from torch.nn import functional as F 14 | from torchvision import datasets, transforms 15 | from torchvision.utils import save_image 16 | 17 | 18 | parser = argparse.ArgumentParser(description='VAE MNIST Example') 19 | parser.add_argument('--batch-size', type=int, default=128, metavar='N', 20 | help='input batch size for training (default: 128)') 21 | parser.add_argument('--epochs', type=int, default=10, metavar='N', 22 | help='number of epochs to train (default: 10)') 23 | parser.add_argument('--temp', type=float, default=1.0, metavar='S', 24 | help='tau(temperature) (default: 1.0)') 25 | parser.add_argument('--no-cuda', action='store_true', default=False, 26 | help='enables CUDA training') 27 | parser.add_argument('--seed', type=int, default=1, metavar='S', 28 | help='random seed (default: 1)') 29 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 30 | help='how many batches to wait before logging training status') 31 | 32 | 33 | 34 | args = parser.parse_args() 35 | args.cuda = not args.no_cuda and torch.cuda.is_available() 36 | 37 | 38 | torch.manual_seed(args.seed) 39 | if args.cuda: 40 | torch.cuda.manual_seed(args.seed) 41 | 42 | 43 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} 44 | train_loader = torch.utils.data.DataLoader( 45 | datasets.MNIST('../data', train=True, download=True, 46 | transform=transforms.ToTensor()), 47 | batch_size=args.batch_size, shuffle=True, **kwargs) 48 | test_loader = torch.utils.data.DataLoader( 49 | datasets.MNIST('../data', train=False, transform=transforms.ToTensor()), 50 | batch_size=args.batch_size, shuffle=True, **kwargs) 51 | 52 | def sample_gumbel(shape, eps=1e-20): 53 | U = torch.rand(shape).cuda() 54 | return -Variable(torch.log(-torch.log(U + eps) + eps)) 55 | 56 | def gumbel_softmax_sample(logits, temperature): 57 | y = logits + sample_gumbel(logits.size()) 58 | return F.softmax(y / temperature, dim=-1) 59 | 60 | def gumbel_softmax(logits, temperature): 61 | """ 62 | ST-gumple-softmax 63 | input: [*, n_class] 64 | return: flatten --> [*, n_class] an one-hot vector 65 | """ 66 | y = gumbel_softmax_sample(logits, temperature) 67 | shape = y.size() 68 | _, ind = y.max(dim=-1) 69 | y_hard = torch.zeros_like(y).view(-1, shape[-1]) 70 | y_hard.scatter_(1, ind.view(-1, 1), 1) 71 | y_hard = y_hard.view(*shape) 72 | y_hard = (y_hard - y).detach() + y 73 | return y_hard.view(-1,latent_dim*categorical_dim) 74 | 75 | 76 | 77 | class VAE_gumbel(nn.Module): 78 | 79 | def __init__(self,temp): 80 | super(VAE_gumbel, self).__init__() 81 | 82 | self.fc1 = nn.Linear(784, 512) 83 | self.fc2 = nn.Linear(512,256) 84 | self.fc3 = nn.Linear(256, latent_dim*categorical_dim) 85 | 86 | self.fc4 = nn.Linear(latent_dim*categorical_dim, 256) 87 | self.fc5 = nn.Linear(256,512) 88 | self.fc6 = nn.Linear(512,784) 89 | 90 | self.relu = nn.ReLU() 91 | self.sigmoid = nn.Sigmoid() 92 | 93 | def encode(self, x): 94 | h1 = self.relu(self.fc1(x)) 95 | h2 = self.relu(self.fc2(h1)) 96 | return self.relu(self.fc3(h2)) 97 | 98 | 99 | def decode(self, z): 100 | h4 = self.relu(self.fc4(z)) 101 | h5 = self.relu(self.fc5(h4)) 102 | return self.sigmoid(self.fc6(h5)) 103 | 104 | def forward(self, x,temp): 105 | q = self.encode(x.view(-1, 784)) 106 | q_y = q.view(q.size(0),latent_dim,categorical_dim) 107 | z = gumbel_softmax(q_y,temp) 108 | return self.decode(z),F.softmax(q) 109 | 110 | latent_dim = 20 111 | categorical_dim = 10 # one-of-K vector 112 | 113 | temp_min = 0.5 114 | ANNEAL_RATE = 0.00003 115 | 116 | 117 | model = VAE_gumbel(args.temp) 118 | if args.cuda: 119 | model.cuda() 120 | optimizer = optim.Adam(model.parameters(), lr=1e-3) 121 | 122 | 123 | # Reconstruction + KL divergence losses summed over all elements and batch 124 | def loss_function(recon_x, x,qy): 125 | BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), size_average=False) 126 | 127 | log_qy = torch.log(qy+1e-20) 128 | g = Variable(torch.log(torch.Tensor([1.0/categorical_dim])).cuda()) 129 | KLD = torch.sum(qy*(log_qy - g),dim=-1).mean() 130 | 131 | return BCE + KLD 132 | 133 | 134 | def train(epoch): 135 | model.train() 136 | train_loss = 0 137 | temp = args.temp 138 | for batch_idx, (data, _) in enumerate(train_loader): 139 | data = Variable(data) 140 | if args.cuda: 141 | data = data.cuda() 142 | optimizer.zero_grad() 143 | recon_batch,qy = model(data,temp) 144 | loss = loss_function(recon_batch, data,qy) 145 | loss.backward() 146 | train_loss += loss.data[0] 147 | optimizer.step() 148 | if batch_idx % 100 == 1: 149 | temp = np.maximum(temp*np.exp(-ANNEAL_RATE*batch_idx),temp_min) 150 | 151 | if batch_idx % args.log_interval == 0: 152 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 153 | epoch, batch_idx * len(data), len(train_loader.dataset), 154 | 100. * batch_idx / len(train_loader), 155 | loss.data[0] / len(data))) 156 | 157 | print('====> Epoch: {} Average loss: {:.4f}'.format( 158 | epoch, train_loss / len(train_loader.dataset))) 159 | 160 | 161 | def test(epoch): 162 | model.eval() 163 | test_loss = 0 164 | temp = args.temp 165 | for i, (data, _) in enumerate(test_loader): 166 | if args.cuda: 167 | data = data.cuda() 168 | data = Variable(data, volatile=True) 169 | recon_batch,qy = model(data,temp) 170 | test_loss += loss_function(recon_batch, data,qy).data[0] 171 | if i % 100 == 1: 172 | temp = np.maximum(temp*np.exp(-ANNEAL_RATE*i),temp_min) 173 | if i == 0: 174 | n = min(data.size(0), 8) 175 | comparison = torch.cat([data[:n], 176 | recon_batch.view(args.batch_size, 1, 28, 28)[:n]]) 177 | save_image(comparison.data.cpu(), 178 | 'results/reconstruction_' + str(epoch) + '.png', nrow=n) 179 | 180 | test_loss /= len(test_loader.dataset) 181 | print('====> Test set loss: {:.4f}'.format(test_loss)) 182 | 183 | 184 | def run(): 185 | for epoch in range(1, args.epochs + 1): 186 | train(epoch) 187 | test(epoch) 188 | sample = Variable(torch.randn(64, 200)) 189 | if args.cuda: 190 | sample = sample.cuda() 191 | sample = model.decode(sample).cpu() 192 | save_image(sample.data.view(64, 1, 28, 28), 193 | 'results/sample_' + str(epoch) + '.png') 194 | 195 | 196 | 197 | if __name__ == '__main__': 198 | run() 199 | --------------------------------------------------------------------------------