├── .gitignore
├── .idea
├── IntroVAE-Pytorch.iml
├── inspectionProfiles
│ └── Project_Default.xml
├── misc.xml
├── modules.xml
├── vcs.xml
└── workspace.xml
├── README.md
├── __pycache__
├── celeba.cpython-36.pyc
├── model.cpython-36.pyc
└── utils.cpython-36.pyc
├── assets
├── heart.gif
├── train.png
└── xr_750000.png
├── celeba.py
├── celebahd.py
├── lsun.py
├── main.py
├── model.py
├── nohup.out
├── res
├── xp_20000.jpg
└── xr_20000.jpg
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | res
2 | ckpt
3 | .idea
4 | __pycache__
5 |
6 |
--------------------------------------------------------------------------------
/.idea/IntroVAE-Pytorch.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
157 |
158 |
159 |
160 | self.decoder
161 | F.mse_loss
162 | mse_loss
163 | reg_ae
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 | Python
192 |
193 |
194 |
195 |
196 | PyPep8Inspection
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 | 1544000979222
246 |
247 |
248 | 1544000979222
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # IntroVAE-Pytorch
2 |
3 | Pytorch Implementation for NeuraIPS2018 paper:
4 | [IntroVAE: Introspective Variational Autoencoders for Photographic Image Synthesis](https://arxiv.org/abs/1807.06358).
5 |
6 | The rep. contains a basic implementation for IntroVAE. However, due to no official implementation released, some hyperparameters can only be guessed and can not reach the performance as stated in paper.
7 |
8 | 
9 |
10 | # HowTo
11 | 1. Download [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) dataset and extract it as:
12 | ```
13 | ├── /home/i/dbs/
14 | ├──img_align_celeba # only one folder in this directory
15 | ├── 050939.jpg
16 | ├── 050940.jpg
17 | ├── 050941.jpg
18 | ├── 050942.jpg
19 | ├── 050943.jpg
20 | ├── 050944.jpg
21 | ├── 050945.jpg
22 | ```
23 |
24 | modify `/home/i/dbs` to your specific path, making sure that the `/home/i/dbs/` comtains only ONE folder since we use
25 | `torchvision.datasets.ImageFolder` API to load dataset.
26 | ```python
27 | argparser.add_argument('--root', type=str, default='/home/i/dbs/',
28 | help='root/label/*.jpg')
29 | ```
30 |
31 | 2. run `python main.py --epoch 750000` to train from strach, and use `python main.py --resume '' --epoch 1000000` to resume training from latest checkpoint.
32 |
33 |
34 | # Training
35 |
36 | only tested for CelebA 128x128 exp.
37 |
38 | - training curves
39 | 
40 |
41 | - sampled x
42 | 
43 |
44 |
--------------------------------------------------------------------------------
/__pycache__/celeba.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dragen1860/IntroVAE-Pytorch/fd42b940eeaf9308b4547a6d11a2c5e6eecee052/__pycache__/celeba.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dragen1860/IntroVAE-Pytorch/fd42b940eeaf9308b4547a6d11a2c5e6eecee052/__pycache__/model.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dragen1860/IntroVAE-Pytorch/fd42b940eeaf9308b4547a6d11a2c5e6eecee052/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/assets/heart.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dragen1860/IntroVAE-Pytorch/fd42b940eeaf9308b4547a6d11a2c5e6eecee052/assets/heart.gif
--------------------------------------------------------------------------------
/assets/train.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dragen1860/IntroVAE-Pytorch/fd42b940eeaf9308b4547a6d11a2c5e6eecee052/assets/train.png
--------------------------------------------------------------------------------
/assets/xr_750000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dragen1860/IntroVAE-Pytorch/fd42b940eeaf9308b4547a6d11a2c5e6eecee052/assets/xr_750000.png
--------------------------------------------------------------------------------
/celeba.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision import datasets, transforms
3 |
4 |
5 | class UnNormalize:
6 | def __init__(self, mean, std):
7 | self.mean = mean
8 | self.std = std
9 |
10 | def __call__(self, tensor):
11 | """
12 | Args:
13 | tensor (Tensor): Tensor image of size (B, C, H, W) to be normalized.
14 | Returns:
15 | Tensor: Normalized image.
16 | """
17 | with torch.no_grad():
18 |
19 | for i, (m, s) in enumerate(zip(self.mean, self.std)):
20 | tensor[:, i,...].mul_(s).add_(m)
21 | # The normalize code -> t.sub_(m).div_(s)
22 | return tensor
23 |
24 | def load_celeba(root, imgsz):
25 |
26 | transform = transforms.Compose([
27 | # transforms.RandomSizedCrop(224),
28 | # transforms.RandomHorizontalFlip(),
29 | transforms.Resize([imgsz, imgsz]),
30 | transforms.ToTensor(),
31 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
32 | std=[0.229, 0.224, 0.225])
33 | ])
34 |
35 | db = datasets.ImageFolder(root, transform=transform)
36 |
37 | return db
38 |
39 |
40 | def unnorm_(*args):
41 | """
42 | conduct reverse normalize on each tensor in-place
43 | :param args:
44 | :return:
45 | """
46 | net = UnNormalize(mean=[0.485, 0.456, 0.406],
47 | std=[0.229, 0.224, 0.225])
48 | for img in args:
49 | net(img)
50 |
--------------------------------------------------------------------------------
/celebahd.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dragen1860/IntroVAE-Pytorch/fd42b940eeaf9308b4547a6d11a2c5e6eecee052/celebahd.py
--------------------------------------------------------------------------------
/lsun.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dragen1860/IntroVAE-Pytorch/fd42b940eeaf9308b4547a6d11a2c5e6eecee052/lsun.py
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os, glob
2 | import torch
3 | import numpy as np
4 | from torch.utils.data import DataLoader
5 | import argparse
6 | from torchvision.utils import save_image
7 |
8 | from celeba import load_celeba, unnorm_
9 | from model import IntroVAE
10 | import visdom
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 | def main(args):
26 | print(args)
27 |
28 | torch.manual_seed(22)
29 | np.random.seed(22)
30 |
31 | viz = visdom.Visdom()
32 |
33 | db = load_celeba(args.root, args.imgsz)
34 | db_loader = DataLoader(db, batch_size=args.batchsz, shuffle=True, num_workers=4, pin_memory=True)
35 |
36 | device = torch.device('cuda')
37 | vae = IntroVAE(args).to(device)
38 | params = filter(lambda x: x.requires_grad, vae.parameters())
39 | num = sum(map(lambda x: np.prod(x.shape), params))
40 | print('Total trainable tensors:', num)
41 | # print(vae)
42 |
43 | for path in ['res', 'ckpt']:
44 | if not os.path.exists(path):
45 | os.mkdir(path)
46 | print('mkdir:', path)
47 |
48 |
49 | epoch_start = 0
50 | if args.resume is not None and args.resume != 'None':
51 | if args.resume is '': # load latest
52 | ckpts = glob.glob('ckpt/*_*.mdl')
53 | if not ckpts:
54 | print('no avaliable ckpt found.')
55 | raise FileNotFoundError
56 | ckpts = sorted(ckpts, key=os.path.getmtime)
57 | # print(ckpts)
58 | ckpt = ckpts[-1]
59 | epoch_start = int(ckpt.split('.')[-2].split('_')[-1])
60 | vae.load_state_dict(torch.load(ckpt))
61 | print('load latest ckpt from:', ckpt, epoch_start)
62 | else: # load specific ckpt
63 | if os.path.isfile(args.resume):
64 | vae.load_state_dict(torch.load(args.resume))
65 | print('load ckpt from:', args.resume, epoch_start)
66 | else:
67 | raise FileNotFoundError
68 | else:
69 | print('pre-training and training from scratch...')
70 |
71 | viz.line([[0 for _ in range(6)]], [epoch_start], win='train', opts=dict(title='training',
72 | legend=['b*ae', 'a*inf(x)', 'a*inf(xr)', 'a*inf(xp)', 'a*gen(xr)', 'a*gen(xp)']))
73 | viz.line([0], [epoch_start], win='encoder_loss', opts=dict(title='encoder_loss'))
74 | viz.line([0], [epoch_start], win='decoder_loss', opts=dict(title='decoder_loss'))
75 | viz.line([0], [epoch_start], win='ae_loss', opts=dict(title='ae_loss'))
76 | viz.line([0], [epoch_start], win='reg_ae', opts=dict(title='reg_ae'))
77 | viz.line([0], [epoch_start], win='encoder_adv', opts=dict(title='encoder_adv'))
78 | viz.line([0], [epoch_start], win='decoder_adv', opts=dict(title='decoder_adv'))
79 |
80 |
81 | # pre-training for 1.5 epoch
82 | # pretraining_epoch = 1.5 * training_splitting_size // batchsz
83 | pretraining_epoch = int(1. * (0.9*len(db)) ) // args.batchsz
84 | if pretraining_epoch - epoch_start > 0:
85 | print('>>pre-training for %d epoches,'%pretraining_epoch, 'already completed:', epoch_start)
86 | vae.set_alph_beta_gamma(0, args.beta, args.gamma)
87 |
88 | # pre-training for at most 2 iteration.
89 | for _ in range(2):
90 | db_loader = DataLoader(db, batch_size=args.batchsz, shuffle=True, num_workers=4, pin_memory=True)
91 | print('epoch\tvae\tenc-adv\t\tdec-adv\t\tae\t\tenc\t\tdec')
92 |
93 | for _, (x, label) in enumerate(db_loader):
94 | x = x.to(device)
95 |
96 | encoder_loss, decoder_loss, reg_ae, encoder_adv, decoder_adv, loss_ae, xr, xp, \
97 | regr, regr_ng, regpp, regpp_ng = vae(x)
98 |
99 | if epoch_start % 50 == 0:
100 |
101 | print(epoch_start, '\t%0.3f\t%0.3f\t\t%0.3f\t\t%0.3f\t\t%0.3f\t\t%0.3f'%(
102 | reg_ae.item(), encoder_adv.item(), decoder_adv.item(), loss_ae.item(), encoder_loss.item(),
103 | decoder_loss.item()
104 | ))
105 |
106 | viz.line([[args.beta*loss_ae.item(), args.gamma*reg_ae.item(), args.alpha*regr_ng.item(),
107 | args.alpha * regpp_ng.item(), args.alpha*regr.item(), args.alpha*regpp.item()]],
108 | [epoch_start], win='train', update='append')
109 | viz.line([encoder_loss.item()], [epoch_start], win='encoder_loss', update='append')
110 | viz.line([decoder_loss.item()], [epoch_start], win='decoder_loss', update='append')
111 | viz.line([loss_ae.item()], [epoch_start], win='ae_loss', update='append')
112 | viz.line([reg_ae.item()], [epoch_start], win='reg_ae', update='append')
113 | viz.line([encoder_adv.item()], [epoch_start], win='encoder_adv', update='append')
114 | viz.line([decoder_adv.item()], [epoch_start], win='decoder_adv', update='append')
115 |
116 | if epoch_start % 200 == 0:
117 | x, xr, xp = x[:8], xr[:8], xp[:8]
118 | viz.histogram(xr[0].view(-1), win='xr_hist', opts=dict(title='xr_hist'))
119 | unnorm_(x, xr, xp)
120 | viz.images(x, nrow=4, win='x', opts=dict(title='x'))
121 | viz.images(xr, nrow=4, win='xr', opts=dict(title='xr'))
122 | viz.images(xp, nrow=4, win='xp', opts=dict(title='xp'))
123 |
124 | if epoch_start % 10000 == 0:
125 | save_image(xr, 'res/xr_%d.jpg'%epoch_start, nrow=4)
126 | save_image(xp, 'res/xp_%d.jpg'%epoch_start, nrow=4)
127 | print('save xr xp to res directory.')
128 |
129 | if epoch_start % 10000 == 0:
130 | torch.save(vae.state_dict(), 'ckpt/introvae_%d.mdl'%epoch_start)
131 | print('saved ckpt:', 'ckpt/introvae_%d.mdl'%epoch_start)
132 |
133 | epoch_start += 1
134 | if epoch_start > pretraining_epoch:
135 | break
136 |
137 |
138 |
139 |
140 | # training.
141 | print('>>training Intro-VAE now...')
142 | vae.set_alph_beta_gamma(args.alpha, args.beta, args.gamma)
143 | db_iter = iter(db_loader)
144 | print('epoch\tvae\tenc-adv\t\tdec-adv\t\tae\t\tenc\t\tdec')
145 | for epoch in range(epoch_start, args.epoch):
146 |
147 | try:
148 | # can not use iter(db_loader).next()
149 | x, label = next(db_iter)
150 | except StopIteration as err:
151 | db_loader = DataLoader(db, batch_size=args.batchsz, shuffle=True, num_workers=4, pin_memory=True)
152 | db_iter = iter(db_loader)
153 | x, label = next(db_iter)
154 | print('epoch\tvae\tenc-adv\t\tdec-adv\t\tae\t\tenc\t\tdec')
155 |
156 | x = x.to(device)
157 |
158 | encoder_loss, decoder_loss, reg_ae, encoder_adv, decoder_adv, loss_ae, xr, xp, \
159 | regr, regr_ng, regpp, regpp_ng = vae(x)
160 |
161 | if epoch % 100 == 0:
162 |
163 | print(epoch_start, '\t%0.3f\t%0.3f\t\t%0.3f\t\t%0.3f\t\t%0.3f\t\t%0.3f' % (
164 | reg_ae.item(), encoder_adv.item(), decoder_adv.item(), loss_ae.item(), encoder_loss.item(),
165 | decoder_loss.item()
166 | ))
167 |
168 | viz.line([[args.beta * loss_ae.item(), args.gamma * reg_ae.item(), args.alpha * regr_ng.item(),
169 | args.alpha * regpp_ng.item(), args.alpha * regr.item(), args.alpha * regpp.item()]],
170 | [epoch], win='train', update='append')
171 | viz.line([encoder_loss.item()], [epoch], win='encoder_loss', update='append')
172 | viz.line([decoder_loss.item()], [epoch], win='decoder_loss', update='append')
173 | viz.line([loss_ae.item()], [epoch], win='ae_loss', update='append')
174 | viz.line([reg_ae.item()], [epoch], win='reg_ae', update='append')
175 | viz.line([encoder_adv.item()], [epoch], win='encoder_adv', update='append')
176 | viz.line([decoder_adv.item()], [epoch], win='decoder_adv', update='append')
177 |
178 | if epoch % 500 == 0:
179 | x, xr, xp = x[:8], xr[:8], xp[:8]
180 | viz.histogram(xr[0].view(-1), win='xr_hist', opts=dict(title='xr_hist'))
181 | unnorm_(x, xr, xp)
182 | viz.images(x, nrow=4, win='x', opts=dict(title='x'))
183 | viz.images(xr, nrow=4, win='xr', opts=dict(title='xr'))
184 | viz.images(xp, nrow=4, win='xp', opts=dict(title='xp'))
185 |
186 | if epoch % 10000 == 0:
187 | save_image(xr, 'res/xr_%d.jpg' % epoch, nrow=4)
188 | save_image(xp, 'res/xp_%d.jpg' % epoch, nrow=4)
189 | print('save xr, xp to res directory')
190 |
191 | if epoch % 10000 == 0:
192 | torch.save(vae.state_dict(), 'ckpt/introvae_%d.mdl'%epoch)
193 | print('saved ckpt:', 'ckpt/introvae_%d.mdl'%epoch)
194 |
195 |
196 |
197 |
198 | torch.save(vae.state_dict(), 'ckpt/introvae_%d.mdl'%args.epoch)
199 | print('saved final ckpt:', 'ckpt/introvae_%d.mdl'%args.epoch)
200 |
201 |
202 |
203 |
204 |
205 |
206 | if __name__ == '__main__':
207 |
208 | argparser = argparse.ArgumentParser()
209 | argparser.add_argument('--imgsz', type=int, default=128, help='imgsz')
210 | argparser.add_argument('--batchsz', type=int, default=8, help='batch size')
211 | argparser.add_argument('--z_dim', type=int, default=256, help='hidden latent z dim')
212 | argparser.add_argument('--epoch', type=int, default=750000, help='epoches')
213 | argparser.add_argument('--margin', type=int, default=110, help='margin')
214 | argparser.add_argument('--alpha', type=float, default=0.25, help='alpha * loss_adv')
215 | argparser.add_argument('--beta', type=float, default=0.5, help='beta * ae_loss')
216 | argparser.add_argument('--gamma', type=float, default=1., help='gamma * kl(q||p)_loss')
217 | argparser.add_argument('--lr', type=float, default=0.0002, help='learning rate')
218 | argparser.add_argument('--root', type=str, default='/home/i/dbs/',
219 | help='root/label/*.jpg')
220 | argparser.add_argument('--resume', type=str, default=None,
221 | help='with ckpt path, set None to train from scratch, set empty str to load latest ckpt')
222 |
223 |
224 | args = argparser.parse_args()
225 |
226 | main(args)
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, optim
3 | from torch.nn import functional as F
4 | import math
5 | from utils import Reshape, Flatten, ResBlk
6 |
7 |
8 |
9 | class Encoder(nn.Module):
10 |
11 | def __init__(self, imgsz, ch):
12 | """
13 |
14 | :param imgsz:
15 | :param ch: base channels
16 | """
17 | super(Encoder, self).__init__()
18 |
19 | x = torch.randn(2, 3, imgsz, imgsz)
20 | print('Encoder:', list(x.shape), end='=>')
21 |
22 | layers = [
23 | nn.Conv2d(3, ch, kernel_size=5, stride=1, padding=2),
24 | nn.BatchNorm2d(ch),
25 | nn.ReLU(inplace=True),
26 | nn.AvgPool2d(2, stride=None, padding=0),
27 | ]
28 | # just for print
29 | out = nn.Sequential(*layers)(x)
30 | print(list(out.shape), end='=>')
31 |
32 | # [b, ch_cur, imgsz, imgsz] => [b, ch_next, mapsz, mapsz]
33 | mapsz = imgsz // 2
34 | ch_cur = ch
35 | ch_next = ch_cur * 2
36 |
37 | while mapsz > 4: # util [b, ch_, 4, 4]
38 | # add resblk
39 | layers.extend([
40 | ResBlk([1, 3, 3], [ch_cur, ch_next, ch_next, ch_next]),
41 | nn.AvgPool2d(kernel_size=2, stride=None)
42 | ])
43 | mapsz = mapsz // 2
44 | ch_cur = ch_next
45 | ch_next = ch_next * 2 if ch_next < 512 else 512 # set max ch=512
46 |
47 | # for print
48 | out = nn.Sequential(*layers)(x)
49 | print(list(out.shape), end='=>')
50 |
51 | layers.extend([
52 | ResBlk([3, 3], [ch_cur, ch_next, ch_next]),
53 | nn.AvgPool2d(kernel_size=2, stride=None),
54 | ResBlk([3, 3], [ch_next, ch_next, ch_next]),
55 | nn.AvgPool2d(kernel_size=2, stride=None),
56 | Flatten()
57 | ])
58 |
59 | self.net = nn.Sequential(*layers)
60 |
61 | # for printing
62 | out = nn.Sequential(*layers)(x)
63 | print(list(out.shape))
64 |
65 |
66 | def forward(self, x):
67 | """
68 |
69 | :param x:
70 | :return:
71 | """
72 | return self.net(x)
73 |
74 |
75 |
76 |
77 | class Decoder(nn.Module):
78 |
79 |
80 | def __init__(self, imgsz, z_dim):
81 | """
82 |
83 | :param imgsz:
84 | :param z_dim:
85 | """
86 | super(Decoder, self).__init__()
87 |
88 | mapsz = 4
89 | ch_next = z_dim
90 | print('Decoder:', [z_dim], '=>', [2, ch_next, mapsz, mapsz], end='=>')
91 |
92 | # z: [b, z_dim] => [b, z_dim, 4, 4]
93 | layers = [
94 | # z_dim => z_dim * 4 * 4 => [z_dim, 4, 4] => [z_dim, 4, 4]
95 | nn.Linear(z_dim, z_dim * mapsz * mapsz),
96 | nn.BatchNorm1d(z_dim * mapsz * mapsz),
97 | nn.ReLU(inplace=True),
98 | Reshape(z_dim, mapsz, mapsz),
99 | ResBlk([3, 3], [z_dim, z_dim, z_dim])
100 | ]
101 |
102 |
103 | # scale imgsz up while keeping channel untouched
104 | # [b, z_dim, 4, 4] => [b, z_dim, 8, 8] => [b, z_dim, 16, 16]
105 | for i in range(2):
106 | layers.extend([
107 | nn.Upsample(scale_factor=2),
108 | ResBlk([3, 3], [ch_next, ch_next, ch_next])
109 | ])
110 | mapsz = mapsz * 2
111 |
112 | # for print
113 | tmp = torch.randn(2, z_dim)
114 | net = nn.Sequential(*layers)
115 | out = net(tmp)
116 | print(list(out.shape), end='=>')
117 | del net
118 |
119 | # scale imgsz up and scale imgc down
120 | # [b, z_dim, 16, 16] => [z_dim//2, 32, 32] => [z_dim//4, 64, 64] => [z_dim//8, 128, 128]
121 | # => [z_dim//16, 256, 256] => [z_dim//32, 512, 512]
122 | while mapsz < imgsz//2:
123 | ch_cur = ch_next
124 | ch_next = ch_next // 2 if ch_next >=32 else ch_next # set mininum ch=16
125 | layers.extend([
126 | # [2, 32, 32, 32] => [2, 32, 64, 64]
127 | nn.Upsample(scale_factor=2),
128 | # => [2, 16, 64, 64]
129 | ResBlk([1, 3, 3], [ch_cur, ch_next, ch_next, ch_next])
130 | ])
131 | mapsz = mapsz * 2
132 |
133 | # for print
134 | tmp = torch.randn(2, z_dim)
135 | net = nn.Sequential(*layers)
136 | out = net(tmp)
137 | print(list(out.shape), end='=>')
138 | del net
139 |
140 |
141 | # [b, ch_next, 512, 512] => [b, 3, 1024, 1024]
142 | layers.extend([
143 | nn.Upsample(scale_factor=2),
144 | ResBlk([3, 3], [ch_next, ch_next, ch_next]),
145 | nn.Conv2d(ch_next, 3, kernel_size=5, stride=1, padding=2),
146 | # sigmoid / tanh
147 | ])
148 |
149 | self.net = nn.Sequential(*layers)
150 |
151 | # for print
152 | tmp = torch.randn(2, z_dim)
153 | out = self.net(tmp)
154 | print(list(out.shape))
155 |
156 | def forward(self, x):
157 | """
158 |
159 | :param x: [b, z_dim]
160 | :return:
161 | """
162 | # print('before forward:', x.shape)
163 | x = self.net(x)
164 | # print('after forward:', x.shape)
165 | return x
166 |
167 |
168 |
169 |
170 |
171 | class IntroVAE(nn.Module):
172 |
173 |
174 | def __init__(self, args):
175 | """
176 |
177 | :param imgsz:
178 | :param z_dim: h_dim is the output dim of encoder, and we use z_net net to convert it from
179 | h_dim to 2*z_dim and then splitting.
180 | """
181 | super(IntroVAE, self).__init__()
182 |
183 | imgsz = args.imgsz
184 | z_dim = args.z_dim
185 |
186 |
187 | # set first conv channel as 16
188 | self.encoder = Encoder(imgsz, 16)
189 |
190 | # get h_dim of encoder output
191 | x = torch.randn(2, 3, imgsz, imgsz)
192 | z_ = self.encoder(x)
193 | h_dim = z_.size(1)
194 |
195 | # convert h_dim to 2*z_dim
196 | self.z_net = nn.Linear(h_dim, 2 * z_dim)
197 |
198 | # sample
199 | z, mu, log_sigma2 = self.reparametrization(z_)
200 |
201 | # create decoder by z_dim
202 | self.decoder = Decoder(imgsz, z_dim)
203 | out = self.decoder(z)
204 |
205 | # print
206 | print('IntroVAE x:', list(x.shape), 'z_:', list(z_.shape), 'z:', list(z.shape), 'out:', list(out.shape))
207 |
208 |
209 | self.alpha = args.alpha # for adversarial loss
210 | self.beta = args.beta # for reconstruction loss
211 | self.gamma = args.gamma # for variational loss
212 | self.margin = args.margin # margin in eq. 11
213 | self.z_dim = z_dim # z is the hidden vector while h is the output of encoder
214 | self.h_dim = h_dim
215 |
216 | self.optim_encoder = optim.Adam(self.encoder.parameters(), lr=args.lr)
217 | self.optim_decoder = optim.Adam(self.decoder.parameters(), lr=args.lr)
218 |
219 |
220 | def set_alph_beta_gamma(self, alpha, beta, gamma):
221 | """
222 | this func is for pre-training, to set alpha=0 to transfer to vilina vae.
223 | :param alpha: for adversarial loss
224 | :param beta: for reconstruction loss
225 | :param gamma: for variational loss
226 | :return:
227 | """
228 | self.alpha = alpha
229 | self.beta = beta
230 | self.gamma = gamma
231 |
232 | def reparametrization(self, z_):
233 | """
234 |
235 | :param z_: [b, 2*z_dim]
236 | :return:
237 | """
238 | # [b, 2*z_dim] => [b, z_dim], [b, z_dim]
239 | mu, log_sigma2 = self.z_net(z_).chunk(2, dim=1)
240 | # sample from normal dist
241 | eps = torch.randn_like(log_sigma2)
242 | # reparametrization trick
243 | # mean + sigma * eps
244 | z = mu + torch.exp(log_sigma2).sqrt() * eps
245 |
246 | return z, mu, log_sigma2
247 |
248 | def kld(self, mu, log_sigma2):
249 | """
250 | compute the kl divergence between N(mu, sigma^2) and N(0, 1)
251 | :param mu: [b, z_dim]
252 | :param log_sigma2: [b, z_dim]
253 | :return:
254 | """
255 | batchsz = mu.size(0)
256 | # https://stats.stackexchange.com/questions/7440/kl-divergence-between-two-univariate-gaussians
257 | kl = - 0.5 * (1 + log_sigma2 - torch.pow(mu, 2) - torch.exp(log_sigma2))
258 | kl = kl.sum() #(batchsz * self.z_dim)
259 |
260 | return kl
261 |
262 | def output_activation(self, x):
263 | """
264 |
265 | :param x:
266 | :return:
267 | """
268 | return torch.tanh(x)
269 |
270 | def forward(self, x):
271 | """
272 | The notation used here all come from Algorithm 1, page 6 of official paper.
273 | can refer to Figure7 in page 15 as well.
274 | :param x: [b, 3, 1024, 1024]
275 | :return:
276 | """
277 | batchsz = x.size(0)
278 |
279 | # 1. update encoder
280 | z_ = self.encoder(x)
281 | z, mu, log_sigma2 = self.reparametrization(z_)
282 | xr = self.output_activation(self.decoder(z))
283 | zp = torch.randn_like(z)
284 | xp = self.output_activation(self.decoder(zp))
285 |
286 | loss_ae = F.mse_loss(xr, x, reduction='sum').sqrt()
287 | reg_ae = self.kld(mu, log_sigma2)
288 |
289 | zr_ng_ = self.encoder(xr.detach())
290 | zr_ng, mur_ng, log_sigma2r_ng = self.reparametrization(zr_ng_)
291 | regr_ng = self.kld(mur_ng, log_sigma2r_ng)
292 | # max(0, margin - l)
293 | regr_ng = torch.clamp(self.margin - regr_ng, min=0)
294 | zpp_ng_ = self.encoder(xp.detach())
295 | zpp_ng, mupp_ng, log_sigma2pp_ng = self.reparametrization(zpp_ng_)
296 | regpp_ng = self.kld(mupp_ng, log_sigma2pp_ng)
297 | # max(0, margin - l)
298 | regpp_ng = torch.clamp(self.margin - regpp_ng, min=0)
299 |
300 |
301 | encoder_adv = regr_ng + regpp_ng
302 | encoder_loss = self.gamma * reg_ae + self.alpha * encoder_adv + self.beta * loss_ae
303 | self.optim_encoder.zero_grad()
304 | encoder_loss.backward()
305 | self.optim_encoder.step()
306 |
307 |
308 | # 2. update decoder
309 | z_ = self.encoder(x)
310 | z, mu, log_sigma2 = self.reparametrization(z_)
311 | xr = self.output_activation(self.decoder(z))
312 | zp = torch.randn_like(z)
313 | xp = self.output_activation(self.decoder(zp))
314 |
315 | loss_ae = F.mse_loss(xr, x, reduction='sum').sqrt()
316 |
317 | zr_ = self.encoder(xr)
318 | zr, mur, log_sigma2r = self.reparametrization(zr_)
319 | regr = self.kld(mur, log_sigma2r)
320 | zpp_ = self.encoder(xp)
321 | zpp, mupp, log_sigma2pp = self.reparametrization(zpp_)
322 | regpp = self.kld(mupp, log_sigma2pp)
323 |
324 | # by Eq.12, the 1st term of loss
325 | decoder_adv = regr + regpp
326 | decoder_loss = self.alpha * decoder_adv + self.beta * loss_ae
327 | self.optim_decoder.zero_grad()
328 | decoder_loss.backward()
329 | self.optim_decoder.step()
330 |
331 |
332 |
333 | return encoder_loss, decoder_loss, reg_ae, encoder_adv, decoder_adv, loss_ae, xr, xp, \
334 | regr, regr_ng, regpp, regpp_ng
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 | def main():
354 | pass
355 |
356 |
357 |
358 |
359 |
360 | if __name__ == '__main__':
361 | main()
--------------------------------------------------------------------------------
/res/xp_20000.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dragen1860/IntroVAE-Pytorch/fd42b940eeaf9308b4547a6d11a2c5e6eecee052/res/xp_20000.jpg
--------------------------------------------------------------------------------
/res/xr_20000.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dragen1860/IntroVAE-Pytorch/fd42b940eeaf9308b4547a6d11a2c5e6eecee052/res/xr_20000.jpg
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 |
6 | class Flatten(nn.Module):
7 | def __init__(self):
8 | super(Flatten, self).__init__()
9 |
10 | def forward(self, x):
11 | # print('before flaten:', x.shape)
12 | return x.view(x.size(0), -1)
13 |
14 | class Reshape(nn.Module):
15 |
16 | def __init__(self, *args):
17 | super(Reshape, self).__init__()
18 | self.shape = args
19 |
20 | def forward(self, x):
21 | return x.view(-1, *self.shape)
22 |
23 | class Add(nn.Module):
24 |
25 | def __init__(self):
26 | super(Add, self).__init__()
27 |
28 | def forward(self, x, residual):
29 | """
30 |
31 | :param x:
32 | :param residual:
33 | :return:
34 | """
35 | return x + residual
36 |
37 | def extra_repr(self):
38 | return "ResNet Element-wise Add Layer"
39 |
40 |
41 | class ResBlk(nn.Module):
42 |
43 | def __init__(self, kernels, chs):
44 | """
45 |
46 | :param kernels: [1, 3, 3], as [kernel_1, kernel_2, kernel_3]
47 | :param chs: [ch_in, 64, 64, 64], as [ch_in, ch_out1, ch_out2, ch_out3]
48 | :return:
49 | """
50 | super(ResBlk, self).__init__()
51 |
52 | layers = []
53 |
54 | assert len(chs)-1 == len(kernels), "mismatching between chs and kernels"
55 |
56 | for idx in range(len(kernels)):
57 | layers.extend([
58 | nn.Conv2d(chs[idx], chs[idx+1], kernel_size=kernels[idx], stride=1,
59 | padding=1 if kernels[idx]!=1 else 0), # no padding for kernel=1
60 | nn.BatchNorm2d(chs[idx+1]),
61 | nn.ReLU(inplace=True)
62 | ])
63 |
64 | self.net = nn.Sequential(*layers)
65 |
66 | self.shortcut = nn.Sequential()
67 | if chs[0] != chs[-1]: # convert from ch_int to ch_out3
68 | self.shortcut = nn.Sequential(
69 | nn.Conv2d(chs[0], chs[-1], kernel_size=1),
70 | nn.BatchNorm2d(chs[-1]),
71 | nn.ReLU(inplace=True)
72 | )
73 |
74 | def forward(self, x):
75 | """
76 |
77 | :param x:
78 | :return:
79 | """
80 | res = self.net(x)
81 | x_ = self.shortcut(x)
82 | # print(x.shape, x_.shape, res.shape)
83 | return x_ + res
--------------------------------------------------------------------------------