├── imgs
├── trans.jpg
├── vec_math.jpg
├── Epoch_28_data.jpg
└── Epoch_28_recon.jpg
├── README.md
└── src
├── util.py
└── vanila_vae.py
/imgs/trans.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bhpfelix/Variational-Autoencoder-PyTorch/HEAD/imgs/trans.jpg
--------------------------------------------------------------------------------
/imgs/vec_math.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bhpfelix/Variational-Autoencoder-PyTorch/HEAD/imgs/vec_math.jpg
--------------------------------------------------------------------------------
/imgs/Epoch_28_data.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bhpfelix/Variational-Autoencoder-PyTorch/HEAD/imgs/Epoch_28_data.jpg
--------------------------------------------------------------------------------
/imgs/Epoch_28_recon.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bhpfelix/Variational-Autoencoder-PyTorch/HEAD/imgs/Epoch_28_recon.jpg
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Variational Autoencoder for face image generation in PyTorch
2 | Variational Autoencoder for face image generation implemented with PyTorch, Trained over a combination of CelebA + FaceScrub + JAFFE datasets.
3 |
4 | Based on Deep Feature Consistent Variational Autoencoder (https://arxiv.org/abs/1610.00291 | https://github.com/houxianxu/DFC-VAE)
5 |
6 | TODO: Add DFC-VAE implementation
7 |
8 | Pretrained model available at https://drive.google.com/open?id=0B4y-iigc5IzcTlJfYlJyaF9ndlU
9 |
10 | ## Results
11 | Original Faces vs. Reconstructed Faces:
12 |
13 |

14 |

15 |
16 |
17 | Linear interpolation between two face images:
18 |
19 |

20 |
21 |
22 | Vector arithmatic in latent space:
23 |
24 |

25 |
26 |
--------------------------------------------------------------------------------
/src/util.py:
--------------------------------------------------------------------------------
1 | import pickle as pk
2 | import sys
3 |
4 |
5 |
6 |
7 | ############################################################
8 | ### IO
9 | ############################################################
10 | def disp_to_term(msg):
11 | sys.stdout.write(msg + '\r')
12 | sys.stdout.flush()
13 |
14 | def load_pickle(filename):
15 | try:
16 | p = open(filename, 'r')
17 | except IOError:
18 | print "Pickle file cannot be opened."
19 | return None
20 | try:
21 | picklelicious = pk.load(p)
22 | except ValueError:
23 | print 'load_pickle failed once, trying again'
24 | p.close()
25 | p = open(filename, 'r')
26 | picklelicious = pk.load(p)
27 |
28 | p.close()
29 | return picklelicious
30 |
31 | def save_pickle(data_object, filename):
32 | pickle_file = open(filename, 'w')
33 | pk.dump(data_object, pickle_file)
34 | pickle_file.close()
--------------------------------------------------------------------------------
/src/vanila_vae.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import argparse
3 | import torch
4 | import torch.utils.data
5 | import torch.nn as nn
6 | import torch.optim as optim
7 | from torch.autograd import Variable
8 | import torchvision
9 | from torchvision import datasets, transforms
10 | import matplotlib.pyplot as plt
11 | import time
12 | from glob import glob
13 | from util import *
14 | import numpy as np
15 | from PIL import Image
16 |
17 | parser = argparse.ArgumentParser(description='PyTorch VAE')
18 | parser.add_argument('--batch-size', type=int, default=128, metavar='N',
19 | help='input batch size for training (default: 128)')
20 | parser.add_argument('--epochs', type=int, default=20, metavar='N',
21 | help='number of epochs to train (default: 20)')
22 | parser.add_argument('--no-cuda', action='store_true', default=False,
23 | help='enables CUDA training')
24 | parser.add_argument('--seed', type=int, default=1, metavar='S',
25 | help='random seed (default: 1)')
26 | parser.add_argument('--log-interval', type=int, default=1, metavar='N',
27 | help='how many batches to wait before logging training status')
28 |
29 | args = parser.parse_args()
30 | args.cuda = not args.no_cuda and torch.cuda.is_available()
31 |
32 | torch.manual_seed(args.seed)
33 | if args.cuda:
34 | torch.cuda.manual_seed(args.seed)
35 |
36 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
37 | train_loader = range(2080)
38 | test_loader = range(40)
39 |
40 | totensor = transforms.ToTensor()
41 | def load_batch(batch_idx, istrain):
42 | if istrain:
43 | template = '../data/train/%s.jpg'
44 | else:
45 | template = '../data/test/%s.jpg'
46 | l = [str(batch_idx*128 + i).zfill(6) for i in range(128)]
47 | data = []
48 | for idx in l:
49 | img = Image.open(template%idx)
50 | data.append(np.array(img))
51 | data = [totensor(i) for i in data]
52 | return torch.stack(data, dim=0)
53 |
54 |
55 | class VAE(nn.Module):
56 | def __init__(self, nc, ngf, ndf, latent_variable_size):
57 | super(VAE, self).__init__()
58 |
59 | self.nc = nc
60 | self.ngf = ngf
61 | self.ndf = ndf
62 | self.latent_variable_size = latent_variable_size
63 |
64 | # encoder
65 | self.e1 = nn.Conv2d(nc, ndf, 4, 2, 1)
66 | self.bn1 = nn.BatchNorm2d(ndf)
67 |
68 | self.e2 = nn.Conv2d(ndf, ndf*2, 4, 2, 1)
69 | self.bn2 = nn.BatchNorm2d(ndf*2)
70 |
71 | self.e3 = nn.Conv2d(ndf*2, ndf*4, 4, 2, 1)
72 | self.bn3 = nn.BatchNorm2d(ndf*4)
73 |
74 | self.e4 = nn.Conv2d(ndf*4, ndf*8, 4, 2, 1)
75 | self.bn4 = nn.BatchNorm2d(ndf*8)
76 |
77 | self.e5 = nn.Conv2d(ndf*8, ndf*8, 4, 2, 1)
78 | self.bn5 = nn.BatchNorm2d(ndf*8)
79 |
80 | self.fc1 = nn.Linear(ndf*8*4*4, latent_variable_size)
81 | self.fc2 = nn.Linear(ndf*8*4*4, latent_variable_size)
82 |
83 | # decoder
84 | self.d1 = nn.Linear(latent_variable_size, ngf*8*2*4*4)
85 |
86 | self.up1 = nn.UpsamplingNearest2d(scale_factor=2)
87 | self.pd1 = nn.ReplicationPad2d(1)
88 | self.d2 = nn.Conv2d(ngf*8*2, ngf*8, 3, 1)
89 | self.bn6 = nn.BatchNorm2d(ngf*8, 1.e-3)
90 |
91 | self.up2 = nn.UpsamplingNearest2d(scale_factor=2)
92 | self.pd2 = nn.ReplicationPad2d(1)
93 | self.d3 = nn.Conv2d(ngf*8, ngf*4, 3, 1)
94 | self.bn7 = nn.BatchNorm2d(ngf*4, 1.e-3)
95 |
96 | self.up3 = nn.UpsamplingNearest2d(scale_factor=2)
97 | self.pd3 = nn.ReplicationPad2d(1)
98 | self.d4 = nn.Conv2d(ngf*4, ngf*2, 3, 1)
99 | self.bn8 = nn.BatchNorm2d(ngf*2, 1.e-3)
100 |
101 | self.up4 = nn.UpsamplingNearest2d(scale_factor=2)
102 | self.pd4 = nn.ReplicationPad2d(1)
103 | self.d5 = nn.Conv2d(ngf*2, ngf, 3, 1)
104 | self.bn9 = nn.BatchNorm2d(ngf, 1.e-3)
105 |
106 | self.up5 = nn.UpsamplingNearest2d(scale_factor=2)
107 | self.pd5 = nn.ReplicationPad2d(1)
108 | self.d6 = nn.Conv2d(ngf, nc, 3, 1)
109 |
110 | self.leakyrelu = nn.LeakyReLU(0.2)
111 | self.relu = nn.ReLU()
112 | self.sigmoid = nn.Sigmoid()
113 |
114 | def encode(self, x):
115 | h1 = self.leakyrelu(self.bn1(self.e1(x)))
116 | h2 = self.leakyrelu(self.bn2(self.e2(h1)))
117 | h3 = self.leakyrelu(self.bn3(self.e3(h2)))
118 | h4 = self.leakyrelu(self.bn4(self.e4(h3)))
119 | h5 = self.leakyrelu(self.bn5(self.e5(h4)))
120 | h5 = h5.view(-1, self.ndf*8*4*4)
121 |
122 | return self.fc1(h5), self.fc2(h5)
123 |
124 | def reparametrize(self, mu, logvar):
125 | std = logvar.mul(0.5).exp_()
126 | if args.cuda:
127 | eps = torch.cuda.FloatTensor(std.size()).normal_()
128 | else:
129 | eps = torch.FloatTensor(std.size()).normal_()
130 | eps = Variable(eps)
131 | return eps.mul(std).add_(mu)
132 |
133 | def decode(self, z):
134 | h1 = self.relu(self.d1(z))
135 | h1 = h1.view(-1, self.ngf*8*2, 4, 4)
136 | h2 = self.leakyrelu(self.bn6(self.d2(self.pd1(self.up1(h1)))))
137 | h3 = self.leakyrelu(self.bn7(self.d3(self.pd2(self.up2(h2)))))
138 | h4 = self.leakyrelu(self.bn8(self.d4(self.pd3(self.up3(h3)))))
139 | h5 = self.leakyrelu(self.bn9(self.d5(self.pd4(self.up4(h4)))))
140 |
141 | return self.sigmoid(self.d6(self.pd5(self.up5(h5))))
142 |
143 | def get_latent_var(self, x):
144 | mu, logvar = self.encode(x.view(-1, self.nc, self.ndf, self.ngf))
145 | z = self.reparametrize(mu, logvar)
146 | return z
147 |
148 | def forward(self, x):
149 | mu, logvar = self.encode(x.view(-1, self.nc, self.ndf, self.ngf))
150 | z = self.reparametrize(mu, logvar)
151 | res = self.decode(z)
152 | return res, mu, logvar
153 |
154 |
155 | model = VAE(nc=3, ngf=128, ndf=128, latent_variable_size=500)
156 |
157 | if args.cuda:
158 | model.cuda()
159 |
160 | reconstruction_function = nn.BCELoss()
161 | reconstruction_function.size_average = False
162 | def loss_function(recon_x, x, mu, logvar):
163 | BCE = reconstruction_function(recon_x, x)
164 |
165 | # https://arxiv.org/abs/1312.6114 (Appendix B)
166 | # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
167 | KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
168 | KLD = torch.sum(KLD_element).mul_(-0.5)
169 |
170 | return BCE + KLD
171 |
172 | optimizer = optim.Adam(model.parameters(), lr=1e-4)
173 |
174 | def train(epoch):
175 | model.train()
176 | train_loss = 0
177 | for batch_idx in train_loader:
178 | data = load_batch(batch_idx, True)
179 | data = Variable(data)
180 | if args.cuda:
181 | data = data.cuda()
182 | optimizer.zero_grad()
183 | recon_batch, mu, logvar = model(data)
184 | loss = loss_function(recon_batch, data, mu, logvar)
185 | loss.backward()
186 | train_loss += loss.data[0]
187 | optimizer.step()
188 | if batch_idx % args.log_interval == 0:
189 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
190 | epoch, batch_idx * len(data), (len(train_loader)*128),
191 | 100. * batch_idx / len(train_loader),
192 | loss.data[0] / len(data)))
193 |
194 | print('====> Epoch: {} Average loss: {:.4f}'.format(
195 | epoch, train_loss / (len(train_loader)*128)))
196 | return train_loss / (len(train_loader)*128)
197 |
198 | def test(epoch):
199 | model.eval()
200 | test_loss = 0
201 | for batch_idx in test_loader:
202 | data = load_batch(batch_idx, False)
203 | data = Variable(data, volatile=True)
204 | if args.cuda:
205 | data = data.cuda()
206 | recon_batch, mu, logvar = model(data)
207 | test_loss += loss_function(recon_batch, data, mu, logvar).data[0]
208 |
209 | torchvision.utils.save_image(data.data, '../imgs/Epoch_{}_data.jpg'.format(epoch), nrow=8, padding=2)
210 | torchvision.utils.save_image(recon_batch.data, '../imgs/Epoch_{}_recon.jpg'.format(epoch), nrow=8, padding=2)
211 |
212 | test_loss /= (len(test_loader)*128)
213 | print('====> Test set loss: {:.4f}'.format(test_loss))
214 | return test_loss
215 |
216 |
217 | def perform_latent_space_arithmatics(items): # input is list of tuples of 3 [(a1,b1,c1), (a2,b2,c2)]
218 | load_last_model()
219 | model.eval()
220 | data = [im for item in items for im in item]
221 | data = [totensor(i) for i in data]
222 | data = torch.stack(data, dim=0)
223 | data = Variable(data, volatile=True)
224 | if args.cuda:
225 | data = data.cuda()
226 | z = model.get_latent_var(data.view(-1, model.nc, model.ndf, model.ngf))
227 | it = iter(z.split(1))
228 | z = zip(it, it, it)
229 | zs = []
230 | numsample = 11
231 | for i,j,k in z:
232 | for factor in np.linspace(0,1,numsample):
233 | zs.append((i-j)*factor+k)
234 | z = torch.cat(zs, 0)
235 | recon = model.decode(z)
236 |
237 | it1 = iter(data.split(1))
238 | it2 = [iter(recon.split(1))]*numsample
239 | result = zip(it1, it1, it1, *it2)
240 | result = [im for item in result for im in item]
241 |
242 | result = torch.cat(result, 0)
243 | torchvision.utils.save_image(result.data, '../imgs/vec_math.jpg', nrow=3+numsample, padding=2)
244 |
245 |
246 | def latent_space_transition(items): # input is list of tuples of (a,b)
247 | load_last_model()
248 | model.eval()
249 | data = [im for item in items for im in item[:-1]]
250 | data = [totensor(i) for i in data]
251 | data = torch.stack(data, dim=0)
252 | data = Variable(data, volatile=True)
253 | if args.cuda:
254 | data = data.cuda()
255 | z = model.get_latent_var(data.view(-1, model.nc, model.ndf, model.ngf))
256 | it = iter(z.split(1))
257 | z = zip(it, it)
258 | zs = []
259 | numsample = 11
260 | for i,j in z:
261 | for factor in np.linspace(0,1,numsample):
262 | zs.append(i+(j-i)*factor)
263 | z = torch.cat(zs, 0)
264 | recon = model.decode(z)
265 |
266 | it1 = iter(data.split(1))
267 | it2 = [iter(recon.split(1))]*numsample
268 | result = zip(it1, it1, *it2)
269 | result = [im for item in result for im in item]
270 |
271 | result = torch.cat(result, 0)
272 | torchvision.utils.save_image(result.data, '../imgs/trans.jpg', nrow=2+numsample, padding=2)
273 |
274 |
275 | def rand_faces(num=5):
276 | load_last_model()
277 | model.eval()
278 | z = torch.randn(num*num, model.latent_variable_size)
279 | z = Variable(z, volatile=True)
280 | if args.cuda:
281 | z = z.cuda()
282 | recon = model.decode(z)
283 | torchvision.utils.save_image(recon.data, '../imgs/rand_faces.jpg', nrow=num, padding=2)
284 |
285 | def load_last_model():
286 | models = glob('../models/*.pth')
287 | model_ids = [(int(f.split('_')[1]), f) for f in models]
288 | start_epoch, last_cp = max(model_ids, key=lambda item:item[0])
289 | print('Last checkpoint: ', last_cp)
290 | model.load_state_dict(torch.load(last_cp))
291 | return start_epoch, last_cp
292 |
293 | def resume_training():
294 | start_epoch, _ = load_last_model()
295 |
296 | for epoch in range(start_epoch + 1, start_epoch + args.epochs + 1):
297 | train_loss = train(epoch)
298 | test_loss = test(epoch)
299 | torch.save(model.state_dict(), '../models/Epoch_{}_Train_loss_{:.4f}_Test_loss_{:.4f}.pth'.format(epoch, train_loss, test_loss))
300 |
301 | def last_model_to_cpu():
302 | _, last_cp = load_last_model()
303 | model.cpu()
304 | torch.save(model.state_dict(), '../models/cpu_'+last_cp.split('/')[-1])
305 |
306 | if __name__ == '__main__':
307 | resume_training()
308 | # last_model_to_cpu()
309 | # load_last_model()
310 | # rand_faces(10)
311 | # da = load_pickle(test_loader[0])
312 | # da = da[:120]
313 | # it = iter(da)
314 | # l = zip(it, it, it)
315 | # # latent_space_transition(l)
316 | # perform_latent_space_arithmatics(l)
--------------------------------------------------------------------------------