├── .gitignore ├── CHANGELOG.md ├── LICENSE ├── README.md ├── mnist_ebgan_generate.py ├── mnist_ebgan_train.py ├── model.py └── png ├── sample.png └── sample_with_pt.png /.gitignore: -------------------------------------------------------------------------------- 1 | /asset 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | env/ 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *,cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # IPython Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # dotenv 80 | .env 81 | 82 | # virtualenv 83 | venv/ 84 | ENV/ 85 | 86 | # Spyder project settings 87 | .spyderproject 88 | 89 | # Rope project settings 90 | .ropeproject 91 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | ## 0.0.0.2 ( 2017-04-04 ) 2 | 3 | Features : 4 | 5 | 6 | Refactored : 7 | 8 | - adapted to tensorflow 1.0.0 9 | - split modeling stub from train.py to model.py 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016 Namju Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EBGAN 2 | A tensorflow implementation of Junbo et al's Energy-based generative adversarial network ( EBGAN ) paper. 3 | ( See : [https://arxiv.org/pdf/1609.03126v2.pdf](https://arxiv.org/pdf/1609.03126v2.pdf) ) 4 | My implementation is somewhat different from original papers, for example I've used convolution layers 5 | in both generator and discriminator instead of fully connected layers. 6 | I think this isn't important and will not make a big difference in the final result. 7 | 8 | ## Version 9 | 10 | Current Version : __***0.0.0.2***__ 11 | 12 | ## Dependencies ( VERSION MUST BE MATCHED EXACTLY! ) 13 | 14 | 1. tensorflow == 1.0.0 15 | 1. sugartensor == 1.0.0.2 16 | 17 | ## Training the network 18 | 19 | Execute 20 |

21 | python mnist_ebgan_train.py
22 | 
23 | to train the network. You can see the result ckpt files and log files in the 'asset/train' directory. 24 | Launch tensorboard --logdir asset/train/log to monitor training process. 25 | 26 | 27 | ## Generating image 28 | 29 | Execute 30 |

31 | python mnist_ebgan_generate.py
32 | 
33 | to generate sample image. The 'sample.png' file will be generated in the 'asset/train' directory. 34 | 35 | ## Generated image sample 36 | 37 | This image was generated by EBGAN network. 38 |

39 | 40 |

