├── .DS_Store ├── .gitignore ├── LICENSE ├── README.md ├── UGATIT.py ├── assets ├── .DS_Store ├── ablation.png ├── discriminator_fix.png ├── generator_fix.png ├── kid_fix2.png ├── teaser.png └── user_study.png ├── main.py ├── ops.py └── utils.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/UGATIT/d508e8f5188e47000d79d8aecada0cc9119e0d56/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Junho Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## U-GAT-IT — Official TensorFlow Implementation (ICLR 2020) 2 | ### : Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization for Image-to-Image Translation 3 | 4 |
5 | 6 |
7 | 8 | ### [Paper](https://arxiv.org/abs/1907.10830) | [Official Pytorch code](https://github.com/znxlwm/UGATIT-pytorch) 9 | This repository provides the **official Tensorflow implementation** of the following paper: 10 | 11 | > **U-GAT-IT: Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization for Image-to-Image Translation**
12 | > **Junho Kim (NCSOFT)**, Minjae Kim (NCSOFT), Hyeonwoo Kang (NCSOFT), Kwanghee Lee (Boeing Korea) 13 | > 14 | > **Abstract** *We propose a novel method for unsupervised image-to-image translation, which incorporates a new attention module and a new learnable normalization function in an end-to-end manner. The attention module guides our model to focus on more important regions distinguishing between source and target domains based on the attention map obtained by the auxiliary classifier. Unlike previous attention-based methods which cannot handle the geometric changes between domains, our model can translate both images requiring holistic changes and images requiring large shape changes. Moreover, our new AdaLIN (Adaptive Layer-Instance Normalization) function helps our attention-guided model to flexibly control the amount of change in shape and texture by learned parameters depending on datasets. Experimental results show the superiority of the proposed method compared to the existing state-of-the-art models with a fixed network architecture and hyper-parameters.* 15 | 16 | ## Requirements 17 | * python == 3.6 18 | * tensorflow == 1.14 19 | 20 | ## Pretrained model 21 | > We released 50 epoch and 100 epoch checkpoints so that people could test more widely. 22 | * [selfie2anime checkpoint (50 epoch)](https://drive.google.com/file/d/1V6GbSItG3HZKv3quYs7AP0rr1kOCT3QO/view?usp=sharing) 23 | * [selfie2anime checkpoint (100 epoch)](https://drive.google.com/file/d/19xQK2onIy-3S5W5K-XIh85pAg_RNvBVf/view?usp=sharing) 24 | 25 | ## Dataset 26 | * [selfie2anime dataset](https://drive.google.com/file/d/1xOWj1UVgp6NKMT3HbPhBbtq2A4EDkghF/view?usp=sharing) 27 | 28 | ## Web page 29 | * [Selfie2Anime](https://selfie2anime.com) by [Nathan Glover](https://github.com/t04glovern) 30 | * [Selfie2Waifu](https://waifu.lofiu.com) by [creke](https://github.com/creke) 31 | 32 | ## Telegram Bot 33 | * [Selfie2AnimeBot](https://t.me/selfie2animebot) by [Alex Spirin](https://github.com/sxela) 34 | 35 | ## Usage 36 | ``` 37 | ├── dataset 38 |    └── YOUR_DATASET_NAME 39 |    ├── trainA 40 |           ├── xxx.jpg (name, format doesn't matter) 41 | ├── yyy.png 42 | └── ... 43 |    ├── trainB 44 | ├── zzz.jpg 45 | ├── www.png 46 | └── ... 47 |    ├── testA 48 |    ├── aaa.jpg 49 | ├── bbb.png 50 | └── ... 51 |    └── testB 52 | ├── ccc.jpg 53 | ├── ddd.png 54 | └── ... 55 | ``` 56 | 57 | ### Train 58 | ``` 59 | > python main.py --dataset selfie2anime 60 | ``` 61 | * If the memory of gpu is **not sufficient**, set `--light` to **True** 62 | * But it may **not** perform well 63 | * paper version is `--light` to **False** 64 | 65 | ### Test 66 | ``` 67 | > python main.py --dataset selfie2anime --phase test 68 | ``` 69 | 70 | ## Architecture 71 |
72 | 73 |
74 | 75 | --- 76 | 77 |
78 | 79 |
80 | 81 | ## Results 82 | ### Ablation study 83 |
84 | 85 |
86 | 87 | ### User study 88 |
89 | 90 |
91 | 92 | ### Kernel Inception Distance (KID) 93 |
94 | 95 |
96 | 97 | ## Citation 98 | If you find this code useful for your research, please cite our paper: 99 | 100 | ``` 101 | @inproceedings{ 102 | Kim2020U-GAT-IT:, 103 | title={U-GAT-IT: Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization for Image-to-Image Translation}, 104 | author={Junho Kim and Minjae Kim and Hyeonwoo Kang and Kwang Hee Lee}, 105 | booktitle={International Conference on Learning Representations}, 106 | year={2020}, 107 | url={https://openreview.net/forum?id=BJlZ5ySKPH} 108 | } 109 | ``` 110 | 111 | ## Author 112 | [Junho Kim](http://bit.ly/jhkim_ai), Minjae Kim, Hyeonwoo Kang, Kwanghee Lee 113 | -------------------------------------------------------------------------------- /UGATIT.py: -------------------------------------------------------------------------------- 1 | from ops import * 2 | from utils import * 3 | from glob import glob 4 | import time 5 | from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch 6 | import numpy as np 7 | 8 | class UGATIT(object) : 9 | def __init__(self, sess, args): 10 | self.light = args.light 11 | 12 | if self.light : 13 | self.model_name = 'UGATIT_light' 14 | else : 15 | self.model_name = 'UGATIT' 16 | 17 | self.sess = sess 18 | self.phase = args.phase 19 | self.checkpoint_dir = args.checkpoint_dir 20 | self.result_dir = args.result_dir 21 | self.log_dir = args.log_dir 22 | self.dataset_name = args.dataset 23 | self.augment_flag = args.augment_flag 24 | 25 | self.epoch = args.epoch 26 | self.iteration = args.iteration 27 | self.decay_flag = args.decay_flag 28 | self.decay_epoch = args.decay_epoch 29 | 30 | self.gan_type = args.gan_type 31 | 32 | self.batch_size = args.batch_size 33 | self.print_freq = args.print_freq 34 | self.save_freq = args.save_freq 35 | 36 | self.init_lr = args.lr 37 | self.ch = args.ch 38 | 39 | """ Weight """ 40 | self.adv_weight = args.adv_weight 41 | self.cycle_weight = args.cycle_weight 42 | self.identity_weight = args.identity_weight 43 | self.cam_weight = args.cam_weight 44 | self.ld = args.GP_ld 45 | self.smoothing = args.smoothing 46 | 47 | """ Generator """ 48 | self.n_res = args.n_res 49 | 50 | """ Discriminator """ 51 | self.n_dis = args.n_dis 52 | self.n_critic = args.n_critic 53 | self.sn = args.sn 54 | 55 | self.img_size = args.img_size 56 | self.img_ch = args.img_ch 57 | 58 | 59 | self.sample_dir = os.path.join(args.sample_dir, self.model_dir) 60 | check_folder(self.sample_dir) 61 | 62 | # self.trainA, self.trainB = prepare_data(dataset_name=self.dataset_name, size=self.img_size 63 | self.trainA_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainA')) 64 | self.trainB_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainB')) 65 | self.dataset_num = max(len(self.trainA_dataset), len(self.trainB_dataset)) 66 | 67 | print() 68 | 69 | print("##### Information #####") 70 | print("# light : ", self.light) 71 | print("# gan type : ", self.gan_type) 72 | print("# dataset : ", self.dataset_name) 73 | print("# max dataset number : ", self.dataset_num) 74 | print("# batch_size : ", self.batch_size) 75 | print("# epoch : ", self.epoch) 76 | print("# iteration per epoch : ", self.iteration) 77 | print("# smoothing : ", self.smoothing) 78 | 79 | print() 80 | 81 | print("##### Generator #####") 82 | print("# residual blocks : ", self.n_res) 83 | 84 | print() 85 | 86 | print("##### Discriminator #####") 87 | print("# discriminator layer : ", self.n_dis) 88 | print("# the number of critic : ", self.n_critic) 89 | print("# spectral normalization : ", self.sn) 90 | 91 | print() 92 | 93 | print("##### Weight #####") 94 | print("# adv_weight : ", self.adv_weight) 95 | print("# cycle_weight : ", self.cycle_weight) 96 | print("# identity_weight : ", self.identity_weight) 97 | print("# cam_weight : ", self.cam_weight) 98 | 99 | ################################################################################## 100 | # Generator 101 | ################################################################################## 102 | 103 | def generator(self, x_init, reuse=False, scope="generator"): 104 | channel = self.ch 105 | with tf.variable_scope(scope, reuse=reuse) : 106 | x = conv(x_init, channel, kernel=7, stride=1, pad=3, pad_type='reflect', scope='conv') 107 | x = instance_norm(x, scope='ins_norm') 108 | x = relu(x) 109 | 110 | # Down-Sampling 111 | for i in range(2) : 112 | x = conv(x, channel*2, kernel=3, stride=2, pad=1, pad_type='reflect', scope='conv_'+str(i)) 113 | x = instance_norm(x, scope='ins_norm_'+str(i)) 114 | x = relu(x) 115 | 116 | channel = channel * 2 117 | 118 | # Down-Sampling Bottleneck 119 | for i in range(self.n_res): 120 | x = resblock(x, channel, scope='resblock_' + str(i)) 121 | 122 | 123 | # Class Activation Map 124 | cam_x = global_avg_pooling(x) 125 | cam_gap_logit, cam_x_weight = fully_connected_with_w(cam_x, scope='CAM_logit') 126 | x_gap = tf.multiply(x, cam_x_weight) 127 | 128 | cam_x = global_max_pooling(x) 129 | cam_gmp_logit, cam_x_weight = fully_connected_with_w(cam_x, reuse=True, scope='CAM_logit') 130 | x_gmp = tf.multiply(x, cam_x_weight) 131 | 132 | 133 | cam_logit = tf.concat([cam_gap_logit, cam_gmp_logit], axis=-1) 134 | x = tf.concat([x_gap, x_gmp], axis=-1) 135 | 136 | x = conv(x, channel, kernel=1, stride=1, scope='conv_1x1') 137 | x = relu(x) 138 | 139 | heatmap = tf.squeeze(tf.reduce_sum(x, axis=-1)) 140 | 141 | # Gamma, Beta block 142 | gamma, beta = self.MLP(x, reuse=reuse) 143 | 144 | # Up-Sampling Bottleneck 145 | for i in range(self.n_res): 146 | x = adaptive_ins_layer_resblock(x, channel, gamma, beta, smoothing=self.smoothing, scope='adaptive_resblock' + str(i)) 147 | 148 | # Up-Sampling 149 | for i in range(2) : 150 | x = up_sample(x, scale_factor=2) 151 | x = conv(x, channel//2, kernel=3, stride=1, pad=1, pad_type='reflect', scope='up_conv_'+str(i)) 152 | x = layer_instance_norm(x, scope='layer_ins_norm_'+str(i)) 153 | x = relu(x) 154 | 155 | channel = channel // 2 156 | 157 | 158 | x = conv(x, channels=3, kernel=7, stride=1, pad=3, pad_type='reflect', scope='G_logit') 159 | x = tanh(x) 160 | 161 | return x, cam_logit, heatmap 162 | 163 | def MLP(self, x, use_bias=True, reuse=False, scope='MLP'): 164 | channel = self.ch * self.n_res 165 | 166 | if self.light : 167 | x = global_avg_pooling(x) 168 | 169 | with tf.variable_scope(scope, reuse=reuse): 170 | for i in range(2) : 171 | x = fully_connected(x, channel, use_bias, scope='linear_' + str(i)) 172 | x = relu(x) 173 | 174 | 175 | gamma = fully_connected(x, channel, use_bias, scope='gamma') 176 | beta = fully_connected(x, channel, use_bias, scope='beta') 177 | 178 | gamma = tf.reshape(gamma, shape=[self.batch_size, 1, 1, channel]) 179 | beta = tf.reshape(beta, shape=[self.batch_size, 1, 1, channel]) 180 | 181 | return gamma, beta 182 | 183 | ################################################################################## 184 | # Discriminator 185 | ################################################################################## 186 | 187 | def discriminator(self, x_init, reuse=False, scope="discriminator"): 188 | D_logit = [] 189 | D_CAM_logit = [] 190 | with tf.variable_scope(scope, reuse=reuse) : 191 | local_x, local_cam, local_heatmap = self.discriminator_local(x_init, reuse=reuse, scope='local') 192 | global_x, global_cam, global_heatmap = self.discriminator_global(x_init, reuse=reuse, scope='global') 193 | 194 | D_logit.extend([local_x, global_x]) 195 | D_CAM_logit.extend([local_cam, global_cam]) 196 | 197 | return D_logit, D_CAM_logit, local_heatmap, global_heatmap 198 | 199 | def discriminator_global(self, x_init, reuse=False, scope='discriminator_global'): 200 | with tf.variable_scope(scope, reuse=reuse): 201 | channel = self.ch 202 | x = conv(x_init, channel, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_0') 203 | x = lrelu(x, 0.2) 204 | 205 | for i in range(1, self.n_dis - 1): 206 | x = conv(x, channel * 2, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_' + str(i)) 207 | x = lrelu(x, 0.2) 208 | 209 | channel = channel * 2 210 | 211 | x = conv(x, channel * 2, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='conv_last') 212 | x = lrelu(x, 0.2) 213 | 214 | channel = channel * 2 215 | 216 | cam_x = global_avg_pooling(x) 217 | cam_gap_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, scope='CAM_logit') 218 | x_gap = tf.multiply(x, cam_x_weight) 219 | 220 | cam_x = global_max_pooling(x) 221 | cam_gmp_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, reuse=True, scope='CAM_logit') 222 | x_gmp = tf.multiply(x, cam_x_weight) 223 | 224 | cam_logit = tf.concat([cam_gap_logit, cam_gmp_logit], axis=-1) 225 | x = tf.concat([x_gap, x_gmp], axis=-1) 226 | 227 | x = conv(x, channel, kernel=1, stride=1, scope='conv_1x1') 228 | x = lrelu(x, 0.2) 229 | 230 | heatmap = tf.squeeze(tf.reduce_sum(x, axis=-1)) 231 | 232 | 233 | x = conv(x, channels=1, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='D_logit') 234 | 235 | return x, cam_logit, heatmap 236 | 237 | def discriminator_local(self, x_init, reuse=False, scope='discriminator_local'): 238 | with tf.variable_scope(scope, reuse=reuse) : 239 | channel = self.ch 240 | x = conv(x_init, channel, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_0') 241 | x = lrelu(x, 0.2) 242 | 243 | for i in range(1, self.n_dis - 2 - 1): 244 | x = conv(x, channel * 2, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_' + str(i)) 245 | x = lrelu(x, 0.2) 246 | 247 | channel = channel * 2 248 | 249 | x = conv(x, channel * 2, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='conv_last') 250 | x = lrelu(x, 0.2) 251 | 252 | channel = channel * 2 253 | 254 | cam_x = global_avg_pooling(x) 255 | cam_gap_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, scope='CAM_logit') 256 | x_gap = tf.multiply(x, cam_x_weight) 257 | 258 | cam_x = global_max_pooling(x) 259 | cam_gmp_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, reuse=True, scope='CAM_logit') 260 | x_gmp = tf.multiply(x, cam_x_weight) 261 | 262 | cam_logit = tf.concat([cam_gap_logit, cam_gmp_logit], axis=-1) 263 | x = tf.concat([x_gap, x_gmp], axis=-1) 264 | 265 | x = conv(x, channel, kernel=1, stride=1, scope='conv_1x1') 266 | x = lrelu(x, 0.2) 267 | 268 | heatmap = tf.squeeze(tf.reduce_sum(x, axis=-1)) 269 | 270 | x = conv(x, channels=1, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='D_logit') 271 | 272 | return x, cam_logit, heatmap 273 | 274 | ################################################################################## 275 | # Model 276 | ################################################################################## 277 | 278 | def generate_a2b(self, x_A, reuse=False): 279 | out, cam, _ = self.generator(x_A, reuse=reuse, scope="generator_B") 280 | 281 | return out, cam 282 | 283 | def generate_b2a(self, x_B, reuse=False): 284 | out, cam, _ = self.generator(x_B, reuse=reuse, scope="generator_A") 285 | 286 | return out, cam 287 | 288 | def discriminate_real(self, x_A, x_B): 289 | real_A_logit, real_A_cam_logit, _, _ = self.discriminator(x_A, scope="discriminator_A") 290 | real_B_logit, real_B_cam_logit, _, _ = self.discriminator(x_B, scope="discriminator_B") 291 | 292 | return real_A_logit, real_A_cam_logit, real_B_logit, real_B_cam_logit 293 | 294 | def discriminate_fake(self, x_ba, x_ab): 295 | fake_A_logit, fake_A_cam_logit, _, _ = self.discriminator(x_ba, reuse=True, scope="discriminator_A") 296 | fake_B_logit, fake_B_cam_logit, _, _ = self.discriminator(x_ab, reuse=True, scope="discriminator_B") 297 | 298 | return fake_A_logit, fake_A_cam_logit, fake_B_logit, fake_B_cam_logit 299 | 300 | def gradient_panalty(self, real, fake, scope="discriminator_A"): 301 | if self.gan_type.__contains__('dragan'): 302 | eps = tf.random_uniform(shape=tf.shape(real), minval=0., maxval=1.) 303 | _, x_var = tf.nn.moments(real, axes=[0, 1, 2, 3]) 304 | x_std = tf.sqrt(x_var) # magnitude of noise decides the size of local region 305 | 306 | fake = real + 0.5 * x_std * eps 307 | 308 | alpha = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0., maxval=1.) 309 | interpolated = real + alpha * (fake - real) 310 | 311 | logit, cam_logit, _, _ = self.discriminator(interpolated, reuse=True, scope=scope) 312 | 313 | 314 | GP = [] 315 | cam_GP = [] 316 | 317 | for i in range(2) : 318 | grad = tf.gradients(logit[i], interpolated)[0] # gradient of D(interpolated) 319 | grad_norm = tf.norm(flatten(grad), axis=1) # l2 norm 320 | 321 | # WGAN - LP 322 | if self.gan_type == 'wgan-lp' : 323 | GP.append(self.ld * tf.reduce_mean(tf.square(tf.maximum(0.0, grad_norm - 1.)))) 324 | 325 | elif self.gan_type == 'wgan-gp' or self.gan_type == 'dragan': 326 | GP.append(self.ld * tf.reduce_mean(tf.square(grad_norm - 1.))) 327 | 328 | for i in range(2) : 329 | grad = tf.gradients(cam_logit[i], interpolated)[0] # gradient of D(interpolated) 330 | grad_norm = tf.norm(flatten(grad), axis=1) # l2 norm 331 | 332 | # WGAN - LP 333 | if self.gan_type == 'wgan-lp' : 334 | cam_GP.append(self.ld * tf.reduce_mean(tf.square(tf.maximum(0.0, grad_norm - 1.)))) 335 | 336 | elif self.gan_type == 'wgan-gp' or self.gan_type == 'dragan': 337 | cam_GP.append(self.ld * tf.reduce_mean(tf.square(grad_norm - 1.))) 338 | 339 | 340 | return sum(GP), sum(cam_GP) 341 | 342 | def build_model(self): 343 | if self.phase == 'train' : 344 | self.lr = tf.placeholder(tf.float32, name='learning_rate') 345 | 346 | 347 | """ Input Image""" 348 | Image_Data_Class = ImageData(self.img_size, self.img_ch, self.augment_flag) 349 | 350 | trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset) 351 | trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset) 352 | 353 | 354 | gpu_device = '/gpu:0' 355 | trainA = trainA.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, None)) 356 | trainB = trainB.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, None)) 357 | 358 | 359 | trainA_iterator = trainA.make_one_shot_iterator() 360 | trainB_iterator = trainB.make_one_shot_iterator() 361 | 362 | self.domain_A = trainA_iterator.get_next() 363 | self.domain_B = trainB_iterator.get_next() 364 | 365 | """ Define Generator, Discriminator """ 366 | x_ab, cam_ab = self.generate_a2b(self.domain_A) # real a 367 | x_ba, cam_ba = self.generate_b2a(self.domain_B) # real b 368 | 369 | x_aba, _ = self.generate_b2a(x_ab, reuse=True) # real b 370 | x_bab, _ = self.generate_a2b(x_ba, reuse=True) # real a 371 | 372 | x_aa, cam_aa = self.generate_b2a(self.domain_A, reuse=True) # fake b 373 | x_bb, cam_bb = self.generate_a2b(self.domain_B, reuse=True) # fake a 374 | 375 | real_A_logit, real_A_cam_logit, real_B_logit, real_B_cam_logit = self.discriminate_real(self.domain_A, self.domain_B) 376 | fake_A_logit, fake_A_cam_logit, fake_B_logit, fake_B_cam_logit = self.discriminate_fake(x_ba, x_ab) 377 | 378 | 379 | """ Define Loss """ 380 | if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan' : 381 | GP_A, GP_CAM_A = self.gradient_panalty(real=self.domain_A, fake=x_ba, scope="discriminator_A") 382 | GP_B, GP_CAM_B = self.gradient_panalty(real=self.domain_B, fake=x_ab, scope="discriminator_B") 383 | else : 384 | GP_A, GP_CAM_A = 0, 0 385 | GP_B, GP_CAM_B = 0, 0 386 | 387 | G_ad_loss_A = (generator_loss(self.gan_type, fake_A_logit) + generator_loss(self.gan_type, fake_A_cam_logit)) 388 | G_ad_loss_B = (generator_loss(self.gan_type, fake_B_logit) + generator_loss(self.gan_type, fake_B_cam_logit)) 389 | 390 | D_ad_loss_A = (discriminator_loss(self.gan_type, real_A_logit, fake_A_logit) + discriminator_loss(self.gan_type, real_A_cam_logit, fake_A_cam_logit) + GP_A + GP_CAM_A) 391 | D_ad_loss_B = (discriminator_loss(self.gan_type, real_B_logit, fake_B_logit) + discriminator_loss(self.gan_type, real_B_cam_logit, fake_B_cam_logit) + GP_B + GP_CAM_B) 392 | 393 | reconstruction_A = L1_loss(x_aba, self.domain_A) # reconstruction 394 | reconstruction_B = L1_loss(x_bab, self.domain_B) # reconstruction 395 | 396 | identity_A = L1_loss(x_aa, self.domain_A) 397 | identity_B = L1_loss(x_bb, self.domain_B) 398 | 399 | cam_A = cam_loss(source=cam_ba, non_source=cam_aa) 400 | cam_B = cam_loss(source=cam_ab, non_source=cam_bb) 401 | 402 | Generator_A_gan = self.adv_weight * G_ad_loss_A 403 | Generator_A_cycle = self.cycle_weight * reconstruction_B 404 | Generator_A_identity = self.identity_weight * identity_A 405 | Generator_A_cam = self.cam_weight * cam_A 406 | 407 | 408 | Generator_B_gan = self.adv_weight * G_ad_loss_B 409 | Generator_B_cycle = self.cycle_weight * reconstruction_A 410 | Generator_B_identity = self.identity_weight * identity_B 411 | Generator_B_cam = self.cam_weight * cam_B 412 | 413 | 414 | Generator_A_loss = Generator_A_gan + Generator_A_cycle + Generator_A_identity + Generator_A_cam 415 | Generator_B_loss = Generator_B_gan + Generator_B_cycle + Generator_B_identity + Generator_B_cam 416 | 417 | 418 | Discriminator_A_loss = self.adv_weight * D_ad_loss_A 419 | Discriminator_B_loss = self.adv_weight * D_ad_loss_B 420 | 421 | self.Generator_loss = Generator_A_loss + Generator_B_loss + regularization_loss('generator') 422 | self.Discriminator_loss = Discriminator_A_loss + Discriminator_B_loss + regularization_loss('discriminator') 423 | 424 | 425 | """ Result Image """ 426 | self.fake_A = x_ba 427 | self.fake_B = x_ab 428 | 429 | self.real_A = self.domain_A 430 | self.real_B = self.domain_B 431 | 432 | 433 | """ Training """ 434 | t_vars = tf.trainable_variables() 435 | G_vars = [var for var in t_vars if 'generator' in var.name] 436 | D_vars = [var for var in t_vars if 'discriminator' in var.name] 437 | 438 | self.G_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Generator_loss, var_list=G_vars) 439 | self.D_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Discriminator_loss, var_list=D_vars) 440 | 441 | 442 | """" Summary """ 443 | self.all_G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss) 444 | self.all_D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss) 445 | 446 | self.G_A_loss = tf.summary.scalar("G_A_loss", Generator_A_loss) 447 | self.G_A_gan = tf.summary.scalar("G_A_gan", Generator_A_gan) 448 | self.G_A_cycle = tf.summary.scalar("G_A_cycle", Generator_A_cycle) 449 | self.G_A_identity = tf.summary.scalar("G_A_identity", Generator_A_identity) 450 | self.G_A_cam = tf.summary.scalar("G_A_cam", Generator_A_cam) 451 | 452 | self.G_B_loss = tf.summary.scalar("G_B_loss", Generator_B_loss) 453 | self.G_B_gan = tf.summary.scalar("G_B_gan", Generator_B_gan) 454 | self.G_B_cycle = tf.summary.scalar("G_B_cycle", Generator_B_cycle) 455 | self.G_B_identity = tf.summary.scalar("G_B_identity", Generator_B_identity) 456 | self.G_B_cam = tf.summary.scalar("G_B_cam", Generator_B_cam) 457 | 458 | self.D_A_loss = tf.summary.scalar("D_A_loss", Discriminator_A_loss) 459 | self.D_B_loss = tf.summary.scalar("D_B_loss", Discriminator_B_loss) 460 | 461 | self.rho_var = [] 462 | for var in tf.trainable_variables(): 463 | if 'rho' in var.name: 464 | self.rho_var.append(tf.summary.histogram(var.name, var)) 465 | self.rho_var.append(tf.summary.scalar(var.name + "_min", tf.reduce_min(var))) 466 | self.rho_var.append(tf.summary.scalar(var.name + "_max", tf.reduce_max(var))) 467 | self.rho_var.append(tf.summary.scalar(var.name + "_mean", tf.reduce_mean(var))) 468 | 469 | g_summary_list = [self.G_A_loss, self.G_A_gan, self.G_A_cycle, self.G_A_identity, self.G_A_cam, 470 | self.G_B_loss, self.G_B_gan, self.G_B_cycle, self.G_B_identity, self.G_B_cam, 471 | self.all_G_loss] 472 | 473 | g_summary_list.extend(self.rho_var) 474 | d_summary_list = [self.D_A_loss, self.D_B_loss, self.all_D_loss] 475 | 476 | self.G_loss = tf.summary.merge(g_summary_list) 477 | self.D_loss = tf.summary.merge(d_summary_list) 478 | 479 | else : 480 | """ Test """ 481 | self.test_domain_A = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_domain_A') 482 | self.test_domain_B = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_domain_B') 483 | 484 | 485 | self.test_fake_B, _ = self.generate_a2b(self.test_domain_A) 486 | self.test_fake_A, _ = self.generate_b2a(self.test_domain_B) 487 | 488 | 489 | def train(self): 490 | # initialize all variables 491 | tf.global_variables_initializer().run() 492 | 493 | # saver to save model 494 | self.saver = tf.train.Saver() 495 | 496 | # summary writer 497 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_dir, self.sess.graph) 498 | 499 | 500 | # restore check-point if it exits 501 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 502 | if could_load: 503 | start_epoch = (int)(checkpoint_counter / self.iteration) 504 | start_batch_id = checkpoint_counter - start_epoch * self.iteration 505 | counter = checkpoint_counter 506 | print(" [*] Load SUCCESS") 507 | else: 508 | start_epoch = 0 509 | start_batch_id = 0 510 | counter = 1 511 | print(" [!] Load failed...") 512 | 513 | # loop for epoch 514 | start_time = time.time() 515 | past_g_loss = -1. 516 | lr = self.init_lr 517 | for epoch in range(start_epoch, self.epoch): 518 | # lr = self.init_lr if epoch < self.decay_epoch else self.init_lr * (self.epoch - epoch) / (self.epoch - self.decay_epoch) 519 | if self.decay_flag : 520 | #lr = self.init_lr * pow(0.5, epoch // self.decay_epoch) 521 | lr = self.init_lr if epoch < self.decay_epoch else self.init_lr * (self.epoch - epoch) / (self.epoch - self.decay_epoch) 522 | for idx in range(start_batch_id, self.iteration): 523 | train_feed_dict = { 524 | self.lr : lr 525 | } 526 | 527 | # Update D 528 | _, d_loss, summary_str = self.sess.run([self.D_optim, 529 | self.Discriminator_loss, self.D_loss], feed_dict = train_feed_dict) 530 | self.writer.add_summary(summary_str, counter) 531 | 532 | # Update G 533 | g_loss = None 534 | if (counter - 1) % self.n_critic == 0 : 535 | batch_A_images, batch_B_images, fake_A, fake_B, _, g_loss, summary_str = self.sess.run([self.real_A, self.real_B, 536 | self.fake_A, self.fake_B, 537 | self.G_optim, 538 | self.Generator_loss, self.G_loss], feed_dict = train_feed_dict) 539 | self.writer.add_summary(summary_str, counter) 540 | past_g_loss = g_loss 541 | 542 | # display training status 543 | counter += 1 544 | if g_loss == None : 545 | g_loss = past_g_loss 546 | print("Epoch: [%2d] [%5d/%5d] time: %4.4f d_loss: %.8f, g_loss: %.8f" % (epoch, idx, self.iteration, time.time() - start_time, d_loss, g_loss)) 547 | 548 | if np.mod(idx+1, self.print_freq) == 0 : 549 | save_images(batch_A_images, [self.batch_size, 1], 550 | './{}/real_A_{:03d}_{:05d}.png'.format(self.sample_dir, epoch, idx+1)) 551 | # save_images(batch_B_images, [self.batch_size, 1], 552 | # './{}/real_B_{:03d}_{:05d}.png'.format(self.sample_dir, epoch, idx+1)) 553 | 554 | # save_images(fake_A, [self.batch_size, 1], 555 | # './{}/fake_A_{:03d}_{:05d}.png'.format(self.sample_dir, epoch, idx+1)) 556 | save_images(fake_B, [self.batch_size, 1], 557 | './{}/fake_B_{:03d}_{:05d}.png'.format(self.sample_dir, epoch, idx+1)) 558 | 559 | if np.mod(idx + 1, self.save_freq) == 0: 560 | self.save(self.checkpoint_dir, counter) 561 | 562 | 563 | 564 | # After an epoch, start_batch_id is set to zero 565 | # non-zero value is only for the first epoch after loading pre-trained model 566 | start_batch_id = 0 567 | 568 | # save model for final step 569 | self.save(self.checkpoint_dir, counter) 570 | 571 | @property 572 | def model_dir(self): 573 | n_res = str(self.n_res) + 'resblock' 574 | n_dis = str(self.n_dis) + 'dis' 575 | 576 | if self.smoothing : 577 | smoothing = '_smoothing' 578 | else : 579 | smoothing = '' 580 | 581 | if self.sn : 582 | sn = '_sn' 583 | else : 584 | sn = '' 585 | 586 | return "{}_{}_{}_{}_{}_{}_{}_{}_{}_{}{}{}".format(self.model_name, self.dataset_name, 587 | self.gan_type, n_res, n_dis, 588 | self.n_critic, 589 | self.adv_weight, self.cycle_weight, self.identity_weight, self.cam_weight, sn, smoothing) 590 | 591 | def save(self, checkpoint_dir, step): 592 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) 593 | 594 | if not os.path.exists(checkpoint_dir): 595 | os.makedirs(checkpoint_dir) 596 | 597 | self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step) 598 | 599 | def load(self, checkpoint_dir): 600 | print(" [*] Reading checkpoints...") 601 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) 602 | 603 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 604 | if ckpt and ckpt.model_checkpoint_path: 605 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 606 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 607 | counter = int(ckpt_name.split('-')[-1]) 608 | print(" [*] Success to read {}".format(ckpt_name)) 609 | return True, counter 610 | else: 611 | print(" [*] Failed to find a checkpoint") 612 | return False, 0 613 | 614 | def test(self): 615 | tf.global_variables_initializer().run() 616 | test_A_files = glob('./dataset/{}/*.*'.format(self.dataset_name + '/testA')) 617 | test_B_files = glob('./dataset/{}/*.*'.format(self.dataset_name + '/testB')) 618 | 619 | self.saver = tf.train.Saver() 620 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 621 | self.result_dir = os.path.join(self.result_dir, self.model_dir) 622 | check_folder(self.result_dir) 623 | 624 | if could_load : 625 | print(" [*] Load SUCCESS") 626 | else : 627 | print(" [!] Load failed...") 628 | 629 | # write html for visual comparison 630 | index_path = os.path.join(self.result_dir, 'index.html') 631 | index = open(index_path, 'w') 632 | index.write("") 633 | index.write("") 634 | 635 | for sample_file in test_A_files : # A -> B 636 | print('Processing A image: ' + sample_file) 637 | sample_image = np.asarray(load_test_data(sample_file, size=self.img_size)) 638 | image_path = os.path.join(self.result_dir,'{0}'.format(os.path.basename(sample_file))) 639 | 640 | fake_img = self.sess.run(self.test_fake_B, feed_dict = {self.test_domain_A : sample_image}) 641 | save_images(fake_img, [1, 1], image_path) 642 | 643 | index.write("" % os.path.basename(image_path)) 644 | 645 | index.write("" % (sample_file if os.path.isabs(sample_file) else ( 646 | '../..' + os.path.sep + sample_file), self.img_size, self.img_size)) 647 | index.write("" % (image_path if os.path.isabs(image_path) else ( 648 | '../..' + os.path.sep + image_path), self.img_size, self.img_size)) 649 | index.write("") 650 | 651 | for sample_file in test_B_files : # B -> A 652 | print('Processing B image: ' + sample_file) 653 | sample_image = np.asarray(load_test_data(sample_file, size=self.img_size)) 654 | image_path = os.path.join(self.result_dir,'{0}'.format(os.path.basename(sample_file))) 655 | 656 | fake_img = self.sess.run(self.test_fake_A, feed_dict = {self.test_domain_B : sample_image}) 657 | 658 | save_images(fake_img, [1, 1], image_path) 659 | index.write("" % os.path.basename(image_path)) 660 | index.write("" % (sample_file if os.path.isabs(sample_file) else ( 661 | '../..' + os.path.sep + sample_file), self.img_size, self.img_size)) 662 | index.write("" % (image_path if os.path.isabs(image_path) else ( 663 | '../..' + os.path.sep + image_path), self.img_size, self.img_size)) 664 | index.write("") 665 | index.close() 666 | -------------------------------------------------------------------------------- /assets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/UGATIT/d508e8f5188e47000d79d8aecada0cc9119e0d56/assets/.DS_Store -------------------------------------------------------------------------------- /assets/ablation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/UGATIT/d508e8f5188e47000d79d8aecada0cc9119e0d56/assets/ablation.png -------------------------------------------------------------------------------- /assets/discriminator_fix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/UGATIT/d508e8f5188e47000d79d8aecada0cc9119e0d56/assets/discriminator_fix.png -------------------------------------------------------------------------------- /assets/generator_fix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/UGATIT/d508e8f5188e47000d79d8aecada0cc9119e0d56/assets/generator_fix.png -------------------------------------------------------------------------------- /assets/kid_fix2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/UGATIT/d508e8f5188e47000d79d8aecada0cc9119e0d56/assets/kid_fix2.png -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/UGATIT/d508e8f5188e47000d79d8aecada0cc9119e0d56/assets/teaser.png -------------------------------------------------------------------------------- /assets/user_study.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/UGATIT/d508e8f5188e47000d79d8aecada0cc9119e0d56/assets/user_study.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from UGATIT import UGATIT 2 | import argparse 3 | from utils import * 4 | 5 | """parsing and configuration""" 6 | 7 | def parse_args(): 8 | desc = "Tensorflow implementation of U-GAT-IT" 9 | parser = argparse.ArgumentParser(description=desc) 10 | parser.add_argument('--phase', type=str, default='train', help='[train / test]') 11 | parser.add_argument('--light', type=str2bool, default=False, help='[U-GAT-IT full version / U-GAT-IT light version]') 12 | parser.add_argument('--dataset', type=str, default='selfie2anime', help='dataset_name') 13 | 14 | parser.add_argument('--epoch', type=int, default=100, help='The number of epochs to run') 15 | parser.add_argument('--iteration', type=int, default=10000, help='The number of training iterations') 16 | parser.add_argument('--batch_size', type=int, default=1, help='The size of batch size') 17 | parser.add_argument('--print_freq', type=int, default=1000, help='The number of image_print_freq') 18 | parser.add_argument('--save_freq', type=int, default=1000, help='The number of ckpt_save_freq') 19 | parser.add_argument('--decay_flag', type=str2bool, default=True, help='The decay_flag') 20 | parser.add_argument('--decay_epoch', type=int, default=50, help='decay epoch') 21 | 22 | parser.add_argument('--lr', type=float, default=0.0001, help='The learning rate') 23 | parser.add_argument('--GP_ld', type=int, default=10, help='The gradient penalty lambda') 24 | parser.add_argument('--adv_weight', type=int, default=1, help='Weight about GAN') 25 | parser.add_argument('--cycle_weight', type=int, default=10, help='Weight about Cycle') 26 | parser.add_argument('--identity_weight', type=int, default=10, help='Weight about Identity') 27 | parser.add_argument('--cam_weight', type=int, default=1000, help='Weight about CAM') 28 | parser.add_argument('--gan_type', type=str, default='lsgan', help='[gan / lsgan / wgan-gp / wgan-lp / dragan / hinge]') 29 | 30 | parser.add_argument('--smoothing', type=str2bool, default=True, help='AdaLIN smoothing effect') 31 | 32 | parser.add_argument('--ch', type=int, default=64, help='base channel number per layer') 33 | parser.add_argument('--n_res', type=int, default=4, help='The number of resblock') 34 | parser.add_argument('--n_dis', type=int, default=6, help='The number of discriminator layer') 35 | parser.add_argument('--n_critic', type=int, default=1, help='The number of critic') 36 | parser.add_argument('--sn', type=str2bool, default=True, help='using spectral norm') 37 | 38 | parser.add_argument('--img_size', type=int, default=256, help='The size of image') 39 | parser.add_argument('--img_ch', type=int, default=3, help='The size of image channel') 40 | parser.add_argument('--augment_flag', type=str2bool, default=True, help='Image augmentation use or not') 41 | 42 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint', 43 | help='Directory name to save the checkpoints') 44 | parser.add_argument('--result_dir', type=str, default='results', 45 | help='Directory name to save the generated images') 46 | parser.add_argument('--log_dir', type=str, default='logs', 47 | help='Directory name to save training logs') 48 | parser.add_argument('--sample_dir', type=str, default='samples', 49 | help='Directory name to save the samples on training') 50 | 51 | return check_args(parser.parse_args()) 52 | 53 | """checking arguments""" 54 | def check_args(args): 55 | # --checkpoint_dir 56 | check_folder(args.checkpoint_dir) 57 | 58 | # --result_dir 59 | check_folder(args.result_dir) 60 | 61 | # --result_dir 62 | check_folder(args.log_dir) 63 | 64 | # --sample_dir 65 | check_folder(args.sample_dir) 66 | 67 | # --epoch 68 | try: 69 | assert args.epoch >= 1 70 | except: 71 | print('number of epochs must be larger than or equal to one') 72 | 73 | # --batch_size 74 | try: 75 | assert args.batch_size >= 1 76 | except: 77 | print('batch size must be larger than or equal to one') 78 | return args 79 | 80 | """main""" 81 | def main(): 82 | # parse arguments 83 | args = parse_args() 84 | if args is None: 85 | exit() 86 | 87 | # open session 88 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: 89 | gan = UGATIT(sess, args) 90 | 91 | # build graph 92 | gan.build_model() 93 | 94 | # show network architecture 95 | show_all_variables() 96 | 97 | if args.phase == 'train' : 98 | gan.train() 99 | print(" [*] Training finished!") 100 | 101 | if args.phase == 'test' : 102 | gan.test() 103 | print(" [*] Test finished!") 104 | 105 | if __name__ == '__main__': 106 | main() 107 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib as tf_contrib 3 | 4 | # Xavier : tf_contrib.layers.xavier_initializer() 5 | # He : tf_contrib.layers.variance_scaling_initializer() 6 | # Normal : tf.random_normal_initializer(mean=0.0, stddev=0.02) 7 | # l2_decay : tf_contrib.layers.l2_regularizer(0.0001) 8 | 9 | weight_init = tf.random_normal_initializer(mean=0.0, stddev=0.02) 10 | weight_regularizer = tf_contrib.layers.l2_regularizer(scale=0.0001) 11 | 12 | ################################################################################## 13 | # Layer 14 | ################################################################################## 15 | 16 | def conv(x, channels, kernel=4, stride=2, pad=0, pad_type='zero', use_bias=True, sn=False, scope='conv_0'): 17 | with tf.variable_scope(scope): 18 | if pad > 0 : 19 | if (kernel - stride) % 2 == 0: 20 | pad_top = pad 21 | pad_bottom = pad 22 | pad_left = pad 23 | pad_right = pad 24 | 25 | else: 26 | pad_top = pad 27 | pad_bottom = kernel - stride - pad_top 28 | pad_left = pad 29 | pad_right = kernel - stride - pad_left 30 | 31 | if pad_type == 'zero': 32 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]) 33 | if pad_type == 'reflect': 34 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]], mode='REFLECT') 35 | 36 | if sn : 37 | w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], initializer=weight_init, 38 | regularizer=weight_regularizer) 39 | x = tf.nn.conv2d(input=x, filter=spectral_norm(w), 40 | strides=[1, stride, stride, 1], padding='VALID') 41 | if use_bias : 42 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0)) 43 | x = tf.nn.bias_add(x, bias) 44 | 45 | else : 46 | x = tf.layers.conv2d(inputs=x, filters=channels, 47 | kernel_size=kernel, kernel_initializer=weight_init, 48 | kernel_regularizer=weight_regularizer, 49 | strides=stride, use_bias=use_bias) 50 | 51 | 52 | return x 53 | 54 | def fully_connected_with_w(x, use_bias=True, sn=False, reuse=False, scope='linear'): 55 | with tf.variable_scope(scope, reuse=reuse): 56 | x = flatten(x) 57 | bias = 0.0 58 | shape = x.get_shape().as_list() 59 | channels = shape[-1] 60 | 61 | w = tf.get_variable("kernel", [channels, 1], tf.float32, 62 | initializer=weight_init, regularizer=weight_regularizer) 63 | 64 | if sn : 65 | w = spectral_norm(w) 66 | 67 | if use_bias : 68 | bias = tf.get_variable("bias", [1], 69 | initializer=tf.constant_initializer(0.0)) 70 | 71 | x = tf.matmul(x, w) + bias 72 | else : 73 | x = tf.matmul(x, w) 74 | 75 | if use_bias : 76 | weights = tf.gather(tf.transpose(tf.nn.bias_add(w, bias)), 0) 77 | else : 78 | weights = tf.gather(tf.transpose(w), 0) 79 | 80 | return x, weights 81 | 82 | def fully_connected(x, units, use_bias=True, sn=False, scope='linear'): 83 | with tf.variable_scope(scope): 84 | x = flatten(x) 85 | shape = x.get_shape().as_list() 86 | channels = shape[-1] 87 | 88 | if sn: 89 | w = tf.get_variable("kernel", [channels, units], tf.float32, 90 | initializer=weight_init, regularizer=weight_regularizer) 91 | if use_bias: 92 | bias = tf.get_variable("bias", [units], 93 | initializer=tf.constant_initializer(0.0)) 94 | 95 | x = tf.matmul(x, spectral_norm(w)) + bias 96 | else: 97 | x = tf.matmul(x, spectral_norm(w)) 98 | 99 | else : 100 | x = tf.layers.dense(x, units=units, kernel_initializer=weight_init, kernel_regularizer=weight_regularizer, use_bias=use_bias) 101 | 102 | return x 103 | 104 | def flatten(x) : 105 | return tf.layers.flatten(x) 106 | 107 | ################################################################################## 108 | # Residual-block 109 | ################################################################################## 110 | 111 | def resblock(x_init, channels, use_bias=True, scope='resblock_0'): 112 | with tf.variable_scope(scope): 113 | with tf.variable_scope('res1'): 114 | x = conv(x_init, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias) 115 | x = instance_norm(x) 116 | x = relu(x) 117 | 118 | with tf.variable_scope('res2'): 119 | x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias) 120 | x = instance_norm(x) 121 | 122 | return x + x_init 123 | 124 | def adaptive_ins_layer_resblock(x_init, channels, gamma, beta, use_bias=True, smoothing=True, scope='adaptive_resblock') : 125 | with tf.variable_scope(scope): 126 | with tf.variable_scope('res1'): 127 | x = conv(x_init, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias) 128 | x = adaptive_instance_layer_norm(x, gamma, beta, smoothing) 129 | x = relu(x) 130 | 131 | with tf.variable_scope('res2'): 132 | x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias) 133 | x = adaptive_instance_layer_norm(x, gamma, beta, smoothing) 134 | 135 | return x + x_init 136 | 137 | 138 | ################################################################################## 139 | # Sampling 140 | ################################################################################## 141 | 142 | def up_sample(x, scale_factor=2): 143 | _, h, w, _ = x.get_shape().as_list() 144 | new_size = [h * scale_factor, w * scale_factor] 145 | return tf.image.resize_nearest_neighbor(x, size=new_size) 146 | 147 | 148 | def global_avg_pooling(x): 149 | gap = tf.reduce_mean(x, axis=[1, 2]) 150 | return gap 151 | 152 | def global_max_pooling(x): 153 | gmp = tf.reduce_max(x, axis=[1, 2]) 154 | return gmp 155 | 156 | ################################################################################## 157 | # Activation function 158 | ################################################################################## 159 | 160 | def lrelu(x, alpha=0.01): 161 | # pytorch alpha is 0.01 162 | return tf.nn.leaky_relu(x, alpha) 163 | 164 | 165 | def relu(x): 166 | return tf.nn.relu(x) 167 | 168 | 169 | def tanh(x): 170 | return tf.tanh(x) 171 | 172 | def sigmoid(x) : 173 | return tf.sigmoid(x) 174 | 175 | ################################################################################## 176 | # Normalization function 177 | ################################################################################## 178 | 179 | def adaptive_instance_layer_norm(x, gamma, beta, smoothing=True, scope='instance_layer_norm') : 180 | with tf.variable_scope(scope): 181 | ch = x.shape[-1] 182 | eps = 1e-5 183 | 184 | ins_mean, ins_sigma = tf.nn.moments(x, axes=[1, 2], keep_dims=True) 185 | x_ins = (x - ins_mean) / (tf.sqrt(ins_sigma + eps)) 186 | 187 | ln_mean, ln_sigma = tf.nn.moments(x, axes=[1, 2, 3], keep_dims=True) 188 | x_ln = (x - ln_mean) / (tf.sqrt(ln_sigma + eps)) 189 | 190 | rho = tf.get_variable("rho", [ch], initializer=tf.constant_initializer(1.0), constraint=lambda x: tf.clip_by_value(x, clip_value_min=0.0, clip_value_max=1.0)) 191 | 192 | if smoothing : 193 | rho = tf.clip_by_value(rho - tf.constant(0.1), 0.0, 1.0) 194 | 195 | x_hat = rho * x_ins + (1 - rho) * x_ln 196 | 197 | 198 | x_hat = x_hat * gamma + beta 199 | 200 | return x_hat 201 | 202 | def instance_norm(x, scope='instance_norm'): 203 | return tf_contrib.layers.instance_norm(x, 204 | epsilon=1e-05, 205 | center=True, scale=True, 206 | scope=scope) 207 | 208 | def layer_norm(x, scope='layer_norm') : 209 | return tf_contrib.layers.layer_norm(x, 210 | center=True, scale=True, 211 | scope=scope) 212 | 213 | def layer_instance_norm(x, scope='layer_instance_norm') : 214 | with tf.variable_scope(scope): 215 | ch = x.shape[-1] 216 | eps = 1e-5 217 | 218 | ins_mean, ins_sigma = tf.nn.moments(x, axes=[1, 2], keep_dims=True) 219 | x_ins = (x - ins_mean) / (tf.sqrt(ins_sigma + eps)) 220 | 221 | ln_mean, ln_sigma = tf.nn.moments(x, axes=[1, 2, 3], keep_dims=True) 222 | x_ln = (x - ln_mean) / (tf.sqrt(ln_sigma + eps)) 223 | 224 | rho = tf.get_variable("rho", [ch], initializer=tf.constant_initializer(0.0), constraint=lambda x: tf.clip_by_value(x, clip_value_min=0.0, clip_value_max=1.0)) 225 | 226 | gamma = tf.get_variable("gamma", [ch], initializer=tf.constant_initializer(1.0)) 227 | beta = tf.get_variable("beta", [ch], initializer=tf.constant_initializer(0.0)) 228 | 229 | x_hat = rho * x_ins + (1 - rho) * x_ln 230 | 231 | x_hat = x_hat * gamma + beta 232 | 233 | return x_hat 234 | 235 | def spectral_norm(w, iteration=1): 236 | w_shape = w.shape.as_list() 237 | w = tf.reshape(w, [-1, w_shape[-1]]) 238 | 239 | u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.random_normal_initializer(), trainable=False) 240 | 241 | u_hat = u 242 | v_hat = None 243 | for i in range(iteration): 244 | """ 245 | power iteration 246 | Usually iteration = 1 will be enough 247 | """ 248 | v_ = tf.matmul(u_hat, tf.transpose(w)) 249 | v_hat = tf.nn.l2_normalize(v_) 250 | 251 | u_ = tf.matmul(v_hat, w) 252 | u_hat = tf.nn.l2_normalize(u_) 253 | 254 | u_hat = tf.stop_gradient(u_hat) 255 | v_hat = tf.stop_gradient(v_hat) 256 | 257 | sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat)) 258 | 259 | with tf.control_dependencies([u.assign(u_hat)]): 260 | w_norm = w / sigma 261 | w_norm = tf.reshape(w_norm, w_shape) 262 | 263 | 264 | return w_norm 265 | 266 | ################################################################################## 267 | # Loss function 268 | ################################################################################## 269 | 270 | def L1_loss(x, y): 271 | loss = tf.reduce_mean(tf.abs(x - y)) 272 | 273 | return loss 274 | 275 | def cam_loss(source, non_source) : 276 | 277 | identity_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(source), logits=source)) 278 | non_identity_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(non_source), logits=non_source)) 279 | 280 | loss = identity_loss + non_identity_loss 281 | 282 | return loss 283 | 284 | def regularization_loss(scope_name) : 285 | """ 286 | If you want to use "Regularization" 287 | g_loss += regularization_loss('generator') 288 | d_loss += regularization_loss('discriminator') 289 | """ 290 | collection_regularization = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) 291 | 292 | loss = [] 293 | for item in collection_regularization : 294 | if scope_name in item.name : 295 | loss.append(item) 296 | 297 | return tf.reduce_sum(loss) 298 | 299 | 300 | def discriminator_loss(loss_func, real, fake): 301 | loss = [] 302 | real_loss = 0 303 | fake_loss = 0 304 | 305 | for i in range(2) : 306 | if loss_func.__contains__('wgan') : 307 | real_loss = -tf.reduce_mean(real[i]) 308 | fake_loss = tf.reduce_mean(fake[i]) 309 | 310 | if loss_func == 'lsgan' : 311 | real_loss = tf.reduce_mean(tf.squared_difference(real[i], 1.0)) 312 | fake_loss = tf.reduce_mean(tf.square(fake[i])) 313 | 314 | if loss_func == 'gan' or loss_func == 'dragan' : 315 | real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real[i]), logits=real[i])) 316 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(fake[i]), logits=fake[i])) 317 | 318 | if loss_func == 'hinge' : 319 | real_loss = tf.reduce_mean(relu(1.0 - real[i])) 320 | fake_loss = tf.reduce_mean(relu(1.0 + fake[i])) 321 | 322 | loss.append(real_loss + fake_loss) 323 | 324 | return sum(loss) 325 | 326 | def generator_loss(loss_func, fake): 327 | loss = [] 328 | fake_loss = 0 329 | 330 | for i in range(2) : 331 | if loss_func.__contains__('wgan') : 332 | fake_loss = -tf.reduce_mean(fake[i]) 333 | 334 | if loss_func == 'lsgan' : 335 | fake_loss = tf.reduce_mean(tf.squared_difference(fake[i], 1.0)) 336 | 337 | if loss_func == 'gan' or loss_func == 'dragan' : 338 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(fake[i]), logits=fake[i])) 339 | 340 | if loss_func == 'hinge' : 341 | fake_loss = -tf.reduce_mean(fake[i]) 342 | 343 | loss.append(fake_loss) 344 | 345 | return sum(loss) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib import slim 3 | import cv2 4 | import os, random 5 | import numpy as np 6 | 7 | class ImageData: 8 | 9 | def __init__(self, load_size, channels, augment_flag): 10 | self.load_size = load_size 11 | self.channels = channels 12 | self.augment_flag = augment_flag 13 | 14 | def image_processing(self, filename): 15 | x = tf.read_file(filename) 16 | x_decode = tf.image.decode_jpeg(x, channels=self.channels) 17 | img = tf.image.resize_images(x_decode, [self.load_size, self.load_size]) 18 | img = tf.cast(img, tf.float32) / 127.5 - 1 19 | 20 | if self.augment_flag : 21 | augment_size = self.load_size + (30 if self.load_size == 256 else 15) 22 | p = random.random() 23 | if p > 0.5: 24 | img = augmentation(img, augment_size) 25 | 26 | return img 27 | 28 | def load_test_data(image_path, size=256): 29 | img = cv2.imread(image_path, flags=cv2.IMREAD_COLOR) 30 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 31 | 32 | img = cv2.resize(img, dsize=(size, size)) 33 | 34 | img = np.expand_dims(img, axis=0) 35 | img = img/127.5 - 1 36 | 37 | return img 38 | 39 | def augmentation(image, augment_size): 40 | seed = random.randint(0, 2 ** 31 - 1) 41 | ori_image_shape = tf.shape(image) 42 | image = tf.image.random_flip_left_right(image, seed=seed) 43 | image = tf.image.resize_images(image, [augment_size, augment_size]) 44 | image = tf.random_crop(image, ori_image_shape, seed=seed) 45 | return image 46 | 47 | def save_images(images, size, image_path): 48 | return imsave(inverse_transform(images), size, image_path) 49 | 50 | def inverse_transform(images): 51 | return ((images+1.) / 2) * 255.0 52 | 53 | 54 | def imsave(images, size, path): 55 | images = merge(images, size) 56 | images = cv2.cvtColor(images.astype('uint8'), cv2.COLOR_RGB2BGR) 57 | 58 | return cv2.imwrite(path, images) 59 | 60 | def merge(images, size): 61 | h, w = images.shape[1], images.shape[2] 62 | img = np.zeros((h * size[0], w * size[1], 3)) 63 | for idx, image in enumerate(images): 64 | i = idx % size[1] 65 | j = idx // size[1] 66 | img[h*j:h*(j+1), w*i:w*(i+1), :] = image 67 | 68 | return img 69 | 70 | def show_all_variables(): 71 | model_vars = tf.trainable_variables() 72 | slim.model_analyzer.analyze_vars(model_vars, print_info=True) 73 | 74 | def check_folder(log_dir): 75 | if not os.path.exists(log_dir): 76 | os.makedirs(log_dir) 77 | return log_dir 78 | 79 | def str2bool(x): 80 | return x.lower() in ('true') 81 | --------------------------------------------------------------------------------
nameinputoutput
%s
%s