├── imgs ├── trans.jpg ├── vec_math.jpg ├── Epoch_28_data.jpg └── Epoch_28_recon.jpg ├── README.md └── src ├── util.py └── vanila_vae.py /imgs/trans.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bhpfelix/Variational-Autoencoder-PyTorch/HEAD/imgs/trans.jpg -------------------------------------------------------------------------------- /imgs/vec_math.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bhpfelix/Variational-Autoencoder-PyTorch/HEAD/imgs/vec_math.jpg -------------------------------------------------------------------------------- /imgs/Epoch_28_data.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bhpfelix/Variational-Autoencoder-PyTorch/HEAD/imgs/Epoch_28_data.jpg -------------------------------------------------------------------------------- /imgs/Epoch_28_recon.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bhpfelix/Variational-Autoencoder-PyTorch/HEAD/imgs/Epoch_28_recon.jpg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Variational Autoencoder for face image generation in PyTorch 2 | Variational Autoencoder for face image generation implemented with PyTorch, Trained over a combination of CelebA + FaceScrub + JAFFE datasets. 3 | 4 | Based on Deep Feature Consistent Variational Autoencoder (https://arxiv.org/abs/1610.00291 | https://github.com/houxianxu/DFC-VAE) 5 | 6 | TODO: Add DFC-VAE implementation 7 | 8 | Pretrained model available at https://drive.google.com/open?id=0B4y-iigc5IzcTlJfYlJyaF9ndlU 9 | 10 | ## Results 11 | Original Faces vs. Reconstructed Faces: 12 |
13 | 14 | 15 |
16 | 17 | Linear interpolation between two face images: 18 |
19 | 20 |
21 | 22 | Vector arithmatic in latent space: 23 |
24 | 25 |
26 | -------------------------------------------------------------------------------- /src/util.py: -------------------------------------------------------------------------------- 1 | import pickle as pk 2 | import sys 3 | 4 | 5 | 6 | 7 | ############################################################ 8 | ### IO 9 | ############################################################ 10 | def disp_to_term(msg): 11 | sys.stdout.write(msg + '\r') 12 | sys.stdout.flush() 13 | 14 | def load_pickle(filename): 15 | try: 16 | p = open(filename, 'r') 17 | except IOError: 18 | print "Pickle file cannot be opened." 19 | return None 20 | try: 21 | picklelicious = pk.load(p) 22 | except ValueError: 23 | print 'load_pickle failed once, trying again' 24 | p.close() 25 | p = open(filename, 'r') 26 | picklelicious = pk.load(p) 27 | 28 | p.close() 29 | return picklelicious 30 | 31 | def save_pickle(data_object, filename): 32 | pickle_file = open(filename, 'w') 33 | pk.dump(data_object, pickle_file) 34 | pickle_file.close() -------------------------------------------------------------------------------- /src/vanila_vae.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | import torch.utils.data 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | from torch.autograd import Variable 8 | import torchvision 9 | from torchvision import datasets, transforms 10 | import matplotlib.pyplot as plt 11 | import time 12 | from glob import glob 13 | from util import * 14 | import numpy as np 15 | from PIL import Image 16 | 17 | parser = argparse.ArgumentParser(description='PyTorch VAE') 18 | parser.add_argument('--batch-size', type=int, default=128, metavar='N', 19 | help='input batch size for training (default: 128)') 20 | parser.add_argument('--epochs', type=int, default=20, metavar='N', 21 | help='number of epochs to train (default: 20)') 22 | parser.add_argument('--no-cuda', action='store_true', default=False, 23 | help='enables CUDA training') 24 | parser.add_argument('--seed', type=int, default=1, metavar='S', 25 | help='random seed (default: 1)') 26 | parser.add_argument('--log-interval', type=int, default=1, metavar='N', 27 | help='how many batches to wait before logging training status') 28 | 29 | args = parser.parse_args() 30 | args.cuda = not args.no_cuda and torch.cuda.is_available() 31 | 32 | torch.manual_seed(args.seed) 33 | if args.cuda: 34 | torch.cuda.manual_seed(args.seed) 35 | 36 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} 37 | train_loader = range(2080) 38 | test_loader = range(40) 39 | 40 | totensor = transforms.ToTensor() 41 | def load_batch(batch_idx, istrain): 42 | if istrain: 43 | template = '../data/train/%s.jpg' 44 | else: 45 | template = '../data/test/%s.jpg' 46 | l = [str(batch_idx*128 + i).zfill(6) for i in range(128)] 47 | data = [] 48 | for idx in l: 49 | img = Image.open(template%idx) 50 | data.append(np.array(img)) 51 | data = [totensor(i) for i in data] 52 | return torch.stack(data, dim=0) 53 | 54 | 55 | class VAE(nn.Module): 56 | def __init__(self, nc, ngf, ndf, latent_variable_size): 57 | super(VAE, self).__init__() 58 | 59 | self.nc = nc 60 | self.ngf = ngf 61 | self.ndf = ndf 62 | self.latent_variable_size = latent_variable_size 63 | 64 | # encoder 65 | self.e1 = nn.Conv2d(nc, ndf, 4, 2, 1) 66 | self.bn1 = nn.BatchNorm2d(ndf) 67 | 68 | self.e2 = nn.Conv2d(ndf, ndf*2, 4, 2, 1) 69 | self.bn2 = nn.BatchNorm2d(ndf*2) 70 | 71 | self.e3 = nn.Conv2d(ndf*2, ndf*4, 4, 2, 1) 72 | self.bn3 = nn.BatchNorm2d(ndf*4) 73 | 74 | self.e4 = nn.Conv2d(ndf*4, ndf*8, 4, 2, 1) 75 | self.bn4 = nn.BatchNorm2d(ndf*8) 76 | 77 | self.e5 = nn.Conv2d(ndf*8, ndf*8, 4, 2, 1) 78 | self.bn5 = nn.BatchNorm2d(ndf*8) 79 | 80 | self.fc1 = nn.Linear(ndf*8*4*4, latent_variable_size) 81 | self.fc2 = nn.Linear(ndf*8*4*4, latent_variable_size) 82 | 83 | # decoder 84 | self.d1 = nn.Linear(latent_variable_size, ngf*8*2*4*4) 85 | 86 | self.up1 = nn.UpsamplingNearest2d(scale_factor=2) 87 | self.pd1 = nn.ReplicationPad2d(1) 88 | self.d2 = nn.Conv2d(ngf*8*2, ngf*8, 3, 1) 89 | self.bn6 = nn.BatchNorm2d(ngf*8, 1.e-3) 90 | 91 | self.up2 = nn.UpsamplingNearest2d(scale_factor=2) 92 | self.pd2 = nn.ReplicationPad2d(1) 93 | self.d3 = nn.Conv2d(ngf*8, ngf*4, 3, 1) 94 | self.bn7 = nn.BatchNorm2d(ngf*4, 1.e-3) 95 | 96 | self.up3 = nn.UpsamplingNearest2d(scale_factor=2) 97 | self.pd3 = nn.ReplicationPad2d(1) 98 | self.d4 = nn.Conv2d(ngf*4, ngf*2, 3, 1) 99 | self.bn8 = nn.BatchNorm2d(ngf*2, 1.e-3) 100 | 101 | self.up4 = nn.UpsamplingNearest2d(scale_factor=2) 102 | self.pd4 = nn.ReplicationPad2d(1) 103 | self.d5 = nn.Conv2d(ngf*2, ngf, 3, 1) 104 | self.bn9 = nn.BatchNorm2d(ngf, 1.e-3) 105 | 106 | self.up5 = nn.UpsamplingNearest2d(scale_factor=2) 107 | self.pd5 = nn.ReplicationPad2d(1) 108 | self.d6 = nn.Conv2d(ngf, nc, 3, 1) 109 | 110 | self.leakyrelu = nn.LeakyReLU(0.2) 111 | self.relu = nn.ReLU() 112 | self.sigmoid = nn.Sigmoid() 113 | 114 | def encode(self, x): 115 | h1 = self.leakyrelu(self.bn1(self.e1(x))) 116 | h2 = self.leakyrelu(self.bn2(self.e2(h1))) 117 | h3 = self.leakyrelu(self.bn3(self.e3(h2))) 118 | h4 = self.leakyrelu(self.bn4(self.e4(h3))) 119 | h5 = self.leakyrelu(self.bn5(self.e5(h4))) 120 | h5 = h5.view(-1, self.ndf*8*4*4) 121 | 122 | return self.fc1(h5), self.fc2(h5) 123 | 124 | def reparametrize(self, mu, logvar): 125 | std = logvar.mul(0.5).exp_() 126 | if args.cuda: 127 | eps = torch.cuda.FloatTensor(std.size()).normal_() 128 | else: 129 | eps = torch.FloatTensor(std.size()).normal_() 130 | eps = Variable(eps) 131 | return eps.mul(std).add_(mu) 132 | 133 | def decode(self, z): 134 | h1 = self.relu(self.d1(z)) 135 | h1 = h1.view(-1, self.ngf*8*2, 4, 4) 136 | h2 = self.leakyrelu(self.bn6(self.d2(self.pd1(self.up1(h1))))) 137 | h3 = self.leakyrelu(self.bn7(self.d3(self.pd2(self.up2(h2))))) 138 | h4 = self.leakyrelu(self.bn8(self.d4(self.pd3(self.up3(h3))))) 139 | h5 = self.leakyrelu(self.bn9(self.d5(self.pd4(self.up4(h4))))) 140 | 141 | return self.sigmoid(self.d6(self.pd5(self.up5(h5)))) 142 | 143 | def get_latent_var(self, x): 144 | mu, logvar = self.encode(x.view(-1, self.nc, self.ndf, self.ngf)) 145 | z = self.reparametrize(mu, logvar) 146 | return z 147 | 148 | def forward(self, x): 149 | mu, logvar = self.encode(x.view(-1, self.nc, self.ndf, self.ngf)) 150 | z = self.reparametrize(mu, logvar) 151 | res = self.decode(z) 152 | return res, mu, logvar 153 | 154 | 155 | model = VAE(nc=3, ngf=128, ndf=128, latent_variable_size=500) 156 | 157 | if args.cuda: 158 | model.cuda() 159 | 160 | reconstruction_function = nn.BCELoss() 161 | reconstruction_function.size_average = False 162 | def loss_function(recon_x, x, mu, logvar): 163 | BCE = reconstruction_function(recon_x, x) 164 | 165 | # https://arxiv.org/abs/1312.6114 (Appendix B) 166 | # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) 167 | KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) 168 | KLD = torch.sum(KLD_element).mul_(-0.5) 169 | 170 | return BCE + KLD 171 | 172 | optimizer = optim.Adam(model.parameters(), lr=1e-4) 173 | 174 | def train(epoch): 175 | model.train() 176 | train_loss = 0 177 | for batch_idx in train_loader: 178 | data = load_batch(batch_idx, True) 179 | data = Variable(data) 180 | if args.cuda: 181 | data = data.cuda() 182 | optimizer.zero_grad() 183 | recon_batch, mu, logvar = model(data) 184 | loss = loss_function(recon_batch, data, mu, logvar) 185 | loss.backward() 186 | train_loss += loss.data[0] 187 | optimizer.step() 188 | if batch_idx % args.log_interval == 0: 189 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 190 | epoch, batch_idx * len(data), (len(train_loader)*128), 191 | 100. * batch_idx / len(train_loader), 192 | loss.data[0] / len(data))) 193 | 194 | print('====> Epoch: {} Average loss: {:.4f}'.format( 195 | epoch, train_loss / (len(train_loader)*128))) 196 | return train_loss / (len(train_loader)*128) 197 | 198 | def test(epoch): 199 | model.eval() 200 | test_loss = 0 201 | for batch_idx in test_loader: 202 | data = load_batch(batch_idx, False) 203 | data = Variable(data, volatile=True) 204 | if args.cuda: 205 | data = data.cuda() 206 | recon_batch, mu, logvar = model(data) 207 | test_loss += loss_function(recon_batch, data, mu, logvar).data[0] 208 | 209 | torchvision.utils.save_image(data.data, '../imgs/Epoch_{}_data.jpg'.format(epoch), nrow=8, padding=2) 210 | torchvision.utils.save_image(recon_batch.data, '../imgs/Epoch_{}_recon.jpg'.format(epoch), nrow=8, padding=2) 211 | 212 | test_loss /= (len(test_loader)*128) 213 | print('====> Test set loss: {:.4f}'.format(test_loss)) 214 | return test_loss 215 | 216 | 217 | def perform_latent_space_arithmatics(items): # input is list of tuples of 3 [(a1,b1,c1), (a2,b2,c2)] 218 | load_last_model() 219 | model.eval() 220 | data = [im for item in items for im in item] 221 | data = [totensor(i) for i in data] 222 | data = torch.stack(data, dim=0) 223 | data = Variable(data, volatile=True) 224 | if args.cuda: 225 | data = data.cuda() 226 | z = model.get_latent_var(data.view(-1, model.nc, model.ndf, model.ngf)) 227 | it = iter(z.split(1)) 228 | z = zip(it, it, it) 229 | zs = [] 230 | numsample = 11 231 | for i,j,k in z: 232 | for factor in np.linspace(0,1,numsample): 233 | zs.append((i-j)*factor+k) 234 | z = torch.cat(zs, 0) 235 | recon = model.decode(z) 236 | 237 | it1 = iter(data.split(1)) 238 | it2 = [iter(recon.split(1))]*numsample 239 | result = zip(it1, it1, it1, *it2) 240 | result = [im for item in result for im in item] 241 | 242 | result = torch.cat(result, 0) 243 | torchvision.utils.save_image(result.data, '../imgs/vec_math.jpg', nrow=3+numsample, padding=2) 244 | 245 | 246 | def latent_space_transition(items): # input is list of tuples of (a,b) 247 | load_last_model() 248 | model.eval() 249 | data = [im for item in items for im in item[:-1]] 250 | data = [totensor(i) for i in data] 251 | data = torch.stack(data, dim=0) 252 | data = Variable(data, volatile=True) 253 | if args.cuda: 254 | data = data.cuda() 255 | z = model.get_latent_var(data.view(-1, model.nc, model.ndf, model.ngf)) 256 | it = iter(z.split(1)) 257 | z = zip(it, it) 258 | zs = [] 259 | numsample = 11 260 | for i,j in z: 261 | for factor in np.linspace(0,1,numsample): 262 | zs.append(i+(j-i)*factor) 263 | z = torch.cat(zs, 0) 264 | recon = model.decode(z) 265 | 266 | it1 = iter(data.split(1)) 267 | it2 = [iter(recon.split(1))]*numsample 268 | result = zip(it1, it1, *it2) 269 | result = [im for item in result for im in item] 270 | 271 | result = torch.cat(result, 0) 272 | torchvision.utils.save_image(result.data, '../imgs/trans.jpg', nrow=2+numsample, padding=2) 273 | 274 | 275 | def rand_faces(num=5): 276 | load_last_model() 277 | model.eval() 278 | z = torch.randn(num*num, model.latent_variable_size) 279 | z = Variable(z, volatile=True) 280 | if args.cuda: 281 | z = z.cuda() 282 | recon = model.decode(z) 283 | torchvision.utils.save_image(recon.data, '../imgs/rand_faces.jpg', nrow=num, padding=2) 284 | 285 | def load_last_model(): 286 | models = glob('../models/*.pth') 287 | model_ids = [(int(f.split('_')[1]), f) for f in models] 288 | start_epoch, last_cp = max(model_ids, key=lambda item:item[0]) 289 | print('Last checkpoint: ', last_cp) 290 | model.load_state_dict(torch.load(last_cp)) 291 | return start_epoch, last_cp 292 | 293 | def resume_training(): 294 | start_epoch, _ = load_last_model() 295 | 296 | for epoch in range(start_epoch + 1, start_epoch + args.epochs + 1): 297 | train_loss = train(epoch) 298 | test_loss = test(epoch) 299 | torch.save(model.state_dict(), '../models/Epoch_{}_Train_loss_{:.4f}_Test_loss_{:.4f}.pth'.format(epoch, train_loss, test_loss)) 300 | 301 | def last_model_to_cpu(): 302 | _, last_cp = load_last_model() 303 | model.cpu() 304 | torch.save(model.state_dict(), '../models/cpu_'+last_cp.split('/')[-1]) 305 | 306 | if __name__ == '__main__': 307 | resume_training() 308 | # last_model_to_cpu() 309 | # load_last_model() 310 | # rand_faces(10) 311 | # da = load_pickle(test_loader[0]) 312 | # da = da[:120] 313 | # it = iter(da) 314 | # l = zip(it, it, it) 315 | # # latent_space_transition(l) 316 | # perform_latent_space_arithmatics(l) --------------------------------------------------------------------------------