├── MNIST_cDCGAN_results
├── MNIST_cDCGAN_30.png
├── MNIST_cDCGAN_generation_animation.gif
└── MNIST_cDCGAN_train_hist.png
├── MNIST_cGAN_results
├── MNIST_cGAN_100.png
├── MNIST_cGAN_generation_animation.gif
├── MNIST_cGAN_train_hist.png
└── raw_MNIST_10.png
├── README.md
├── tensorflow_MNIST_cDCGAN.py
├── tensorflow_MNIST_cGAN.py
└── tensorflow_cGAN.png
/MNIST_cDCGAN_results/MNIST_cDCGAN_30.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/tensorflow-MNIST-cGAN-cDCGAN/2558e30fcc05ba5002b2d6ff63849f8604e3ace3/MNIST_cDCGAN_results/MNIST_cDCGAN_30.png
--------------------------------------------------------------------------------
/MNIST_cDCGAN_results/MNIST_cDCGAN_generation_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/tensorflow-MNIST-cGAN-cDCGAN/2558e30fcc05ba5002b2d6ff63849f8604e3ace3/MNIST_cDCGAN_results/MNIST_cDCGAN_generation_animation.gif
--------------------------------------------------------------------------------
/MNIST_cDCGAN_results/MNIST_cDCGAN_train_hist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/tensorflow-MNIST-cGAN-cDCGAN/2558e30fcc05ba5002b2d6ff63849f8604e3ace3/MNIST_cDCGAN_results/MNIST_cDCGAN_train_hist.png
--------------------------------------------------------------------------------
/MNIST_cGAN_results/MNIST_cGAN_100.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/tensorflow-MNIST-cGAN-cDCGAN/2558e30fcc05ba5002b2d6ff63849f8604e3ace3/MNIST_cGAN_results/MNIST_cGAN_100.png
--------------------------------------------------------------------------------
/MNIST_cGAN_results/MNIST_cGAN_generation_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/tensorflow-MNIST-cGAN-cDCGAN/2558e30fcc05ba5002b2d6ff63849f8604e3ace3/MNIST_cGAN_results/MNIST_cGAN_generation_animation.gif
--------------------------------------------------------------------------------
/MNIST_cGAN_results/MNIST_cGAN_train_hist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/tensorflow-MNIST-cGAN-cDCGAN/2558e30fcc05ba5002b2d6ff63849f8604e3ace3/MNIST_cGAN_results/MNIST_cGAN_train_hist.png
--------------------------------------------------------------------------------
/MNIST_cGAN_results/raw_MNIST_10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/tensorflow-MNIST-cGAN-cDCGAN/2558e30fcc05ba5002b2d6ff63849f8604e3ace3/MNIST_cGAN_results/raw_MNIST_10.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # tensorflow-MNIST-cGAN-cDCGAN
2 | Tensorflow implementation of conditional Generative Adversarial Networks (cGAN) [1] and conditional Deep Convolutional Generative Adversarial Networks (cDCGAN) for MANIST [2] dataset.
3 |
4 | * you can download
5 | - MNIST dataset: http://yann.lecun.com/exdb/mnist/
6 |
7 | ## Implementation details
8 | * cGAN
9 |
10 | 
11 |
12 | ## Resutls
13 | * Generate using fixed noise (fixed_z_)
14 |
15 |
16 |
17 | cGAN |
18 | cDCGAN |
19 |
20 |
21 |
22 | |
23 | |
24 |
25 |
26 | * MNIST vs Generated images
27 |
28 |
29 |
30 | MNIST |
31 | cGAN after 100 epochs |
32 | cDCGAN after 30 epochs |
33 |
34 |
35 |
36 | |
37 | |
38 | |
39 |
40 |
41 | * Training loss
42 |
43 |
44 |
45 | cGAN |
46 | cDCGAN |
47 |
48 |
49 |
50 | |
51 | |
52 |
53 |
54 | * Learning time
55 | * MNIST cGAN - Avg. per epoch: 3.21 sec; Total 100 epochs: 1800.37 sec
56 | * MNIST cDCGAN - Avg. per epoch: 53.07 sec; Total 30 epochs: 2072.29 sec
57 |
58 | ## Development Environment
59 |
60 | * Windows 7
61 | * GTX1080 ti
62 | * cuda 8.0
63 | * Python 3.5.3
64 | * tensorflow-gpu 1.2.1
65 | * numpy 1.13.1
66 | * matplotlib 2.0.2
67 | * imageio 2.2.0
68 |
69 | ## Reference
70 |
71 | [1] Mirza, Mehdi, and Simon Osindero. "Conditional generative adversarial nets." arXiv preprint arXiv:1411.1784 (2014).
72 |
73 | (Full paper: https://arxiv.org/pdf/1411.1784.pdf)
74 |
75 | [2] Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. "Gradient-based learning applied to document recognition." Proceedings of the IEEE, 86(11):2278-2324, November 1998.
76 |
--------------------------------------------------------------------------------
/tensorflow_MNIST_cDCGAN.py:
--------------------------------------------------------------------------------
1 | import os, time, itertools, imageio, pickle, random
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | import tensorflow as tf
5 | from tensorflow.examples.tutorials.mnist import input_data
6 |
7 | # leaky_relu
8 | def lrelu(X, leak=0.2):
9 | f1 = 0.5 * (1 + leak)
10 | f2 = 0.5 * (1 - leak)
11 | return f1 * X + f2 * tf.abs(X)
12 |
13 | # G(z)
14 | def generator(x, y_label, isTrain=True, reuse=False):
15 | with tf.variable_scope('generator', reuse=reuse):
16 | # initializer
17 | w_init = tf.truncated_normal_initializer(mean=0.0, stddev=0.02)
18 | b_init = tf.constant_initializer(0.0)
19 |
20 | # concat layer
21 | cat1 = tf.concat([x, y_label], 3)
22 |
23 | # 1st hidden layer
24 | deconv1 = tf.layers.conv2d_transpose(cat1, 256, [7, 7], strides=(1, 1), padding='valid', kernel_initializer=w_init, bias_initializer=b_init)
25 | lrelu1 = lrelu(tf.layers.batch_normalization(deconv1, training=isTrain), 0.2)
26 |
27 | # 2nd hidden layer
28 | deconv2 = tf.layers.conv2d_transpose(lrelu1, 128, [5, 5], strides=(2, 2), padding='same', kernel_initializer=w_init, bias_initializer=b_init)
29 | lrelu2 = lrelu(tf.layers.batch_normalization(deconv2, training=isTrain), 0.2)
30 |
31 | # output layer
32 | deconv3 = tf.layers.conv2d_transpose(lrelu2, 1, [5, 5], strides=(2, 2), padding='same', kernel_initializer=w_init, bias_initializer=b_init)
33 | o = tf.nn.tanh(deconv3)
34 |
35 | return o
36 |
37 | # D(x)
38 | def discriminator(x, y_fill, isTrain=True, reuse=False):
39 | with tf.variable_scope('discriminator', reuse=reuse):
40 | # initializer
41 | w_init = tf.truncated_normal_initializer(mean=0.0, stddev=0.02)
42 | b_init = tf.constant_initializer(0.0)
43 |
44 | # concat layer
45 | cat1 = tf.concat([x, y_fill], 3)
46 |
47 | # 1st hidden layer
48 | conv1 = tf.layers.conv2d(cat1, 128, [5, 5], strides=(2, 2), padding='same', kernel_initializer=w_init, bias_initializer=b_init)
49 | lrelu1 = lrelu(conv1, 0.2)
50 |
51 | # 2nd hidden layer
52 | conv2 = tf.layers.conv2d(lrelu1, 256, [5, 5], strides=(2, 2), padding='same', kernel_initializer=w_init, bias_initializer=b_init)
53 | lrelu2 = lrelu(tf.layers.batch_normalization(conv2, training=isTrain), 0.2)
54 |
55 | # output layer
56 | conv3 = tf.layers.conv2d(lrelu2, 1, [7, 7], strides=(1, 1), padding='valid', kernel_initializer=w_init)
57 | o = tf.nn.sigmoid(conv3)
58 |
59 | return o, conv3
60 |
61 | # preprocess
62 | img_size = 28
63 | onehot = np.eye(10)
64 | temp_z_ = np.random.normal(0, 1, (10, 1, 1, 100))
65 | fixed_z_ = temp_z_
66 | fixed_y_ = np.zeros((10, 1))
67 | for i in range(9):
68 | fixed_z_ = np.concatenate([fixed_z_, temp_z_], 0)
69 | temp = np.ones((10, 1)) + i
70 | fixed_y_ = np.concatenate([fixed_y_, temp], 0)
71 |
72 | fixed_y_ = onehot[fixed_y_.astype(np.int32)].reshape((100, 1, 1, 10))
73 | def show_result(num_epoch, show = False, save = False, path = 'result.png'):
74 | test_images = sess.run(G_z, {z: fixed_z_, y_label: fixed_y_, isTrain: False})
75 |
76 | size_figure_grid = 10
77 | fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(5, 5))
78 | for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)):
79 | ax[i, j].get_xaxis().set_visible(False)
80 | ax[i, j].get_yaxis().set_visible(False)
81 |
82 | for k in range(10*10):
83 | i = k // 10
84 | j = k % 10
85 | ax[i, j].cla()
86 | ax[i, j].imshow(np.reshape(test_images[k], (img_size, img_size)), cmap='gray')
87 |
88 | label = 'Epoch {0}'.format(num_epoch)
89 | fig.text(0.5, 0.04, label, ha='center')
90 |
91 | if save:
92 | plt.savefig(path)
93 |
94 | if show:
95 | plt.show()
96 | else:
97 | plt.close()
98 |
99 | def show_train_hist(hist, show = False, save = False, path = 'Train_hist.png'):
100 | x = range(len(hist['D_losses']))
101 |
102 | y1 = hist['D_losses']
103 | y2 = hist['G_losses']
104 |
105 | plt.plot(x, y1, label='D_loss')
106 | plt.plot(x, y2, label='G_loss')
107 |
108 | plt.xlabel('Epoch')
109 | plt.ylabel('Loss')
110 |
111 | plt.legend(loc=4)
112 | plt.grid(True)
113 | plt.tight_layout()
114 |
115 | if save:
116 | plt.savefig(path)
117 |
118 | if show:
119 | plt.show()
120 | else:
121 | plt.close()
122 |
123 | # training parameters
124 | batch_size = 100
125 | # lr = 0.0002
126 | train_epoch = 30
127 | global_step = tf.Variable(0, trainable=False)
128 | lr = tf.train.exponential_decay(0.0002, global_step, 500, 0.95, staircase=True)
129 | # load MNIST
130 | mnist = input_data.read_data_sets("MNIST_data/", one_hot=True, reshape=[])
131 |
132 | # variables : input
133 | x = tf.placeholder(tf.float32, shape=(None, img_size, img_size, 1))
134 | z = tf.placeholder(tf.float32, shape=(None, 1, 1, 100))
135 | y_label = tf.placeholder(tf.float32, shape=(None, 1, 1, 10))
136 | y_fill = tf.placeholder(tf.float32, shape=(None, img_size, img_size, 10))
137 | isTrain = tf.placeholder(dtype=tf.bool)
138 |
139 | # networks : generator
140 | G_z = generator(z, y_label, isTrain)
141 |
142 | # networks : discriminator
143 | D_real, D_real_logits = discriminator(x, y_fill, isTrain)
144 | D_fake, D_fake_logits = discriminator(G_z, y_fill, isTrain, reuse=True)
145 |
146 | # loss for each network
147 | D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_real_logits, labels=tf.ones([batch_size, 1, 1, 1])))
148 | D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.zeros([batch_size, 1, 1, 1])))
149 | D_loss = D_loss_real + D_loss_fake
150 | G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.ones([batch_size, 1, 1, 1])))
151 |
152 | # trainable variables for each network
153 | T_vars = tf.trainable_variables()
154 | D_vars = [var for var in T_vars if var.name.startswith('discriminator')]
155 | G_vars = [var for var in T_vars if var.name.startswith('generator')]
156 |
157 | # optimizer for each network
158 |
159 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
160 | optim = tf.train.AdamOptimizer(lr, beta1=0.5)
161 | D_optim = optim.minimize(D_loss, global_step=global_step, var_list=D_vars)
162 | # D_optim = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(D_loss, var_list=D_vars)
163 | G_optim = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(G_loss, var_list=G_vars)
164 |
165 | # open session and initialize all variables
166 | sess = tf.InteractiveSession()
167 | tf.global_variables_initializer().run()
168 |
169 | # MNIST resize and normalization
170 | # train_set = tf.image.resize_images(mnist.train.images, [img_size, img_size]).eval()
171 | # train_set = (train_set - 0.5) / 0.5 # normalization; range: -1 ~ 1
172 | train_set = (mnist.train.images - 0.5) / 0.5
173 | train_label = mnist.train.labels
174 |
175 | # results save folder
176 | root = 'MNIST_cDCGAN_results/'
177 | model = 'MNIST_cDCGAN_'
178 | if not os.path.isdir(root):
179 | os.mkdir(root)
180 | if not os.path.isdir(root + 'Fixed_results'):
181 | os.mkdir(root + 'Fixed_results')
182 |
183 | train_hist = {}
184 | train_hist['D_losses'] = []
185 | train_hist['G_losses'] = []
186 | train_hist['per_epoch_ptimes'] = []
187 | train_hist['total_ptime'] = []
188 |
189 | # training-loop
190 | np.random.seed(int(time.time()))
191 | print('training start!')
192 | start_time = time.time()
193 | for epoch in range(train_epoch):
194 | G_losses = []
195 | D_losses = []
196 | epoch_start_time = time.time()
197 | shuffle_idxs = random.sample(range(0, train_set.shape[0]), train_set.shape[0])
198 | shuffled_set = train_set[shuffle_idxs]
199 | shuffled_label = train_label[shuffle_idxs]
200 | for iter in range(shuffled_set.shape[0] // batch_size):
201 | # update discriminator
202 | x_ = shuffled_set[iter*batch_size:(iter+1)*batch_size]
203 | y_label_ = shuffled_label[iter*batch_size:(iter+1)*batch_size].reshape([batch_size, 1, 1, 10])
204 | y_fill_ = y_label_ * np.ones([batch_size, img_size, img_size, 10])
205 | z_ = np.random.normal(0, 1, (batch_size, 1, 1, 100))
206 |
207 | loss_d_, _ = sess.run([D_loss, D_optim], {x: x_, z: z_, y_fill: y_fill_, y_label: y_label_, isTrain: True})
208 |
209 | # update generator
210 | z_ = np.random.normal(0, 1, (batch_size, 1, 1, 100))
211 | y_ = np.random.randint(0, 9, (batch_size, 1))
212 | y_label_ = onehot[y_.astype(np.int32)].reshape([batch_size, 1, 1, 10])
213 | y_fill_ = y_label_ * np.ones([batch_size, img_size, img_size, 10])
214 | loss_g_, _ = sess.run([G_loss, G_optim], {z: z_, x: x_, y_fill: y_fill_, y_label: y_label_, isTrain: True})
215 |
216 | errD_fake = D_loss_fake.eval({z: z_, y_label: y_label_, y_fill: y_fill_, isTrain: False})
217 | errD_real = D_loss_real.eval({x: x_, y_label: y_label_, y_fill: y_fill_, isTrain: False})
218 | errG = G_loss.eval({z: z_, y_label: y_label_, y_fill: y_fill_, isTrain: False})
219 |
220 | D_losses.append(errD_fake + errD_real)
221 | G_losses.append(errG)
222 |
223 | epoch_end_time = time.time()
224 | per_epoch_ptime = epoch_end_time - epoch_start_time
225 | print('[%d/%d] - ptime: %.2f loss_d: %.3f, loss_g: %.3f' % ((epoch + 1), train_epoch, per_epoch_ptime, np.mean(D_losses), np.mean(G_losses)))
226 | fixed_p = root + 'Fixed_results/' + model + str(epoch + 1) + '.png'
227 | show_result((epoch + 1), save=True, path=fixed_p)
228 | train_hist['D_losses'].append(np.mean(D_losses))
229 | train_hist['G_losses'].append(np.mean(G_losses))
230 | train_hist['per_epoch_ptimes'].append(per_epoch_ptime)
231 |
232 | end_time = time.time()
233 | total_ptime = end_time - start_time
234 | train_hist['total_ptime'].append(total_ptime)
235 |
236 | print('Avg per epoch ptime: %.2f, total %d epochs ptime: %.2f' % (np.mean(train_hist['per_epoch_ptimes']), train_epoch, total_ptime))
237 | print("Training finish!... save training results")
238 | with open(root + model + 'train_hist.pkl', 'wb') as f:
239 | pickle.dump(train_hist, f)
240 |
241 | show_train_hist(train_hist, save=True, path=root + model + 'train_hist.png')
242 |
243 | images = []
244 | for e in range(train_epoch):
245 | img_name = root + 'Fixed_results/' + model + str(e + 1) + '.png'
246 | images.append(imageio.imread(img_name))
247 | imageio.mimsave(root + model + 'generation_animation.gif', images, fps=5)
248 |
249 | sess.close()
--------------------------------------------------------------------------------
/tensorflow_MNIST_cGAN.py:
--------------------------------------------------------------------------------
1 | import os, time, itertools, imageio, pickle
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | import tensorflow as tf
5 | from tensorflow.examples.tutorials.mnist import input_data
6 |
7 | # leaky_relu
8 | def lrelu(X, leak=0.2):
9 | f1 = 0.5 * (1 + leak)
10 | f2 = 0.5 * (1 - leak)
11 | return f1 * X + f2 * tf.abs(X)
12 |
13 | # G(z)
14 | def generator(x, y, isTrain=True, reuse=False):
15 | with tf.variable_scope('generator', reuse=reuse):
16 | w_init = tf.contrib.layers.xavier_initializer()
17 |
18 | cat1 = tf.concat([x, y], 1)
19 |
20 | dense1 = tf.layers.dense(cat1, 128, kernel_initializer=w_init)
21 | relu1 = tf.nn.relu(dense1)
22 |
23 | dense2 = tf.layers.dense(relu1, 784, kernel_initializer=w_init)
24 | o = tf.nn.tanh(dense2)
25 |
26 | return o
27 |
28 | # D(x)
29 | def discriminator(x, y, isTrain=True, reuse=False):
30 | with tf.variable_scope('discriminator', reuse=reuse):
31 | w_init = tf.contrib.layers.xavier_initializer()
32 |
33 | cat1 = tf.concat([x, y], 1)
34 |
35 | dense1 = tf.layers.dense(cat1, 128, kernel_initializer=w_init)
36 | lrelu1 = lrelu(dense1, 0.2)
37 |
38 | dense2 = tf.layers.dense(lrelu1, 1, kernel_initializer=w_init)
39 | o = tf.nn.sigmoid(dense2)
40 |
41 | return o, dense2
42 |
43 | # label preprocess
44 | onehot = np.eye(10)
45 |
46 | temp_z_ = np.random.normal(0, 1, (10, 100))
47 | fixed_z_ = temp_z_
48 | fixed_y_ = np.zeros((10, 1))
49 |
50 | for i in range(9):
51 | fixed_z_ = np.concatenate([fixed_z_, temp_z_], 0)
52 | temp = np.ones((10,1)) + i
53 | fixed_y_ = np.concatenate([fixed_y_, temp], 0)
54 |
55 | fixed_y_ = onehot[fixed_y_.astype(np.int32)].squeeze()
56 | def show_result(num_epoch, show = False, save = False, path = 'result.png'):
57 | test_images = sess.run(G_z, {z: fixed_z_, y: fixed_y_, isTrain: False})
58 |
59 | size_figure_grid = 10
60 | fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(5, 5))
61 | for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)):
62 | ax[i, j].get_xaxis().set_visible(False)
63 | ax[i, j].get_yaxis().set_visible(False)
64 |
65 | for k in range(size_figure_grid*size_figure_grid):
66 | i = k // size_figure_grid
67 | j = k % size_figure_grid
68 | ax[i, j].cla()
69 | ax[i, j].imshow(np.reshape(test_images[k], (28, 28)), cmap='gray')
70 |
71 | label = 'Epoch {0}'.format(num_epoch)
72 | fig.text(0.5, 0.04, label, ha='center')
73 |
74 | if save:
75 | plt.savefig(path)
76 |
77 | if show:
78 | plt.show()
79 | else:
80 | plt.close()
81 |
82 | def show_train_hist(hist, show = False, save = False, path = 'Train_hist.png'):
83 | x = range(len(hist['D_losses']))
84 |
85 | y1 = hist['D_losses']
86 | y2 = hist['G_losses']
87 |
88 | plt.plot(x, y1, label='D_loss')
89 | plt.plot(x, y2, label='G_loss')
90 |
91 | plt.xlabel('Epoch')
92 | plt.ylabel('Loss')
93 |
94 | plt.legend(loc=4)
95 | plt.grid(True)
96 | plt.tight_layout()
97 |
98 | if save:
99 | plt.savefig(path)
100 |
101 | if show:
102 | plt.show()
103 | else:
104 | plt.close()
105 |
106 | # training parameters
107 | batch_size = 100
108 | lr = 0.0002
109 | train_epoch = 100
110 |
111 | # load MNIST
112 | mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
113 | train_set = (mnist.train.images - 0.5) / 0.5 # normalization; range: -1 ~ 1
114 | train_label = mnist.train.labels
115 |
116 | # variables : input
117 | x = tf.placeholder(tf.float32, shape=(None, 784))
118 | y = tf.placeholder(tf.float32, shape=(None, 10))
119 | z = tf.placeholder(tf.float32, shape=(None, 100))
120 | isTrain = tf.placeholder(dtype=tf.bool)
121 |
122 | # networks : generator
123 | G_z = generator(z, y, isTrain)
124 |
125 | # networks : discriminator
126 | D_real, D_real_logits = discriminator(x, y, isTrain)
127 | D_fake, D_fake_logits = discriminator(G_z, y, isTrain, reuse=True)
128 |
129 | # loss for each network
130 | D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_real_logits, labels=tf.ones([batch_size, 1])))
131 | D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.zeros([batch_size, 1])))
132 | D_loss = D_loss_real + D_loss_fake
133 | G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.ones([batch_size, 1])))
134 |
135 | # trainable variables for each network
136 | T_vars = tf.trainable_variables()
137 | D_vars = [var for var in T_vars if var.name.startswith('discriminator')]
138 | G_vars = [var for var in T_vars if var.name.startswith('generator')]
139 |
140 | # optimizer for each network
141 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
142 | D_optim = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(D_loss, var_list=D_vars)
143 | G_optim = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(G_loss, var_list=G_vars)
144 |
145 | # open session and initialize all variables
146 | sess = tf.InteractiveSession()
147 | tf.global_variables_initializer().run()
148 |
149 | # results save folder
150 | root = 'MNIST_cGAN_results/'
151 | model = 'MNIST_cGAN_'
152 | if not os.path.isdir(root):
153 | os.mkdir(root)
154 | if not os.path.isdir(root + 'Fixed_results'):
155 | os.mkdir(root + 'Fixed_results')
156 |
157 | train_hist = {}
158 | train_hist['D_losses'] = []
159 | train_hist['G_losses'] = []
160 | train_hist['per_epoch_ptimes'] = []
161 | train_hist['total_ptime'] = []
162 |
163 | # training-loop
164 | np.random.seed(int(time.time()))
165 | print('training start!')
166 | start_time = time.time()
167 | for epoch in range(train_epoch):
168 | G_losses = []
169 | D_losses = []
170 | epoch_start_time = time.time()
171 | for iter in range(len(train_set) // batch_size):
172 | # update discriminator
173 | x_ = train_set[iter * batch_size:(iter + 1) * batch_size]
174 | y_ = train_label[iter * batch_size:(iter + 1) * batch_size]
175 |
176 | z_ = np.random.normal(0, 1, (batch_size, 100))
177 |
178 | loss_d_, _ = sess.run([D_loss, D_optim], {x: x_, y: y_, z: z_, isTrain: True})
179 | D_losses.append(loss_d_)
180 |
181 | # update generator
182 | z_ = np.random.normal(0, 1, (batch_size, 100))
183 | y_ = np.random.randint(0, 9, (batch_size, 1))
184 | y_ = onehot[y_.astype(np.int32)].squeeze()
185 | loss_g_, _ = sess.run([G_loss, G_optim], {z: z_, x: x_, y: y_, isTrain: True})
186 | G_losses.append(loss_g_)
187 |
188 | epoch_end_time = time.time()
189 | per_epoch_ptime = epoch_end_time - epoch_start_time
190 | print('[%d/%d] - ptime: %.2f loss_d: %.3f, loss_g: %.3f' % ((epoch + 1), train_epoch, per_epoch_ptime, np.mean(D_losses), np.mean(G_losses)))
191 | fixed_p = root + 'Fixed_results/' + model + str(epoch + 1) + '.png'
192 | show_result((epoch + 1), save=True, path=fixed_p)
193 | train_hist['D_losses'].append(np.mean(D_losses))
194 | train_hist['G_losses'].append(np.mean(G_losses))
195 | train_hist['per_epoch_ptimes'].append(per_epoch_ptime)
196 |
197 | end_time = time.time()
198 | total_ptime = end_time - start_time
199 | train_hist['total_ptime'].append(total_ptime)
200 |
201 | print('Avg per epoch ptime: %.2f, total %d epochs ptime: %.2f' % (np.mean(train_hist['per_epoch_ptimes']), train_epoch, total_ptime))
202 | print("Training finish!... save training results")
203 | with open(root + model + 'train_hist.pkl', 'wb') as f:
204 | pickle.dump(train_hist, f)
205 |
206 | show_train_hist(train_hist, save=True, path=root + model + 'train_hist.png')
207 |
208 | images = []
209 | for e in range(train_epoch):
210 | img_name = root + 'Fixed_results/' + model + str(e + 1) + '.png'
211 | images.append(imageio.imread(img_name))
212 | imageio.mimsave(root + model + 'generation_animation.gif', images, fps=5)
213 |
214 | sess.close()
--------------------------------------------------------------------------------
/tensorflow_cGAN.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/tensorflow-MNIST-cGAN-cDCGAN/2558e30fcc05ba5002b2d6ff63849f8604e3ace3/tensorflow_cGAN.png
--------------------------------------------------------------------------------