├── LICENSE ├── README.md └── wgan_mnist.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Jonas Adler 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Minimal Wasserstein GAN 2 | ======================= 3 | 4 | This is a simple TensorFlow implementation of Wasserstein Generative Advesarial Networks applied to MNIST. 5 | 6 | Some example generated digits: 7 | 8 | ![WGAN results](https://user-images.githubusercontent.com/2202312/32365318-b0ccc44a-c079-11e7-8fb1-6b1566c0bdc4.png) 9 | 10 | 11 | How to run 12 | ---------- 13 | 14 | Simply run the file [`wgan_mnist.py`](wgan_mnist.py). Results will be displayed in real time, while full training takes about an hour using a GPU. 15 | 16 | Implementation details 17 | ---------------------- 18 | 19 | The implementation follows [Improved Training of Wasserstein GANs](https://arxiv.org/abs/1704.00028), using the network from [the accompanying code](https://github.com/igul222/improved_wgan_training). In particular both the generator and discriminator uses 3 convolutional layers with 5x5 convolutions. 20 | -------------------------------------------------------------------------------- /wgan_mnist.py: -------------------------------------------------------------------------------- 1 | """Minimal implementation of Wasserstein GAN for MNIST.""" 2 | 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import tensorflow as tf 6 | from tensorflow.contrib import layers 7 | from tensorflow.examples.tutorials.mnist import input_data 8 | 9 | 10 | session = tf.InteractiveSession() 11 | 12 | 13 | def leaky_relu(x): 14 | return tf.maximum(x, 0.2 * x) 15 | 16 | 17 | def generator(z): 18 | with tf.variable_scope('generator'): 19 | z = layers.fully_connected(z, num_outputs=4096) 20 | z = tf.reshape(z, [-1, 4, 4, 256]) 21 | 22 | z = layers.conv2d_transpose(z, num_outputs=128, kernel_size=5, stride=2) 23 | z = layers.conv2d_transpose(z, num_outputs=64, kernel_size=5, stride=2) 24 | z = layers.conv2d_transpose(z, num_outputs=1, kernel_size=5, stride=2, 25 | activation_fn=tf.nn.sigmoid) 26 | return z[:, 2:-2, 2:-2, :] 27 | 28 | 29 | def discriminator(x, reuse): 30 | with tf.variable_scope('discriminator', reuse=reuse): 31 | x = layers.conv2d(x, num_outputs=64, kernel_size=5, stride=2, 32 | activation_fn=leaky_relu) 33 | x = layers.conv2d(x, num_outputs=128, kernel_size=5, stride=2, 34 | activation_fn=leaky_relu) 35 | x = layers.conv2d(x, num_outputs=256, kernel_size=5, stride=2, 36 | activation_fn=leaky_relu) 37 | 38 | x = layers.flatten(x) 39 | return layers.fully_connected(x, num_outputs=1, activation_fn=None) 40 | 41 | 42 | with tf.name_scope('placeholders'): 43 | x_true = tf.placeholder(tf.float32, [None, 28, 28, 1]) 44 | z = tf.placeholder(tf.float32, [None, 128]) 45 | 46 | 47 | x_generated = generator(z) 48 | 49 | d_true = discriminator(x_true, reuse=False) 50 | d_generated = discriminator(x_generated, reuse=True) 51 | 52 | with tf.name_scope('regularizer'): 53 | epsilon = tf.random_uniform([50, 1, 1, 1], 0.0, 1.0) 54 | x_hat = epsilon * x_true + (1 - epsilon) * x_generated 55 | d_hat = discriminator(x_hat, reuse=True) 56 | 57 | gradients = tf.gradients(d_hat, x_hat)[0] 58 | ddx = tf.sqrt(tf.reduce_sum(gradients ** 2, axis=[1, 2])) 59 | d_regularizer = tf.reduce_mean((ddx - 1.0) ** 2) 60 | 61 | with tf.name_scope('loss'): 62 | g_loss = tf.reduce_mean(d_generated) 63 | d_loss = (tf.reduce_mean(d_true) - tf.reduce_mean(d_generated) + 64 | 10 * d_regularizer) 65 | 66 | with tf.name_scope('optimizer'): 67 | optimizer = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0, beta2=0.9) 68 | 69 | g_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='generator') 70 | g_train = optimizer.minimize(g_loss, var_list=g_vars) 71 | d_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='discriminator') 72 | d_train = optimizer.minimize(d_loss, var_list=d_vars) 73 | 74 | tf.global_variables_initializer().run() 75 | 76 | mnist = input_data.read_data_sets('MNIST_data') 77 | 78 | for i in range(20000): 79 | batch = mnist.train.next_batch(50) 80 | images = batch[0].reshape([-1, 28, 28, 1]) 81 | z_train = np.random.randn(50, 128) 82 | 83 | session.run(g_train, feed_dict={z: z_train}) 84 | for j in range(5): 85 | session.run(d_train, feed_dict={x_true: images, z: z_train}) 86 | 87 | if i % 100 == 0: 88 | print('iter={}/20000'.format(i)) 89 | z_validate = np.random.randn(1, 128) 90 | generated = x_generated.eval(feed_dict={z: z_validate}).squeeze() 91 | 92 | plt.figure('results') 93 | plt.imshow(generated, clim=[0, 1], cmap='bone') 94 | plt.pause(0.001) 95 | --------------------------------------------------------------------------------