├── README.md └── gan.py /README.md: -------------------------------------------------------------------------------- 1 | # Introduction-to-GANs-with-Python-and-TensorFlow -------------------------------------------------------------------------------- /gan.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | # In[1]: 5 | 6 | import tensorflow as tf 7 | from tensorflow.examples.tutorials.mnist import input_data 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | import matplotlib.gridspec as gridspec 11 | import os 12 | 13 | 14 | # In[2]: 15 | 16 | # Sample z from uniform distribution 17 | def sample_Z(m, n): 18 | return np.random.uniform(-1., 1., size=[m, n]) 19 | 20 | 21 | def plot(samples): 22 | fig = plt.figure(figsize=(4, 4)) 23 | gs = gridspec.GridSpec(4, 4) 24 | gs.update(wspace=0.05, hspace=0.05) 25 | 26 | for i, sample in enumerate(samples): 27 | ax = plt.subplot(gs[i]) 28 | plt.axis('off') 29 | ax.set_xticklabels([]) 30 | ax.set_yticklabels([]) 31 | ax.set_aspect('equal') 32 | plt.imshow(sample.reshape(28, 28), cmap='Greys_r') 33 | 34 | return fig 35 | 36 | 37 | # In[3]: 38 | 39 | # Declare inputs and parameters to the model 40 | 41 | # Input image, foe discrminator model. 42 | X = tf.placeholder(tf.float32, shape=[None, 784]) 43 | 44 | # Input noise for generator. 45 | Z = tf.placeholder(tf.float32, shape=[None, 100]) 46 | 47 | 48 | # In[8]: 49 | 50 | def generator(z): 51 | 52 | with tf.variable_scope("generator", reuse=tf.AUTO_REUSE): 53 | 54 | x = tf.layers.dense(z, 128, activation = tf.nn.relu) 55 | x = tf.layers.dense(z, 784) 56 | x = tf.nn.sigmoid(x) 57 | 58 | return x 59 | 60 | 61 | def discriminator(x): 62 | with tf.variable_scope("discrminator", reuse=tf.AUTO_REUSE): 63 | x = tf.layers.dense(x, 128, activation = tf.nn.relu) 64 | x = tf.layers.dense(x, 1) 65 | x = tf.nn.sigmoid(x) 66 | 67 | return x 68 | 69 | 70 | # In[10]: 71 | 72 | G_sample = generator(Z) 73 | D_real = discriminator(X) 74 | D_fake = discriminator(G_sample) 75 | 76 | D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake)) 77 | G_loss = -tf.reduce_mean(tf.log(D_fake)) 78 | 79 | # Optimizers 80 | disc_vars = [var for var in tf.trainable_variables() if var.name.startswith("disc")] 81 | gen_vars = [var for var in tf.trainable_variables() if var.name.startswith("gen")] 82 | 83 | D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=disc_vars) 84 | G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=gen_vars) 85 | 86 | 87 | # In[11]: 88 | 89 | # Batch size 90 | mb_size = 128 91 | 92 | # Dimention of input noise 93 | Z_dim = 100 94 | 95 | mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True) 96 | 97 | sess = tf.Session() 98 | sess.run(tf.global_variables_initializer()) 99 | 100 | if not os.path.exists('out2/'): 101 | os.makedirs('out2/') 102 | 103 | i = 0 104 | 105 | for it in range(1000000): 106 | 107 | # Save generated images every 1000 iterations. 108 | if it % 1000 == 0: 109 | samples = sess.run(G_sample, feed_dict={Z: sample_Z(16, Z_dim)}) 110 | 111 | fig = plot(samples) 112 | plt.savefig('out2/{}.png'.format(str(i).zfill(3)), bbox_inches='tight') 113 | i += 1 114 | plt.close(fig) 115 | 116 | 117 | # Get next batch of images. Each batch has mb_size samples. 118 | X_mb, _ = mnist.train.next_batch(mb_size) 119 | 120 | 121 | # Run disciminator solver 122 | _, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, Z: sample_Z(mb_size, Z_dim)}) 123 | 124 | # Run generator solver 125 | _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: sample_Z(mb_size, Z_dim)}) 126 | 127 | # Print loss 128 | if it % 1000 == 0: 129 | print('Iter: {}'.format(it)) 130 | print('D loss: {:.4}'. format(D_loss_curr)) 131 | 132 | 133 | 134 | 135 | --------------------------------------------------------------------------------