├── .DS_Store ├── .gitignore ├── BigGAN_128.py ├── BigGAN_256.py ├── BigGAN_512.py ├── LICENSE ├── README.md ├── assets ├── 128.png ├── 256.png ├── 512.png ├── architecture.png └── main.png ├── main.py ├── ops.py └── utils.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/BigGAN-Tensorflow/d64d62ecd2b0761d08ff9d8c51241e963be06183/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /BigGAN_128.py: -------------------------------------------------------------------------------- 1 | import time 2 | from ops import * 3 | from utils import * 4 | from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch 5 | from tensorflow.contrib.opt import MovingAverageOptimizer 6 | 7 | 8 | class BigGAN_128(object): 9 | 10 | def __init__(self, sess, args): 11 | self.model_name = "BigGAN" # name for checkpoint 12 | self.sess = sess 13 | self.dataset_name = args.dataset 14 | self.checkpoint_dir = args.checkpoint_dir 15 | self.sample_dir = args.sample_dir 16 | self.result_dir = args.result_dir 17 | self.log_dir = args.log_dir 18 | 19 | self.epoch = args.epoch 20 | self.iteration = args.iteration 21 | self.batch_size = args.batch_size 22 | self.print_freq = args.print_freq 23 | self.save_freq = args.save_freq 24 | self.img_size = args.img_size 25 | 26 | """ Generator """ 27 | self.ch = args.ch 28 | self.z_dim = args.z_dim # dimension of noise-vector 29 | self.gan_type = args.gan_type 30 | 31 | """ Discriminator """ 32 | self.n_critic = args.n_critic 33 | self.sn = args.sn 34 | self.ld = args.ld 35 | 36 | self.sample_num = args.sample_num # number of generated images to be saved 37 | self.test_num = args.test_num 38 | 39 | # train 40 | self.g_learning_rate = args.g_lr 41 | self.d_learning_rate = args.d_lr 42 | self.beta1 = args.beta1 43 | self.beta2 = args.beta2 44 | self.moving_decay = args.moving_decay 45 | 46 | self.custom_dataset = False 47 | 48 | if self.dataset_name == 'mnist': 49 | self.c_dim = 1 50 | self.data = load_mnist() 51 | 52 | elif self.dataset_name == 'cifar10': 53 | self.c_dim = 3 54 | self.data = load_cifar10() 55 | 56 | else: 57 | self.c_dim = 3 58 | self.data = load_data(dataset_name=self.dataset_name) 59 | self.custom_dataset = True 60 | 61 | self.dataset_num = len(self.data) 62 | 63 | self.sample_dir = os.path.join(self.sample_dir, self.model_dir) 64 | check_folder(self.sample_dir) 65 | 66 | print() 67 | 68 | print("##### Information #####") 69 | print("# BigGAN 128") 70 | print("# gan type : ", self.gan_type) 71 | print("# dataset : ", self.dataset_name) 72 | print("# dataset number : ", self.dataset_num) 73 | print("# batch_size : ", self.batch_size) 74 | print("# epoch : ", self.epoch) 75 | print("# iteration per epoch : ", self.iteration) 76 | 77 | print() 78 | 79 | print("##### Generator #####") 80 | print("# spectral normalization : ", self.sn) 81 | print("# learning rate : ", self.g_learning_rate) 82 | 83 | print() 84 | 85 | print("##### Discriminator #####") 86 | print("# the number of critic : ", self.n_critic) 87 | print("# spectral normalization : ", self.sn) 88 | print("# learning rate : ", self.d_learning_rate) 89 | 90 | ################################################################################## 91 | # Generator 92 | ################################################################################## 93 | 94 | def generator(self, z, is_training=True, reuse=False): 95 | with tf.variable_scope("generator", reuse=reuse): 96 | # 6 97 | if self.z_dim == 128: 98 | split_dim = 20 99 | split_dim_remainder = self.z_dim - (split_dim * 5) 100 | 101 | z_split = tf.split(z, num_or_size_splits=[split_dim] * 5 + [split_dim_remainder], axis=-1) 102 | 103 | else: 104 | split_dim = self.z_dim // 6 105 | split_dim_remainder = self.z_dim - (split_dim * 6) 106 | 107 | if split_dim_remainder == 0 : 108 | z_split = tf.split(z, num_or_size_splits=[split_dim] * 6, axis=-1) 109 | else : 110 | z_split = tf.split(z, num_or_size_splits=[split_dim] * 5 + [split_dim_remainder], axis=-1) 111 | 112 | 113 | ch = 16 * self.ch 114 | x = fully_conneted(z_split[0], units=4 * 4 * ch, sn=self.sn, scope='dense') 115 | x = tf.reshape(x, shape=[-1, 4, 4, ch]) 116 | 117 | x = resblock_up_condition(x, z_split[1], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_16') 118 | ch = ch // 2 119 | 120 | x = resblock_up_condition(x, z_split[2], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_8') 121 | ch = ch // 2 122 | 123 | x = resblock_up_condition(x, z_split[3], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_4') 124 | ch = ch // 2 125 | 126 | x = resblock_up_condition(x, z_split[4], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_2') 127 | 128 | # Non-Local Block 129 | x = self_attention_2(x, channels=ch, sn=self.sn, scope='self_attention') 130 | ch = ch // 2 131 | 132 | x = resblock_up_condition(x, z_split[5], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_1') 133 | 134 | x = batch_norm(x, is_training) 135 | x = relu(x) 136 | x = conv(x, channels=self.c_dim, kernel=3, stride=1, pad=1, use_bias=False, sn=self.sn, scope='G_logit') 137 | 138 | x = tanh(x) 139 | 140 | return x 141 | 142 | ################################################################################## 143 | # Discriminator 144 | ################################################################################## 145 | 146 | def discriminator(self, x, is_training=True, reuse=False): 147 | with tf.variable_scope("discriminator", reuse=reuse): 148 | ch = self.ch 149 | 150 | x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_1') 151 | 152 | # Non-Local Block 153 | x = self_attention_2(x, channels=ch, sn=self.sn, scope='self_attention') 154 | ch = ch * 2 155 | 156 | x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_2') 157 | ch = ch * 2 158 | 159 | x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_4') 160 | ch = ch * 2 161 | 162 | x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_8') 163 | ch = ch * 2 164 | 165 | x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_16') 166 | 167 | x = resblock(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock') 168 | x = relu(x) 169 | 170 | x = global_sum_pooling(x) 171 | 172 | x = fully_conneted(x, units=1, sn=self.sn, scope='D_logit') 173 | 174 | return x 175 | 176 | def gradient_penalty(self, real, fake): 177 | if self.gan_type.__contains__('dragan'): 178 | eps = tf.random_uniform(shape=tf.shape(real), minval=0., maxval=1.) 179 | _, x_var = tf.nn.moments(real, axes=[0, 1, 2, 3]) 180 | x_std = tf.sqrt(x_var) # magnitude of noise decides the size of local region 181 | 182 | fake = real + 0.5 * x_std * eps 183 | 184 | alpha = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0., maxval=1.) 185 | interpolated = real + alpha * (fake - real) 186 | 187 | logit = self.discriminator(interpolated, reuse=True) 188 | 189 | grad = tf.gradients(logit, interpolated)[0] # gradient of D(interpolated) 190 | grad_norm = tf.norm(flatten(grad), axis=1) # l2 norm 191 | 192 | GP = 0 193 | 194 | # WGAN - LP 195 | if self.gan_type == 'wgan-lp': 196 | GP = self.ld * tf.reduce_mean(tf.square(tf.maximum(0.0, grad_norm - 1.))) 197 | 198 | elif self.gan_type == 'wgan-gp' or self.gan_type == 'dragan': 199 | GP = self.ld * tf.reduce_mean(tf.square(grad_norm - 1.)) 200 | 201 | return GP 202 | 203 | ################################################################################## 204 | # Model 205 | ################################################################################## 206 | 207 | def build_model(self): 208 | """ Graph Input """ 209 | # images 210 | Image_Data_Class = ImageData(self.img_size, self.c_dim, self.custom_dataset) 211 | inputs = tf.data.Dataset.from_tensor_slices(self.data) 212 | 213 | gpu_device = '/gpu:0' 214 | inputs = inputs.\ 215 | apply(shuffle_and_repeat(self.dataset_num)).\ 216 | apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).\ 217 | apply(prefetch_to_device(gpu_device, self.batch_size)) 218 | 219 | inputs_iterator = inputs.make_one_shot_iterator() 220 | 221 | self.inputs = inputs_iterator.get_next() 222 | 223 | # noises 224 | self.z = tf.truncated_normal(shape=[self.batch_size, 1, 1, self.z_dim], name='random_z') 225 | 226 | """ Loss Function """ 227 | # output of D for real images 228 | real_logits = self.discriminator(self.inputs) 229 | 230 | # output of D for fake images 231 | fake_images = self.generator(self.z) 232 | fake_logits = self.discriminator(fake_images, reuse=True) 233 | 234 | if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan': 235 | GP = self.gradient_penalty(real=self.inputs, fake=fake_images) 236 | else: 237 | GP = 0 238 | 239 | # get loss for discriminator 240 | self.d_loss = discriminator_loss(self.gan_type, real=real_logits, fake=fake_logits) + GP 241 | 242 | # get loss for generator 243 | self.g_loss = generator_loss(self.gan_type, fake=fake_logits) 244 | 245 | """ Training """ 246 | # divide trainable variables into a group for D and a group for G 247 | t_vars = tf.trainable_variables() 248 | d_vars = [var for var in t_vars if 'discriminator' in var.name] 249 | g_vars = [var for var in t_vars if 'generator' in var.name] 250 | 251 | # optimizers 252 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): 253 | self.d_optim = tf.train.AdamOptimizer(self.d_learning_rate, beta1=self.beta1, beta2=self.beta2).minimize(self.d_loss, var_list=d_vars) 254 | 255 | self.opt = MovingAverageOptimizer(tf.train.AdamOptimizer(self.g_learning_rate, beta1=self.beta1, beta2=self.beta2), average_decay=self.moving_decay) 256 | 257 | self.g_optim = self.opt.minimize(self.g_loss, var_list=g_vars) 258 | 259 | """" Testing """ 260 | # for test 261 | self.fake_images = self.generator(self.z, is_training=False, reuse=True) 262 | 263 | """ Summary """ 264 | self.d_sum = tf.summary.scalar("d_loss", self.d_loss) 265 | self.g_sum = tf.summary.scalar("g_loss", self.g_loss) 266 | 267 | ################################################################################## 268 | # Train 269 | ################################################################################## 270 | 271 | def train(self): 272 | # initialize all variables 273 | tf.global_variables_initializer().run() 274 | 275 | # saver to save model 276 | self.saver = self.opt.swapping_saver() 277 | 278 | # summary writer 279 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_dir, self.sess.graph) 280 | 281 | # restore check-point if it exits 282 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 283 | if could_load: 284 | start_epoch = (int)(checkpoint_counter / self.iteration) 285 | start_batch_id = checkpoint_counter - start_epoch * self.iteration 286 | counter = checkpoint_counter 287 | print(" [*] Load SUCCESS") 288 | else: 289 | start_epoch = 0 290 | start_batch_id = 0 291 | counter = 1 292 | print(" [!] Load failed...") 293 | 294 | # loop for epoch 295 | start_time = time.time() 296 | past_g_loss = -1. 297 | for epoch in range(start_epoch, self.epoch): 298 | # get batch data 299 | for idx in range(start_batch_id, self.iteration): 300 | 301 | # update D network 302 | _, summary_str, d_loss = self.sess.run([self.d_optim, self.d_sum, self.d_loss]) 303 | self.writer.add_summary(summary_str, counter) 304 | 305 | # update G network 306 | g_loss = None 307 | if (counter - 1) % self.n_critic == 0: 308 | _, summary_str, g_loss = self.sess.run([self.g_optim, self.g_sum, self.g_loss]) 309 | self.writer.add_summary(summary_str, counter) 310 | past_g_loss = g_loss 311 | 312 | # display training status 313 | counter += 1 314 | if g_loss == None: 315 | g_loss = past_g_loss 316 | print("Epoch: [%2d] [%5d/%5d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \ 317 | % (epoch, idx, self.iteration, time.time() - start_time, d_loss, g_loss)) 318 | 319 | # save training results for every 300 steps 320 | if np.mod(idx + 1, self.print_freq) == 0: 321 | samples = self.sess.run(self.fake_images) 322 | tot_num_samples = min(self.sample_num, self.batch_size) 323 | manifold_h = int(np.floor(np.sqrt(tot_num_samples))) 324 | manifold_w = int(np.floor(np.sqrt(tot_num_samples))) 325 | save_images(samples[:manifold_h * manifold_w, :, :, :], 326 | [manifold_h, manifold_w], 327 | './' + self.sample_dir + '/' + self.model_name + '_train_{:02d}_{:05d}.png'.format( 328 | epoch, idx + 1)) 329 | 330 | if np.mod(idx + 1, self.save_freq) == 0: 331 | self.save(self.checkpoint_dir, counter) 332 | 333 | # After an epoch, start_batch_id is set to zero 334 | # non-zero value is only for the first epoch after loading pre-trained model 335 | start_batch_id = 0 336 | 337 | # save model 338 | self.save(self.checkpoint_dir, counter) 339 | 340 | # show temporal results 341 | # self.visualize_results(epoch) 342 | 343 | # save model for final step 344 | self.save(self.checkpoint_dir, counter) 345 | 346 | @property 347 | def model_dir(self): 348 | if self.sn : 349 | sn = '_sn' 350 | else : 351 | sn = '' 352 | 353 | return "{}_{}_{}_{}_{}{}".format( 354 | self.model_name, self.dataset_name, self.gan_type, self.img_size, self.z_dim, sn) 355 | 356 | def save(self, checkpoint_dir, step): 357 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) 358 | 359 | if not os.path.exists(checkpoint_dir): 360 | os.makedirs(checkpoint_dir) 361 | 362 | self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step) 363 | 364 | def load(self, checkpoint_dir): 365 | print(" [*] Reading checkpoints...") 366 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) 367 | 368 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 369 | if ckpt and ckpt.model_checkpoint_path: 370 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 371 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 372 | counter = int(ckpt_name.split('-')[-1]) 373 | print(" [*] Success to read {}".format(ckpt_name)) 374 | return True, counter 375 | else: 376 | print(" [*] Failed to find a checkpoint") 377 | return False, 0 378 | 379 | def visualize_results(self, epoch): 380 | tot_num_samples = min(self.sample_num, self.batch_size) 381 | image_frame_dim = int(np.floor(np.sqrt(tot_num_samples))) 382 | 383 | """ random condition, random noise """ 384 | 385 | samples = self.sess.run(self.fake_images) 386 | 387 | save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], 388 | self.sample_dir + '/' + self.model_name + '_epoch%02d' % epoch + '_visualize.png') 389 | 390 | def test(self): 391 | tf.global_variables_initializer().run() 392 | 393 | self.saver = tf.train.Saver() 394 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 395 | result_dir = os.path.join(self.result_dir, self.model_dir) 396 | check_folder(result_dir) 397 | 398 | if could_load: 399 | print(" [*] Load SUCCESS") 400 | else: 401 | print(" [!] Load failed...") 402 | 403 | tot_num_samples = min(self.sample_num, self.batch_size) 404 | image_frame_dim = int(np.floor(np.sqrt(tot_num_samples))) 405 | 406 | """ random condition, random noise """ 407 | 408 | for i in range(self.test_num): 409 | samples = self.sess.run(self.fake_images) 410 | 411 | save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], 412 | [image_frame_dim, image_frame_dim], 413 | result_dir + '/' + self.model_name + '_test_{}.png'.format(i)) 414 | -------------------------------------------------------------------------------- /BigGAN_256.py: -------------------------------------------------------------------------------- 1 | import time 2 | from ops import * 3 | from utils import * 4 | from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch 5 | from tensorflow.contrib.opt import MovingAverageOptimizer 6 | 7 | 8 | class BigGAN_256(object): 9 | 10 | def __init__(self, sess, args): 11 | self.model_name = "BigGAN" # name for checkpoint 12 | self.sess = sess 13 | self.dataset_name = args.dataset 14 | self.checkpoint_dir = args.checkpoint_dir 15 | self.sample_dir = args.sample_dir 16 | self.result_dir = args.result_dir 17 | self.log_dir = args.log_dir 18 | 19 | self.epoch = args.epoch 20 | self.iteration = args.iteration 21 | self.batch_size = args.batch_size 22 | self.print_freq = args.print_freq 23 | self.save_freq = args.save_freq 24 | self.img_size = args.img_size 25 | 26 | """ Generator """ 27 | self.ch = args.ch 28 | self.z_dim = args.z_dim # dimension of noise-vector 29 | self.gan_type = args.gan_type 30 | 31 | """ Discriminator """ 32 | self.n_critic = args.n_critic 33 | self.sn = args.sn 34 | self.ld = args.ld 35 | 36 | self.sample_num = args.sample_num # number of generated images to be saved 37 | self.test_num = args.test_num 38 | 39 | # train 40 | self.g_learning_rate = args.g_lr 41 | self.d_learning_rate = args.d_lr 42 | self.beta1 = args.beta1 43 | self.beta2 = args.beta2 44 | self.moving_decay = args.moving_decay 45 | 46 | self.custom_dataset = False 47 | 48 | if self.dataset_name == 'mnist': 49 | self.c_dim = 1 50 | self.data = load_mnist() 51 | 52 | elif self.dataset_name == 'cifar10': 53 | self.c_dim = 3 54 | self.data = load_cifar10() 55 | 56 | else: 57 | self.c_dim = 3 58 | self.data = load_data(dataset_name=self.dataset_name) 59 | self.custom_dataset = True 60 | 61 | self.dataset_num = len(self.data) 62 | 63 | self.sample_dir = os.path.join(self.sample_dir, self.model_dir) 64 | check_folder(self.sample_dir) 65 | 66 | print() 67 | 68 | print("##### Information #####") 69 | print("# BigGAN 256") 70 | print("# gan type : ", self.gan_type) 71 | print("# dataset : ", self.dataset_name) 72 | print("# dataset number : ", self.dataset_num) 73 | print("# batch_size : ", self.batch_size) 74 | print("# epoch : ", self.epoch) 75 | print("# iteration per epoch : ", self.iteration) 76 | 77 | print() 78 | 79 | print("##### Generator #####") 80 | print("# spectral normalization : ", self.sn) 81 | print("# learning rate : ", self.g_learning_rate) 82 | 83 | print() 84 | 85 | print("##### Discriminator #####") 86 | print("# the number of critic : ", self.n_critic) 87 | print("# spectral normalization : ", self.sn) 88 | print("# learning rate : ", self.d_learning_rate) 89 | 90 | ################################################################################## 91 | # Generator 92 | ################################################################################## 93 | 94 | def generator(self, z, is_training=True, reuse=False): 95 | with tf.variable_scope("generator", reuse=reuse): 96 | # 7 97 | if self.z_dim == 128: 98 | split_dim = 18 99 | split_dim_remainder = self.z_dim - (split_dim * 6) 100 | 101 | z_split = tf.split(z, num_or_size_splits=[split_dim] * 6 + [split_dim_remainder], axis=-1) 102 | 103 | else: 104 | split_dim = self.z_dim // 7 105 | split_dim_remainder = self.z_dim - (split_dim * 7) 106 | 107 | if split_dim_remainder == 0 : 108 | z_split = tf.split(z, num_or_size_splits=[split_dim] * 7, axis=-1) 109 | else : 110 | z_split = tf.split(z, num_or_size_splits=[split_dim] * 6 + [split_dim_remainder], axis=-1) 111 | 112 | 113 | ch = 16 * self.ch 114 | x = fully_conneted(z_split[0], units=4 * 4 * ch, sn=self.sn, scope='dense') 115 | x = tf.reshape(x, shape=[-1, 4, 4, ch]) 116 | 117 | x = resblock_up_condition(x, z_split[1], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_16') 118 | ch = ch // 2 119 | 120 | x = resblock_up_condition(x, z_split[2], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_8_0') 121 | x = resblock_up_condition(x, z_split[3], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_8_1') 122 | ch = ch // 2 123 | 124 | x = resblock_up_condition(x, z_split[4], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_4') 125 | ch = ch // 2 126 | 127 | x = resblock_up_condition(x, z_split[5], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_2') 128 | 129 | # Non-Local Block 130 | x = self_attention_2(x, channels=ch, sn=self.sn, scope='self_attention') 131 | ch = ch // 2 132 | 133 | x = resblock_up_condition(x, z_split[6], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_1') 134 | 135 | x = batch_norm(x, is_training) 136 | x = relu(x) 137 | x = conv(x, channels=self.c_dim, kernel=3, stride=1, pad=1, use_bias=False, sn=self.sn, scope='G_logit') 138 | 139 | x = tanh(x) 140 | 141 | return x 142 | 143 | ################################################################################## 144 | # Discriminator 145 | ################################################################################## 146 | 147 | def discriminator(self, x, is_training=True, reuse=False): 148 | with tf.variable_scope("discriminator", reuse=reuse): 149 | ch = self.ch 150 | 151 | x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_1') 152 | ch = ch * 2 153 | 154 | x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_2') 155 | 156 | # Non-Local Block 157 | x = self_attention_2(x, channels=ch, sn=self.sn, scope='self_attention') 158 | ch = ch * 2 159 | 160 | x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_4') 161 | ch = ch * 2 162 | 163 | x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_8_0') 164 | x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_8_1') 165 | ch = ch * 2 166 | 167 | x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_16') 168 | 169 | x = resblock(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock') 170 | x = relu(x) 171 | 172 | x = global_sum_pooling(x) 173 | 174 | x = fully_conneted(x, units=1, sn=self.sn, scope='D_logit') 175 | 176 | return x 177 | 178 | def gradient_penalty(self, real, fake): 179 | if self.gan_type.__contains__('dragan'): 180 | eps = tf.random_uniform(shape=tf.shape(real), minval=0., maxval=1.) 181 | _, x_var = tf.nn.moments(real, axes=[0, 1, 2, 3]) 182 | x_std = tf.sqrt(x_var) # magnitude of noise decides the size of local region 183 | 184 | fake = real + 0.5 * x_std * eps 185 | 186 | alpha = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0., maxval=1.) 187 | interpolated = real + alpha * (fake - real) 188 | 189 | logit = self.discriminator(interpolated, reuse=True) 190 | 191 | grad = tf.gradients(logit, interpolated)[0] # gradient of D(interpolated) 192 | grad_norm = tf.norm(flatten(grad), axis=1) # l2 norm 193 | 194 | GP = 0 195 | 196 | # WGAN - LP 197 | if self.gan_type == 'wgan-lp': 198 | GP = self.ld * tf.reduce_mean(tf.square(tf.maximum(0.0, grad_norm - 1.))) 199 | 200 | elif self.gan_type == 'wgan-gp' or self.gan_type == 'dragan': 201 | GP = self.ld * tf.reduce_mean(tf.square(grad_norm - 1.)) 202 | 203 | return GP 204 | 205 | ################################################################################## 206 | # Model 207 | ################################################################################## 208 | 209 | def build_model(self): 210 | """ Graph Input """ 211 | # images 212 | Image_Data_Class = ImageData(self.img_size, self.c_dim, self.custom_dataset) 213 | inputs = tf.data.Dataset.from_tensor_slices(self.data) 214 | 215 | gpu_device = '/gpu:0' 216 | inputs = inputs.\ 217 | apply(shuffle_and_repeat(self.dataset_num)).\ 218 | apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).\ 219 | apply(prefetch_to_device(gpu_device, self.batch_size)) 220 | 221 | inputs_iterator = inputs.make_one_shot_iterator() 222 | 223 | self.inputs = inputs_iterator.get_next() 224 | 225 | # noises 226 | self.z = tf.truncated_normal(shape=[self.batch_size, 1, 1, self.z_dim], name='random_z') 227 | 228 | """ Loss Function """ 229 | # output of D for real images 230 | real_logits = self.discriminator(self.inputs) 231 | 232 | # output of D for fake images 233 | fake_images = self.generator(self.z) 234 | fake_logits = self.discriminator(fake_images, reuse=True) 235 | 236 | if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan': 237 | GP = self.gradient_penalty(real=self.inputs, fake=fake_images) 238 | else: 239 | GP = 0 240 | 241 | # get loss for discriminator 242 | self.d_loss = discriminator_loss(self.gan_type, real=real_logits, fake=fake_logits) + GP 243 | 244 | # get loss for generator 245 | self.g_loss = generator_loss(self.gan_type, fake=fake_logits) 246 | 247 | """ Training """ 248 | # divide trainable variables into a group for D and a group for G 249 | t_vars = tf.trainable_variables() 250 | d_vars = [var for var in t_vars if 'discriminator' in var.name] 251 | g_vars = [var for var in t_vars if 'generator' in var.name] 252 | 253 | # optimizers 254 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): 255 | self.d_optim = tf.train.AdamOptimizer(self.d_learning_rate, beta1=self.beta1, beta2=self.beta2).minimize(self.d_loss, var_list=d_vars) 256 | 257 | self.opt = MovingAverageOptimizer(tf.train.AdamOptimizer(self.g_learning_rate, beta1=self.beta1, beta2=self.beta2), average_decay=self.moving_decay) 258 | 259 | self.g_optim = self.opt.minimize(self.g_loss, var_list=g_vars) 260 | 261 | """" Testing """ 262 | # for test 263 | self.fake_images = self.generator(self.z, is_training=False, reuse=True) 264 | 265 | """ Summary """ 266 | self.d_sum = tf.summary.scalar("d_loss", self.d_loss) 267 | self.g_sum = tf.summary.scalar("g_loss", self.g_loss) 268 | 269 | ################################################################################## 270 | # Train 271 | ################################################################################## 272 | 273 | def train(self): 274 | # initialize all variables 275 | tf.global_variables_initializer().run() 276 | 277 | # saver to save model 278 | self.saver = self.opt.swapping_saver() 279 | 280 | # summary writer 281 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_dir, self.sess.graph) 282 | 283 | # restore check-point if it exits 284 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 285 | if could_load: 286 | start_epoch = (int)(checkpoint_counter / self.iteration) 287 | start_batch_id = checkpoint_counter - start_epoch * self.iteration 288 | counter = checkpoint_counter 289 | print(" [*] Load SUCCESS") 290 | else: 291 | start_epoch = 0 292 | start_batch_id = 0 293 | counter = 1 294 | print(" [!] Load failed...") 295 | 296 | # loop for epoch 297 | start_time = time.time() 298 | past_g_loss = -1. 299 | for epoch in range(start_epoch, self.epoch): 300 | # get batch data 301 | for idx in range(start_batch_id, self.iteration): 302 | # update D network 303 | _, summary_str, d_loss = self.sess.run([self.d_optim, self.d_sum, self.d_loss]) 304 | self.writer.add_summary(summary_str, counter) 305 | 306 | # update G network 307 | g_loss = None 308 | if (counter - 1) % self.n_critic == 0: 309 | _, summary_str, g_loss = self.sess.run([self.g_optim, self.g_sum, self.g_loss]) 310 | self.writer.add_summary(summary_str, counter) 311 | past_g_loss = g_loss 312 | 313 | # display training status 314 | counter += 1 315 | if g_loss == None: 316 | g_loss = past_g_loss 317 | print("Epoch: [%2d] [%5d/%5d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \ 318 | % (epoch, idx, self.iteration, time.time() - start_time, d_loss, g_loss)) 319 | 320 | # save training results for every 300 steps 321 | if np.mod(idx + 1, self.print_freq) == 0: 322 | samples = self.sess.run(self.fake_images) 323 | tot_num_samples = min(self.sample_num, self.batch_size) 324 | manifold_h = int(np.floor(np.sqrt(tot_num_samples))) 325 | manifold_w = int(np.floor(np.sqrt(tot_num_samples))) 326 | save_images(samples[:manifold_h * manifold_w, :, :, :], 327 | [manifold_h, manifold_w], 328 | './' + self.sample_dir + '/' + self.model_name + '_train_{:02d}_{:05d}.png'.format( 329 | epoch, idx + 1)) 330 | 331 | if np.mod(idx + 1, self.save_freq) == 0: 332 | self.save(self.checkpoint_dir, counter) 333 | 334 | # After an epoch, start_batch_id is set to zero 335 | # non-zero value is only for the first epoch after loading pre-trained model 336 | start_batch_id = 0 337 | 338 | # save model 339 | self.save(self.checkpoint_dir, counter) 340 | 341 | # show temporal results 342 | # self.visualize_results(epoch) 343 | 344 | # save model for final step 345 | self.save(self.checkpoint_dir, counter) 346 | 347 | @property 348 | def model_dir(self): 349 | if self.sn : 350 | sn = '_sn' 351 | else : 352 | sn = '' 353 | 354 | return "{}_{}_{}_{}_{}{}".format( 355 | self.model_name, self.dataset_name, self.gan_type, self.img_size, self.z_dim, sn) 356 | 357 | def save(self, checkpoint_dir, step): 358 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) 359 | 360 | if not os.path.exists(checkpoint_dir): 361 | os.makedirs(checkpoint_dir) 362 | 363 | self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step) 364 | 365 | def load(self, checkpoint_dir): 366 | print(" [*] Reading checkpoints...") 367 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) 368 | 369 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 370 | if ckpt and ckpt.model_checkpoint_path: 371 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 372 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 373 | counter = int(ckpt_name.split('-')[-1]) 374 | print(" [*] Success to read {}".format(ckpt_name)) 375 | return True, counter 376 | else: 377 | print(" [*] Failed to find a checkpoint") 378 | return False, 0 379 | 380 | def visualize_results(self, epoch): 381 | tot_num_samples = min(self.sample_num, self.batch_size) 382 | image_frame_dim = int(np.floor(np.sqrt(tot_num_samples))) 383 | 384 | """ random condition, random noise """ 385 | 386 | samples = self.sess.run(self.fake_images) 387 | 388 | save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], 389 | self.sample_dir + '/' + self.model_name + '_epoch%02d' % epoch + '_visualize.png') 390 | 391 | def test(self): 392 | tf.global_variables_initializer().run() 393 | 394 | self.saver = tf.train.Saver() 395 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 396 | result_dir = os.path.join(self.result_dir, self.model_dir) 397 | check_folder(result_dir) 398 | 399 | if could_load: 400 | print(" [*] Load SUCCESS") 401 | else: 402 | print(" [!] Load failed...") 403 | 404 | tot_num_samples = min(self.sample_num, self.batch_size) 405 | image_frame_dim = int(np.floor(np.sqrt(tot_num_samples))) 406 | 407 | """ random condition, random noise """ 408 | 409 | for i in range(self.test_num): 410 | samples = self.sess.run(self.fake_images) 411 | 412 | save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], 413 | [image_frame_dim, image_frame_dim], 414 | result_dir + '/' + self.model_name + '_test_{}.png'.format(i)) 415 | -------------------------------------------------------------------------------- /BigGAN_512.py: -------------------------------------------------------------------------------- 1 | import time 2 | from ops import * 3 | from utils import * 4 | from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch 5 | from tensorflow.contrib.opt import MovingAverageOptimizer 6 | 7 | 8 | class BigGAN_512(object): 9 | 10 | def __init__(self, sess, args): 11 | self.model_name = "BigGAN" # name for checkpoint 12 | self.sess = sess 13 | self.dataset_name = args.dataset 14 | self.checkpoint_dir = args.checkpoint_dir 15 | self.sample_dir = args.sample_dir 16 | self.result_dir = args.result_dir 17 | self.log_dir = args.log_dir 18 | 19 | self.epoch = args.epoch 20 | self.iteration = args.iteration 21 | self.batch_size = args.batch_size 22 | self.print_freq = args.print_freq 23 | self.save_freq = args.save_freq 24 | self.img_size = args.img_size 25 | 26 | """ Generator """ 27 | self.ch = args.ch 28 | self.z_dim = args.z_dim # dimension of noise-vector 29 | self.gan_type = args.gan_type 30 | 31 | """ Discriminator """ 32 | self.n_critic = args.n_critic 33 | self.sn = args.sn 34 | self.ld = args.ld 35 | 36 | self.sample_num = args.sample_num # number of generated images to be saved 37 | self.test_num = args.test_num 38 | 39 | # train 40 | self.g_learning_rate = args.g_lr 41 | self.d_learning_rate = args.d_lr 42 | self.beta1 = args.beta1 43 | self.beta2 = args.beta2 44 | self.moving_decay = args.moving_decay 45 | 46 | self.custom_dataset = False 47 | 48 | if self.dataset_name == 'mnist': 49 | self.c_dim = 1 50 | self.data = load_mnist() 51 | 52 | elif self.dataset_name == 'cifar10': 53 | self.c_dim = 3 54 | self.data = load_cifar10() 55 | 56 | else: 57 | self.c_dim = 3 58 | self.data = load_data(dataset_name=self.dataset_name) 59 | self.custom_dataset = True 60 | 61 | self.dataset_num = len(self.data) 62 | 63 | self.sample_dir = os.path.join(self.sample_dir, self.model_dir) 64 | check_folder(self.sample_dir) 65 | 66 | print() 67 | 68 | print("##### Information #####") 69 | print("# BigGAN 512") 70 | print("# gan type : ", self.gan_type) 71 | print("# dataset : ", self.dataset_name) 72 | print("# dataset number : ", self.dataset_num) 73 | print("# batch_size : ", self.batch_size) 74 | print("# epoch : ", self.epoch) 75 | print("# iteration per epoch : ", self.iteration) 76 | 77 | print() 78 | 79 | print("##### Generator #####") 80 | print("# spectral normalization : ", self.sn) 81 | print("# learning rate : ", self.g_learning_rate) 82 | 83 | print() 84 | 85 | print("##### Discriminator #####") 86 | print("# the number of critic : ", self.n_critic) 87 | print("# spectral normalization : ", self.sn) 88 | print("# learning rate : ", self.d_learning_rate) 89 | 90 | ################################################################################## 91 | # Generator 92 | ################################################################################## 93 | 94 | def generator(self, z, is_training=True, reuse=False): 95 | with tf.variable_scope("generator", reuse=reuse): 96 | # 8 97 | if self.z_dim == 128 : 98 | split_dim = 16 99 | split_dim_remainder = self.z_dim - (split_dim * 7) 100 | 101 | z_split = tf.split(z, num_or_size_splits=[split_dim] * 7 + [split_dim_remainder], axis=-1) 102 | 103 | else : 104 | split_dim = self.z_dim // 8 105 | split_dim_remainder = self.z_dim - (split_dim * 8) 106 | 107 | if split_dim_remainder == 0 : 108 | z_split = tf.split(z, num_or_size_splits=[split_dim] * 8, axis=-1) 109 | else : 110 | z_split = tf.split(z, num_or_size_splits=[split_dim] * 7 + [split_dim_remainder], axis=-1) 111 | 112 | 113 | ch = 16 * self.ch 114 | x = fully_conneted(z_split[0], units=4 * 4 * ch, sn=self.sn, scope='dense') 115 | x = tf.reshape(x, shape=[-1, 4, 4, ch]) 116 | 117 | x = resblock_up_condition(x, z_split[1], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_16') 118 | ch = ch // 2 119 | 120 | x = resblock_up_condition(x, z_split[2], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_8_0') 121 | x = resblock_up_condition(x, z_split[3], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_8_1') 122 | ch = ch // 2 123 | 124 | x = resblock_up_condition(x, z_split[4], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_4') 125 | 126 | # Non-Local Block 127 | x = self_attention_2(x, channels=ch, sn=self.sn, scope='self_attention') 128 | ch = ch // 2 129 | 130 | x = resblock_up_condition(x, z_split[5], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_2') 131 | ch = ch // 2 132 | 133 | x = resblock_up_condition(x, z_split[6], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_1_0') 134 | x = resblock_up_condition(x, z_split[7], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_1_1') 135 | 136 | x = batch_norm(x, is_training) 137 | x = relu(x) 138 | x = conv(x, channels=self.c_dim, kernel=3, stride=1, pad=1, use_bias=False, sn=self.sn, scope='G_logit') 139 | 140 | x = tanh(x) 141 | 142 | return x 143 | 144 | ################################################################################## 145 | # Discriminator 146 | ################################################################################## 147 | 148 | def discriminator(self, x, is_training=True, reuse=False): 149 | with tf.variable_scope("discriminator", reuse=reuse): 150 | ch = self.ch 151 | 152 | x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_1_0') 153 | x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_1_1') 154 | ch = ch * 2 155 | 156 | x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_2') 157 | 158 | # Non-Local Block 159 | x = self_attention_2(x, channels=ch, sn=self.sn, scope='self_attention') 160 | ch = ch * 2 161 | 162 | x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_4') 163 | ch = ch * 2 164 | 165 | x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_8_0') 166 | x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_8_1') 167 | ch = ch * 2 168 | 169 | x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_16') 170 | 171 | x = resblock(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock') 172 | x = relu(x) 173 | 174 | x = global_sum_pooling(x) 175 | 176 | x = fully_conneted(x, units=1, sn=self.sn, scope='D_logit') 177 | 178 | return x 179 | 180 | def gradient_penalty(self, real, fake): 181 | if self.gan_type.__contains__('dragan'): 182 | eps = tf.random_uniform(shape=tf.shape(real), minval=0., maxval=1.) 183 | _, x_var = tf.nn.moments(real, axes=[0, 1, 2, 3]) 184 | x_std = tf.sqrt(x_var) # magnitude of noise decides the size of local region 185 | 186 | fake = real + 0.5 * x_std * eps 187 | 188 | alpha = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0., maxval=1.) 189 | interpolated = real + alpha * (fake - real) 190 | 191 | logit = self.discriminator(interpolated, reuse=True) 192 | 193 | grad = tf.gradients(logit, interpolated)[0] # gradient of D(interpolated) 194 | grad_norm = tf.norm(flatten(grad), axis=1) # l2 norm 195 | 196 | GP = 0 197 | 198 | # WGAN - LP 199 | if self.gan_type == 'wgan-lp': 200 | GP = self.ld * tf.reduce_mean(tf.square(tf.maximum(0.0, grad_norm - 1.))) 201 | 202 | elif self.gan_type == 'wgan-gp' or self.gan_type == 'dragan': 203 | GP = self.ld * tf.reduce_mean(tf.square(grad_norm - 1.)) 204 | 205 | return GP 206 | 207 | ################################################################################## 208 | # Model 209 | ################################################################################## 210 | 211 | def build_model(self): 212 | """ Graph Input """ 213 | # images 214 | Image_Data_Class = ImageData(self.img_size, self.c_dim, self.custom_dataset) 215 | inputs = tf.data.Dataset.from_tensor_slices(self.data) 216 | 217 | gpu_device = '/gpu:0' 218 | inputs = inputs.\ 219 | apply(shuffle_and_repeat(self.dataset_num)).\ 220 | apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).\ 221 | apply(prefetch_to_device(gpu_device, self.batch_size)) 222 | 223 | inputs_iterator = inputs.make_one_shot_iterator() 224 | 225 | self.inputs = inputs_iterator.get_next() 226 | 227 | # noises 228 | self.z = tf.truncated_normal(shape=[self.batch_size, 1, 1, self.z_dim], name='random_z') 229 | 230 | """ Loss Function """ 231 | # output of D for real images 232 | real_logits = self.discriminator(self.inputs) 233 | 234 | # output of D for fake images 235 | fake_images = self.generator(self.z) 236 | fake_logits = self.discriminator(fake_images, reuse=True) 237 | 238 | if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan': 239 | GP = self.gradient_penalty(real=self.inputs, fake=fake_images) 240 | else: 241 | GP = 0 242 | 243 | # get loss for discriminator 244 | self.d_loss = discriminator_loss(self.gan_type, real=real_logits, fake=fake_logits) + GP 245 | 246 | # get loss for generator 247 | self.g_loss = generator_loss(self.gan_type, fake=fake_logits) 248 | 249 | """ Training """ 250 | # divide trainable variables into a group for D and a group for G 251 | t_vars = tf.trainable_variables() 252 | d_vars = [var for var in t_vars if 'discriminator' in var.name] 253 | g_vars = [var for var in t_vars if 'generator' in var.name] 254 | 255 | # optimizers 256 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): 257 | self.d_optim = tf.train.AdamOptimizer(self.d_learning_rate, beta1=self.beta1, beta2=self.beta2).minimize(self.d_loss, var_list=d_vars) 258 | 259 | self.opt = MovingAverageOptimizer(tf.train.AdamOptimizer(self.g_learning_rate, beta1=self.beta1, beta2=self.beta2), average_decay=self.moving_decay) 260 | self.g_optim = self.opt.minimize(self.g_loss, var_list=g_vars) 261 | 262 | """" Testing """ 263 | # for test 264 | self.fake_images = self.generator(self.z, is_training=False, reuse=True) 265 | 266 | """ Summary """ 267 | self.d_sum = tf.summary.scalar("d_loss", self.d_loss) 268 | self.g_sum = tf.summary.scalar("g_loss", self.g_loss) 269 | 270 | ################################################################################## 271 | # Train 272 | ################################################################################## 273 | 274 | def train(self): 275 | # initialize all variables 276 | tf.global_variables_initializer().run() 277 | 278 | # saver to save model 279 | self.saver = self.opt.swapping_saver() 280 | 281 | # summary writer 282 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_dir, self.sess.graph) 283 | 284 | # restore check-point if it exits 285 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 286 | if could_load: 287 | start_epoch = (int)(checkpoint_counter / self.iteration) 288 | start_batch_id = checkpoint_counter - start_epoch * self.iteration 289 | counter = checkpoint_counter 290 | print(" [*] Load SUCCESS") 291 | else: 292 | start_epoch = 0 293 | start_batch_id = 0 294 | counter = 1 295 | print(" [!] Load failed...") 296 | 297 | # loop for epoch 298 | start_time = time.time() 299 | past_g_loss = -1. 300 | for epoch in range(start_epoch, self.epoch): 301 | # get batch data 302 | for idx in range(start_batch_id, self.iteration): 303 | 304 | # update D network 305 | _, summary_str, d_loss = self.sess.run([self.d_optim, self.d_sum, self.d_loss]) 306 | self.writer.add_summary(summary_str, counter) 307 | 308 | # update G network 309 | g_loss = None 310 | if (counter - 1) % self.n_critic == 0: 311 | _, summary_str, g_loss = self.sess.run([self.g_optim, self.g_sum, self.g_loss]) 312 | self.writer.add_summary(summary_str, counter) 313 | past_g_loss = g_loss 314 | 315 | # display training status 316 | counter += 1 317 | if g_loss == None: 318 | g_loss = past_g_loss 319 | print("Epoch: [%2d] [%5d/%5d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \ 320 | % (epoch, idx, self.iteration, time.time() - start_time, d_loss, g_loss)) 321 | 322 | # save training results for every 300 steps 323 | if np.mod(idx + 1, self.print_freq) == 0: 324 | samples = self.sess.run(self.fake_images) 325 | tot_num_samples = min(self.sample_num, self.batch_size) 326 | manifold_h = int(np.floor(np.sqrt(tot_num_samples))) 327 | manifold_w = int(np.floor(np.sqrt(tot_num_samples))) 328 | save_images(samples[:manifold_h * manifold_w, :, :, :], 329 | [manifold_h, manifold_w], 330 | './' + self.sample_dir + '/' + self.model_name + '_train_{:02d}_{:05d}.png'.format( 331 | epoch, idx + 1)) 332 | 333 | if np.mod(idx + 1, self.save_freq) == 0: 334 | self.save(self.checkpoint_dir, counter) 335 | 336 | # After an epoch, start_batch_id is set to zero 337 | # non-zero value is only for the first epoch after loading pre-trained model 338 | start_batch_id = 0 339 | 340 | # save model 341 | self.save(self.checkpoint_dir, counter) 342 | 343 | # show temporal results 344 | # self.visualize_results(epoch) 345 | 346 | # save model for final step 347 | self.save(self.checkpoint_dir, counter) 348 | 349 | @property 350 | def model_dir(self): 351 | if self.sn : 352 | sn = '_sn' 353 | else : 354 | sn = '' 355 | 356 | return "{}_{}_{}_{}_{}{}".format( 357 | self.model_name, self.dataset_name, self.gan_type, self.img_size, self.z_dim, sn) 358 | 359 | def save(self, checkpoint_dir, step): 360 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) 361 | 362 | if not os.path.exists(checkpoint_dir): 363 | os.makedirs(checkpoint_dir) 364 | 365 | self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step) 366 | 367 | def load(self, checkpoint_dir): 368 | print(" [*] Reading checkpoints...") 369 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) 370 | 371 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 372 | if ckpt and ckpt.model_checkpoint_path: 373 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 374 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 375 | counter = int(ckpt_name.split('-')[-1]) 376 | print(" [*] Success to read {}".format(ckpt_name)) 377 | return True, counter 378 | else: 379 | print(" [*] Failed to find a checkpoint") 380 | return False, 0 381 | 382 | def visualize_results(self, epoch): 383 | tot_num_samples = min(self.sample_num, self.batch_size) 384 | image_frame_dim = int(np.floor(np.sqrt(tot_num_samples))) 385 | 386 | """ random condition, random noise """ 387 | 388 | samples = self.sess.run(self.fake_images) 389 | 390 | save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], 391 | self.sample_dir + '/' + self.model_name + '_epoch%02d' % epoch + '_visualize.png') 392 | 393 | def test(self): 394 | tf.global_variables_initializer().run() 395 | 396 | self.saver = tf.train.Saver() 397 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 398 | result_dir = os.path.join(self.result_dir, self.model_dir) 399 | check_folder(result_dir) 400 | 401 | if could_load: 402 | print(" [*] Load SUCCESS") 403 | else: 404 | print(" [!] Load failed...") 405 | 406 | tot_num_samples = min(self.sample_num, self.batch_size) 407 | image_frame_dim = int(np.floor(np.sqrt(tot_num_samples))) 408 | 409 | """ random condition, random noise """ 410 | 411 | for i in range(self.test_num): 412 | samples = self.sess.run(self.fake_images) 413 | 414 | save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], 415 | [image_frame_dim, image_frame_dim], 416 | result_dir + '/' + self.model_name + '_test_{}.png'.format(i)) 417 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Junho Kim (1993.01.12) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BigGAN-Tensorflow 2 | Simple Tensorflow implementation of ["Large Scale GAN Training for High Fidelity Natural Image Synthesis" (BigGAN)](https://arxiv.org/abs/1809.11096) 3 | 4 | ![main](./assets/main.png) 5 | 6 | ## Issue 7 | * **The paper** used `orthogonal initialization`, but `I used random normal initialization.` The reason is, when using the orthogonal initialization, it did not train properly. 8 | * I have applied a hierarchical latent space, but **not** a class embeddedding. 9 | 10 | ## Usage 11 | ### dataset 12 | * `mnist` and `cifar10` are used inside keras 13 | * For `your dataset`, put images like this: 14 | 15 | ``` 16 | ├── dataset 17 |    └── YOUR_DATASET_NAME 18 | ├── xxx.jpg (name, format doesn't matter) 19 | ├── yyy.png 20 | └── ... 21 | ``` 22 | ### train 23 | * python main.py --phase train --dataset celebA-HQ --gan_type hinge 24 | 25 | ### test 26 | * python main.py --phase test --dataset celebA-HQ --gan_type hinge 27 | 28 | ## Architecture 29 | 30 | 31 | ### 128x128 32 | 33 | 34 | ### 256x256 35 | 36 | 37 | ### 512x512 38 | 39 | 40 | ## Author 41 | Junho Kim 42 | -------------------------------------------------------------------------------- /assets/128.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/BigGAN-Tensorflow/d64d62ecd2b0761d08ff9d8c51241e963be06183/assets/128.png -------------------------------------------------------------------------------- /assets/256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/BigGAN-Tensorflow/d64d62ecd2b0761d08ff9d8c51241e963be06183/assets/256.png -------------------------------------------------------------------------------- /assets/512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/BigGAN-Tensorflow/d64d62ecd2b0761d08ff9d8c51241e963be06183/assets/512.png -------------------------------------------------------------------------------- /assets/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/BigGAN-Tensorflow/d64d62ecd2b0761d08ff9d8c51241e963be06183/assets/architecture.png -------------------------------------------------------------------------------- /assets/main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/BigGAN-Tensorflow/d64d62ecd2b0761d08ff9d8c51241e963be06183/assets/main.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from BigGAN_512 import BigGAN_512 2 | from BigGAN_256 import BigGAN_256 3 | from BigGAN_128 import BigGAN_128 4 | import argparse 5 | from utils import * 6 | 7 | """parsing and configuration""" 8 | def parse_args(): 9 | desc = "Tensorflow implementation of BigGAN" 10 | parser = argparse.ArgumentParser(description=desc) 11 | parser.add_argument('--phase', type=str, default='train', help='train or test ?') 12 | parser.add_argument('--dataset', type=str, default='celebA-HQ', help='[mnist / cifar10 / custom_dataset]') 13 | 14 | parser.add_argument('--epoch', type=int, default=50, help='The number of epochs to run') 15 | parser.add_argument('--iteration', type=int, default=10000, help='The number of training iterations') 16 | parser.add_argument('--batch_size', type=int, default=2048, help='The size of batch per gpu') 17 | parser.add_argument('--ch', type=int, default=96, help='base channel number per layer') 18 | 19 | # SAGAN 20 | # batch_size = 256 21 | # base channel = 64 22 | # epoch = 100 (1M iterations) 23 | 24 | parser.add_argument('--print_freq', type=int, default=1000, help='The number of image_print_freqy') 25 | parser.add_argument('--save_freq', type=int, default=1000, help='The number of ckpt_save_freq') 26 | 27 | parser.add_argument('--g_lr', type=float, default=0.00005, help='learning rate for generator') 28 | parser.add_argument('--d_lr', type=float, default=0.0002, help='learning rate for discriminator') 29 | 30 | # if lower batch size 31 | # g_lr = 0.0001 32 | # d_lr = 0.0004 33 | 34 | # if larger batch size 35 | # g_lr = 0.00005 36 | # d_lr = 0.0002 37 | 38 | parser.add_argument('--beta1', type=float, default=0.0, help='beta1 for Adam optimizer') 39 | parser.add_argument('--beta2', type=float, default=0.9, help='beta2 for Adam optimizer') 40 | parser.add_argument('--moving_decay', type=float, default=0.9999, help='moving average decay for generator') 41 | 42 | parser.add_argument('--z_dim', type=int, default=128, help='Dimension of noise vector') 43 | parser.add_argument('--sn', type=str2bool, default=True, help='using spectral norm') 44 | 45 | parser.add_argument('--gan_type', type=str, default='hinge', help='[gan / lsgan / wgan-gp / wgan-lp / dragan / hinge]') 46 | parser.add_argument('--ld', type=float, default=10.0, help='The gradient penalty lambda') 47 | 48 | parser.add_argument('--n_critic', type=int, default=2, help='The number of critic') 49 | 50 | parser.add_argument('--img_size', type=int, default=512, help='The size of image') 51 | parser.add_argument('--sample_num', type=int, default=64, help='The number of sample images') 52 | 53 | parser.add_argument('--test_num', type=int, default=10, help='The number of images generated by the test') 54 | 55 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint', 56 | help='Directory name to save the checkpoints') 57 | parser.add_argument('--result_dir', type=str, default='results', 58 | help='Directory name to save the generated images') 59 | parser.add_argument('--log_dir', type=str, default='logs', 60 | help='Directory name to save training logs') 61 | parser.add_argument('--sample_dir', type=str, default='samples', 62 | help='Directory name to save the samples on training') 63 | 64 | return check_args(parser.parse_args()) 65 | 66 | """checking arguments""" 67 | def check_args(args): 68 | # --checkpoint_dir 69 | check_folder(args.checkpoint_dir) 70 | 71 | # --result_dir 72 | check_folder(args.result_dir) 73 | 74 | # --result_dir 75 | check_folder(args.log_dir) 76 | 77 | # --sample_dir 78 | check_folder(args.sample_dir) 79 | 80 | # --epoch 81 | try: 82 | assert args.epoch >= 1 83 | except: 84 | print('number of epochs must be larger than or equal to one') 85 | 86 | # --batch_size 87 | try: 88 | assert args.batch_size >= 1 89 | except: 90 | print('batch size must be larger than or equal to one') 91 | return args 92 | 93 | 94 | """main""" 95 | def main(): 96 | # parse arguments 97 | args = parse_args() 98 | if args is None: 99 | exit() 100 | 101 | # open session 102 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: 103 | # default gan = BigGAN_128 104 | 105 | if args.img_size == 512 : 106 | gan = BigGAN_512(sess, args) 107 | elif args.img_size == 256 : 108 | gan = BigGAN_256(sess, args) 109 | else : 110 | gan = BigGAN_128(sess, args) 111 | 112 | # build graph 113 | gan.build_model() 114 | 115 | # show network architecture 116 | show_all_variables() 117 | 118 | if args.phase == 'train' : 119 | # launch the graph in a session 120 | gan.train() 121 | 122 | # visualize learned generator 123 | gan.visualize_results(args.epoch - 1) 124 | 125 | print(" [*] Training finished!") 126 | 127 | if args.phase == 'test' : 128 | gan.test() 129 | print(" [*] Test finished!") 130 | 131 | if __name__ == '__main__': 132 | main() -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from utils import orthogonal_regularizer_fully, orthogonal_regularizer 3 | 4 | ################################################################################## 5 | # Initialization 6 | ################################################################################## 7 | 8 | # Xavier : tf_contrib.layers.xavier_initializer() 9 | # He : tf_contrib.layers.variance_scaling_initializer() 10 | # Normal : tf.random_normal_initializer(mean=0.0, stddev=0.02) 11 | # Truncated_normal : tf.truncated_normal_initializer(mean=0.0, stddev=0.02) 12 | # Orthogonal : tf.orthogonal_initializer(1.0) / relu = sqrt(2), the others = 1.0 13 | 14 | ################################################################################## 15 | # Regularization 16 | ################################################################################## 17 | 18 | # l2_decay : tf_contrib.layers.l2_regularizer(0.0001) 19 | # orthogonal_regularizer : orthogonal_regularizer(0.0001) / orthogonal_regularizer_fully(0.0001) 20 | 21 | weight_init = tf.truncated_normal_initializer(mean=0.0, stddev=0.02) 22 | weight_regularizer = orthogonal_regularizer(0.0001) 23 | weight_regularizer_fully = orthogonal_regularizer_fully(0.0001) 24 | 25 | # Regularization only G in BigGAN 26 | 27 | ################################################################################## 28 | # Layer 29 | ################################################################################## 30 | 31 | # pad = ceil[ (kernel - stride) / 2 ] 32 | 33 | def conv(x, channels, kernel=4, stride=2, pad=0, pad_type='zero', use_bias=True, sn=False, scope='conv_0'): 34 | with tf.variable_scope(scope): 35 | if pad > 0: 36 | h = x.get_shape().as_list()[1] 37 | if h % stride == 0: 38 | pad = pad * 2 39 | else: 40 | pad = max(kernel - (h % stride), 0) 41 | 42 | pad_top = pad // 2 43 | pad_bottom = pad - pad_top 44 | pad_left = pad // 2 45 | pad_right = pad - pad_left 46 | 47 | if pad_type == 'zero' : 48 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]) 49 | if pad_type == 'reflect' : 50 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]], mode='REFLECT') 51 | 52 | if sn : 53 | if scope.__contains__('generator') : 54 | w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], initializer=weight_init, 55 | regularizer=weight_regularizer) 56 | else : 57 | w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], initializer=weight_init, 58 | regularizer=None) 59 | 60 | x = tf.nn.conv2d(input=x, filter=spectral_norm(w), 61 | strides=[1, stride, stride, 1], padding='VALID') 62 | if use_bias : 63 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0)) 64 | x = tf.nn.bias_add(x, bias) 65 | 66 | else : 67 | if scope.__contains__('generator'): 68 | x = tf.layers.conv2d(inputs=x, filters=channels, 69 | kernel_size=kernel, kernel_initializer=weight_init, 70 | kernel_regularizer=weight_regularizer, 71 | strides=stride, use_bias=use_bias) 72 | else : 73 | x = tf.layers.conv2d(inputs=x, filters=channels, 74 | kernel_size=kernel, kernel_initializer=weight_init, 75 | kernel_regularizer=None, 76 | strides=stride, use_bias=use_bias) 77 | 78 | 79 | return x 80 | 81 | 82 | def deconv(x, channels, kernel=4, stride=2, padding='SAME', use_bias=True, sn=False, scope='deconv_0'): 83 | with tf.variable_scope(scope): 84 | x_shape = x.get_shape().as_list() 85 | 86 | if padding == 'SAME': 87 | output_shape = [x_shape[0], x_shape[1] * stride, x_shape[2] * stride, channels] 88 | 89 | else: 90 | output_shape =[x_shape[0], x_shape[1] * stride + max(kernel - stride, 0), x_shape[2] * stride + max(kernel - stride, 0), channels] 91 | 92 | if sn : 93 | w = tf.get_variable("kernel", shape=[kernel, kernel, channels, x.get_shape()[-1]], initializer=weight_init, regularizer=weight_regularizer) 94 | x = tf.nn.conv2d_transpose(x, filter=spectral_norm(w), output_shape=output_shape, strides=[1, stride, stride, 1], padding=padding) 95 | 96 | if use_bias : 97 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0)) 98 | x = tf.nn.bias_add(x, bias) 99 | 100 | else : 101 | x = tf.layers.conv2d_transpose(inputs=x, filters=channels, 102 | kernel_size=kernel, kernel_initializer=weight_init, kernel_regularizer=weight_regularizer, 103 | strides=stride, padding=padding, use_bias=use_bias) 104 | 105 | return x 106 | 107 | def fully_conneted(x, units, use_bias=True, sn=False, scope='fully_0'): 108 | with tf.variable_scope(scope): 109 | x = flatten(x) 110 | shape = x.get_shape().as_list() 111 | channels = shape[-1] 112 | 113 | if sn : 114 | if scope.__contains__('generator'): 115 | w = tf.get_variable("kernel", [channels, units], tf.float32, initializer=weight_init, regularizer=weight_regularizer_fully) 116 | else : 117 | w = tf.get_variable("kernel", [channels, units], tf.float32, initializer=weight_init, regularizer=None) 118 | 119 | if use_bias : 120 | bias = tf.get_variable("bias", [units], initializer=tf.constant_initializer(0.0)) 121 | 122 | x = tf.matmul(x, spectral_norm(w)) + bias 123 | else : 124 | x = tf.matmul(x, spectral_norm(w)) 125 | 126 | else : 127 | if scope.__contains__('generator'): 128 | x = tf.layers.dense(x, units=units, kernel_initializer=weight_init, 129 | kernel_regularizer=weight_regularizer_fully, use_bias=use_bias) 130 | else : 131 | x = tf.layers.dense(x, units=units, kernel_initializer=weight_init, 132 | kernel_regularizer=None, use_bias=use_bias) 133 | 134 | return x 135 | 136 | def flatten(x) : 137 | return tf.layers.flatten(x) 138 | 139 | def hw_flatten(x) : 140 | return tf.reshape(x, shape=[x.shape[0], -1, x.shape[-1]]) 141 | 142 | ################################################################################## 143 | # Residual-block, Self-Attention-block 144 | ################################################################################## 145 | 146 | def resblock(x_init, channels, use_bias=True, is_training=True, sn=False, scope='resblock'): 147 | with tf.variable_scope(scope): 148 | with tf.variable_scope('res1'): 149 | x = conv(x_init, channels, kernel=3, stride=1, pad=1, use_bias=use_bias, sn=sn) 150 | x = batch_norm(x, is_training) 151 | x = relu(x) 152 | 153 | with tf.variable_scope('res2'): 154 | x = conv(x, channels, kernel=3, stride=1, pad=1, use_bias=use_bias, sn=sn) 155 | x = batch_norm(x, is_training) 156 | 157 | return x + x_init 158 | 159 | def resblock_up(x_init, channels, use_bias=True, is_training=True, sn=False, scope='resblock_up'): 160 | with tf.variable_scope(scope): 161 | with tf.variable_scope('res1'): 162 | x = batch_norm(x_init, is_training) 163 | x = relu(x) 164 | x = deconv(x, channels, kernel=3, stride=2, use_bias=use_bias, sn=sn) 165 | 166 | with tf.variable_scope('res2') : 167 | x = batch_norm(x, is_training) 168 | x = relu(x) 169 | x = deconv(x, channels, kernel=3, stride=1, use_bias=use_bias, sn=sn) 170 | 171 | with tf.variable_scope('skip') : 172 | x_init = deconv(x_init, channels, kernel=3, stride=2, use_bias=use_bias, sn=sn) 173 | 174 | 175 | return x + x_init 176 | 177 | def resblock_up_condition(x_init, z, channels, use_bias=True, is_training=True, sn=False, scope='resblock_up'): 178 | with tf.variable_scope(scope): 179 | with tf.variable_scope('res1'): 180 | x = condition_batch_norm(x_init, z, is_training) 181 | x = relu(x) 182 | x = deconv(x, channels, kernel=3, stride=2, use_bias=use_bias, sn=sn) 183 | 184 | with tf.variable_scope('res2') : 185 | x = condition_batch_norm(x, z, is_training) 186 | x = relu(x) 187 | x = deconv(x, channels, kernel=3, stride=1, use_bias=use_bias, sn=sn) 188 | 189 | with tf.variable_scope('skip') : 190 | x_init = deconv(x_init, channels, kernel=3, stride=2, use_bias=use_bias, sn=sn) 191 | 192 | 193 | return x + x_init 194 | 195 | 196 | def resblock_down(x_init, channels, use_bias=True, is_training=True, sn=False, scope='resblock_down'): 197 | with tf.variable_scope(scope): 198 | with tf.variable_scope('res1'): 199 | x = batch_norm(x_init, is_training) 200 | x = relu(x) 201 | x = conv(x, channels, kernel=3, stride=2, pad=1, use_bias=use_bias, sn=sn) 202 | 203 | with tf.variable_scope('res2') : 204 | x = batch_norm(x, is_training) 205 | x = relu(x) 206 | x = conv(x, channels, kernel=3, stride=1, pad=1, use_bias=use_bias, sn=sn) 207 | 208 | with tf.variable_scope('skip') : 209 | x_init = conv(x_init, channels, kernel=3, stride=2, pad=1, use_bias=use_bias, sn=sn) 210 | 211 | 212 | return x + x_init 213 | 214 | def self_attention(x, channels, sn=False, scope='self_attention'): 215 | with tf.variable_scope(scope): 216 | f = conv(x, channels // 8, kernel=1, stride=1, sn=sn, scope='f_conv') # [bs, h, w, c'] 217 | g = conv(x, channels // 8, kernel=1, stride=1, sn=sn, scope='g_conv') # [bs, h, w, c'] 218 | h = conv(x, channels, kernel=1, stride=1, sn=sn, scope='h_conv') # [bs, h, w, c] 219 | 220 | # N = h * w 221 | s = tf.matmul(hw_flatten(g), hw_flatten(f), transpose_b=True) # # [bs, N, N] 222 | 223 | beta = tf.nn.softmax(s) # attention map 224 | 225 | o = tf.matmul(beta, hw_flatten(h)) # [bs, N, C] 226 | gamma = tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0)) 227 | 228 | o = tf.reshape(o, shape=x.shape) # [bs, h, w, C] 229 | x = gamma * o + x 230 | 231 | return x 232 | 233 | def self_attention_2(x, channels, sn=False, scope='self_attention'): 234 | with tf.variable_scope(scope): 235 | f = conv(x, channels // 8, kernel=1, stride=1, sn=sn, scope='f_conv') # [bs, h, w, c'] 236 | f = max_pooling(f) 237 | 238 | g = conv(x, channels // 8, kernel=1, stride=1, sn=sn, scope='g_conv') # [bs, h, w, c'] 239 | 240 | h = conv(x, channels // 2, kernel=1, stride=1, sn=sn, scope='h_conv') # [bs, h, w, c] 241 | h = max_pooling(h) 242 | 243 | # N = h * w 244 | s = tf.matmul(hw_flatten(g), hw_flatten(f), transpose_b=True) # # [bs, N, N] 245 | 246 | beta = tf.nn.softmax(s) # attention map 247 | 248 | o = tf.matmul(beta, hw_flatten(h)) # [bs, N, C] 249 | gamma = tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0)) 250 | 251 | o = tf.reshape(o, shape=[x.shape[0], x.shape[1], x.shape[2], channels // 2]) # [bs, h, w, C] 252 | o = conv(o, channels, kernel=1, stride=1, sn=sn, scope='attn_conv') 253 | x = gamma * o + x 254 | 255 | return x 256 | 257 | ################################################################################## 258 | # Sampling 259 | ################################################################################## 260 | 261 | def global_avg_pooling(x): 262 | gap = tf.reduce_mean(x, axis=[1, 2]) 263 | 264 | return gap 265 | 266 | def global_sum_pooling(x) : 267 | gsp = tf.reduce_sum(x, axis=[1, 2]) 268 | 269 | return gsp 270 | 271 | def max_pooling(x) : 272 | x = tf.layers.max_pooling2d(x, pool_size=2, strides=2, padding='SAME') 273 | return x 274 | 275 | def up_sample(x, scale_factor=2): 276 | _, h, w, _ = x.get_shape().as_list() 277 | new_size = [h * scale_factor, w * scale_factor] 278 | return tf.image.resize_nearest_neighbor(x, size=new_size) 279 | 280 | ################################################################################## 281 | # Activation function 282 | ################################################################################## 283 | 284 | def lrelu(x, alpha=0.2): 285 | return tf.nn.leaky_relu(x, alpha) 286 | 287 | 288 | def relu(x): 289 | return tf.nn.relu(x) 290 | 291 | 292 | def tanh(x): 293 | return tf.tanh(x) 294 | 295 | ################################################################################## 296 | # Normalization function 297 | ################################################################################## 298 | 299 | def batch_norm(x, is_training=True, scope='batch_norm'): 300 | return tf.layers.batch_normalization(x, 301 | momentum=0.9, 302 | epsilon=1e-05, 303 | training=is_training, 304 | name=scope) 305 | 306 | def condition_batch_norm(x, z, is_training=True, scope='batch_norm'): 307 | with tf.variable_scope(scope) : 308 | _, _, _, c = x.get_shape().as_list() 309 | decay = 0.9 310 | epsilon = 1e-05 311 | 312 | test_mean = tf.get_variable("pop_mean", shape=[c], dtype=tf.float32, initializer=tf.constant_initializer(0.0), trainable=False) 313 | test_var = tf.get_variable("pop_var", shape=[c], dtype=tf.float32, initializer=tf.constant_initializer(1.0), trainable=False) 314 | 315 | beta = fully_conneted(z, units=c, scope='beta') 316 | gamma = fully_conneted(z, units=c, scope='gamma') 317 | 318 | beta = tf.reshape(beta, shape=[-1, 1, 1, c]) 319 | gamma = tf.reshape(gamma, shape=[-1, 1, 1, c]) 320 | 321 | if is_training: 322 | batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2]) 323 | ema_mean = tf.assign(test_mean, test_mean * decay + batch_mean * (1 - decay)) 324 | ema_var = tf.assign(test_var, test_var * decay + batch_var * (1 - decay)) 325 | 326 | with tf.control_dependencies([ema_mean, ema_var]): 327 | return tf.nn.batch_normalization(x, batch_mean, batch_var, beta, gamma, epsilon) 328 | else: 329 | return tf.nn.batch_normalization(x, test_mean, test_var, beta, gamma, epsilon) 330 | 331 | 332 | def spectral_norm(w, iteration=1): 333 | w_shape = w.shape.as_list() 334 | w = tf.reshape(w, [-1, w_shape[-1]]) 335 | 336 | u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.random_normal_initializer(), trainable=False) 337 | 338 | u_hat = u 339 | v_hat = None 340 | for i in range(iteration): 341 | """ 342 | power iteration 343 | Usually iteration = 1 will be enough 344 | """ 345 | 346 | v_ = tf.matmul(u_hat, tf.transpose(w)) 347 | v_hat = tf.nn.l2_normalize(v_) 348 | 349 | u_ = tf.matmul(v_hat, w) 350 | u_hat = tf.nn.l2_normalize(u_) 351 | 352 | u_hat = tf.stop_gradient(u_hat) 353 | v_hat = tf.stop_gradient(v_hat) 354 | 355 | sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat)) 356 | 357 | with tf.control_dependencies([u.assign(u_hat)]): 358 | w_norm = w / sigma 359 | w_norm = tf.reshape(w_norm, w_shape) 360 | 361 | return w_norm 362 | 363 | ################################################################################## 364 | # Loss function 365 | ################################################################################## 366 | 367 | def discriminator_loss(loss_func, real, fake): 368 | real_loss = 0 369 | fake_loss = 0 370 | 371 | if loss_func.__contains__('wgan') : 372 | real_loss = -tf.reduce_mean(real) 373 | fake_loss = tf.reduce_mean(fake) 374 | 375 | if loss_func == 'lsgan' : 376 | real_loss = tf.reduce_mean(tf.squared_difference(real, 1.0)) 377 | fake_loss = tf.reduce_mean(tf.square(fake)) 378 | 379 | if loss_func == 'gan' or loss_func == 'dragan' : 380 | real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real), logits=real)) 381 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(fake), logits=fake)) 382 | 383 | if loss_func == 'hinge' : 384 | real_loss = tf.reduce_mean(relu(1.0 - real)) 385 | fake_loss = tf.reduce_mean(relu(1.0 + fake)) 386 | 387 | loss = real_loss + fake_loss 388 | 389 | return loss 390 | 391 | def generator_loss(loss_func, fake): 392 | fake_loss = 0 393 | 394 | if loss_func.__contains__('wgan') : 395 | fake_loss = -tf.reduce_mean(fake) 396 | 397 | if loss_func == 'lsgan' : 398 | fake_loss = tf.reduce_mean(tf.squared_difference(fake, 1.0)) 399 | 400 | if loss_func == 'gan' or loss_func == 'dragan' : 401 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(fake), logits=fake)) 402 | 403 | if loss_func == 'hinge' : 404 | fake_loss = -tf.reduce_mean(fake) 405 | 406 | loss = fake_loss 407 | 408 | return loss -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import scipy.misc 2 | import numpy as np 3 | import os 4 | from glob import glob 5 | 6 | import tensorflow as tf 7 | import tensorflow.contrib.slim as slim 8 | from keras.datasets import cifar10, mnist 9 | 10 | class ImageData: 11 | 12 | def __init__(self, load_size, channels, custom_dataset): 13 | self.load_size = load_size 14 | self.channels = channels 15 | self.custom_dataset = custom_dataset 16 | 17 | def image_processing(self, filename): 18 | 19 | if not self.custom_dataset : 20 | x_decode = filename 21 | else : 22 | x = tf.read_file(filename) 23 | x_decode = tf.image.decode_jpeg(x, channels=self.channels) 24 | 25 | img = tf.image.resize_images(x_decode, [self.load_size, self.load_size]) 26 | img = tf.cast(img, tf.float32) / 127.5 - 1 27 | 28 | return img 29 | 30 | 31 | def load_mnist(): 32 | (train_data, train_labels), (test_data, test_labels) = mnist.load_data() 33 | x = np.concatenate((train_data, test_data), axis=0) 34 | x = np.expand_dims(x, axis=-1) 35 | 36 | return x 37 | 38 | def load_cifar10() : 39 | (train_data, train_labels), (test_data, test_labels) = cifar10.load_data() 40 | x = np.concatenate((train_data, test_data), axis=0) 41 | 42 | return x 43 | 44 | def load_data(dataset_name) : 45 | if dataset_name == 'mnist' : 46 | x = load_mnist() 47 | elif dataset_name == 'cifar10' : 48 | x = load_cifar10() 49 | else : 50 | 51 | x = glob(os.path.join("./dataset", dataset_name, '*.*')) 52 | 53 | return x 54 | 55 | 56 | def preprocessing(x, size): 57 | x = scipy.misc.imread(x, mode='RGB') 58 | x = scipy.misc.imresize(x, [size, size]) 59 | x = normalize(x) 60 | return x 61 | 62 | def normalize(x) : 63 | return x/127.5 - 1 64 | 65 | def save_images(images, size, image_path): 66 | return imsave(inverse_transform(images), size, image_path) 67 | 68 | def merge(images, size): 69 | h, w = images.shape[1], images.shape[2] 70 | if (images.shape[3] in (3,4)): 71 | c = images.shape[3] 72 | img = np.zeros((h * size[0], w * size[1], c)) 73 | for idx, image in enumerate(images): 74 | i = idx % size[1] 75 | j = idx // size[1] 76 | img[j * h:j * h + h, i * w:i * w + w, :] = image 77 | return img 78 | elif images.shape[3]==1: 79 | img = np.zeros((h * size[0], w * size[1])) 80 | for idx, image in enumerate(images): 81 | i = idx % size[1] 82 | j = idx // size[1] 83 | img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0] 84 | return img 85 | else: 86 | raise ValueError('in merge(images,size) images parameter ''must have dimensions: HxW or HxWx3 or HxWx4') 87 | 88 | def imsave(images, size, path): 89 | # image = np.squeeze(merge(images, size)) # 채널이 1인거 제거 ? 90 | return scipy.misc.imsave(path, merge(images, size)) 91 | 92 | 93 | def inverse_transform(images): 94 | return (images+1.)/2. 95 | 96 | 97 | def check_folder(log_dir): 98 | if not os.path.exists(log_dir): 99 | os.makedirs(log_dir) 100 | return log_dir 101 | 102 | def show_all_variables(): 103 | model_vars = tf.trainable_variables() 104 | slim.model_analyzer.analyze_vars(model_vars, print_info=True) 105 | 106 | def str2bool(x): 107 | return x.lower() in ('true') 108 | 109 | ################################################################################## 110 | # Regularization 111 | ################################################################################## 112 | 113 | def orthogonal_regularizer(scale) : 114 | """ Defining the Orthogonal regularizer and return the function at last to be used in Conv layer as kernel regularizer""" 115 | 116 | def ortho_reg(w) : 117 | """ Reshaping the matrxi in to 2D tensor for enforcing orthogonality""" 118 | _, _, _, c = w.get_shape().as_list() 119 | 120 | w = tf.reshape(w, [-1, c]) 121 | 122 | """ Declaring a Identity Tensor of appropriate size""" 123 | identity = tf.eye(c) 124 | 125 | """ Regularizer Wt*W - I """ 126 | w_transpose = tf.transpose(w) 127 | w_mul = tf.matmul(w_transpose, w) 128 | reg = tf.subtract(w_mul, identity) 129 | 130 | """Calculating the Loss Obtained""" 131 | ortho_loss = tf.nn.l2_loss(reg) 132 | 133 | return scale * ortho_loss 134 | 135 | return ortho_reg 136 | 137 | def orthogonal_regularizer_fully(scale) : 138 | """ Defining the Orthogonal regularizer and return the function at last to be used in Fully Connected Layer """ 139 | 140 | def ortho_reg_fully(w) : 141 | """ Reshaping the matrix in to 2D tensor for enforcing orthogonality""" 142 | _, c = w.get_shape().as_list() 143 | 144 | """Declaring a Identity Tensor of appropriate size""" 145 | identity = tf.eye(c) 146 | w_transpose = tf.transpose(w) 147 | w_mul = tf.matmul(w_transpose, w) 148 | reg = tf.subtract(w_mul, identity) 149 | 150 | """ Calculating the Loss """ 151 | ortho_loss = tf.nn.l2_loss(reg) 152 | 153 | return scale * ortho_loss 154 | 155 | return ortho_reg_fully --------------------------------------------------------------------------------