├── .gitignore
├── .idea
├── NIPS-18.iml
├── markdown-navigator.xml
├── markdown-navigator
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── vcs.xml
└── workspace.xml
├── LICENSE.txt
├── Readme.md
├── coders
├── __init__.py
└── vae_coding.py
├── figs
└── train
│ ├── grid
│ ├── anime_samples.png
│ ├── mnist_samples.png
│ ├── samples.png
│ ├── small_samples.png
│ └── some_samples.png
│ └── scatter
│ └── latent.png
├── mnist
├── t10k-images-idx3-ubyte.gz
├── t10k-labels-idx1-ubyte.gz
├── train-images-idx3-ubyte.gz
└── train-labels-idx1-ubyte.gz
├── models
├── __init__.py
├── generator.py
└── vae.py
├── plots
├── __init__.py
└── grid_plots.py
├── providers
├── __init__.py
└── anime.py
├── sample.py
├── utils
├── __init__.py
└── path.py
├── vae_train.py
├── vae_train_anime.py
└── weights
├── vae_anime
├── checkpoint
├── generator.data-00000-of-00001
├── generator.index
└── generator.meta
└── vae_mnist
├── checkpoint
├── generator.data-00000-of-00001
├── generator.index
└── generator.meta
/.gitignore:
--------------------------------------------------------------------------------
1 | /faces
2 |
--------------------------------------------------------------------------------
/.idea/NIPS-18.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/markdown-navigator.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
--------------------------------------------------------------------------------
/.idea/markdown-navigator/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 | true
174 | DEFINITION_ORDER
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 | 1518012270523
294 |
295 |
296 | 1518012270523
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 |
410 |
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 |
431 |
432 |
433 |
434 |
435 |
436 |
437 |
438 |
439 |
440 |
441 |
442 |
443 |
444 |
445 |
446 |
447 |
448 |
449 |
450 |
451 |
452 |
453 |
454 |
455 |
456 |
457 |
458 |
459 |
460 |
461 |
462 |
463 |
464 |
465 |
466 |
467 |
468 |
469 |
470 |
471 |
472 |
473 |
474 |
475 |
476 |
477 |
478 |
479 |
480 |
481 |
482 |
483 |
484 |
485 |
486 |
487 |
488 |
489 |
490 |
491 |
492 |
493 |
494 |
495 |
496 |
497 |
498 |
499 |
500 |
501 |
502 |
503 |
504 |
505 |
506 |
507 |
508 |
509 |
510 |
511 |
512 |
513 |
514 |
515 |
516 |
517 |
518 |
519 |
520 |
521 |
522 |
523 |
524 |
525 |
526 |
527 |
528 |
529 |
530 |
531 |
532 |
533 |
534 |
535 |
536 |
537 |
538 |
539 |
540 |
541 |
542 |
543 |
544 |
545 |
546 |
547 |
548 |
549 |
550 |
551 |
552 |
553 |
554 |
555 |
556 |
557 |
558 |
559 |
560 |
561 |
562 |
563 |
564 |
565 |
566 |
567 |
568 |
569 |
570 |
571 |
572 |
573 |
574 |
575 |
576 |
577 |
578 |
579 |
580 |
581 |
582 |
583 |
584 |
585 |
586 |
587 |
588 |
589 |
590 |
591 |
592 |
593 |
594 |
595 |
596 |
597 |
598 |
599 |
600 |
601 |
602 |
603 |
604 |
605 |
606 |
607 |
608 |
609 |
610 |
611 |
612 |
613 |
614 |
615 |
616 |
617 |
618 |
619 |
620 |
621 |
622 |
623 |
624 |
625 |
626 |
627 |
628 |
629 |
630 |
631 |
632 |
633 |
634 |
635 |
636 |
637 |
638 |
639 |
640 |
641 |
642 |
643 |
644 |
645 |
646 |
647 |
648 |
649 |
650 |
651 |
652 |
653 |
654 |
655 |
656 |
657 |
658 |
659 |
660 |
661 |
662 |
663 |
664 |
665 |
666 |
667 |
668 |
669 |
670 |
671 |
672 |
673 |
674 |
675 |
676 |
677 |
678 |
679 |
680 |
681 |
682 |
683 |
684 |
685 |
686 |
687 |
688 |
689 |
690 |
691 |
692 |
693 |
694 |
695 |
696 |
697 |
698 |
699 |
700 |
701 |
702 |
703 |
704 |
705 |
706 |
707 |
708 |
709 |
710 |
711 |
712 |
713 |
714 |
715 |
716 |
717 |
718 |
719 |
720 |
721 |
722 |
723 |
724 |
725 |
726 |
727 |
728 |
729 |
730 |
731 |
732 |
733 |
734 |
735 |
736 |
737 |
738 |
739 |
740 |
741 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) [2018] [Ga Wu] [University of Toronto]
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Readme.md:
--------------------------------------------------------------------------------
1 | Variational Autoencoder
2 | ===
3 | This is a enhanced implementation of Variational Autoencoder.
4 | Both fully connected and convolutional encoder/decoder are built in this model.
5 | Please star if you like this implementation.
6 |
7 | ## Use
8 | ```python
9 | $python vae_train_amine.py # for training
10 | $python sample.py # for sampling
11 | ```
12 |
13 | ## Update
14 | 1. Removed standard derivation learning on Gaussian observation decoder.
15 | 2. Set the standard derivation of observation to hyper-parameter.
16 | 3. Add deconvolution CNN support for the Anime dataset.
17 | 4. Remove Anime dataset itself to avoid legal issues.
18 |
19 | ## Pre-Trained Models
20 | There are two pretrained models
21 | 1. Anime
22 | 2. MNIST
23 |
24 | The weights of pretrained models are locaded in weights folder
25 |
26 | ## Samples
27 |
28 | ### ANIME
29 | 
30 |
31 | ### MNIST
32 | 
33 |
34 |
35 | ### Latent Space Distribution
36 | 
37 |
--------------------------------------------------------------------------------
/coders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuga214/IMPLEMENTATION_Variational-Auto-Encoder/6fb103bd44e2887b827861d5026f054802a76bd0/coders/__init__.py
--------------------------------------------------------------------------------
/coders/vae_coding.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import tensorlayer as tl
3 | from tensorlayer.layers import *
4 | from tensorflow.contrib import layers
5 |
6 |
7 | def fc_mnist_encoder(x, latent_dim, activation=None):
8 | e = layers.fully_connected(x, 500, scope='fc-01')
9 | e = layers.fully_connected(e, 500, scope='fc-02')
10 | e = layers.fully_connected(e, 200, scope='fc-03')
11 | output = layers.fully_connected(e, 2 * latent_dim, activation_fn=activation,
12 | scope='fc-final')
13 |
14 | return output
15 |
16 |
17 | def fc_mnist_decoder(z, observation_dim, activation=tf.sigmoid):
18 | x = layers.fully_connected(z, 200, scope='fc-01')
19 | x = layers.fully_connected(x, 500, scope='fc-02')
20 | x = layers.fully_connected(x, 500, scope='fc-03')
21 | output = layers.fully_connected(x, observation_dim, activation_fn=activation,
22 | scope='fc-final')
23 | return output
24 |
25 | def conv_anime_encoder(x, latent_dim, activation=None):
26 | is_train = True
27 | z_dim = latent_dim # 512
28 | ef_dim = 64 # encoder filter number
29 |
30 | w_init = tf.random_normal_initializer(stddev=0.02)
31 | gamma_init = tf.random_normal_initializer(1., 0.02)
32 |
33 | input = tf.reshape(x, [-1, 64, 64, 3])
34 | net_in = InputLayer(input, name='en/in') # (b_size,64,64,3)
35 | net_h0 = Conv2d(net_in, ef_dim, (5, 5), (2, 2), act=None,
36 | padding='SAME', W_init=w_init, name='en/h0/conv2d')
37 | net_h0 = BatchNormLayer(net_h0, act=tf.nn.relu,
38 | is_train=is_train, gamma_init=gamma_init, name='en/h0/batch_norm')
39 |
40 | net_h1 = Conv2d(net_h0, ef_dim * 2, (5, 5), (2, 2), act=None,
41 | padding='SAME', W_init=w_init, name='en/h1/conv2d')
42 | net_h1 = BatchNormLayer(net_h1, act=tf.nn.relu,
43 | is_train=is_train, gamma_init=gamma_init, name='en/h1/batch_norm')
44 |
45 | net_h2 = Conv2d(net_h1, ef_dim * 4, (5, 5), (2, 2), act=None,
46 | padding='SAME', W_init=w_init, name='en/h2/conv2d')
47 | net_h2 = BatchNormLayer(net_h2, act=tf.nn.relu,
48 | is_train=is_train, gamma_init=gamma_init, name='en/h2/batch_norm')
49 |
50 | net_h3 = Conv2d(net_h2, ef_dim * 8, (5, 5), (2, 2), act=None,
51 | padding='SAME', W_init=w_init, name='en/h3/conv2d')
52 | net_h3 = BatchNormLayer(net_h3, act=tf.nn.relu,
53 | is_train=is_train, gamma_init=gamma_init, name='en/h3/batch_norm')
54 |
55 |
56 | net_h4 = FlattenLayer(net_h3, name='en/h4/flatten')
57 |
58 | net_out = DenseLayer(net_h4, n_units=z_dim*2, act=tf.identity,
59 | W_init=w_init, name='en/h3/lin_sigmoid')
60 | net_out = BatchNormLayer(net_out, act=tf.identity,
61 | is_train=is_train, gamma_init=gamma_init, name='en/out1/batch_norm')
62 |
63 | output = net_out.outputs
64 |
65 | return output
66 |
67 |
68 | def conv_anime_decoder(z, observation_dim, activation=tf.tanh):
69 |
70 | is_train = True
71 | image_size = 64
72 | s2, s4, s8, s16 = int(image_size / 2), int(image_size / 4), int(image_size / 8), int(image_size / 16) # 32,16,8,4
73 | gf_dim = 64
74 | c_dim = 3
75 |
76 | w_init = tf.random_normal_initializer(stddev=0.02)
77 | gamma_init = tf.random_normal_initializer(1., 0.02)
78 |
79 | net_in = InputLayer(z, name='g/in')
80 | net_h0 = DenseLayer(net_in, n_units=gf_dim * 4 * s8 * s8, W_init=w_init,
81 | act=tf.identity, name='g/h0/lin')
82 | # net_h0.outputs._shape = (b_size,256*8*8)
83 | net_h0 = ReshapeLayer(net_h0, shape=[-1, s8, s8, gf_dim * 4], name='g/h0/reshape')
84 | # net_h0.outputs._shape = (b_size,8,8,256)
85 | net_h0 = BatchNormLayer(net_h0, act=tf.nn.relu, is_train=is_train,
86 | gamma_init=gamma_init, name='g/h0/batch_norm')
87 |
88 | # upsampling
89 | net_h1 = DeConv2d(net_h0, gf_dim * 4, (5, 5), out_size=(s4, s4), strides=(2, 2),
90 | padding='SAME', act=None, W_init=w_init, name='g/h1/decon2d')
91 | net_h1 = BatchNormLayer(net_h1, act=tf.nn.relu, is_train=is_train,
92 | gamma_init=gamma_init, name='g/h1/batch_norm')
93 | # net_h1.outputs._shape = (b_size,16,16,256)
94 |
95 | net_h2 = DeConv2d(net_h1, gf_dim * 2, (5, 5), out_size=(s2, s2), strides=(2, 2),
96 | padding='SAME', act=None, W_init=w_init, name='g/h2/decon2d')
97 | net_h2 = BatchNormLayer(net_h2, act=tf.nn.relu, is_train=is_train,
98 | gamma_init=gamma_init, name='g/h2/batch_norm')
99 | # net_h2.outputs._shape = (b_size,32,32,128)
100 |
101 | net_h3 = DeConv2d(net_h2, gf_dim // 2, (5, 5), out_size=(image_size, image_size), strides=(2, 2),
102 | padding='SAME', act=None, W_init=w_init, name='g/h3/decon2d')
103 | net_h3 = BatchNormLayer(net_h3, act=tf.nn.relu, is_train=is_train,
104 | gamma_init=gamma_init, name='g/h3/batch_norm')
105 | # net_h3.outputs._shape = (b_size,64,64,32)
106 |
107 | # no BN on last deconv
108 | net_h4 = DeConv2d(net_h3, c_dim, (5, 5), out_size=(image_size, image_size), strides=(1, 1),
109 | padding='SAME', act=None, W_init=w_init, name='g/h4/decon2d')
110 | # net_h4.outputs._shape = (b_size,64,64,3)
111 |
112 | logits = net_h4.outputs
113 | net_h4.outputs = tf.nn.tanh(net_h4.outputs)
114 |
115 | output = layers.flatten(net_h4.outputs)
116 |
117 | return output
118 |
--------------------------------------------------------------------------------
/figs/train/grid/anime_samples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuga214/IMPLEMENTATION_Variational-Auto-Encoder/6fb103bd44e2887b827861d5026f054802a76bd0/figs/train/grid/anime_samples.png
--------------------------------------------------------------------------------
/figs/train/grid/mnist_samples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuga214/IMPLEMENTATION_Variational-Auto-Encoder/6fb103bd44e2887b827861d5026f054802a76bd0/figs/train/grid/mnist_samples.png
--------------------------------------------------------------------------------
/figs/train/grid/samples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuga214/IMPLEMENTATION_Variational-Auto-Encoder/6fb103bd44e2887b827861d5026f054802a76bd0/figs/train/grid/samples.png
--------------------------------------------------------------------------------
/figs/train/grid/small_samples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuga214/IMPLEMENTATION_Variational-Auto-Encoder/6fb103bd44e2887b827861d5026f054802a76bd0/figs/train/grid/small_samples.png
--------------------------------------------------------------------------------
/figs/train/grid/some_samples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuga214/IMPLEMENTATION_Variational-Auto-Encoder/6fb103bd44e2887b827861d5026f054802a76bd0/figs/train/grid/some_samples.png
--------------------------------------------------------------------------------
/figs/train/scatter/latent.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuga214/IMPLEMENTATION_Variational-Auto-Encoder/6fb103bd44e2887b827861d5026f054802a76bd0/figs/train/scatter/latent.png
--------------------------------------------------------------------------------
/mnist/t10k-images-idx3-ubyte.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuga214/IMPLEMENTATION_Variational-Auto-Encoder/6fb103bd44e2887b827861d5026f054802a76bd0/mnist/t10k-images-idx3-ubyte.gz
--------------------------------------------------------------------------------
/mnist/t10k-labels-idx1-ubyte.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuga214/IMPLEMENTATION_Variational-Auto-Encoder/6fb103bd44e2887b827861d5026f054802a76bd0/mnist/t10k-labels-idx1-ubyte.gz
--------------------------------------------------------------------------------
/mnist/train-images-idx3-ubyte.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuga214/IMPLEMENTATION_Variational-Auto-Encoder/6fb103bd44e2887b827861d5026f054802a76bd0/mnist/train-images-idx3-ubyte.gz
--------------------------------------------------------------------------------
/mnist/train-labels-idx1-ubyte.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuga214/IMPLEMENTATION_Variational-Auto-Encoder/6fb103bd44e2887b827861d5026f054802a76bd0/mnist/train-labels-idx1-ubyte.gz
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuga214/IMPLEMENTATION_Variational-Auto-Encoder/6fb103bd44e2887b827861d5026f054802a76bd0/models/__init__.py
--------------------------------------------------------------------------------
/models/generator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | from tensorflow.contrib.distributions import Normal
4 |
5 |
6 | class GENERATOR(object):
7 | def __init__(self, latent_dim, observation_dim, generator,
8 | obs_distrib="Bernoulli",
9 | obs_std=0.01,
10 | ):
11 | """
12 |
13 | """
14 | self._latent_dim = latent_dim
15 | self._observation_dim = observation_dim
16 | self._generator = generator
17 | self._obs_distrib = obs_distrib
18 | self._obs_std = obs_std
19 | self._p_distribution = self._multivariate_normal()
20 | self._build_graph()
21 |
22 | def _multivariate_normal(self):
23 | return Normal([0.] * self._latent_dim, [1.] * self._latent_dim)
24 |
25 | def _build_graph(self):
26 |
27 | with tf.variable_scope('in'):
28 | # placeholder for the input noise
29 | self.candid = tf.placeholder(tf.float32, shape=[None, self._latent_dim], name='candidate')
30 |
31 | # decode batch
32 | with tf.variable_scope('generator'):
33 | self.generated = self._generator(self.candid, self._latent_dim)
34 |
35 | self._sesh = tf.Session()
36 | init = tf.global_variables_initializer()
37 | self._sesh.run(init)
38 |
39 | def load_pretrained(self, path):
40 | generator_variables = []
41 | for v in tf.trainable_variables():
42 | if "generator" in v.name:
43 | generator_variables.append(v)
44 | saver = tf.train.Saver(generator_variables)
45 | saver.restore(self._sesh, path)
46 |
47 | def e2x(self, noise):
48 | x = self._sesh.run(self.generated,
49 | feed_dict={self.candid: noise})
50 | return x
--------------------------------------------------------------------------------
/models/vae.py:
--------------------------------------------------------------------------------
1 | import re
2 | import numpy as np
3 | import tensorflow as tf
4 | from tensorflow.contrib.distributions import Bernoulli
5 |
6 |
7 | class VAE(object):
8 |
9 | def __init__(self, latent_dim, batch_size, encoder, decoder,
10 | observation_dim=784,
11 | learning_rate=1e-4,
12 | optimizer=tf.train.RMSPropOptimizer,
13 | observation_distribution="Bernoulli", # or Gaussian
14 | observation_std=0.01):
15 |
16 | self._latent_dim = latent_dim
17 | self._batch_size = batch_size
18 | self._encode = encoder
19 | self._decode = decoder
20 | self._observation_dim = observation_dim
21 | self._learning_rate = learning_rate
22 | self._optimizer = optimizer
23 | self._observation_distribution = observation_distribution
24 | self._observation_std = observation_std
25 | self._build_graph()
26 |
27 | def _build_graph(self):
28 |
29 | with tf.variable_scope('vae'):
30 | self.x = tf.placeholder(tf.float32, shape=[None, self._observation_dim])
31 |
32 | with tf.variable_scope('encoder'):
33 | encoded = self._encode(self.x, self._latent_dim)
34 |
35 | with tf.variable_scope('latent'):
36 | self.mean = encoded[:, :self._latent_dim]
37 | logvar = encoded[:, self._latent_dim:]
38 | stddev = tf.sqrt(tf.exp(logvar))
39 | epsilon = tf.random_normal([self._batch_size, self._latent_dim])
40 | self.z = self.mean + stddev * epsilon
41 |
42 | with tf.variable_scope('decoder'):
43 | decoded = self._decode(self.z, self._observation_dim)
44 | self.obs_mean = decoded
45 | if self._observation_distribution == 'Gaussian':
46 | obs_epsilon = tf.random_normal([self._batch_size, self._observation_dim])
47 | self.sample = self.obs_mean + self._observation_std * obs_epsilon
48 | else:
49 | self.sample = Bernoulli(probs=self.obs_mean).sample()
50 |
51 |
52 | with tf.variable_scope('loss'):
53 | with tf.variable_scope('kl-divergence'):
54 | kl = self._kl_diagnormal_stdnormal(self.mean, logvar)
55 |
56 | if self._observation_distribution == 'Gaussian':
57 | with tf.variable_scope('gaussian'):
58 | obj = self._gaussian_log_likelihood(self.x, self.obs_mean, self._observation_std)
59 | else:
60 | with tf.variable_scope('bernoulli'):
61 | obj = self._bernoulli_log_likelihood(self.x, self.obs_mean)
62 |
63 | self._loss = (kl + obj) / self._batch_size
64 |
65 | with tf.variable_scope('optimizer'):
66 | optimizer = tf.train.RMSPropOptimizer(learning_rate=self._learning_rate)
67 | with tf.variable_scope('training-step'):
68 | self._train = optimizer.minimize(self._loss)
69 |
70 | self._sesh = tf.Session()
71 | init = tf.global_variables_initializer()
72 | self._sesh.run(init)
73 |
74 | @staticmethod
75 | def _kl_diagnormal_stdnormal(mu, log_var):
76 |
77 | var = tf.exp(log_var)
78 | kl = 0.5 * tf.reduce_sum(tf.square(mu) + var - 1. - log_var)
79 | return kl
80 |
81 | @staticmethod
82 | def _gaussian_log_likelihood(targets, mean, std):
83 | se = 0.5 * tf.reduce_sum(tf.square(targets - mean)) / (2*tf.square(std)) + tf.log(std)
84 | return se
85 |
86 | @staticmethod
87 | def _bernoulli_log_likelihood(targets, outputs, eps=1e-8):
88 |
89 | log_like = -tf.reduce_sum(targets * tf.log(outputs + eps)
90 | + (1. - targets) * tf.log((1. - outputs) + eps))
91 | return log_like
92 |
93 | def update(self, x):
94 | _, loss = self._sesh.run([self._train, self._loss], feed_dict={self.x: x})
95 | return loss
96 |
97 | def x2z(self, x):
98 |
99 | mean = self._sesh.run([self.mean], feed_dict={self.x: x})
100 |
101 | return np.asarray(mean).reshape(-1, self._latent_dim)
102 |
103 | def z2x(self, z):
104 |
105 | x = self._sesh.run([self.obs_mean], feed_dict={self.z: z})
106 | return x
107 |
108 | def sample(self, z):
109 |
110 | x = self._sesh.run([self.sample], feed_dict={self.z: z})
111 |
112 | return x
113 |
114 | def save_generator(self, path, prefix="in/generator"):
115 | variables = tf.trainable_variables()
116 | var_dict = {}
117 | for v in variables:
118 | if "decoder" in v.name:
119 | name = prefix+re.sub("vae/decoder", "", v.name)
120 | name = re.sub(":0", "", name)
121 | var_dict[name] = v
122 | for k, v in var_dict.items():
123 | print(k)
124 | print(v)
125 | saver = tf.train.Saver(var_dict)
126 | saver.save(self._sesh, path)
--------------------------------------------------------------------------------
/plots/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuga214/IMPLEMENTATION_Variational-Auto-Encoder/6fb103bd44e2887b827861d5026f054802a76bd0/plots/__init__.py
--------------------------------------------------------------------------------
/plots/grid_plots.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | from mpl_toolkits.axes_grid1 import ImageGrid
4 | import seaborn as sns
5 | sns.axes_style("white")
6 |
7 |
8 | def show_samples(images, row, col, image_shape, name="Unknown", save=True, shift=False):
9 | num_images = row*col
10 | if shift:
11 | images = (images+1.)/2.
12 | fig = plt.figure(figsize=(col, row))
13 | grid = ImageGrid(fig, 111,
14 | nrows_ncols=(row, col),
15 | axes_pad=0.)
16 | for i in xrange(num_images):
17 | im = images[i].reshape(image_shape)
18 | axis = grid[i]
19 | axis.axis('off')
20 | axis.imshow(im)
21 | plt.axis('off')
22 | plt.tight_layout()
23 | if save:
24 | fig.savefig('figs/train/grid/'+name+'.png', bbox_inches="tight", pad_inches=0, format='png')
25 | else:
26 | plt.show()
27 |
28 |
29 | #From some github code
30 | def show_latent_scatter(vae, data, name="latent"):
31 | n_test = 5000
32 | batch_size = 100
33 | zs = np.zeros((n_test, 2), dtype=np.float32)
34 | labels = np.zeros(n_test)
35 | for i in range(int(n_test / batch_size)):
36 | x, y = data.test.next_batch(batch_size)
37 | labels[(100 * i):(100 * (i + 1))] = y
38 | z = vae.x2z(x)
39 | zs[(100 * i):(100 * (i + 1)), :] = z
40 |
41 | indices = np.array([np.where(labels == i)[0] for i in range(10)])
42 | classes = np.array([zs[index] for index in indices])
43 | means = np.array([np.mean(c, axis=0) for c in classes])
44 |
45 | fig = plt.figure(figsize=(10, 10))
46 | ax = fig.add_subplot(111)
47 | ax.scatter(zs[:, 0], zs[:, 1], c=labels)
48 |
49 | # annotate means
50 | for i, mean in enumerate(means):
51 | ax.annotate(str(i), xy=mean, size=16,
52 | bbox=dict(boxstyle='round', facecolor='white', alpha=0.9))
53 | # plot details
54 | ax.set_xlim([-5, 5])
55 | ax.set_ylim([-5, 5])
56 | ax.set_xticks([])
57 | ax.set_yticks([])
58 | plt.tight_layout()
59 |
60 | plt.savefig('figs/train/scatter/' + name + '.png', format='png')
61 |
--------------------------------------------------------------------------------
/providers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuga214/IMPLEMENTATION_Variational-Auto-Encoder/6fb103bd44e2887b827861d5026f054802a76bd0/providers/__init__.py
--------------------------------------------------------------------------------
/providers/anime.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import tensorflow as tf
3 | import numpy as np
4 | from tqdm import tqdm
5 | from utils.path import getfiles
6 |
7 |
8 | class Anime(object):
9 | def __init__(self, image_folder=None, npy='faces/faces.npy'):
10 | try:
11 | self.data = np.load(npy)
12 | except:
13 | self.data = self._get_data(image_folder)
14 | self.save_txt(npy)
15 |
16 | def _get_data(self, image_folder):
17 |
18 | image_names = getfiles(image_folder)
19 |
20 | images = []
21 |
22 | for image_name in tqdm(image_names):
23 | image = np.array(Image.open('{0}/{1}'.format(image_folder, image_name)))
24 | if np.ndim(image) == 3:
25 | image = (image / 255.) * 2 - 1
26 | image = np.pad(image, ((2, 2), (2, 2), (0, 0)), mode='constant')
27 | images.append(image.flatten())
28 |
29 | images = np.array(images)
30 |
31 | return images
32 |
33 | def save_txt(self, path='faces/faces.npy'):
34 | np.save(path, self.data)
35 |
36 | def get_size(self):
37 | return self.data.shape[0]
38 |
39 | def next_batch(self, num):
40 | '''
41 | Return a total of `num` random samples and labels.
42 | '''
43 | idx = np.arange(0, self.get_size())
44 | np.random.shuffle(idx)
45 | idx = idx[:num]
46 | batch = self.data[idx]
47 |
48 | return batch
--------------------------------------------------------------------------------
/sample.py:
--------------------------------------------------------------------------------
1 | from coders.vae_coding import conv_anime_decoder, conv_anime_encoder
2 | import tensorflow as tf
3 | import numpy as np
4 | from plots.grid_plots import show_samples, show_latent_scatter
5 | from providers.anime import Anime
6 | from tqdm import tqdm
7 | from models.generator import GENERATOR
8 |
9 | def main():
10 | flags = tf.flags
11 | flags.DEFINE_integer("latent_dim", 64, "Dimension of latent space.")
12 | flags.DEFINE_integer("obs_dim", 12288, "Dimension of observation space.")
13 | flags.DEFINE_integer("batch_size", 60, "Batch size.")
14 | flags.DEFINE_integer("epochs", 500, "As it said")
15 | flags.DEFINE_integer("updates_per_epoch", 100, "Really just can set to 1 if you don't like mini-batch.")
16 | FLAGS = flags.FLAGS
17 |
18 | kwargs = {
19 | 'latent_dim': FLAGS.latent_dim,
20 | 'observation_dim': FLAGS.obs_dim,
21 | 'generator': conv_anime_decoder,
22 | 'obs_distrib': 'Gaussian'
23 | }
24 | g = GENERATOR(**kwargs)
25 | g.load_pretrained("weights/vae_anime/generator")
26 |
27 | z = np.random.normal(size=[FLAGS.batch_size, FLAGS.latent_dim])
28 | samples = g.e2x(z)
29 | print samples.shape
30 | show_samples(samples, 4, 15, [64, 64, 3], name='small_samples', shift=True)
31 |
32 | if __name__ == '__main__':
33 | main()
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuga214/IMPLEMENTATION_Variational-Auto-Encoder/6fb103bd44e2887b827861d5026f054802a76bd0/utils/__init__.py
--------------------------------------------------------------------------------
/utils/path.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 | from numpy import genfromtxt
4 |
5 |
6 | def pathfinder(path):
7 | script_dir = os.path.dirname('__file__')
8 | fullpath = os.path.join(script_dir, path)
9 | return fullpath
10 |
11 |
12 | def readdata(path, pandas=False):
13 | fullpath = pathfinder(path)
14 | if not pandas:
15 | return genfromtxt(fullpath, delimiter=',')
16 | else:
17 | return pd.read_csv(fullpath)
18 |
19 |
20 | def getfiles(path):
21 | files = []
22 | for (dirpath, dirnames, filenames) in os.walk(path):
23 | files.extend(filenames)
24 | return files
--------------------------------------------------------------------------------
/vae_train.py:
--------------------------------------------------------------------------------
1 | from coders.vae_coding import fc_mnist_encoder, fc_mnist_decoder
2 | import tensorflow as tf
3 | import numpy as np
4 | from plots.grid_plots import show_samples, show_latent_scatter
5 | from tensorflow.examples.tutorials.mnist import input_data
6 | from tqdm import tqdm
7 | from models.vae import VAE
8 |
9 | """
10 | This simple implementation is heavily refer on some github code online.
11 | Such as:
12 | https://github.com/kvfrans/variational-autoencoder
13 | https://github.com/hwalsuklee/tensorflow-mnist-VAE
14 | etc
15 |
16 | The entire purpose of releasing this code is to help people understand the simple structure of VAE.
17 | """
18 |
19 | def main():
20 | flags = tf.flags
21 | flags.DEFINE_integer("latent_dim", 2, "Dimension of latent space.")
22 | flags.DEFINE_integer("batch_size", 128, "Batch size.")
23 | flags.DEFINE_integer("epochs", 500, "As it said")
24 | flags.DEFINE_integer("updates_per_epoch", 100, "Really just can set to 1 if you don't like mini-batch.")
25 | flags.DEFINE_string("data_dir", 'mnist', "Tensorflow demo data download position.")
26 | FLAGS = flags.FLAGS
27 |
28 | kwargs = {
29 | 'latent_dim': FLAGS.latent_dim,
30 | 'batch_size': FLAGS.batch_size,
31 | 'encoder': fc_mnist_encoder,
32 | 'decoder': fc_mnist_decoder
33 | }
34 | vae = VAE(**kwargs)
35 | mnist = input_data.read_data_sets(train_dir=FLAGS.data_dir)
36 | tbar = tqdm(range(FLAGS.epochs))
37 | for epoch in tbar:
38 | training_loss = 0.
39 |
40 | for _ in range(FLAGS.updates_per_epoch):
41 | x, _ = mnist.train.next_batch(FLAGS.batch_size)
42 | loss = vae.update(x)
43 | training_loss += loss
44 |
45 | training_loss /= FLAGS.updates_per_epoch
46 | s = "Loss: {:.4f}".format(training_loss)
47 | tbar.set_description(s)
48 |
49 | z = np.random.normal(size=[FLAGS.batch_size, FLAGS.latent_dim])
50 | samples = vae.z2x(z)[0]
51 | show_samples(samples, 10, 10, [28, 28], name='samples')
52 | show_latent_scatter(vae, mnist, name='latent')
53 |
54 | vae.save_generator('weights/vae_mnist/generator')
55 |
56 |
57 | if __name__ == '__main__':
58 | main()
--------------------------------------------------------------------------------
/vae_train_anime.py:
--------------------------------------------------------------------------------
1 | from coders.vae_coding import conv_anime_decoder, conv_anime_encoder
2 | import tensorflow as tf
3 | import numpy as np
4 | from plots.grid_plots import show_samples, show_latent_scatter
5 | from providers.anime import Anime
6 | from tqdm import tqdm
7 | from models.vae import VAE
8 |
9 |
10 | def main():
11 | flags = tf.flags
12 | flags.DEFINE_integer("latent_dim", 64, "Dimension of latent space.")
13 | flags.DEFINE_integer("obs_dim", 12288, "Dimension of observation space.")
14 | flags.DEFINE_integer("batch_size", 64, "Batch size.")
15 | flags.DEFINE_integer("epochs", 500, "As it said")
16 | flags.DEFINE_integer("updates_per_epoch", 100, "Really just can set to 1 if you don't like mini-batch.")
17 | FLAGS = flags.FLAGS
18 |
19 | kwargs = {
20 | 'latent_dim': FLAGS.latent_dim,
21 | 'batch_size': FLAGS.batch_size,
22 | 'observation_dim': FLAGS.obs_dim,
23 | 'encoder': conv_anime_encoder,
24 | 'decoder': conv_anime_decoder,
25 | 'observation_distribution': 'Gaussian'
26 | }
27 | vae = VAE(**kwargs)
28 | provider = Anime()
29 | tbar = tqdm(range(FLAGS.epochs))
30 | for epoch in tbar:
31 | training_loss = 0.
32 |
33 | for _ in range(FLAGS.updates_per_epoch):
34 | x = provider.next_batch(FLAGS.batch_size)
35 | loss = vae.update(x)
36 | training_loss += loss
37 |
38 | training_loss /= FLAGS.updates_per_epoch
39 | s = "Loss: {:.4f}".format(training_loss)
40 | tbar.set_description(s)
41 |
42 | z = np.random.normal(size=[FLAGS.batch_size, FLAGS.latent_dim])
43 | samples = vae.z2x(z)[0]
44 | show_samples(samples, 8, 8, [64, 64, 3], name='samples')
45 |
46 | vae.save_generator('weights/vae_anime/generator')
47 |
48 |
49 | if __name__ == '__main__':
50 | main()
--------------------------------------------------------------------------------
/weights/vae_anime/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "generator"
2 | all_model_checkpoint_paths: "generator"
3 |
--------------------------------------------------------------------------------
/weights/vae_anime/generator.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuga214/IMPLEMENTATION_Variational-Auto-Encoder/6fb103bd44e2887b827861d5026f054802a76bd0/weights/vae_anime/generator.data-00000-of-00001
--------------------------------------------------------------------------------
/weights/vae_anime/generator.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuga214/IMPLEMENTATION_Variational-Auto-Encoder/6fb103bd44e2887b827861d5026f054802a76bd0/weights/vae_anime/generator.index
--------------------------------------------------------------------------------
/weights/vae_anime/generator.meta:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuga214/IMPLEMENTATION_Variational-Auto-Encoder/6fb103bd44e2887b827861d5026f054802a76bd0/weights/vae_anime/generator.meta
--------------------------------------------------------------------------------
/weights/vae_mnist/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "generator"
2 | all_model_checkpoint_paths: "generator"
3 |
--------------------------------------------------------------------------------
/weights/vae_mnist/generator.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuga214/IMPLEMENTATION_Variational-Auto-Encoder/6fb103bd44e2887b827861d5026f054802a76bd0/weights/vae_mnist/generator.data-00000-of-00001
--------------------------------------------------------------------------------
/weights/vae_mnist/generator.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuga214/IMPLEMENTATION_Variational-Auto-Encoder/6fb103bd44e2887b827861d5026f054802a76bd0/weights/vae_mnist/generator.index
--------------------------------------------------------------------------------
/weights/vae_mnist/generator.meta:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuga214/IMPLEMENTATION_Variational-Auto-Encoder/6fb103bd44e2887b827861d5026f054802a76bd0/weights/vae_mnist/generator.meta
--------------------------------------------------------------------------------