41 | 42 | ## Other resources 43 | 44 | 1. [Original GAN tensorflow implementation](https://github.com/buriburisuri/sugartensor/blob/master/sugartensor/example/mnist_gan.py) 45 | 1. [InfoGAN tensorflow implementation](https://github.com/buriburisuri/sugartensor/blob/master/sugartensor/example/mnist_info_gan.py) 46 | 1. [Supervised InfoGAN tensorflow implementation](https://github.com/buriburisuri/supervised_infogan) 47 | 48 | # Authors 49 | Namju Kim (buriburisuri@gmail.com) at Jamonglabs Co., Ltd. -------------------------------------------------------------------------------- /mnist_ebgan_generate.py: -------------------------------------------------------------------------------- 1 | import sugartensor as tf 2 | import matplotlib 3 | matplotlib.use('Agg') 4 | import matplotlib.pyplot as plt 5 | from model import * 6 | 7 | 8 | __author__ = 'namju.kim@kakaobrain.com' 9 | 10 | 11 | # set log level to debug 12 | tf.sg_verbosity(10) 13 | 14 | # 15 | # hyper parameters 16 | # 17 | 18 | batch_size = 100 19 | 20 | 21 | # random uniform seed 22 | z = tf.random_uniform((batch_size, z_dim)) 23 | 24 | # generator 25 | gen = generator(z) 26 | 27 | # 28 | # draw samples 29 | # 30 | 31 | with tf.Session() as sess: 32 | 33 | tf.sg_init(sess) 34 | 35 | # restore parameters 36 | tf.sg_restore(sess, tf.train.latest_checkpoint('asset/train'), category='generator') 37 | 38 | # run generator 39 | imgs = sess.run(gen.sg_squeeze()) 40 | 41 | # plot result 42 | _, ax = plt.subplots(10, 10, sharex=True, sharey=True) 43 | for i in range(10): 44 | for j in range(10): 45 | ax[i][j].imshow(imgs[i * 10 + j], 'gray') 46 | ax[i][j].set_axis_off() 47 | plt.savefig('asset/train/sample.png', dpi=600) 48 | tf.sg_info('Sample image saved to "asset/train/sample.png"') 49 | plt.close() 50 | -------------------------------------------------------------------------------- /mnist_ebgan_train.py: -------------------------------------------------------------------------------- 1 | import sugartensor as tf 2 | import numpy as np 3 | from model import * 4 | 5 | 6 | __author__ = 'namju.kim@kakaobrain.com' 7 | 8 | 9 | # set log level to debug 10 | tf.sg_verbosity(10) 11 | 12 | # 13 | # hyper parameters 14 | # 15 | 16 | batch_size = 128 # batch size 17 | 18 | # 19 | # inputs 20 | # 21 | 22 | # MNIST input tensor ( with QueueRunner ) 23 | data = tf.sg_data.Mnist(batch_size=batch_size) 24 | 25 | # input images 26 | x = data.train.image 27 | 28 | # random uniform seed 29 | z = tf.random_uniform((batch_size, z_dim)) 30 | 31 | # 32 | # Computational graph 33 | # 34 | 35 | # generator 36 | gen = generator(z) 37 | 38 | # add image summary 39 | tf.sg_summary_image(x, name='real') 40 | tf.sg_summary_image(gen, name='fake') 41 | 42 | # discriminator 43 | disc_real = discriminator(x) 44 | disc_fake = discriminator(gen) 45 | 46 | # 47 | # pull-away term ( PT ) regularizer 48 | # 49 | 50 | sample = gen.sg_flatten() 51 | nom = tf.matmul(sample, tf.transpose(sample, perm=[1, 0])) 52 | denom = tf.reduce_sum(tf.square(sample), reduction_indices=[1], keep_dims=True) 53 | pt = tf.square(nom/denom) 54 | pt -= tf.diag(tf.diag_part(pt)) 55 | pt = tf.reduce_sum(pt) / (batch_size * (batch_size - 1)) 56 | 57 | 58 | # 59 | # loss & train ops 60 | # 61 | 62 | # mean squared errors 63 | mse_real = tf.reduce_mean(tf.square(disc_real - x), reduction_indices=[1, 2, 3]) 64 | mse_fake = tf.reduce_mean(tf.square(disc_fake - gen), reduction_indices=[1, 2, 3]) 65 | 66 | # discriminator loss 67 | loss_disc = mse_real + tf.maximum(margin - mse_fake, 0) 68 | # generator loss + PT regularizer 69 | loss_gen = mse_fake + pt * pt_weight 70 | 71 | train_disc = tf.sg_optim(loss_disc, lr=0.001, category='discriminator') # discriminator train ops 72 | train_gen = tf.sg_optim(loss_gen, lr=0.001, category='generator') # generator train ops 73 | 74 | # add summary 75 | tf.sg_summary_loss(loss_disc, name='disc') 76 | tf.sg_summary_loss(loss_gen, name='gen') 77 | 78 | 79 | # 80 | # training 81 | # 82 | 83 | # def alternate training func 84 | @tf.sg_train_func 85 | def alt_train(sess, opt): 86 | l_disc = sess.run([loss_disc, train_disc])[0] # training discriminator 87 | l_gen = sess.run([loss_gen, train_gen])[0] # training generator 88 | return np.mean(l_disc) + np.mean(l_gen) 89 | 90 | # do training 91 | alt_train(log_interval=10, max_ep=30, ep_size=data.train.num_batch) 92 | 93 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import sugartensor as tf 2 | 3 | # 4 | # hyper parameters 5 | # 6 | 7 | z_dim = 50 # noise dimension 8 | margin = 1 # max-margin for hinge loss 9 | pt_weight = 0.1 # PT regularizer's weight 10 | 11 | 12 | # 13 | # create generator 14 | # 15 | 16 | def generator(x): 17 | 18 | reuse = len([t for t in tf.global_variables() if t.name.startswith('generator')]) > 0 19 | with tf.sg_context(name='generator', size=4, stride=2, act='leaky_relu', bn=True, reuse=reuse): 20 | 21 | # generator network 22 | res = (x.sg_dense(dim=1024, name='fc_1') 23 | .sg_dense(dim=7*7*128, name='fc_2') 24 | .sg_reshape(shape=(-1, 7, 7, 128)) 25 | .sg_upconv(dim=64, name='conv_1') 26 | .sg_upconv(dim=1, act='sigmoid', bn=False, name='conv_2')) 27 | return res 28 | 29 | 30 | # 31 | # create discriminator 32 | # 33 | 34 | def discriminator(x): 35 | 36 | reuse = len([t for t in tf.global_variables() if t.name.startswith('discriminator')]) > 0 37 | with tf.sg_context(name='discriminator', size=4, stride=2, act='leaky_relu', bn=True, reuse=reuse): 38 | res = (x.sg_conv(dim=64, name='conv_1') 39 | .sg_conv(dim=128, name='conv_2') 40 | .sg_upconv(dim=64, name='conv_3') 41 | .sg_upconv(dim=1, act='linear', name='conv_4')) 42 | 43 | return res 44 | -------------------------------------------------------------------------------- /png/sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buriburisuri/ebgan/370025905a29ccb3cad22ffdf80220dfe63f9088/png/sample.png -------------------------------------------------------------------------------- /png/sample_with_pt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buriburisuri/ebgan/370025905a29ccb3cad22ffdf80220dfe63f9088/png/sample_with_pt.png --------------------------------------------------------------------------------