├── figs ├── 1519649452.702026 │ ├── E9I937.png │ └── E9-Dist.png └── 1519649461.195146 │ ├── E9I937.png │ └── E9-Dist.png ├── requirements.txt ├── utils.py ├── README.md ├── models.py └── train.py /figs/1519649452.702026/E9I937.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timbmg/VAE-CVAE-MNIST/HEAD/figs/1519649452.702026/E9I937.png -------------------------------------------------------------------------------- /figs/1519649461.195146/E9I937.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timbmg/VAE-CVAE-MNIST/HEAD/figs/1519649461.195146/E9I937.png -------------------------------------------------------------------------------- /figs/1519649452.702026/E9-Dist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timbmg/VAE-CVAE-MNIST/HEAD/figs/1519649452.702026/E9-Dist.png -------------------------------------------------------------------------------- /figs/1519649461.195146/E9-Dist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timbmg/VAE-CVAE-MNIST/HEAD/figs/1519649461.195146/E9-Dist.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cycler==0.10.0 2 | kiwisolver==1.0.1 3 | matplotlib==3.0.2 4 | numpy==1.22.0 5 | pandas==0.24.1 6 | Pillow>=6.2.2 7 | pyparsing==2.3.1 8 | python-dateutil==2.8.0 9 | pytz==2018.9 10 | scipy==1.10.0 11 | seaborn==0.9.0 12 | six==1.12.0 13 | torch==2.2.0 14 | torchvision==0.2.1 15 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def idx2onehot(idx, n): 5 | 6 | assert torch.max(idx).item() < n 7 | 8 | if idx.dim() == 1: 9 | idx = idx.unsqueeze(1) 10 | onehot = torch.zeros(idx.size(0), n).to(idx.device) 11 | onehot.scatter_(1, idx, 1) 12 | 13 | return onehot 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Variational Autoencoder & Conditional Variational Autoenoder on MNIST 2 | 3 | VAE paper: [Auto-Encoding Variational Bayes](https://arxiv.org/abs/1312.6114) 4 | 5 | CVAE paper: [Semi-supervised Learning with Deep Generative Models](https://proceedings.neurips.cc/paper/2014/hash/d523773c6b194f37b938d340d5d02232-Abstract.html) 6 | 7 | --- 8 | In order to run _conditional_ variational autoencoder, add `--conditional` to the the command. Check out the other commandline options in the code for hyperparameter settings (like learning rate, batch size, encoder/decoder layer depth and size). 9 | 10 | --- 11 | 12 | ## Results 13 | 14 | All plots obtained after 10 epochs of training. Hyperparameters accordning to default settings in the code; not tuned. 15 | 16 | ### z ~ q(z|x) and q(z|x,c) 17 | The modeled latent distribution after 10 epochs and 100 samples per digit. 18 | 19 | VAE | CVAE 20 | --- | --- 21 | | 22 | 23 | ### p(x|z) and p(x|z,c) 24 | Randomly sampled z, and their output. For CVAE, each c has been given as input once. 25 | 26 | VAE | CVAE 27 | --- | --- 28 | | 29 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from utils import idx2onehot 5 | 6 | 7 | class VAE(nn.Module): 8 | 9 | def __init__(self, encoder_layer_sizes, latent_size, decoder_layer_sizes, 10 | conditional=False, num_labels=0): 11 | 12 | super().__init__() 13 | 14 | if conditional: 15 | assert num_labels > 0 16 | 17 | assert type(encoder_layer_sizes) == list 18 | assert type(latent_size) == int 19 | assert type(decoder_layer_sizes) == list 20 | 21 | self.latent_size = latent_size 22 | 23 | self.encoder = Encoder( 24 | encoder_layer_sizes, latent_size, conditional, num_labels) 25 | self.decoder = Decoder( 26 | decoder_layer_sizes, latent_size, conditional, num_labels) 27 | 28 | def forward(self, x, c=None): 29 | 30 | if x.dim() > 2: 31 | x = x.view(-1, 28*28) 32 | 33 | means, log_var = self.encoder(x, c) 34 | z = self.reparameterize(means, log_var) 35 | recon_x = self.decoder(z, c) 36 | 37 | return recon_x, means, log_var, z 38 | 39 | def reparameterize(self, mu, log_var): 40 | 41 | std = torch.exp(0.5 * log_var) 42 | eps = torch.randn_like(std) 43 | 44 | return mu + eps * std 45 | 46 | def inference(self, z, c=None): 47 | 48 | recon_x = self.decoder(z, c) 49 | 50 | return recon_x 51 | 52 | 53 | class Encoder(nn.Module): 54 | 55 | def __init__(self, layer_sizes, latent_size, conditional, num_labels): 56 | 57 | super().__init__() 58 | 59 | self.conditional = conditional 60 | if self.conditional: 61 | layer_sizes[0] += num_labels 62 | 63 | self.MLP = nn.Sequential() 64 | 65 | for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])): 66 | self.MLP.add_module( 67 | name="L{:d}".format(i), module=nn.Linear(in_size, out_size)) 68 | self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU()) 69 | 70 | self.linear_means = nn.Linear(layer_sizes[-1], latent_size) 71 | self.linear_log_var = nn.Linear(layer_sizes[-1], latent_size) 72 | 73 | def forward(self, x, c=None): 74 | 75 | if self.conditional: 76 | c = idx2onehot(c, n=10) 77 | x = torch.cat((x, c), dim=-1) 78 | 79 | x = self.MLP(x) 80 | 81 | means = self.linear_means(x) 82 | log_vars = self.linear_log_var(x) 83 | 84 | return means, log_vars 85 | 86 | 87 | class Decoder(nn.Module): 88 | 89 | def __init__(self, layer_sizes, latent_size, conditional, num_labels): 90 | 91 | super().__init__() 92 | 93 | self.MLP = nn.Sequential() 94 | 95 | self.conditional = conditional 96 | if self.conditional: 97 | input_size = latent_size + num_labels 98 | else: 99 | input_size = latent_size 100 | 101 | for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)): 102 | self.MLP.add_module( 103 | name="L{:d}".format(i), module=nn.Linear(in_size, out_size)) 104 | if i+1 < len(layer_sizes): 105 | self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU()) 106 | else: 107 | self.MLP.add_module(name="sigmoid", module=nn.Sigmoid()) 108 | 109 | def forward(self, z, c): 110 | 111 | if self.conditional: 112 | c = idx2onehot(c, n=10) 113 | z = torch.cat((z, c), dim=-1) 114 | 115 | x = self.MLP(z) 116 | 117 | return x 118 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import argparse 5 | import pandas as pd 6 | import seaborn as sns 7 | import matplotlib.pyplot as plt 8 | from torchvision import transforms 9 | from torchvision.datasets import MNIST 10 | from torch.utils.data import DataLoader 11 | from collections import defaultdict 12 | 13 | from models import VAE 14 | 15 | 16 | def main(args): 17 | 18 | torch.manual_seed(args.seed) 19 | if torch.cuda.is_available(): 20 | torch.cuda.manual_seed(args.seed) 21 | 22 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 23 | 24 | ts = time.time() 25 | 26 | dataset = MNIST( 27 | root='data', train=True, transform=transforms.ToTensor(), 28 | download=True) 29 | data_loader = DataLoader( 30 | dataset=dataset, batch_size=args.batch_size, shuffle=True) 31 | 32 | def loss_fn(recon_x, x, mean, log_var): 33 | BCE = torch.nn.functional.binary_cross_entropy( 34 | recon_x.view(-1, 28*28), x.view(-1, 28*28), reduction='sum') 35 | KLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp()) 36 | 37 | return (BCE + KLD) / x.size(0) 38 | 39 | vae = VAE( 40 | encoder_layer_sizes=args.encoder_layer_sizes, 41 | latent_size=args.latent_size, 42 | decoder_layer_sizes=args.decoder_layer_sizes, 43 | conditional=args.conditional, 44 | num_labels=10 if args.conditional else 0).to(device) 45 | 46 | optimizer = torch.optim.Adam(vae.parameters(), lr=args.learning_rate) 47 | 48 | logs = defaultdict(list) 49 | 50 | for epoch in range(args.epochs): 51 | 52 | tracker_epoch = defaultdict(lambda: defaultdict(dict)) 53 | 54 | for iteration, (x, y) in enumerate(data_loader): 55 | 56 | x, y = x.to(device), y.to(device) 57 | 58 | if args.conditional: 59 | recon_x, mean, log_var, z = vae(x, y) 60 | else: 61 | recon_x, mean, log_var, z = vae(x) 62 | 63 | for i, yi in enumerate(y): 64 | id = len(tracker_epoch) 65 | tracker_epoch[id]['x'] = z[i, 0].item() 66 | tracker_epoch[id]['y'] = z[i, 1].item() 67 | tracker_epoch[id]['label'] = yi.item() 68 | 69 | loss = loss_fn(recon_x, x, mean, log_var) 70 | 71 | optimizer.zero_grad() 72 | loss.backward() 73 | optimizer.step() 74 | 75 | logs['loss'].append(loss.item()) 76 | 77 | if iteration % args.print_every == 0 or iteration == len(data_loader)-1: 78 | print("Epoch {:02d}/{:02d} Batch {:04d}/{:d}, Loss {:9.4f}".format( 79 | epoch, args.epochs, iteration, len(data_loader)-1, loss.item())) 80 | 81 | if args.conditional: 82 | c = torch.arange(0, 10).long().unsqueeze(1).to(device) 83 | z = torch.randn([c.size(0), args.latent_size]).to(device) 84 | x = vae.inference(z, c=c) 85 | else: 86 | z = torch.randn([10, args.latent_size]).to(device) 87 | x = vae.inference(z) 88 | 89 | plt.figure() 90 | plt.figure(figsize=(5, 10)) 91 | for p in range(10): 92 | plt.subplot(5, 2, p+1) 93 | if args.conditional: 94 | plt.text( 95 | 0, 0, "c={:d}".format(c[p].item()), color='black', 96 | backgroundcolor='white', fontsize=8) 97 | plt.imshow(x[p].view(28, 28).cpu().data.numpy()) 98 | plt.axis('off') 99 | 100 | if not os.path.exists(os.path.join(args.fig_root, str(ts))): 101 | if not(os.path.exists(os.path.join(args.fig_root))): 102 | os.mkdir(os.path.join(args.fig_root)) 103 | os.mkdir(os.path.join(args.fig_root, str(ts))) 104 | 105 | plt.savefig( 106 | os.path.join(args.fig_root, str(ts), 107 | "E{:d}I{:d}.png".format(epoch, iteration)), 108 | dpi=300) 109 | plt.clf() 110 | plt.close('all') 111 | 112 | df = pd.DataFrame.from_dict(tracker_epoch, orient='index') 113 | g = sns.lmplot( 114 | x='x', y='y', hue='label', data=df.groupby('label').head(100), 115 | fit_reg=False, legend=True) 116 | g.savefig(os.path.join( 117 | args.fig_root, str(ts), "E{:d}-Dist.png".format(epoch)), 118 | dpi=300) 119 | 120 | 121 | if __name__ == '__main__': 122 | 123 | parser = argparse.ArgumentParser() 124 | parser.add_argument("--seed", type=int, default=0) 125 | parser.add_argument("--epochs", type=int, default=10) 126 | parser.add_argument("--batch_size", type=int, default=64) 127 | parser.add_argument("--learning_rate", type=float, default=0.001) 128 | parser.add_argument("--encoder_layer_sizes", type=list, default=[784, 256]) 129 | parser.add_argument("--decoder_layer_sizes", type=list, default=[256, 784]) 130 | parser.add_argument("--latent_size", type=int, default=2) 131 | parser.add_argument("--print_every", type=int, default=100) 132 | parser.add_argument("--fig_root", type=str, default='figs') 133 | parser.add_argument("--conditional", action='store_true') 134 | 135 | args = parser.parse_args() 136 | 137 | main(args) 138 | --------------------------------------------------------------------------------