├── MNIST_cDCGAN_results ├── MNIST_cDCGAN_30.png ├── MNIST_cDCGAN_generation_animation.gif └── MNIST_cDCGAN_train_hist.png ├── MNIST_cGAN_results ├── MNIST_cGAN_100.png ├── MNIST_cGAN_generation_animation.gif ├── MNIST_cGAN_train_hist.png └── raw_MNIST_10.png ├── README.md ├── tensorflow_MNIST_cDCGAN.py ├── tensorflow_MNIST_cGAN.py └── tensorflow_cGAN.png /MNIST_cDCGAN_results/MNIST_cDCGAN_30.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/tensorflow-MNIST-cGAN-cDCGAN/2558e30fcc05ba5002b2d6ff63849f8604e3ace3/MNIST_cDCGAN_results/MNIST_cDCGAN_30.png -------------------------------------------------------------------------------- /MNIST_cDCGAN_results/MNIST_cDCGAN_generation_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/tensorflow-MNIST-cGAN-cDCGAN/2558e30fcc05ba5002b2d6ff63849f8604e3ace3/MNIST_cDCGAN_results/MNIST_cDCGAN_generation_animation.gif -------------------------------------------------------------------------------- /MNIST_cDCGAN_results/MNIST_cDCGAN_train_hist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/tensorflow-MNIST-cGAN-cDCGAN/2558e30fcc05ba5002b2d6ff63849f8604e3ace3/MNIST_cDCGAN_results/MNIST_cDCGAN_train_hist.png -------------------------------------------------------------------------------- /MNIST_cGAN_results/MNIST_cGAN_100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/tensorflow-MNIST-cGAN-cDCGAN/2558e30fcc05ba5002b2d6ff63849f8604e3ace3/MNIST_cGAN_results/MNIST_cGAN_100.png -------------------------------------------------------------------------------- /MNIST_cGAN_results/MNIST_cGAN_generation_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/tensorflow-MNIST-cGAN-cDCGAN/2558e30fcc05ba5002b2d6ff63849f8604e3ace3/MNIST_cGAN_results/MNIST_cGAN_generation_animation.gif -------------------------------------------------------------------------------- /MNIST_cGAN_results/MNIST_cGAN_train_hist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/tensorflow-MNIST-cGAN-cDCGAN/2558e30fcc05ba5002b2d6ff63849f8604e3ace3/MNIST_cGAN_results/MNIST_cGAN_train_hist.png -------------------------------------------------------------------------------- /MNIST_cGAN_results/raw_MNIST_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/tensorflow-MNIST-cGAN-cDCGAN/2558e30fcc05ba5002b2d6ff63849f8604e3ace3/MNIST_cGAN_results/raw_MNIST_10.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tensorflow-MNIST-cGAN-cDCGAN 2 | Tensorflow implementation of conditional Generative Adversarial Networks (cGAN) [1] and conditional Deep Convolutional Generative Adversarial Networks (cDCGAN) for MANIST [2] dataset. 3 | 4 | * you can download 5 | - MNIST dataset: http://yann.lecun.com/exdb/mnist/ 6 | 7 | ## Implementation details 8 | * cGAN 9 | 10 | ![GAN](tensorflow_cGAN.png) 11 | 12 | ## Resutls 13 | * Generate using fixed noise (fixed_z_) 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 24 |
cGAN cDCGAN
22 | 23 |
25 | 26 | * MNIST vs Generated images 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 39 |
MNIST cGAN after 100 epochs cDCGAN after 30 epochs
36 | 37 | 38 |
40 | 41 | * Training loss 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 52 |
cGAN cDCGAN
50 | 51 |
53 | 54 | * Learning time 55 | * MNIST cGAN - Avg. per epoch: 3.21 sec; Total 100 epochs: 1800.37 sec 56 | * MNIST cDCGAN - Avg. per epoch: 53.07 sec; Total 30 epochs: 2072.29 sec 57 | 58 | ## Development Environment 59 | 60 | * Windows 7 61 | * GTX1080 ti 62 | * cuda 8.0 63 | * Python 3.5.3 64 | * tensorflow-gpu 1.2.1 65 | * numpy 1.13.1 66 | * matplotlib 2.0.2 67 | * imageio 2.2.0 68 | 69 | ## Reference 70 | 71 | [1] Mirza, Mehdi, and Simon Osindero. "Conditional generative adversarial nets." arXiv preprint arXiv:1411.1784 (2014). 72 | 73 | (Full paper: https://arxiv.org/pdf/1411.1784.pdf) 74 | 75 | [2] Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. "Gradient-based learning applied to document recognition." Proceedings of the IEEE, 86(11):2278-2324, November 1998. 76 | -------------------------------------------------------------------------------- /tensorflow_MNIST_cDCGAN.py: -------------------------------------------------------------------------------- 1 | import os, time, itertools, imageio, pickle, random 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import tensorflow as tf 5 | from tensorflow.examples.tutorials.mnist import input_data 6 | 7 | # leaky_relu 8 | def lrelu(X, leak=0.2): 9 | f1 = 0.5 * (1 + leak) 10 | f2 = 0.5 * (1 - leak) 11 | return f1 * X + f2 * tf.abs(X) 12 | 13 | # G(z) 14 | def generator(x, y_label, isTrain=True, reuse=False): 15 | with tf.variable_scope('generator', reuse=reuse): 16 | # initializer 17 | w_init = tf.truncated_normal_initializer(mean=0.0, stddev=0.02) 18 | b_init = tf.constant_initializer(0.0) 19 | 20 | # concat layer 21 | cat1 = tf.concat([x, y_label], 3) 22 | 23 | # 1st hidden layer 24 | deconv1 = tf.layers.conv2d_transpose(cat1, 256, [7, 7], strides=(1, 1), padding='valid', kernel_initializer=w_init, bias_initializer=b_init) 25 | lrelu1 = lrelu(tf.layers.batch_normalization(deconv1, training=isTrain), 0.2) 26 | 27 | # 2nd hidden layer 28 | deconv2 = tf.layers.conv2d_transpose(lrelu1, 128, [5, 5], strides=(2, 2), padding='same', kernel_initializer=w_init, bias_initializer=b_init) 29 | lrelu2 = lrelu(tf.layers.batch_normalization(deconv2, training=isTrain), 0.2) 30 | 31 | # output layer 32 | deconv3 = tf.layers.conv2d_transpose(lrelu2, 1, [5, 5], strides=(2, 2), padding='same', kernel_initializer=w_init, bias_initializer=b_init) 33 | o = tf.nn.tanh(deconv3) 34 | 35 | return o 36 | 37 | # D(x) 38 | def discriminator(x, y_fill, isTrain=True, reuse=False): 39 | with tf.variable_scope('discriminator', reuse=reuse): 40 | # initializer 41 | w_init = tf.truncated_normal_initializer(mean=0.0, stddev=0.02) 42 | b_init = tf.constant_initializer(0.0) 43 | 44 | # concat layer 45 | cat1 = tf.concat([x, y_fill], 3) 46 | 47 | # 1st hidden layer 48 | conv1 = tf.layers.conv2d(cat1, 128, [5, 5], strides=(2, 2), padding='same', kernel_initializer=w_init, bias_initializer=b_init) 49 | lrelu1 = lrelu(conv1, 0.2) 50 | 51 | # 2nd hidden layer 52 | conv2 = tf.layers.conv2d(lrelu1, 256, [5, 5], strides=(2, 2), padding='same', kernel_initializer=w_init, bias_initializer=b_init) 53 | lrelu2 = lrelu(tf.layers.batch_normalization(conv2, training=isTrain), 0.2) 54 | 55 | # output layer 56 | conv3 = tf.layers.conv2d(lrelu2, 1, [7, 7], strides=(1, 1), padding='valid', kernel_initializer=w_init) 57 | o = tf.nn.sigmoid(conv3) 58 | 59 | return o, conv3 60 | 61 | # preprocess 62 | img_size = 28 63 | onehot = np.eye(10) 64 | temp_z_ = np.random.normal(0, 1, (10, 1, 1, 100)) 65 | fixed_z_ = temp_z_ 66 | fixed_y_ = np.zeros((10, 1)) 67 | for i in range(9): 68 | fixed_z_ = np.concatenate([fixed_z_, temp_z_], 0) 69 | temp = np.ones((10, 1)) + i 70 | fixed_y_ = np.concatenate([fixed_y_, temp], 0) 71 | 72 | fixed_y_ = onehot[fixed_y_.astype(np.int32)].reshape((100, 1, 1, 10)) 73 | def show_result(num_epoch, show = False, save = False, path = 'result.png'): 74 | test_images = sess.run(G_z, {z: fixed_z_, y_label: fixed_y_, isTrain: False}) 75 | 76 | size_figure_grid = 10 77 | fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(5, 5)) 78 | for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)): 79 | ax[i, j].get_xaxis().set_visible(False) 80 | ax[i, j].get_yaxis().set_visible(False) 81 | 82 | for k in range(10*10): 83 | i = k // 10 84 | j = k % 10 85 | ax[i, j].cla() 86 | ax[i, j].imshow(np.reshape(test_images[k], (img_size, img_size)), cmap='gray') 87 | 88 | label = 'Epoch {0}'.format(num_epoch) 89 | fig.text(0.5, 0.04, label, ha='center') 90 | 91 | if save: 92 | plt.savefig(path) 93 | 94 | if show: 95 | plt.show() 96 | else: 97 | plt.close() 98 | 99 | def show_train_hist(hist, show = False, save = False, path = 'Train_hist.png'): 100 | x = range(len(hist['D_losses'])) 101 | 102 | y1 = hist['D_losses'] 103 | y2 = hist['G_losses'] 104 | 105 | plt.plot(x, y1, label='D_loss') 106 | plt.plot(x, y2, label='G_loss') 107 | 108 | plt.xlabel('Epoch') 109 | plt.ylabel('Loss') 110 | 111 | plt.legend(loc=4) 112 | plt.grid(True) 113 | plt.tight_layout() 114 | 115 | if save: 116 | plt.savefig(path) 117 | 118 | if show: 119 | plt.show() 120 | else: 121 | plt.close() 122 | 123 | # training parameters 124 | batch_size = 100 125 | # lr = 0.0002 126 | train_epoch = 30 127 | global_step = tf.Variable(0, trainable=False) 128 | lr = tf.train.exponential_decay(0.0002, global_step, 500, 0.95, staircase=True) 129 | # load MNIST 130 | mnist = input_data.read_data_sets("MNIST_data/", one_hot=True, reshape=[]) 131 | 132 | # variables : input 133 | x = tf.placeholder(tf.float32, shape=(None, img_size, img_size, 1)) 134 | z = tf.placeholder(tf.float32, shape=(None, 1, 1, 100)) 135 | y_label = tf.placeholder(tf.float32, shape=(None, 1, 1, 10)) 136 | y_fill = tf.placeholder(tf.float32, shape=(None, img_size, img_size, 10)) 137 | isTrain = tf.placeholder(dtype=tf.bool) 138 | 139 | # networks : generator 140 | G_z = generator(z, y_label, isTrain) 141 | 142 | # networks : discriminator 143 | D_real, D_real_logits = discriminator(x, y_fill, isTrain) 144 | D_fake, D_fake_logits = discriminator(G_z, y_fill, isTrain, reuse=True) 145 | 146 | # loss for each network 147 | D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_real_logits, labels=tf.ones([batch_size, 1, 1, 1]))) 148 | D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.zeros([batch_size, 1, 1, 1]))) 149 | D_loss = D_loss_real + D_loss_fake 150 | G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.ones([batch_size, 1, 1, 1]))) 151 | 152 | # trainable variables for each network 153 | T_vars = tf.trainable_variables() 154 | D_vars = [var for var in T_vars if var.name.startswith('discriminator')] 155 | G_vars = [var for var in T_vars if var.name.startswith('generator')] 156 | 157 | # optimizer for each network 158 | 159 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): 160 | optim = tf.train.AdamOptimizer(lr, beta1=0.5) 161 | D_optim = optim.minimize(D_loss, global_step=global_step, var_list=D_vars) 162 | # D_optim = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(D_loss, var_list=D_vars) 163 | G_optim = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(G_loss, var_list=G_vars) 164 | 165 | # open session and initialize all variables 166 | sess = tf.InteractiveSession() 167 | tf.global_variables_initializer().run() 168 | 169 | # MNIST resize and normalization 170 | # train_set = tf.image.resize_images(mnist.train.images, [img_size, img_size]).eval() 171 | # train_set = (train_set - 0.5) / 0.5 # normalization; range: -1 ~ 1 172 | train_set = (mnist.train.images - 0.5) / 0.5 173 | train_label = mnist.train.labels 174 | 175 | # results save folder 176 | root = 'MNIST_cDCGAN_results/' 177 | model = 'MNIST_cDCGAN_' 178 | if not os.path.isdir(root): 179 | os.mkdir(root) 180 | if not os.path.isdir(root + 'Fixed_results'): 181 | os.mkdir(root + 'Fixed_results') 182 | 183 | train_hist = {} 184 | train_hist['D_losses'] = [] 185 | train_hist['G_losses'] = [] 186 | train_hist['per_epoch_ptimes'] = [] 187 | train_hist['total_ptime'] = [] 188 | 189 | # training-loop 190 | np.random.seed(int(time.time())) 191 | print('training start!') 192 | start_time = time.time() 193 | for epoch in range(train_epoch): 194 | G_losses = [] 195 | D_losses = [] 196 | epoch_start_time = time.time() 197 | shuffle_idxs = random.sample(range(0, train_set.shape[0]), train_set.shape[0]) 198 | shuffled_set = train_set[shuffle_idxs] 199 | shuffled_label = train_label[shuffle_idxs] 200 | for iter in range(shuffled_set.shape[0] // batch_size): 201 | # update discriminator 202 | x_ = shuffled_set[iter*batch_size:(iter+1)*batch_size] 203 | y_label_ = shuffled_label[iter*batch_size:(iter+1)*batch_size].reshape([batch_size, 1, 1, 10]) 204 | y_fill_ = y_label_ * np.ones([batch_size, img_size, img_size, 10]) 205 | z_ = np.random.normal(0, 1, (batch_size, 1, 1, 100)) 206 | 207 | loss_d_, _ = sess.run([D_loss, D_optim], {x: x_, z: z_, y_fill: y_fill_, y_label: y_label_, isTrain: True}) 208 | 209 | # update generator 210 | z_ = np.random.normal(0, 1, (batch_size, 1, 1, 100)) 211 | y_ = np.random.randint(0, 9, (batch_size, 1)) 212 | y_label_ = onehot[y_.astype(np.int32)].reshape([batch_size, 1, 1, 10]) 213 | y_fill_ = y_label_ * np.ones([batch_size, img_size, img_size, 10]) 214 | loss_g_, _ = sess.run([G_loss, G_optim], {z: z_, x: x_, y_fill: y_fill_, y_label: y_label_, isTrain: True}) 215 | 216 | errD_fake = D_loss_fake.eval({z: z_, y_label: y_label_, y_fill: y_fill_, isTrain: False}) 217 | errD_real = D_loss_real.eval({x: x_, y_label: y_label_, y_fill: y_fill_, isTrain: False}) 218 | errG = G_loss.eval({z: z_, y_label: y_label_, y_fill: y_fill_, isTrain: False}) 219 | 220 | D_losses.append(errD_fake + errD_real) 221 | G_losses.append(errG) 222 | 223 | epoch_end_time = time.time() 224 | per_epoch_ptime = epoch_end_time - epoch_start_time 225 | print('[%d/%d] - ptime: %.2f loss_d: %.3f, loss_g: %.3f' % ((epoch + 1), train_epoch, per_epoch_ptime, np.mean(D_losses), np.mean(G_losses))) 226 | fixed_p = root + 'Fixed_results/' + model + str(epoch + 1) + '.png' 227 | show_result((epoch + 1), save=True, path=fixed_p) 228 | train_hist['D_losses'].append(np.mean(D_losses)) 229 | train_hist['G_losses'].append(np.mean(G_losses)) 230 | train_hist['per_epoch_ptimes'].append(per_epoch_ptime) 231 | 232 | end_time = time.time() 233 | total_ptime = end_time - start_time 234 | train_hist['total_ptime'].append(total_ptime) 235 | 236 | print('Avg per epoch ptime: %.2f, total %d epochs ptime: %.2f' % (np.mean(train_hist['per_epoch_ptimes']), train_epoch, total_ptime)) 237 | print("Training finish!... save training results") 238 | with open(root + model + 'train_hist.pkl', 'wb') as f: 239 | pickle.dump(train_hist, f) 240 | 241 | show_train_hist(train_hist, save=True, path=root + model + 'train_hist.png') 242 | 243 | images = [] 244 | for e in range(train_epoch): 245 | img_name = root + 'Fixed_results/' + model + str(e + 1) + '.png' 246 | images.append(imageio.imread(img_name)) 247 | imageio.mimsave(root + model + 'generation_animation.gif', images, fps=5) 248 | 249 | sess.close() -------------------------------------------------------------------------------- /tensorflow_MNIST_cGAN.py: -------------------------------------------------------------------------------- 1 | import os, time, itertools, imageio, pickle 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import tensorflow as tf 5 | from tensorflow.examples.tutorials.mnist import input_data 6 | 7 | # leaky_relu 8 | def lrelu(X, leak=0.2): 9 | f1 = 0.5 * (1 + leak) 10 | f2 = 0.5 * (1 - leak) 11 | return f1 * X + f2 * tf.abs(X) 12 | 13 | # G(z) 14 | def generator(x, y, isTrain=True, reuse=False): 15 | with tf.variable_scope('generator', reuse=reuse): 16 | w_init = tf.contrib.layers.xavier_initializer() 17 | 18 | cat1 = tf.concat([x, y], 1) 19 | 20 | dense1 = tf.layers.dense(cat1, 128, kernel_initializer=w_init) 21 | relu1 = tf.nn.relu(dense1) 22 | 23 | dense2 = tf.layers.dense(relu1, 784, kernel_initializer=w_init) 24 | o = tf.nn.tanh(dense2) 25 | 26 | return o 27 | 28 | # D(x) 29 | def discriminator(x, y, isTrain=True, reuse=False): 30 | with tf.variable_scope('discriminator', reuse=reuse): 31 | w_init = tf.contrib.layers.xavier_initializer() 32 | 33 | cat1 = tf.concat([x, y], 1) 34 | 35 | dense1 = tf.layers.dense(cat1, 128, kernel_initializer=w_init) 36 | lrelu1 = lrelu(dense1, 0.2) 37 | 38 | dense2 = tf.layers.dense(lrelu1, 1, kernel_initializer=w_init) 39 | o = tf.nn.sigmoid(dense2) 40 | 41 | return o, dense2 42 | 43 | # label preprocess 44 | onehot = np.eye(10) 45 | 46 | temp_z_ = np.random.normal(0, 1, (10, 100)) 47 | fixed_z_ = temp_z_ 48 | fixed_y_ = np.zeros((10, 1)) 49 | 50 | for i in range(9): 51 | fixed_z_ = np.concatenate([fixed_z_, temp_z_], 0) 52 | temp = np.ones((10,1)) + i 53 | fixed_y_ = np.concatenate([fixed_y_, temp], 0) 54 | 55 | fixed_y_ = onehot[fixed_y_.astype(np.int32)].squeeze() 56 | def show_result(num_epoch, show = False, save = False, path = 'result.png'): 57 | test_images = sess.run(G_z, {z: fixed_z_, y: fixed_y_, isTrain: False}) 58 | 59 | size_figure_grid = 10 60 | fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(5, 5)) 61 | for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)): 62 | ax[i, j].get_xaxis().set_visible(False) 63 | ax[i, j].get_yaxis().set_visible(False) 64 | 65 | for k in range(size_figure_grid*size_figure_grid): 66 | i = k // size_figure_grid 67 | j = k % size_figure_grid 68 | ax[i, j].cla() 69 | ax[i, j].imshow(np.reshape(test_images[k], (28, 28)), cmap='gray') 70 | 71 | label = 'Epoch {0}'.format(num_epoch) 72 | fig.text(0.5, 0.04, label, ha='center') 73 | 74 | if save: 75 | plt.savefig(path) 76 | 77 | if show: 78 | plt.show() 79 | else: 80 | plt.close() 81 | 82 | def show_train_hist(hist, show = False, save = False, path = 'Train_hist.png'): 83 | x = range(len(hist['D_losses'])) 84 | 85 | y1 = hist['D_losses'] 86 | y2 = hist['G_losses'] 87 | 88 | plt.plot(x, y1, label='D_loss') 89 | plt.plot(x, y2, label='G_loss') 90 | 91 | plt.xlabel('Epoch') 92 | plt.ylabel('Loss') 93 | 94 | plt.legend(loc=4) 95 | plt.grid(True) 96 | plt.tight_layout() 97 | 98 | if save: 99 | plt.savefig(path) 100 | 101 | if show: 102 | plt.show() 103 | else: 104 | plt.close() 105 | 106 | # training parameters 107 | batch_size = 100 108 | lr = 0.0002 109 | train_epoch = 100 110 | 111 | # load MNIST 112 | mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) 113 | train_set = (mnist.train.images - 0.5) / 0.5 # normalization; range: -1 ~ 1 114 | train_label = mnist.train.labels 115 | 116 | # variables : input 117 | x = tf.placeholder(tf.float32, shape=(None, 784)) 118 | y = tf.placeholder(tf.float32, shape=(None, 10)) 119 | z = tf.placeholder(tf.float32, shape=(None, 100)) 120 | isTrain = tf.placeholder(dtype=tf.bool) 121 | 122 | # networks : generator 123 | G_z = generator(z, y, isTrain) 124 | 125 | # networks : discriminator 126 | D_real, D_real_logits = discriminator(x, y, isTrain) 127 | D_fake, D_fake_logits = discriminator(G_z, y, isTrain, reuse=True) 128 | 129 | # loss for each network 130 | D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_real_logits, labels=tf.ones([batch_size, 1]))) 131 | D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.zeros([batch_size, 1]))) 132 | D_loss = D_loss_real + D_loss_fake 133 | G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.ones([batch_size, 1]))) 134 | 135 | # trainable variables for each network 136 | T_vars = tf.trainable_variables() 137 | D_vars = [var for var in T_vars if var.name.startswith('discriminator')] 138 | G_vars = [var for var in T_vars if var.name.startswith('generator')] 139 | 140 | # optimizer for each network 141 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): 142 | D_optim = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(D_loss, var_list=D_vars) 143 | G_optim = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(G_loss, var_list=G_vars) 144 | 145 | # open session and initialize all variables 146 | sess = tf.InteractiveSession() 147 | tf.global_variables_initializer().run() 148 | 149 | # results save folder 150 | root = 'MNIST_cGAN_results/' 151 | model = 'MNIST_cGAN_' 152 | if not os.path.isdir(root): 153 | os.mkdir(root) 154 | if not os.path.isdir(root + 'Fixed_results'): 155 | os.mkdir(root + 'Fixed_results') 156 | 157 | train_hist = {} 158 | train_hist['D_losses'] = [] 159 | train_hist['G_losses'] = [] 160 | train_hist['per_epoch_ptimes'] = [] 161 | train_hist['total_ptime'] = [] 162 | 163 | # training-loop 164 | np.random.seed(int(time.time())) 165 | print('training start!') 166 | start_time = time.time() 167 | for epoch in range(train_epoch): 168 | G_losses = [] 169 | D_losses = [] 170 | epoch_start_time = time.time() 171 | for iter in range(len(train_set) // batch_size): 172 | # update discriminator 173 | x_ = train_set[iter * batch_size:(iter + 1) * batch_size] 174 | y_ = train_label[iter * batch_size:(iter + 1) * batch_size] 175 | 176 | z_ = np.random.normal(0, 1, (batch_size, 100)) 177 | 178 | loss_d_, _ = sess.run([D_loss, D_optim], {x: x_, y: y_, z: z_, isTrain: True}) 179 | D_losses.append(loss_d_) 180 | 181 | # update generator 182 | z_ = np.random.normal(0, 1, (batch_size, 100)) 183 | y_ = np.random.randint(0, 9, (batch_size, 1)) 184 | y_ = onehot[y_.astype(np.int32)].squeeze() 185 | loss_g_, _ = sess.run([G_loss, G_optim], {z: z_, x: x_, y: y_, isTrain: True}) 186 | G_losses.append(loss_g_) 187 | 188 | epoch_end_time = time.time() 189 | per_epoch_ptime = epoch_end_time - epoch_start_time 190 | print('[%d/%d] - ptime: %.2f loss_d: %.3f, loss_g: %.3f' % ((epoch + 1), train_epoch, per_epoch_ptime, np.mean(D_losses), np.mean(G_losses))) 191 | fixed_p = root + 'Fixed_results/' + model + str(epoch + 1) + '.png' 192 | show_result((epoch + 1), save=True, path=fixed_p) 193 | train_hist['D_losses'].append(np.mean(D_losses)) 194 | train_hist['G_losses'].append(np.mean(G_losses)) 195 | train_hist['per_epoch_ptimes'].append(per_epoch_ptime) 196 | 197 | end_time = time.time() 198 | total_ptime = end_time - start_time 199 | train_hist['total_ptime'].append(total_ptime) 200 | 201 | print('Avg per epoch ptime: %.2f, total %d epochs ptime: %.2f' % (np.mean(train_hist['per_epoch_ptimes']), train_epoch, total_ptime)) 202 | print("Training finish!... save training results") 203 | with open(root + model + 'train_hist.pkl', 'wb') as f: 204 | pickle.dump(train_hist, f) 205 | 206 | show_train_hist(train_hist, save=True, path=root + model + 'train_hist.png') 207 | 208 | images = [] 209 | for e in range(train_epoch): 210 | img_name = root + 'Fixed_results/' + model + str(e + 1) + '.png' 211 | images.append(imageio.imread(img_name)) 212 | imageio.mimsave(root + model + 'generation_animation.gif', images, fps=5) 213 | 214 | sess.close() -------------------------------------------------------------------------------- /tensorflow_cGAN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/tensorflow-MNIST-cGAN-cDCGAN/2558e30fcc05ba5002b2d6ff63849f8604e3ace3/tensorflow_cGAN.png --------------------------------------------------------------------------------