├── .DS_Store ├── .gitignore ├── DatasetAPI ├── UNIT.py ├── UNIT_multi_gpu.py ├── main.py ├── main_multi_gpu.py ├── ops.py └── utils.py ├── LICENSE ├── README.md ├── UNIT.py ├── UNIT_multi_gpu.py ├── assests ├── .DS_Store ├── architecture.png ├── cat_species.gif ├── cat_trans.png ├── compare.png ├── cycle.png ├── dog_breed.gif ├── dog_trans.png ├── faces.png ├── fail.png ├── framework.png ├── gan_model.png ├── slide │ ├── compare.png │ ├── cycle.png │ ├── framework.png │ ├── gan_model.png │ ├── training_objective.png │ └── vae_model.png ├── success.png ├── training_objective__.png └── vae_model.png ├── main.py ├── main_multi_gpu.py ├── ops.py └── utils.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /DatasetAPI/UNIT.py: -------------------------------------------------------------------------------- 1 | from ops import * 2 | from utils import * 3 | from glob import glob 4 | import time 5 | from tensorflow.contrib.data import batch_and_drop_remainder 6 | 7 | class UNIT(object) : 8 | def __init__(self, sess, args): 9 | self.model_name = 'UNIT' 10 | self.sess = sess 11 | self.checkpoint_dir = args.checkpoint_dir 12 | self.result_dir = args.result_dir 13 | self.log_dir = args.log_dir 14 | self.sample_dir = args.sample_dir 15 | self.dataset_name = args.dataset 16 | self.augment_flag = args.augment_flag 17 | 18 | self.epoch = args.epoch 19 | self.iteration = args.iteration 20 | self.gan_type = args.gan_type 21 | 22 | self.batch_size = args.batch_size 23 | self.print_freq = args.print_freq 24 | self.save_freq = args.save_freq 25 | 26 | self.img_size = args.img_size 27 | self.img_ch = args.img_ch 28 | 29 | self.init_lr = args.lr 30 | self.ch = args.ch 31 | 32 | """ Weight about VAE """ 33 | self.KL_weight = args.KL_weight # lambda 1 34 | self.L1_weight = args.L1_weight # lambda 2 35 | 36 | """ Weight about VAE Cycle""" 37 | self.KL_cycle_weight = args.KL_cycle_weight # lambda 3 38 | self.L1_cycle_weight = args.L1_cycle_weight # lambda 4 39 | 40 | """ Weight about GAN """ 41 | self.GAN_weight = args.GAN_weight # lambda 0 42 | 43 | """ Encoder """ 44 | self.n_encoder = args.n_encoder 45 | self.n_enc_resblock = args.n_enc_resblock 46 | self.n_enc_share = args.n_enc_share 47 | 48 | """ Generator """ 49 | self.n_gen_share = args.n_gen_share 50 | self.n_gen_resblock = args.n_gen_resblock 51 | self.n_gen_decoder = args.n_gen_decoder 52 | 53 | """ Discriminator """ 54 | self.n_dis = args.n_dis 55 | 56 | self.sample_dir = os.path.join(args.sample_dir, self.model_dir) 57 | check_folder(self.sample_dir) 58 | 59 | self.trainA_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainA')) 60 | self.trainB_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainB')) 61 | self.dataset_num = max(len(self.trainA_dataset), len(self.trainB_dataset)) 62 | 63 | print("##### Information #####") 64 | print("# gan type : ", self.gan_type) 65 | print("# dataset : ", self.dataset_name) 66 | print("# max dataset number : ", self.dataset_num) 67 | print("# batch_size : ", self.batch_size) 68 | print("# epoch : ", self.epoch) 69 | print("# iteration per epoch : ", self.iteration) 70 | 71 | print() 72 | 73 | print("##### Encoder #####") 74 | print("# encoder blocks : ", self.n_encoder) 75 | print("# encoder resblock : ", self.n_enc_resblock) 76 | print("# encoder share : ", self.n_enc_share) 77 | 78 | print() 79 | 80 | print("##### Decoder #####") 81 | print("# decoder share : ", self.n_gen_share) 82 | print("# decoder resblock : ", self.n_gen_resblock) 83 | print("# decoder blocks : ", self.n_gen_decoder) 84 | 85 | print() 86 | 87 | print("##### Discriminator #####") 88 | print("# Discriminator layer : ", self.n_dis) 89 | 90 | ############################################################################## 91 | # BEGIN of ENCODERS 92 | 93 | def encoder(self, x, reuse=False, scope="encoder"): 94 | channel = self.ch 95 | with tf.variable_scope(scope, reuse=reuse): 96 | x = conv(x, channel, kernel=7, stride=1, pad=3, scope='conv_0') 97 | x = lrelu(x, 0.01) 98 | 99 | for i in range(1, self.n_encoder): 100 | x = conv(x, channel * 2, kernel=3, stride=2, pad=1, scope='conv_' + str(i)) 101 | x = lrelu(x, 0.01) 102 | channel *= 2 103 | 104 | # channel = 256 105 | for i in range(0, self.n_enc_resblock): 106 | x = resblock(x, channel, scope='resblock_'+str(i)) 107 | 108 | return x 109 | # END of ENCODERS 110 | ############################################################################## 111 | 112 | ############################################################################## 113 | # BEGIN of SHARED LAYERS 114 | # Shared residual-blocks 115 | 116 | def share_encoder(self, x, reuse=False, scope="share_encoder"): 117 | channel = self.ch * pow(2, self.n_encoder - 1) 118 | with tf.variable_scope(scope, reuse=reuse): 119 | for i in range(0, self.n_enc_share): 120 | x = resblock(x, channel, scope='resblock_' + str(i)) 121 | 122 | x = gaussian_noise_layer(x) 123 | 124 | return x 125 | 126 | def share_generator(self, x, reuse=False, scope="share_generator"): 127 | channel = self.ch * pow(2, self.n_encoder - 1) 128 | with tf.variable_scope(scope, reuse=reuse): 129 | for i in range(0, self.n_gen_share): 130 | x = resblock(x, channel, scope='resblock_' + str(i)) 131 | 132 | return x 133 | # END of SHARED LAYERS 134 | ############################################################################## 135 | 136 | ############################################################################## 137 | # BEGIN of DECODERS 138 | 139 | def generator(self, x, reuse=False, scope="generator"): 140 | channel = self.ch * pow(2, self.n_encoder - 1) 141 | with tf.variable_scope(scope, reuse=reuse): 142 | for i in range(0, self.n_gen_resblock): 143 | x = resblock(x, channel, scope='resblock_' + str(i)) 144 | 145 | for i in range(0, self.n_gen_decoder - 1): 146 | x = deconv(x, channel // 2, kernel=3, stride=2, scope='deconv_' + str(i)) 147 | x = lrelu(x, 0.01) 148 | channel = channel // 2 149 | 150 | x = deconv(x, channels=3, kernel=1, stride=1, scope='G_logit') 151 | x = tanh(x) 152 | 153 | return x 154 | # END of DECODERS 155 | ############################################################################## 156 | 157 | ############################################################################## 158 | # BEGIN of DISCRIMINATORS 159 | 160 | def discriminator(self, x, reuse=False, scope="discriminator"): 161 | channel = self.ch 162 | with tf.variable_scope(scope, reuse=reuse): 163 | x = conv(x, channel, kernel=3, stride=2, pad=1, scope='conv_0') 164 | x = lrelu(x, 0.01) 165 | 166 | for i in range(1, self.n_dis): 167 | x = conv(x, channel * 2, kernel=3, stride=2, pad=1, scope='conv_' + str(i)) 168 | x = lrelu(x, 0.01) 169 | channel *= 2 170 | 171 | x = conv(x, channels=1, kernel=1, stride=1, scope='D_logit') 172 | 173 | return x 174 | 175 | # END of DISCRIMINATORS 176 | ############################################################################## 177 | 178 | def translation(self, x_A, x_B): 179 | out = tf.concat([self.encoder(x_A, scope="encoder_A"), self.encoder(x_B, scope="encoder_B")], axis=0) 180 | shared = self.share_encoder(out) 181 | out = self.share_generator(shared) 182 | 183 | out_A = self.generator(out, scope="generator_A") 184 | out_B = self.generator(out, scope="generator_B") 185 | 186 | x_Aa, x_Ba = tf.split(out_A, 2, axis=0) 187 | x_Ab, x_Bb = tf.split(out_B, 2, axis=0) 188 | 189 | return x_Aa, x_Ba, x_Ab, x_Bb, shared 190 | 191 | def generate_a2b(self, x_A): 192 | out = self.encoder(x_A, reuse=True, scope="encoder_A") 193 | shared = self.share_encoder(out, reuse=True) 194 | out = self.share_generator(shared, reuse=True) 195 | out = self.generator(out, reuse=True, scope="generator_B") 196 | 197 | return out, shared 198 | 199 | def generate_b2a(self, x_B): 200 | out = self.encoder(x_B, reuse=True, scope="encoder_B") 201 | shared = self.share_encoder(out, reuse=True) 202 | out = self.share_generator(shared, reuse=True) 203 | out = self.generator(out, reuse=True, scope="generator_A") 204 | 205 | return out, shared 206 | 207 | def discriminate_real(self, x_A, x_B): 208 | real_A_logit = self.discriminator(x_A, scope="discriminator_A") 209 | real_B_logit = self.discriminator(x_B, scope="discriminator_B") 210 | 211 | return real_A_logit, real_B_logit 212 | 213 | def discriminate_fake(self, x_ba, x_ab): 214 | fake_A_logit = self.discriminator(x_ba, reuse=True, scope="discriminator_A") 215 | fake_B_logit = self.discriminator(x_ab, reuse=True, scope="discriminator_B") 216 | 217 | return fake_A_logit, fake_B_logit 218 | 219 | def build_model(self): 220 | self.lr = tf.placeholder(tf.float32, name='learning_rate') 221 | 222 | """ Input Image""" 223 | Image_Data_Class = ImageData(self.img_size, self.img_ch, self.augment_flag) 224 | 225 | trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset) 226 | trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset) 227 | 228 | trainA = trainA.prefetch(self.batch_size).shuffle(self.dataset_num).map(Image_Data_Class.image_processing, num_parallel_calls=8).apply(batch_and_drop_remainder(self.batch_size)).repeat() 229 | trainB = trainB.prefetch(self.batch_size).shuffle(self.dataset_num).map(Image_Data_Class.image_processing, num_parallel_calls=8).apply(batch_and_drop_remainder(self.batch_size)).repeat() 230 | 231 | trainA_iterator = trainA.make_one_shot_iterator() 232 | trainB_iterator = trainB.make_one_shot_iterator() 233 | 234 | 235 | self.domain_A = trainA_iterator.get_next() 236 | self.domain_B = trainB_iterator.get_next() 237 | 238 | 239 | """ Define Encoder, Generator, Discriminator """ 240 | x_aa, x_ba, x_ab, x_bb, shared = self.translation(self.domain_A, self.domain_B) 241 | x_bab, shared_bab = self.generate_a2b(x_ba) 242 | x_aba, shared_aba = self.generate_b2a(x_ab) 243 | 244 | real_A_logit, real_B_logit = self.discriminate_real(self.domain_A, self.domain_B) 245 | 246 | 247 | fake_A_logit, fake_B_logit = self.discriminate_fake(x_ba, x_ab) 248 | 249 | """ Define Loss """ 250 | G_ad_loss_a = generator_loss(self.gan_type, fake_A_logit) 251 | G_ad_loss_b = generator_loss(self.gan_type, fake_B_logit) 252 | 253 | D_ad_loss_a = discriminator_loss(self.gan_type, real_A_logit, fake_A_logit) 254 | D_ad_loss_b = discriminator_loss(self.gan_type, real_B_logit, fake_B_logit) 255 | 256 | enc_loss = KL_divergence(shared) 257 | enc_bab_loss = KL_divergence(shared_bab) 258 | enc_aba_loss = KL_divergence(shared_aba) 259 | 260 | l1_loss_a = L1_loss(x_aa, self.domain_A) # identity 261 | l1_loss_b = L1_loss(x_bb, self.domain_B) # identity 262 | l1_loss_aba = L1_loss(x_aba, self.domain_A) # reconstruction 263 | l1_loss_bab = L1_loss(x_bab, self.domain_B) # reconstruction 264 | 265 | Generator_A_loss = self.GAN_weight * G_ad_loss_a + \ 266 | self.L1_weight * l1_loss_a + \ 267 | self.L1_cycle_weight * l1_loss_aba + \ 268 | self.KL_weight * enc_loss + \ 269 | self.KL_cycle_weight * enc_bab_loss 270 | 271 | Generator_B_loss = self.GAN_weight * G_ad_loss_b + \ 272 | self.L1_weight * l1_loss_b + \ 273 | self.L1_cycle_weight * l1_loss_bab + \ 274 | self.KL_weight * enc_loss + \ 275 | self.KL_cycle_weight * enc_aba_loss 276 | 277 | Discriminator_A_loss = self.GAN_weight * D_ad_loss_a 278 | Discriminator_B_loss = self.GAN_weight * D_ad_loss_b 279 | 280 | self.Generator_loss = Generator_A_loss + Generator_B_loss 281 | self.Discriminator_loss = Discriminator_A_loss + Discriminator_B_loss 282 | 283 | """ Training """ 284 | t_vars = tf.trainable_variables() 285 | G_vars = [var for var in t_vars if 'generator' in var.name or 'encoder' in var.name] 286 | D_vars = [var for var in t_vars if 'discriminator' in var.name] 287 | 288 | self.G_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Generator_loss, var_list=G_vars) 289 | self.D_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Discriminator_loss, var_list=D_vars) 290 | 291 | """" Summary """ 292 | self.all_G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss) 293 | self.all_D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss) 294 | self.G_A_loss = tf.summary.scalar("G_A_loss", Generator_A_loss) 295 | self.G_B_loss = tf.summary.scalar("G_B_loss", Generator_B_loss) 296 | self.D_A_loss = tf.summary.scalar("D_A_loss", Discriminator_A_loss) 297 | self.D_B_loss = tf.summary.scalar("D_B_loss", Discriminator_B_loss) 298 | 299 | self.G_loss = tf.summary.merge([self.G_A_loss, self.G_B_loss, self.all_G_loss]) 300 | self.D_loss = tf.summary.merge([self.D_A_loss, self.D_B_loss, self.all_D_loss]) 301 | 302 | """ Image """ 303 | self.fake_A = x_ba 304 | self.fake_B = x_ab 305 | 306 | self.real_A = self.domain_A 307 | self.real_B = self.domain_B 308 | 309 | """ Test """ 310 | self.test_image = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_image') 311 | 312 | self.test_fake_A, _ = self.generate_b2a(self.test_image) 313 | self.test_fake_B, _ = self.generate_a2b(self.test_image) 314 | 315 | def train(self): 316 | # initialize all variables 317 | tf.global_variables_initializer().run() 318 | 319 | # saver to save model 320 | self.saver = tf.train.Saver() 321 | 322 | # summary writer 323 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_dir, self.sess.graph) 324 | 325 | # restore check-point if it exits 326 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 327 | if could_load: 328 | start_epoch = (int)(checkpoint_counter / self.iteration) 329 | start_batch_id = checkpoint_counter - start_epoch * self.iteration 330 | counter = checkpoint_counter 331 | print(" [*] Load SUCCESS") 332 | else: 333 | start_epoch = 0 334 | start_batch_id = 0 335 | counter = 1 336 | print(" [!] Load failed...") 337 | 338 | # loop for epoch 339 | start_time = time.time() 340 | lr = self.init_lr 341 | for epoch in range(start_epoch, self.epoch): 342 | for idx in range(start_batch_id, self.iteration): 343 | train_feed_dict = { 344 | self.lr : lr 345 | } 346 | 347 | # Update D 348 | _, d_loss, summary_str = self.sess.run([self.D_optim, self.Discriminator_loss, self.D_loss], feed_dict = train_feed_dict) 349 | self.writer.add_summary(summary_str, counter) 350 | 351 | # Update G 352 | batch_A_images, batch_B_images, fake_A, fake_B, _, g_loss, summary_str = self.sess.run([self.real_A, self.real_B, self.fake_A, self.fake_B, self.G_optim, self.Generator_loss, self.G_loss], feed_dict = train_feed_dict) 353 | self.writer.add_summary(summary_str, counter) 354 | 355 | # display training status 356 | counter += 1 357 | print("Epoch: [%2d] [%6d/%6d] time: %4.4f d_loss: %.8f, g_loss: %.8f" \ 358 | % (epoch, idx, self.iteration, time.time() - start_time, d_loss, g_loss)) 359 | 360 | if np.mod(idx+1, self.print_freq) == 0 : 361 | save_images(batch_A_images, [self.batch_size, 1], 362 | './{}/real_A_{:02d}_{:06d}.jpg'.format(self.sample_dir, epoch, idx+1)) 363 | # save_images(batch_B_images, [self.batch_size, 1], 364 | # './{}/real_B_{}_{:02d}_{:06d}.jpg'.format(self.sample_dir, gpu_id, epoch, idx+1)) 365 | 366 | # save_images(fake_A, [self.batch_size, 1], 367 | # './{}/fake_A_{}_{:02d}_{:06d}.jpg'.format(self.sample_dir, gpu_id, epoch, idx+1)) 368 | save_images(fake_B, [self.batch_size, 1], 369 | './{}/fake_B_{:02d}_{:06d}.jpg'.format(self.sample_dir, epoch, idx+1)) 370 | 371 | if np.mod(idx+1, self.save_freq) == 0 : 372 | self.save(self.checkpoint_dir, counter) 373 | 374 | # After an epoch, start_batch_id is set to zero 375 | # non-zero value is only for the first epoch after loading pre-trained model 376 | start_batch_id = 0 377 | 378 | # save model for final step 379 | self.save(self.checkpoint_dir, counter) 380 | 381 | @property 382 | def model_dir(self): 383 | return "{}_{}_{}".format(self.model_name, self.dataset_name, self.gan_type) 384 | 385 | def save(self, checkpoint_dir, step): 386 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) 387 | 388 | if not os.path.exists(checkpoint_dir): 389 | os.makedirs(checkpoint_dir) 390 | 391 | self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step) 392 | 393 | def load(self, checkpoint_dir): 394 | import re 395 | print(" [*] Reading checkpoints...") 396 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) 397 | 398 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 399 | if ckpt and ckpt.model_checkpoint_path: 400 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 401 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 402 | counter = int(next(re.finditer("(\d+)(?!.*\d)", ckpt_name)).group(0)) 403 | print(" [*] Success to read {}".format(ckpt_name)) 404 | return True, counter 405 | else: 406 | print(" [*] Failed to find a checkpoint") 407 | return False, 0 408 | 409 | def test(self): 410 | tf.global_variables_initializer().run() 411 | test_A_files = glob('./dataset/{}/*.*'.format(self.dataset_name + '/testA')) 412 | test_B_files = glob('./dataset/{}/*.*'.format(self.dataset_name + '/testB')) 413 | 414 | self.saver = tf.train.Saver() 415 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 416 | self.result_dir = os.path.join(self.result_dir, self.model_dir) 417 | check_folder(self.result_dir) 418 | 419 | if could_load : 420 | print(" [*] Load SUCCESS") 421 | else : 422 | print(" [!] Load failed...") 423 | 424 | # write html for visual comparison 425 | index_path = os.path.join(self.result_dir, 'index.html') 426 | index = open(index_path, 'w') 427 | index.write("") 428 | index.write("") 429 | 430 | for sample_file in test_A_files : # A -> B 431 | print('Processing A image: ' + sample_file) 432 | sample_image = np.asarray(load_test_data(sample_file, size=self.img_size)) 433 | image_path = os.path.join(self.result_dir, '{0}'.format(os.path.basename(sample_file))) 434 | 435 | fake_img = self.sess.run(self.test_fake_B, feed_dict={self.test_image: sample_image}) 436 | save_images(fake_img, [1, 1], image_path) 437 | 438 | index.write("" % os.path.basename(image_path)) 439 | index.write("" % (sample_file if os.path.isabs(sample_file) else ( 440 | '../..' + os.path.sep + sample_file), self.img_size, self.img_size)) 441 | index.write("" % (image_path if os.path.isabs(image_path) else ( 442 | '../..' + os.path.sep + image_path), self.img_size, self.img_size)) 443 | index.write("") 444 | 445 | for sample_file in test_B_files : # B -> A 446 | print('Processing B image: ' + sample_file) 447 | sample_image = np.asarray(load_test_data(sample_file, size=self.img_size)) 448 | image_path = os.path.join(self.result_dir, '{0}'.format(os.path.basename(sample_file))) 449 | 450 | fake_img = self.sess.run(self.test_fake_A, feed_dict={self.test_image: sample_image}) 451 | save_images(fake_img, [1, 1], image_path) 452 | 453 | index.write("" % os.path.basename(image_path)) 454 | index.write("" % (sample_file if os.path.isabs(sample_file) else ( 455 | '../..' + os.path.sep + sample_file), self.img_size, self.img_size)) 456 | index.write("" % (image_path if os.path.isabs(image_path) else ( 457 | '../..' + os.path.sep + image_path), self.img_size, self.img_size)) 458 | index.write("") 459 | 460 | index.close() 461 | -------------------------------------------------------------------------------- /DatasetAPI/UNIT_multi_gpu.py: -------------------------------------------------------------------------------- 1 | from ops import * 2 | from utils import * 3 | from glob import glob 4 | import time 5 | from tensorflow.contrib.data import batch_and_drop_remainder 6 | 7 | class UNIT(object) : 8 | def __init__(self, sess, args): 9 | self.model_name = 'UNIT' 10 | self.sess = sess 11 | self.checkpoint_dir = args.checkpoint_dir 12 | self.result_dir = args.result_dir 13 | self.log_dir = args.log_dir 14 | self.sample_dir = args.sample_dir 15 | self.dataset_name = args.dataset 16 | self.augment_flag = args.augment_flag 17 | 18 | self.epoch = args.epoch 19 | self.iteration = args.iteration 20 | self.gan_type = args.gan_type 21 | 22 | self.batch_size_per_gpu = args.batch_size 23 | self.batch_size = args.batch_size * args.gpu_num 24 | self.gpu_num = args.gpu_num 25 | self.print_freq = args.print_freq 26 | self.save_freq = args.save_freq 27 | 28 | self.img_size = args.img_size 29 | self.img_ch = args.img_ch 30 | 31 | self.init_lr = args.lr 32 | self.ch = args.ch 33 | 34 | """ Weight about VAE """ 35 | self.KL_weight = args.KL_weight # lambda 1 36 | self.L1_weight = args.L1_weight # lambda 2 37 | 38 | """ Weight about VAE Cycle""" 39 | self.KL_cycle_weight = args.KL_cycle_weight # lambda 3 40 | self.L1_cycle_weight = args.L1_cycle_weight # lambda 4 41 | 42 | """ Weight about GAN """ 43 | self.GAN_weight = args.GAN_weight # lambda 0 44 | 45 | """ Encoder """ 46 | self.n_encoder = args.n_encoder 47 | self.n_enc_resblock = args.n_enc_resblock 48 | self.n_enc_share = args.n_enc_share 49 | 50 | """ Generator """ 51 | self.n_gen_share = args.n_gen_share 52 | self.n_gen_resblock = args.n_gen_resblock 53 | self.n_gen_decoder = args.n_gen_decoder 54 | 55 | """ Discriminator """ 56 | self.n_dis = args.n_dis 57 | 58 | self.sample_dir = os.path.join(args.sample_dir, self.model_dir) 59 | check_folder(self.sample_dir) 60 | 61 | self.trainA_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainA')) 62 | self.trainB_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainB')) 63 | self.dataset_num = max(len(self.trainA_dataset), len(self.trainB_dataset)) 64 | 65 | print("##### Information #####") 66 | print("# gan type : ", self.gan_type) 67 | print("# dataset : ", self.dataset_name) 68 | print("# max dataset number : ", self.dataset_num) 69 | print("# batch_size : ", self.batch_size) 70 | print("# epoch : ", self.epoch) 71 | print("# iteration per epoch : ", self.iteration) 72 | 73 | print() 74 | 75 | print("##### Encoder #####") 76 | print("# encoder blocks : ", self.n_encoder) 77 | print("# encoder resblock : ", self.n_enc_resblock) 78 | print("# encoder share : ", self.n_enc_share) 79 | 80 | print() 81 | 82 | print("##### Decoder #####") 83 | print("# decoder share : ", self.n_gen_share) 84 | print("# decoder resblock : ", self.n_gen_resblock) 85 | print("# decoder blocks : ", self.n_gen_decoder) 86 | 87 | print() 88 | 89 | print("##### Discriminator #####") 90 | print("# Discriminator layer : ", self.n_dis) 91 | 92 | ############################################################################## 93 | # BEGIN of ENCODERS 94 | 95 | def encoder(self, x, reuse=False, scope="encoder"): 96 | channel = self.ch 97 | with tf.variable_scope(scope, reuse=reuse): 98 | x = conv(x, channel, kernel=7, stride=1, pad=3, scope='conv_0') 99 | x = lrelu(x, 0.01) 100 | 101 | for i in range(1, self.n_encoder): 102 | x = conv(x, channel * 2, kernel=3, stride=2, pad=1, scope='conv_' + str(i)) 103 | x = lrelu(x, 0.01) 104 | channel *= 2 105 | 106 | # channel = 256 107 | for i in range(0, self.n_enc_resblock): 108 | x = resblock(x, channel, scope='resblock_'+str(i)) 109 | 110 | return x 111 | # END of ENCODERS 112 | ############################################################################## 113 | 114 | ############################################################################## 115 | # BEGIN of SHARED LAYERS 116 | # Shared residual-blocks 117 | 118 | def share_encoder(self, x, reuse=False, scope="share_encoder"): 119 | channel = self.ch * pow(2, self.n_encoder - 1) 120 | with tf.variable_scope(scope, reuse=reuse): 121 | for i in range(0, self.n_enc_share): 122 | x = resblock(x, channel, scope='resblock_' + str(i)) 123 | 124 | x = gaussian_noise_layer(x) 125 | 126 | return x 127 | 128 | def share_generator(self, x, reuse=False, scope="share_generator"): 129 | channel = self.ch * pow(2, self.n_encoder - 1) 130 | with tf.variable_scope(scope, reuse=reuse): 131 | for i in range(0, self.n_gen_share): 132 | x = resblock(x, channel, scope='resblock_' + str(i)) 133 | 134 | return x 135 | # END of SHARED LAYERS 136 | ############################################################################## 137 | 138 | ############################################################################## 139 | # BEGIN of DECODERS 140 | 141 | def generator(self, x, reuse=False, scope="generator"): 142 | channel = self.ch * pow(2, self.n_encoder - 1) 143 | with tf.variable_scope(scope, reuse=reuse): 144 | for i in range(0, self.n_gen_resblock): 145 | x = resblock(x, channel, scope='resblock_' + str(i)) 146 | 147 | for i in range(0, self.n_gen_decoder - 1): 148 | x = deconv(x, channel // 2, kernel=3, stride=2, scope='deconv_' + str(i)) 149 | x = lrelu(x, 0.01) 150 | channel = channel // 2 151 | 152 | x = deconv(x, channels=3, kernel=1, stride=1, scope='G_logit') 153 | x = tanh(x) 154 | 155 | return x 156 | # END of DECODERS 157 | ############################################################################## 158 | 159 | ############################################################################## 160 | # BEGIN of DISCRIMINATORS 161 | 162 | def discriminator(self, x, reuse=False, scope="discriminator"): 163 | channel = self.ch 164 | with tf.variable_scope(scope, reuse=reuse): 165 | x = conv(x, channel, kernel=3, stride=2, pad=1, scope='conv_0') 166 | x = lrelu(x, 0.01) 167 | 168 | for i in range(1, self.n_dis): 169 | x = conv(x, channel * 2, kernel=3, stride=2, pad=1, scope='conv_' + str(i)) 170 | x = lrelu(x, 0.01) 171 | channel *= 2 172 | 173 | x = conv(x, channels=1, kernel=1, stride=1, scope='D_logit') 174 | 175 | return x 176 | 177 | # END of DISCRIMINATORS 178 | ############################################################################## 179 | 180 | def translation(self, x_A, x_B): 181 | out = tf.concat([self.encoder(x_A, scope="encoder_A"), self.encoder(x_B, scope="encoder_B")], axis=0) 182 | shared = self.share_encoder(out) 183 | out = self.share_generator(shared) 184 | 185 | out_A = self.generator(out, scope="generator_A") 186 | out_B = self.generator(out, scope="generator_B") 187 | 188 | x_Aa, x_Ba = tf.split(out_A, 2, axis=0) 189 | x_Ab, x_Bb = tf.split(out_B, 2, axis=0) 190 | 191 | return x_Aa, x_Ba, x_Ab, x_Bb, shared 192 | 193 | def generate_a2b(self, x_A): 194 | out = self.encoder(x_A, reuse=True, scope="encoder_A") 195 | shared = self.share_encoder(out, reuse=True) 196 | out = self.share_generator(shared, reuse=True) 197 | out = self.generator(out, reuse=True, scope="generator_B") 198 | 199 | return out, shared 200 | 201 | def generate_b2a(self, x_B): 202 | out = self.encoder(x_B, reuse=True, scope="encoder_B") 203 | shared = self.share_encoder(out, reuse=True) 204 | out = self.share_generator(shared, reuse=True) 205 | out = self.generator(out, reuse=True, scope="generator_A") 206 | 207 | return out, shared 208 | 209 | def discriminate_real(self, x_A, x_B): 210 | real_A_logit = self.discriminator(x_A, scope="discriminator_A") 211 | real_B_logit = self.discriminator(x_B, scope="discriminator_B") 212 | 213 | return real_A_logit, real_B_logit 214 | 215 | def discriminate_fake(self, x_ba, x_ab): 216 | fake_A_logit = self.discriminator(x_ba, reuse=True, scope="discriminator_A") 217 | fake_B_logit = self.discriminator(x_ab, reuse=True, scope="discriminator_B") 218 | 219 | return fake_A_logit, fake_B_logit 220 | 221 | def build_model(self): 222 | self.lr = tf.placeholder(tf.float32, name='learning_rate') 223 | 224 | """ Input Image""" 225 | Image_Data_Class = ImageData(self.img_size, self.img_ch, self.augment_flag) 226 | 227 | trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset) 228 | trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset) 229 | 230 | trainA = trainA.prefetch(self.batch_size).shuffle(self.dataset_num).map(Image_Data_Class.image_processing, num_parallel_calls=8).apply(batch_and_drop_remainder(self.batch_size)).repeat() 231 | trainB = trainB.prefetch(self.batch_size).shuffle(self.dataset_num).map(Image_Data_Class.image_processing, num_parallel_calls=8).apply(batch_and_drop_remainder(self.batch_size)).repeat() 232 | 233 | trainA_iterator = trainA.make_one_shot_iterator() 234 | trainB_iterator = trainB.make_one_shot_iterator() 235 | 236 | self.domain_A = trainA_iterator.get_next() 237 | self.domain_B = trainB_iterator.get_next() 238 | 239 | domain_A = tf.split(self.domain_A, self.gpu_num) 240 | domain_B = tf.split(self.domain_B, self.gpu_num) 241 | 242 | G_A_losses = [] 243 | G_B_losses = [] 244 | D_A_losses = [] 245 | D_B_losses = [] 246 | 247 | G_losses = [] 248 | D_losses = [] 249 | 250 | self.fake_A = [] 251 | self.fake_B = [] 252 | 253 | self.real_A = [] 254 | self.real_B = [] 255 | 256 | for gpu_id in range(self.gpu_num): 257 | with tf.device(tf.DeviceSpec(device_type="GPU", device_index=gpu_id)): 258 | with tf.variable_scope(tf.get_variable_scope(), reuse=(gpu_id > 0)): 259 | """ Define Encoder, Generator, Discriminator """ 260 | x_aa, x_ba, x_ab, x_bb, shared = self.translation(domain_A[gpu_id], domain_B[gpu_id]) 261 | x_bab, shared_bab = self.generate_a2b(x_ba) 262 | x_aba, shared_aba = self.generate_b2a(x_ab) 263 | 264 | real_A_logit, real_B_logit = self.discriminate_real(domain_A[gpu_id], domain_B[gpu_id]) 265 | 266 | 267 | fake_A_logit, fake_B_logit = self.discriminate_fake(x_ba, x_ab) 268 | 269 | """ Define Loss """ 270 | G_ad_loss_a = generator_loss(self.gan_type, fake_A_logit) 271 | G_ad_loss_b = generator_loss(self.gan_type, fake_B_logit) 272 | 273 | D_ad_loss_a = discriminator_loss(self.gan_type, real_A_logit, fake_A_logit) 274 | D_ad_loss_b = discriminator_loss(self.gan_type, real_B_logit, fake_B_logit) 275 | 276 | enc_loss = KL_divergence(shared) 277 | enc_bab_loss = KL_divergence(shared_bab) 278 | enc_aba_loss = KL_divergence(shared_aba) 279 | 280 | l1_loss_a = L1_loss(x_aa, domain_A[gpu_id]) # identity 281 | l1_loss_b = L1_loss(x_bb, domain_B[gpu_id]) # identity 282 | l1_loss_aba = L1_loss(x_aba, domain_A[gpu_id]) # reconstruction 283 | l1_loss_bab = L1_loss(x_bab, domain_B[gpu_id]) # reconstruction 284 | 285 | Generator_A_loss_split = self.GAN_weight * G_ad_loss_a + \ 286 | self.L1_weight * l1_loss_a + \ 287 | self.L1_cycle_weight * l1_loss_aba + \ 288 | self.KL_weight * enc_loss + \ 289 | self.KL_cycle_weight * enc_bab_loss 290 | 291 | Generator_B_loss_split = self.GAN_weight * G_ad_loss_b + \ 292 | self.L1_weight * l1_loss_b + \ 293 | self.L1_cycle_weight * l1_loss_bab + \ 294 | self.KL_weight * enc_loss + \ 295 | self.KL_cycle_weight * enc_aba_loss 296 | 297 | Discriminator_A_loss_split = self.GAN_weight * D_ad_loss_a 298 | Discriminator_B_loss_split = self.GAN_weight * D_ad_loss_b 299 | 300 | Generator_loss_split = Generator_A_loss_split + Generator_B_loss_split 301 | Discriminator_loss_split = Discriminator_A_loss_split + Discriminator_B_loss_split 302 | 303 | G_A_losses.append(Generator_A_loss_split) 304 | G_B_losses.append(Generator_B_loss_split) 305 | D_A_losses.append(Discriminator_A_loss_split) 306 | D_B_losses.append(Discriminator_B_loss_split) 307 | 308 | G_losses.append(Generator_loss_split) 309 | D_losses.append(Discriminator_loss_split) 310 | 311 | self.fake_A.append(x_ba) 312 | self.fake_B.append(x_ab) 313 | 314 | self.real_A.append(domain_A[gpu_id]) 315 | self.real_B.append(domain_B[gpu_id]) 316 | 317 | Generator_A_loss = tf.reduce_mean(G_A_losses) 318 | Generator_B_loss = tf.reduce_mean(G_B_losses) 319 | Discriminator_A_loss = tf.reduce_mean(D_A_losses) 320 | Discriminator_B_loss = tf.reduce_mean(D_B_losses) 321 | 322 | self.Generator_loss = tf.reduce_mean(G_losses) 323 | self.Discriminator_loss = tf.reduce_mean(D_losses) 324 | 325 | 326 | """ Training """ 327 | t_vars = tf.trainable_variables() 328 | G_vars = [var for var in t_vars if 'generator' in var.name or 'encoder' in var.name] 329 | D_vars = [var for var in t_vars if 'discriminator' in var.name] 330 | 331 | self.G_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Generator_loss, var_list=G_vars) 332 | self.D_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Discriminator_loss, var_list=D_vars) 333 | 334 | """" Summary """ 335 | self.all_G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss) 336 | self.all_D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss) 337 | self.G_A_loss = tf.summary.scalar("G_A_loss", Generator_A_loss) 338 | self.G_B_loss = tf.summary.scalar("G_B_loss", Generator_B_loss) 339 | self.D_A_loss = tf.summary.scalar("D_A_loss", Discriminator_A_loss) 340 | self.D_B_loss = tf.summary.scalar("D_B_loss", Discriminator_B_loss) 341 | 342 | self.G_loss = tf.summary.merge([self.G_A_loss, self.G_B_loss, self.all_G_loss]) 343 | self.D_loss = tf.summary.merge([self.D_A_loss, self.D_B_loss, self.all_D_loss]) 344 | 345 | """ Image """ 346 | self.fake_A = tf.squeeze(self.fake_A) 347 | self.fake_B = tf.squeeze(self.fake_B) 348 | 349 | self.real_A = tf.squeeze(self.real_A) 350 | self.real_B = tf.squeeze(self.real_B) 351 | 352 | """ Test """ 353 | self.test_image = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_image') 354 | 355 | self.test_fake_A, _ = self.generate_b2a(self.test_image) 356 | self.test_fake_B, _ = self.generate_a2b(self.test_image) 357 | 358 | def train(self): 359 | # initialize all variables 360 | tf.global_variables_initializer().run() 361 | 362 | # saver to save model 363 | self.saver = tf.train.Saver() 364 | 365 | # summary writer 366 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_dir, self.sess.graph) 367 | 368 | # restore check-point if it exits 369 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 370 | if could_load: 371 | start_epoch = (int)(checkpoint_counter / self.iteration) 372 | start_batch_id = checkpoint_counter - start_epoch * self.iteration 373 | counter = checkpoint_counter 374 | print(" [*] Load SUCCESS") 375 | else: 376 | start_epoch = 0 377 | start_batch_id = 0 378 | counter = 1 379 | print(" [!] Load failed...") 380 | 381 | # loop for epoch 382 | start_time = time.time() 383 | lr = self.init_lr 384 | for epoch in range(start_epoch, self.epoch): 385 | for idx in range(start_batch_id, self.iteration): 386 | train_feed_dict = { 387 | self.lr : lr 388 | } 389 | 390 | # Update D 391 | _, d_loss, summary_str = self.sess.run([self.D_optim, self.Discriminator_loss, self.D_loss], feed_dict = train_feed_dict) 392 | self.writer.add_summary(summary_str, counter) 393 | 394 | # Update G 395 | batch_A_images, batch_B_images, fake_A, fake_B, _, g_loss, summary_str = self.sess.run([self.real_A, self.real_B, self.fake_A, self.fake_B, self.G_optim, self.Generator_loss, self.G_loss], feed_dict = train_feed_dict) 396 | self.writer.add_summary(summary_str, counter) 397 | 398 | # display training status 399 | counter += 1 400 | print("Epoch: [%2d] [%6d/%6d] time: %4.4f d_loss: %.8f, g_loss: %.8f" \ 401 | % (epoch, idx, self.iteration, time.time() - start_time, d_loss, g_loss)) 402 | 403 | if np.mod(idx+1, self.print_freq) == 0 : 404 | save_images(batch_A_images, [self.batch_size, 1], 405 | './{}/real_A_{:02d}_{:06d}.jpg'.format(self.sample_dir, epoch, idx+1)) 406 | # save_images(batch_B_images, [self.batch_size, 1], 407 | # './{}/real_B_{}_{:02d}_{:06d}.jpg'.format(self.sample_dir, gpu_id, epoch, idx+1)) 408 | 409 | # save_images(fake_A, [self.batch_size, 1], 410 | # './{}/fake_A_{}_{:02d}_{:06d}.jpg'.format(self.sample_dir, gpu_id, epoch, idx+1)) 411 | save_images(fake_B, [self.batch_size, 1], 412 | './{}/fake_B_{:02d}_{:06d}.jpg'.format(self.sample_dir, epoch, idx+1)) 413 | 414 | if np.mod(idx+1, self.save_freq) == 0 : 415 | self.save(self.checkpoint_dir, counter) 416 | 417 | # After an epoch, start_batch_id is set to zero 418 | # non-zero value is only for the first epoch after loading pre-trained model 419 | start_batch_id = 0 420 | 421 | # save model for final step 422 | self.save(self.checkpoint_dir, counter) 423 | 424 | @property 425 | def model_dir(self): 426 | return "{}_{}_{}".format(self.model_name, self.dataset_name, self.gan_type) 427 | 428 | def save(self, checkpoint_dir, step): 429 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) 430 | 431 | if not os.path.exists(checkpoint_dir): 432 | os.makedirs(checkpoint_dir) 433 | 434 | self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step) 435 | 436 | def load(self, checkpoint_dir): 437 | import re 438 | print(" [*] Reading checkpoints...") 439 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) 440 | 441 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 442 | if ckpt and ckpt.model_checkpoint_path: 443 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 444 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 445 | counter = int(next(re.finditer("(\d+)(?!.*\d)", ckpt_name)).group(0)) 446 | print(" [*] Success to read {}".format(ckpt_name)) 447 | return True, counter 448 | else: 449 | print(" [*] Failed to find a checkpoint") 450 | return False, 0 451 | 452 | def test(self): 453 | tf.global_variables_initializer().run() 454 | test_A_files = glob('./dataset/{}/*.*'.format(self.dataset_name + '/testA')) 455 | test_B_files = glob('./dataset/{}/*.*'.format(self.dataset_name + '/testB')) 456 | 457 | self.saver = tf.train.Saver() 458 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 459 | self.result_dir = os.path.join(self.result_dir, self.model_dir) 460 | check_folder(self.result_dir) 461 | 462 | if could_load : 463 | print(" [*] Load SUCCESS") 464 | else : 465 | print(" [!] Load failed...") 466 | 467 | # write html for visual comparison 468 | index_path = os.path.join(self.result_dir, 'index.html') 469 | index = open(index_path, 'w') 470 | index.write("
nameinputoutput
%s
%s
") 471 | index.write("") 472 | 473 | for sample_file in test_A_files : # A -> B 474 | print('Processing A image: ' + sample_file) 475 | sample_image = np.asarray(load_test_data(sample_file, size=self.img_size)) 476 | image_path = os.path.join(self.result_dir, '{0}'.format(os.path.basename(sample_file))) 477 | 478 | fake_img = self.sess.run(self.test_fake_B, feed_dict={self.test_image: sample_image}) 479 | save_images(fake_img, [1, 1], image_path) 480 | 481 | index.write("" % os.path.basename(image_path)) 482 | index.write("" % (sample_file if os.path.isabs(sample_file) else ( 483 | '../..' + os.path.sep + sample_file), self.img_size, self.img_size)) 484 | index.write("" % (image_path if os.path.isabs(image_path) else ( 485 | '../..' + os.path.sep + image_path), self.img_size, self.img_size)) 486 | index.write("") 487 | 488 | for sample_file in test_B_files : # B -> A 489 | print('Processing B image: ' + sample_file) 490 | sample_image = np.asarray(load_test_data(sample_file, size=self.img_size)) 491 | image_path = os.path.join(self.result_dir, '{0}'.format(os.path.basename(sample_file))) 492 | 493 | fake_img = self.sess.run(self.test_fake_A, feed_dict={self.test_image: sample_image}) 494 | save_images(fake_img, [1, 1], image_path) 495 | 496 | index.write("" % os.path.basename(image_path)) 497 | index.write("" % (sample_file if os.path.isabs(sample_file) else ( 498 | '../..' + os.path.sep + sample_file), self.img_size, self.img_size)) 499 | index.write("" % (image_path if os.path.isabs(image_path) else ( 500 | '../..' + os.path.sep + image_path), self.img_size, self.img_size)) 501 | index.write("") 502 | 503 | index.close() 504 | -------------------------------------------------------------------------------- /DatasetAPI/main.py: -------------------------------------------------------------------------------- 1 | from UNIT import UNIT 2 | import argparse 3 | from utils import * 4 | 5 | """parsing and configuration""" 6 | def parse_args(): 7 | desc = "Tensorflow implementation of UNIT" 8 | parser = argparse.ArgumentParser(description=desc) 9 | parser.add_argument('--phase', type=str, default='train', help='train or test ?') 10 | parser.add_argument('--dataset', type=str, default='summer2winter', help='dataset_name') 11 | parser.add_argument('--augment_flag', type=bool, default=False, help='Image augmentation use or not') 12 | 13 | parser.add_argument('--epoch', type=int, default=5, help='The number of epochs to run') 14 | parser.add_argument('--iteration', type=int, default=100000, help='The number of training iterations') 15 | parser.add_argument('--batch_size', type=int, default=1, help='The batch size') 16 | parser.add_argument('--print_freq', type=int, default=1000, help='The number of image_print_freq') 17 | parser.add_argument('--save_freq', type=int, default=1000, help='The number of ckpt_save_freq') 18 | 19 | parser.add_argument('--lr', type=float, default=0.0001, help='The learning rate') 20 | parser.add_argument('--GAN_weight', type=float, default=10.0, help='Weight about GAN, lambda0') 21 | parser.add_argument('--KL_weight', type=float, default=0.1, help='Weight about VAE, lambda1') 22 | parser.add_argument('--L1_weight', type=float, default=100.0, help='Weight about VAE, lambda2' ) 23 | parser.add_argument('--KL_cycle_weight', type=float, default=0.1, help='Weight about VAE Cycle, lambda3') 24 | parser.add_argument('--L1_cycle_weight', type=float, default=100.0, help='Weight about VAE Cycle, lambda4') 25 | 26 | parser.add_argument('--gan_type', type=str, default='gan', help='GAN loss type [gan / lsgan]') 27 | 28 | parser.add_argument('--ch', type=int, default=64, help='base channel number per layer') 29 | parser.add_argument('--n_encoder', type=int, default=3, help='The number of encoder') 30 | parser.add_argument('--n_enc_resblock', type=int, default=3, help='The number of encoder_resblock') 31 | parser.add_argument('--n_enc_share', type=int, default=1, help='The number of share_encoder') 32 | parser.add_argument('--n_gen_share', type=int, default=1, help='The number of share_generator') 33 | parser.add_argument('--n_gen_resblock', type=int, default=3, help='The number of generator_resblock') 34 | parser.add_argument('--n_gen_decoder', type=int, default=3, help='The number of generator_decoder') 35 | 36 | parser.add_argument('--n_dis', type=int, default=6, help='The number of discriminator layer') 37 | 38 | parser.add_argument('--img_size', type=int, default=256, help='The size of image') 39 | parser.add_argument('--img_ch', type=int, default=3, help='The size of image channel') 40 | 41 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint', 42 | help='Directory name to save the checkpoints') 43 | parser.add_argument('--result_dir', type=str, default='results', 44 | help='Directory name to save the generated images') 45 | parser.add_argument('--log_dir', type=str, default='logs', 46 | help='Directory name to save training logs') 47 | parser.add_argument('--sample_dir', type=str, default='samples', 48 | help='Directory name to save the samples on training') 49 | 50 | return check_args(parser.parse_args()) 51 | 52 | """checking arguments""" 53 | def check_args(args): 54 | # --checkpoint_dir 55 | check_folder(args.checkpoint_dir) 56 | 57 | # --result_dir 58 | check_folder(args.result_dir) 59 | 60 | # --result_dir 61 | check_folder(args.log_dir) 62 | 63 | # --sample_dir 64 | check_folder(args.sample_dir) 65 | 66 | # --epoch 67 | try: 68 | assert args.epoch >= 1 69 | except: 70 | print('number of epochs must be larger than or equal to one') 71 | 72 | # --batch_size 73 | try: 74 | assert args.batch_size >= 1 75 | except: 76 | print('batch size must be larger than or equal to one') 77 | return args 78 | 79 | """main""" 80 | def main(): 81 | # parse arguments 82 | args = parse_args() 83 | if args is None: 84 | exit() 85 | 86 | # open session 87 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: 88 | gan = UNIT(sess, args) 89 | 90 | # build graph 91 | gan.build_model() 92 | 93 | # show network architecture 94 | show_all_variables() 95 | 96 | if args.phase == 'train' : 97 | # launch the graph in a session 98 | gan.train() 99 | print(" [*] Training finished!") 100 | 101 | if args.phase == 'test' : 102 | gan.test() 103 | print(" [*] Test finished!") 104 | 105 | if __name__ == '__main__': 106 | main() -------------------------------------------------------------------------------- /DatasetAPI/main_multi_gpu.py: -------------------------------------------------------------------------------- 1 | from UNIT import UNIT 2 | import argparse 3 | from utils import * 4 | 5 | """parsing and configuration""" 6 | def parse_args(): 7 | desc = "Tensorflow implementation of UNIT" 8 | parser = argparse.ArgumentParser(description=desc) 9 | parser.add_argument('--phase', type=str, default='train', help='train or test ?') 10 | parser.add_argument('--dataset', type=str, default='summer2winter', help='dataset_name') 11 | parser.add_argument('--augment_flag', type=bool, default=False, help='Image augmentation use or not') 12 | 13 | parser.add_argument('--epoch', type=int, default=5, help='The number of epochs to run') 14 | parser.add_argument('--iteration', type=int, default=100000, help='The number of training iterations') 15 | parser.add_argument('--batch_size', type=int, default=1, help='The batch size') 16 | parser.add_argument('--gpu_num', type=int, default=8, help='The number of gpu') 17 | parser.add_argument('--print_freq', type=int, default=1000, help='The number of image_print_freq') 18 | parser.add_argument('--save_freq', type=int, default=1000, help='The number of ckpt_save_freq') 19 | 20 | parser.add_argument('--lr', type=float, default=0.0001, help='The learning rate') 21 | parser.add_argument('--GAN_weight', type=float, default=10.0, help='Weight about GAN, lambda0') 22 | parser.add_argument('--KL_weight', type=float, default=0.1, help='Weight about VAE, lambda1') 23 | parser.add_argument('--L1_weight', type=float, default=100.0, help='Weight about VAE, lambda2' ) 24 | parser.add_argument('--KL_cycle_weight', type=float, default=0.1, help='Weight about VAE Cycle, lambda3') 25 | parser.add_argument('--L1_cycle_weight', type=float, default=100.0, help='Weight about VAE Cycle, lambda4') 26 | 27 | parser.add_argument('--gan_type', type=str, default='gan', help='GAN loss type [gan / lsgan]') 28 | 29 | parser.add_argument('--ch', type=int, default=64, help='base channel number per layer') 30 | parser.add_argument('--n_encoder', type=int, default=3, help='The number of encoder') 31 | parser.add_argument('--n_enc_resblock', type=int, default=3, help='The number of encoder_resblock') 32 | parser.add_argument('--n_enc_share', type=int, default=1, help='The number of share_encoder') 33 | parser.add_argument('--n_gen_share', type=int, default=1, help='The number of share_generator') 34 | parser.add_argument('--n_gen_resblock', type=int, default=3, help='The number of generator_resblock') 35 | parser.add_argument('--n_gen_decoder', type=int, default=3, help='The number of generator_decoder') 36 | 37 | parser.add_argument('--n_dis', type=int, default=6, help='The number of discriminator layer') 38 | 39 | parser.add_argument('--img_size', type=int, default=256, help='The size of image') 40 | parser.add_argument('--img_ch', type=int, default=3, help='The size of image channel') 41 | 42 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint', 43 | help='Directory name to save the checkpoints') 44 | parser.add_argument('--result_dir', type=str, default='results', 45 | help='Directory name to save the generated images') 46 | parser.add_argument('--log_dir', type=str, default='logs', 47 | help='Directory name to save training logs') 48 | parser.add_argument('--sample_dir', type=str, default='samples', 49 | help='Directory name to save the samples on training') 50 | 51 | return check_args(parser.parse_args()) 52 | 53 | """checking arguments""" 54 | def check_args(args): 55 | # --checkpoint_dir 56 | check_folder(args.checkpoint_dir) 57 | 58 | # --result_dir 59 | check_folder(args.result_dir) 60 | 61 | # --result_dir 62 | check_folder(args.log_dir) 63 | 64 | # --sample_dir 65 | check_folder(args.sample_dir) 66 | 67 | # --epoch 68 | try: 69 | assert args.epoch >= 1 70 | except: 71 | print('number of epochs must be larger than or equal to one') 72 | 73 | # --batch_size 74 | try: 75 | assert args.batch_size >= 1 76 | except: 77 | print('batch size must be larger than or equal to one') 78 | return args 79 | 80 | """main""" 81 | def main(): 82 | # parse arguments 83 | args = parse_args() 84 | if args is None: 85 | exit() 86 | 87 | # open session 88 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: 89 | gan = UNIT(sess, args) 90 | 91 | # build graph 92 | gan.build_model() 93 | 94 | # show network architecture 95 | show_all_variables() 96 | 97 | if args.phase == 'train' : 98 | # launch the graph in a session 99 | gan.train() 100 | print(" [*] Training finished!") 101 | 102 | if args.phase == 'test' : 103 | gan.test() 104 | print(" [*] Test finished!") 105 | 106 | if __name__ == '__main__': 107 | main() -------------------------------------------------------------------------------- /DatasetAPI/ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib as tf_contrib 3 | 4 | weight_init = tf.random_normal_initializer(mean=0.0, stddev=0.02) 5 | weight_regularizer = tf_contrib.layers.l2_regularizer(scale=0.0001) 6 | 7 | ################################################################################## 8 | # Layer 9 | ################################################################################## 10 | 11 | def conv(x, channels, kernel=4, stride=2, pad=0, pad_type='zero', use_bias=True, scope='conv'): 12 | with tf.variable_scope(scope): 13 | if pad_type == 'zero' : 14 | x = tf.pad(x, [[0, 0], [pad, pad], [pad, pad], [0, 0]]) 15 | if pad_type == 'reflect' : 16 | x = tf.pad(x, [[0, 0], [pad, pad], [pad, pad], [0, 0]], mode='REFLECT') 17 | 18 | x = tf.layers.conv2d(inputs=x, filters=channels, 19 | kernel_size=kernel, kernel_initializer=weight_init, 20 | kernel_regularizer=weight_regularizer, 21 | strides=stride, use_bias=use_bias) 22 | 23 | return x 24 | 25 | def deconv(x, channels, kernel=3, stride=2, use_bias=True, scope='deconv_0') : 26 | with tf.variable_scope(scope): 27 | x = tf.layers.conv2d_transpose(inputs=x, filters=channels, 28 | kernel_size=kernel, kernel_initializer=weight_init, 29 | kernel_regularizer=weight_regularizer, 30 | strides=stride, use_bias=use_bias, padding='SAME') 31 | 32 | return x 33 | 34 | def linear(x, units, use_bias=True, scope='linear'): 35 | with tf.variable_scope(scope): 36 | x = flatten(x) 37 | x = tf.layers.dense(x, units=units, kernel_initializer=weight_init, kernel_regularizer=weight_regularizer, use_bias=use_bias) 38 | 39 | return x 40 | 41 | def flatten(x) : 42 | return tf.layers.flatten(x) 43 | 44 | def gaussian_noise_layer(mu): 45 | sigma = 1.0 46 | gaussian_random_vector = tf.random_normal(shape=tf.shape(mu), mean=0.0, stddev=1.0, dtype=tf.float32) 47 | return mu + sigma * gaussian_random_vector 48 | 49 | 50 | ################################################################################## 51 | # Residual-block 52 | ################################################################################## 53 | 54 | def resblock(x_init, channels, use_bias=True, scope='resblock'): 55 | with tf.variable_scope(scope): 56 | with tf.variable_scope('res1'): 57 | x = conv(x_init, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias) 58 | x = instance_norm(x) 59 | x = relu(x) 60 | 61 | with tf.variable_scope('res2'): 62 | x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias) 63 | x = instance_norm(x) 64 | 65 | return x + x_init 66 | 67 | ################################################################################## 68 | # Activation function 69 | ################################################################################## 70 | 71 | def lrelu(x, alpha=0.01): 72 | # pytorch alpha is 0.01 73 | return tf.nn.leaky_relu(x, alpha) 74 | 75 | 76 | def relu(x): 77 | return tf.nn.relu(x) 78 | 79 | 80 | def tanh(x): 81 | return tf.tanh(x) 82 | 83 | ################################################################################## 84 | # Normalization function 85 | ################################################################################## 86 | 87 | def instance_norm(x, scope='instance_norm'): 88 | return tf_contrib.layers.instance_norm(x, 89 | epsilon=1e-05, 90 | center=True, scale=True, 91 | scope=scope) 92 | 93 | ################################################################################## 94 | # Loss function 95 | ################################################################################## 96 | 97 | def discriminator_loss(type, real, fake): 98 | real_loss = 0 99 | fake_loss = 0 100 | 101 | if type == 'lsgan' : 102 | real_loss = tf.reduce_mean(tf.squared_difference(real, 1.0)) 103 | fake_loss = tf.reduce_mean(tf.square(fake)) 104 | 105 | if type == 'gan' : 106 | real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real), logits=real)) 107 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(fake), logits=fake)) 108 | 109 | loss = real_loss + fake_loss 110 | 111 | return loss 112 | 113 | 114 | def generator_loss(type, fake): 115 | fake_loss = 0 116 | 117 | if type == 'lsgan' : 118 | fake_loss = tf.reduce_mean(tf.squared_difference(fake, 1.0)) 119 | 120 | if type == 'gan' : 121 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(fake), logits=fake)) 122 | 123 | loss = fake_loss 124 | 125 | 126 | return loss 127 | 128 | 129 | def L1_loss(x, y): 130 | loss = tf.reduce_mean(tf.abs(x - y)) 131 | 132 | return loss 133 | 134 | def KL_divergence(mu) : 135 | # KL_divergence = 0.5 * tf.reduce_sum(tf.square(mu) + tf.square(sigma) - tf.log(1e-8 + tf.square(sigma)) - 1, axis = -1) 136 | # loss = tf.reduce_mean(KL_divergence) 137 | mu_2 = tf.square(mu) 138 | loss = tf.reduce_mean(mu_2) 139 | 140 | return loss -------------------------------------------------------------------------------- /DatasetAPI/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib import slim 3 | from scipy import misc 4 | import os, random 5 | import numpy as np 6 | 7 | # https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/ 8 | # https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/ 9 | 10 | class ImageData: 11 | 12 | def __init__(self, load_size, channels, augment_flag=False): 13 | self.load_size = load_size 14 | self.channels = channels 15 | self.augment_flag = augment_flag 16 | 17 | def image_processing(self, filename): 18 | x = tf.read_file(filename) 19 | x_decode = tf.image.decode_jpeg(x, channels=self.channels) 20 | img = tf.image.resize_images(x_decode, [self.load_size, self.load_size]) 21 | img = tf.cast(img, tf.float32) / 127.5 - 1 22 | 23 | if self.augment_flag : 24 | augment_size = self.load_size + (30 if self.load_size == 256 else 15) 25 | p = random.random() 26 | if p > 0.5: 27 | img = augmentation(img, augment_size) 28 | 29 | return img 30 | 31 | 32 | def load_test_data(image_path, size=256): 33 | img = misc.imread(image_path, mode='RGB') 34 | img = misc.imresize(img, [size, size]) 35 | img = np.expand_dims(img, axis=0) 36 | img = preprocessing(img) 37 | 38 | return img 39 | 40 | def preprocessing(x): 41 | x = x/127.5 - 1 # -1 ~ 1 42 | return x 43 | 44 | def augmentation(image, augment_size): 45 | seed = random.randint(0, 2 ** 31 - 1) 46 | ori_image_shape = tf.shape(image) 47 | image = tf.image.random_flip_left_right(image, seed=seed) 48 | image = tf.image.resize_images(image, [augment_size, augment_size]) 49 | image = tf.random_crop(image, ori_image_shape, seed=seed) 50 | return image 51 | 52 | def save_images(images, size, image_path): 53 | return imsave(inverse_transform(images), size, image_path) 54 | 55 | def inverse_transform(images): 56 | return (images+1.) / 2 57 | 58 | def imsave(images, size, path): 59 | return misc.imsave(path, merge(images, size)) 60 | 61 | def merge(images, size): 62 | h, w = images.shape[1], images.shape[2] 63 | img = np.zeros((h * size[0], w * size[1], 3)) 64 | for idx, image in enumerate(images): 65 | i = idx % size[1] 66 | j = idx // size[1] 67 | img[h*j:h*(j+1), w*i:w*(i+1), :] = image 68 | 69 | return img 70 | 71 | def show_all_variables(): 72 | model_vars = tf.trainable_variables() 73 | slim.model_analyzer.analyze_vars(model_vars, print_info=True) 74 | 75 | def check_folder(log_dir): 76 | if not os.path.exists(log_dir): 77 | os.makedirs(log_dir) 78 | return log_dir 79 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Junho Kim (1993.01.12) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UNIT-Tensorflow 2 | Simple Tensorflow implementation of ["Unsupervised Image to Image Translation Networks"](https://arxiv.org/abs/1703.00848) (NIPS 2017 Spotlight) 3 | 4 | ## Requirements 5 | * Tensorflow 1.4 6 | * Python 3.6 7 | 8 | ## Usage 9 | ```bash 10 | ├── dataset 11 |    └── YOUR_DATASET_NAME 12 |    ├── trainA 13 |           ├── xxx.jpg (name, format doesn't matter) 14 | ├── yyy.png 15 | └── ... 16 |    ├── trainB 17 | ├── zzz.jpg 18 | ├── www.png 19 | └── ... 20 |    ├── testA 21 |    ├── aaa.jpg 22 | ├── bbb.png 23 | └── ... 24 |    └── testB 25 | ├── ccc.jpg 26 | ├── ddd.png 27 | └── ... 28 | ``` 29 | 30 | ```bash 31 | > python main.py --phase train --dataset cat2tiger 32 | ``` 33 | * See `main.py` for other arguments 34 | * If you want to `multi_gpu_version`, then use `main_multi_gpu.py` (batch_size = The batch_size per gpu) 35 | * If you want to `faster_UNIT`, then use `DatasetAPI` (code is more simple !) 36 | 37 | ## Issue 38 | ### Too much Slow !!! 39 | * The slower reason is that it stores checkpoints 40 | * If you want to speed up, do not save checkpoints per iteration 41 | 42 | ## Arichitecture 43 | ![architecture](./assests/architecture.png) 44 | 45 | ## Framework 46 | ![framework](./assests/framework.png) 47 | 48 | ## Model 49 | ![compare](./assests/compare.png) 50 | 51 | ![vae](./assests/vae_model.png) 52 | 53 | ![gan](./assests/gan_model.png) 54 | 55 | ![cycle](./assests/cycle.png) 56 | 57 | ## Training Objective 58 | ![objective](./assests/training_objective__.png) 59 | 60 | ## Result 61 | ### Success 62 | ![success](./assests/success.png) 63 | 64 | ### Fail 65 | ![fail](./assests/fail.png) 66 | 67 | ## Related works 68 | * [CycleGAN-Tensorflow](https://github.com/taki0112/CycleGAN-Tensorflow) 69 | * [DiscoGAN-Tensorflow](https://github.com/taki0112/DiscoGAN-Tensorflow) 70 | * [MUNIT-Tensorflow](https://github.com/taki0112/MUNIT-Tensorflow) 71 | * [StarGAN-Tensorflow](https://github.com/taki0112/StarGAN-Tensorflow) 72 | * [DRIT-Tensorflow](https://github.com/taki0112/DRIT-Tensorflow) 73 | 74 | ## Reference 75 | * [UNIT-Pytorch](https://github.com/mingyuliutw/UNIT) 76 | * [Multi-GPU-Tensorflow](https://github.com/golbin/TensorFlow-Multi-GPUs) 77 | * [DatasetAPI-Tensorflow](https://github.com/taki0112/Tensorflow-DatasetAPI) 78 | 79 | ## Author 80 | Junho Kim 81 | -------------------------------------------------------------------------------- /UNIT.py: -------------------------------------------------------------------------------- 1 | from ops import * 2 | from utils import * 3 | from glob import glob 4 | import time 5 | 6 | class UNIT(object): 7 | def __init__(self, sess, args): 8 | self.model_name = 'UNIT' 9 | self.sess = sess 10 | self.checkpoint_dir = args.checkpoint_dir 11 | self.result_dir = args.result_dir 12 | self.log_dir = args.log_dir 13 | self.sample_dir = args.sample_dir 14 | self.dataset_name = args.dataset 15 | 16 | self.epoch = args.epoch # 100000 17 | self.batch_size = args.batch_size # 1 18 | 19 | self.lr = args.lr # 0.0001 20 | """ Weight about VAE """ 21 | self.KL_weight = args.KL_weight # lambda 1 22 | self.L1_weight = args.L1_weight # lambda 2 23 | 24 | """ Weight about VAE Cycle""" 25 | self.KL_cycle_weight = args.KL_cycle_weight # lambda 3 26 | self.L1_cycle_weight = args.L1_cycle_weight # lambda 4 27 | 28 | """ Weight about GAN """ 29 | self.GAN_weight = args.GAN_weight # lambda 0 30 | 31 | 32 | """ Encoder """ 33 | self.ch = args.ch # base channel number per layer 34 | self.n_encoder = args.n_encoder 35 | self.n_enc_resblock = args.n_enc_resblock 36 | self.n_enc_share = args.n_enc_share 37 | 38 | """ Generator """ 39 | self.n_gen_share = args.n_gen_share 40 | self.n_gen_resblock = args.n_gen_resblock 41 | self.n_gen_decoder = args.n_gen_decoder 42 | 43 | """ Discriminator """ 44 | self.n_dis = args.n_dis # + 2 45 | 46 | self.res_dropout = args.res_dropout 47 | self.smoothing = args.smoothing 48 | self.lsgan = args.lsgan 49 | self.norm = args.norm 50 | self.replay_memory = args.replay_memory 51 | self.pool_size = args.pool_size 52 | self.img_size = args.img_size 53 | self.channel = args.img_ch 54 | self.augment_flag = args.augment_flag 55 | self.augment_size = self.img_size + (30 if self.img_size == 256 else 15) 56 | self.normal_weight_init = args.normal_weight_init 57 | 58 | self.trainA, self.trainB = prepare_data(dataset_name=self.dataset_name, size=self.img_size) 59 | self.num_batches = max(len(self.trainA) // self.batch_size, len(self.trainB) // self.batch_size) 60 | 61 | ############################################################################## 62 | # BEGIN of ENCODERS 63 | def encoder(self, x, is_training=True, reuse=False, scope="encoder"): 64 | channel = self.ch 65 | with tf.variable_scope(scope, reuse=reuse) : 66 | x = conv(x, channel, kernel=7, stride=1, pad=3, normal_weight_init=self.normal_weight_init, activation_fn='leaky', scope='conv_0') 67 | 68 | for i in range(1, self.n_encoder) : 69 | x = conv(x, channel*2, kernel=3, stride=2, pad=1, normal_weight_init=self.normal_weight_init, activation_fn='leaky', scope='conv_'+str(i)) 70 | channel *= 2 71 | 72 | # channel = 256 73 | for i in range(0, self.n_enc_resblock) : 74 | x = resblock(x, channel, kernel=3, stride=1, pad=1, dropout_ratio=self.res_dropout, 75 | normal_weight_init=self.normal_weight_init, 76 | is_training=is_training, norm_fn=self.norm, scope='resblock_'+str(i)) 77 | 78 | return x 79 | # END of ENCODERS 80 | ############################################################################## 81 | 82 | ############################################################################## 83 | # BEGIN of SHARED LAYERS 84 | # Shared residual-blocks 85 | def share_encoder(self, x, is_training=True, reuse=False, scope="share_encoder"): 86 | channel = self.ch * pow(2, self.n_encoder-1) 87 | with tf.variable_scope(scope, reuse=reuse) : 88 | for i in range(0, self.n_enc_share) : 89 | x = resblock(x, channel, kernel=3, stride=1, pad=1, dropout_ratio=self.res_dropout, 90 | normal_weight_init=self.normal_weight_init, 91 | is_training=is_training, norm_fn=self.norm, scope='resblock_'+str(i)) 92 | 93 | x = gaussian_noise_layer(x) 94 | 95 | return x 96 | 97 | def share_generator(self, x, is_training=True, reuse=False, scope="share_generator"): 98 | channel = self.ch * pow(2, self.n_encoder-1) 99 | with tf.variable_scope(scope, reuse=reuse) : 100 | for i in range(0, self.n_gen_share) : 101 | x = resblock(x, channel, kernel=3, stride=1, pad=1, dropout_ratio=self.res_dropout, 102 | normal_weight_init=self.normal_weight_init, 103 | is_training=is_training, norm_fn=self.norm, scope='resblock_'+str(i)) 104 | 105 | return x 106 | # END of SHARED LAYERS 107 | ############################################################################## 108 | 109 | ############################################################################## 110 | # BEGIN of DECODERS 111 | def generator(self, x, is_training=True, reuse=False, scope="generator"): 112 | channel = self.ch * pow(2, self.n_encoder - 1) 113 | with tf.variable_scope(scope, reuse=reuse) : 114 | for i in range(0, self.n_gen_resblock) : 115 | x = resblock(x, channel, kernel=3, stride=1, pad=1, dropout_ratio=self.res_dropout, 116 | normal_weight_init=self.normal_weight_init, 117 | is_training=is_training, norm_fn=self.norm, scope='resblock_'+str(i)) 118 | 119 | for i in range(0, self.n_gen_decoder-1) : 120 | x = deconv(x, channel//2, kernel=3, stride=2, normal_weight_init=self.normal_weight_init, activation_fn='leaky', scope='deconv_'+str(i)) 121 | channel = channel // 2 122 | 123 | x = deconv(x, self.channel, kernel=1, stride=1, normal_weight_init=self.normal_weight_init, activation_fn='tanh', scope='deconv_tanh') 124 | 125 | return x 126 | # END of DECODERS 127 | ############################################################################## 128 | 129 | ############################################################################## 130 | # BEGIN of DISCRIMINATORS 131 | def discriminator(self, x, reuse=False, scope="discriminator"): 132 | channel = self.ch 133 | with tf.variable_scope(scope, reuse=reuse): 134 | x = conv(x, channel, kernel=3, stride=2, pad=1, normal_weight_init=self.normal_weight_init, activation_fn='leaky', scope='conv_0') 135 | 136 | for i in range(1, self.n_dis) : 137 | x = conv(x, channel*2, kernel=3, stride=2, pad=1, normal_weight_init=self.normal_weight_init, activation_fn='leaky', scope='conv_'+str(i)) 138 | channel *= 2 139 | 140 | x = conv(x, channels=1, kernel=1, stride=1, pad=0, normal_weight_init=self.normal_weight_init, activation_fn=None, scope='dis_logit') 141 | 142 | return x 143 | # END of DISCRIMINATORS 144 | ############################################################################## 145 | 146 | def translation(self, x_A, x_B): 147 | out = tf.concat([self.encoder(x_A, self.is_training, scope="encoder_A"), self.encoder(x_B, self.is_training, scope="encoder_B")], axis=0) 148 | shared = self.share_encoder(out, self.is_training) 149 | out = self.share_generator(shared, self.is_training) 150 | 151 | out_A = self.generator(out, self.is_training, scope="generator_A") 152 | out_B = self.generator(out, self.is_training, scope="generator_B") 153 | 154 | x_Aa, x_Ba = tf.split(out_A, 2, axis=0) 155 | x_Ab, x_Bb = tf.split(out_B, 2, axis=0) 156 | 157 | return x_Aa, x_Ba, x_Ab, x_Bb, shared 158 | 159 | def generate_a2b(self, x_A): 160 | out = self.encoder(x_A, self.is_training, reuse=True, scope="encoder_A") 161 | shared = self.share_encoder(out, self.is_training, reuse=True) 162 | out = self.share_generator(shared, self.is_training, reuse=True) 163 | out = self.generator(out, self.is_training, reuse=True, scope="generator_B") 164 | 165 | return out, shared 166 | 167 | def generate_b2a(self, x_B): 168 | out = self.encoder(x_B, self.is_training, reuse=True, scope="encoder_B") 169 | shared = self.share_encoder(out, self.is_training, reuse=True) 170 | out = self.share_generator(shared, self.is_training, reuse=True) 171 | out = self.generator(out, self.is_training, reuse=True, scope="generator_A") 172 | 173 | return out, shared 174 | 175 | def discriminate_real(self, x_A, x_B): 176 | real_A_logit = self.discriminator(x_A, scope="discriminator_A") 177 | real_B_logit = self.discriminator(x_B, scope="discriminator_B") 178 | 179 | return real_A_logit, real_B_logit 180 | 181 | def discriminate_fake(self, x_ba, x_ab): 182 | fake_A_logit = self.discriminator(x_ba, reuse=True, scope="discriminator_A") 183 | fake_B_logit = self.discriminator(x_ab, reuse=True, scope="discriminator_B") 184 | 185 | return fake_A_logit, fake_B_logit 186 | 187 | def discriminate_fake_pool(self, x_ba, x_ab): 188 | fake_A_pool_logit = self.discriminator(self.fake_A_pool.query(x_ba), reuse=True, scope="discriminator_A") # replay memory 189 | fake_B_pool_logit = self.discriminator(self.fake_B_pool.query(x_ab), reuse=True, scope="discriminator_B") # replay memory 190 | 191 | return fake_A_pool_logit, fake_B_pool_logit 192 | 193 | def build_model(self): 194 | self.is_training = tf.placeholder(tf.bool) 195 | self.prob = tf.placeholder(tf.float32) 196 | self.condition = tf.logical_and(tf.greater(self.prob, tf.constant(0.5)), self.is_training) 197 | 198 | """ Input Image""" 199 | domain_A = self.domain_A = tf.placeholder(tf.float32, [self.batch_size, self.img_size, self.img_size, self.channel], name='domain_A') # real A 200 | domain_B = self.domain_B = tf.placeholder(tf.float32, [self.batch_size, self.img_size, self.img_size, self.channel], name='domain_B') # real B 201 | 202 | if self.augment_flag : 203 | """ Augmentation """ 204 | domain_A = tf.cond( 205 | self.condition, 206 | lambda : augmentation(domain_A, self.augment_size), 207 | lambda : domain_A 208 | ) 209 | 210 | domain_B = tf.cond( 211 | self.condition, 212 | lambda : augmentation(domain_B, self.augment_size), 213 | lambda : domain_B 214 | ) 215 | 216 | 217 | """ Define Encoder, Generator, Discriminator """ 218 | x_aa, x_ba, x_ab, x_bb, shared = self.translation(domain_A, domain_B) 219 | x_bab, shared_bab = self.generate_a2b(x_ba) 220 | x_aba, shared_aba = self.generate_b2a(x_ab) 221 | 222 | real_A_logit, real_B_logit = self.discriminate_real(domain_A, domain_B) 223 | 224 | if self.replay_memory : 225 | self.fake_A_pool = ImagePool(self.pool_size) # pool of generated A 226 | self.fake_B_pool = ImagePool(self.pool_size) # pool of generated B 227 | fake_A_logit, fake_B_logit = self.discriminate_fake_pool(x_ba, x_ab) 228 | else : 229 | fake_A_logit, fake_B_logit = self.discriminate_fake(x_ba, x_ab) 230 | 231 | 232 | 233 | """ Define Loss """ 234 | G_ad_loss_a = generator_loss(fake_A_logit, smoothing=self.smoothing, use_lsgan=self.lsgan) 235 | G_ad_loss_b = generator_loss(fake_B_logit, smoothing=self.smoothing, use_lsgan=self.lsgan) 236 | 237 | D_ad_loss_a = discriminator_loss(real_A_logit, fake_A_logit, smoothing=self.smoothing, use_lasgan=self.lsgan) 238 | D_ad_loss_b = discriminator_loss(real_B_logit, fake_B_logit, smoothing=self.smoothing, use_lasgan=self.lsgan) 239 | 240 | enc_loss = KL_divergence(shared) 241 | enc_bab_loss = KL_divergence(shared_bab) 242 | enc_aba_loss = KL_divergence(shared_aba) 243 | 244 | l1_loss_a = L1_loss(x_aa, domain_A) # identity 245 | l1_loss_b = L1_loss(x_bb, domain_B) # identity 246 | l1_loss_aba = L1_loss(x_aba, domain_A) # reconstruction 247 | l1_loss_bab = L1_loss(x_bab, domain_B) # reconstruction 248 | 249 | Generator_A_loss = self.GAN_weight * G_ad_loss_a + \ 250 | self.L1_weight * l1_loss_a + \ 251 | self.L1_cycle_weight * l1_loss_aba + \ 252 | self.KL_weight * enc_loss + \ 253 | self.KL_cycle_weight * enc_bab_loss 254 | 255 | Generator_B_loss = self.GAN_weight * G_ad_loss_b + \ 256 | self.L1_weight * l1_loss_b + \ 257 | self.L1_cycle_weight * l1_loss_bab + \ 258 | self.KL_weight * enc_loss + \ 259 | self.KL_cycle_weight * enc_aba_loss 260 | 261 | Discriminator_A_loss = self.GAN_weight * D_ad_loss_a 262 | Discriminator_B_loss = self.GAN_weight * D_ad_loss_b 263 | 264 | self.Generator_loss = Generator_A_loss + Generator_B_loss 265 | self.Discriminator_loss = Discriminator_A_loss + Discriminator_B_loss 266 | 267 | 268 | """ Training """ 269 | t_vars = tf.trainable_variables() 270 | G_vars = [var for var in t_vars if ('generator' in var.name) or ('encoder' in var.name)] 271 | D_vars = [var for var in t_vars if 'discriminator' in var.name] 272 | 273 | 274 | # with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): 275 | self.G_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Generator_loss, var_list=G_vars) 276 | self.D_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Discriminator_loss, var_list=D_vars) 277 | 278 | """" Summary """ 279 | self.all_G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss) 280 | self.all_D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss) 281 | self.G_A_loss = tf.summary.scalar("G_A_loss", Generator_A_loss) 282 | self.G_B_loss = tf.summary.scalar("G_B_loss", Generator_B_loss) 283 | self.D_A_loss = tf.summary.scalar("D_A_loss", Discriminator_A_loss) 284 | self.D_B_loss = tf.summary.scalar("D_B_loss", Discriminator_B_loss) 285 | 286 | self.G_loss = tf.summary.merge([self.G_A_loss, self.G_B_loss, self.all_G_loss]) 287 | self.D_loss = tf.summary.merge([self.D_A_loss, self.D_B_loss, self.all_D_loss]) 288 | 289 | """ Generated Image """ 290 | self.fake_B, _ = self.generate_a2b(domain_A) # for test 291 | self.fake_A, _ = self.generate_b2a(domain_B) # for test 292 | 293 | def train(self): 294 | # initialize all variables 295 | tf.global_variables_initializer().run() 296 | 297 | # saver to save model 298 | self.saver = tf.train.Saver() 299 | 300 | # summary writer 301 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_name, self.sess.graph) 302 | 303 | 304 | # restore check-point if it exits 305 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 306 | if could_load: 307 | start_epoch = (int)(checkpoint_counter / self.num_batches) 308 | start_batch_id = checkpoint_counter - start_epoch * self.num_batches 309 | counter = checkpoint_counter 310 | print(" [*] Load SUCCESS") 311 | else: 312 | start_epoch = 0 313 | start_batch_id = 0 314 | counter = 1 315 | print(" [!] Load failed...") 316 | 317 | # loop for epoch 318 | start_time = time.time() 319 | for epoch in range(start_epoch, self.epoch): 320 | # get batch data 321 | for idx in range(start_batch_id, self.num_batches): 322 | random_index_A = np.random.choice(len(self.trainA), size=self.batch_size, replace=False) 323 | random_index_B = np.random.choice(len(self.trainB), size=self.batch_size, replace=False) 324 | batch_A_images = self.trainA[random_index_A] 325 | batch_B_images = self.trainB[random_index_B] 326 | p = np.random.uniform(low=0.0, high=1.0) 327 | 328 | 329 | train_feed_dict = { 330 | self.domain_A : batch_A_images, 331 | self.domain_B : batch_B_images, 332 | self.prob : p, 333 | self.is_training : True 334 | } 335 | 336 | # Update D 337 | _, d_loss, summary_str = self.sess.run([self.D_optim, self.Discriminator_loss, self.D_loss], feed_dict = train_feed_dict) 338 | self.writer.add_summary(summary_str, counter) 339 | 340 | # Update G 341 | fake_A, fake_B, _, g_loss, summary_str = self.sess.run([self.fake_A, self.fake_B, self.G_optim, self.Generator_loss, self.G_loss], feed_dict = train_feed_dict) 342 | self.writer.add_summary(summary_str, counter) 343 | 344 | # display training status 345 | counter += 1 346 | print("Epoch: [%2d] [%4d/%4d] time: %4.4f d_loss: %.8f, g_loss: %.8f" \ 347 | % (epoch, idx, self.num_batches, time.time() - start_time, d_loss, g_loss)) 348 | 349 | if np.mod(counter, 100) == 0 : 350 | save_images(batch_A_images, [self.batch_size, 1], 351 | './{}/real_A_{:02d}_{:04d}.jpg'.format(self.sample_dir, epoch, idx+2)) 352 | save_images(batch_B_images, [self.batch_size, 1], 353 | './{}/real_B_{:02d}_{:04d}.jpg'.format(self.sample_dir, epoch, idx+2)) 354 | 355 | save_images(fake_A, [self.batch_size, 1], 356 | './{}/fake_A_{:02d}_{:04d}.jpg'.format(self.sample_dir, epoch, idx+2)) 357 | save_images(fake_B, [self.batch_size, 1], 358 | './{}/fake_B_{:02d}_{:04d}.jpg'.format(self.sample_dir, epoch, idx+2)) 359 | 360 | # After an epoch, start_batch_id is set to zero 361 | # non-zero value is only for the first epoch after loading pre-trained model 362 | start_batch_id = 0 363 | 364 | # save model 365 | self.save(self.checkpoint_dir, counter) 366 | 367 | # save model for final step 368 | self.save(self.checkpoint_dir, counter) 369 | 370 | 371 | @property 372 | def model_dir(self): 373 | return "{}_{}_{}".format( 374 | self.model_name, self.dataset_name, self.norm) 375 | 376 | def save(self, checkpoint_dir, step): 377 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name) 378 | 379 | if not os.path.exists(checkpoint_dir): 380 | os.makedirs(checkpoint_dir) 381 | 382 | self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step) 383 | 384 | def load(self, checkpoint_dir): 385 | import re 386 | print(" [*] Reading checkpoints...") 387 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name) 388 | 389 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 390 | if ckpt and ckpt.model_checkpoint_path: 391 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 392 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 393 | counter = int(next(re.finditer("(\d+)(?!.*\d)", ckpt_name)).group(0)) 394 | print(" [*] Success to read {}".format(ckpt_name)) 395 | return True, counter 396 | else: 397 | print(" [*] Failed to find a checkpoint") 398 | return False, 0 399 | 400 | def test(self): 401 | tf.global_variables_initializer().run() 402 | test_A_files = glob('./dataset/{}/*.*'.format(self.dataset_name + '/testA')) 403 | test_B_files = glob('./dataset/{}/*.*'.format(self.dataset_name + '/testB')) 404 | 405 | """ 406 | testA, testB = test_data(dataset_name=self.dataset_name, size=self.img_size) 407 | test_A_images = testA[:] 408 | test_B_images = testB[:] 409 | """ 410 | self.saver = tf.train.Saver() 411 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 412 | 413 | if could_load : 414 | print(" [*] Load SUCCESS") 415 | else : 416 | print(" [!] Load failed...") 417 | 418 | # write html for visual comparison 419 | index_path = os.path.join(self.result_dir, 'index.html') 420 | index = open(index_path, 'w') 421 | index.write("
nameinputoutput
%s
%s
") 422 | index.write("") 423 | 424 | for sample_file in test_A_files : # A -> B 425 | print('Processing A image: ' + sample_file) 426 | sample_image = np.asarray(load_test_data(sample_file)) 427 | image_path = os.path.join(self.result_dir,'{0}'.format(os.path.basename(sample_file))) 428 | 429 | fake_img = self.sess.run(self.fake_B, feed_dict = {self.domain_A : sample_image, self.prob : 0.0, self.is_training : False}) 430 | 431 | save_images(fake_img, [1, 1], image_path) 432 | index.write("" % os.path.basename(image_path)) 433 | index.write("" % (sample_file if os.path.isabs(sample_file) else ( 434 | '..' + os.path.sep + sample_file), self.img_size, self.img_size)) 435 | index.write("" % (image_path if os.path.isabs(image_path) else ( 436 | '..' + os.path.sep + image_path), self.img_size, self.img_size)) 437 | index.write("") 438 | 439 | for sample_file in test_B_files : # B -> A 440 | print('Processing B image: ' + sample_file) 441 | sample_image = np.asarray(load_test_data(sample_file)) 442 | image_path = os.path.join(self.result_dir,'{0}'.format(os.path.basename(sample_file))) 443 | 444 | fake_img = self.sess.run(self.fake_A, feed_dict = {self.domain_B : sample_image, self.prob : 0.0, self.is_training : False}) 445 | 446 | save_images(fake_img, [1, 1], image_path) 447 | index.write("" % os.path.basename(image_path)) 448 | index.write("" % (sample_file if os.path.isabs(sample_file) else ( 449 | '..' + os.path.sep + sample_file), self.img_size, self.img_size)) 450 | index.write("" % (image_path if os.path.isabs(image_path) else ( 451 | '..' + os.path.sep + image_path), self.img_size, self.img_size)) 452 | index.write("") 453 | index.close() -------------------------------------------------------------------------------- /UNIT_multi_gpu.py: -------------------------------------------------------------------------------- 1 | from ops import * 2 | from utils import * 3 | from glob import glob 4 | import time 5 | 6 | class UNIT(object): 7 | def __init__(self, sess, args): 8 | self.model_name = 'UNIT' 9 | self.sess = sess 10 | self.checkpoint_dir = args.checkpoint_dir 11 | self.result_dir = args.result_dir 12 | self.log_dir = args.log_dir 13 | self.sample_dir = args.sample_dir 14 | self.dataset_name = args.dataset 15 | 16 | self.epoch = args.epoch # 100000 17 | self.batch_size_per_gpu = args.batch_size 18 | self.batch_size = args.batch_size * args.gpu_num 19 | self.gpu_num = args.gpu_num 20 | 21 | self.lr = args.lr # 0.0001 22 | """ Weight about VAE """ 23 | self.KL_weight = args.KL_weight # lambda 1 24 | self.L1_weight = args.L1_weight # lambda 2 25 | 26 | """ Weight about VAE Cycle""" 27 | self.KL_cycle_weight = args.KL_cycle_weight # lambda 3 28 | self.L1_cycle_weight = args.L1_cycle_weight # lambda 4 29 | 30 | """ Weight about GAN """ 31 | self.GAN_weight = args.GAN_weight # lambda 0 32 | 33 | 34 | """ Encoder """ 35 | self.ch = args.ch # base channel number per layer 36 | self.n_encoder = args.n_encoder 37 | self.n_enc_resblock = args.n_enc_resblock 38 | self.n_enc_share = args.n_enc_share 39 | 40 | """ Generator """ 41 | self.n_gen_share = args.n_gen_share 42 | self.n_gen_resblock = args.n_gen_resblock 43 | self.n_gen_decoder = args.n_gen_decoder 44 | 45 | """ Discriminator """ 46 | self.n_dis = args.n_dis # + 2 47 | 48 | self.res_dropout = args.res_dropout 49 | self.smoothing = args.smoothing 50 | self.lsgan = args.lsgan 51 | self.norm = args.norm 52 | self.replay_memory = args.replay_memory 53 | self.pool_size = args.pool_size 54 | self.img_size = args.img_size 55 | self.channel = args.img_ch 56 | self.augment_flag = args.augment_flag 57 | self.augment_size = self.img_size + (30 if self.img_size == 256 else 15) 58 | self.normal_weight_init = args.normal_weight_init 59 | 60 | self.trainA, self.trainB = prepare_data(dataset_name=self.dataset_name, size=self.img_size) 61 | self.num_batches = max(len(self.trainA) // self.batch_size, len(self.trainB) // self.batch_size) 62 | 63 | ############################################################################## 64 | # BEGIN of ENCODERS 65 | def encoder(self, x, is_training=True, reuse=False, scope="encoder"): 66 | channel = self.ch 67 | with tf.variable_scope(scope, reuse=reuse) : 68 | x = conv(x, channel, kernel=7, stride=1, pad=3, normal_weight_init=self.normal_weight_init, activation_fn='leaky', scope='conv_0') 69 | 70 | for i in range(1, self.n_encoder) : 71 | x = conv(x, channel*2, kernel=3, stride=2, pad=1, normal_weight_init=self.normal_weight_init, activation_fn='leaky', scope='conv_'+str(i)) 72 | channel *= 2 73 | 74 | # channel = 256 75 | for i in range(0, self.n_enc_resblock) : 76 | x = resblock(x, channel, kernel=3, stride=1, pad=1, dropout_ratio=self.res_dropout, 77 | normal_weight_init=self.normal_weight_init, 78 | is_training=is_training, norm_fn=self.norm, scope='resblock_'+str(i)) 79 | 80 | return x 81 | # END of ENCODERS 82 | ############################################################################## 83 | 84 | ############################################################################## 85 | # BEGIN of SHARED LAYERS 86 | # Shared residual-blocks 87 | def share_encoder(self, x, is_training=True, reuse=False, scope="share_encoder"): 88 | channel = self.ch * pow(2, self.n_encoder-1) 89 | with tf.variable_scope(scope, reuse=reuse) : 90 | for i in range(0, self.n_enc_share) : 91 | x = resblock(x, channel, kernel=3, stride=1, pad=1, dropout_ratio=self.res_dropout, 92 | normal_weight_init=self.normal_weight_init, 93 | is_training=is_training, norm_fn=self.norm, scope='resblock_'+str(i)) 94 | 95 | x = gaussian_noise_layer(x) 96 | 97 | return x 98 | 99 | def share_generator(self, x, is_training=True, reuse=False, scope="share_generator"): 100 | channel = self.ch * pow(2, self.n_encoder-1) 101 | with tf.variable_scope(scope, reuse=reuse) : 102 | for i in range(0, self.n_gen_share) : 103 | x = resblock(x, channel, kernel=3, stride=1, pad=1, dropout_ratio=self.res_dropout, 104 | normal_weight_init=self.normal_weight_init, 105 | is_training=is_training, norm_fn=self.norm, scope='resblock_'+str(i)) 106 | 107 | return x 108 | # END of SHARED LAYERS 109 | ############################################################################## 110 | 111 | ############################################################################## 112 | # BEGIN of DECODERS 113 | def generator(self, x, is_training=True, reuse=False, scope="generator"): 114 | channel = self.ch * pow(2, self.n_encoder - 1) 115 | with tf.variable_scope(scope, reuse=reuse) : 116 | for i in range(0, self.n_gen_resblock) : 117 | x = resblock(x, channel, kernel=3, stride=1, pad=1, dropout_ratio=self.res_dropout, 118 | normal_weight_init=self.normal_weight_init, 119 | is_training=is_training, norm_fn=self.norm, scope='resblock_'+str(i)) 120 | 121 | for i in range(0, self.n_gen_decoder-1) : 122 | x = deconv(x, channel//2, kernel=3, stride=2, normal_weight_init=self.normal_weight_init, activation_fn='leaky', scope='deconv_'+str(i)) 123 | channel = channel // 2 124 | 125 | x = deconv(x, self.channel, kernel=1, stride=1, normal_weight_init=self.normal_weight_init, activation_fn='tanh', scope='deconv_tanh') 126 | 127 | return x 128 | # END of DECODERS 129 | ############################################################################## 130 | 131 | ############################################################################## 132 | # BEGIN of DISCRIMINATORS 133 | def discriminator(self, x, reuse=False, scope="discriminator"): 134 | channel = self.ch 135 | with tf.variable_scope(scope, reuse=reuse): 136 | x = conv(x, channel, kernel=3, stride=2, pad=1, normal_weight_init=self.normal_weight_init, activation_fn='leaky', scope='conv_0') 137 | 138 | for i in range(1, self.n_dis) : 139 | x = conv(x, channel*2, kernel=3, stride=2, pad=1, normal_weight_init=self.normal_weight_init, activation_fn='leaky', scope='conv_'+str(i)) 140 | channel *= 2 141 | 142 | x = conv(x, channels=1, kernel=1, stride=1, pad=0, normal_weight_init=self.normal_weight_init, activation_fn=None, scope='dis_logit') 143 | 144 | return x 145 | # END of DISCRIMINATORS 146 | ############################################################################## 147 | 148 | def translation(self, x_A, x_B): 149 | out = tf.concat([self.encoder(x_A, self.is_training, scope="encoder_A"), self.encoder(x_B, self.is_training, scope="encoder_B")], axis=0) 150 | shared = self.share_encoder(out, self.is_training) 151 | out = self.share_generator(shared, self.is_training) 152 | 153 | out_A = self.generator(out, self.is_training, scope="generator_A") 154 | out_B = self.generator(out, self.is_training, scope="generator_B") 155 | 156 | x_Aa, x_Ba = tf.split(out_A, 2, axis=0) 157 | x_Ab, x_Bb = tf.split(out_B, 2, axis=0) 158 | 159 | return x_Aa, x_Ba, x_Ab, x_Bb, shared 160 | 161 | def generate_a2b(self, x_A): 162 | out = self.encoder(x_A, self.is_training, reuse=True, scope="encoder_A") 163 | shared = self.share_encoder(out, self.is_training, reuse=True) 164 | out = self.share_generator(shared, self.is_training, reuse=True) 165 | out = self.generator(out, self.is_training, reuse=True, scope="generator_B") 166 | 167 | return out, shared 168 | 169 | def generate_b2a(self, x_B): 170 | out = self.encoder(x_B, self.is_training, reuse=True, scope="encoder_B") 171 | shared = self.share_encoder(out, self.is_training, reuse=True) 172 | out = self.share_generator(shared, self.is_training, reuse=True) 173 | out = self.generator(out, self.is_training, reuse=True, scope="generator_A") 174 | 175 | return out, shared 176 | 177 | def discriminate_real(self, x_A, x_B): 178 | real_A_logit = self.discriminator(x_A, scope="discriminator_A") 179 | real_B_logit = self.discriminator(x_B, scope="discriminator_B") 180 | 181 | return real_A_logit, real_B_logit 182 | 183 | def discriminate_fake(self, x_ba, x_ab): 184 | fake_A_logit = self.discriminator(x_ba, reuse=True, scope="discriminator_A") 185 | fake_B_logit = self.discriminator(x_ab, reuse=True, scope="discriminator_B") 186 | 187 | return fake_A_logit, fake_B_logit 188 | 189 | def discriminate_fake_pool(self, x_ba, x_ab): 190 | fake_A_pool_logit = self.discriminator(self.fake_A_pool.query(x_ba), reuse=True, scope="discriminator_A") # replay memory 191 | fake_B_pool_logit = self.discriminator(self.fake_B_pool.query(x_ab), reuse=True, scope="discriminator_B") # replay memory 192 | 193 | return fake_A_pool_logit, fake_B_pool_logit 194 | 195 | def build_model(self): 196 | self.is_training = tf.placeholder(tf.bool) 197 | self.prob = tf.placeholder(tf.float32) 198 | self.condition = tf.logical_and(tf.greater(self.prob, tf.constant(0.5)), self.is_training) 199 | 200 | """ Input Image""" 201 | domain_A = self.domain_A = tf.placeholder(tf.float32, [self.batch_size, self.img_size, self.img_size, self.channel], name='domain_A') # real A 202 | domain_B = self.domain_B = tf.placeholder(tf.float32, [self.batch_size, self.img_size, self.img_size, self.channel], name='domain_B') # real B 203 | 204 | self.test_domain_A = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.channel], name='test_domain_A') 205 | self.test_domain_B = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.channel], name='test_domain_B') 206 | 207 | if self.augment_flag : 208 | """ Augmentation """ 209 | domain_A = tf.cond( 210 | self.condition, 211 | lambda : augmentation(domain_A, self.augment_size), 212 | lambda : domain_A 213 | ) 214 | 215 | domain_B = tf.cond( 216 | self.condition, 217 | lambda : augmentation(domain_B, self.augment_size), 218 | lambda : domain_B 219 | ) 220 | 221 | domain_A = tf.split(domain_A, self.gpu_num) 222 | domain_B = tf.split(domain_B, self.gpu_num) 223 | 224 | G_A_losses= [] 225 | G_B_losses = [] 226 | D_A_losses = [] 227 | D_B_losses = [] 228 | 229 | G_losses = [] 230 | D_losses = [] 231 | 232 | self.fake_A = [] 233 | self.fake_B = [] 234 | for gpu_id in range(self.gpu_num) : 235 | with tf.device(tf.DeviceSpec(device_type="GPU", device_index=gpu_id)) : 236 | with tf.variable_scope(tf.get_variable_scope(), reuse=(gpu_id > 0)) : 237 | """ Define Encoder, Generator, Discriminator """ 238 | x_aa, x_ba, x_ab, x_bb, shared = self.translation(domain_A[gpu_id], domain_B[gpu_id]) 239 | x_bab, shared_bab = self.generate_a2b(x_ba) 240 | x_aba, shared_aba = self.generate_b2a(x_ab) 241 | 242 | real_A_logit, real_B_logit = self.discriminate_real(domain_A[gpu_id], domain_B[gpu_id]) 243 | 244 | if self.replay_memory : 245 | self.fake_A_pool = ImagePool(self.pool_size) # pool of generated A 246 | self.fake_B_pool = ImagePool(self.pool_size) # pool of generated B 247 | fake_A_logit, fake_B_logit = self.discriminate_fake_pool(x_ba, x_ab) 248 | else : 249 | fake_A_logit, fake_B_logit = self.discriminate_fake(x_ba, x_ab) 250 | 251 | 252 | 253 | """ Define Loss """ 254 | G_ad_loss_a = generator_loss(fake_A_logit, smoothing=self.smoothing, use_lsgan=self.lsgan) 255 | G_ad_loss_b = generator_loss(fake_B_logit, smoothing=self.smoothing, use_lsgan=self.lsgan) 256 | 257 | D_ad_loss_a = discriminator_loss(real_A_logit, fake_A_logit, smoothing=self.smoothing, use_lasgan=self.lsgan) 258 | D_ad_loss_b = discriminator_loss(real_B_logit, fake_B_logit, smoothing=self.smoothing, use_lasgan=self.lsgan) 259 | 260 | enc_loss = KL_divergence(shared) 261 | enc_bab_loss = KL_divergence(shared_bab) 262 | enc_aba_loss = KL_divergence(shared_aba) 263 | 264 | l1_loss_a = L1_loss(x_aa, domain_A[gpu_id]) # identity 265 | l1_loss_b = L1_loss(x_bb, domain_B[gpu_id]) # identity 266 | l1_loss_aba = L1_loss(x_aba, domain_A[gpu_id]) # reconstruction 267 | l1_loss_bab = L1_loss(x_bab, domain_B[gpu_id]) # reconstruction 268 | 269 | Generator_A_loss_split = self.GAN_weight * G_ad_loss_a + \ 270 | self.L1_weight * l1_loss_a + \ 271 | self.L1_cycle_weight * l1_loss_aba + \ 272 | self.KL_weight * enc_loss + \ 273 | self.KL_cycle_weight * enc_bab_loss 274 | 275 | Generator_B_loss_split = self.GAN_weight * G_ad_loss_b + \ 276 | self.L1_weight * l1_loss_b + \ 277 | self.L1_cycle_weight * l1_loss_bab + \ 278 | self.KL_weight * enc_loss + \ 279 | self.KL_cycle_weight * enc_aba_loss 280 | 281 | Discriminator_A_loss_split = self.GAN_weight * D_ad_loss_a 282 | Discriminator_B_loss_split = self.GAN_weight * D_ad_loss_b 283 | 284 | Generator_loss_split = Generator_A_loss_split + Generator_B_loss_split 285 | Discriminator_loss_split = Discriminator_A_loss_split + Discriminator_B_loss_split 286 | 287 | """ Generated Image """ 288 | fake_B, _ = self.generate_a2b(domain_A[gpu_id]) # for test 289 | fake_A, _ = self.generate_b2a(domain_B[gpu_id]) # for test 290 | 291 | G_A_losses.append(Generator_A_loss_split) 292 | G_B_losses.append(Generator_B_loss_split) 293 | D_A_losses.append(Discriminator_A_loss_split) 294 | D_B_losses.append(Discriminator_B_loss_split) 295 | 296 | G_losses.append(Generator_loss_split) 297 | D_losses.append(Discriminator_loss_split) 298 | 299 | self.fake_A.append(fake_A) 300 | self.fake_B.append(fake_B) 301 | 302 | Generator_A_loss = tf.reduce_mean(G_A_losses) 303 | Generator_B_loss = tf.reduce_mean(G_B_losses) 304 | Discriminator_A_loss = tf.reduce_mean(D_A_losses) 305 | Discriminator_B_loss = tf.reduce_mean(D_B_losses) 306 | 307 | self.Generator_loss = tf.reduce_mean(G_losses) 308 | self.Discriminator_loss = tf.reduce_mean(D_losses) 309 | 310 | self.fake_A = tf.concat(self.fake_A, axis=0) 311 | self.fake_B = tf.concat(self.fake_B, axis=0) 312 | 313 | self.test_fake_B, _ = self.generate_a2b(self.test_domain_A) 314 | self.test_fake_A, _ = self.generate_b2a(self.test_domain_B) 315 | 316 | """ Training """ 317 | t_vars = tf.trainable_variables() 318 | G_vars = [var for var in t_vars if ('generator' in var.name) or ('encoder' in var.name)] 319 | D_vars = [var for var in t_vars if 'discriminator' in var.name] 320 | 321 | 322 | # with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): 323 | self.G_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Generator_loss, colocate_gradients_with_ops=True, var_list=G_vars) 324 | self.D_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Discriminator_loss, colocate_gradients_with_ops=True, var_list=D_vars) 325 | """" Summary """ 326 | self.all_G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss) 327 | self.all_D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss) 328 | self.G_A_loss = tf.summary.scalar("G_A_loss", Generator_A_loss) 329 | self.G_B_loss = tf.summary.scalar("G_B_loss", Generator_B_loss) 330 | self.D_A_loss = tf.summary.scalar("D_A_loss", Discriminator_A_loss) 331 | self.D_B_loss = tf.summary.scalar("D_B_loss", Discriminator_B_loss) 332 | 333 | self.G_loss = tf.summary.merge([self.G_A_loss, self.G_B_loss, self.all_G_loss]) 334 | self.D_loss = tf.summary.merge([self.D_A_loss, self.D_B_loss, self.all_D_loss]) 335 | 336 | 337 | def train(self): 338 | # initialize all variables 339 | tf.global_variables_initializer().run() 340 | 341 | # saver to save model 342 | self.saver = tf.train.Saver() 343 | 344 | # summary writer 345 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_name, self.sess.graph) 346 | 347 | 348 | # restore check-point if it exits 349 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 350 | if could_load: 351 | start_epoch = (int)(checkpoint_counter / self.num_batches) 352 | start_batch_id = checkpoint_counter - start_epoch * self.num_batches 353 | counter = checkpoint_counter 354 | print(" [*] Load SUCCESS") 355 | else: 356 | start_epoch = 0 357 | start_batch_id = 0 358 | counter = 1 359 | print(" [!] Load failed...") 360 | 361 | # loop for epoch 362 | start_time = time.time() 363 | for epoch in range(start_epoch, self.epoch): 364 | # get batch data 365 | for idx in range(start_batch_id, self.num_batches): 366 | random_index_A = np.random.choice(len(self.trainA), size=self.batch_size, replace=False) 367 | random_index_B = np.random.choice(len(self.trainB), size=self.batch_size, replace=False) 368 | batch_A_images = self.trainA[random_index_A] 369 | batch_B_images = self.trainB[random_index_B] 370 | p = np.random.uniform(low=0.0, high=1.0) 371 | 372 | 373 | train_feed_dict = { 374 | self.domain_A : batch_A_images, 375 | self.domain_B : batch_B_images, 376 | self.prob : p, 377 | self.is_training : True 378 | } 379 | 380 | # Update D 381 | _, d_loss, summary_str = self.sess.run([self.D_optim, self.Discriminator_loss, self.D_loss], feed_dict = train_feed_dict) 382 | self.writer.add_summary(summary_str, counter) 383 | 384 | # Update G 385 | fake_A, fake_B, _, g_loss, summary_str = self.sess.run([self.fake_A, self.fake_B, self.G_optim, self.Generator_loss, self.G_loss], feed_dict = train_feed_dict) 386 | self.writer.add_summary(summary_str, counter) 387 | 388 | # display training status 389 | counter += 1 390 | print("Epoch: [%2d] [%4d/%4d] time: %4.4f d_loss: %.8f, g_loss: %.8f" \ 391 | % (epoch, idx, self.num_batches, time.time() - start_time, d_loss, g_loss)) 392 | 393 | if np.mod(counter, 10) == 0 : 394 | batch_A_images = np.split(batch_A_images, self.gpu_num) 395 | batch_B_images = np.split(batch_B_images, self.gpu_num) 396 | fake_A = np.split(fake_A, self.gpu_num) 397 | fake_B = np.split(fake_B, self.gpu_num) 398 | 399 | for gpu_id in range(self.gpu_num) : 400 | save_images(batch_A_images[gpu_id], [self.batch_size_per_gpu, 1], 401 | './{}/real_A_{}_{:02d}_{:04d}.jpg'.format(self.sample_dir, gpu_id, epoch, idx+2)) 402 | save_images(batch_B_images[gpu_id], [self.batch_size_per_gpu, 1], 403 | './{}/real_B_{}_{:02d}_{:04d}.jpg'.format(self.sample_dir, gpu_id, epoch, idx+2)) 404 | 405 | save_images(fake_A[gpu_id], [self.batch_size_per_gpu, 1], 406 | './{}/fake_A_{}_{:02d}_{:04d}.jpg'.format(self.sample_dir, gpu_id, epoch, idx+2)) 407 | save_images(fake_B[gpu_id], [self.batch_size_per_gpu, 1], 408 | './{}/fake_B_{}_{:02d}_{:04d}.jpg'.format(self.sample_dir, gpu_id, epoch, idx+2)) 409 | 410 | # After an epoch, start_batch_id is set to zero 411 | # non-zero value is only for the first epoch after loading pre-trained model 412 | start_batch_id = 0 413 | 414 | # save model 415 | self.save(self.checkpoint_dir, counter) 416 | 417 | # save model for final step 418 | self.save(self.checkpoint_dir, counter) 419 | 420 | 421 | @property 422 | def model_dir(self): 423 | return "{}_{}_{}".format( 424 | self.model_name, self.dataset_name, self.norm) 425 | 426 | def save(self, checkpoint_dir, step): 427 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name) 428 | 429 | if not os.path.exists(checkpoint_dir): 430 | os.makedirs(checkpoint_dir) 431 | 432 | self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step) 433 | 434 | def load(self, checkpoint_dir): 435 | import re 436 | print(" [*] Reading checkpoints...") 437 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name) 438 | 439 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 440 | if ckpt and ckpt.model_checkpoint_path: 441 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 442 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 443 | counter = int(next(re.finditer("(\d+)(?!.*\d)", ckpt_name)).group(0)) 444 | print(" [*] Success to read {}".format(ckpt_name)) 445 | return True, counter 446 | else: 447 | print(" [*] Failed to find a checkpoint") 448 | return False, 0 449 | 450 | def test(self): 451 | tf.global_variables_initializer().run() 452 | test_A_files = glob('./dataset/{}/*.*'.format(self.dataset_name + '/testA')) 453 | test_B_files = glob('./dataset/{}/*.*'.format(self.dataset_name + '/testB')) 454 | 455 | """ 456 | testA, testB = test_data(dataset_name=self.dataset_name, size=self.img_size) 457 | test_A_images = testA[:] 458 | test_B_images = testB[:] 459 | """ 460 | self.saver = tf.train.Saver() 461 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 462 | 463 | if could_load : 464 | print(" [*] Load SUCCESS") 465 | else : 466 | print(" [!] Load failed...") 467 | 468 | # write html for visual comparison 469 | index_path = os.path.join(self.result_dir, 'index.html') 470 | index = open(index_path, 'w') 471 | index.write("
nameinputoutput
%s
%s
") 472 | index.write("") 473 | 474 | for sample_file in test_A_files : # A -> B 475 | print('Processing A image: ' + sample_file) 476 | sample_image = np.asarray(load_test_data(sample_file)) 477 | image_path = os.path.join(self.result_dir,'{0}'.format(os.path.basename(sample_file))) 478 | 479 | fake_img = self.sess.run(self.test_fake_B, feed_dict = {self.test_domain_A : sample_image, self.is_training : False}) 480 | 481 | save_images(fake_img, [1, 1], image_path) 482 | index.write("" % os.path.basename(image_path)) 483 | index.write("" % (sample_file if os.path.isabs(sample_file) else ( 484 | '..' + os.path.sep + sample_file), self.img_size, self.img_size)) 485 | index.write("" % (image_path if os.path.isabs(image_path) else ( 486 | '..' + os.path.sep + image_path), self.img_size, self.img_size)) 487 | index.write("") 488 | 489 | for sample_file in test_B_files : # B -> A 490 | print('Processing B image: ' + sample_file) 491 | sample_image = np.asarray(load_test_data(sample_file)) 492 | image_path = os.path.join(self.result_dir,'{0}'.format(os.path.basename(sample_file))) 493 | 494 | fake_img = self.sess.run(self.test_fake_A, feed_dict = {self.test_domain_B : sample_image, self.is_training : False}) 495 | 496 | save_images(fake_img, [1, 1], image_path) 497 | index.write("" % os.path.basename(image_path)) 498 | index.write("" % (sample_file if os.path.isabs(sample_file) else ( 499 | '..' + os.path.sep + sample_file), self.img_size, self.img_size)) 500 | index.write("" % (image_path if os.path.isabs(image_path) else ( 501 | '..' + os.path.sep + image_path), self.img_size, self.img_size)) 502 | index.write("") 503 | index.close() -------------------------------------------------------------------------------- /assests/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/.DS_Store -------------------------------------------------------------------------------- /assests/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/architecture.png -------------------------------------------------------------------------------- /assests/cat_species.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/cat_species.gif -------------------------------------------------------------------------------- /assests/cat_trans.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/cat_trans.png -------------------------------------------------------------------------------- /assests/compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/compare.png -------------------------------------------------------------------------------- /assests/cycle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/cycle.png -------------------------------------------------------------------------------- /assests/dog_breed.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/dog_breed.gif -------------------------------------------------------------------------------- /assests/dog_trans.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/dog_trans.png -------------------------------------------------------------------------------- /assests/faces.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/faces.png -------------------------------------------------------------------------------- /assests/fail.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/fail.png -------------------------------------------------------------------------------- /assests/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/framework.png -------------------------------------------------------------------------------- /assests/gan_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/gan_model.png -------------------------------------------------------------------------------- /assests/slide/compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/slide/compare.png -------------------------------------------------------------------------------- /assests/slide/cycle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/slide/cycle.png -------------------------------------------------------------------------------- /assests/slide/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/slide/framework.png -------------------------------------------------------------------------------- /assests/slide/gan_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/slide/gan_model.png -------------------------------------------------------------------------------- /assests/slide/training_objective.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/slide/training_objective.png -------------------------------------------------------------------------------- /assests/slide/vae_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/slide/vae_model.png -------------------------------------------------------------------------------- /assests/success.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/success.png -------------------------------------------------------------------------------- /assests/training_objective__.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/training_objective__.png -------------------------------------------------------------------------------- /assests/vae_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/vae_model.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from UNIT import UNIT 2 | import argparse 3 | from utils import * 4 | 5 | """parsing and configuration""" 6 | def parse_args(): 7 | desc = "Tensorflow implementation of UNIT" 8 | parser = argparse.ArgumentParser(description=desc) 9 | parser.add_argument('--phase', type=str, default='train', help='train or test ?') 10 | parser.add_argument('--dataset', type=str, default='cat2dog', help='dataset_name') 11 | 12 | parser.add_argument('--epoch', type=int, default=200, help='The number of epochs to run') 13 | parser.add_argument('--batch_size', type=int, default=1, help='The size of batch') 14 | 15 | parser.add_argument('--lr', type=float, default=0.0001, help='The learning rate') 16 | parser.add_argument('--GAN_weight', type=float, default=10.0, help='Weight about GAN, lambda0') 17 | parser.add_argument('--KL_weight', type=float, default=0.1, help='Weight about VAE, lambda1') 18 | parser.add_argument('--L1_weight', type=float, default=100.0, help='Weight about VAE, lambda2' ) 19 | parser.add_argument('--KL_cycle_weight', type=float, default=0.1, help='Weight about VAE Cycle, lambda3') 20 | parser.add_argument('--L1_cycle_weight', type=float, default=100.0, help='Weight about VAE Cycle, lambda4') 21 | 22 | parser.add_argument('--ch', type=int, default=64, help='base channel number per layer') 23 | parser.add_argument('--n_encoder', type=int, default=3, help='The number of encoder') 24 | parser.add_argument('--n_enc_resblock', type=int, default=3, help='The number of encoder_resblock') 25 | parser.add_argument('--n_enc_share', type=int, default=1, help='The number of share_encoder') 26 | parser.add_argument('--n_gen_share', type=int, default=1, help='The number of share_generator') 27 | parser.add_argument('--n_gen_resblock', type=int, default=3, help='The number of generator_resblock') 28 | parser.add_argument('--n_gen_decoder', type=int, default=3, help='The number of generator_decoder') 29 | parser.add_argument('--n_dis', type=int, default=6, help='The number of discriminator layer') 30 | 31 | parser.add_argument('--res_dropout', type=float, default=0.0, help='The dropout ration of Resblock') 32 | parser.add_argument('--smoothing', type=bool, default=False, help='smoothing loss use or not') 33 | parser.add_argument('--lsgan', type=bool, default=False, help='lsgan loss use or not') 34 | parser.add_argument('--norm', type=str, default='instance', help='The norm type') 35 | parser.add_argument('--replay_memory', type=bool, default=False, help='discriminator pool use or not') 36 | parser.add_argument('--pool_size', type=int, default=50, help='The size of image buffer that stores previously generated images') 37 | parser.add_argument('--img_size', type=int, default=256, help='The size of image') 38 | parser.add_argument('--img_ch', type=int, default=3, help='The size of image channel') 39 | parser.add_argument('--augment_flag', type=bool, default=True, help='Image augmentation use or not') 40 | parser.add_argument('--normal_weight_init', type=bool, default=True, help='normal initialization use or not') 41 | 42 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint', 43 | help='Directory name to save the checkpoints') 44 | parser.add_argument('--result_dir', type=str, default='results', 45 | help='Directory name to save the generated images') 46 | parser.add_argument('--log_dir', type=str, default='logs', 47 | help='Directory name to save training logs') 48 | parser.add_argument('--sample_dir', type=str, default='samples', 49 | help='Directory name to save the samples on training') 50 | 51 | return check_args(parser.parse_args()) 52 | 53 | """checking arguments""" 54 | def check_args(args): 55 | # --checkpoint_dir 56 | check_folder(args.checkpoint_dir) 57 | 58 | # --result_dir 59 | check_folder(args.result_dir) 60 | 61 | # --result_dir 62 | check_folder(args.log_dir) 63 | 64 | # --sample_dir 65 | check_folder(args.sample_dir) 66 | 67 | # --epoch 68 | try: 69 | assert args.epoch >= 1 70 | except: 71 | print('number of epochs must be larger than or equal to one') 72 | 73 | # --batch_size 74 | try: 75 | assert args.batch_size >= 1 76 | except: 77 | print('batch size must be larger than or equal to one') 78 | return args 79 | 80 | """main""" 81 | def main(): 82 | # parse arguments 83 | args = parse_args() 84 | if args is None: 85 | exit() 86 | 87 | # open session 88 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: 89 | gan = UNIT(sess, args) 90 | 91 | # build graph 92 | gan.build_model() 93 | 94 | # show network architecture 95 | show_all_variables() 96 | 97 | if args.phase == 'train' : 98 | # launch the graph in a session 99 | gan.train() 100 | print(" [*] Training finished!") 101 | 102 | if args.phase == 'test' : 103 | gan.test() 104 | print(" [*] Test finished!") 105 | 106 | if __name__ == '__main__': 107 | main() -------------------------------------------------------------------------------- /main_multi_gpu.py: -------------------------------------------------------------------------------- 1 | from UNIT_multi_gpu import UNIT 2 | import argparse 3 | from utils import * 4 | 5 | """parsing and configuration""" 6 | def parse_args(): 7 | desc = "Tensorflow implementation of UNIT" 8 | parser = argparse.ArgumentParser(description=desc) 9 | parser.add_argument('--phase', type=str, default='train', help='train or test ?') 10 | parser.add_argument('--dataset', type=str, default='cat2dog', help='dataset_name') 11 | 12 | parser.add_argument('--epoch', type=int, default=200, help='The number of epochs to run') 13 | parser.add_argument('--batch_size', type=int, default=1, help='The size of batch per gpu') 14 | parser.add_argument('--gpu_num', type=int, default=8, help='The number of gpu') 15 | 16 | parser.add_argument('--lr', type=float, default=0.0001, help='The learning rate') 17 | parser.add_argument('--GAN_weight', type=float, default=10.0, help='Weight about GAN, lambda0') 18 | parser.add_argument('--KL_weight', type=float, default=0.1, help='Weight about VAE, lambda1') 19 | parser.add_argument('--L1_weight', type=float, default=100.0, help='Weight about VAE, lambda2' ) 20 | parser.add_argument('--KL_cycle_weight', type=float, default=0.1, help='Weight about VAE Cycle, lambda3') 21 | parser.add_argument('--L1_cycle_weight', type=float, default=100.0, help='Weight about VAE Cycle, lambda4') 22 | 23 | parser.add_argument('--ch', type=int, default=64, help='base channel number per layer') 24 | parser.add_argument('--n_encoder', type=int, default=3, help='The number of encoder') 25 | parser.add_argument('--n_enc_resblock', type=int, default=3, help='The number of encoder_resblock') 26 | parser.add_argument('--n_enc_share', type=int, default=1, help='The number of share_encoder') 27 | parser.add_argument('--n_gen_share', type=int, default=1, help='The number of share_generator') 28 | parser.add_argument('--n_gen_resblock', type=int, default=3, help='The number of generator_resblock') 29 | parser.add_argument('--n_gen_decoder', type=int, default=3, help='The number of generator_decoder') 30 | parser.add_argument('--n_dis', type=int, default=6, help='The number of discriminator layer') 31 | 32 | parser.add_argument('--res_dropout', type=float, default=0.0, help='The dropout ration of Resblock') 33 | parser.add_argument('--smoothing', type=bool, default=False, help='smoothing loss use or not') 34 | parser.add_argument('--lsgan', type=bool, default=False, help='lsgan loss use or not') 35 | parser.add_argument('--norm', type=str, default='instance', help='The norm type') 36 | parser.add_argument('--replay_memory', type=bool, default=False, help='discriminator pool use or not') 37 | parser.add_argument('--pool_size', type=int, default=50, help='The size of image buffer that stores previously generated images') 38 | parser.add_argument('--img_size', type=int, default=256, help='The size of image') 39 | parser.add_argument('--img_ch', type=int, default=3, help='The size of image channel') 40 | parser.add_argument('--augment_flag', type=bool, default=True, help='Image augmentation use or not') 41 | parser.add_argument('--normal_weight_init', type=bool, default=True, help='normal initialization use or not') 42 | 43 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint', 44 | help='Directory name to save the checkpoints') 45 | parser.add_argument('--result_dir', type=str, default='results', 46 | help='Directory name to save the generated images') 47 | parser.add_argument('--log_dir', type=str, default='logs', 48 | help='Directory name to save training logs') 49 | parser.add_argument('--sample_dir', type=str, default='samples', 50 | help='Directory name to save the samples on training') 51 | 52 | return check_args(parser.parse_args()) 53 | 54 | """checking arguments""" 55 | def check_args(args): 56 | # --checkpoint_dir 57 | check_folder(args.checkpoint_dir) 58 | 59 | # --result_dir 60 | check_folder(args.result_dir) 61 | 62 | # --result_dir 63 | check_folder(args.log_dir) 64 | 65 | # --sample_dir 66 | check_folder(args.sample_dir) 67 | 68 | # --epoch 69 | try: 70 | assert args.epoch >= 1 71 | except: 72 | print('number of epochs must be larger than or equal to one') 73 | 74 | # --batch_size 75 | try: 76 | assert args.batch_size >= 1 77 | except: 78 | print('batch size must be larger than or equal to one') 79 | return args 80 | 81 | """main""" 82 | def main(): 83 | # parse arguments 84 | args = parse_args() 85 | if args is None: 86 | exit() 87 | 88 | # open session 89 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: 90 | gan = UNIT(sess, args) 91 | 92 | # build graph 93 | gan.build_model() 94 | 95 | # show network architecture 96 | show_all_variables() 97 | 98 | if args.phase == 'train' : 99 | # launch the graph in a session 100 | gan.train() 101 | print(" [*] Training finished!") 102 | 103 | if args.phase == 'test' : 104 | gan.test() 105 | print(" [*] Test finished!") 106 | 107 | if __name__ == '__main__': 108 | main() -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib as tf_contrib 3 | from tensorflow.contrib.layers import variance_scaling_initializer as he_init 4 | 5 | def conv(x, channels, kernel=3, stride=2, pad=0, normal_weight_init=False, activation_fn='leaky', scope='conv_0') : 6 | with tf.variable_scope(scope) : 7 | x = tf.pad(x, [[0,0], [pad, pad], [pad, pad], [0,0]]) 8 | 9 | if normal_weight_init : 10 | x = tf.layers.conv2d(inputs=x, filters=channels, kernel_size=kernel, kernel_initializer=tf.truncated_normal_initializer(stddev=0.02), 11 | strides=stride, kernel_regularizer=tf_contrib.layers.l2_regularizer(scale=0.0001)) 12 | 13 | else : 14 | if activation_fn == 'relu' : 15 | x = tf.layers.conv2d(inputs=x, filters=channels, kernel_size=kernel, kernel_initializer=he_init(), strides=stride, 16 | kernel_regularizer=tf_contrib.layers.l2_regularizer(scale=0.0001)) 17 | else : 18 | x = tf.layers.conv2d(inputs=x, filters=channels, kernel_size=kernel, strides=stride, 19 | kernel_regularizer=tf_contrib.layers.l2_regularizer(scale=0.0001)) 20 | 21 | 22 | x = activation(x, activation_fn) 23 | 24 | return x 25 | 26 | def deconv(x, channels, kernel=3, stride=2, normal_weight_init=False, activation_fn='leaky', scope='deconv_0') : 27 | with tf.variable_scope(scope): 28 | if normal_weight_init: 29 | x = tf.layers.conv2d_transpose(inputs=x, filters=channels, kernel_size=kernel, 30 | kernel_initializer=tf.truncated_normal_initializer(stddev=0.02), 31 | strides=stride, padding='SAME', kernel_regularizer=tf_contrib.layers.l2_regularizer(scale=0.0001)) 32 | 33 | else: 34 | if activation_fn == 'relu' : 35 | x = tf.layers.conv2d_transpose(inputs=x, filters=channels, kernel_size=kernel, kernel_initializer=he_init(), strides=stride, padding='SAME', 36 | kernel_regularizer=tf_contrib.layers.l2_regularizer(scale=0.0001)) 37 | else : 38 | x = tf.layers.conv2d_transpose(inputs=x, filters=channels, kernel_size=kernel, strides=stride, padding='SAME', 39 | kernel_regularizer=tf_contrib.layers.l2_regularizer(scale=0.0001)) 40 | 41 | x = activation(x, activation_fn) 42 | 43 | return x 44 | 45 | def resblock(x_init, channels, kernel=3, stride=1, pad=1, dropout_ratio=0.0, normal_weight_init=False, is_training=True, norm_fn='instance', scope='resblock_0') : 46 | assert norm_fn in ['instance', 'batch', 'weight', 'spectral', None] 47 | with tf.variable_scope(scope) : 48 | with tf.variable_scope('res1') : 49 | x = tf.pad(x_init, [[0, 0], [pad, pad], [pad, pad], [0, 0]]) 50 | 51 | if normal_weight_init : 52 | x = tf.layers.conv2d(inputs=x, filters=channels, kernel_size=kernel, 53 | kernel_initializer=tf.truncated_normal_initializer(stddev=0.02), 54 | strides=stride, kernel_regularizer=tf_contrib.layers.l2_regularizer(scale=0.0001)) 55 | else : 56 | x = tf.layers.conv2d(inputs=x, filters=channels, kernel_size=kernel, kernel_initializer=he_init(), 57 | strides=stride, kernel_regularizer=tf_contrib.layers.l2_regularizer(scale=0.0001)) 58 | 59 | if norm_fn == 'instance' : 60 | x = instance_norm(x, 'res1_instance') 61 | if norm_fn == 'batch' : 62 | x = batch_norm(x, is_training, 'res1_batch') 63 | 64 | x = relu(x) 65 | with tf.variable_scope('res2') : 66 | x = tf.pad(x, [[0, 0], [pad, pad], [pad, pad], [0, 0]]) 67 | 68 | if normal_weight_init : 69 | x = tf.layers.conv2d(inputs=x, filters=channels, kernel_size=kernel, 70 | kernel_initializer=tf.truncated_normal_initializer(stddev=0.02), 71 | strides=stride, kernel_regularizer=tf_contrib.layers.l2_regularizer(scale=0.0001)) 72 | else : 73 | x = tf.layers.conv2d(inputs=x, filters=channels, kernel_size=kernel, strides=stride, 74 | kernel_regularizer=tf_contrib.layers.l2_regularizer(scale=0.0001)) 75 | 76 | if norm_fn == 'instance' : 77 | x = instance_norm(x, 'res2_instance') 78 | if norm_fn == 'batch' : 79 | x = batch_norm(x, is_training, 'res2_batch') 80 | 81 | if dropout_ratio > 0.0 : 82 | x = tf.layers.dropout(x, rate=dropout_ratio, training=is_training) 83 | 84 | return x + x_init 85 | 86 | def activation(x, activation_fn='leaky') : 87 | assert activation_fn in ['relu', 'leaky', 'tanh', 'sigmoid', 'swish', None] 88 | if activation_fn == 'leaky': 89 | x = lrelu(x) 90 | 91 | if activation_fn == 'relu': 92 | x = relu(x) 93 | 94 | if activation_fn == 'sigmoid': 95 | x = sigmoid(x) 96 | 97 | if activation_fn == 'tanh' : 98 | x = tanh(x) 99 | 100 | if activation_fn == 'swish' : 101 | x = swish(x) 102 | 103 | return x 104 | 105 | def lrelu(x, alpha=0.01) : 106 | # pytorch alpha is 0.01 107 | return tf.nn.leaky_relu(x, alpha) 108 | 109 | def relu(x) : 110 | return tf.nn.relu(x) 111 | 112 | def sigmoid(x) : 113 | return tf.sigmoid(x) 114 | 115 | def tanh(x) : 116 | return tf.tanh(x) 117 | 118 | def swish(x) : 119 | return x * sigmoid(x) 120 | 121 | def batch_norm(x, is_training=False, scope='batch_nom') : 122 | return tf_contrib.layers.batch_norm(x, 123 | decay=0.9, epsilon=1e-05, 124 | center=True, scale=True, updates_collections=None, 125 | is_training=is_training, scope=scope) 126 | 127 | def instance_norm(x, scope='instance') : 128 | return tf_contrib.layers.instance_norm(x, 129 | epsilon=1e-05, 130 | center=True, scale=True, 131 | scope=scope) 132 | 133 | def gaussian_noise_layer(mu): 134 | sigma = 1.0 135 | gaussian_random_vector = tf.random_normal(shape=tf.shape(mu), mean=0.0, stddev=1.0, dtype=tf.float32) 136 | return mu + sigma * gaussian_random_vector 137 | 138 | def KL_divergence(mu) : 139 | # KL_divergence = 0.5 * tf.reduce_sum(tf.square(mu) + tf.square(sigma) - tf.log(1e-8 + tf.square(sigma)) - 1, axis = -1) 140 | # loss = tf.reduce_mean(KL_divergence) 141 | mu_2 = tf.square(mu) 142 | loss = tf.reduce_mean(mu_2) 143 | 144 | return loss 145 | 146 | def L1_loss(x, y) : 147 | loss = tf.reduce_mean(tf.abs(x - y)) 148 | return loss 149 | 150 | def discriminator_loss(real, fake, smoothing=False, use_lasgan=False) : 151 | if use_lasgan : 152 | if smoothing : 153 | real_loss = tf.reduce_mean(tf.squared_difference(real, 0.9)) * 0.5 154 | else : 155 | real_loss = tf.reduce_mean(tf.squared_difference(real, 1.0)) * 0.5 156 | 157 | fake_loss = tf.reduce_mean(tf.square(fake)) * 0.5 158 | else : 159 | if smoothing : 160 | real_labels = tf.fill(tf.shape(real), 0.9) 161 | else : 162 | real_labels = tf.ones_like(real) 163 | 164 | fake_labels = tf.zeros_like(fake) 165 | 166 | real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=real_labels, logits=real)) 167 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=fake_labels, logits=fake)) 168 | 169 | loss = real_loss + fake_loss 170 | 171 | return loss 172 | 173 | def generator_loss(fake, smoothing=False, use_lsgan=False) : 174 | if use_lsgan : 175 | if smoothing : 176 | loss = tf.reduce_mean(tf.squared_difference(fake, 0.9)) * 0.5 177 | else : 178 | loss = tf.reduce_mean(tf.squared_difference(fake, 1.0)) * 0.5 179 | else : 180 | if smoothing : 181 | fake_labels = tf.fill(tf.shape(fake), 0.9) 182 | else : 183 | fake_labels = tf.ones_like(fake) 184 | 185 | loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=fake_labels, logits=fake)) 186 | 187 | return loss 188 | 189 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib import slim 3 | from scipy import misc 4 | import os, random 5 | import numpy as np 6 | 7 | 8 | class ImagePool: 9 | """ History of generated images 10 | Same logic as https://github.com/junyanz/CycleGAN/blob/master/util/image_pool.lua 11 | """ 12 | 13 | def __init__(self, pool_size): 14 | self.pool_size = pool_size 15 | self.images = [] 16 | 17 | def query(self, image): 18 | if self.pool_size == 0: 19 | return image 20 | 21 | if len(self.images) < self.pool_size: 22 | self.images.append(image) 23 | return image 24 | else: 25 | p = random.random() 26 | if p > 0.5: 27 | # use old image 28 | random_id = random.randrange(0, self.pool_size) 29 | tmp = self.images[random_id].copy() 30 | self.images[random_id] = image.copy() 31 | return tmp 32 | else: 33 | return image 34 | 35 | 36 | def prepare_data(dataset_name, size): 37 | data_path = os.path.join("./dataset", dataset_name) 38 | 39 | trainA = [] 40 | trainB = [] 41 | for path, dir, files in os.walk(data_path): 42 | for file in files: 43 | image = os.path.join(path, file) 44 | if path.__contains__('trainA') : 45 | trainA.append(misc.imresize(misc.imread(image, mode='RGB'), [size, size])) 46 | if path.__contains__('trainB') : 47 | trainB.append(misc.imresize(misc.imread(image, mode='RGB'), [size, size])) 48 | 49 | 50 | trainA = preprocessing(np.asarray(trainA)) 51 | trainB = preprocessing(np.asarray(trainB)) 52 | 53 | np.random.shuffle(trainA) 54 | np.random.shuffle(trainB) 55 | 56 | return trainA, trainB 57 | 58 | def test_data(dataset_name, size) : 59 | data_path = os.path.join("./dataset", dataset_name) 60 | testA = [] 61 | testB = [] 62 | for path, dir, files in os.walk(data_path) : 63 | for file in files : 64 | image = os.path.join(path, file) 65 | if path.__contains__('testA') : 66 | testA.append(misc.imresize(misc.imread(image, mode='RGB'), [size, size])) 67 | if path.__contains__('testB') : 68 | testB.append(misc.imresize(misc.imread(image, mode='RGB'), [size, size])) 69 | 70 | testA = preprocessing(np.asarray(testA)) 71 | testB = preprocessing(np.asarray(testB)) 72 | 73 | return testA, testB 74 | 75 | def load_test_data(image_path, size=256): 76 | img = misc.imread(image_path, mode='RGB') 77 | img = misc.imresize(img, [size, size]) 78 | img = np.expand_dims(img, axis=0) 79 | img = preprocessing(img) 80 | 81 | return img 82 | 83 | def preprocessing(x): 84 | """ 85 | # Create Normal distribution 86 | x = x.astype('float32') 87 | x[:, :, :, 0] = (x[:, :, :, 0] - np.mean(x[:, :, :, 0])) / np.std(x[:, :, :, 0]) 88 | x[:, :, :, 1] = (x[:, :, :, 1] - np.mean(x[:, :, :, 1])) / np.std(x[:, :, :, 1]) 89 | x[:, :, :, 2] = (x[:, :, :, 2] - np.mean(x[:, :, :, 2])) / np.std(x[:, :, :, 2]) 90 | """ 91 | x = x/127.5 - 1 # -1 ~ 1 92 | return x 93 | 94 | def augmentation(image, augment_size): 95 | seed = random.randint(0, 2 ** 31 - 1) 96 | ori_image_shape = tf.shape(image) 97 | image = tf.image.resize_images(image, [augment_size, augment_size]) 98 | image = tf.random_crop(image, ori_image_shape, seed=seed) 99 | image = tf.map_fn(lambda x: tf.image.random_flip_left_right(x, seed), image) 100 | return image 101 | 102 | def save_images(images, size, image_path): 103 | return imsave(inverse_transform(images), size, image_path) 104 | 105 | def inverse_transform(images): 106 | return (images+1.) / 2 107 | 108 | def imsave(images, size, path): 109 | return misc.imsave(path, merge(images, size)) 110 | 111 | def merge(images, size): 112 | h, w = images.shape[1], images.shape[2] 113 | img = np.zeros((h * size[0], w * size[1], 3)) 114 | for idx, image in enumerate(images): 115 | i = idx % size[1] 116 | j = idx // size[1] 117 | img[h*j:h*(j+1), w*i:w*(i+1), :] = image 118 | 119 | return img 120 | 121 | def show_all_variables(): 122 | model_vars = tf.trainable_variables() 123 | slim.model_analyzer.analyze_vars(model_vars, print_info=True) 124 | 125 | def check_folder(log_dir): 126 | if not os.path.exists(log_dir): 127 | os.makedirs(log_dir) 128 | return log_dir --------------------------------------------------------------------------------
nameinputoutput
%s
%s