├── .gitignore ├── DRIT.py ├── LICENSE ├── README.md ├── assets ├── comparison.png ├── false.png ├── final.gif ├── result1.png ├── result2.png ├── test.png ├── test_1.png ├── test_2.png ├── train_1.png ├── train_2.png └── true.png ├── main.py ├── ops.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /DRIT.py: -------------------------------------------------------------------------------- 1 | from ops import * 2 | from utils import * 3 | from glob import glob 4 | import time 5 | from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch 6 | 7 | class DRIT(object) : 8 | def __init__(self, sess, args): 9 | self.model_name = 'DRIT' 10 | self.sess = sess 11 | self.checkpoint_dir = args.checkpoint_dir 12 | self.result_dir = args.result_dir 13 | self.log_dir = args.log_dir 14 | self.sample_dir = args.sample_dir 15 | self.dataset_name = args.dataset 16 | self.augment_flag = args.augment_flag 17 | 18 | self.epoch = args.epoch 19 | self.iteration = args.iteration 20 | self.decay_flag = args.decay_flag 21 | self.decay_epoch = args.decay_epoch 22 | 23 | self.gan_type = args.gan_type 24 | 25 | self.batch_size = args.batch_size 26 | self.print_freq = args.print_freq 27 | self.save_freq = args.save_freq 28 | 29 | self.num_attribute = args.num_attribute # for test 30 | self.guide_img = args.guide_img 31 | self.direction = args.direction 32 | 33 | self.img_size = args.img_size 34 | self.img_ch = args.img_ch 35 | 36 | self.init_lr = args.lr 37 | self.content_init_lr = args.lr / 2.5 38 | self.ch = args.ch 39 | self.concat = args.concat 40 | 41 | """ Weight """ 42 | self.content_adv_w = args.content_adv_w 43 | self.domain_adv_w = args.domain_adv_w 44 | self.cycle_w = args.cycle_w 45 | self.recon_w = args.recon_w 46 | self.latent_w = args.latent_w 47 | self.kl_w = args.kl_w 48 | 49 | """ Generator """ 50 | self.n_layer = args.n_layer 51 | self.n_z = args.n_z 52 | 53 | """ Discriminator """ 54 | self.n_dis = args.n_dis 55 | self.n_scale = args.n_scale 56 | self.n_d_con = args.n_d_con 57 | self.multi = True if args.n_scale > 1 else False 58 | self.sn = args.sn 59 | 60 | self.sample_dir = os.path.join(args.sample_dir, self.model_dir) 61 | check_folder(self.sample_dir) 62 | 63 | self.trainA_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainA')) 64 | self.trainB_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainB')) 65 | self.dataset_num = max(len(self.trainA_dataset), len(self.trainB_dataset)) 66 | 67 | print("##### Information #####") 68 | print("# gan type : ", self.gan_type) 69 | print("# dataset : ", self.dataset_name) 70 | print("# max dataset number : ", self.dataset_num) 71 | print("# batch_size : ", self.batch_size) 72 | print("# decay_flag : ", self.decay_flag) 73 | print("# epoch : ", self.epoch) 74 | print("# decay_epoch : ", self.decay_epoch) 75 | print("# iteration per epoch : ", self.iteration) 76 | print("# attribute in test phase : ", self.num_attribute) 77 | 78 | print() 79 | 80 | print("##### Generator #####") 81 | print("# layer : ", self.n_layer) 82 | print("# z dimension : ", self.n_z) 83 | print("# concat : ", self.concat) 84 | 85 | print() 86 | 87 | print("##### Discriminator #####") 88 | print("# discriminator layer : ", self.n_dis) 89 | print("# multi-scale Dis : ", self.n_scale) 90 | print("# updating iteration of con_dis : ", self.n_d_con) 91 | print("# spectral_norm : ", self.sn) 92 | 93 | print() 94 | 95 | print("##### Weight #####") 96 | print("# domain_adv_weight : ", self.domain_adv_w) 97 | print("# content_adv_weight : ", self.content_adv_w) 98 | print("# cycle_weight : ", self.cycle_w) 99 | print("# recon_weight : ", self.recon_w) 100 | print("# latent_weight : ", self.latent_w) 101 | print("# kl_weight : ", self.kl_w) 102 | 103 | ################################################################################## 104 | # Encoder and Decoders 105 | ################################################################################## 106 | 107 | def content_encoder(self, x, is_training=True, reuse=False, scope='content_encoder'): 108 | channel = self.ch 109 | with tf.variable_scope(scope, reuse=reuse) : 110 | x = conv(x, channel, kernel=7, stride=1, pad=3, pad_type='reflect', scope='conv') 111 | x = lrelu(x, 0.01) 112 | 113 | for i in range(2) : 114 | x = conv(x, channel * 2, kernel=3, stride=2, pad=1, pad_type='reflect', scope='conv_' + str(i)) 115 | x = instance_norm(x, scope='ins_norm_' + str(i)) 116 | x = relu(x) 117 | 118 | channel = channel * 2 119 | 120 | 121 | for i in range(1, self.n_layer) : 122 | x = resblock(x, channel, scope='resblock_'+str(i)) 123 | 124 | with tf.variable_scope('content_encoder_share', reuse=tf.AUTO_REUSE) : 125 | x = resblock(x, channel, scope='resblock_share') 126 | x = gaussian_noise_layer(x, is_training) 127 | 128 | return x 129 | 130 | def attribute_encoder(self, x, reuse=False, scope='attribute_encoder'): 131 | channel = self.ch 132 | with tf.variable_scope(scope, reuse=reuse) : 133 | x = conv(x, channel, kernel=7, stride=1, pad=3, pad_type='reflect', scope='conv') 134 | x = relu(x) 135 | channel = channel * 2 136 | 137 | x = conv(x, channel, kernel=4, stride=2, pad=1, pad_type='reflect', scope='conv_0') 138 | x = relu(x) 139 | channel = channel * 2 140 | 141 | 142 | for i in range(1, self.n_layer) : 143 | x = conv(x, channel, kernel=4, stride=2, pad=1, pad_type='reflect', scope='conv_' + str(i)) 144 | x = relu(x) 145 | 146 | x = global_avg_pooling(x) 147 | x = conv(x, channels=self.n_z, kernel=1, stride=1, scope='attribute_logit') 148 | 149 | return x 150 | 151 | def attribute_encoder_concat(self, x, reuse=False, scope='attribute_encoder_concat'): 152 | channel = self.ch 153 | with tf.variable_scope(scope, reuse=reuse) : 154 | x = conv(x, channel, kernel=4, stride=2, pad=1, pad_type='reflect', scope='conv') 155 | 156 | for i in range(1, self.n_layer) : 157 | channel = channel * (i+1) 158 | x = basic_block(x, channel, scope='basic_block_' + str(i)) 159 | 160 | x = lrelu(x, 0.2) 161 | x = global_avg_pooling(x) 162 | 163 | mean = fully_conneted(x, channels=self.n_z, scope='z_mean') 164 | logvar = fully_conneted(x, channels=self.n_z, scope='z_logvar') 165 | 166 | return mean, logvar 167 | 168 | def MLP(self, z, reuse=False, scope='MLP'): 169 | channel = self.ch * self.n_layer 170 | with tf.variable_scope(scope, reuse=reuse) : 171 | 172 | for i in range(2) : 173 | z = fully_conneted(z, channel, scope='fully_' + str(i)) 174 | z = relu(z) 175 | 176 | z = fully_conneted(z, channel*self.n_layer, scope='fully_logit') 177 | 178 | return z 179 | 180 | def generator(self, x, z, reuse=False, scope="generator"): 181 | channel = self.ch * self.n_layer 182 | with tf.variable_scope(scope, reuse=reuse) : 183 | z = self.MLP(z, reuse=reuse) 184 | z = tf.split(z, num_or_size_splits=self.n_layer, axis=-1) 185 | 186 | for i in range(self.n_layer) : 187 | x = mis_resblock(x, z[i], channel, scope='mis_resblock_' + str(i)) 188 | 189 | for i in range(2) : 190 | x = deconv(x, channel // 2, kernel=3, stride=2, scope='deconv_' + str(i)) 191 | x = layer_norm(x, scope='layer_norm_' + str(i)) 192 | x = relu(x) 193 | 194 | channel = channel // 2 195 | 196 | x = deconv(x, channels=self.img_ch, kernel=1, stride=1, scope='G_logit') 197 | x = tanh(x) 198 | 199 | return x 200 | 201 | def generator_concat(self, x, z, reuse=False, scope='generator_concat'): 202 | channel = self.ch * self.n_layer 203 | with tf.variable_scope('generator_concat_share', reuse=tf.AUTO_REUSE) : 204 | x = resblock(x, channel, scope='resblock') 205 | 206 | with tf.variable_scope(scope, reuse=reuse) : 207 | channel = channel + self.n_z 208 | x = expand_concat(x, z) 209 | 210 | for i in range(1, self.n_layer) : 211 | x = resblock(x, channel, scope='resblock_' + str(i)) 212 | 213 | for i in range(2) : 214 | channel = channel + self.n_z 215 | x = expand_concat(x, z) 216 | 217 | x = deconv(x, channel // 2, kernel=3, stride=2, scope='deconv_' + str(i)) 218 | x = layer_norm(x, scope='layer_norm_' + str(i)) 219 | x = relu(x) 220 | 221 | channel = channel // 2 222 | 223 | x = expand_concat(x, z) 224 | x = deconv(x, channels=self.img_ch, kernel=1, stride=1, scope='G_logit') 225 | x = tanh(x) 226 | 227 | return x 228 | 229 | 230 | 231 | ################################################################################## 232 | # Discriminator 233 | ################################################################################## 234 | 235 | def content_discriminator(self, x, reuse=False, scope='content_discriminator'): 236 | D_logit = [] 237 | with tf.variable_scope(scope, reuse=reuse) : 238 | channel = self.ch * self.n_layer 239 | for i in range(3) : 240 | x = conv(x, channel, kernel=7, stride=2, pad=1, pad_type='reflect', scope='conv_' + str(i)) 241 | x = instance_norm(x, scope='ins_norm_' + str(i)) 242 | x = lrelu(x, 0.01) 243 | 244 | x = conv(x, channel, kernel=4, stride=1, scope='conv_3') 245 | x = lrelu(x, 0.01) 246 | 247 | x = conv(x, channels=1, kernel=1, stride=1, scope='D_content_logit') 248 | D_logit.append(x) 249 | 250 | return D_logit 251 | 252 | def multi_discriminator(self, x_init, reuse=False, scope="multi_discriminator"): 253 | D_logit = [] 254 | with tf.variable_scope(scope, reuse=reuse) : 255 | for scale in range(self.n_scale) : 256 | channel = self.ch 257 | x = conv(x_init, channel, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='ms_' + str(scale) + 'conv_0') 258 | x = lrelu(x, 0.01) 259 | 260 | for i in range(1, self.n_dis): 261 | x = conv(x, channel * 2, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='ms_' + str(scale) +'conv_' + str(i)) 262 | x = lrelu(x, 0.01) 263 | 264 | channel = channel * 2 265 | 266 | x = conv(x, channels=1, kernel=1, stride=1, sn=self.sn, scope='ms_' + str(scale) + 'D_logit') 267 | D_logit.append(x) 268 | 269 | x_init = down_sample(x_init) 270 | 271 | return D_logit 272 | 273 | def discriminator(self, x, reuse=False, scope="discriminator"): 274 | D_logit = [] 275 | with tf.variable_scope(scope, reuse=reuse) : 276 | channel = self.ch 277 | x = conv(x, channel, kernel=3, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv') 278 | x = lrelu(x, 0.01) 279 | 280 | for i in range(1, self.n_dis) : 281 | x = conv(x, channel * 2, kernel=3, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_' + str(i)) 282 | x = lrelu(x, 0.01) 283 | 284 | channel = channel * 2 285 | 286 | x = conv(x, channels=1, kernel=1, stride=1, sn=self.sn, scope='D_logit') 287 | D_logit.append(x) 288 | 289 | return D_logit 290 | 291 | ################################################################################## 292 | # Model 293 | ################################################################################## 294 | 295 | def Encoder_A(self, x_A, is_training=True, random_fake=False, reuse=False): 296 | mean = None 297 | logvar = None 298 | 299 | content_A = self.content_encoder(x_A, is_training=is_training, reuse=reuse, scope='content_encoder_A') 300 | 301 | if self.concat : 302 | mean, logvar = self.attribute_encoder_concat(x_A, reuse=reuse, scope='attribute_encoder_concat_A') 303 | if random_fake : 304 | attribute_A = mean 305 | else : 306 | attribute_A = z_sample(mean, logvar) 307 | else : 308 | attribute_A = self.attribute_encoder(x_A, reuse=reuse, scope='attribute_encoder_A') 309 | 310 | return content_A, attribute_A, mean, logvar 311 | 312 | def Encoder_B(self, x_B, is_training=True, random_fake=False, reuse=False): 313 | mean = None 314 | logvar = None 315 | 316 | content_B = self.content_encoder(x_B, is_training=is_training, reuse=reuse, scope='content_encoder_B') 317 | 318 | if self.concat: 319 | mean, logvar = self.attribute_encoder_concat(x_B, reuse=reuse, scope='attribute_encoder_concat_B') 320 | if random_fake : 321 | attribute_B = mean 322 | 323 | else : 324 | attribute_B = z_sample(mean, logvar) 325 | else: 326 | attribute_B = self.attribute_encoder(x_B, reuse=reuse, scope='attribute_encoder_B') 327 | 328 | return content_B, attribute_B, mean, logvar 329 | 330 | def Decoder_A(self, content_B, attribute_A, reuse=False): 331 | # x = fake_A, identity_A, random_fake_A 332 | # x = (B, A), (A, A), (B, z) 333 | if self.concat : 334 | x = self.generator_concat(x=content_B, z=attribute_A, reuse=reuse, scope='generator_concat_A') 335 | else : 336 | x = self.generator(x=content_B, z=attribute_A, reuse=reuse, scope='generator_A') 337 | 338 | return x 339 | 340 | def Decoder_B(self, content_A, attribute_B, reuse=False): 341 | # x = fake_B, identity_B, random_fake_B 342 | # x = (A, B), (B, B), (A, z) 343 | if self.concat : 344 | x = self.generator_concat(x=content_A, z=attribute_B, reuse=reuse, scope='generator_concat_B') 345 | else : 346 | x = self.generator(x=content_A, z=attribute_B, reuse=reuse, scope='generator_B') 347 | 348 | return x 349 | 350 | def discriminate_real(self, x_A, x_B): 351 | if self.multi : 352 | real_A_logit = self.multi_discriminator(x_A, scope='multi_discriminator_A') 353 | real_B_logit = self.multi_discriminator(x_B, scope='multi_discriminator_B') 354 | 355 | else : 356 | real_A_logit = self.discriminator(x_A, scope="discriminator_A") 357 | real_B_logit = self.discriminator(x_B, scope="discriminator_B") 358 | 359 | return real_A_logit, real_B_logit 360 | 361 | def discriminate_fake(self, x_ba, x_ab): 362 | if self.multi : 363 | fake_A_logit = self.multi_discriminator(x_ba, reuse=True, scope='multi_discriminator_A') 364 | fake_B_logit = self.multi_discriminator(x_ab, reuse=True, scope='multi_discriminator_B') 365 | 366 | else : 367 | fake_A_logit = self.discriminator(x_ba, reuse=True, scope="discriminator_A") 368 | fake_B_logit = self.discriminator(x_ab, reuse=True, scope="discriminator_B") 369 | 370 | return fake_A_logit, fake_B_logit 371 | 372 | def discriminate_content(self, content_A, content_B, reuse=False): 373 | content_A_logit = self.content_discriminator(content_A, reuse=reuse, scope='content_discriminator') 374 | content_B_logit = self.content_discriminator(content_B, reuse=True, scope='content_discriminator') 375 | 376 | return content_A_logit, content_B_logit 377 | 378 | 379 | def build_model(self): 380 | self.lr = tf.placeholder(tf.float32, name='lr') 381 | self.content_lr = tf.placeholder(tf.float32, name='content_lr') 382 | 383 | """ Input Image""" 384 | Image_Data_Class = ImageData(self.img_size, self.img_ch, self.augment_flag) 385 | 386 | trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset) 387 | trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset) 388 | 389 | gpu_device = '/gpu:0' 390 | trainA = trainA.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size)) 391 | trainB = trainB.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size)) 392 | 393 | 394 | trainA_iterator = trainA.make_one_shot_iterator() 395 | trainB_iterator = trainB.make_one_shot_iterator() 396 | 397 | 398 | self.domain_A = trainA_iterator.get_next() 399 | self.domain_B = trainB_iterator.get_next() 400 | 401 | 402 | """ Define Encoder, Generator, Discriminator """ 403 | random_z = tf.random_normal(shape=[self.batch_size, self.n_z], mean=0.0, stddev=1.0, dtype=tf.float32) 404 | 405 | # encode 406 | content_a, attribute_a, mean_a, logvar_a = self.Encoder_A(self.domain_A) 407 | content_b, attribute_b, mean_b, logvar_b = self.Encoder_B(self.domain_B) 408 | 409 | # decode (fake, identity, random) 410 | fake_a = self.Decoder_A(content_B=content_b, attribute_A=attribute_a) 411 | fake_b = self.Decoder_B(content_A=content_a, attribute_B=attribute_b) 412 | 413 | recon_a = self.Decoder_A(content_B=content_a, attribute_A=attribute_a, reuse=True) 414 | recon_b = self.Decoder_B(content_A=content_b, attribute_B=attribute_b, reuse=True) 415 | 416 | random_fake_a = self.Decoder_A(content_B=content_b, attribute_A=random_z, reuse=True) 417 | random_fake_b = self.Decoder_B(content_A=content_a, attribute_B=random_z, reuse=True) 418 | 419 | # encode & decode again for cycle-consistency 420 | content_fake_a, attribute_fake_a, _, _ = self.Encoder_A(fake_a, reuse=True) 421 | content_fake_b, attribute_fake_b, _, _ = self.Encoder_B(fake_b, reuse=True) 422 | 423 | cycle_a = self.Decoder_A(content_B=content_fake_b, attribute_A=attribute_fake_a, reuse=True) 424 | cycle_b = self.Decoder_B(content_A=content_fake_a, attribute_B=attribute_fake_b, reuse=True) 425 | 426 | # for latent regression 427 | _, attribute_fake_random_a, _, _ = self.Encoder_A(random_fake_a, random_fake=True, reuse=True) 428 | _, attribute_fake_random_b, _, _ = self.Encoder_B(random_fake_b, random_fake=True, reuse=True) 429 | 430 | 431 | # discriminate 432 | real_A_logit, real_B_logit = self.discriminate_real(self.domain_A, self.domain_B) 433 | fake_A_logit, fake_B_logit = self.discriminate_fake(fake_a, fake_b) 434 | random_fake_A_logit, random_fake_B_logit = self.discriminate_fake(random_fake_a, random_fake_b) 435 | content_A_logit, content_B_logit = self.discriminate_content(content_a, content_b) 436 | 437 | 438 | """ Define Loss """ 439 | g_adv_loss_a = generator_loss(self.gan_type, fake_A_logit) + generator_loss(self.gan_type, random_fake_A_logit) 440 | g_adv_loss_b = generator_loss(self.gan_type, fake_B_logit) + generator_loss(self.gan_type, random_fake_B_logit) 441 | 442 | g_con_loss_a = generator_loss(self.gan_type, content_A_logit, content=True) 443 | g_con_loss_b = generator_loss(self.gan_type, content_B_logit, content=True) 444 | 445 | g_cyc_loss_a = L1_loss(cycle_a, self.domain_A) 446 | g_cyc_loss_b = L1_loss(cycle_b, self.domain_B) 447 | 448 | g_rec_loss_a = L1_loss(recon_a, self.domain_A) 449 | g_rec_loss_b = L1_loss(recon_b, self.domain_B) 450 | 451 | g_latent_loss_a = L1_loss(attribute_fake_random_a, random_z) 452 | g_latent_loss_b = L1_loss(attribute_fake_random_b, random_z) 453 | 454 | if self.concat : 455 | g_kl_loss_a = kl_loss(mean_a, logvar_a) + l2_regularize(content_a) 456 | g_kl_loss_b = kl_loss(mean_b, logvar_b) + l2_regularize(content_b) 457 | else : 458 | g_kl_loss_a = l2_regularize(attribute_a) + l2_regularize(content_a) 459 | g_kl_loss_b = l2_regularize(attribute_b) + l2_regularize(content_b) 460 | 461 | 462 | d_adv_loss_a = discriminator_loss(self.gan_type, real_A_logit, fake_A_logit, random_fake_A_logit) 463 | d_adv_loss_b = discriminator_loss(self.gan_type, real_B_logit, fake_B_logit, random_fake_B_logit) 464 | 465 | d_con_loss = discriminator_loss(self.gan_type, content_A_logit, content_B_logit, content=True) 466 | 467 | Generator_A_domain_loss = self.domain_adv_w * g_adv_loss_a 468 | Generator_A_content_loss = self.content_adv_w * g_con_loss_a 469 | Generator_A_cycle_loss = self.cycle_w * g_cyc_loss_b 470 | Generator_A_recon_loss = self.recon_w * g_rec_loss_a 471 | Generator_A_latent_loss = self.latent_w * g_latent_loss_a 472 | Generator_A_kl_loss = self.kl_w * g_kl_loss_a 473 | 474 | Generator_A_loss = Generator_A_domain_loss + \ 475 | Generator_A_content_loss + \ 476 | Generator_A_cycle_loss + \ 477 | Generator_A_recon_loss + \ 478 | Generator_A_latent_loss + \ 479 | Generator_A_kl_loss 480 | 481 | Generator_B_domain_loss = self.domain_adv_w * g_adv_loss_b 482 | Generator_B_content_loss = self.content_adv_w * g_con_loss_b 483 | Generator_B_cycle_loss = self.cycle_w * g_cyc_loss_a 484 | Generator_B_recon_loss = self.recon_w * g_rec_loss_b 485 | Generator_B_latent_loss = self.latent_w * g_latent_loss_b 486 | Generator_B_kl_loss = self.kl_w * g_kl_loss_b 487 | 488 | Generator_B_loss = Generator_B_domain_loss + \ 489 | Generator_B_content_loss + \ 490 | Generator_B_cycle_loss + \ 491 | Generator_B_recon_loss + \ 492 | Generator_B_latent_loss + \ 493 | Generator_B_kl_loss 494 | 495 | Discriminator_A_loss = self.domain_adv_w * d_adv_loss_a 496 | Discriminator_B_loss = self.domain_adv_w * d_adv_loss_b 497 | Discriminator_content_loss = self.content_adv_w * d_con_loss 498 | 499 | self.Generator_loss = Generator_A_loss + Generator_B_loss 500 | self.Discriminator_loss = Discriminator_A_loss + Discriminator_B_loss 501 | self.Discriminator_content_loss = Discriminator_content_loss 502 | 503 | """ Training """ 504 | t_vars = tf.trainable_variables() 505 | G_vars = [var for var in t_vars if 'encoder' in var.name or 'generator' in var.name] 506 | D_vars = [var for var in t_vars if 'discriminator' in var.name and 'content' not in var.name] 507 | D_content_vars = [var for var in t_vars if 'content_discriminator' in var.name] 508 | 509 | grads, _ = tf.clip_by_global_norm(tf.gradients(self.Discriminator_content_loss, D_content_vars), clip_norm=5) 510 | 511 | self.G_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Generator_loss, var_list=G_vars) 512 | self.D_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Discriminator_loss, var_list=D_vars) 513 | self.D_content_optim = tf.train.AdamOptimizer(self.content_lr, beta1=0.5, beta2=0.999).apply_gradients(zip(grads, D_content_vars)) 514 | 515 | 516 | """" Summary """ 517 | self.lr_write = tf.summary.scalar("learning_rate", self.lr) 518 | 519 | self.all_G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss) 520 | self.all_D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss) 521 | 522 | self.G_A_loss = tf.summary.scalar("G_A_loss", Generator_A_loss) 523 | self.G_A_domain_loss = tf.summary.scalar("G_A_domain_loss", Generator_A_domain_loss) 524 | self.G_A_content_loss = tf.summary.scalar("G_A_content_loss", Generator_A_content_loss) 525 | self.G_A_cycle_loss = tf.summary.scalar("G_A_cycle_loss", Generator_A_cycle_loss) 526 | self.G_A_recon_loss = tf.summary.scalar("G_A_recon_loss", Generator_A_recon_loss) 527 | self.G_A_latent_loss = tf.summary.scalar("G_A_latent_loss", Generator_A_latent_loss) 528 | self.G_A_kl_loss = tf.summary.scalar("G_A_kl_loss", Generator_A_kl_loss) 529 | 530 | 531 | self.G_B_loss = tf.summary.scalar("G_B_loss", Generator_B_loss) 532 | self.G_B_domain_loss = tf.summary.scalar("G_B_domain_loss", Generator_B_domain_loss) 533 | self.G_B_content_loss = tf.summary.scalar("G_B_content_loss", Generator_B_content_loss) 534 | self.G_B_cycle_loss = tf.summary.scalar("G_B_cycle_loss", Generator_B_cycle_loss) 535 | self.G_B_recon_loss = tf.summary.scalar("G_B_recon_loss", Generator_B_recon_loss) 536 | self.G_B_latent_loss = tf.summary.scalar("G_B_latent_loss", Generator_B_latent_loss) 537 | self.G_B_kl_loss = tf.summary.scalar("G_B_kl_loss", Generator_B_kl_loss) 538 | 539 | self.D_A_loss = tf.summary.scalar("D_A_loss", Discriminator_A_loss) 540 | self.D_B_loss = tf.summary.scalar("D_B_loss", Discriminator_B_loss) 541 | 542 | self.G_loss = tf.summary.merge([self.G_A_loss, 543 | self.G_A_domain_loss, self.G_A_content_loss, 544 | self.G_A_cycle_loss, self.G_A_recon_loss, 545 | self.G_A_latent_loss, self.G_A_kl_loss, 546 | 547 | self.G_B_loss, 548 | self.G_B_domain_loss, self.G_B_content_loss, 549 | self.G_B_cycle_loss, self.G_B_recon_loss, 550 | self.G_B_latent_loss, self.G_B_kl_loss, 551 | 552 | self.all_G_loss]) 553 | 554 | self.D_loss = tf.summary.merge([self.D_A_loss, 555 | self.D_B_loss, 556 | self.all_D_loss]) 557 | 558 | self.D_content_loss = tf.summary.scalar("Discriminator_content_loss", self.Discriminator_content_loss) 559 | 560 | 561 | 562 | """ Image """ 563 | self.fake_A = random_fake_a 564 | self.fake_B = random_fake_b 565 | 566 | self.real_A = self.domain_A 567 | self.real_B = self.domain_B 568 | 569 | 570 | """ Test """ 571 | self.test_image = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_image') 572 | self.test_random_z = tf.random_normal(shape=[1, self.n_z], mean=0.0, stddev=1.0, dtype=tf.float32) 573 | 574 | test_content_a, _, _, _ = self.Encoder_A(self.test_image, is_training=False, reuse=True) 575 | test_content_b, _, _, _ = self.Encoder_B(self.test_image, is_training=False, reuse=True) 576 | 577 | self.test_fake_A = self.Decoder_A(content_B=test_content_b, attribute_A=self.test_random_z, reuse=True) 578 | self.test_fake_B = self.Decoder_B(content_A=test_content_a, attribute_B=self.test_random_z, reuse=True) 579 | 580 | """ Guided Image Translation """ 581 | self.content_image = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='content_image') 582 | self.attribute_image = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='guide_attribute_image') 583 | 584 | if self.direction == 'a2b' : 585 | guide_content_A, _, _, _ = self.Encoder_A(self.content_image, is_training=False, reuse=True) 586 | _, guide_attribute_B, _, _ = self.Encoder_B(self.attribute_image, is_training=False, reuse=True) 587 | self.guide_fake_B = self.Decoder_B(content_A=guide_content_A, attribute_B=guide_attribute_B, reuse=True) 588 | 589 | else : 590 | guide_content_B, _, _, _ = self.Encoder_B(self.content_image, is_training=False, reuse=True) 591 | _, guide_attribute_A, _, _ = self.Encoder_A(self.attribute_image, is_training=False, reuse=True) 592 | self.guide_fake_A = self.Decoder_A(content_B=guide_content_B, attribute_A=guide_attribute_A, reuse=True) 593 | 594 | def train(self): 595 | # initialize all variables 596 | tf.global_variables_initializer().run() 597 | 598 | # saver to save model 599 | self.saver = tf.train.Saver() 600 | 601 | # summary writer 602 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_dir, self.sess.graph) 603 | 604 | # restore check-point if it exits 605 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 606 | if could_load: 607 | start_epoch = (int)(checkpoint_counter / self.iteration) 608 | start_batch_id = checkpoint_counter - start_epoch * self.iteration 609 | counter = checkpoint_counter 610 | print(" [*] Load SUCCESS") 611 | else: 612 | start_epoch = 0 613 | start_batch_id = 0 614 | counter = 1 615 | print(" [!] Load failed...") 616 | 617 | # loop for epoch 618 | start_time = time.time() 619 | lr = self.init_lr 620 | content_lr = self.content_init_lr 621 | for epoch in range(start_epoch, self.epoch): 622 | if self.decay_flag: 623 | lr = self.init_lr if epoch < self.decay_epoch else self.init_lr * (self.epoch - epoch) / (self.epoch - self.decay_epoch) # linear decay 624 | content_lr = self.content_init_lr if epoch < self.decay_epoch else self.content_init_lr * (self.epoch - epoch) / (self.epoch - self.decay_epoch) # linear decay 625 | 626 | for idx in range(start_batch_id, self.iteration): 627 | train_feed_dict = { 628 | self.lr : lr, 629 | self.content_lr : content_lr 630 | } 631 | 632 | summary_str = self.sess.run(self.lr_write, feed_dict=train_feed_dict) 633 | self.writer.add_summary(summary_str, counter) 634 | 635 | # Update content D 636 | _, d_con_loss, summary_str = self.sess.run([self.D_content_optim, self.Discriminator_content_loss, self.D_content_loss], feed_dict=train_feed_dict) 637 | self.writer.add_summary(summary_str, counter) 638 | 639 | if (counter - 1) % self.n_d_con == 0 : 640 | # Update D 641 | _, d_loss, summary_str = self.sess.run([self.D_optim, self.Discriminator_loss, self.D_loss], feed_dict = train_feed_dict) 642 | self.writer.add_summary(summary_str, counter) 643 | 644 | # Update G 645 | batch_A_images, batch_B_images, fake_A, fake_B, _, g_loss, summary_str = self.sess.run([self.real_A, self.real_B, self.fake_A, self.fake_B, self.G_optim, self.Generator_loss, self.G_loss], feed_dict = train_feed_dict) 646 | self.writer.add_summary(summary_str, counter) 647 | 648 | print("Epoch: [%2d] [%6d/%6d] time: %4.4f d_con_loss: %.8f, d_loss: %.8f, g_loss: %.8f" \ 649 | % (epoch, idx, self.iteration, time.time() - start_time, d_con_loss, d_loss, g_loss)) 650 | 651 | else : 652 | print("Epoch: [%2d] [%6d/%6d] time: %4.4f d_con_loss: %.8f" % (epoch, idx, self.iteration, time.time() - start_time, d_con_loss)) 653 | 654 | if np.mod(idx + 1, self.print_freq) == 0: 655 | save_images(batch_A_images, [self.batch_size, 1], 656 | './{}/real_A_{:03d}_{:05d}.jpg'.format(self.sample_dir, epoch, idx + 1)) 657 | # save_images(batch_B_images, [self.batch_size, 1], 658 | # './{}/real_B_{}_{:03d}_{:05d}.jpg'.format(self.sample_dir, gpu_id, epoch, idx+1)) 659 | 660 | # save_images(fake_A, [self.batch_size, 1], 661 | # './{}/fake_A_{}_{:03d}_{:05d}.jpg'.format(self.sample_dir, gpu_id, epoch, idx+1)) 662 | save_images(fake_B, [self.batch_size, 1], 663 | './{}/fake_B_{:03d}_{:05d}.jpg'.format(self.sample_dir, epoch, idx + 1)) 664 | 665 | # display training status 666 | counter += 1 667 | 668 | if np.mod(idx+1, self.save_freq) == 0 : 669 | self.save(self.checkpoint_dir, counter) 670 | 671 | # After an epoch, start_batch_id is set to zero 672 | # non-zero value is only for the first epoch after loading pre-trained model 673 | start_batch_id = 0 674 | 675 | # save model for final step 676 | self.save(self.checkpoint_dir, counter) 677 | 678 | @property 679 | def model_dir(self): 680 | if self.concat : 681 | concat = "_concat" 682 | else : 683 | concat = "" 684 | 685 | if self.sn : 686 | sn = "_sn" 687 | else : 688 | sn = "" 689 | 690 | return "{}{}_{}_{}_{}layer_{}dis_{}scale_{}con{}".format(self.model_name, concat, self.dataset_name, self.gan_type, 691 | self.n_layer, self.n_dis, self.n_scale, self.n_d_con, sn) 692 | 693 | def save(self, checkpoint_dir, step): 694 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) 695 | 696 | if not os.path.exists(checkpoint_dir): 697 | os.makedirs(checkpoint_dir) 698 | 699 | self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step) 700 | 701 | def load(self, checkpoint_dir): 702 | print(" [*] Reading checkpoints...") 703 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) 704 | 705 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 706 | if ckpt and ckpt.model_checkpoint_path: 707 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 708 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 709 | counter = int(ckpt_name.split('-')[-1]) 710 | print(" [*] Success to read {}".format(ckpt_name)) 711 | return True, counter 712 | else: 713 | print(" [*] Failed to find a checkpoint") 714 | return False, 0 715 | 716 | def test(self): 717 | tf.global_variables_initializer().run() 718 | test_A_files = glob('./dataset/{}/*.*'.format(self.dataset_name + '/testA')) 719 | test_B_files = glob('./dataset/{}/*.*'.format(self.dataset_name + '/testB')) 720 | 721 | self.saver = tf.train.Saver() 722 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 723 | self.result_dir = os.path.join(self.result_dir, self.model_dir) 724 | check_folder(self.result_dir) 725 | 726 | if could_load : 727 | print(" [*] Load SUCCESS") 728 | else : 729 | print(" [!] Load failed...") 730 | 731 | # write html for visual comparison 732 | index_path = os.path.join(self.result_dir, 'index.html') 733 | index = open(index_path, 'w') 734 | index.write("") 735 | index.write("") 736 | 737 | for sample_file in test_A_files : # A -> B 738 | print('Processing A image: ' + sample_file) 739 | sample_image = np.asarray(load_test_data(sample_file, size=self.img_size)) 740 | file_name = os.path.basename(sample_file).split(".")[0] 741 | file_extension = os.path.basename(sample_file).split(".")[1] 742 | 743 | for i in range(self.num_attribute) : 744 | image_path = os.path.join(self.result_dir, '{}_attribute{}.{}'.format(file_name, i, file_extension)) 745 | 746 | fake_img = self.sess.run(self.test_fake_B, feed_dict = {self.test_image : sample_image}) 747 | save_images(fake_img, [1, 1], image_path) 748 | 749 | index.write("" % os.path.basename(image_path)) 750 | index.write("" % (sample_file if os.path.isabs(sample_file) else ( 751 | '../..' + os.path.sep + sample_file), self.img_size, self.img_size)) 752 | index.write("" % (image_path if os.path.isabs(image_path) else ( 753 | '../..' + os.path.sep + image_path), self.img_size, self.img_size)) 754 | index.write("") 755 | 756 | for sample_file in test_B_files : # B -> A 757 | print('Processing B image: ' + sample_file) 758 | sample_image = np.asarray(load_test_data(sample_file, size=self.img_size)) 759 | file_name = os.path.basename(sample_file).split(".")[0] 760 | file_extension = os.path.basename(sample_file).split(".")[1] 761 | 762 | for i in range(self.num_attribute): 763 | image_path = os.path.join(self.result_dir, '{}_attribute{}.{}'.format(file_name, i, file_extension)) 764 | 765 | fake_img = self.sess.run(self.test_fake_A, feed_dict={self.test_image: sample_image}) 766 | save_images(fake_img, [1, 1], image_path) 767 | 768 | index.write("" % os.path.basename(image_path)) 769 | index.write("" % (sample_file if os.path.isabs(sample_file) else ( 770 | '../..' + os.path.sep + sample_file), self.img_size, self.img_size)) 771 | index.write("" % (image_path if os.path.isabs(image_path) else ( 772 | '../..' + os.path.sep + image_path), self.img_size, self.img_size)) 773 | index.write("") 774 | index.close() 775 | 776 | def guide_test(self): 777 | tf.global_variables_initializer().run() 778 | test_A_files = glob('./dataset/{}/*.*'.format(self.dataset_name + '/testA')) 779 | test_B_files = glob('./dataset/{}/*.*'.format(self.dataset_name + '/testB')) 780 | 781 | attribute_file = np.asarray(load_test_data(self.guide_img, size=self.img_size)) 782 | 783 | self.saver = tf.train.Saver() 784 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 785 | self.result_dir = os.path.join(self.result_dir, self.model_dir, 'guide') 786 | check_folder(self.result_dir) 787 | 788 | if could_load: 789 | print(" [*] Load SUCCESS") 790 | else: 791 | print(" [!] Load failed...") 792 | 793 | # write html for visual comparison 794 | index_path = os.path.join(self.result_dir, 'index.html') 795 | index = open(index_path, 'w') 796 | index.write("
nameinputoutput
%s
%s
") 797 | index.write("") 798 | 799 | if self.direction == 'a2b' : 800 | for sample_file in test_A_files: # A -> B 801 | print('Processing A image: ' + sample_file) 802 | sample_image = np.asarray(load_test_data(sample_file, size=self.img_size)) 803 | image_path = os.path.join(self.result_dir, '{}'.format(os.path.basename(sample_file))) 804 | 805 | fake_img = self.sess.run(self.guide_fake_B, feed_dict={self.content_image: sample_image, self.attribute_image : attribute_file}) 806 | save_images(fake_img, [1, 1], image_path) 807 | 808 | index.write("" % os.path.basename(image_path)) 809 | index.write("" % (sample_file if os.path.isabs(sample_file) else ( 810 | '../../..' + os.path.sep + sample_file), self.img_size, self.img_size)) 811 | index.write("" % (image_path if os.path.isabs(image_path) else ( 812 | '../../..' + os.path.sep + image_path), self.img_size, self.img_size)) 813 | index.write("") 814 | 815 | else : 816 | for sample_file in test_B_files: # B -> A 817 | print('Processing B image: ' + sample_file) 818 | sample_image = np.asarray(load_test_data(sample_file, size=self.img_size)) 819 | image_path = os.path.join(self.result_dir, '{}'.format(os.path.basename(sample_file))) 820 | 821 | fake_img = self.sess.run(self.guide_fake_A, feed_dict={self.content_image: sample_image, self.attribute_image : attribute_file}) 822 | save_images(fake_img, [1, 1], image_path) 823 | 824 | index.write("" % os.path.basename(image_path)) 825 | index.write("" % (sample_file if os.path.isabs(sample_file) else ( 826 | '../../..' + os.path.sep + sample_file), self.img_size, self.img_size)) 827 | index.write("" % (image_path if os.path.isabs(image_path) else ( 828 | '../../..' + os.path.sep + image_path), self.img_size, self.img_size)) 829 | index.write("") 830 | index.close() 831 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Junho Kim (1993.01.12) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DRIT-Tensorflow 2 | Simple Tensorflow implementation of [Diverse Image-to-Image Translation via Disentangled Representations](https://arxiv.org/abs/1808.00948) (ECCV 2018 Oral) 3 | 4 | 5 | 6 | ## Pytorch version 7 | * [Author_pytorch_code](https://github.com/HsinYingLee/DRIT) 8 | 9 | ## Requirements 10 | * Tensorflow 1.8 11 | * python 3.6 12 | 13 | ## Usage 14 | ### Download Dataset 15 | * [cat2dog](http://vllab.ucmerced.edu/hylee/DRIT/datasets/cat2dog) 16 | * [portrait](http://vllab.ucmerced.edu/hylee/DRIT/datasets/portrait) 17 | * [CycleGAN](https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/) 18 | 19 | ``` 20 | ├── dataset 21 |    └── YOUR_DATASET_NAME 22 |    ├── trainA 23 |           ├── xxx.jpg (name, format doesn't matter) 24 | ├── yyy.png 25 | └── ... 26 |    ├── trainB 27 | ├── zzz.jpg 28 | ├── www.png 29 | └── ... 30 |    ├── testA 31 |    ├── aaa.jpg 32 | ├── bbb.png 33 | └── ... 34 |    └── testB 35 | ├── ccc.jpg 36 | ├── ddd.png 37 | └── ... 38 | 39 | ├── guide.jpg (example for guided image translation task) 40 | ``` 41 | 42 | ### Train 43 | ``` 44 | python main.py --phase train --dataset summer2winter --concat True 45 | ``` 46 | 47 | ### Test 48 | ``` 49 | python main.py --phase test --dataset summer2winter --concat True --num_attribute 3 50 | ``` 51 | 52 | ### Guide 53 | ``` 54 | python main.py --phase guide --dataset summer2winter --concat True --direction a2b --guide_img ./guide.jpg 55 | ``` 56 | 57 | ### Tips 58 | * --concat 59 | * `True` : for the **shape preserving translation** (summer <-> winter) **(default)** 60 | * `False` : for the **shape variation translation** (cat <-> dog) 61 | 62 | * --n_scale 63 | * Recommend `n_scale = 3` **(default)** 64 | * Using the `n_scale > 1`, a.k.a. `multiscale discriminator` often gets better results 65 | 66 | * --n_dis 67 | * If you use the multi-discriminator, then recommend `n_dis = 4` **(default)** 68 | * If you don't the use multi-discriminator, then recommend `n_dis = 6` 69 | 70 | * --n_d_con 71 | * Author use `n_d_con = 3` **(default)** 72 | * Model can still generate diverse results with `n_d_con = 1` 73 | 74 | * --num_attribute **(only for the test phase)** 75 | * If you use the `num_attribute > 1`, then output images are variously generated 76 | 77 | ## Summary 78 | ### Comparison 79 | ![comparison](./assets/comparison.png) 80 | 81 | ### Architecture 82 | ![true](./assets/true.png) 83 | ![false](./assets/false.png) 84 | 85 | ### Train phase 86 | ![train_1](./assets/train_1.png) 87 | ![train_2](./assets/train_2.png) 88 | 89 | ### Test & Guide phase 90 | ![test](./assets/test.png) 91 | 92 | ## Results 93 | ![result_1](./assets/result1.png) 94 | ![result_2](./assets/result2.png) 95 | 96 | ## Related works 97 | * [UNIT-Tensorflow](https://github.com/taki0112/UNIT-Tensorflow) 98 | * [MUNIT-Tensorflow](https://github.com/taki0112/MUNIT-Tensorflow) 99 | 100 | ## Author 101 | Junho Kim 102 | -------------------------------------------------------------------------------- /assets/comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/DRIT-Tensorflow/384f6aac3e91898ee400c57418d2e7b3d6df1916/assets/comparison.png -------------------------------------------------------------------------------- /assets/false.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/DRIT-Tensorflow/384f6aac3e91898ee400c57418d2e7b3d6df1916/assets/false.png -------------------------------------------------------------------------------- /assets/final.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/DRIT-Tensorflow/384f6aac3e91898ee400c57418d2e7b3d6df1916/assets/final.gif -------------------------------------------------------------------------------- /assets/result1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/DRIT-Tensorflow/384f6aac3e91898ee400c57418d2e7b3d6df1916/assets/result1.png -------------------------------------------------------------------------------- /assets/result2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/DRIT-Tensorflow/384f6aac3e91898ee400c57418d2e7b3d6df1916/assets/result2.png -------------------------------------------------------------------------------- /assets/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/DRIT-Tensorflow/384f6aac3e91898ee400c57418d2e7b3d6df1916/assets/test.png -------------------------------------------------------------------------------- /assets/test_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/DRIT-Tensorflow/384f6aac3e91898ee400c57418d2e7b3d6df1916/assets/test_1.png -------------------------------------------------------------------------------- /assets/test_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/DRIT-Tensorflow/384f6aac3e91898ee400c57418d2e7b3d6df1916/assets/test_2.png -------------------------------------------------------------------------------- /assets/train_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/DRIT-Tensorflow/384f6aac3e91898ee400c57418d2e7b3d6df1916/assets/train_1.png -------------------------------------------------------------------------------- /assets/train_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/DRIT-Tensorflow/384f6aac3e91898ee400c57418d2e7b3d6df1916/assets/train_2.png -------------------------------------------------------------------------------- /assets/true.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/DRIT-Tensorflow/384f6aac3e91898ee400c57418d2e7b3d6df1916/assets/true.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from DRIT import DRIT 2 | import argparse 3 | from utils import * 4 | 5 | """parsing and configuration""" 6 | def parse_args(): 7 | desc = "Tensorflow implementation of DRIT" 8 | 9 | parser = argparse.ArgumentParser(description=desc) 10 | parser.add_argument('--phase', type=str, default='train', help='[train, test, guide]') 11 | parser.add_argument('--dataset', type=str, default='cat2dog', help='dataset_name') 12 | parser.add_argument('--augment_flag', type=str2bool, default=True, help='Image augmentation use or not') 13 | parser.add_argument('--decay_flag', type=str2bool, default=True, help='using learning rate decay') 14 | 15 | parser.add_argument('--epoch', type=int, default=50, help='The number of epochs to run') 16 | parser.add_argument('--decay_epoch', type=int, default=25, help='The number of decay epochs to run') 17 | parser.add_argument('--iteration', type=int, default=10000, help='The number of training iterations') 18 | parser.add_argument('--batch_size', type=int, default=1, help='The batch size') 19 | parser.add_argument('--print_freq', type=int, default=1000, help='The number of image_print_freq') 20 | parser.add_argument('--save_freq', type=int, default=1000, help='The number of ckpt_save_freq') 21 | 22 | parser.add_argument('--num_attribute', type=int, default=3, help='number of attributes to sample') 23 | parser.add_argument('--direction', type=str, default='a2b', help='direction of guided image translation') 24 | parser.add_argument('--guide_img', type=str, default='guide.jpg', help='Style guided image translation') 25 | 26 | parser.add_argument('--gan_type', type=str, default='gan', help='GAN loss type [gan / lsgan]') 27 | 28 | parser.add_argument('--lr', type=float, default=0.0001, help='The learning rate') 29 | parser.add_argument('--content_adv_w', type=int, default=1, help='weight of content adversarial loss') 30 | parser.add_argument('--domain_adv_w', type=int, default=1, help='weight of domain adversarial loss') 31 | parser.add_argument('--cycle_w', type=int, default=10, help='weight of cross-cycle reconstruction loss') 32 | parser.add_argument('--recon_w', type=int, default=10, help='weight of self-reconstruction loss') 33 | parser.add_argument('--latent_w', type=int, default=10, help='wight of latent regression loss') 34 | parser.add_argument('--kl_w', type=float, default=0.01, help='weight of kl-divergence loss') 35 | 36 | parser.add_argument('--ch', type=int, default=64, help='base channel number per layer') 37 | parser.add_argument('--concat', type=str2bool, default=False, help='using concat networks') 38 | 39 | # concat = False : for the shape variation translation (cat <-> dog) 40 | # concat = True : for the shape preserving translation (winter <-> summer) 41 | 42 | parser.add_argument('--n_z', type=int, default=8, help='length of z') 43 | parser.add_argument('--n_layer', type=int, default=4, help='number of layers in G, D') 44 | 45 | parser.add_argument('--n_dis', type=int, default=4, help='number of discriminator layer') 46 | 47 | # If you don't use multi-discriminator, then recommend n_dis = 6 48 | 49 | parser.add_argument('--n_scale', type=int, default=3, help='number of scales for discriminator') 50 | 51 | # using the multiscale discriminator often gets better results 52 | 53 | parser.add_argument('--n_d_con', type=int, default=3, help='# of iterations for updating content discrimnator') 54 | 55 | # model can still generate diverse results with n_d_con = 1 56 | 57 | parser.add_argument('--sn', type=str2bool, default=False, help='using spectral normalization') 58 | 59 | parser.add_argument('--img_size', type=int, default=256, help='The size of image') 60 | parser.add_argument('--img_ch', type=int, default=3, help='The size of image channel') 61 | 62 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint', 63 | help='Directory name to save the checkpoints') 64 | parser.add_argument('--result_dir', type=str, default='results', 65 | help='Directory name to save the generated images') 66 | parser.add_argument('--log_dir', type=str, default='logs', 67 | help='Directory name to save training logs') 68 | parser.add_argument('--sample_dir', type=str, default='samples', 69 | help='Directory name to save the samples on training') 70 | 71 | return check_args(parser.parse_args()) 72 | 73 | """checking arguments""" 74 | def check_args(args): 75 | # --checkpoint_dir 76 | check_folder(args.checkpoint_dir) 77 | 78 | # --result_dir 79 | check_folder(args.result_dir) 80 | 81 | # --result_dir 82 | check_folder(args.log_dir) 83 | 84 | # --sample_dir 85 | check_folder(args.sample_dir) 86 | 87 | # --epoch 88 | try: 89 | assert args.epoch >= 1 90 | except: 91 | print('number of epochs must be larger than or equal to one') 92 | 93 | # --batch_size 94 | try: 95 | assert args.batch_size >= 1 96 | except: 97 | print('batch size must be larger than or equal to one') 98 | return args 99 | 100 | """main""" 101 | def main(): 102 | # parse arguments 103 | args = parse_args() 104 | if args is None: 105 | exit() 106 | 107 | # open session 108 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: 109 | gan = DRIT(sess, args) 110 | 111 | # build graph 112 | gan.build_model() 113 | 114 | # show network architecture 115 | show_all_variables() 116 | 117 | if args.phase == 'train' : 118 | # launch the graph in a session 119 | gan.train() 120 | print(" [*] Training finished!") 121 | 122 | if args.phase == 'test' : 123 | gan.test() 124 | print(" [*] Test finished!") 125 | 126 | if args.phase == 'guide' : 127 | gan.guide_test() 128 | print(" [*] Guide finished!") 129 | 130 | if __name__ == '__main__': 131 | main() 132 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib as tf_contrib 3 | 4 | # Xavier : tf_contrib.layers.xavier_initializer() 5 | # He : tf_contrib.layers.variance_scaling_initializer() 6 | # Normal : tf.random_normal_initializer(mean=0.0, stddev=0.02) 7 | # l2_decay : tf_contrib.layers.l2_regularizer(0.0001) 8 | 9 | weight_init = tf.random_normal_initializer(mean=0.0, stddev=0.02) 10 | weight_regularizer = None 11 | 12 | ################################################################################## 13 | # Layer 14 | ################################################################################## 15 | 16 | def conv(x, channels, kernel=4, stride=2, pad=0, pad_type='zero', use_bias=True, sn=False, scope='conv'): 17 | with tf.variable_scope(scope): 18 | if pad > 0 : 19 | if (kernel - stride) % 2 == 0: 20 | pad_top = pad 21 | pad_bottom = pad 22 | pad_left = pad 23 | pad_right = pad 24 | 25 | else: 26 | pad_top = pad 27 | pad_bottom = kernel - stride - pad_top 28 | pad_left = pad 29 | pad_right = kernel - stride - pad_left 30 | 31 | if pad_type == 'zero': 32 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]) 33 | if pad_type == 'reflect': 34 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]], mode='REFLECT') 35 | 36 | if sn : 37 | w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], initializer=weight_init, regularizer=weight_regularizer) 38 | x = tf.nn.conv2d(input=x, filter=spectral_norm(w), strides=[1, stride, stride, 1], padding='VALID') 39 | if use_bias : 40 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0)) 41 | x = tf.nn.bias_add(x, bias) 42 | 43 | else : 44 | x = tf.layers.conv2d(inputs=x, filters=channels, 45 | kernel_size=kernel, kernel_initializer=weight_init, 46 | kernel_regularizer=weight_regularizer, 47 | strides=stride, use_bias=use_bias) 48 | 49 | 50 | return x 51 | 52 | def deconv(x, channels, kernel=4, stride=2, padding='SAME', use_bias=True, sn=False, scope='deconv'): 53 | with tf.variable_scope(scope): 54 | x_shape = x.get_shape().as_list() 55 | 56 | if padding == 'SAME': 57 | output_shape = [x_shape[0], x_shape[1] * stride, x_shape[2] * stride, channels] 58 | 59 | else: 60 | output_shape =[x_shape[0], x_shape[1] * stride + max(kernel - stride, 0), x_shape[2] * stride + max(kernel - stride, 0), channels] 61 | 62 | if sn : 63 | w = tf.get_variable("kernel", shape=[kernel, kernel, channels, x.get_shape()[-1]], initializer=weight_init, regularizer=weight_regularizer) 64 | x = tf.nn.conv2d_transpose(x, filter=spectral_norm(w), output_shape=output_shape, strides=[1, stride, stride, 1], padding=padding) 65 | 66 | if use_bias : 67 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0)) 68 | x = tf.nn.bias_add(x, bias) 69 | 70 | else : 71 | x = tf.layers.conv2d_transpose(inputs=x, filters=channels, 72 | kernel_size=kernel, kernel_initializer=weight_init, kernel_regularizer=weight_regularizer, 73 | strides=stride, padding=padding, use_bias=use_bias) 74 | 75 | return x 76 | 77 | def fully_conneted(x, channels, use_bias=True, sn=False, scope='fully'): 78 | with tf.variable_scope(scope): 79 | x = tf.layers.flatten(x) 80 | shape = x.get_shape().as_list() 81 | x_channel = shape[-1] 82 | 83 | if sn : 84 | w = tf.get_variable("kernel", [x_channel, channels], tf.float32, initializer=weight_init, regularizer=weight_regularizer) 85 | if use_bias : 86 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0)) 87 | 88 | x = tf.matmul(x, spectral_norm(w)) + bias 89 | else : 90 | x = tf.matmul(x, spectral_norm(w)) 91 | 92 | else : 93 | x = tf.layers.dense(x, units=channels, kernel_initializer=weight_init, kernel_regularizer=weight_regularizer, use_bias=use_bias) 94 | 95 | return x 96 | 97 | def gaussian_noise_layer(x, is_training=False): 98 | if is_training : 99 | noise = tf.random_normal(shape=tf.shape(x), mean=0.0, stddev=1.0, dtype=tf.float32) 100 | return x + noise 101 | 102 | else : 103 | return x 104 | 105 | ################################################################################## 106 | # Block 107 | ################################################################################## 108 | 109 | def resblock(x_init, channels, use_bias=True, sn=False, scope='resblock'): 110 | with tf.variable_scope(scope): 111 | with tf.variable_scope('res1'): 112 | x = conv(x_init, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias, sn=sn) 113 | x = instance_norm(x) 114 | x = relu(x) 115 | 116 | with tf.variable_scope('res2'): 117 | x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias, sn=sn) 118 | x = instance_norm(x) 119 | 120 | return x + x_init 121 | 122 | def basic_block(x_init, channels, use_bias=True, sn=False, scope='basic_block') : 123 | with tf.variable_scope(scope) : 124 | x = lrelu(x_init, 0.2) 125 | x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias, sn=sn) 126 | 127 | x = lrelu(x, 0.2) 128 | x = conv_avg(x, channels, use_bias=use_bias, sn=sn) 129 | 130 | shortcut = avg_conv(x_init, channels, use_bias=use_bias, sn=sn) 131 | 132 | return x + shortcut 133 | 134 | def mis_resblock(x_init, z, channels, use_bias=True, sn=False, scope='mis_resblock') : 135 | with tf.variable_scope(scope) : 136 | z = tf.reshape(z, shape=[-1, 1, 1, z.shape[-1]]) 137 | z = tf.tile(z, multiples=[1, x_init.shape[1], x_init.shape[2], 1]) # expand 138 | 139 | with tf.variable_scope('mis1') : 140 | x = conv(x_init, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias, sn=sn, scope='conv3x3') 141 | x = instance_norm(x) 142 | 143 | x = tf.concat([x, z], axis=-1) 144 | x = conv(x, channels * 2, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv1x1_0') 145 | x = relu(x) 146 | 147 | x = conv(x, channels, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv1x1_1') 148 | x = relu(x) 149 | 150 | with tf.variable_scope('mis2') : 151 | x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias, sn=sn, scope='conv3x3') 152 | x = instance_norm(x) 153 | 154 | x = tf.concat([x, z], axis=-1) 155 | x = conv(x, channels * 2, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv1x1_0') 156 | x = relu(x) 157 | 158 | x = conv(x, channels, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv1x1_1') 159 | x = relu(x) 160 | 161 | return x + x_init 162 | 163 | def avg_conv(x, channels, use_bias=True, sn=False, scope='avg_conv') : 164 | with tf.variable_scope(scope) : 165 | x = avg_pooling(x, kernel=2, stride=2) 166 | x = conv(x, channels, kernel=1, stride=1, use_bias=use_bias, sn=sn) 167 | 168 | return x 169 | 170 | def conv_avg(x, channels, use_bias=True, sn=False, scope='conv_avg') : 171 | with tf.variable_scope(scope) : 172 | x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias, sn=sn) 173 | x = avg_pooling(x, kernel=2, stride=2) 174 | 175 | return x 176 | 177 | def expand_concat(x, z) : 178 | z = tf.reshape(z, shape=[z.shape[0], 1, 1, -1]) 179 | z = tf.tile(z, multiples=[1, x.shape[1], x.shape[2], 1]) # expand 180 | x = tf.concat([x, z], axis=-1) 181 | 182 | return x 183 | 184 | ################################################################################## 185 | # Sampling 186 | ################################################################################## 187 | 188 | def down_sample(x) : 189 | return avg_pooling(x, kernel=3, stride=2, pad=1) 190 | 191 | def avg_pooling(x, kernel=2, stride=2, pad=0) : 192 | if pad > 0 : 193 | if (kernel - stride) % 2 == 0: 194 | pad_top = pad 195 | pad_bottom = pad 196 | pad_left = pad 197 | pad_right = pad 198 | 199 | else: 200 | pad_top = pad 201 | pad_bottom = kernel - stride - pad_top 202 | pad_left = pad 203 | pad_right = kernel - stride - pad_left 204 | 205 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]) 206 | 207 | return tf.layers.average_pooling2d(x, pool_size=kernel, strides=stride, padding='VALID') 208 | 209 | def global_avg_pooling(x): 210 | gap = tf.reduce_mean(x, axis=[1, 2], keepdims=True) 211 | 212 | return gap 213 | 214 | def z_sample(mean, logvar) : 215 | eps = tf.random_normal(shape=tf.shape(mean), mean=0.0, stddev=1.0, dtype=tf.float32) 216 | 217 | return mean + tf.exp(logvar * 0.5) * eps 218 | 219 | ################################################################################## 220 | # Activation function 221 | ################################################################################## 222 | 223 | def lrelu(x, alpha=0.01): 224 | # pytorch alpha is 0.01 225 | return tf.nn.leaky_relu(x, alpha) 226 | 227 | 228 | def relu(x): 229 | return tf.nn.relu(x) 230 | 231 | 232 | def tanh(x): 233 | return tf.tanh(x) 234 | 235 | ################################################################################## 236 | # Normalization function 237 | ################################################################################## 238 | 239 | def instance_norm(x, scope='instance_norm'): 240 | return tf_contrib.layers.instance_norm(x, 241 | epsilon=1e-05, 242 | center=True, scale=True, 243 | scope=scope) 244 | 245 | def layer_norm(x, scope='layer_norm') : 246 | return tf_contrib.layers.layer_norm(x, 247 | center=True, scale=True, 248 | scope=scope) 249 | 250 | def spectral_norm(w, iteration=1): 251 | w_shape = w.shape.as_list() 252 | w = tf.reshape(w, [-1, w_shape[-1]]) 253 | 254 | u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.random_normal_initializer(), trainable=False) 255 | 256 | u_hat = u 257 | v_hat = None 258 | for i in range(iteration): 259 | """ 260 | power iteration 261 | Usually iteration = 1 will be enough 262 | """ 263 | v_ = tf.matmul(u_hat, tf.transpose(w)) 264 | v_hat = tf.nn.l2_normalize(v_) 265 | 266 | u_ = tf.matmul(v_hat, w) 267 | u_hat = tf.nn.l2_normalize(u_) 268 | 269 | u_hat = tf.stop_gradient(u_hat) 270 | v_hat = tf.stop_gradient(v_hat) 271 | 272 | sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat)) 273 | 274 | with tf.control_dependencies([u.assign(u_hat)]): 275 | w_norm = w / sigma 276 | w_norm = tf.reshape(w_norm, w_shape) 277 | 278 | 279 | return w_norm 280 | 281 | ################################################################################## 282 | # Loss function 283 | ################################################################################## 284 | 285 | def discriminator_loss(type, real, fake, fake_random=None, content=False): 286 | n_scale = len(real) 287 | loss = [] 288 | 289 | real_loss = 0 290 | fake_loss = 0 291 | fake_random_loss = 0 292 | 293 | if content : 294 | for i in range(n_scale): 295 | if type == 'lsgan' : 296 | real_loss = tf.reduce_mean(tf.squared_difference(real[i], 1.0)) 297 | fake_loss = tf.reduce_mean(tf.square(fake[i])) 298 | 299 | if type =='gan' : 300 | real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real[i]), logits=real[i])) 301 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(fake[i]), logits=fake[i])) 302 | 303 | loss.append(real_loss + fake_loss) 304 | 305 | else : 306 | for i in range(n_scale) : 307 | if type == 'lsgan' : 308 | real_loss = tf.reduce_mean(tf.squared_difference(real[i], 1.0)) 309 | fake_loss = tf.reduce_mean(tf.square(fake[i])) 310 | fake_random_loss = tf.reduce_mean(tf.square(fake_random[i])) 311 | 312 | if type == 'gan' : 313 | real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real[i]), logits=real[i])) 314 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(fake[i]), logits=fake[i])) 315 | fake_random_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(fake_random[i]), logits=fake_random[i])) 316 | 317 | loss.append(real_loss * 2 + fake_loss + fake_random_loss) 318 | 319 | return sum(loss) 320 | 321 | 322 | def generator_loss(type, fake, content=False): 323 | n_scale = len(fake) 324 | loss = [] 325 | 326 | fake_loss = 0 327 | 328 | if content : 329 | for i in range(n_scale): 330 | if type =='lsgan' : 331 | fake_loss = tf.reduce_mean(tf.squared_difference(fake[i], 0.5)) 332 | 333 | if type == 'gan' : 334 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=0.5 * tf.ones_like(fake[i]), logits=fake[i])) 335 | 336 | loss.append(fake_loss) 337 | else : 338 | for i in range(n_scale) : 339 | if type == 'lsgan' : 340 | fake_loss = tf.reduce_mean(tf.squared_difference(fake[i], 1.0)) 341 | 342 | if type == 'gan' : 343 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(fake[i]), logits=fake[i])) 344 | 345 | loss.append(fake_loss) 346 | 347 | 348 | return sum(loss) 349 | 350 | 351 | def l2_regularize(x) : 352 | loss = tf.reduce_mean(tf.square(x)) 353 | 354 | return loss 355 | 356 | def kl_loss(mu, logvar) : 357 | loss = 0.5 * tf.reduce_sum(tf.square(mu) + tf.exp(logvar) - 1 - logvar, axis=-1) 358 | loss = tf.reduce_mean(loss) 359 | 360 | 361 | return loss 362 | 363 | def L1_loss(x, y): 364 | loss = tf.reduce_mean(tf.abs(x - y)) 365 | 366 | return loss 367 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib import slim 3 | from scipy import misc 4 | import os, random 5 | import numpy as np 6 | 7 | class ImageData: 8 | 9 | def __init__(self, img_size, channels, augment_flag=False): 10 | self.img_size = img_size 11 | self.channels = channels 12 | self.augment_flag = augment_flag 13 | 14 | def image_processing(self, filename): 15 | x = tf.read_file(filename) 16 | x_decode = tf.image.decode_jpeg(x, channels=self.channels) 17 | img = tf.image.resize_images(x_decode, [self.img_size, self.img_size]) 18 | img = tf.cast(img, tf.float32) / 127.5 - 1 19 | 20 | if self.augment_flag : 21 | if self.img_size < 256 : 22 | augment_size = 256 23 | else : 24 | augment_size = self.img_size + 30 25 | p = random.random() 26 | if p > 0.5: 27 | img = augmentation(img, augment_size) 28 | 29 | return img 30 | 31 | 32 | def load_test_data(image_path, size=256): 33 | img = misc.imread(image_path, mode='RGB') 34 | img = misc.imresize(img, [size, size]) 35 | img = np.expand_dims(img, axis=0) 36 | img = preprocessing(img) 37 | 38 | return img 39 | 40 | def preprocessing(x): 41 | x = x/127.5 - 1 # -1 ~ 1 42 | return x 43 | 44 | def augmentation(image, aug_img_size): 45 | seed = random.randint(0, 2 ** 31 - 1) 46 | ori_image_shape = tf.shape(image) 47 | image = tf.image.random_flip_left_right(image, seed=seed) 48 | image = tf.image.resize_images(image, [aug_img_size, aug_img_size]) 49 | image = tf.random_crop(image, ori_image_shape, seed=seed) 50 | return image 51 | 52 | def save_images(images, size, image_path): 53 | return imsave(inverse_transform(images), size, image_path) 54 | 55 | def inverse_transform(images): 56 | return (images+1.) / 2 57 | 58 | def imsave(images, size, path): 59 | return misc.imsave(path, merge(images, size)) 60 | 61 | def merge(images, size): 62 | h, w = images.shape[1], images.shape[2] 63 | img = np.zeros((h * size[0], w * size[1], 3)) 64 | for idx, image in enumerate(images): 65 | i = idx % size[1] 66 | j = idx // size[1] 67 | img[h*j:h*(j+1), w*i:w*(i+1), :] = image 68 | 69 | return img 70 | 71 | def show_all_variables(): 72 | model_vars = tf.trainable_variables() 73 | slim.model_analyzer.analyze_vars(model_vars, print_info=True) 74 | 75 | def check_folder(log_dir): 76 | if not os.path.exists(log_dir): 77 | os.makedirs(log_dir) 78 | return log_dir 79 | 80 | def str2bool(x): 81 | return x.lower() in ('true') --------------------------------------------------------------------------------
nameinputoutput
%s
%s