├── .gitignore
├── CHANGELOG.md
├── LICENSE
├── README.md
├── mnist_ebgan_generate.py
├── mnist_ebgan_train.py
├── model.py
└── png
├── sample.png
└── sample_with_pt.png
/.gitignore:
--------------------------------------------------------------------------------
1 | /asset
2 | # Byte-compiled / optimized / DLL files
3 | __pycache__/
4 | *.py[cod]
5 | *$py.class
6 |
7 | # C extensions
8 | *.so
9 |
10 | # Distribution / packaging
11 | .Python
12 | env/
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *,cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # IPython Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # dotenv
80 | .env
81 |
82 | # virtualenv
83 | venv/
84 | ENV/
85 |
86 | # Spyder project settings
87 | .spyderproject
88 |
89 | # Rope project settings
90 | .ropeproject
91 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | ## 0.0.0.2 ( 2017-04-04 )
2 |
3 | Features :
4 |
5 |
6 | Refactored :
7 |
8 | - adapted to tensorflow 1.0.0
9 | - split modeling stub from train.py to model.py
10 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2016 Namju Kim
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # EBGAN
2 | A tensorflow implementation of Junbo et al's Energy-based generative adversarial network ( EBGAN ) paper.
3 | ( See : [https://arxiv.org/pdf/1609.03126v2.pdf](https://arxiv.org/pdf/1609.03126v2.pdf) )
4 | My implementation is somewhat different from original papers, for example I've used convolution layers
5 | in both generator and discriminator instead of fully connected layers.
6 | I think this isn't important and will not make a big difference in the final result.
7 |
8 | ## Version
9 |
10 | Current Version : __***0.0.0.2***__
11 |
12 | ## Dependencies ( VERSION MUST BE MATCHED EXACTLY! )
13 |
14 | 1. tensorflow == 1.0.0
15 | 1. sugartensor == 1.0.0.2
16 |
17 | ## Training the network
18 |
19 | Execute
20 |
21 | python mnist_ebgan_train.py
22 |
23 | to train the network. You can see the result ckpt files and log files in the 'asset/train' directory.
24 | Launch tensorboard --logdir asset/train/log to monitor training process.
25 |
26 |
27 | ## Generating image
28 |
29 | Execute
30 |
31 | python mnist_ebgan_generate.py
32 |
33 | to generate sample image. The 'sample.png' file will be generated in the 'asset/train' directory.
34 |
35 | ## Generated image sample
36 |
37 | This image was generated by EBGAN network.
38 |
39 |
40 |
41 |
42 | ## Other resources
43 |
44 | 1. [Original GAN tensorflow implementation](https://github.com/buriburisuri/sugartensor/blob/master/sugartensor/example/mnist_gan.py)
45 | 1. [InfoGAN tensorflow implementation](https://github.com/buriburisuri/sugartensor/blob/master/sugartensor/example/mnist_info_gan.py)
46 | 1. [Supervised InfoGAN tensorflow implementation](https://github.com/buriburisuri/supervised_infogan)
47 |
48 | # Authors
49 | Namju Kim (buriburisuri@gmail.com) at Jamonglabs Co., Ltd.
--------------------------------------------------------------------------------
/mnist_ebgan_generate.py:
--------------------------------------------------------------------------------
1 | import sugartensor as tf
2 | import matplotlib
3 | matplotlib.use('Agg')
4 | import matplotlib.pyplot as plt
5 | from model import *
6 |
7 |
8 | __author__ = 'namju.kim@kakaobrain.com'
9 |
10 |
11 | # set log level to debug
12 | tf.sg_verbosity(10)
13 |
14 | #
15 | # hyper parameters
16 | #
17 |
18 | batch_size = 100
19 |
20 |
21 | # random uniform seed
22 | z = tf.random_uniform((batch_size, z_dim))
23 |
24 | # generator
25 | gen = generator(z)
26 |
27 | #
28 | # draw samples
29 | #
30 |
31 | with tf.Session() as sess:
32 |
33 | tf.sg_init(sess)
34 |
35 | # restore parameters
36 | tf.sg_restore(sess, tf.train.latest_checkpoint('asset/train'), category='generator')
37 |
38 | # run generator
39 | imgs = sess.run(gen.sg_squeeze())
40 |
41 | # plot result
42 | _, ax = plt.subplots(10, 10, sharex=True, sharey=True)
43 | for i in range(10):
44 | for j in range(10):
45 | ax[i][j].imshow(imgs[i * 10 + j], 'gray')
46 | ax[i][j].set_axis_off()
47 | plt.savefig('asset/train/sample.png', dpi=600)
48 | tf.sg_info('Sample image saved to "asset/train/sample.png"')
49 | plt.close()
50 |
--------------------------------------------------------------------------------
/mnist_ebgan_train.py:
--------------------------------------------------------------------------------
1 | import sugartensor as tf
2 | import numpy as np
3 | from model import *
4 |
5 |
6 | __author__ = 'namju.kim@kakaobrain.com'
7 |
8 |
9 | # set log level to debug
10 | tf.sg_verbosity(10)
11 |
12 | #
13 | # hyper parameters
14 | #
15 |
16 | batch_size = 128 # batch size
17 |
18 | #
19 | # inputs
20 | #
21 |
22 | # MNIST input tensor ( with QueueRunner )
23 | data = tf.sg_data.Mnist(batch_size=batch_size)
24 |
25 | # input images
26 | x = data.train.image
27 |
28 | # random uniform seed
29 | z = tf.random_uniform((batch_size, z_dim))
30 |
31 | #
32 | # Computational graph
33 | #
34 |
35 | # generator
36 | gen = generator(z)
37 |
38 | # add image summary
39 | tf.sg_summary_image(x, name='real')
40 | tf.sg_summary_image(gen, name='fake')
41 |
42 | # discriminator
43 | disc_real = discriminator(x)
44 | disc_fake = discriminator(gen)
45 |
46 | #
47 | # pull-away term ( PT ) regularizer
48 | #
49 |
50 | sample = gen.sg_flatten()
51 | nom = tf.matmul(sample, tf.transpose(sample, perm=[1, 0]))
52 | denom = tf.reduce_sum(tf.square(sample), reduction_indices=[1], keep_dims=True)
53 | pt = tf.square(nom/denom)
54 | pt -= tf.diag(tf.diag_part(pt))
55 | pt = tf.reduce_sum(pt) / (batch_size * (batch_size - 1))
56 |
57 |
58 | #
59 | # loss & train ops
60 | #
61 |
62 | # mean squared errors
63 | mse_real = tf.reduce_mean(tf.square(disc_real - x), reduction_indices=[1, 2, 3])
64 | mse_fake = tf.reduce_mean(tf.square(disc_fake - gen), reduction_indices=[1, 2, 3])
65 |
66 | # discriminator loss
67 | loss_disc = mse_real + tf.maximum(margin - mse_fake, 0)
68 | # generator loss + PT regularizer
69 | loss_gen = mse_fake + pt * pt_weight
70 |
71 | train_disc = tf.sg_optim(loss_disc, lr=0.001, category='discriminator') # discriminator train ops
72 | train_gen = tf.sg_optim(loss_gen, lr=0.001, category='generator') # generator train ops
73 |
74 | # add summary
75 | tf.sg_summary_loss(loss_disc, name='disc')
76 | tf.sg_summary_loss(loss_gen, name='gen')
77 |
78 |
79 | #
80 | # training
81 | #
82 |
83 | # def alternate training func
84 | @tf.sg_train_func
85 | def alt_train(sess, opt):
86 | l_disc = sess.run([loss_disc, train_disc])[0] # training discriminator
87 | l_gen = sess.run([loss_gen, train_gen])[0] # training generator
88 | return np.mean(l_disc) + np.mean(l_gen)
89 |
90 | # do training
91 | alt_train(log_interval=10, max_ep=30, ep_size=data.train.num_batch)
92 |
93 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import sugartensor as tf
2 |
3 | #
4 | # hyper parameters
5 | #
6 |
7 | z_dim = 50 # noise dimension
8 | margin = 1 # max-margin for hinge loss
9 | pt_weight = 0.1 # PT regularizer's weight
10 |
11 |
12 | #
13 | # create generator
14 | #
15 |
16 | def generator(x):
17 |
18 | reuse = len([t for t in tf.global_variables() if t.name.startswith('generator')]) > 0
19 | with tf.sg_context(name='generator', size=4, stride=2, act='leaky_relu', bn=True, reuse=reuse):
20 |
21 | # generator network
22 | res = (x.sg_dense(dim=1024, name='fc_1')
23 | .sg_dense(dim=7*7*128, name='fc_2')
24 | .sg_reshape(shape=(-1, 7, 7, 128))
25 | .sg_upconv(dim=64, name='conv_1')
26 | .sg_upconv(dim=1, act='sigmoid', bn=False, name='conv_2'))
27 | return res
28 |
29 |
30 | #
31 | # create discriminator
32 | #
33 |
34 | def discriminator(x):
35 |
36 | reuse = len([t for t in tf.global_variables() if t.name.startswith('discriminator')]) > 0
37 | with tf.sg_context(name='discriminator', size=4, stride=2, act='leaky_relu', bn=True, reuse=reuse):
38 | res = (x.sg_conv(dim=64, name='conv_1')
39 | .sg_conv(dim=128, name='conv_2')
40 | .sg_upconv(dim=64, name='conv_3')
41 | .sg_upconv(dim=1, act='linear', name='conv_4'))
42 |
43 | return res
44 |
--------------------------------------------------------------------------------
/png/sample.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buriburisuri/ebgan/370025905a29ccb3cad22ffdf80220dfe63f9088/png/sample.png
--------------------------------------------------------------------------------
/png/sample_with_pt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buriburisuri/ebgan/370025905a29ccb3cad22ffdf80220dfe63f9088/png/sample_with_pt.png
--------------------------------------------------------------------------------