├── .gitignore ├── bvae-pytorch ├── README.md ├── data_utils.py ├── models.py ├── train.py └── traversal_75_2000.png └── bvae-tf ├── README.md ├── data_utils.py ├── models.py ├── train.py └── traversal_40_1000.png /.gitignore: -------------------------------------------------------------------------------- 1 | **/results/ 2 | **/__pycache__/ 3 | *.swp 4 | -------------------------------------------------------------------------------- /bvae-pytorch/README.md: -------------------------------------------------------------------------------- 1 | Latent traversals after training for 75 epochs. The traversals are constructed by obtaining the posterior of the first datapoint in the dSprites dataset, and then setting the invidual latents to values ranging from -3 to 3. Each column represents the traversal of one latent. 2 | 3 | The PyTorch version seems to work worse than Tensorflow's implementation, as it gives entangled reconsturctions. This is probably due to the fact that the authors of the original paper used Tensorflow in their experiments. Due to suddle differences between the frameworks, PyTorch's version likely needs an additional hyperparameter search. 4 | 5 | ![](traversal_75_2000.png) 6 | -------------------------------------------------------------------------------- /bvae-pytorch/data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import Dataset, DataLoader 3 | 4 | 5 | DSPRITES_PATH = '/home/stensootla/projects/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz' 6 | 7 | 8 | class SpritesDataset(Dataset): 9 | def __init__(self, dataset_path): 10 | dataset_zip = np.load(DSPRITES_PATH, encoding='latin1') 11 | self.imgs = dataset_zip['imgs'] 12 | 13 | def __len__(self): 14 | return self.imgs.shape[0] 15 | 16 | def __getitem__(self, idx): 17 | return self.imgs[idx].astype(np.float32) 18 | 19 | 20 | def get_dataloader(batch_size, shuffle=True): 21 | sprite_dataset = SpritesDataset(dataset_path=DSPRITES_PATH) 22 | return DataLoader(sprite_dataset, batch_size=batch_size, shuffle=shuffle) 23 | 24 | -------------------------------------------------------------------------------- /bvae-pytorch/models.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.autograd import Variable 3 | 4 | 5 | class DenseVAE(nn.Module): 6 | def __init__(self, nb_latents): 7 | super(DenseVAE, self).__init__() 8 | 9 | self.fc1 = nn.Linear(4096, 1200) 10 | self.fc2 = nn.Linear(1200, 1200) 11 | self.fc_mean = nn.Linear(1200, nb_latents) 12 | self.fc_std = nn.Linear(1200, nb_latents) 13 | self.fc3 = nn.Linear(nb_latents, 1200) 14 | self.fc4 = nn.Linear(1200, 1200) 15 | self.fc5 = nn.Linear(1200, 1200) 16 | self.fc6 = nn.Linear(1200, 4096) 17 | 18 | self.relu = nn.ReLU() 19 | self.tanh = nn.Tanh() 20 | self.sigmoid = nn.Sigmoid() 21 | 22 | def encode(self, x): 23 | x = self.relu(self.fc1(x)) 24 | x = self.relu(self.fc2(x)) 25 | return self.fc_mean(x), self.fc_std(x) 26 | 27 | def reparameterize(self, mu, logvar): 28 | if self.training: 29 | std = logvar.mul(0.5).exp_() 30 | eps = Variable(std.data.new(std.size()).normal_()) 31 | return eps.mul(std).add_(mu) 32 | else: 33 | return mu 34 | 35 | def decode(self, z): 36 | x = self.tanh(self.fc3(z)) 37 | x = self.tanh(self.fc4(x)) 38 | x = self.tanh(self.fc5(x)) 39 | return self.sigmoid(self.fc6(x)) 40 | 41 | def forward(self, x): 42 | mu, logvar = self.encode(x.view(-1, 64*64)) 43 | z = self.reparameterize(mu, logvar) 44 | return self.decode(z), mu, logvar 45 | 46 | 47 | class ConvVAE(nn.Module): 48 | def __init__(self, nb_latents): 49 | super(ConvVAE, self).__init__() 50 | self.conv1 = nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1) 51 | self.conv2 = nn.Conv2d(32, 32, kernel_size=4, stride=2, padding=1) 52 | self.conv3 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1) 53 | self.conv4 = nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1) 54 | self.fc1 = nn.Linear(64*4*4, 256) 55 | 56 | self.fc_mean = nn.Linear(256, nb_latents) 57 | self.fc_std = nn.Linear(256, nb_latents) 58 | 59 | self.fc2 = nn.Linear(nb_latents, 256) 60 | self.fc3 = nn.Linear(256, 64*4*4) 61 | self.deconv1 = nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1) 62 | self.deconv2 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1) 63 | self.deconv3 = nn.ConvTranspose2d(32, 32, kernel_size=4, stride=2, padding=1) 64 | self.deconv4 = nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1) 65 | 66 | self.relu = nn.ReLU() 67 | self.sigmoid = nn.Sigmoid() 68 | 69 | def encode(self, x): 70 | x = self.relu(self.conv1(x)) 71 | x = self.relu(self.conv2(x)) 72 | x = self.relu(self.conv3(x)) 73 | x = self.relu(self.conv4(x)) 74 | x = self.relu(self.fc1(x.view(-1, 64*4*4))) 75 | return self.fc_mean(x), self.fc_std(x) 76 | 77 | def reparameterize(self, mu, logvar): 78 | if self.training: 79 | std = logvar.mul(0.5).exp_() 80 | eps = Variable(std.data.new(std.size()).normal_()) 81 | return eps.mul(std).add_(mu) 82 | else: 83 | return mu 84 | 85 | def decode(self, z): 86 | x = self.relu(self.fc2(z)) 87 | x = self.relu(self.fc3(x)) 88 | x = self.relu(self.deconv1(x.view(-1, 64, 4, 4))) 89 | x = self.relu(self.deconv2(x)) 90 | x = self.relu(self.deconv3(x)) 91 | return self.sigmoid(self.deconv4(x)) 92 | 93 | def forward(self, x): 94 | mu, logvar = self.encode(x.unsqueeze(1)) 95 | z = self.reparameterize(mu, logvar) 96 | return self.decode(z), mu, logvar 97 | 98 | -------------------------------------------------------------------------------- /bvae-pytorch/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import time 4 | import argparse 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | 8 | import torch 9 | from torch import optim 10 | from torch.autograd import Variable 11 | from torch.nn import functional as F 12 | from torchvision.utils import save_image 13 | 14 | from models import DenseVAE, ConvVAE 15 | from data_utils import get_dataloader 16 | 17 | 18 | def _make_results_dir(dirpath='results'): 19 | if os.path.isdir('results'): 20 | shutil.rmtree('results') 21 | os.makedirs('results') 22 | 23 | 24 | def _traverse_latents(model, datapoint, nb_latents, epoch_nb, batch_idx, dirpath='results'): 25 | model.eval() 26 | 27 | if isinstance(model, ConvVAE): 28 | datapoint = datapoint.unsqueeze(0).unsqueeze(1) 29 | mu, _ = model.encode(datapoint) 30 | else: 31 | mu, _ = model.encode(datapoint.view(-1)) 32 | 33 | recons = torch.zeros((7, nb_latents, 64, 64)) 34 | for zi in range(nb_latents): 35 | muc = mu.squeeze().clone() 36 | for i, val in enumerate(np.linspace(-3, 3, 7)): 37 | muc[zi] = val 38 | recon = model.decode(muc).cpu() 39 | recons[i, zi] = recon.view(64, 64) 40 | 41 | filename = os.path.join(dirpath, 'traversal_' + str(epoch_nb) + '_' + str(batch_idx) + '.png') 42 | save_image(recons.view(-1, 1, 64, 64), filename, nrow=nb_latents, pad_value=1) 43 | 44 | 45 | def _loss_function(recon_x, x, mu, logvar, beta): 46 | """Reconstruction + KL divergence losses summed over all elements and batch.""" 47 | bce = F.binary_cross_entropy(recon_x, x.view(-1, 64*64), size_average=False) 48 | kld = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()) 49 | return kld.mean(dim=0), bce + beta*kld.sum() 50 | 51 | 52 | def train(args): 53 | _make_results_dir() 54 | dataloader = get_dataloader(args.batch_size) 55 | testpoint = torch.Tensor(dataloader.dataset[0]) 56 | if not args.no_cuda: testpoint = testpoint.cuda() 57 | 58 | model = DenseVAE(args.nb_latents) 59 | model.train() 60 | if not args.no_cuda: model.cuda() 61 | optimizer = optim.Adagrad(model.parameters(), lr=args.eta) 62 | 63 | runloss, runkld = None, np.array([]) 64 | start_time = time.time() 65 | 66 | for epoch_nb in range(1, args.epochs + 1): 67 | 68 | for batch_idx, data in enumerate(dataloader): 69 | if not args.no_cuda: data = data.cuda() 70 | 71 | recon_batch, mu, logvar = model(data) 72 | kld, loss = _loss_function(recon_batch, data, mu, logvar, args.beta) 73 | 74 | # param update 75 | optimizer.zero_grad() 76 | loss.backward() 77 | optimizer.step() 78 | 79 | loss /= len(data) 80 | runloss = loss if not runloss else runloss*0.99 + loss*0.01 81 | runkld = np.zeros(args.nb_latents) if not len(runkld) else runkld*0.99 + kld.data.cpu().numpy()*0.01 82 | 83 | if not batch_idx % args.log_interval: 84 | print("Epoch {}, batch: {}/{} ({:.2f} s), loss: {:.2f}, kl: [{}]".format( 85 | epoch_nb, batch_idx, len(dataloader), time.time() - start_time, runloss, 86 | ", ".join("{:.2f}".format(kl) for kl in runkld))) 87 | start_time = time.time() 88 | 89 | if not batch_idx % args.save_interval: 90 | _traverse_latents(model, testpoint, args.nb_latents, epoch_nb, batch_idx) 91 | model.train() 92 | 93 | 94 | def parse(): 95 | parser = argparse.ArgumentParser(description='train beta-VAE on the sprites dataset') 96 | parser.add_argument('--eta', type=float, default=1e-2, metavar='L', 97 | help='learning rate for Adagrad (default: 1e-2)') 98 | parser.add_argument('--beta', type=int, default=4, metavar='B', 99 | help='the beta coefficient (default: 4)') 100 | parser.add_argument('--nb-latents', type=int, default=10, metavar='N', 101 | help='number of latents (default: 10)') 102 | parser.add_argument('--batch-size', type=int, default=128, metavar='N', 103 | help='input batch size for training (default: 128)') 104 | parser.add_argument('--epochs', type=int, default=100, metavar='N', 105 | help='number of epochs to train (default: 10)') 106 | parser.add_argument('--no-cuda', action='store_true', default=False, 107 | help='enables CUDA training') 108 | parser.add_argument('--seed', type=int, default=1, metavar='S', 109 | help='random seed (default: 1)') 110 | parser.add_argument('--log-interval', type=int, default=100, metavar='N', 111 | help='how many batches to wait before logging training status') 112 | parser.add_argument('--save-interval', type=int, default=1000, metavar='T', 113 | help="how many batches to wait before saving latent traversal") 114 | return parser.parse_args() 115 | 116 | 117 | if __name__ == '__main__': 118 | args = parse() 119 | train(args) 120 | 121 | -------------------------------------------------------------------------------- /bvae-pytorch/traversal_75_2000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sootlasten/beta-vae/cf2d65ecf08c78e06d7ff53794afd3b3c4259844/bvae-pytorch/traversal_75_2000.png -------------------------------------------------------------------------------- /bvae-tf/README.md: -------------------------------------------------------------------------------- 1 | Latent traversals after training for 40 epochs. The traversals are constructed by obtaining the posterior of the first datapoint in the dSprites dataset, and then setting the invidual latents to values ranging from -3 to 3. Each column represents the traversal of one latent. 2 | 3 | ![](traversal_40_1000.png) 4 | -------------------------------------------------------------------------------- /bvae-tf/data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | DSPRITES_PATH = '/home/stensootla/projects/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz' 5 | 6 | 7 | class SpritesDataset(): 8 | def __init__(self, batch_size, dataset_path=DSPRITES_PATH): 9 | dataset_zip = np.load(dataset_path, encoding='latin1') 10 | self.imgs = dataset_zip['imgs'] 11 | self.batch_size = batch_size 12 | 13 | def __len__(self): 14 | return len(self.imgs) 15 | 16 | @property 17 | def nb_batches(self): 18 | return len(self) // self.batch_size + 1 19 | 20 | def gen(self, shuffle=True): 21 | if shuffle: np.random.shuffle(self.imgs) 22 | for i in range(0, len(self), self.batch_size): 23 | yield self.imgs[i:i+self.batch_size] 24 | 25 | -------------------------------------------------------------------------------- /bvae-tf/models.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import tensorflow as tf 4 | from tensorflow.contrib.layers import fully_connected 5 | 6 | 7 | class DenseVAE(): 8 | def __init__(self, x, z, nb_latents, beta, eta): 9 | self.x = x 10 | self.z = z 11 | 12 | self.nb_latents = nb_latents 13 | self.beta = beta 14 | self.eta = eta 15 | 16 | # call stuff 17 | self._reconstruct = None 18 | self._encode = None 19 | self._decode = None 20 | self._optimize = None 21 | 22 | self.encode 23 | self.reconstruct 24 | self.decode 25 | self.optimize 26 | 27 | @property 28 | def reconstruct(self): 29 | if self._reconstruct is None: 30 | mean, logvar = self.encode 31 | eps = tf.random_normal(tf.shape(mean)) 32 | z = mean + tf.sqrt(tf.exp(logvar))*eps 33 | self._reconstruct = self.__decode(z) 34 | return self._reconstruct 35 | 36 | @property 37 | def encode(self): 38 | if self._encode is None: 39 | fc1 = fully_connected(self.x, 1200) 40 | fc2 = fully_connected(fc1, 1200) 41 | mean = fully_connected(fc2, self.nb_latents, activation_fn=None) 42 | logvar = fully_connected(fc2, self.nb_latents, activation_fn=None) 43 | self._encode = mean, logvar 44 | return self._encode 45 | 46 | @property 47 | def decode(self): 48 | if self._decode is None: 49 | self._decode = self.__decode(self.z, reuse=True) 50 | return self._decode 51 | 52 | def __decode(self, z, reuse=False): 53 | with tf.variable_scope("decoder", reuse=reuse): 54 | fc1 = fully_connected(z, 1200, activation_fn=tf.nn.tanh) 55 | fc2 = fully_connected(fc1, 1200, activation_fn=tf.nn.tanh) 56 | fc3 = fully_connected(fc2, 1200, activation_fn=tf.nn.tanh) 57 | recon = fully_connected(fc3, 4096, activation_fn=None) 58 | return recon 59 | 60 | @property 61 | def optimize(self): 62 | if self._optimize is None: 63 | ce = tf.reduce_mean(tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits( 64 | labels=self.x, logits=self.reconstruct), axis=1)) 65 | mean, logvar = self.encode 66 | kl = -self.beta*0.5*(1 + logvar - tf.exp(logvar) - tf.square(mean)) 67 | loss = ce + tf.reduce_mean(tf.reduce_sum(kl, axis=1)) 68 | optimizer = tf.train.AdagradOptimizer(self.eta) 69 | self._optimize = optimizer.minimize(loss), loss, tf.reduce_mean(kl, axis=0) 70 | return self._optimize 71 | 72 | -------------------------------------------------------------------------------- /bvae-tf/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import time 4 | import argparse 5 | 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | import tensorflow as tf 9 | 10 | from models import DenseVAE 11 | from data_utils import SpritesDataset 12 | 13 | 14 | RESULTS_DIR = 'results' 15 | 16 | def sigmoid(x): 17 | return 1 / (1 + np.exp(-x)) 18 | 19 | 20 | def _make_results_dir(dirpath=RESULTS_DIR): 21 | if os.path.isdir(dirpath): 22 | shutil.rmtree(dirpath) 23 | os.makedirs(dirpath) 24 | 25 | 26 | def train(args): 27 | _make_results_dir() 28 | 29 | sprites = SpritesDataset(args.batch_size) 30 | testpoint = sprites.imgs[0].flatten()[np.newaxis, :] 31 | 32 | x = tf.placeholder(tf.float32, [None, 4096]) 33 | z = tf.placeholder(tf.float32, [None, args.nb_latents]) 34 | model = DenseVAE(x, z, args.nb_latents, args.beta, args.eta) 35 | 36 | runloss, runkl = None, np.array([]) 37 | start_time = time.time() 38 | 39 | with tf.Session() as sess: 40 | sess.run(tf.global_variables_initializer()) 41 | 42 | for epoch_nb in range(1, args.epochs + 1): 43 | datagen = sprites.gen() 44 | 45 | for batch_idx, batch_xs in enumerate(datagen): 46 | batch_xs = batch_xs.reshape(args.batch_size, -1) 47 | _, loss, kl = sess.run(model.optimize, feed_dict={x: batch_xs}) 48 | 49 | runloss = loss if not runloss else runloss*0.99 + loss*0.01 50 | runkl = np.zeros(args.nb_latents) if not len(runkl) else runkl*0.99 + kl*0.01 51 | 52 | if not batch_idx % args.log_interval: 53 | print("Epoch {}, batch: {}/{} ({:.2f} s), loss: {:.2f}, kl: [{}]".format( 54 | epoch_nb, batch_idx, sprites.nb_batches, time.time() - start_time, runloss, 55 | ", ".join("{:.2f}".format(kl) for kl in runkl))) 56 | start_time = time.time() 57 | 58 | # latent traversal 59 | if not batch_idx % args.save_interval: 60 | mu, _ = sess.run(model.encode, feed_dict={x: testpoint}) 61 | mus = np.tile(mu, (args.nb_trav, 1)) 62 | 63 | fig = plt.figure(figsize=(10,10)) 64 | for zi in range(args.nb_latents): 65 | muc = np.matrix.copy(mus) 66 | muc[:, zi] = np.linspace(-3, 3, args.nb_trav) 67 | recon = sess.run(model.decode, feed_dict={z: muc}) 68 | recon = sigmoid(recon) 69 | 70 | for i in range(args.nb_trav): 71 | ax = plt.subplot(args.nb_trav, args.nb_latents, i*args.nb_latents + (zi+1)) 72 | plt.axis('off') 73 | ax.imshow(recon[i].reshape(64, 64), cmap='gray') 74 | 75 | filename = os.path.join(RESULTS_DIR, 'traversal_' + str(epoch_nb) + '_' + str(batch_idx) + '.png') 76 | plt.savefig(filename) 77 | 78 | 79 | def parse(): 80 | parser = argparse.ArgumentParser(description='train beta-VAE on the sprites dataset') 81 | parser.add_argument('--eta', type=float, default=1e-2, metavar='L', 82 | help='learning rate for Adam (default: 1e-2)') 83 | parser.add_argument('--beta', type=int, default=4, metavar='B', 84 | help='the beta coefficient (default: 4)') 85 | parser.add_argument('--nb-latents', type=int, default=10, metavar='N', 86 | help='number of latents (default: 10)') 87 | parser.add_argument('--batch-size', type=int, default=128, metavar='N', 88 | help='input batch size for training (default: 128)') 89 | parser.add_argument('--epochs', type=int, default=100, metavar='N', 90 | help='number of epochs to train (default: 10)') 91 | parser.add_argument('--no-cuda', action='store_true', default=False, 92 | help='enables CUDA training') 93 | parser.add_argument('--log-interval', type=int, default=100, metavar='N', 94 | help='how many batches to wait before logging training status') 95 | parser.add_argument('--save-interval', type=int, default=1000, metavar='T', 96 | help="how many batches to wait before saving latent traversal") 97 | parser.add_argument('--nb_trav', type=int, default=7, metavar='T', 98 | help='how many point to choose linearly from interval [-3, 3] for \ 99 | latent traversal analysis') 100 | return parser.parse_args() 101 | 102 | 103 | if __name__ == '__main__': 104 | args = parse() 105 | train(args) 106 | 107 | -------------------------------------------------------------------------------- /bvae-tf/traversal_40_1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sootlasten/beta-vae/cf2d65ecf08c78e06d7ff53794afd3b3c4259844/bvae-tf/traversal_40_1000.png --------------------------------------------------------------------------------