├── figs
├── 1519649452.702026
│ ├── E9I937.png
│ └── E9-Dist.png
└── 1519649461.195146
│ ├── E9I937.png
│ └── E9-Dist.png
├── requirements.txt
├── utils.py
├── README.md
├── models.py
└── train.py
/figs/1519649452.702026/E9I937.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/timbmg/VAE-CVAE-MNIST/HEAD/figs/1519649452.702026/E9I937.png
--------------------------------------------------------------------------------
/figs/1519649461.195146/E9I937.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/timbmg/VAE-CVAE-MNIST/HEAD/figs/1519649461.195146/E9I937.png
--------------------------------------------------------------------------------
/figs/1519649452.702026/E9-Dist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/timbmg/VAE-CVAE-MNIST/HEAD/figs/1519649452.702026/E9-Dist.png
--------------------------------------------------------------------------------
/figs/1519649461.195146/E9-Dist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/timbmg/VAE-CVAE-MNIST/HEAD/figs/1519649461.195146/E9-Dist.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | cycler==0.10.0
2 | kiwisolver==1.0.1
3 | matplotlib==3.0.2
4 | numpy==1.22.0
5 | pandas==0.24.1
6 | Pillow>=6.2.2
7 | pyparsing==2.3.1
8 | python-dateutil==2.8.0
9 | pytz==2018.9
10 | scipy==1.10.0
11 | seaborn==0.9.0
12 | six==1.12.0
13 | torch==2.2.0
14 | torchvision==0.2.1
15 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def idx2onehot(idx, n):
5 |
6 | assert torch.max(idx).item() < n
7 |
8 | if idx.dim() == 1:
9 | idx = idx.unsqueeze(1)
10 | onehot = torch.zeros(idx.size(0), n).to(idx.device)
11 | onehot.scatter_(1, idx, 1)
12 |
13 | return onehot
14 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Variational Autoencoder & Conditional Variational Autoenoder on MNIST
2 |
3 | VAE paper: [Auto-Encoding Variational Bayes](https://arxiv.org/abs/1312.6114)
4 |
5 | CVAE paper: [Semi-supervised Learning with Deep Generative Models](https://proceedings.neurips.cc/paper/2014/hash/d523773c6b194f37b938d340d5d02232-Abstract.html)
6 |
7 | ---
8 | In order to run _conditional_ variational autoencoder, add `--conditional` to the the command. Check out the other commandline options in the code for hyperparameter settings (like learning rate, batch size, encoder/decoder layer depth and size).
9 |
10 | ---
11 |
12 | ## Results
13 |
14 | All plots obtained after 10 epochs of training. Hyperparameters accordning to default settings in the code; not tuned.
15 |
16 | ### z ~ q(z|x) and q(z|x,c)
17 | The modeled latent distribution after 10 epochs and 100 samples per digit.
18 |
19 | VAE | CVAE
20 | --- | ---
21 |
|
22 |
23 | ### p(x|z) and p(x|z,c)
24 | Randomly sampled z, and their output. For CVAE, each c has been given as input once.
25 |
26 | VAE | CVAE
27 | --- | ---
28 |
|
29 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from utils import idx2onehot
5 |
6 |
7 | class VAE(nn.Module):
8 |
9 | def __init__(self, encoder_layer_sizes, latent_size, decoder_layer_sizes,
10 | conditional=False, num_labels=0):
11 |
12 | super().__init__()
13 |
14 | if conditional:
15 | assert num_labels > 0
16 |
17 | assert type(encoder_layer_sizes) == list
18 | assert type(latent_size) == int
19 | assert type(decoder_layer_sizes) == list
20 |
21 | self.latent_size = latent_size
22 |
23 | self.encoder = Encoder(
24 | encoder_layer_sizes, latent_size, conditional, num_labels)
25 | self.decoder = Decoder(
26 | decoder_layer_sizes, latent_size, conditional, num_labels)
27 |
28 | def forward(self, x, c=None):
29 |
30 | if x.dim() > 2:
31 | x = x.view(-1, 28*28)
32 |
33 | means, log_var = self.encoder(x, c)
34 | z = self.reparameterize(means, log_var)
35 | recon_x = self.decoder(z, c)
36 |
37 | return recon_x, means, log_var, z
38 |
39 | def reparameterize(self, mu, log_var):
40 |
41 | std = torch.exp(0.5 * log_var)
42 | eps = torch.randn_like(std)
43 |
44 | return mu + eps * std
45 |
46 | def inference(self, z, c=None):
47 |
48 | recon_x = self.decoder(z, c)
49 |
50 | return recon_x
51 |
52 |
53 | class Encoder(nn.Module):
54 |
55 | def __init__(self, layer_sizes, latent_size, conditional, num_labels):
56 |
57 | super().__init__()
58 |
59 | self.conditional = conditional
60 | if self.conditional:
61 | layer_sizes[0] += num_labels
62 |
63 | self.MLP = nn.Sequential()
64 |
65 | for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
66 | self.MLP.add_module(
67 | name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
68 | self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
69 |
70 | self.linear_means = nn.Linear(layer_sizes[-1], latent_size)
71 | self.linear_log_var = nn.Linear(layer_sizes[-1], latent_size)
72 |
73 | def forward(self, x, c=None):
74 |
75 | if self.conditional:
76 | c = idx2onehot(c, n=10)
77 | x = torch.cat((x, c), dim=-1)
78 |
79 | x = self.MLP(x)
80 |
81 | means = self.linear_means(x)
82 | log_vars = self.linear_log_var(x)
83 |
84 | return means, log_vars
85 |
86 |
87 | class Decoder(nn.Module):
88 |
89 | def __init__(self, layer_sizes, latent_size, conditional, num_labels):
90 |
91 | super().__init__()
92 |
93 | self.MLP = nn.Sequential()
94 |
95 | self.conditional = conditional
96 | if self.conditional:
97 | input_size = latent_size + num_labels
98 | else:
99 | input_size = latent_size
100 |
101 | for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)):
102 | self.MLP.add_module(
103 | name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
104 | if i+1 < len(layer_sizes):
105 | self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
106 | else:
107 | self.MLP.add_module(name="sigmoid", module=nn.Sigmoid())
108 |
109 | def forward(self, z, c):
110 |
111 | if self.conditional:
112 | c = idx2onehot(c, n=10)
113 | z = torch.cat((z, c), dim=-1)
114 |
115 | x = self.MLP(z)
116 |
117 | return x
118 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import torch
4 | import argparse
5 | import pandas as pd
6 | import seaborn as sns
7 | import matplotlib.pyplot as plt
8 | from torchvision import transforms
9 | from torchvision.datasets import MNIST
10 | from torch.utils.data import DataLoader
11 | from collections import defaultdict
12 |
13 | from models import VAE
14 |
15 |
16 | def main(args):
17 |
18 | torch.manual_seed(args.seed)
19 | if torch.cuda.is_available():
20 | torch.cuda.manual_seed(args.seed)
21 |
22 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23 |
24 | ts = time.time()
25 |
26 | dataset = MNIST(
27 | root='data', train=True, transform=transforms.ToTensor(),
28 | download=True)
29 | data_loader = DataLoader(
30 | dataset=dataset, batch_size=args.batch_size, shuffle=True)
31 |
32 | def loss_fn(recon_x, x, mean, log_var):
33 | BCE = torch.nn.functional.binary_cross_entropy(
34 | recon_x.view(-1, 28*28), x.view(-1, 28*28), reduction='sum')
35 | KLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
36 |
37 | return (BCE + KLD) / x.size(0)
38 |
39 | vae = VAE(
40 | encoder_layer_sizes=args.encoder_layer_sizes,
41 | latent_size=args.latent_size,
42 | decoder_layer_sizes=args.decoder_layer_sizes,
43 | conditional=args.conditional,
44 | num_labels=10 if args.conditional else 0).to(device)
45 |
46 | optimizer = torch.optim.Adam(vae.parameters(), lr=args.learning_rate)
47 |
48 | logs = defaultdict(list)
49 |
50 | for epoch in range(args.epochs):
51 |
52 | tracker_epoch = defaultdict(lambda: defaultdict(dict))
53 |
54 | for iteration, (x, y) in enumerate(data_loader):
55 |
56 | x, y = x.to(device), y.to(device)
57 |
58 | if args.conditional:
59 | recon_x, mean, log_var, z = vae(x, y)
60 | else:
61 | recon_x, mean, log_var, z = vae(x)
62 |
63 | for i, yi in enumerate(y):
64 | id = len(tracker_epoch)
65 | tracker_epoch[id]['x'] = z[i, 0].item()
66 | tracker_epoch[id]['y'] = z[i, 1].item()
67 | tracker_epoch[id]['label'] = yi.item()
68 |
69 | loss = loss_fn(recon_x, x, mean, log_var)
70 |
71 | optimizer.zero_grad()
72 | loss.backward()
73 | optimizer.step()
74 |
75 | logs['loss'].append(loss.item())
76 |
77 | if iteration % args.print_every == 0 or iteration == len(data_loader)-1:
78 | print("Epoch {:02d}/{:02d} Batch {:04d}/{:d}, Loss {:9.4f}".format(
79 | epoch, args.epochs, iteration, len(data_loader)-1, loss.item()))
80 |
81 | if args.conditional:
82 | c = torch.arange(0, 10).long().unsqueeze(1).to(device)
83 | z = torch.randn([c.size(0), args.latent_size]).to(device)
84 | x = vae.inference(z, c=c)
85 | else:
86 | z = torch.randn([10, args.latent_size]).to(device)
87 | x = vae.inference(z)
88 |
89 | plt.figure()
90 | plt.figure(figsize=(5, 10))
91 | for p in range(10):
92 | plt.subplot(5, 2, p+1)
93 | if args.conditional:
94 | plt.text(
95 | 0, 0, "c={:d}".format(c[p].item()), color='black',
96 | backgroundcolor='white', fontsize=8)
97 | plt.imshow(x[p].view(28, 28).cpu().data.numpy())
98 | plt.axis('off')
99 |
100 | if not os.path.exists(os.path.join(args.fig_root, str(ts))):
101 | if not(os.path.exists(os.path.join(args.fig_root))):
102 | os.mkdir(os.path.join(args.fig_root))
103 | os.mkdir(os.path.join(args.fig_root, str(ts)))
104 |
105 | plt.savefig(
106 | os.path.join(args.fig_root, str(ts),
107 | "E{:d}I{:d}.png".format(epoch, iteration)),
108 | dpi=300)
109 | plt.clf()
110 | plt.close('all')
111 |
112 | df = pd.DataFrame.from_dict(tracker_epoch, orient='index')
113 | g = sns.lmplot(
114 | x='x', y='y', hue='label', data=df.groupby('label').head(100),
115 | fit_reg=False, legend=True)
116 | g.savefig(os.path.join(
117 | args.fig_root, str(ts), "E{:d}-Dist.png".format(epoch)),
118 | dpi=300)
119 |
120 |
121 | if __name__ == '__main__':
122 |
123 | parser = argparse.ArgumentParser()
124 | parser.add_argument("--seed", type=int, default=0)
125 | parser.add_argument("--epochs", type=int, default=10)
126 | parser.add_argument("--batch_size", type=int, default=64)
127 | parser.add_argument("--learning_rate", type=float, default=0.001)
128 | parser.add_argument("--encoder_layer_sizes", type=list, default=[784, 256])
129 | parser.add_argument("--decoder_layer_sizes", type=list, default=[256, 784])
130 | parser.add_argument("--latent_size", type=int, default=2)
131 | parser.add_argument("--print_every", type=int, default=100)
132 | parser.add_argument("--fig_root", type=str, default='figs')
133 | parser.add_argument("--conditional", action='store_true')
134 |
135 | args = parser.parse_args()
136 |
137 | main(args)
138 |
--------------------------------------------------------------------------------