├── .DS_Store ├── .gitignore ├── FUNIT.py ├── LICENSE ├── README.md ├── assets ├── .DS_Store ├── animal.gif ├── architecture.png ├── funit_example.jpg ├── our_result.png └── process.png ├── main.py ├── ops.py └── utils.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/FUNIT-Tensorflow/6a160e5690544359133fc0860cedf2a61dbdcaf9/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /FUNIT.py: -------------------------------------------------------------------------------- 1 | from ops import * 2 | from utils import * 3 | import time 4 | from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | 9 | class FUNIT(object): 10 | def __init__(self, sess, args): 11 | 12 | self.phase = args.phase 13 | self.model_name = 'FUNIT' 14 | 15 | self.sess = sess 16 | self.checkpoint_dir = args.checkpoint_dir 17 | self.result_dir = args.result_dir 18 | self.log_dir = args.log_dir 19 | self.dataset_name = args.dataset 20 | self.augment_flag = args.augment_flag 21 | 22 | self.gpu_num = args.gpu_num 23 | 24 | self.iteration = args.iteration // args.gpu_num 25 | 26 | self.batch_size = args.batch_size 27 | self.print_freq = args.print_freq 28 | self.save_freq = args.save_freq 29 | 30 | self.lr = args.lr 31 | self.ch = args.ch 32 | self.ema_decay = args.ema_decay 33 | 34 | self.K = args.K 35 | 36 | self.gan_type = args.gan_type 37 | 38 | 39 | """ Weight """ 40 | self.adv_weight = args.adv_weight 41 | self.recon_weight = args.recon_weight 42 | self.feature_weight = args.feature_weight 43 | 44 | 45 | """ Generator """ 46 | self.latent_dim = args.latent_dim 47 | 48 | """ Discriminator """ 49 | self.sn = args.sn 50 | 51 | self.img_height = args.img_height 52 | self.img_width = args.img_width 53 | 54 | self.img_ch = args.img_ch 55 | 56 | self.sample_dir = os.path.join(args.sample_dir, self.model_dir) 57 | check_folder(self.sample_dir) 58 | 59 | self.dataset_path = os.path.join('./dataset', self.dataset_name, 'train') 60 | self.class_dim = len(glob(self.dataset_path + '/*')) 61 | 62 | print() 63 | 64 | print("##### Information #####") 65 | print("# dataset : ", self.dataset_name) 66 | print("# batch_size : ", self.batch_size) 67 | print("# max iteration : ", self.iteration) 68 | print("# gpu num : ", self.gpu_num) 69 | 70 | print() 71 | 72 | print("##### Generator #####") 73 | print("# latent_dim : ", self.latent_dim) 74 | 75 | print() 76 | 77 | print("##### Discriminator #####") 78 | print("# spectral normalization : ", self.sn) 79 | 80 | print() 81 | 82 | print("##### Weight #####") 83 | print("# adv_weight : ", self.adv_weight) 84 | print("# feature_weight : ", self.feature_weight) 85 | print("# recon_weight : ", self.recon_weight) 86 | 87 | print() 88 | 89 | ################################################################################## 90 | # Generator 91 | ################################################################################## 92 | 93 | def content_encoder(self, x_init, reuse=tf.AUTO_REUSE, scope='content_encoder'): 94 | channel = self.ch 95 | with tf.variable_scope(scope, reuse=reuse) : 96 | x = conv(x_init, channel, kernel=7, stride=1, pad=3, pad_type='reflect',scope='conv') 97 | x = instance_norm(x, scope='ins_norm') 98 | x = relu(x) 99 | 100 | for i in range(3) : 101 | x = conv(x, channel * 2, kernel=4, stride=2, pad=1, pad_type='reflect', scope='conv_' + str(i)) 102 | x = instance_norm(x, scope='ins_norm_' + str(i)) 103 | x = relu(x) 104 | 105 | channel = channel * 2 106 | 107 | for i in range(2) : 108 | x = resblock(x, channel, scope='resblock_' + str(i)) 109 | 110 | return x 111 | 112 | def class_encoder(self, x_init, reuse=tf.AUTO_REUSE, scope='class_encoder'): 113 | channel = self.ch 114 | with tf.variable_scope(scope, reuse=reuse) : 115 | x = conv(x_init, channel, kernel=7, stride=1, pad=3, pad_type='reflect', scope='conv') 116 | x = relu(x) 117 | 118 | for i in range(2) : 119 | x = conv(x, channel * 2, kernel=4, stride=2, pad=1, pad_type='reflect', scope='conv_' + str(i)) 120 | x = relu(x) 121 | 122 | channel = channel * 2 123 | 124 | for i in range(2) : 125 | x = conv(x, channel, kernel=4, stride=2, pad=1, pad_type='reflect', scope='fix_conv_' + str(i)) 126 | x = relu(x) 127 | 128 | x = global_avg_pooling(x) 129 | x = conv(x, channels=self.latent_dim, kernel=1, stride=1, scope='style_logit') 130 | 131 | return x 132 | 133 | def generator(self, content, style, reuse=tf.AUTO_REUSE, scope="generator"): 134 | channel = self.ch * 8 # 512 135 | with tf.variable_scope(scope, reuse=reuse): 136 | x = content 137 | 138 | mu, var = self.MLP(style, channel // 2, scope='MLP') 139 | 140 | for i in range(2) : 141 | idx = 2 * i 142 | x = adaptive_resblock(x, channel, mu[idx], var[idx], mu[idx + 1], var[idx + 1], scope='ada_resbloack_' + str(i)) 143 | 144 | for i in range(3) : 145 | 146 | x = up_sample(x, scale_factor=2) 147 | x = conv(x, channel//2, kernel=5, stride=1, pad=2, pad_type='reflect', scope='up_conv_' + str(i)) 148 | x = instance_norm(x, scope='ins_norm_' + str(i)) 149 | x = relu(x) 150 | 151 | channel = channel // 2 152 | 153 | x = conv(x, channels=self.img_ch, kernel=7, stride=1, pad=3, pad_type='reflect', scope='g_logit') 154 | x = tanh(x) 155 | 156 | return x 157 | 158 | def MLP(self, style, channel, scope='MLP'): 159 | with tf.variable_scope(scope): 160 | x = style 161 | 162 | for i in range(2) : 163 | x = fully_connected(x, channel, scope='FC_' + str(i)) 164 | x = relu(x) 165 | 166 | mu_list = [] 167 | var_list = [] 168 | 169 | for i in range(4) : 170 | mu = fully_connected(x, channel * 2, scope='FC_mu_' + str(i)) 171 | var = fully_connected(x, channel * 2, scope='FC_var_' + str(i)) 172 | 173 | mu = tf.reshape(mu, shape=[-1, 1, 1, channel * 2]) 174 | var = tf.reshape(var, shape=[-1, 1, 1, channel * 2]) 175 | 176 | mu_list.append(mu) 177 | var_list.append(var) 178 | 179 | 180 | return mu_list, var_list 181 | 182 | 183 | ################################################################################## 184 | # Discriminator 185 | ################################################################################## 186 | 187 | def discriminator(self, x_init, class_onehot, reuse=tf.AUTO_REUSE, scope="discriminator"): 188 | channel = self.ch 189 | class_onehot = tf.reshape(class_onehot, shape=[self.batch_size, 1, 1, -1]) 190 | 191 | with tf.variable_scope(scope, reuse=reuse): 192 | x = conv(x_init, channel, kernel=7, stride=1, pad=3, pad_type='reflect', sn=self.sn, scope='conv') 193 | 194 | for i in range(4) : 195 | x = pre_resblock(x, channel * 2, sn=self.sn, scope='front_resblock_0_' + str(i)) 196 | x = pre_resblock(x, channel * 2, sn=self.sn, scope='front_resblock_1_' + str(i)) 197 | x = down_sample_avg(x, scale_factor=2) 198 | 199 | channel = channel * 2 200 | 201 | for i in range(2) : 202 | x = pre_resblock(x, channel, sn=self.sn, scope='back_resblock_' + str(i)) 203 | 204 | x_feature = x 205 | x = lrelu(x, 0.2) 206 | 207 | x = conv(x, channels=self.class_dim, kernel=1, stride=1, sn=self.sn, scope='d_logit') 208 | x = tf.reduce_sum(x * class_onehot, axis=-1, keepdims=True) # [1, 0, 0, 0, 0] 209 | 210 | return x, x_feature 211 | 212 | ################################################################################## 213 | # Model 214 | ################################################################################## 215 | 216 | 217 | def build_model(self): 218 | 219 | if self.phase == 'train' : 220 | """ Input Image""" 221 | img_data_class = Image_data(self.img_height, self.img_width, self.img_ch, self.dataset_path, self.augment_flag) 222 | img_data_class.preprocess() 223 | 224 | self.dataset_num = len(img_data_class.image_list) 225 | 226 | 227 | img_and_class = tf.data.Dataset.from_tensor_slices((img_data_class.image_list, img_data_class.class_list)) 228 | 229 | gpu_device = '/gpu:0' 230 | img_and_class = img_and_class.apply(shuffle_and_repeat(self.dataset_num)).apply( 231 | map_and_batch(img_data_class.image_processing, batch_size=self.batch_size * self.gpu_num, num_parallel_batches=16, 232 | drop_remainder=True)).apply(prefetch_to_device(gpu_device, None)) 233 | 234 | 235 | img_and_class_iterator = img_and_class.make_one_shot_iterator() 236 | 237 | self.content_img, self.content_class = img_and_class_iterator.get_next() 238 | self.style_img, self.style_class = img_and_class_iterator.get_next() 239 | 240 | self.content_img = tf.split(self.content_img, num_or_size_splits=self.gpu_num) 241 | self.content_class = tf.split(self.content_class, num_or_size_splits=self.gpu_num) 242 | self.style_img = tf.split(self.style_img, num_or_size_splits=self.gpu_num) 243 | self.style_class = tf.split(self.style_class, num_or_size_splits=self.gpu_num) 244 | 245 | self.fake_img = [] 246 | 247 | d_adv_losses = [] 248 | g_adv_losses = [] 249 | g_recon_losses = [] 250 | g_feature_losses = [] 251 | 252 | 253 | for gpu_id in range(self.gpu_num): 254 | with tf.device(tf.DeviceSpec(device_type="GPU", device_index=gpu_id)): 255 | with tf.variable_scope(tf.get_variable_scope(), reuse=(gpu_id > 0)): 256 | """ Define Generator, Discriminator """ 257 | content_code = self.content_encoder(self.content_img[gpu_id]) 258 | style_class_code = self.class_encoder(self.style_img[gpu_id]) 259 | content_class_code = self.class_encoder(self.content_img[gpu_id]) 260 | 261 | fake_img = self.generator(content_code, style_class_code) 262 | recon_img = self.generator(content_code, content_class_code) 263 | 264 | real_logit, style_feature_map = self.discriminator(self.style_img[gpu_id], self.style_class[gpu_id]) 265 | fake_logit, fake_feature_map = self.discriminator(fake_img, self.style_class[gpu_id]) 266 | 267 | recon_logit, recon_feature_map = self.discriminator(recon_img, self.content_class[gpu_id]) 268 | _, content_feature_map = self.discriminator(self.content_img[gpu_id], self.content_class[gpu_id]) 269 | 270 | """ Define Loss """ 271 | d_adv_loss = self.adv_weight * discriminator_loss(self.gan_type, real_logit, fake_logit, self.style_img[gpu_id]) 272 | g_adv_loss = 0.5 * self.adv_weight * (generator_loss(self.gan_type, fake_logit) + generator_loss(self.gan_type, recon_logit)) 273 | 274 | g_recon_loss = self.recon_weight * L1_loss(self.content_img[gpu_id], recon_img) 275 | 276 | content_feature_map = tf.reduce_mean(tf.reduce_mean(content_feature_map, axis=2), axis=1) 277 | recon_feature_map = tf.reduce_mean(tf.reduce_mean(recon_feature_map, axis=2), axis=1) 278 | fake_feature_map = tf.reduce_mean(tf.reduce_mean(fake_feature_map, axis=2), axis=1) 279 | style_feature_map = tf.reduce_mean(tf.reduce_mean(style_feature_map, axis=2), axis=1) 280 | 281 | g_feature_loss = self.feature_weight * (L1_loss(recon_feature_map, content_feature_map) + L1_loss(fake_feature_map, style_feature_map)) 282 | 283 | d_adv_losses.append(d_adv_loss) 284 | g_adv_losses.append(g_adv_loss) 285 | g_recon_losses.append(g_recon_loss) 286 | g_feature_losses.append(g_feature_loss) 287 | 288 | self.fake_img.append(fake_img) 289 | 290 | self.g_loss = tf.reduce_mean(g_adv_losses) + \ 291 | tf.reduce_mean(g_recon_losses) + \ 292 | tf.reduce_mean(g_feature_losses) + regularization_loss('encoder') + regularization_loss('generator') 293 | 294 | self.d_loss = tf.reduce_mean(d_adv_losses) + regularization_loss('discriminator') 295 | 296 | 297 | """ Training """ 298 | t_vars = tf.trainable_variables() 299 | G_vars = [var for var in t_vars if 'encoder' in var.name or 'generator' in var.name] 300 | D_vars = [var for var in t_vars if 'discriminator' in var.name] 301 | 302 | if self.gpu_num == 1 : 303 | prev_G_optim = tf.train.RMSPropOptimizer(self.lr, decay=0.99, epsilon=1e-8).minimize(self.g_loss, var_list=G_vars) 304 | self.D_optim = tf.train.RMSPropOptimizer(self.lr, decay=0.99, epsilon=1e-8).minimize(self.d_loss, var_list=D_vars) 305 | # Pytorch : decay=0.99, epsilon=1e-8 306 | 307 | else : 308 | prev_G_optim = tf.train.RMSPropOptimizer(self.lr, decay=0.99, epsilon=1e-8).minimize(self.g_loss, var_list=G_vars, colocate_gradients_with_ops=True) 309 | self.D_optim = tf.train.RMSPropOptimizer(self.lr, decay=0.99, epsilon=1e-8).minimize(self.d_loss, var_list=D_vars, colocate_gradients_with_ops=True) 310 | # Pytorch : decay=0.99, epsilon=1e-8 311 | 312 | self.ema = tf.train.ExponentialMovingAverage(decay=self.ema_decay) 313 | with tf.control_dependencies([prev_G_optim]): 314 | self.G_optim = self.ema.apply(G_vars) 315 | 316 | 317 | """" Summary """ 318 | self.summary_g_loss = tf.summary.scalar("g_loss", self.g_loss) 319 | self.summary_d_loss = tf.summary.scalar("d_loss", self.d_loss) 320 | 321 | self.summary_g_adv_loss = tf.summary.scalar("g_adv_loss", tf.reduce_mean(g_adv_losses)) 322 | self.summary_g_recon_loss = tf.summary.scalar("g_recon_loss", tf.reduce_mean(g_recon_losses)) 323 | self.summary_g_feature_loss = tf.summary.scalar("g_feature_loss", tf.reduce_mean(g_feature_losses)) 324 | 325 | 326 | g_summary_list = [self.summary_g_loss, 327 | self.summary_g_adv_loss, 328 | self.summary_g_recon_loss, self.summary_g_feature_loss 329 | ] 330 | 331 | d_summary_list = [self.summary_d_loss] 332 | 333 | self.summary_merge_g_loss = tf.summary.merge(g_summary_list) 334 | self.summary_merge_d_loss = tf.summary.merge(d_summary_list) 335 | 336 | else : 337 | """ Test """ 338 | self.ema = tf.train.ExponentialMovingAverage(decay=self.ema_decay) 339 | self.test_content_img = tf.placeholder(tf.float32, [1, self.img_height, self.img_width, self.img_ch]) 340 | self.test_class_img = tf.placeholder(tf.float32, [self.K, self.img_height, self.img_width, self.img_ch]) 341 | 342 | test_content_code = self.content_encoder(self.test_content_img) 343 | test_style_class_code = tf.reduce_mean(self.class_encoder(self.test_class_img), axis=0, keepdims=True) 344 | 345 | self.test_fake_img = self.generator(test_content_code, test_style_class_code) 346 | 347 | def train(self): 348 | # initialize all variables 349 | tf.global_variables_initializer().run() 350 | 351 | # saver to save model 352 | self.saver = tf.train.Saver(max_to_keep=20) 353 | 354 | # summary writer 355 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_dir, self.sess.graph) 356 | 357 | # restore check-point if it exits 358 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 359 | if could_load: 360 | start_batch_id = checkpoint_counter 361 | counter = checkpoint_counter 362 | print(" [*] Load SUCCESS") 363 | 364 | else: 365 | start_batch_id = 0 366 | counter = 1 367 | print(" [!] Load failed...") 368 | 369 | # loop for epoch 370 | start_time = time.time() 371 | for idx in range(start_batch_id, self.iteration): 372 | 373 | # Update D 374 | _, d_loss, summary_str = self.sess.run([self.D_optim, self.d_loss, self.summary_merge_d_loss]) 375 | self.writer.add_summary(summary_str, counter) 376 | 377 | # Update G 378 | content_images, style_images, fake_x_images, _, g_loss, summary_str = self.sess.run( 379 | [self.content_img[0], self.style_img[0], self.fake_img[0], 380 | self.G_optim, 381 | self.g_loss, self.summary_merge_g_loss]) 382 | 383 | self.writer.add_summary(summary_str, counter) 384 | 385 | 386 | # display training status 387 | counter += 1 388 | print("iter: [%6d/%6d] time: %4.4f d_loss: %.8f, g_loss: %.8f" % (idx, self.iteration, time.time() - start_time, d_loss, g_loss)) 389 | 390 | if np.mod(idx + 1, self.print_freq) == 0: 391 | content_images = np.expand_dims(content_images[0], axis=0) 392 | style_images = np.expand_dims(style_images[0], axis=0) 393 | fake_x_images = np.expand_dims(fake_x_images[0], axis=0) 394 | 395 | merge_images = np.concatenate([content_images, style_images, fake_x_images], axis=0) 396 | 397 | save_images(merge_images, [1, 3], 398 | './{}/merge_{:07d}.jpg'.format(self.sample_dir, idx + 1)) 399 | 400 | # save_images(content_images, [1, 1], 401 | # './{}/content_{:07d}.jpg'.format(self.sample_dir, idx + 1)) 402 | # 403 | # save_images(style_images, [1, 1], 404 | # './{}/style_{:07d}.jpg'.format(self.sample_dir, idx + 1)) 405 | # 406 | # save_images(fake_x_images, [1, 1], 407 | # './{}/fake_{:07d}.jpg'.format(self.sample_dir, idx + 1)) 408 | 409 | 410 | if np.mod(counter - 1, self.save_freq) == 0: 411 | self.save(self.checkpoint_dir, counter) 412 | 413 | # save model for final step 414 | self.save(self.checkpoint_dir, counter) 415 | 416 | @property 417 | def model_dir(self): 418 | if self.sn: 419 | sn = '_sn' 420 | else: 421 | sn = '' 422 | 423 | return "{}_{}_{}_{}adv_{}feature_{}recon{}".format(self.model_name, self.dataset_name, self.gan_type, 424 | self.adv_weight, self.feature_weight, self.recon_weight, 425 | sn) 426 | 427 | def save(self, checkpoint_dir, step): 428 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) 429 | 430 | if not os.path.exists(checkpoint_dir): 431 | os.makedirs(checkpoint_dir) 432 | 433 | self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step) 434 | 435 | def load(self, checkpoint_dir): 436 | print(" [*] Reading checkpoints...") 437 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) 438 | 439 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 440 | if ckpt and ckpt.model_checkpoint_path: 441 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 442 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 443 | counter = int(ckpt_name.split('-')[-1]) 444 | print(" [*] Success to read {}".format(ckpt_name)) 445 | return True, counter 446 | else: 447 | print(" [*] Failed to find a checkpoint") 448 | return False, 0 449 | 450 | def test(self): 451 | tf.global_variables_initializer().run() 452 | 453 | content_images = glob('./dataset/{}/{}/{}/*.*'.format(self.dataset_name, 'test', 'content')) 454 | class_images = glob('./dataset/{}/{}/{}/*.*'.format(self.dataset_name, 'test', 'class')) 455 | 456 | t_vars = tf.trainable_variables() 457 | G_vars = [var for var in t_vars if 'encoder' in var.name or 'generator' in var.name] 458 | 459 | shadow_G_vars_dict = {} 460 | 461 | for g_var in G_vars : 462 | shadow_G_vars_dict[self.ema.average_name(g_var)] = g_var 463 | 464 | self.saver = tf.train.Saver(shadow_G_vars_dict) 465 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 466 | self.result_dir = os.path.join(self.result_dir, self.model_dir) 467 | check_folder(self.result_dir) 468 | 469 | if could_load: 470 | print(" [*] Load SUCCESS") 471 | else: 472 | print(" [!] Load failed...") 473 | 474 | # write html for visual comparison 475 | index_path = os.path.join(self.result_dir, 'index.html') 476 | index = open(index_path, 'w') 477 | index.write("") 478 | index.write("") 479 | 480 | for sample_content_image in tqdm(content_images): 481 | sample_image = load_test_image(sample_content_image, self.img_width, self.img_height) 482 | 483 | random_class_images = np.random.choice(class_images, size=self.K, replace=False) 484 | sample_class_image = np.concatenate([load_test_image(x, self.img_width, self.img_height) for x in random_class_images]) 485 | 486 | fake_path = os.path.join(self.result_dir, '{}'.format(os.path.basename(sample_content_image))) 487 | class_path = os.path.join(self.result_dir, 'style_{}'.format(os.path.basename(sample_content_image))) 488 | 489 | fake_img = self.sess.run(self.test_fake_img, feed_dict={self.test_content_img : sample_image, self.test_class_img : sample_class_image}) 490 | 491 | save_images(fake_img, [1, 1], fake_path) 492 | save_images(sample_class_image, [1, self.K], class_path) 493 | 494 | index.write("" % os.path.basename(sample_content_image)) 495 | index.write( 496 | "" % (sample_content_image if os.path.isabs(sample_content_image) else ( 497 | '../..' + os.path.sep + sample_content_image), self.img_width, self.img_height)) 498 | 499 | index.write( 500 | "" % (class_path if os.path.isabs(class_path) else ( 501 | '../..' + os.path.sep + class_path), self.img_width * self.K, self.img_height)) 502 | index.write( 503 | "" % (fake_path if os.path.isabs(fake_path) else ( 504 | '../..' + os.path.sep + fake_path), self.img_width, self.img_height)) 505 | index.write("") 506 | 507 | index.close() -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Junho Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FUNIT-Tensorflow 2 | ## : Few-Shot Unsupervised Image-to-Image Translation (ICCV 2019) 3 | 4 |
5 | 6 | 7 |
8 | 9 | ### [Paper](https://arxiv.org/abs/1905.01723) | [Official Pytorch code](https://github.com/NVlabs/FUNIT) 10 | 11 | ### [Other Pytorch Implementation](https://github.com/znxlwm/FUNIT-pytorch) 12 | 13 | ## Usage 14 | ``` 15 | ├── dataset 16 |    └── YOUR_DATASET_NAME 17 |    ├── train 18 |           ├── class1 (class folder) 19 | ├── xxx.jpg (class1 image) 20 | ├── yyy.png 21 | ├── ... 22 | ├── class2 23 | ├── aaa.jpg (class2 image) 24 | ├── bbb.png 25 | ├── ... 26 | ├── class3 27 | ├── ... 28 |    ├── test 29 | ├── content (content folder) 30 | ├── zzz.jpg (any content image) 31 | ├── www.png 32 | ├── ... 33 | ├── class (class folder) 34 | ├── ccc.jpg (unseen target class image) 35 | ├── ddd.jpg 36 | ├── ... 37 | ``` 38 | 39 | ### Train 40 | ``` 41 | > python main.py --dataset flower 42 | ``` 43 | 44 | ### Test 45 | ``` 46 | > python main.py --dataset flower --phase test 47 | ``` 48 | 49 | ## Architecture 50 | ![architecture](./assets/architecture.png) 51 | 52 | ## Our result 53 | ![our_result](./assets/our_result.png) 54 | 55 | ## Paper result 56 | ![paper_result](./assets/funit_example.jpg) 57 | 58 | ## Author 59 | [Junho Kim](http://bit.ly/jhkim_ai) 60 | -------------------------------------------------------------------------------- /assets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/FUNIT-Tensorflow/6a160e5690544359133fc0860cedf2a61dbdcaf9/assets/.DS_Store -------------------------------------------------------------------------------- /assets/animal.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/FUNIT-Tensorflow/6a160e5690544359133fc0860cedf2a61dbdcaf9/assets/animal.gif -------------------------------------------------------------------------------- /assets/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/FUNIT-Tensorflow/6a160e5690544359133fc0860cedf2a61dbdcaf9/assets/architecture.png -------------------------------------------------------------------------------- /assets/funit_example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/FUNIT-Tensorflow/6a160e5690544359133fc0860cedf2a61dbdcaf9/assets/funit_example.jpg -------------------------------------------------------------------------------- /assets/our_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/FUNIT-Tensorflow/6a160e5690544359133fc0860cedf2a61dbdcaf9/assets/our_result.png -------------------------------------------------------------------------------- /assets/process.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/FUNIT-Tensorflow/6a160e5690544359133fc0860cedf2a61dbdcaf9/assets/process.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from FUNIT import FUNIT 2 | import argparse 3 | from utils import * 4 | 5 | """parsing and configuration""" 6 | def parse_args(): 7 | desc = "Tensorflow implementation of FUNIT" 8 | parser = argparse.ArgumentParser(description=desc) 9 | parser.add_argument('--phase', type=str, default='train', choices=('train', 'test'), help='phase name') 10 | parser.add_argument('--dataset', type=str, default='flower', help='dataset_name') 11 | 12 | parser.add_argument('--iteration', type=int, default=800000, help='The number of training iterations') 13 | parser.add_argument('--batch_size', type=int, default=64, help='The size of batch size for each gpu') 14 | parser.add_argument('--print_freq', type=int, default=1000, help='The number of image_print_freq') 15 | parser.add_argument('--save_freq', type=int, default=10000, help='The number of ckpt_save_freq') 16 | parser.add_argument('--gpu_num', type=int, default=1, help='The number of gpu') 17 | 18 | parser.add_argument('--lr', type=float, default=0.0001, help='The learning rate') 19 | parser.add_argument('--ema_decay', type=float, default=0.999, help='ema decay value') 20 | parser.add_argument('--K', type=int, default=5, help='Test K') 21 | 22 | parser.add_argument('--gan_type', type=str, default='hinge', help='[gan / lsgan / hinge]') 23 | 24 | parser.add_argument('--adv_weight', type=int, default=1, help='Weight about GAN') 25 | parser.add_argument('--feature_weight', type=int, default=1, help='Weight about feature-matching') 26 | parser.add_argument('--recon_weight', type=int, default=0.1, help='Weight about reconstruction') 27 | 28 | parser.add_argument('--latent_dim', type=int, default=64, help='The dimension of class code') 29 | parser.add_argument('--ch', type=int, default=64, help='base channel number per layer') 30 | 31 | parser.add_argument('--sn', type=str2bool, default=False, help='using spectral norm') 32 | 33 | parser.add_argument('--img_height', type=int, default=128, help='The height size of image') 34 | parser.add_argument('--img_width', type=int, default=128, help='The width size of image ') 35 | parser.add_argument('--img_ch', type=int, default=3, help='The size of image channel') 36 | parser.add_argument('--augment_flag', type=str2bool, default=True, help='Image augmentation use or not') 37 | 38 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint', 39 | help='Directory name to save the checkpoints') 40 | parser.add_argument('--result_dir', type=str, default='results', 41 | help='Directory name to save the generated images') 42 | parser.add_argument('--log_dir', type=str, default='logs', 43 | help='Directory name to save training logs') 44 | parser.add_argument('--sample_dir', type=str, default='samples', 45 | help='Directory name to save the samples on training') 46 | 47 | return check_args(parser.parse_args()) 48 | 49 | """checking arguments""" 50 | def check_args(args): 51 | # --checkpoint_dir 52 | check_folder(args.checkpoint_dir) 53 | 54 | # --result_dir 55 | check_folder(args.result_dir) 56 | 57 | # --log_dir 58 | check_folder(args.log_dir) 59 | 60 | # --sample_dir 61 | check_folder(args.sample_dir) 62 | 63 | # --batch_size 64 | try: 65 | assert args.batch_size >= 1 66 | except: 67 | print('batch size must be larger than or equal to one') 68 | return args 69 | 70 | """main""" 71 | def main(): 72 | # parse arguments 73 | args = parse_args() 74 | if args is None: 75 | exit() 76 | 77 | # open session 78 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: 79 | gan = FUNIT(sess, args) 80 | 81 | # build graph 82 | gan.build_model() 83 | 84 | # show network architecture 85 | show_all_variables() 86 | 87 | if args.phase == 'train' : 88 | gan.train() 89 | print(" [*] Training finished!") 90 | 91 | if args.phase == 'test' : 92 | gan.test() 93 | print(" [*] Test finished!") 94 | 95 | 96 | 97 | if __name__ == '__main__': 98 | main() 99 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib as tf_contrib 3 | from utils import pytorch_kaiming_weight_factor 4 | 5 | ################################################################################## 6 | # Initialization 7 | ################################################################################## 8 | 9 | factor, mode, uniform = pytorch_kaiming_weight_factor(a=0.0, uniform=False) 10 | weight_init = tf_contrib.layers.variance_scaling_initializer(factor=factor, mode=mode, uniform=uniform) 11 | # weight_init = tf.random_normal_initializer(mean=0.0, stddev=0.02) 12 | 13 | weight_regularizer = tf.contrib.layers.l2_regularizer(0.0001) 14 | weight_regularizer_fully = tf.contrib.layers.l2_regularizer(0.0001) 15 | 16 | ################################################################################## 17 | # Layer 18 | ################################################################################## 19 | 20 | def conv(x, channels, kernel=4, stride=2, pad=0, pad_type='zero', use_bias=True, sn=False, scope='conv_0'): 21 | with tf.variable_scope(scope): 22 | if pad > 0: 23 | h = x.get_shape().as_list()[1] 24 | if h % stride == 0: 25 | pad = pad * 2 26 | else: 27 | pad = max(kernel - (h % stride), 0) 28 | 29 | pad_top = pad // 2 30 | pad_bottom = pad - pad_top 31 | pad_left = pad // 2 32 | pad_right = pad - pad_left 33 | 34 | if pad_type == 'zero': 35 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]) 36 | if pad_type == 'reflect': 37 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]], mode='REFLECT') 38 | 39 | if sn: 40 | w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], initializer=weight_init, 41 | regularizer=weight_regularizer) 42 | x = tf.nn.conv2d(input=x, filter=spectral_norm(w), 43 | strides=[1, stride, stride, 1], padding='VALID') 44 | if use_bias: 45 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0)) 46 | x = tf.nn.bias_add(x, bias) 47 | 48 | else: 49 | x = tf.layers.conv2d(inputs=x, filters=channels, 50 | kernel_size=kernel, kernel_initializer=weight_init, 51 | kernel_regularizer=weight_regularizer, 52 | strides=stride, use_bias=use_bias) 53 | 54 | return x 55 | 56 | 57 | def fully_connected(x, units, use_bias=True, sn=False, scope='linear'): 58 | with tf.variable_scope(scope): 59 | x = flatten(x) 60 | shape = x.get_shape().as_list() 61 | channels = shape[-1] 62 | 63 | if sn: 64 | w = tf.get_variable("kernel", [channels, units], tf.float32, 65 | initializer=weight_init, regularizer=weight_regularizer_fully) 66 | if use_bias: 67 | bias = tf.get_variable("bias", [units], 68 | initializer=tf.constant_initializer(0.0)) 69 | 70 | x = tf.matmul(x, spectral_norm(w)) + bias 71 | else: 72 | x = tf.matmul(x, spectral_norm(w)) 73 | 74 | else: 75 | x = tf.layers.dense(x, units=units, kernel_initializer=weight_init, 76 | kernel_regularizer=weight_regularizer_fully, 77 | use_bias=use_bias) 78 | 79 | return x 80 | 81 | 82 | def flatten(x): 83 | return tf.layers.flatten(x) 84 | 85 | 86 | ################################################################################## 87 | # Residual-block 88 | ################################################################################## 89 | 90 | 91 | def resblock(x_init, channels, use_bias=True, sn=False, scope='resblock'): 92 | with tf.variable_scope(scope): 93 | with tf.variable_scope('res1'): 94 | x = conv(x_init, channels, kernel=3, stride=1, pad=1, use_bias=use_bias, sn=sn) 95 | x = instance_norm(x) 96 | x = relu(x) 97 | 98 | with tf.variable_scope('res2'): 99 | x = conv(x, channels, kernel=3, stride=1, pad=1, use_bias=use_bias, sn=sn) 100 | x = instance_norm(x) 101 | 102 | return x + x_init 103 | 104 | def pre_resblock(x_init, channels, use_bias=True, sn=False, scope='resblock'): 105 | with tf.variable_scope(scope): 106 | _, _, _, init_channel = x_init.get_shape().as_list() 107 | 108 | with tf.variable_scope('res1'): 109 | x = lrelu(x_init, 0.2) 110 | x = conv(x, channels, kernel=3, stride=1, pad=1, use_bias=use_bias, sn=sn) 111 | 112 | with tf.variable_scope('res2'): 113 | x = lrelu(x, 0.2) 114 | x = conv(x, channels, kernel=3, stride=1, pad=1, use_bias=use_bias, sn=sn) 115 | 116 | if init_channel != channels : 117 | with tf.variable_scope('shortcut'): 118 | x_init = conv(x_init, channels, kernel=1, stride=1, use_bias=False, sn=sn) 119 | 120 | return x + x_init 121 | 122 | def adaptive_resblock(x_init, channels, gamma1, beta1, gamma2, beta2, use_bias=True, sn=False, scope='adaptive_resblock') : 123 | with tf.variable_scope(scope): 124 | with tf.variable_scope('res1'): 125 | x = conv(x_init, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias, sn=sn) 126 | x = adaptive_instance_norm(x, gamma1, beta1) 127 | x = relu(x) 128 | 129 | with tf.variable_scope('res2'): 130 | x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias, sn=sn) 131 | x = adaptive_instance_norm(x, gamma2, beta2) 132 | 133 | return x + x_init 134 | 135 | 136 | ################################################################################## 137 | # Sampling 138 | ################################################################################## 139 | 140 | def up_sample(x, scale_factor=2): 141 | _, h, w, _ = x.get_shape().as_list() 142 | new_size = [h * scale_factor, w * scale_factor] 143 | return tf.image.resize_nearest_neighbor(x, size=new_size) 144 | 145 | def down_sample_avg(x, scale_factor=2): 146 | return tf.layers.average_pooling2d(x, pool_size=3, strides=scale_factor, padding='SAME') 147 | 148 | def global_avg_pooling(x): 149 | gap = tf.reduce_mean(x, axis=[1, 2], keepdims=True) 150 | return gap 151 | 152 | 153 | ################################################################################## 154 | # Activation function 155 | ################################################################################## 156 | 157 | def lrelu(x, alpha=0.01): 158 | # pytorch alpha is 0.01 159 | return tf.nn.leaky_relu(x, alpha) 160 | 161 | 162 | def relu(x): 163 | return tf.nn.relu(x) 164 | 165 | 166 | def tanh(x): 167 | return tf.tanh(x) 168 | 169 | 170 | ################################################################################## 171 | # Normalization function 172 | ################################################################################## 173 | 174 | def instance_norm(x, scope='instance_norm'): 175 | return tf_contrib.layers.instance_norm(x, 176 | epsilon=1e-05, 177 | center=True, scale=True, 178 | scope=scope) 179 | 180 | def param_free_norm(x, epsilon=1e-5): 181 | x_mean, x_var = tf.nn.moments(x, axes=[1, 2], keep_dims=True) 182 | x_std = tf.sqrt(x_var + epsilon) 183 | 184 | return (x - x_mean) / x_std 185 | 186 | def adaptive_instance_norm(content, gamma, beta, epsilon=1e-5): 187 | # gamma, beta = style_mean, style_std from MLP 188 | 189 | x = param_free_norm(content, epsilon) 190 | 191 | return gamma * x + beta 192 | 193 | def spectral_norm(w, iteration=1): 194 | w_shape = w.shape.as_list() 195 | w = tf.reshape(w, [-1, w_shape[-1]]) 196 | 197 | u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.random_normal_initializer(), trainable=False) 198 | 199 | u_hat = u 200 | v_hat = None 201 | for i in range(iteration): 202 | """ 203 | power iteration 204 | Usually iteration = 1 will be enough 205 | """ 206 | v_ = tf.matmul(u_hat, tf.transpose(w)) 207 | v_hat = tf.nn.l2_normalize(v_) 208 | 209 | u_ = tf.matmul(v_hat, w) 210 | u_hat = tf.nn.l2_normalize(u_) 211 | 212 | u_hat = tf.stop_gradient(u_hat) 213 | v_hat = tf.stop_gradient(v_hat) 214 | 215 | sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat)) 216 | 217 | with tf.control_dependencies([u.assign(u_hat)]): 218 | w_norm = w / sigma 219 | w_norm = tf.reshape(w_norm, w_shape) 220 | 221 | return w_norm 222 | 223 | 224 | ################################################################################## 225 | # Loss function 226 | ################################################################################## 227 | 228 | def L1_loss(x, y): 229 | loss = tf.reduce_mean(tf.abs(x - y)) # [64, h, w, c] 230 | 231 | return loss 232 | 233 | def discriminator_loss(gan_type, real_logit, fake_logit, real_images): 234 | real_loss = 0 235 | fake_loss = 0 236 | 237 | if gan_type == 'lsgan': 238 | real_loss = tf.reduce_mean(tf.squared_difference(real_logit, 1.0)) 239 | fake_loss = tf.reduce_mean(tf.square(fake_logit)) 240 | 241 | if gan_type == 'gan': 242 | real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real_logit), logits=real_logit)) 243 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(fake_logit), logits=fake_logit)) 244 | 245 | if gan_type == 'hinge': 246 | 247 | real_loss = tf.reduce_mean(relu(1 - real_logit)) 248 | fake_loss = tf.reduce_mean(relu(1 + fake_logit)) 249 | 250 | return real_loss + fake_loss + real_gp(real_images, real_logit) 251 | 252 | def real_gp(real_images, real_logit) : 253 | grad_out = tf.gradients(tf.reduce_mean(real_logit), [real_images])[0] 254 | grad_out2 = tf.square(grad_out) 255 | 256 | r1_penalty = 10 * tf.reduce_mean(tf.reduce_sum(grad_out2, axis=[1, 2, 3])) 257 | 258 | return r1_penalty 259 | 260 | def generator_loss(gan_type, fake_logit): 261 | fake_loss = 0 262 | 263 | if gan_type == 'lsgan': 264 | fake_loss = tf.reduce_mean(tf.squared_difference(fake_logit, 1.0)) 265 | 266 | if gan_type == 'gan': 267 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(fake_logit), logits=fake_logit)) 268 | 269 | if gan_type == 'hinge': 270 | fake_loss = -tf.reduce_mean(fake_logit) 271 | 272 | return fake_loss 273 | 274 | 275 | def regularization_loss(scope_name): 276 | """ 277 | If you want to use "Regularization" 278 | g_loss += regularization_loss('generator') 279 | d_loss += regularization_loss('discriminator') 280 | """ 281 | collection_regularization = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) 282 | 283 | loss = [] 284 | for item in collection_regularization: 285 | if scope_name in item.name: 286 | loss.append(item) 287 | 288 | return tf.reduce_sum(loss) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib import slim 3 | import os 4 | import numpy as np 5 | from glob import glob 6 | import cv2 7 | 8 | class Image_data: 9 | 10 | def __init__(self, img_height, img_width, channels, dataset_path, augment_flag): 11 | self.img_height = img_height 12 | self.img_width = img_width 13 | self.channels = channels 14 | self.augment_flag = augment_flag 15 | 16 | self.dataset_path = dataset_path 17 | 18 | 19 | self.image_list = [] 20 | self.class_list = [] 21 | 22 | 23 | def image_processing(self, filename, label): 24 | x = tf.read_file(filename) 25 | x_decode = tf.image.decode_jpeg(x, channels=self.channels, dct_method='INTEGER_ACCURATE') 26 | img = preprocess_fit_train_image(x_decode, self.img_height, self.img_width) 27 | 28 | 29 | if self.augment_flag : 30 | augment_height_size = self.img_height + (30 if self.img_height == 256 else int(self.img_height * 0.1)) 31 | augment_width_size = self.img_width + (30 if self.img_width == 256 else int(self.img_width * 0.1)) 32 | 33 | img = tf.cond(pred=tf.greater_equal(tf.random_uniform(shape=[], minval=0.0, maxval=1.0), 0.5), 34 | true_fn=lambda : augmentation(img, augment_height_size, augment_width_size), 35 | false_fn=lambda : img) 36 | 37 | return img, label 38 | 39 | def preprocess(self): 40 | self.class_label = [os.path.basename(x) for x in glob(self.dataset_path + '/*')] 41 | 42 | v = 0 43 | 44 | for class_label in self.class_label : 45 | class_one_hot = list(get_one_hot(v, len(self.class_label))) # [1, 0, 0, 0, 0] 46 | v = v+1 47 | 48 | image_list = glob(os.path.join(self.dataset_path, class_label) + '/*.png') + glob(os.path.join(self.dataset_path, class_label) + '/*.jpg') 49 | class_one_hot = [class_one_hot] * len(image_list) 50 | 51 | self.image_list.extend(image_list) 52 | self.class_list.extend(class_one_hot) 53 | 54 | def load_test_image(image_path, img_width, img_height): 55 | 56 | img = cv2.imread(image_path, flags=cv2.IMREAD_COLOR) 57 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 58 | 59 | img = cv2.resize(img, dsize=(img_width, img_height)) 60 | img = np.expand_dims(img, axis=0) 61 | 62 | img = adjust_dynamic_range(img) 63 | 64 | return img 65 | 66 | 67 | def preprocessing(x): 68 | x = x/127.5 - 1 # -1 ~ 1 69 | return x 70 | 71 | def preprocess_fit_train_image(images, height, width): 72 | images = tf.image.resize(images, size=[height, width], method=tf.image.ResizeMethod.BILINEAR) 73 | images = adjust_dynamic_range(images) 74 | 75 | return images 76 | 77 | def adjust_dynamic_range(images): 78 | drange_in = [0.0, 255.0] 79 | drange_out = [-1.0, 1.0] 80 | scale = (drange_out[1] - drange_out[0]) / (drange_in[1] - drange_in[0]) 81 | bias = drange_out[0] - drange_in[0] * scale 82 | images = images * scale + bias 83 | return images 84 | 85 | def augmentation(image, augment_height, augment_width): 86 | seed = np.random.randint(0, 2 ** 31 - 1) 87 | 88 | ori_image_shape = tf.shape(image) 89 | image = tf.image.random_flip_left_right(image, seed=seed) 90 | image = tf.image.resize(image, size=[augment_height, augment_width], method=tf.image.ResizeMethod.BILINEAR) 91 | image = tf.random_crop(image, ori_image_shape, seed=seed) 92 | 93 | 94 | return image 95 | 96 | def save_images(images, size, image_path): 97 | return imsave(images, size, image_path) 98 | 99 | def imsave(images, size, path): 100 | images = merge(images, size) 101 | images = post_process_generator_output(images) 102 | images = cv2.cvtColor(images.astype('uint8'), cv2.COLOR_RGB2BGR) 103 | cv2.imwrite(path, images) 104 | 105 | def post_process_generator_output(generator_output): 106 | 107 | drange_min, drange_max = -1.0, 1.0 108 | scale = 255.0 / (drange_max - drange_min) 109 | 110 | scaled_image = generator_output * scale + (0.5 - drange_min * scale) 111 | scaled_image = np.clip(scaled_image, 0, 255) 112 | 113 | return scaled_image 114 | 115 | def merge(images, size): 116 | h, w = images.shape[1], images.shape[2] 117 | c = images.shape[3] 118 | img = np.zeros((h * size[0], w * size[1], c)) 119 | for idx, image in enumerate(images): 120 | i = idx % size[1] 121 | j = idx // size[1] 122 | img[h*j:h*(j+1), w*i:w*(i+1), :] = image 123 | 124 | return img 125 | 126 | def show_all_variables(): 127 | model_vars = tf.trainable_variables() 128 | slim.model_analyzer.analyze_vars(model_vars, print_info=True) 129 | 130 | def check_folder(log_dir): 131 | if not os.path.exists(log_dir): 132 | os.makedirs(log_dir) 133 | return log_dir 134 | 135 | def str2bool(x): 136 | return x.lower() in ('true') 137 | 138 | def get_one_hot(targets, nb_classes): 139 | 140 | x = np.eye(nb_classes)[targets] 141 | 142 | return x 143 | 144 | def pytorch_xavier_weight_factor(gain=0.02, uniform=False) : 145 | 146 | if uniform : 147 | factor = gain * gain 148 | mode = 'FAN_AVG' 149 | else : 150 | factor = (gain * gain) / 1.3 151 | mode = 'FAN_AVG' 152 | 153 | return factor, mode, uniform 154 | 155 | def pytorch_kaiming_weight_factor(a=0.0, activation_function='leaky_relu', uniform=False) : 156 | 157 | if activation_function == 'relu' : 158 | gain = np.sqrt(2.0) 159 | elif activation_function == 'leaky_relu' : 160 | gain = np.sqrt(2.0 / (1 + a ** 2)) 161 | elif activation_function == 'tanh' : 162 | gain = 5.0 / 3 163 | else : 164 | gain = 1.0 165 | 166 | if uniform : 167 | factor = gain * gain 168 | mode = 'FAN_IN' 169 | else : 170 | factor = (gain * gain) / 1.3 171 | mode = 'FAN_IN' 172 | 173 | return factor, mode, uniform 174 | --------------------------------------------------------------------------------
namecontentstyleoutput
%s