├── .DS_Store
├── .gitignore
├── BigGAN_128.py
├── BigGAN_256.py
├── BigGAN_512.py
├── LICENSE
├── README.md
├── assets
├── 128.png
├── 256.png
├── 512.png
├── architecture.png
└── main.png
├── main.py
├── ops.py
└── utils.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/BigGAN-Tensorflow/d64d62ecd2b0761d08ff9d8c51241e963be06183/.DS_Store
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/BigGAN_128.py:
--------------------------------------------------------------------------------
1 | import time
2 | from ops import *
3 | from utils import *
4 | from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch
5 | from tensorflow.contrib.opt import MovingAverageOptimizer
6 |
7 |
8 | class BigGAN_128(object):
9 |
10 | def __init__(self, sess, args):
11 | self.model_name = "BigGAN" # name for checkpoint
12 | self.sess = sess
13 | self.dataset_name = args.dataset
14 | self.checkpoint_dir = args.checkpoint_dir
15 | self.sample_dir = args.sample_dir
16 | self.result_dir = args.result_dir
17 | self.log_dir = args.log_dir
18 |
19 | self.epoch = args.epoch
20 | self.iteration = args.iteration
21 | self.batch_size = args.batch_size
22 | self.print_freq = args.print_freq
23 | self.save_freq = args.save_freq
24 | self.img_size = args.img_size
25 |
26 | """ Generator """
27 | self.ch = args.ch
28 | self.z_dim = args.z_dim # dimension of noise-vector
29 | self.gan_type = args.gan_type
30 |
31 | """ Discriminator """
32 | self.n_critic = args.n_critic
33 | self.sn = args.sn
34 | self.ld = args.ld
35 |
36 | self.sample_num = args.sample_num # number of generated images to be saved
37 | self.test_num = args.test_num
38 |
39 | # train
40 | self.g_learning_rate = args.g_lr
41 | self.d_learning_rate = args.d_lr
42 | self.beta1 = args.beta1
43 | self.beta2 = args.beta2
44 | self.moving_decay = args.moving_decay
45 |
46 | self.custom_dataset = False
47 |
48 | if self.dataset_name == 'mnist':
49 | self.c_dim = 1
50 | self.data = load_mnist()
51 |
52 | elif self.dataset_name == 'cifar10':
53 | self.c_dim = 3
54 | self.data = load_cifar10()
55 |
56 | else:
57 | self.c_dim = 3
58 | self.data = load_data(dataset_name=self.dataset_name)
59 | self.custom_dataset = True
60 |
61 | self.dataset_num = len(self.data)
62 |
63 | self.sample_dir = os.path.join(self.sample_dir, self.model_dir)
64 | check_folder(self.sample_dir)
65 |
66 | print()
67 |
68 | print("##### Information #####")
69 | print("# BigGAN 128")
70 | print("# gan type : ", self.gan_type)
71 | print("# dataset : ", self.dataset_name)
72 | print("# dataset number : ", self.dataset_num)
73 | print("# batch_size : ", self.batch_size)
74 | print("# epoch : ", self.epoch)
75 | print("# iteration per epoch : ", self.iteration)
76 |
77 | print()
78 |
79 | print("##### Generator #####")
80 | print("# spectral normalization : ", self.sn)
81 | print("# learning rate : ", self.g_learning_rate)
82 |
83 | print()
84 |
85 | print("##### Discriminator #####")
86 | print("# the number of critic : ", self.n_critic)
87 | print("# spectral normalization : ", self.sn)
88 | print("# learning rate : ", self.d_learning_rate)
89 |
90 | ##################################################################################
91 | # Generator
92 | ##################################################################################
93 |
94 | def generator(self, z, is_training=True, reuse=False):
95 | with tf.variable_scope("generator", reuse=reuse):
96 | # 6
97 | if self.z_dim == 128:
98 | split_dim = 20
99 | split_dim_remainder = self.z_dim - (split_dim * 5)
100 |
101 | z_split = tf.split(z, num_or_size_splits=[split_dim] * 5 + [split_dim_remainder], axis=-1)
102 |
103 | else:
104 | split_dim = self.z_dim // 6
105 | split_dim_remainder = self.z_dim - (split_dim * 6)
106 |
107 | if split_dim_remainder == 0 :
108 | z_split = tf.split(z, num_or_size_splits=[split_dim] * 6, axis=-1)
109 | else :
110 | z_split = tf.split(z, num_or_size_splits=[split_dim] * 5 + [split_dim_remainder], axis=-1)
111 |
112 |
113 | ch = 16 * self.ch
114 | x = fully_conneted(z_split[0], units=4 * 4 * ch, sn=self.sn, scope='dense')
115 | x = tf.reshape(x, shape=[-1, 4, 4, ch])
116 |
117 | x = resblock_up_condition(x, z_split[1], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_16')
118 | ch = ch // 2
119 |
120 | x = resblock_up_condition(x, z_split[2], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_8')
121 | ch = ch // 2
122 |
123 | x = resblock_up_condition(x, z_split[3], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_4')
124 | ch = ch // 2
125 |
126 | x = resblock_up_condition(x, z_split[4], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_2')
127 |
128 | # Non-Local Block
129 | x = self_attention_2(x, channels=ch, sn=self.sn, scope='self_attention')
130 | ch = ch // 2
131 |
132 | x = resblock_up_condition(x, z_split[5], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_1')
133 |
134 | x = batch_norm(x, is_training)
135 | x = relu(x)
136 | x = conv(x, channels=self.c_dim, kernel=3, stride=1, pad=1, use_bias=False, sn=self.sn, scope='G_logit')
137 |
138 | x = tanh(x)
139 |
140 | return x
141 |
142 | ##################################################################################
143 | # Discriminator
144 | ##################################################################################
145 |
146 | def discriminator(self, x, is_training=True, reuse=False):
147 | with tf.variable_scope("discriminator", reuse=reuse):
148 | ch = self.ch
149 |
150 | x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_1')
151 |
152 | # Non-Local Block
153 | x = self_attention_2(x, channels=ch, sn=self.sn, scope='self_attention')
154 | ch = ch * 2
155 |
156 | x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_2')
157 | ch = ch * 2
158 |
159 | x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_4')
160 | ch = ch * 2
161 |
162 | x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_8')
163 | ch = ch * 2
164 |
165 | x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_16')
166 |
167 | x = resblock(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock')
168 | x = relu(x)
169 |
170 | x = global_sum_pooling(x)
171 |
172 | x = fully_conneted(x, units=1, sn=self.sn, scope='D_logit')
173 |
174 | return x
175 |
176 | def gradient_penalty(self, real, fake):
177 | if self.gan_type.__contains__('dragan'):
178 | eps = tf.random_uniform(shape=tf.shape(real), minval=0., maxval=1.)
179 | _, x_var = tf.nn.moments(real, axes=[0, 1, 2, 3])
180 | x_std = tf.sqrt(x_var) # magnitude of noise decides the size of local region
181 |
182 | fake = real + 0.5 * x_std * eps
183 |
184 | alpha = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0., maxval=1.)
185 | interpolated = real + alpha * (fake - real)
186 |
187 | logit = self.discriminator(interpolated, reuse=True)
188 |
189 | grad = tf.gradients(logit, interpolated)[0] # gradient of D(interpolated)
190 | grad_norm = tf.norm(flatten(grad), axis=1) # l2 norm
191 |
192 | GP = 0
193 |
194 | # WGAN - LP
195 | if self.gan_type == 'wgan-lp':
196 | GP = self.ld * tf.reduce_mean(tf.square(tf.maximum(0.0, grad_norm - 1.)))
197 |
198 | elif self.gan_type == 'wgan-gp' or self.gan_type == 'dragan':
199 | GP = self.ld * tf.reduce_mean(tf.square(grad_norm - 1.))
200 |
201 | return GP
202 |
203 | ##################################################################################
204 | # Model
205 | ##################################################################################
206 |
207 | def build_model(self):
208 | """ Graph Input """
209 | # images
210 | Image_Data_Class = ImageData(self.img_size, self.c_dim, self.custom_dataset)
211 | inputs = tf.data.Dataset.from_tensor_slices(self.data)
212 |
213 | gpu_device = '/gpu:0'
214 | inputs = inputs.\
215 | apply(shuffle_and_repeat(self.dataset_num)).\
216 | apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).\
217 | apply(prefetch_to_device(gpu_device, self.batch_size))
218 |
219 | inputs_iterator = inputs.make_one_shot_iterator()
220 |
221 | self.inputs = inputs_iterator.get_next()
222 |
223 | # noises
224 | self.z = tf.truncated_normal(shape=[self.batch_size, 1, 1, self.z_dim], name='random_z')
225 |
226 | """ Loss Function """
227 | # output of D for real images
228 | real_logits = self.discriminator(self.inputs)
229 |
230 | # output of D for fake images
231 | fake_images = self.generator(self.z)
232 | fake_logits = self.discriminator(fake_images, reuse=True)
233 |
234 | if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan':
235 | GP = self.gradient_penalty(real=self.inputs, fake=fake_images)
236 | else:
237 | GP = 0
238 |
239 | # get loss for discriminator
240 | self.d_loss = discriminator_loss(self.gan_type, real=real_logits, fake=fake_logits) + GP
241 |
242 | # get loss for generator
243 | self.g_loss = generator_loss(self.gan_type, fake=fake_logits)
244 |
245 | """ Training """
246 | # divide trainable variables into a group for D and a group for G
247 | t_vars = tf.trainable_variables()
248 | d_vars = [var for var in t_vars if 'discriminator' in var.name]
249 | g_vars = [var for var in t_vars if 'generator' in var.name]
250 |
251 | # optimizers
252 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
253 | self.d_optim = tf.train.AdamOptimizer(self.d_learning_rate, beta1=self.beta1, beta2=self.beta2).minimize(self.d_loss, var_list=d_vars)
254 |
255 | self.opt = MovingAverageOptimizer(tf.train.AdamOptimizer(self.g_learning_rate, beta1=self.beta1, beta2=self.beta2), average_decay=self.moving_decay)
256 |
257 | self.g_optim = self.opt.minimize(self.g_loss, var_list=g_vars)
258 |
259 | """" Testing """
260 | # for test
261 | self.fake_images = self.generator(self.z, is_training=False, reuse=True)
262 |
263 | """ Summary """
264 | self.d_sum = tf.summary.scalar("d_loss", self.d_loss)
265 | self.g_sum = tf.summary.scalar("g_loss", self.g_loss)
266 |
267 | ##################################################################################
268 | # Train
269 | ##################################################################################
270 |
271 | def train(self):
272 | # initialize all variables
273 | tf.global_variables_initializer().run()
274 |
275 | # saver to save model
276 | self.saver = self.opt.swapping_saver()
277 |
278 | # summary writer
279 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_dir, self.sess.graph)
280 |
281 | # restore check-point if it exits
282 | could_load, checkpoint_counter = self.load(self.checkpoint_dir)
283 | if could_load:
284 | start_epoch = (int)(checkpoint_counter / self.iteration)
285 | start_batch_id = checkpoint_counter - start_epoch * self.iteration
286 | counter = checkpoint_counter
287 | print(" [*] Load SUCCESS")
288 | else:
289 | start_epoch = 0
290 | start_batch_id = 0
291 | counter = 1
292 | print(" [!] Load failed...")
293 |
294 | # loop for epoch
295 | start_time = time.time()
296 | past_g_loss = -1.
297 | for epoch in range(start_epoch, self.epoch):
298 | # get batch data
299 | for idx in range(start_batch_id, self.iteration):
300 |
301 | # update D network
302 | _, summary_str, d_loss = self.sess.run([self.d_optim, self.d_sum, self.d_loss])
303 | self.writer.add_summary(summary_str, counter)
304 |
305 | # update G network
306 | g_loss = None
307 | if (counter - 1) % self.n_critic == 0:
308 | _, summary_str, g_loss = self.sess.run([self.g_optim, self.g_sum, self.g_loss])
309 | self.writer.add_summary(summary_str, counter)
310 | past_g_loss = g_loss
311 |
312 | # display training status
313 | counter += 1
314 | if g_loss == None:
315 | g_loss = past_g_loss
316 | print("Epoch: [%2d] [%5d/%5d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
317 | % (epoch, idx, self.iteration, time.time() - start_time, d_loss, g_loss))
318 |
319 | # save training results for every 300 steps
320 | if np.mod(idx + 1, self.print_freq) == 0:
321 | samples = self.sess.run(self.fake_images)
322 | tot_num_samples = min(self.sample_num, self.batch_size)
323 | manifold_h = int(np.floor(np.sqrt(tot_num_samples)))
324 | manifold_w = int(np.floor(np.sqrt(tot_num_samples)))
325 | save_images(samples[:manifold_h * manifold_w, :, :, :],
326 | [manifold_h, manifold_w],
327 | './' + self.sample_dir + '/' + self.model_name + '_train_{:02d}_{:05d}.png'.format(
328 | epoch, idx + 1))
329 |
330 | if np.mod(idx + 1, self.save_freq) == 0:
331 | self.save(self.checkpoint_dir, counter)
332 |
333 | # After an epoch, start_batch_id is set to zero
334 | # non-zero value is only for the first epoch after loading pre-trained model
335 | start_batch_id = 0
336 |
337 | # save model
338 | self.save(self.checkpoint_dir, counter)
339 |
340 | # show temporal results
341 | # self.visualize_results(epoch)
342 |
343 | # save model for final step
344 | self.save(self.checkpoint_dir, counter)
345 |
346 | @property
347 | def model_dir(self):
348 | if self.sn :
349 | sn = '_sn'
350 | else :
351 | sn = ''
352 |
353 | return "{}_{}_{}_{}_{}{}".format(
354 | self.model_name, self.dataset_name, self.gan_type, self.img_size, self.z_dim, sn)
355 |
356 | def save(self, checkpoint_dir, step):
357 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
358 |
359 | if not os.path.exists(checkpoint_dir):
360 | os.makedirs(checkpoint_dir)
361 |
362 | self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step)
363 |
364 | def load(self, checkpoint_dir):
365 | print(" [*] Reading checkpoints...")
366 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
367 |
368 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
369 | if ckpt and ckpt.model_checkpoint_path:
370 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
371 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
372 | counter = int(ckpt_name.split('-')[-1])
373 | print(" [*] Success to read {}".format(ckpt_name))
374 | return True, counter
375 | else:
376 | print(" [*] Failed to find a checkpoint")
377 | return False, 0
378 |
379 | def visualize_results(self, epoch):
380 | tot_num_samples = min(self.sample_num, self.batch_size)
381 | image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
382 |
383 | """ random condition, random noise """
384 |
385 | samples = self.sess.run(self.fake_images)
386 |
387 | save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
388 | self.sample_dir + '/' + self.model_name + '_epoch%02d' % epoch + '_visualize.png')
389 |
390 | def test(self):
391 | tf.global_variables_initializer().run()
392 |
393 | self.saver = tf.train.Saver()
394 | could_load, checkpoint_counter = self.load(self.checkpoint_dir)
395 | result_dir = os.path.join(self.result_dir, self.model_dir)
396 | check_folder(result_dir)
397 |
398 | if could_load:
399 | print(" [*] Load SUCCESS")
400 | else:
401 | print(" [!] Load failed...")
402 |
403 | tot_num_samples = min(self.sample_num, self.batch_size)
404 | image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
405 |
406 | """ random condition, random noise """
407 |
408 | for i in range(self.test_num):
409 | samples = self.sess.run(self.fake_images)
410 |
411 | save_images(samples[:image_frame_dim * image_frame_dim, :, :, :],
412 | [image_frame_dim, image_frame_dim],
413 | result_dir + '/' + self.model_name + '_test_{}.png'.format(i))
414 |
--------------------------------------------------------------------------------
/BigGAN_256.py:
--------------------------------------------------------------------------------
1 | import time
2 | from ops import *
3 | from utils import *
4 | from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch
5 | from tensorflow.contrib.opt import MovingAverageOptimizer
6 |
7 |
8 | class BigGAN_256(object):
9 |
10 | def __init__(self, sess, args):
11 | self.model_name = "BigGAN" # name for checkpoint
12 | self.sess = sess
13 | self.dataset_name = args.dataset
14 | self.checkpoint_dir = args.checkpoint_dir
15 | self.sample_dir = args.sample_dir
16 | self.result_dir = args.result_dir
17 | self.log_dir = args.log_dir
18 |
19 | self.epoch = args.epoch
20 | self.iteration = args.iteration
21 | self.batch_size = args.batch_size
22 | self.print_freq = args.print_freq
23 | self.save_freq = args.save_freq
24 | self.img_size = args.img_size
25 |
26 | """ Generator """
27 | self.ch = args.ch
28 | self.z_dim = args.z_dim # dimension of noise-vector
29 | self.gan_type = args.gan_type
30 |
31 | """ Discriminator """
32 | self.n_critic = args.n_critic
33 | self.sn = args.sn
34 | self.ld = args.ld
35 |
36 | self.sample_num = args.sample_num # number of generated images to be saved
37 | self.test_num = args.test_num
38 |
39 | # train
40 | self.g_learning_rate = args.g_lr
41 | self.d_learning_rate = args.d_lr
42 | self.beta1 = args.beta1
43 | self.beta2 = args.beta2
44 | self.moving_decay = args.moving_decay
45 |
46 | self.custom_dataset = False
47 |
48 | if self.dataset_name == 'mnist':
49 | self.c_dim = 1
50 | self.data = load_mnist()
51 |
52 | elif self.dataset_name == 'cifar10':
53 | self.c_dim = 3
54 | self.data = load_cifar10()
55 |
56 | else:
57 | self.c_dim = 3
58 | self.data = load_data(dataset_name=self.dataset_name)
59 | self.custom_dataset = True
60 |
61 | self.dataset_num = len(self.data)
62 |
63 | self.sample_dir = os.path.join(self.sample_dir, self.model_dir)
64 | check_folder(self.sample_dir)
65 |
66 | print()
67 |
68 | print("##### Information #####")
69 | print("# BigGAN 256")
70 | print("# gan type : ", self.gan_type)
71 | print("# dataset : ", self.dataset_name)
72 | print("# dataset number : ", self.dataset_num)
73 | print("# batch_size : ", self.batch_size)
74 | print("# epoch : ", self.epoch)
75 | print("# iteration per epoch : ", self.iteration)
76 |
77 | print()
78 |
79 | print("##### Generator #####")
80 | print("# spectral normalization : ", self.sn)
81 | print("# learning rate : ", self.g_learning_rate)
82 |
83 | print()
84 |
85 | print("##### Discriminator #####")
86 | print("# the number of critic : ", self.n_critic)
87 | print("# spectral normalization : ", self.sn)
88 | print("# learning rate : ", self.d_learning_rate)
89 |
90 | ##################################################################################
91 | # Generator
92 | ##################################################################################
93 |
94 | def generator(self, z, is_training=True, reuse=False):
95 | with tf.variable_scope("generator", reuse=reuse):
96 | # 7
97 | if self.z_dim == 128:
98 | split_dim = 18
99 | split_dim_remainder = self.z_dim - (split_dim * 6)
100 |
101 | z_split = tf.split(z, num_or_size_splits=[split_dim] * 6 + [split_dim_remainder], axis=-1)
102 |
103 | else:
104 | split_dim = self.z_dim // 7
105 | split_dim_remainder = self.z_dim - (split_dim * 7)
106 |
107 | if split_dim_remainder == 0 :
108 | z_split = tf.split(z, num_or_size_splits=[split_dim] * 7, axis=-1)
109 | else :
110 | z_split = tf.split(z, num_or_size_splits=[split_dim] * 6 + [split_dim_remainder], axis=-1)
111 |
112 |
113 | ch = 16 * self.ch
114 | x = fully_conneted(z_split[0], units=4 * 4 * ch, sn=self.sn, scope='dense')
115 | x = tf.reshape(x, shape=[-1, 4, 4, ch])
116 |
117 | x = resblock_up_condition(x, z_split[1], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_16')
118 | ch = ch // 2
119 |
120 | x = resblock_up_condition(x, z_split[2], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_8_0')
121 | x = resblock_up_condition(x, z_split[3], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_8_1')
122 | ch = ch // 2
123 |
124 | x = resblock_up_condition(x, z_split[4], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_4')
125 | ch = ch // 2
126 |
127 | x = resblock_up_condition(x, z_split[5], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_2')
128 |
129 | # Non-Local Block
130 | x = self_attention_2(x, channels=ch, sn=self.sn, scope='self_attention')
131 | ch = ch // 2
132 |
133 | x = resblock_up_condition(x, z_split[6], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_1')
134 |
135 | x = batch_norm(x, is_training)
136 | x = relu(x)
137 | x = conv(x, channels=self.c_dim, kernel=3, stride=1, pad=1, use_bias=False, sn=self.sn, scope='G_logit')
138 |
139 | x = tanh(x)
140 |
141 | return x
142 |
143 | ##################################################################################
144 | # Discriminator
145 | ##################################################################################
146 |
147 | def discriminator(self, x, is_training=True, reuse=False):
148 | with tf.variable_scope("discriminator", reuse=reuse):
149 | ch = self.ch
150 |
151 | x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_1')
152 | ch = ch * 2
153 |
154 | x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_2')
155 |
156 | # Non-Local Block
157 | x = self_attention_2(x, channels=ch, sn=self.sn, scope='self_attention')
158 | ch = ch * 2
159 |
160 | x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_4')
161 | ch = ch * 2
162 |
163 | x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_8_0')
164 | x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_8_1')
165 | ch = ch * 2
166 |
167 | x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_16')
168 |
169 | x = resblock(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock')
170 | x = relu(x)
171 |
172 | x = global_sum_pooling(x)
173 |
174 | x = fully_conneted(x, units=1, sn=self.sn, scope='D_logit')
175 |
176 | return x
177 |
178 | def gradient_penalty(self, real, fake):
179 | if self.gan_type.__contains__('dragan'):
180 | eps = tf.random_uniform(shape=tf.shape(real), minval=0., maxval=1.)
181 | _, x_var = tf.nn.moments(real, axes=[0, 1, 2, 3])
182 | x_std = tf.sqrt(x_var) # magnitude of noise decides the size of local region
183 |
184 | fake = real + 0.5 * x_std * eps
185 |
186 | alpha = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0., maxval=1.)
187 | interpolated = real + alpha * (fake - real)
188 |
189 | logit = self.discriminator(interpolated, reuse=True)
190 |
191 | grad = tf.gradients(logit, interpolated)[0] # gradient of D(interpolated)
192 | grad_norm = tf.norm(flatten(grad), axis=1) # l2 norm
193 |
194 | GP = 0
195 |
196 | # WGAN - LP
197 | if self.gan_type == 'wgan-lp':
198 | GP = self.ld * tf.reduce_mean(tf.square(tf.maximum(0.0, grad_norm - 1.)))
199 |
200 | elif self.gan_type == 'wgan-gp' or self.gan_type == 'dragan':
201 | GP = self.ld * tf.reduce_mean(tf.square(grad_norm - 1.))
202 |
203 | return GP
204 |
205 | ##################################################################################
206 | # Model
207 | ##################################################################################
208 |
209 | def build_model(self):
210 | """ Graph Input """
211 | # images
212 | Image_Data_Class = ImageData(self.img_size, self.c_dim, self.custom_dataset)
213 | inputs = tf.data.Dataset.from_tensor_slices(self.data)
214 |
215 | gpu_device = '/gpu:0'
216 | inputs = inputs.\
217 | apply(shuffle_and_repeat(self.dataset_num)).\
218 | apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).\
219 | apply(prefetch_to_device(gpu_device, self.batch_size))
220 |
221 | inputs_iterator = inputs.make_one_shot_iterator()
222 |
223 | self.inputs = inputs_iterator.get_next()
224 |
225 | # noises
226 | self.z = tf.truncated_normal(shape=[self.batch_size, 1, 1, self.z_dim], name='random_z')
227 |
228 | """ Loss Function """
229 | # output of D for real images
230 | real_logits = self.discriminator(self.inputs)
231 |
232 | # output of D for fake images
233 | fake_images = self.generator(self.z)
234 | fake_logits = self.discriminator(fake_images, reuse=True)
235 |
236 | if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan':
237 | GP = self.gradient_penalty(real=self.inputs, fake=fake_images)
238 | else:
239 | GP = 0
240 |
241 | # get loss for discriminator
242 | self.d_loss = discriminator_loss(self.gan_type, real=real_logits, fake=fake_logits) + GP
243 |
244 | # get loss for generator
245 | self.g_loss = generator_loss(self.gan_type, fake=fake_logits)
246 |
247 | """ Training """
248 | # divide trainable variables into a group for D and a group for G
249 | t_vars = tf.trainable_variables()
250 | d_vars = [var for var in t_vars if 'discriminator' in var.name]
251 | g_vars = [var for var in t_vars if 'generator' in var.name]
252 |
253 | # optimizers
254 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
255 | self.d_optim = tf.train.AdamOptimizer(self.d_learning_rate, beta1=self.beta1, beta2=self.beta2).minimize(self.d_loss, var_list=d_vars)
256 |
257 | self.opt = MovingAverageOptimizer(tf.train.AdamOptimizer(self.g_learning_rate, beta1=self.beta1, beta2=self.beta2), average_decay=self.moving_decay)
258 |
259 | self.g_optim = self.opt.minimize(self.g_loss, var_list=g_vars)
260 |
261 | """" Testing """
262 | # for test
263 | self.fake_images = self.generator(self.z, is_training=False, reuse=True)
264 |
265 | """ Summary """
266 | self.d_sum = tf.summary.scalar("d_loss", self.d_loss)
267 | self.g_sum = tf.summary.scalar("g_loss", self.g_loss)
268 |
269 | ##################################################################################
270 | # Train
271 | ##################################################################################
272 |
273 | def train(self):
274 | # initialize all variables
275 | tf.global_variables_initializer().run()
276 |
277 | # saver to save model
278 | self.saver = self.opt.swapping_saver()
279 |
280 | # summary writer
281 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_dir, self.sess.graph)
282 |
283 | # restore check-point if it exits
284 | could_load, checkpoint_counter = self.load(self.checkpoint_dir)
285 | if could_load:
286 | start_epoch = (int)(checkpoint_counter / self.iteration)
287 | start_batch_id = checkpoint_counter - start_epoch * self.iteration
288 | counter = checkpoint_counter
289 | print(" [*] Load SUCCESS")
290 | else:
291 | start_epoch = 0
292 | start_batch_id = 0
293 | counter = 1
294 | print(" [!] Load failed...")
295 |
296 | # loop for epoch
297 | start_time = time.time()
298 | past_g_loss = -1.
299 | for epoch in range(start_epoch, self.epoch):
300 | # get batch data
301 | for idx in range(start_batch_id, self.iteration):
302 | # update D network
303 | _, summary_str, d_loss = self.sess.run([self.d_optim, self.d_sum, self.d_loss])
304 | self.writer.add_summary(summary_str, counter)
305 |
306 | # update G network
307 | g_loss = None
308 | if (counter - 1) % self.n_critic == 0:
309 | _, summary_str, g_loss = self.sess.run([self.g_optim, self.g_sum, self.g_loss])
310 | self.writer.add_summary(summary_str, counter)
311 | past_g_loss = g_loss
312 |
313 | # display training status
314 | counter += 1
315 | if g_loss == None:
316 | g_loss = past_g_loss
317 | print("Epoch: [%2d] [%5d/%5d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
318 | % (epoch, idx, self.iteration, time.time() - start_time, d_loss, g_loss))
319 |
320 | # save training results for every 300 steps
321 | if np.mod(idx + 1, self.print_freq) == 0:
322 | samples = self.sess.run(self.fake_images)
323 | tot_num_samples = min(self.sample_num, self.batch_size)
324 | manifold_h = int(np.floor(np.sqrt(tot_num_samples)))
325 | manifold_w = int(np.floor(np.sqrt(tot_num_samples)))
326 | save_images(samples[:manifold_h * manifold_w, :, :, :],
327 | [manifold_h, manifold_w],
328 | './' + self.sample_dir + '/' + self.model_name + '_train_{:02d}_{:05d}.png'.format(
329 | epoch, idx + 1))
330 |
331 | if np.mod(idx + 1, self.save_freq) == 0:
332 | self.save(self.checkpoint_dir, counter)
333 |
334 | # After an epoch, start_batch_id is set to zero
335 | # non-zero value is only for the first epoch after loading pre-trained model
336 | start_batch_id = 0
337 |
338 | # save model
339 | self.save(self.checkpoint_dir, counter)
340 |
341 | # show temporal results
342 | # self.visualize_results(epoch)
343 |
344 | # save model for final step
345 | self.save(self.checkpoint_dir, counter)
346 |
347 | @property
348 | def model_dir(self):
349 | if self.sn :
350 | sn = '_sn'
351 | else :
352 | sn = ''
353 |
354 | return "{}_{}_{}_{}_{}{}".format(
355 | self.model_name, self.dataset_name, self.gan_type, self.img_size, self.z_dim, sn)
356 |
357 | def save(self, checkpoint_dir, step):
358 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
359 |
360 | if not os.path.exists(checkpoint_dir):
361 | os.makedirs(checkpoint_dir)
362 |
363 | self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step)
364 |
365 | def load(self, checkpoint_dir):
366 | print(" [*] Reading checkpoints...")
367 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
368 |
369 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
370 | if ckpt and ckpt.model_checkpoint_path:
371 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
372 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
373 | counter = int(ckpt_name.split('-')[-1])
374 | print(" [*] Success to read {}".format(ckpt_name))
375 | return True, counter
376 | else:
377 | print(" [*] Failed to find a checkpoint")
378 | return False, 0
379 |
380 | def visualize_results(self, epoch):
381 | tot_num_samples = min(self.sample_num, self.batch_size)
382 | image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
383 |
384 | """ random condition, random noise """
385 |
386 | samples = self.sess.run(self.fake_images)
387 |
388 | save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
389 | self.sample_dir + '/' + self.model_name + '_epoch%02d' % epoch + '_visualize.png')
390 |
391 | def test(self):
392 | tf.global_variables_initializer().run()
393 |
394 | self.saver = tf.train.Saver()
395 | could_load, checkpoint_counter = self.load(self.checkpoint_dir)
396 | result_dir = os.path.join(self.result_dir, self.model_dir)
397 | check_folder(result_dir)
398 |
399 | if could_load:
400 | print(" [*] Load SUCCESS")
401 | else:
402 | print(" [!] Load failed...")
403 |
404 | tot_num_samples = min(self.sample_num, self.batch_size)
405 | image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
406 |
407 | """ random condition, random noise """
408 |
409 | for i in range(self.test_num):
410 | samples = self.sess.run(self.fake_images)
411 |
412 | save_images(samples[:image_frame_dim * image_frame_dim, :, :, :],
413 | [image_frame_dim, image_frame_dim],
414 | result_dir + '/' + self.model_name + '_test_{}.png'.format(i))
415 |
--------------------------------------------------------------------------------
/BigGAN_512.py:
--------------------------------------------------------------------------------
1 | import time
2 | from ops import *
3 | from utils import *
4 | from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch
5 | from tensorflow.contrib.opt import MovingAverageOptimizer
6 |
7 |
8 | class BigGAN_512(object):
9 |
10 | def __init__(self, sess, args):
11 | self.model_name = "BigGAN" # name for checkpoint
12 | self.sess = sess
13 | self.dataset_name = args.dataset
14 | self.checkpoint_dir = args.checkpoint_dir
15 | self.sample_dir = args.sample_dir
16 | self.result_dir = args.result_dir
17 | self.log_dir = args.log_dir
18 |
19 | self.epoch = args.epoch
20 | self.iteration = args.iteration
21 | self.batch_size = args.batch_size
22 | self.print_freq = args.print_freq
23 | self.save_freq = args.save_freq
24 | self.img_size = args.img_size
25 |
26 | """ Generator """
27 | self.ch = args.ch
28 | self.z_dim = args.z_dim # dimension of noise-vector
29 | self.gan_type = args.gan_type
30 |
31 | """ Discriminator """
32 | self.n_critic = args.n_critic
33 | self.sn = args.sn
34 | self.ld = args.ld
35 |
36 | self.sample_num = args.sample_num # number of generated images to be saved
37 | self.test_num = args.test_num
38 |
39 | # train
40 | self.g_learning_rate = args.g_lr
41 | self.d_learning_rate = args.d_lr
42 | self.beta1 = args.beta1
43 | self.beta2 = args.beta2
44 | self.moving_decay = args.moving_decay
45 |
46 | self.custom_dataset = False
47 |
48 | if self.dataset_name == 'mnist':
49 | self.c_dim = 1
50 | self.data = load_mnist()
51 |
52 | elif self.dataset_name == 'cifar10':
53 | self.c_dim = 3
54 | self.data = load_cifar10()
55 |
56 | else:
57 | self.c_dim = 3
58 | self.data = load_data(dataset_name=self.dataset_name)
59 | self.custom_dataset = True
60 |
61 | self.dataset_num = len(self.data)
62 |
63 | self.sample_dir = os.path.join(self.sample_dir, self.model_dir)
64 | check_folder(self.sample_dir)
65 |
66 | print()
67 |
68 | print("##### Information #####")
69 | print("# BigGAN 512")
70 | print("# gan type : ", self.gan_type)
71 | print("# dataset : ", self.dataset_name)
72 | print("# dataset number : ", self.dataset_num)
73 | print("# batch_size : ", self.batch_size)
74 | print("# epoch : ", self.epoch)
75 | print("# iteration per epoch : ", self.iteration)
76 |
77 | print()
78 |
79 | print("##### Generator #####")
80 | print("# spectral normalization : ", self.sn)
81 | print("# learning rate : ", self.g_learning_rate)
82 |
83 | print()
84 |
85 | print("##### Discriminator #####")
86 | print("# the number of critic : ", self.n_critic)
87 | print("# spectral normalization : ", self.sn)
88 | print("# learning rate : ", self.d_learning_rate)
89 |
90 | ##################################################################################
91 | # Generator
92 | ##################################################################################
93 |
94 | def generator(self, z, is_training=True, reuse=False):
95 | with tf.variable_scope("generator", reuse=reuse):
96 | # 8
97 | if self.z_dim == 128 :
98 | split_dim = 16
99 | split_dim_remainder = self.z_dim - (split_dim * 7)
100 |
101 | z_split = tf.split(z, num_or_size_splits=[split_dim] * 7 + [split_dim_remainder], axis=-1)
102 |
103 | else :
104 | split_dim = self.z_dim // 8
105 | split_dim_remainder = self.z_dim - (split_dim * 8)
106 |
107 | if split_dim_remainder == 0 :
108 | z_split = tf.split(z, num_or_size_splits=[split_dim] * 8, axis=-1)
109 | else :
110 | z_split = tf.split(z, num_or_size_splits=[split_dim] * 7 + [split_dim_remainder], axis=-1)
111 |
112 |
113 | ch = 16 * self.ch
114 | x = fully_conneted(z_split[0], units=4 * 4 * ch, sn=self.sn, scope='dense')
115 | x = tf.reshape(x, shape=[-1, 4, 4, ch])
116 |
117 | x = resblock_up_condition(x, z_split[1], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_16')
118 | ch = ch // 2
119 |
120 | x = resblock_up_condition(x, z_split[2], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_8_0')
121 | x = resblock_up_condition(x, z_split[3], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_8_1')
122 | ch = ch // 2
123 |
124 | x = resblock_up_condition(x, z_split[4], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_4')
125 |
126 | # Non-Local Block
127 | x = self_attention_2(x, channels=ch, sn=self.sn, scope='self_attention')
128 | ch = ch // 2
129 |
130 | x = resblock_up_condition(x, z_split[5], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_2')
131 | ch = ch // 2
132 |
133 | x = resblock_up_condition(x, z_split[6], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_1_0')
134 | x = resblock_up_condition(x, z_split[7], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_1_1')
135 |
136 | x = batch_norm(x, is_training)
137 | x = relu(x)
138 | x = conv(x, channels=self.c_dim, kernel=3, stride=1, pad=1, use_bias=False, sn=self.sn, scope='G_logit')
139 |
140 | x = tanh(x)
141 |
142 | return x
143 |
144 | ##################################################################################
145 | # Discriminator
146 | ##################################################################################
147 |
148 | def discriminator(self, x, is_training=True, reuse=False):
149 | with tf.variable_scope("discriminator", reuse=reuse):
150 | ch = self.ch
151 |
152 | x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_1_0')
153 | x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_1_1')
154 | ch = ch * 2
155 |
156 | x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_2')
157 |
158 | # Non-Local Block
159 | x = self_attention_2(x, channels=ch, sn=self.sn, scope='self_attention')
160 | ch = ch * 2
161 |
162 | x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_4')
163 | ch = ch * 2
164 |
165 | x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_8_0')
166 | x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_8_1')
167 | ch = ch * 2
168 |
169 | x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_16')
170 |
171 | x = resblock(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock')
172 | x = relu(x)
173 |
174 | x = global_sum_pooling(x)
175 |
176 | x = fully_conneted(x, units=1, sn=self.sn, scope='D_logit')
177 |
178 | return x
179 |
180 | def gradient_penalty(self, real, fake):
181 | if self.gan_type.__contains__('dragan'):
182 | eps = tf.random_uniform(shape=tf.shape(real), minval=0., maxval=1.)
183 | _, x_var = tf.nn.moments(real, axes=[0, 1, 2, 3])
184 | x_std = tf.sqrt(x_var) # magnitude of noise decides the size of local region
185 |
186 | fake = real + 0.5 * x_std * eps
187 |
188 | alpha = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0., maxval=1.)
189 | interpolated = real + alpha * (fake - real)
190 |
191 | logit = self.discriminator(interpolated, reuse=True)
192 |
193 | grad = tf.gradients(logit, interpolated)[0] # gradient of D(interpolated)
194 | grad_norm = tf.norm(flatten(grad), axis=1) # l2 norm
195 |
196 | GP = 0
197 |
198 | # WGAN - LP
199 | if self.gan_type == 'wgan-lp':
200 | GP = self.ld * tf.reduce_mean(tf.square(tf.maximum(0.0, grad_norm - 1.)))
201 |
202 | elif self.gan_type == 'wgan-gp' or self.gan_type == 'dragan':
203 | GP = self.ld * tf.reduce_mean(tf.square(grad_norm - 1.))
204 |
205 | return GP
206 |
207 | ##################################################################################
208 | # Model
209 | ##################################################################################
210 |
211 | def build_model(self):
212 | """ Graph Input """
213 | # images
214 | Image_Data_Class = ImageData(self.img_size, self.c_dim, self.custom_dataset)
215 | inputs = tf.data.Dataset.from_tensor_slices(self.data)
216 |
217 | gpu_device = '/gpu:0'
218 | inputs = inputs.\
219 | apply(shuffle_and_repeat(self.dataset_num)).\
220 | apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).\
221 | apply(prefetch_to_device(gpu_device, self.batch_size))
222 |
223 | inputs_iterator = inputs.make_one_shot_iterator()
224 |
225 | self.inputs = inputs_iterator.get_next()
226 |
227 | # noises
228 | self.z = tf.truncated_normal(shape=[self.batch_size, 1, 1, self.z_dim], name='random_z')
229 |
230 | """ Loss Function """
231 | # output of D for real images
232 | real_logits = self.discriminator(self.inputs)
233 |
234 | # output of D for fake images
235 | fake_images = self.generator(self.z)
236 | fake_logits = self.discriminator(fake_images, reuse=True)
237 |
238 | if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan':
239 | GP = self.gradient_penalty(real=self.inputs, fake=fake_images)
240 | else:
241 | GP = 0
242 |
243 | # get loss for discriminator
244 | self.d_loss = discriminator_loss(self.gan_type, real=real_logits, fake=fake_logits) + GP
245 |
246 | # get loss for generator
247 | self.g_loss = generator_loss(self.gan_type, fake=fake_logits)
248 |
249 | """ Training """
250 | # divide trainable variables into a group for D and a group for G
251 | t_vars = tf.trainable_variables()
252 | d_vars = [var for var in t_vars if 'discriminator' in var.name]
253 | g_vars = [var for var in t_vars if 'generator' in var.name]
254 |
255 | # optimizers
256 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
257 | self.d_optim = tf.train.AdamOptimizer(self.d_learning_rate, beta1=self.beta1, beta2=self.beta2).minimize(self.d_loss, var_list=d_vars)
258 |
259 | self.opt = MovingAverageOptimizer(tf.train.AdamOptimizer(self.g_learning_rate, beta1=self.beta1, beta2=self.beta2), average_decay=self.moving_decay)
260 | self.g_optim = self.opt.minimize(self.g_loss, var_list=g_vars)
261 |
262 | """" Testing """
263 | # for test
264 | self.fake_images = self.generator(self.z, is_training=False, reuse=True)
265 |
266 | """ Summary """
267 | self.d_sum = tf.summary.scalar("d_loss", self.d_loss)
268 | self.g_sum = tf.summary.scalar("g_loss", self.g_loss)
269 |
270 | ##################################################################################
271 | # Train
272 | ##################################################################################
273 |
274 | def train(self):
275 | # initialize all variables
276 | tf.global_variables_initializer().run()
277 |
278 | # saver to save model
279 | self.saver = self.opt.swapping_saver()
280 |
281 | # summary writer
282 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_dir, self.sess.graph)
283 |
284 | # restore check-point if it exits
285 | could_load, checkpoint_counter = self.load(self.checkpoint_dir)
286 | if could_load:
287 | start_epoch = (int)(checkpoint_counter / self.iteration)
288 | start_batch_id = checkpoint_counter - start_epoch * self.iteration
289 | counter = checkpoint_counter
290 | print(" [*] Load SUCCESS")
291 | else:
292 | start_epoch = 0
293 | start_batch_id = 0
294 | counter = 1
295 | print(" [!] Load failed...")
296 |
297 | # loop for epoch
298 | start_time = time.time()
299 | past_g_loss = -1.
300 | for epoch in range(start_epoch, self.epoch):
301 | # get batch data
302 | for idx in range(start_batch_id, self.iteration):
303 |
304 | # update D network
305 | _, summary_str, d_loss = self.sess.run([self.d_optim, self.d_sum, self.d_loss])
306 | self.writer.add_summary(summary_str, counter)
307 |
308 | # update G network
309 | g_loss = None
310 | if (counter - 1) % self.n_critic == 0:
311 | _, summary_str, g_loss = self.sess.run([self.g_optim, self.g_sum, self.g_loss])
312 | self.writer.add_summary(summary_str, counter)
313 | past_g_loss = g_loss
314 |
315 | # display training status
316 | counter += 1
317 | if g_loss == None:
318 | g_loss = past_g_loss
319 | print("Epoch: [%2d] [%5d/%5d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
320 | % (epoch, idx, self.iteration, time.time() - start_time, d_loss, g_loss))
321 |
322 | # save training results for every 300 steps
323 | if np.mod(idx + 1, self.print_freq) == 0:
324 | samples = self.sess.run(self.fake_images)
325 | tot_num_samples = min(self.sample_num, self.batch_size)
326 | manifold_h = int(np.floor(np.sqrt(tot_num_samples)))
327 | manifold_w = int(np.floor(np.sqrt(tot_num_samples)))
328 | save_images(samples[:manifold_h * manifold_w, :, :, :],
329 | [manifold_h, manifold_w],
330 | './' + self.sample_dir + '/' + self.model_name + '_train_{:02d}_{:05d}.png'.format(
331 | epoch, idx + 1))
332 |
333 | if np.mod(idx + 1, self.save_freq) == 0:
334 | self.save(self.checkpoint_dir, counter)
335 |
336 | # After an epoch, start_batch_id is set to zero
337 | # non-zero value is only for the first epoch after loading pre-trained model
338 | start_batch_id = 0
339 |
340 | # save model
341 | self.save(self.checkpoint_dir, counter)
342 |
343 | # show temporal results
344 | # self.visualize_results(epoch)
345 |
346 | # save model for final step
347 | self.save(self.checkpoint_dir, counter)
348 |
349 | @property
350 | def model_dir(self):
351 | if self.sn :
352 | sn = '_sn'
353 | else :
354 | sn = ''
355 |
356 | return "{}_{}_{}_{}_{}{}".format(
357 | self.model_name, self.dataset_name, self.gan_type, self.img_size, self.z_dim, sn)
358 |
359 | def save(self, checkpoint_dir, step):
360 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
361 |
362 | if not os.path.exists(checkpoint_dir):
363 | os.makedirs(checkpoint_dir)
364 |
365 | self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step)
366 |
367 | def load(self, checkpoint_dir):
368 | print(" [*] Reading checkpoints...")
369 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
370 |
371 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
372 | if ckpt and ckpt.model_checkpoint_path:
373 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
374 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
375 | counter = int(ckpt_name.split('-')[-1])
376 | print(" [*] Success to read {}".format(ckpt_name))
377 | return True, counter
378 | else:
379 | print(" [*] Failed to find a checkpoint")
380 | return False, 0
381 |
382 | def visualize_results(self, epoch):
383 | tot_num_samples = min(self.sample_num, self.batch_size)
384 | image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
385 |
386 | """ random condition, random noise """
387 |
388 | samples = self.sess.run(self.fake_images)
389 |
390 | save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
391 | self.sample_dir + '/' + self.model_name + '_epoch%02d' % epoch + '_visualize.png')
392 |
393 | def test(self):
394 | tf.global_variables_initializer().run()
395 |
396 | self.saver = tf.train.Saver()
397 | could_load, checkpoint_counter = self.load(self.checkpoint_dir)
398 | result_dir = os.path.join(self.result_dir, self.model_dir)
399 | check_folder(result_dir)
400 |
401 | if could_load:
402 | print(" [*] Load SUCCESS")
403 | else:
404 | print(" [!] Load failed...")
405 |
406 | tot_num_samples = min(self.sample_num, self.batch_size)
407 | image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
408 |
409 | """ random condition, random noise """
410 |
411 | for i in range(self.test_num):
412 | samples = self.sess.run(self.fake_images)
413 |
414 | save_images(samples[:image_frame_dim * image_frame_dim, :, :, :],
415 | [image_frame_dim, image_frame_dim],
416 | result_dir + '/' + self.model_name + '_test_{}.png'.format(i))
417 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Junho Kim (1993.01.12)
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # BigGAN-Tensorflow
2 | Simple Tensorflow implementation of ["Large Scale GAN Training for High Fidelity Natural Image Synthesis" (BigGAN)](https://arxiv.org/abs/1809.11096)
3 |
4 | 
5 |
6 | ## Issue
7 | * **The paper** used `orthogonal initialization`, but `I used random normal initialization.` The reason is, when using the orthogonal initialization, it did not train properly.
8 | * I have applied a hierarchical latent space, but **not** a class embeddedding.
9 |
10 | ## Usage
11 | ### dataset
12 | * `mnist` and `cifar10` are used inside keras
13 | * For `your dataset`, put images like this:
14 |
15 | ```
16 | ├── dataset
17 | └── YOUR_DATASET_NAME
18 | ├── xxx.jpg (name, format doesn't matter)
19 | ├── yyy.png
20 | └── ...
21 | ```
22 | ### train
23 | * python main.py --phase train --dataset celebA-HQ --gan_type hinge
24 |
25 | ### test
26 | * python main.py --phase test --dataset celebA-HQ --gan_type hinge
27 |
28 | ## Architecture
29 |
30 |
31 | ### 128x128
32 |
33 |
34 | ### 256x256
35 |
36 |
37 | ### 512x512
38 |
39 |
40 | ## Author
41 | Junho Kim
42 |
--------------------------------------------------------------------------------
/assets/128.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/BigGAN-Tensorflow/d64d62ecd2b0761d08ff9d8c51241e963be06183/assets/128.png
--------------------------------------------------------------------------------
/assets/256.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/BigGAN-Tensorflow/d64d62ecd2b0761d08ff9d8c51241e963be06183/assets/256.png
--------------------------------------------------------------------------------
/assets/512.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/BigGAN-Tensorflow/d64d62ecd2b0761d08ff9d8c51241e963be06183/assets/512.png
--------------------------------------------------------------------------------
/assets/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/BigGAN-Tensorflow/d64d62ecd2b0761d08ff9d8c51241e963be06183/assets/architecture.png
--------------------------------------------------------------------------------
/assets/main.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/BigGAN-Tensorflow/d64d62ecd2b0761d08ff9d8c51241e963be06183/assets/main.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from BigGAN_512 import BigGAN_512
2 | from BigGAN_256 import BigGAN_256
3 | from BigGAN_128 import BigGAN_128
4 | import argparse
5 | from utils import *
6 |
7 | """parsing and configuration"""
8 | def parse_args():
9 | desc = "Tensorflow implementation of BigGAN"
10 | parser = argparse.ArgumentParser(description=desc)
11 | parser.add_argument('--phase', type=str, default='train', help='train or test ?')
12 | parser.add_argument('--dataset', type=str, default='celebA-HQ', help='[mnist / cifar10 / custom_dataset]')
13 |
14 | parser.add_argument('--epoch', type=int, default=50, help='The number of epochs to run')
15 | parser.add_argument('--iteration', type=int, default=10000, help='The number of training iterations')
16 | parser.add_argument('--batch_size', type=int, default=2048, help='The size of batch per gpu')
17 | parser.add_argument('--ch', type=int, default=96, help='base channel number per layer')
18 |
19 | # SAGAN
20 | # batch_size = 256
21 | # base channel = 64
22 | # epoch = 100 (1M iterations)
23 |
24 | parser.add_argument('--print_freq', type=int, default=1000, help='The number of image_print_freqy')
25 | parser.add_argument('--save_freq', type=int, default=1000, help='The number of ckpt_save_freq')
26 |
27 | parser.add_argument('--g_lr', type=float, default=0.00005, help='learning rate for generator')
28 | parser.add_argument('--d_lr', type=float, default=0.0002, help='learning rate for discriminator')
29 |
30 | # if lower batch size
31 | # g_lr = 0.0001
32 | # d_lr = 0.0004
33 |
34 | # if larger batch size
35 | # g_lr = 0.00005
36 | # d_lr = 0.0002
37 |
38 | parser.add_argument('--beta1', type=float, default=0.0, help='beta1 for Adam optimizer')
39 | parser.add_argument('--beta2', type=float, default=0.9, help='beta2 for Adam optimizer')
40 | parser.add_argument('--moving_decay', type=float, default=0.9999, help='moving average decay for generator')
41 |
42 | parser.add_argument('--z_dim', type=int, default=128, help='Dimension of noise vector')
43 | parser.add_argument('--sn', type=str2bool, default=True, help='using spectral norm')
44 |
45 | parser.add_argument('--gan_type', type=str, default='hinge', help='[gan / lsgan / wgan-gp / wgan-lp / dragan / hinge]')
46 | parser.add_argument('--ld', type=float, default=10.0, help='The gradient penalty lambda')
47 |
48 | parser.add_argument('--n_critic', type=int, default=2, help='The number of critic')
49 |
50 | parser.add_argument('--img_size', type=int, default=512, help='The size of image')
51 | parser.add_argument('--sample_num', type=int, default=64, help='The number of sample images')
52 |
53 | parser.add_argument('--test_num', type=int, default=10, help='The number of images generated by the test')
54 |
55 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint',
56 | help='Directory name to save the checkpoints')
57 | parser.add_argument('--result_dir', type=str, default='results',
58 | help='Directory name to save the generated images')
59 | parser.add_argument('--log_dir', type=str, default='logs',
60 | help='Directory name to save training logs')
61 | parser.add_argument('--sample_dir', type=str, default='samples',
62 | help='Directory name to save the samples on training')
63 |
64 | return check_args(parser.parse_args())
65 |
66 | """checking arguments"""
67 | def check_args(args):
68 | # --checkpoint_dir
69 | check_folder(args.checkpoint_dir)
70 |
71 | # --result_dir
72 | check_folder(args.result_dir)
73 |
74 | # --result_dir
75 | check_folder(args.log_dir)
76 |
77 | # --sample_dir
78 | check_folder(args.sample_dir)
79 |
80 | # --epoch
81 | try:
82 | assert args.epoch >= 1
83 | except:
84 | print('number of epochs must be larger than or equal to one')
85 |
86 | # --batch_size
87 | try:
88 | assert args.batch_size >= 1
89 | except:
90 | print('batch size must be larger than or equal to one')
91 | return args
92 |
93 |
94 | """main"""
95 | def main():
96 | # parse arguments
97 | args = parse_args()
98 | if args is None:
99 | exit()
100 |
101 | # open session
102 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
103 | # default gan = BigGAN_128
104 |
105 | if args.img_size == 512 :
106 | gan = BigGAN_512(sess, args)
107 | elif args.img_size == 256 :
108 | gan = BigGAN_256(sess, args)
109 | else :
110 | gan = BigGAN_128(sess, args)
111 |
112 | # build graph
113 | gan.build_model()
114 |
115 | # show network architecture
116 | show_all_variables()
117 |
118 | if args.phase == 'train' :
119 | # launch the graph in a session
120 | gan.train()
121 |
122 | # visualize learned generator
123 | gan.visualize_results(args.epoch - 1)
124 |
125 | print(" [*] Training finished!")
126 |
127 | if args.phase == 'test' :
128 | gan.test()
129 | print(" [*] Test finished!")
130 |
131 | if __name__ == '__main__':
132 | main()
--------------------------------------------------------------------------------
/ops.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from utils import orthogonal_regularizer_fully, orthogonal_regularizer
3 |
4 | ##################################################################################
5 | # Initialization
6 | ##################################################################################
7 |
8 | # Xavier : tf_contrib.layers.xavier_initializer()
9 | # He : tf_contrib.layers.variance_scaling_initializer()
10 | # Normal : tf.random_normal_initializer(mean=0.0, stddev=0.02)
11 | # Truncated_normal : tf.truncated_normal_initializer(mean=0.0, stddev=0.02)
12 | # Orthogonal : tf.orthogonal_initializer(1.0) / relu = sqrt(2), the others = 1.0
13 |
14 | ##################################################################################
15 | # Regularization
16 | ##################################################################################
17 |
18 | # l2_decay : tf_contrib.layers.l2_regularizer(0.0001)
19 | # orthogonal_regularizer : orthogonal_regularizer(0.0001) / orthogonal_regularizer_fully(0.0001)
20 |
21 | weight_init = tf.truncated_normal_initializer(mean=0.0, stddev=0.02)
22 | weight_regularizer = orthogonal_regularizer(0.0001)
23 | weight_regularizer_fully = orthogonal_regularizer_fully(0.0001)
24 |
25 | # Regularization only G in BigGAN
26 |
27 | ##################################################################################
28 | # Layer
29 | ##################################################################################
30 |
31 | # pad = ceil[ (kernel - stride) / 2 ]
32 |
33 | def conv(x, channels, kernel=4, stride=2, pad=0, pad_type='zero', use_bias=True, sn=False, scope='conv_0'):
34 | with tf.variable_scope(scope):
35 | if pad > 0:
36 | h = x.get_shape().as_list()[1]
37 | if h % stride == 0:
38 | pad = pad * 2
39 | else:
40 | pad = max(kernel - (h % stride), 0)
41 |
42 | pad_top = pad // 2
43 | pad_bottom = pad - pad_top
44 | pad_left = pad // 2
45 | pad_right = pad - pad_left
46 |
47 | if pad_type == 'zero' :
48 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]])
49 | if pad_type == 'reflect' :
50 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]], mode='REFLECT')
51 |
52 | if sn :
53 | if scope.__contains__('generator') :
54 | w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], initializer=weight_init,
55 | regularizer=weight_regularizer)
56 | else :
57 | w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], initializer=weight_init,
58 | regularizer=None)
59 |
60 | x = tf.nn.conv2d(input=x, filter=spectral_norm(w),
61 | strides=[1, stride, stride, 1], padding='VALID')
62 | if use_bias :
63 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0))
64 | x = tf.nn.bias_add(x, bias)
65 |
66 | else :
67 | if scope.__contains__('generator'):
68 | x = tf.layers.conv2d(inputs=x, filters=channels,
69 | kernel_size=kernel, kernel_initializer=weight_init,
70 | kernel_regularizer=weight_regularizer,
71 | strides=stride, use_bias=use_bias)
72 | else :
73 | x = tf.layers.conv2d(inputs=x, filters=channels,
74 | kernel_size=kernel, kernel_initializer=weight_init,
75 | kernel_regularizer=None,
76 | strides=stride, use_bias=use_bias)
77 |
78 |
79 | return x
80 |
81 |
82 | def deconv(x, channels, kernel=4, stride=2, padding='SAME', use_bias=True, sn=False, scope='deconv_0'):
83 | with tf.variable_scope(scope):
84 | x_shape = x.get_shape().as_list()
85 |
86 | if padding == 'SAME':
87 | output_shape = [x_shape[0], x_shape[1] * stride, x_shape[2] * stride, channels]
88 |
89 | else:
90 | output_shape =[x_shape[0], x_shape[1] * stride + max(kernel - stride, 0), x_shape[2] * stride + max(kernel - stride, 0), channels]
91 |
92 | if sn :
93 | w = tf.get_variable("kernel", shape=[kernel, kernel, channels, x.get_shape()[-1]], initializer=weight_init, regularizer=weight_regularizer)
94 | x = tf.nn.conv2d_transpose(x, filter=spectral_norm(w), output_shape=output_shape, strides=[1, stride, stride, 1], padding=padding)
95 |
96 | if use_bias :
97 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0))
98 | x = tf.nn.bias_add(x, bias)
99 |
100 | else :
101 | x = tf.layers.conv2d_transpose(inputs=x, filters=channels,
102 | kernel_size=kernel, kernel_initializer=weight_init, kernel_regularizer=weight_regularizer,
103 | strides=stride, padding=padding, use_bias=use_bias)
104 |
105 | return x
106 |
107 | def fully_conneted(x, units, use_bias=True, sn=False, scope='fully_0'):
108 | with tf.variable_scope(scope):
109 | x = flatten(x)
110 | shape = x.get_shape().as_list()
111 | channels = shape[-1]
112 |
113 | if sn :
114 | if scope.__contains__('generator'):
115 | w = tf.get_variable("kernel", [channels, units], tf.float32, initializer=weight_init, regularizer=weight_regularizer_fully)
116 | else :
117 | w = tf.get_variable("kernel", [channels, units], tf.float32, initializer=weight_init, regularizer=None)
118 |
119 | if use_bias :
120 | bias = tf.get_variable("bias", [units], initializer=tf.constant_initializer(0.0))
121 |
122 | x = tf.matmul(x, spectral_norm(w)) + bias
123 | else :
124 | x = tf.matmul(x, spectral_norm(w))
125 |
126 | else :
127 | if scope.__contains__('generator'):
128 | x = tf.layers.dense(x, units=units, kernel_initializer=weight_init,
129 | kernel_regularizer=weight_regularizer_fully, use_bias=use_bias)
130 | else :
131 | x = tf.layers.dense(x, units=units, kernel_initializer=weight_init,
132 | kernel_regularizer=None, use_bias=use_bias)
133 |
134 | return x
135 |
136 | def flatten(x) :
137 | return tf.layers.flatten(x)
138 |
139 | def hw_flatten(x) :
140 | return tf.reshape(x, shape=[x.shape[0], -1, x.shape[-1]])
141 |
142 | ##################################################################################
143 | # Residual-block, Self-Attention-block
144 | ##################################################################################
145 |
146 | def resblock(x_init, channels, use_bias=True, is_training=True, sn=False, scope='resblock'):
147 | with tf.variable_scope(scope):
148 | with tf.variable_scope('res1'):
149 | x = conv(x_init, channels, kernel=3, stride=1, pad=1, use_bias=use_bias, sn=sn)
150 | x = batch_norm(x, is_training)
151 | x = relu(x)
152 |
153 | with tf.variable_scope('res2'):
154 | x = conv(x, channels, kernel=3, stride=1, pad=1, use_bias=use_bias, sn=sn)
155 | x = batch_norm(x, is_training)
156 |
157 | return x + x_init
158 |
159 | def resblock_up(x_init, channels, use_bias=True, is_training=True, sn=False, scope='resblock_up'):
160 | with tf.variable_scope(scope):
161 | with tf.variable_scope('res1'):
162 | x = batch_norm(x_init, is_training)
163 | x = relu(x)
164 | x = deconv(x, channels, kernel=3, stride=2, use_bias=use_bias, sn=sn)
165 |
166 | with tf.variable_scope('res2') :
167 | x = batch_norm(x, is_training)
168 | x = relu(x)
169 | x = deconv(x, channels, kernel=3, stride=1, use_bias=use_bias, sn=sn)
170 |
171 | with tf.variable_scope('skip') :
172 | x_init = deconv(x_init, channels, kernel=3, stride=2, use_bias=use_bias, sn=sn)
173 |
174 |
175 | return x + x_init
176 |
177 | def resblock_up_condition(x_init, z, channels, use_bias=True, is_training=True, sn=False, scope='resblock_up'):
178 | with tf.variable_scope(scope):
179 | with tf.variable_scope('res1'):
180 | x = condition_batch_norm(x_init, z, is_training)
181 | x = relu(x)
182 | x = deconv(x, channels, kernel=3, stride=2, use_bias=use_bias, sn=sn)
183 |
184 | with tf.variable_scope('res2') :
185 | x = condition_batch_norm(x, z, is_training)
186 | x = relu(x)
187 | x = deconv(x, channels, kernel=3, stride=1, use_bias=use_bias, sn=sn)
188 |
189 | with tf.variable_scope('skip') :
190 | x_init = deconv(x_init, channels, kernel=3, stride=2, use_bias=use_bias, sn=sn)
191 |
192 |
193 | return x + x_init
194 |
195 |
196 | def resblock_down(x_init, channels, use_bias=True, is_training=True, sn=False, scope='resblock_down'):
197 | with tf.variable_scope(scope):
198 | with tf.variable_scope('res1'):
199 | x = batch_norm(x_init, is_training)
200 | x = relu(x)
201 | x = conv(x, channels, kernel=3, stride=2, pad=1, use_bias=use_bias, sn=sn)
202 |
203 | with tf.variable_scope('res2') :
204 | x = batch_norm(x, is_training)
205 | x = relu(x)
206 | x = conv(x, channels, kernel=3, stride=1, pad=1, use_bias=use_bias, sn=sn)
207 |
208 | with tf.variable_scope('skip') :
209 | x_init = conv(x_init, channels, kernel=3, stride=2, pad=1, use_bias=use_bias, sn=sn)
210 |
211 |
212 | return x + x_init
213 |
214 | def self_attention(x, channels, sn=False, scope='self_attention'):
215 | with tf.variable_scope(scope):
216 | f = conv(x, channels // 8, kernel=1, stride=1, sn=sn, scope='f_conv') # [bs, h, w, c']
217 | g = conv(x, channels // 8, kernel=1, stride=1, sn=sn, scope='g_conv') # [bs, h, w, c']
218 | h = conv(x, channels, kernel=1, stride=1, sn=sn, scope='h_conv') # [bs, h, w, c]
219 |
220 | # N = h * w
221 | s = tf.matmul(hw_flatten(g), hw_flatten(f), transpose_b=True) # # [bs, N, N]
222 |
223 | beta = tf.nn.softmax(s) # attention map
224 |
225 | o = tf.matmul(beta, hw_flatten(h)) # [bs, N, C]
226 | gamma = tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0))
227 |
228 | o = tf.reshape(o, shape=x.shape) # [bs, h, w, C]
229 | x = gamma * o + x
230 |
231 | return x
232 |
233 | def self_attention_2(x, channels, sn=False, scope='self_attention'):
234 | with tf.variable_scope(scope):
235 | f = conv(x, channels // 8, kernel=1, stride=1, sn=sn, scope='f_conv') # [bs, h, w, c']
236 | f = max_pooling(f)
237 |
238 | g = conv(x, channels // 8, kernel=1, stride=1, sn=sn, scope='g_conv') # [bs, h, w, c']
239 |
240 | h = conv(x, channels // 2, kernel=1, stride=1, sn=sn, scope='h_conv') # [bs, h, w, c]
241 | h = max_pooling(h)
242 |
243 | # N = h * w
244 | s = tf.matmul(hw_flatten(g), hw_flatten(f), transpose_b=True) # # [bs, N, N]
245 |
246 | beta = tf.nn.softmax(s) # attention map
247 |
248 | o = tf.matmul(beta, hw_flatten(h)) # [bs, N, C]
249 | gamma = tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0))
250 |
251 | o = tf.reshape(o, shape=[x.shape[0], x.shape[1], x.shape[2], channels // 2]) # [bs, h, w, C]
252 | o = conv(o, channels, kernel=1, stride=1, sn=sn, scope='attn_conv')
253 | x = gamma * o + x
254 |
255 | return x
256 |
257 | ##################################################################################
258 | # Sampling
259 | ##################################################################################
260 |
261 | def global_avg_pooling(x):
262 | gap = tf.reduce_mean(x, axis=[1, 2])
263 |
264 | return gap
265 |
266 | def global_sum_pooling(x) :
267 | gsp = tf.reduce_sum(x, axis=[1, 2])
268 |
269 | return gsp
270 |
271 | def max_pooling(x) :
272 | x = tf.layers.max_pooling2d(x, pool_size=2, strides=2, padding='SAME')
273 | return x
274 |
275 | def up_sample(x, scale_factor=2):
276 | _, h, w, _ = x.get_shape().as_list()
277 | new_size = [h * scale_factor, w * scale_factor]
278 | return tf.image.resize_nearest_neighbor(x, size=new_size)
279 |
280 | ##################################################################################
281 | # Activation function
282 | ##################################################################################
283 |
284 | def lrelu(x, alpha=0.2):
285 | return tf.nn.leaky_relu(x, alpha)
286 |
287 |
288 | def relu(x):
289 | return tf.nn.relu(x)
290 |
291 |
292 | def tanh(x):
293 | return tf.tanh(x)
294 |
295 | ##################################################################################
296 | # Normalization function
297 | ##################################################################################
298 |
299 | def batch_norm(x, is_training=True, scope='batch_norm'):
300 | return tf.layers.batch_normalization(x,
301 | momentum=0.9,
302 | epsilon=1e-05,
303 | training=is_training,
304 | name=scope)
305 |
306 | def condition_batch_norm(x, z, is_training=True, scope='batch_norm'):
307 | with tf.variable_scope(scope) :
308 | _, _, _, c = x.get_shape().as_list()
309 | decay = 0.9
310 | epsilon = 1e-05
311 |
312 | test_mean = tf.get_variable("pop_mean", shape=[c], dtype=tf.float32, initializer=tf.constant_initializer(0.0), trainable=False)
313 | test_var = tf.get_variable("pop_var", shape=[c], dtype=tf.float32, initializer=tf.constant_initializer(1.0), trainable=False)
314 |
315 | beta = fully_conneted(z, units=c, scope='beta')
316 | gamma = fully_conneted(z, units=c, scope='gamma')
317 |
318 | beta = tf.reshape(beta, shape=[-1, 1, 1, c])
319 | gamma = tf.reshape(gamma, shape=[-1, 1, 1, c])
320 |
321 | if is_training:
322 | batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2])
323 | ema_mean = tf.assign(test_mean, test_mean * decay + batch_mean * (1 - decay))
324 | ema_var = tf.assign(test_var, test_var * decay + batch_var * (1 - decay))
325 |
326 | with tf.control_dependencies([ema_mean, ema_var]):
327 | return tf.nn.batch_normalization(x, batch_mean, batch_var, beta, gamma, epsilon)
328 | else:
329 | return tf.nn.batch_normalization(x, test_mean, test_var, beta, gamma, epsilon)
330 |
331 |
332 | def spectral_norm(w, iteration=1):
333 | w_shape = w.shape.as_list()
334 | w = tf.reshape(w, [-1, w_shape[-1]])
335 |
336 | u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.random_normal_initializer(), trainable=False)
337 |
338 | u_hat = u
339 | v_hat = None
340 | for i in range(iteration):
341 | """
342 | power iteration
343 | Usually iteration = 1 will be enough
344 | """
345 |
346 | v_ = tf.matmul(u_hat, tf.transpose(w))
347 | v_hat = tf.nn.l2_normalize(v_)
348 |
349 | u_ = tf.matmul(v_hat, w)
350 | u_hat = tf.nn.l2_normalize(u_)
351 |
352 | u_hat = tf.stop_gradient(u_hat)
353 | v_hat = tf.stop_gradient(v_hat)
354 |
355 | sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat))
356 |
357 | with tf.control_dependencies([u.assign(u_hat)]):
358 | w_norm = w / sigma
359 | w_norm = tf.reshape(w_norm, w_shape)
360 |
361 | return w_norm
362 |
363 | ##################################################################################
364 | # Loss function
365 | ##################################################################################
366 |
367 | def discriminator_loss(loss_func, real, fake):
368 | real_loss = 0
369 | fake_loss = 0
370 |
371 | if loss_func.__contains__('wgan') :
372 | real_loss = -tf.reduce_mean(real)
373 | fake_loss = tf.reduce_mean(fake)
374 |
375 | if loss_func == 'lsgan' :
376 | real_loss = tf.reduce_mean(tf.squared_difference(real, 1.0))
377 | fake_loss = tf.reduce_mean(tf.square(fake))
378 |
379 | if loss_func == 'gan' or loss_func == 'dragan' :
380 | real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real), logits=real))
381 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(fake), logits=fake))
382 |
383 | if loss_func == 'hinge' :
384 | real_loss = tf.reduce_mean(relu(1.0 - real))
385 | fake_loss = tf.reduce_mean(relu(1.0 + fake))
386 |
387 | loss = real_loss + fake_loss
388 |
389 | return loss
390 |
391 | def generator_loss(loss_func, fake):
392 | fake_loss = 0
393 |
394 | if loss_func.__contains__('wgan') :
395 | fake_loss = -tf.reduce_mean(fake)
396 |
397 | if loss_func == 'lsgan' :
398 | fake_loss = tf.reduce_mean(tf.squared_difference(fake, 1.0))
399 |
400 | if loss_func == 'gan' or loss_func == 'dragan' :
401 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(fake), logits=fake))
402 |
403 | if loss_func == 'hinge' :
404 | fake_loss = -tf.reduce_mean(fake)
405 |
406 | loss = fake_loss
407 |
408 | return loss
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import scipy.misc
2 | import numpy as np
3 | import os
4 | from glob import glob
5 |
6 | import tensorflow as tf
7 | import tensorflow.contrib.slim as slim
8 | from keras.datasets import cifar10, mnist
9 |
10 | class ImageData:
11 |
12 | def __init__(self, load_size, channels, custom_dataset):
13 | self.load_size = load_size
14 | self.channels = channels
15 | self.custom_dataset = custom_dataset
16 |
17 | def image_processing(self, filename):
18 |
19 | if not self.custom_dataset :
20 | x_decode = filename
21 | else :
22 | x = tf.read_file(filename)
23 | x_decode = tf.image.decode_jpeg(x, channels=self.channels)
24 |
25 | img = tf.image.resize_images(x_decode, [self.load_size, self.load_size])
26 | img = tf.cast(img, tf.float32) / 127.5 - 1
27 |
28 | return img
29 |
30 |
31 | def load_mnist():
32 | (train_data, train_labels), (test_data, test_labels) = mnist.load_data()
33 | x = np.concatenate((train_data, test_data), axis=0)
34 | x = np.expand_dims(x, axis=-1)
35 |
36 | return x
37 |
38 | def load_cifar10() :
39 | (train_data, train_labels), (test_data, test_labels) = cifar10.load_data()
40 | x = np.concatenate((train_data, test_data), axis=0)
41 |
42 | return x
43 |
44 | def load_data(dataset_name) :
45 | if dataset_name == 'mnist' :
46 | x = load_mnist()
47 | elif dataset_name == 'cifar10' :
48 | x = load_cifar10()
49 | else :
50 |
51 | x = glob(os.path.join("./dataset", dataset_name, '*.*'))
52 |
53 | return x
54 |
55 |
56 | def preprocessing(x, size):
57 | x = scipy.misc.imread(x, mode='RGB')
58 | x = scipy.misc.imresize(x, [size, size])
59 | x = normalize(x)
60 | return x
61 |
62 | def normalize(x) :
63 | return x/127.5 - 1
64 |
65 | def save_images(images, size, image_path):
66 | return imsave(inverse_transform(images), size, image_path)
67 |
68 | def merge(images, size):
69 | h, w = images.shape[1], images.shape[2]
70 | if (images.shape[3] in (3,4)):
71 | c = images.shape[3]
72 | img = np.zeros((h * size[0], w * size[1], c))
73 | for idx, image in enumerate(images):
74 | i = idx % size[1]
75 | j = idx // size[1]
76 | img[j * h:j * h + h, i * w:i * w + w, :] = image
77 | return img
78 | elif images.shape[3]==1:
79 | img = np.zeros((h * size[0], w * size[1]))
80 | for idx, image in enumerate(images):
81 | i = idx % size[1]
82 | j = idx // size[1]
83 | img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0]
84 | return img
85 | else:
86 | raise ValueError('in merge(images,size) images parameter ''must have dimensions: HxW or HxWx3 or HxWx4')
87 |
88 | def imsave(images, size, path):
89 | # image = np.squeeze(merge(images, size)) # 채널이 1인거 제거 ?
90 | return scipy.misc.imsave(path, merge(images, size))
91 |
92 |
93 | def inverse_transform(images):
94 | return (images+1.)/2.
95 |
96 |
97 | def check_folder(log_dir):
98 | if not os.path.exists(log_dir):
99 | os.makedirs(log_dir)
100 | return log_dir
101 |
102 | def show_all_variables():
103 | model_vars = tf.trainable_variables()
104 | slim.model_analyzer.analyze_vars(model_vars, print_info=True)
105 |
106 | def str2bool(x):
107 | return x.lower() in ('true')
108 |
109 | ##################################################################################
110 | # Regularization
111 | ##################################################################################
112 |
113 | def orthogonal_regularizer(scale) :
114 | """ Defining the Orthogonal regularizer and return the function at last to be used in Conv layer as kernel regularizer"""
115 |
116 | def ortho_reg(w) :
117 | """ Reshaping the matrxi in to 2D tensor for enforcing orthogonality"""
118 | _, _, _, c = w.get_shape().as_list()
119 |
120 | w = tf.reshape(w, [-1, c])
121 |
122 | """ Declaring a Identity Tensor of appropriate size"""
123 | identity = tf.eye(c)
124 |
125 | """ Regularizer Wt*W - I """
126 | w_transpose = tf.transpose(w)
127 | w_mul = tf.matmul(w_transpose, w)
128 | reg = tf.subtract(w_mul, identity)
129 |
130 | """Calculating the Loss Obtained"""
131 | ortho_loss = tf.nn.l2_loss(reg)
132 |
133 | return scale * ortho_loss
134 |
135 | return ortho_reg
136 |
137 | def orthogonal_regularizer_fully(scale) :
138 | """ Defining the Orthogonal regularizer and return the function at last to be used in Fully Connected Layer """
139 |
140 | def ortho_reg_fully(w) :
141 | """ Reshaping the matrix in to 2D tensor for enforcing orthogonality"""
142 | _, c = w.get_shape().as_list()
143 |
144 | """Declaring a Identity Tensor of appropriate size"""
145 | identity = tf.eye(c)
146 | w_transpose = tf.transpose(w)
147 | w_mul = tf.matmul(w_transpose, w)
148 | reg = tf.subtract(w_mul, identity)
149 |
150 | """ Calculating the Loss """
151 | ortho_loss = tf.nn.l2_loss(reg)
152 |
153 | return scale * ortho_loss
154 |
155 | return ortho_reg_fully
--------------------------------------------------------------------------------