├── .DS_Store ├── .gitignore ├── LICENSE ├── README.md ├── SDIT.py ├── assets ├── framework.png ├── result.png └── teaser.png ├── main.py ├── ops.py └── utils.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/SDIT-Tensorflow/71f9917325d647d3d51e691e85f5a3079e068da0/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Junho Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SDIT-Tensorflow 2 | ## : Scalable and Diverse Cross-domain Image Translation (ACM-MM 2019) 3 | 4 |
5 | 6 |
7 | 8 | ### [Paper](https://arxiv.org/abs/1908.06881) | [Official Pytorch code](https://github.com/yaxingwang/SDIT) 9 | 10 | ## Usage 11 | ``` 12 | ├── dataset 13 |    └── YOUR_DATASET_NAME 14 |    ├── train 15 |           ├── class1 (class folder) 16 | ├── xxx.jpg (class1 image) 17 | ├── yyy.png 18 | ├── ... 19 | ├── class2 20 | ├── aaa.jpg (class2 image) 21 | ├── bbb.png 22 | ├── ... 23 | ├── class3 24 | ├── ... 25 |    ├── test 26 | ├── zzz.jpg (any content image) 27 | ├── www.png 28 | ├── ... 29 | 30 | └── celebA 31 | ├── train 32 | ├── 000001.png 33 | ├── 000002.png 34 | └── ... 35 | ├── test 36 | ├── a.jpg (The test image that you wanted) 37 | ├── b.png 38 | └── ... 39 | ├── list_attr_celeba.txt (For attribute information) 40 | ``` 41 | ### Train 42 | * python main.py --dataset celebA --phase train 43 | 44 | ### Test 45 | * python main.py --dataset celebA --phase test 46 | * The celebA test image and the image you wanted run simultaneously 47 | 48 | 49 | ## Comparison 50 |
51 | 52 |
53 | 54 | ## Paper results 55 |
56 | 57 |
58 | 59 | 60 | ## Author 61 | [Junho Kim](http://bit.ly/jhkim_ai) 62 | -------------------------------------------------------------------------------- /SDIT.py: -------------------------------------------------------------------------------- 1 | from ops import * 2 | from utils import * 3 | import time 4 | from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch 5 | import numpy as np 6 | from glob import glob 7 | from tqdm import tqdm 8 | 9 | class SDIT() : 10 | def __init__(self, sess, args): 11 | self.model_name = 'SDIT' 12 | self.sess = sess 13 | self.phase = args.phase 14 | self.checkpoint_dir = args.checkpoint_dir 15 | self.sample_dir = args.sample_dir 16 | self.result_dir = args.result_dir 17 | self.log_dir = args.log_dir 18 | self.dataset_name = args.dataset 19 | self.dataset_path = os.path.join('./dataset', self.dataset_name) 20 | self.augment_flag = args.augment_flag 21 | 22 | self.epoch = args.epoch 23 | self.iteration = args.iteration 24 | self.decay_flag = args.decay_flag 25 | self.decay_epoch = args.decay_epoch 26 | 27 | self.gan_type = args.gan_type 28 | self.attention = args.attention 29 | 30 | self.batch_size = args.batch_size 31 | self.print_freq = args.print_freq 32 | self.save_freq = args.save_freq 33 | 34 | self.init_lr = args.lr 35 | self.ch = args.ch 36 | 37 | if self.dataset_name == 'celebA-HQ' or self.dataset_name == 'celebA': 38 | self.label_list = args.label_list 39 | else : 40 | self.dataset_path = os.path.join(self.dataset_path, 'train') 41 | self.label_list = [os.path.basename(x) for x in glob(self.dataset_path + '/*')] 42 | 43 | 44 | self.c_dim = len(self.label_list) 45 | 46 | """ Weight """ 47 | self.adv_weight = args.adv_weight 48 | self.rec_weight = args.rec_weight 49 | self.cls_weight = args.cls_weight 50 | self.noise_weight = args.noise_weight 51 | self.gp_weight = args.gp_weight 52 | 53 | self.sn = args.sn 54 | 55 | """ Generator """ 56 | self.n_res = args.n_res 57 | self.style_dim = args.style_dim 58 | self.num_style = args.num_style 59 | 60 | """ Discriminator """ 61 | self.n_dis = args.n_dis 62 | self.n_critic = args.n_critic 63 | 64 | self.img_height = args.img_height 65 | self.img_width = args.img_width 66 | self.img_ch = args.img_ch 67 | 68 | print() 69 | 70 | print("##### Information #####") 71 | print("# gan type : ", self.gan_type) 72 | print("# selected_attrs : ", self.label_list) 73 | print("# dataset : ", self.dataset_name) 74 | print("# batch_size : ", self.batch_size) 75 | print("# epoch : ", self.epoch) 76 | print("# iteration per epoch : ", self.iteration) 77 | print("# spectral normalization : ", self.sn) 78 | 79 | print() 80 | 81 | print("##### Generator #####") 82 | print("# residual blocks : ", self.n_res) 83 | print("# attention : ", self.attention) 84 | 85 | print() 86 | 87 | print("##### Discriminator #####") 88 | print("# discriminator layer : ", self.n_dis) 89 | print("# the number of critic : ", self.n_critic) 90 | 91 | ################################################################################## 92 | # Generator 93 | ################################################################################## 94 | 95 | def generator(self, x_init, c, style, reuse=False, scope="generator"): 96 | channel = self.ch 97 | c = tf.cast(tf.reshape(c, shape=[-1, 1, 1, c.shape[-1]]), tf.float32) 98 | c = tf.tile(c, [1, x_init.shape[1], x_init.shape[2], 1]) 99 | x = tf.concat([x_init, c], axis=-1) 100 | 101 | with tf.variable_scope(scope, reuse=reuse) : 102 | """ Encoder """ 103 | x = conv(x, channel, kernel=7, stride=1, pad=3, pad_type='reflect', use_bias=False, sn=self.sn, scope='conv') 104 | x = instance_norm(x, scope='ins_norm') 105 | x = relu(x) 106 | 107 | # Down-Sampling 108 | for i in range(2) : 109 | x = conv(x, channel*2, kernel=4, stride=2, pad=1, pad_type='reflect', use_bias=False, sn=self.sn, scope='conv_'+str(i)) 110 | x = instance_norm(x, scope='down_ins_norm_'+str(i)) 111 | x = relu(x) 112 | 113 | channel = channel * 2 114 | 115 | """ Bottleneck """ 116 | # Encoder Bottleneck 117 | for i in range(self.n_res) : 118 | x = resblock(x, channel, use_bias=False, sn=self.sn, scope='encoder_resblock_' + str(i)) 119 | 120 | attention = x 121 | adaptive = x 122 | 123 | # Adaptive Bottleneck 124 | mu, var = self.MLP(style, channel) 125 | for i in range(self.n_res - 2) : 126 | idx = 2 * i 127 | adaptive = adaptive_resblock(adaptive, channel, mu[idx], var[idx], mu[idx + 1], var[idx + 1], use_bias=True, sn=self.sn, scope='ada_resbloack_' + str(i)) 128 | 129 | if self.attention : 130 | # Attention Bottleneck 131 | for i in range(self.n_res - 1) : 132 | attention = resblock(attention, channel, use_bias=False, sn=self.sn, scope='attention_resblock_' + str(i)) 133 | 134 | attention = conv(attention, 1, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=False, sn=self.sn, scope='attention_conv') 135 | attention = instance_norm(attention, scope='attention_ins_norm') 136 | attention = sigmoid(attention) 137 | 138 | x = attention * adaptive 139 | 140 | # attention_map = tf.concat([attention, attention, attention], axis=-1) * 2 - 1 141 | # attention_map = up_sample(attention_map, scale_factor=4) 142 | 143 | else : 144 | x = adaptive 145 | 146 | """ Decoder """ 147 | # Up-Sampling 148 | for i in range(2): 149 | x = deconv(x, channel // 2, kernel=4, stride=2, use_bias=False, sn=self.sn, scope='deconv_' + str(i)) 150 | x = instance_norm(x, scope='up_ins_norm' + str(i)) 151 | x = relu(x) 152 | 153 | channel = channel // 2 154 | 155 | x = conv(x, channels=self.img_ch, kernel=7, stride=1, pad=3, pad_type='reflect', use_bias=False, sn=self.sn, scope='G_logit') 156 | x = tanh(x) 157 | 158 | return x 159 | 160 | def MLP(self, style, channel, scope='MLP'): 161 | with tf.variable_scope(scope): 162 | x = style 163 | 164 | for i in range(2): 165 | x = fully_connected(x, channel, sn=self.sn, scope='FC_' + str(i)) 166 | x = relu(x) 167 | 168 | mu_list = [] 169 | var_list = [] 170 | 171 | for i in range(8): 172 | mu = fully_connected(x, channel, sn=self.sn, scope='FC_mu_' + str(i)) 173 | var = fully_connected(x, channel, sn=self.sn, scope='FC_var_' + str(i)) 174 | 175 | mu = tf.reshape(mu, shape=[-1, 1, 1, channel]) 176 | var = tf.reshape(var, shape=[-1, 1, 1, channel]) 177 | 178 | mu_list.append(mu) 179 | var_list.append(var) 180 | 181 | return mu_list, var_list 182 | 183 | ################################################################################## 184 | # Discriminator 185 | ################################################################################## 186 | 187 | def discriminator(self, x_init, reuse=False, scope="discriminator"): 188 | with tf.variable_scope(scope, reuse=reuse) : 189 | channel = self.ch 190 | x = conv(x_init, channel, kernel=4, stride=2, pad=1, pad_type='reflect', use_bias=True, sn=self.sn, scope='conv_0') 191 | x = lrelu(x, 0.01) 192 | 193 | for i in range(1, self.n_dis): 194 | x = conv(x, channel * 2, kernel=4, stride=2, pad=1, pad_type='reflect', use_bias=True, sn=self.sn, scope='conv_' + str(i)) 195 | x = lrelu(x, 0.01) 196 | 197 | channel = channel * 2 198 | 199 | c_kernel = int(self.img_height / np.power(2, self.n_dis)) 200 | 201 | logit = conv(x, channels=1, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=False, sn=self.sn, scope='D_logit') 202 | 203 | c = conv(x, channels=self.c_dim, kernel=c_kernel, stride=1, pad_type='reflect', use_bias=False, sn=self.sn, scope='D_label') 204 | c = tf.reshape(c, shape=[-1, self.c_dim]) 205 | 206 | noise = conv(x, channels=self.style_dim, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=False, sn=self.sn, scope='D_noise') 207 | noise = fully_connected(noise, units=self.style_dim, use_bias=True, sn=self.sn, scope='fc_0') 208 | noise = relu(noise) 209 | noise = fully_connected(noise, units=self.style_dim, use_bias=True, sn=self.sn, scope='fc_1') 210 | 211 | return logit, c, noise 212 | 213 | ################################################################################## 214 | # Model 215 | ################################################################################## 216 | 217 | def gradient_panalty(self, real, fake, scope="discriminator"): 218 | if self.gan_type.__contains__('dragan'): 219 | eps = tf.random_uniform(shape=tf.shape(real), minval=0., maxval=1.) 220 | _, x_var = tf.nn.moments(real, axes=[0, 1, 2, 3]) 221 | x_std = tf.sqrt(x_var) # magnitude of noise decides the size of local region 222 | 223 | fake = real + 0.5 * x_std * eps 224 | 225 | alpha = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0., maxval=1.) 226 | interpolated = real + alpha * (fake - real) 227 | 228 | logit, _, _ = self.discriminator(interpolated, reuse=True, scope=scope) 229 | 230 | 231 | GP = 0 232 | 233 | grad = tf.gradients(logit, interpolated)[0] # gradient of D(interpolated) 234 | grad_norm = tf.norm(flatten(grad), axis=-1) # l2 norm 235 | 236 | # WGAN - LP 237 | if self.gan_type == 'wgan-lp' : 238 | GP = self.gp_weight * tf.reduce_mean(tf.square(tf.maximum(0.0, grad_norm - 1.))) 239 | 240 | elif self.gan_type == 'wgan-gp' or self.gan_type == 'dragan': 241 | GP = self.gp_weight * tf.reduce_mean(tf.square(grad_norm - 1.)) 242 | 243 | return GP 244 | 245 | def build_model(self): 246 | label_fix_onehot_list = [] 247 | 248 | """ Input Image""" 249 | if self.dataset_name == 'celebA-HQ' or self.dataset_name == 'celebA': 250 | img_class = ImageData_celebA(self.img_height, self.img_width, self.img_ch, self.dataset_path, 251 | self.label_list, self.augment_flag) 252 | img_class.preprocess(self.phase) 253 | 254 | else: 255 | img_class = Image_data(self.img_height, self.img_width, self.img_ch, self.dataset_path, self.label_list, 256 | self.augment_flag) 257 | img_class.preprocess() 258 | 259 | label_fix_onehot_list = img_class.label_onehot_list 260 | label_fix_onehot_list = tf.tile(tf.expand_dims(label_fix_onehot_list, axis=1), [1, self.batch_size, 1]) 261 | 262 | dataset_num = len(img_class.image) 263 | print("Dataset number : ", dataset_num) 264 | 265 | if self.phase == 'train' : 266 | self.lr = tf.placeholder(tf.float32, name='learning_rate') 267 | 268 | if self.dataset_name == 'celebA-HQ' or self.dataset_name == 'celebA': 269 | img_and_label = tf.data.Dataset.from_tensor_slices( 270 | (img_class.image, img_class.label, img_class.train_label_onehot_list)) 271 | else: 272 | img_and_label = tf.data.Dataset.from_tensor_slices((img_class.image, img_class.label)) 273 | 274 | gpu_device = '/gpu:0' 275 | img_and_label = img_and_label.apply(shuffle_and_repeat(dataset_num)).apply( 276 | map_and_batch(img_class.image_processing, self.batch_size, num_parallel_batches=16, 277 | drop_remainder=True)).apply(prefetch_to_device(gpu_device, None)) 278 | 279 | img_and_label_iterator = img_and_label.make_one_shot_iterator() 280 | 281 | if self.dataset_name == 'celebA-HQ' or self.dataset_name == 'celebA': 282 | self.x_real, label_org, label_fix_onehot_list = img_and_label_iterator.get_next() 283 | label_trg = tf.random_shuffle(label_org) # Target domain labels 284 | label_fix_onehot_list = tf.transpose(label_fix_onehot_list, perm=[1, 0, 2]) 285 | else: 286 | self.x_real, label_org = img_and_label_iterator.get_next() 287 | label_trg = tf.random_shuffle(label_org) # Target domain labels 288 | 289 | 290 | """ Define Generator, Discriminator """ 291 | fake_style_code = tf.random_normal(shape=[self.batch_size, self.style_dim]) 292 | x_fake = self.generator(self.x_real, label_trg, fake_style_code) # real a 293 | 294 | recon_style_code = tf.random_normal(shape=[self.batch_size, self.style_dim]) 295 | x_recon = self.generator(x_fake, label_org, recon_style_code, reuse=True) # real b 296 | 297 | real_logit, real_cls, _ = self.discriminator(self.x_real) 298 | fake_logit, fake_cls, fake_noise = self.discriminator(x_fake, reuse=True) 299 | 300 | 301 | """ Define Loss """ 302 | if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan' : 303 | GP = self.gradient_panalty(real=self.x_real, fake=x_fake) 304 | else : 305 | GP = 0 306 | 307 | g_adv_loss = self.adv_weight * generator_loss(self.gan_type, fake_logit) 308 | g_cls_loss = self.cls_weight * classification_loss(logit=fake_cls, label=label_trg) 309 | g_rec_loss = self.rec_weight * L1_loss(self.x_real, x_recon) 310 | g_noise_loss = self.noise_weight * L1_loss(fake_style_code, fake_noise) 311 | 312 | d_adv_loss = self.adv_weight * discriminator_loss(self.gan_type, real_logit, fake_logit) + GP 313 | d_cls_loss = self.cls_weight * classification_loss(logit=real_cls, label=label_org) 314 | d_noise_loss = self.noise_weight * L1_loss(fake_style_code, fake_noise) 315 | 316 | self.d_loss = d_adv_loss + d_cls_loss + d_noise_loss 317 | self.g_loss = g_adv_loss + g_cls_loss + g_rec_loss + g_noise_loss 318 | 319 | 320 | """ Result Image """ 321 | if self.dataset_name == 'celebA-HQ' or self.dataset_name == 'celebA': 322 | self.x_fake_list = [] 323 | 324 | for _ in range(self.num_style): 325 | random_style_code = tf.random_normal(shape=[self.batch_size, self.style_dim]) 326 | self.x_fake_list.append(tf.map_fn(lambda c : self.generator(self.x_real, c, random_style_code, reuse=True), label_fix_onehot_list, dtype=tf.float32)) 327 | 328 | else : 329 | self.x_fake_list = [] 330 | 331 | for _ in range(self.num_style) : 332 | random_style_code = tf.random_normal(shape=[self.batch_size, self.style_dim]) 333 | self.x_fake_list.append(tf.map_fn(lambda c : self.generator(self.x_real, c, random_style_code, reuse=True), label_fix_onehot_list, dtype=tf.float32)) 334 | 335 | 336 | 337 | """ Training """ 338 | t_vars = tf.trainable_variables() 339 | G_vars = [var for var in t_vars if 'generator' in var.name] 340 | D_vars = [var for var in t_vars if 'discriminator' in var.name] 341 | 342 | self.g_optimizer = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.g_loss, var_list=G_vars) 343 | self.d_optimizer = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.d_loss, var_list=D_vars) 344 | 345 | 346 | """" Summary """ 347 | self.Generator_loss = tf.summary.scalar("g_loss", self.g_loss) 348 | self.Discriminator_loss = tf.summary.scalar("d_loss", self.d_loss) 349 | 350 | self.g_adv_loss = tf.summary.scalar("g_adv_loss", g_adv_loss) 351 | self.g_cls_loss = tf.summary.scalar("g_cls_loss", g_cls_loss) 352 | self.g_rec_loss = tf.summary.scalar("g_rec_loss", g_rec_loss) 353 | self.g_noise_loss = tf.summary.scalar("g_noise_loss", g_noise_loss) 354 | 355 | self.d_adv_loss = tf.summary.scalar("d_adv_loss", d_adv_loss) 356 | self.d_cls_loss = tf.summary.scalar("d_cls_loss", d_cls_loss) 357 | self.d_noise_loss = tf.summary.scalar("d_noise_loss", d_noise_loss) 358 | 359 | self.g_summary_loss = tf.summary.merge([self.Generator_loss, self.g_adv_loss, self.g_cls_loss, self.g_rec_loss, self.g_noise_loss]) 360 | self.d_summary_loss = tf.summary.merge([self.Discriminator_loss, self.d_adv_loss, self.d_cls_loss, self.d_noise_loss]) 361 | 362 | else : 363 | """ Test """ 364 | if self.dataset_name == 'celebA-HQ' or self.dataset_name == 'celebA': 365 | img_and_label = tf.data.Dataset.from_tensor_slices( 366 | (img_class.test_image, img_class.test_label, img_class.test_label_onehot_list)) 367 | dataset_num = len(img_class.test_image) 368 | 369 | gpu_device = '/gpu:0' 370 | img_and_label = img_and_label.apply(shuffle_and_repeat(dataset_num)).apply( 371 | map_and_batch(img_class.image_processing, batch_size=self.batch_size, num_parallel_batches=16, 372 | drop_remainder=True)).apply(prefetch_to_device(gpu_device, None)) 373 | 374 | img_and_label_iterator = img_and_label.make_one_shot_iterator() 375 | 376 | self.x_test, _, self.test_label_fix_onehot_list = img_and_label_iterator.get_next() 377 | self.test_img_placeholder = tf.placeholder(tf.float32, [1, self.img_height, self.img_width, self.img_ch]) 378 | self.test_label_fix_placeholder = tf.placeholder(tf.float32, [self.c_dim, 1, self.c_dim]) 379 | 380 | self.custom_image = tf.placeholder(tf.float32, [1, self.img_height, self.img_width, self.img_ch], name='custom_image') # Custom Image 381 | custom_label_fix_onehot_list = tf.transpose(np.expand_dims(label2onehot(self.label_list), axis=0), perm=[1, 0, 2]) # [c_dim, bs, c_dim] 382 | 383 | """ Test Image """ 384 | test_random_style_code = tf.random_normal(shape=[1, self.style_dim]) 385 | 386 | self.x_test_fake_list = tf.map_fn(lambda c : self.generator(self.test_img_placeholder, c, test_random_style_code), self.test_label_fix_placeholder, dtype=tf.float32) 387 | self.custom_fake_image = tf.map_fn(lambda c : self.generator(self.custom_image, c, test_random_style_code, reuse=True), custom_label_fix_onehot_list, dtype=tf.float32) 388 | 389 | else : 390 | self.custom_image = tf.placeholder(tf.float32, [1, self.img_height, self.img_width, self.img_ch], name='custom_image') # Custom Image 391 | custom_label_fix_onehot_list = tf.transpose(np.expand_dims(label2onehot(self.label_list), axis=0), perm=[1, 0, 2]) # [c_dim, bs, c_dim] 392 | 393 | test_random_style_code = tf.random_normal(shape=[1, self.style_dim]) 394 | self.custom_fake_image = tf.map_fn(lambda c : self.generator(self.custom_image, c, test_random_style_code), custom_label_fix_onehot_list, dtype=tf.float32) 395 | 396 | 397 | 398 | def train(self): 399 | # initialize all variables 400 | tf.global_variables_initializer().run() 401 | 402 | # saver to save model 403 | self.saver = tf.train.Saver(max_to_keep=10) 404 | 405 | # summary writer 406 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_dir, self.sess.graph) 407 | 408 | # restore check-point if it exits 409 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 410 | if could_load: 411 | start_epoch = (int)(checkpoint_counter / self.iteration) 412 | start_batch_id = checkpoint_counter - start_epoch * self.iteration 413 | counter = checkpoint_counter 414 | print(" [*] Load SUCCESS") 415 | else: 416 | start_epoch = 0 417 | start_batch_id = 0 418 | counter = 1 419 | print(" [!] Load failed...") 420 | 421 | self.sample_dir = os.path.join(self.sample_dir, self.model_dir) 422 | check_folder(self.sample_dir) 423 | 424 | # loop for epoch 425 | start_time = time.time() 426 | past_g_loss = -1. 427 | lr = self.init_lr 428 | for epoch in range(start_epoch, self.epoch): 429 | if self.decay_flag : 430 | lr = self.init_lr if epoch < self.decay_epoch else self.init_lr * (self.epoch - epoch) / (self.epoch - self.decay_epoch) # linear decay 431 | 432 | for idx in range(start_batch_id, self.iteration): 433 | train_feed_dict = { 434 | self.lr : lr 435 | } 436 | 437 | # Update D 438 | _, d_loss, summary_str = self.sess.run([self.d_optimizer, self.d_loss, self.d_summary_loss], feed_dict = train_feed_dict) 439 | self.writer.add_summary(summary_str, counter) 440 | 441 | # Update G 442 | g_loss = None 443 | if (counter - 1) % self.n_critic == 0 : 444 | real_images, fake_images, _, g_loss, summary_str = self.sess.run([self.x_real, self.x_fake_list, self.g_optimizer, self.g_loss, self.g_summary_loss], feed_dict = train_feed_dict) 445 | self.writer.add_summary(summary_str, counter) 446 | past_g_loss = g_loss 447 | 448 | # display training status 449 | counter += 1 450 | if g_loss == None : 451 | g_loss = past_g_loss 452 | 453 | print("Epoch: [%2d] [%5d/%5d] time: %4.4f d_loss: %.8f, g_loss: %.8f" % (epoch, idx, self.iteration, time.time() - start_time, d_loss, g_loss)) 454 | 455 | if np.mod(idx+1, self.print_freq) == 0 : 456 | real_image = np.expand_dims(real_images[0], axis=0) 457 | save_images(real_image, [1, 1], 458 | './{}/real_{:03d}_{:05d}.jpg'.format(self.sample_dir, epoch, idx+1)) 459 | 460 | merge_fake_x = None 461 | 462 | for ns in range(self.num_style) : 463 | fake_img = np.transpose(fake_images[ns], axes=[1, 0, 2, 3, 4])[0] 464 | 465 | if ns == 0 : 466 | merge_fake_x = return_images(fake_img, [1, self.c_dim]) # [self.img_height, self.img_width * self.c_dim, self.img_ch] 467 | else : 468 | x = return_images(fake_img, [1, self.c_dim]) 469 | merge_fake_x = np.concatenate([merge_fake_x, x], axis=0) 470 | 471 | merge_fake_x = np.expand_dims(merge_fake_x, axis=0) 472 | save_images(merge_fake_x, [1, 1], 473 | './{}/fake_{:03d}_{:05d}.jpg'.format(self.sample_dir, epoch, idx+1)) 474 | 475 | if np.mod(counter - 1, self.save_freq) == 0: 476 | self.save(self.checkpoint_dir, counter) 477 | 478 | # After an epoch, start_batch_id is set to zero 479 | # non-zero value is only for the first epoch after loading pre-trained model 480 | start_batch_id = 0 481 | 482 | # save model for final step 483 | self.save(self.checkpoint_dir, counter) 484 | 485 | @property 486 | def model_dir(self): 487 | 488 | if self.sn: 489 | sn = '_sn' 490 | else: 491 | sn = '' 492 | 493 | if self.attention: 494 | attention = '_attention' 495 | else: 496 | attention = '' 497 | 498 | return "{}_{}_{}_{}adv_{}rec_{}cls_{}noise{}{}".format(self.model_name, self.dataset_name, self.gan_type, 499 | self.adv_weight, self.rec_weight, self.cls_weight, self.noise_weight, 500 | sn, attention) 501 | 502 | def save(self, checkpoint_dir, step): 503 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) 504 | 505 | if not os.path.exists(checkpoint_dir): 506 | os.makedirs(checkpoint_dir) 507 | 508 | self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step) 509 | 510 | def load(self, checkpoint_dir): 511 | print(" [*] Reading checkpoints...") 512 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) 513 | 514 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 515 | if ckpt and ckpt.model_checkpoint_path: 516 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 517 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 518 | counter = int(ckpt_name.split('-')[-1]) 519 | print(" [*] Success to read {}".format(ckpt_name)) 520 | return True, counter 521 | else: 522 | print(" [*] Failed to find a checkpoint") 523 | return False, 0 524 | 525 | def test(self): 526 | tf.global_variables_initializer().run() 527 | test_files = glob('./dataset/{}/{}/*.jpg'.format(self.dataset_name, 'test')) + glob('./dataset/{}/{}/*.png'.format(self.dataset_name, 'test')) 528 | 529 | self.saver = tf.train.Saver() 530 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 531 | self.result_dir = os.path.join(self.result_dir, self.model_dir) 532 | check_folder(self.result_dir) 533 | 534 | custom_image_folder = os.path.join(self.result_dir, 'custom_fake_images') 535 | check_folder(custom_image_folder) 536 | 537 | if could_load : 538 | print(" [*] Load SUCCESS") 539 | else : 540 | print(" [!] Load failed...") 541 | 542 | # write html for visual comparison 543 | index_path = os.path.join(self.result_dir, 'index.html') 544 | index = open(index_path, 'w') 545 | index.write("") 546 | index.write("") 547 | 548 | # Custom Image 549 | for sample_file in tqdm(test_files): 550 | print("Processing image: " + sample_file) 551 | sample_image = load_test_image(sample_file, self.img_width, self.img_height, self.img_ch) 552 | image_path = os.path.join(custom_image_folder, '{}'.format(os.path.basename(sample_file))) 553 | 554 | merge_x = None 555 | 556 | for i in range(self.num_style) : 557 | fake_img = self.sess.run(self.custom_fake_image, feed_dict={self.custom_image: sample_image}) 558 | fake_img = np.transpose(fake_img, axes=[1, 0, 2, 3, 4])[0] 559 | 560 | if i == 0: 561 | merge_x = return_images(fake_img, [1, self.c_dim]) # [self.img_height, self.img_width * self.c_dim, self.img_ch] 562 | else : 563 | x = return_images(fake_img, [1, self.c_dim]) 564 | merge_x = np.concatenate([merge_x, x], axis=0) 565 | 566 | merge_x = np.expand_dims(merge_x, axis=0) 567 | 568 | save_images(merge_x, [1, 1], image_path) 569 | 570 | index.write("" % os.path.basename(image_path)) 571 | index.write("" % (sample_file if os.path.isabs(sample_file) else ( 572 | '../..' + os.path.sep + sample_file), self.img_width, self.img_height)) 573 | 574 | index.write("" % (image_path if os.path.isabs(image_path) else ( 575 | '../..' + os.path.sep + image_path), self.img_width * self.c_dim, self.img_height * self.num_style)) 576 | index.write("") 577 | 578 | if self.dataset_name == 'celebA-HQ' or self.dataset_name == 'celebA': 579 | # CelebA 580 | celebA_image_folder = os.path.join(self.result_dir, 'celebA_real_fake_images') 581 | check_folder(celebA_image_folder) 582 | real_images, real_label_fixes = self.sess.run([self.x_test, self.test_label_fix_onehot_list]) 583 | 584 | for i in tqdm(range(len(real_images))) : 585 | 586 | real_path = os.path.join(celebA_image_folder, 'real_{}.png'.format(i)) 587 | fake_path = os.path.join(celebA_image_folder, 'fake_{}.png'.format(i)) 588 | 589 | real_img = np.expand_dims(real_images[i], axis=0) 590 | real_label_fix = np.expand_dims(real_label_fixes[i], axis=1) 591 | 592 | merge_x = None 593 | 594 | for ns in range(self.num_style) : 595 | fake_img = self.sess.run(self.x_test_fake_list, feed_dict={self.test_img_placeholder: real_img, self.test_label_fix_placeholder:real_label_fix}) 596 | fake_img = np.transpose(fake_img, axes=[1, 0, 2, 3, 4])[0] 597 | 598 | if ns == 0: 599 | merge_x = return_images(fake_img, [1, self.c_dim]) # [self.img_height, self.img_width * self.c_dim, self.img_ch] 600 | else: 601 | x = return_images(fake_img, [1, self.c_dim]) 602 | merge_x = np.concatenate([merge_x, x], axis=0) 603 | 604 | merge_x = np.expand_dims(merge_x, axis=0) 605 | 606 | save_images(real_img, [1, 1], real_path) 607 | save_images(merge_x, [1, 1], fake_path) 608 | 609 | index.write("" % os.path.basename(real_path)) 610 | index.write("" % (real_path if os.path.isabs(real_path) else ( 611 | '../..' + os.path.sep + real_path), self.img_width, self.img_height)) 612 | 613 | index.write("" % (fake_path if os.path.isabs(fake_path) else ( 614 | '../..' + os.path.sep + fake_path), self.img_width * self.c_dim, self.img_height * self.num_style)) 615 | index.write("") 616 | 617 | index.close() -------------------------------------------------------------------------------- /assets/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/SDIT-Tensorflow/71f9917325d647d3d51e691e85f5a3079e068da0/assets/framework.png -------------------------------------------------------------------------------- /assets/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/SDIT-Tensorflow/71f9917325d647d3d51e691e85f5a3079e068da0/assets/result.png -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/SDIT-Tensorflow/71f9917325d647d3d51e691e85f5a3079e068da0/assets/teaser.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from SDIT import SDIT 2 | import argparse 3 | from utils import * 4 | 5 | """parsing and configuration""" 6 | def parse_args(): 7 | desc = "Tensorflow implementation of SDIT" 8 | parser = argparse.ArgumentParser(description=desc) 9 | parser.add_argument('--phase', type=str, default='train', help='train or test ?') 10 | parser.add_argument('--attention', type=str2bool, default=True, choices=[True, False]) 11 | parser.add_argument('--dataset', type=str, default='celebA', help='dataset_name') 12 | 13 | parser.add_argument('--epoch', type=int, default=20, help='The number of epochs to run') 14 | parser.add_argument('--iteration', type=int, default=10000, help='The number of training iterations') 15 | # The total number of iterations is [epoch * iteration] 16 | 17 | parser.add_argument('--batch_size', type=int, default=16, help='The size of batch size') 18 | parser.add_argument('--print_freq', type=int, default=1000, help='The number of image_print_freq') 19 | parser.add_argument('--save_freq', type=int, default=10000, help='The number of ckpt_save_freq') 20 | parser.add_argument('--decay_flag', type=str2bool, default=True, help='The decay_flag') 21 | parser.add_argument('--decay_epoch', type=int, default=10, help='decay epoch') 22 | 23 | parser.add_argument('--lr', type=float, default=0.0001, help='The learning rate') 24 | parser.add_argument('--adv_weight', type=float, default=1, help='Weight about GAN') 25 | parser.add_argument('--rec_weight', type=float, default=10, help='Weight about Reconstruction') 26 | parser.add_argument('--cls_weight', type=float, default=10, help='Weight about Classification') 27 | parser.add_argument('--gp_weight', type=float, default=10, help='The gradient penalty lambda') 28 | parser.add_argument('--noise_weight', type=float, default=800, help='weight of noise for reconstruction loss') 29 | 30 | parser.add_argument('--gan_type', type=str, default='wgan-gp', help='gan / lsgan / wgan-gp / wgan-lp / dragan / hinge') 31 | parser.add_argument('--sn', type=str2bool, default=False, help='using spectral norm') 32 | parser.add_argument('--label_list', type=str, nargs='+', help='selected attributes for the CelebA dataset', 33 | default=['Blond_Hair', 'Brown_Hair', 'Male', 'Eyeglasses', 'Bangs']) 34 | 35 | parser.add_argument('--ch', type=int, default=64, help='base channel number per layer') 36 | parser.add_argument('--n_res', type=int, default=6, help='The number of resblock') 37 | parser.add_argument('--n_dis', type=int, default=6, help='The number of discriminator layer') 38 | parser.add_argument('--n_critic', type=int, default=5, help='number of D updates per each G update') 39 | parser.add_argument('--style_dim', type=int, default=8, help='length of style code') 40 | 41 | parser.add_argument('--num_style', type=int, default=5, help='number of styles to sample') 42 | 43 | parser.add_argument('--img_height', type=int, default=128, help='The height size of image') 44 | parser.add_argument('--img_width', type=int, default=128, help='The width size of image ') 45 | parser.add_argument('--img_ch', type=int, default=3, help='The size of image channel') 46 | parser.add_argument('--augment_flag', type=str2bool, default=True, help='Image augmentation use or not') 47 | 48 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint', 49 | help='Directory name to save the checkpoints') 50 | parser.add_argument('--result_dir', type=str, default='results', 51 | help='Directory name to save the generated images') 52 | parser.add_argument('--log_dir', type=str, default='logs', 53 | help='Directory name to save training logs') 54 | parser.add_argument('--sample_dir', type=str, default='samples', 55 | help='Directory name to save the samples on training') 56 | 57 | return check_args(parser.parse_args()) 58 | 59 | """checking arguments""" 60 | def check_args(args): 61 | # --checkpoint_dir 62 | check_folder(args.checkpoint_dir) 63 | 64 | # --result_dir 65 | check_folder(args.result_dir) 66 | 67 | # --result_dir 68 | check_folder(args.log_dir) 69 | 70 | # --sample_dir 71 | check_folder(args.sample_dir) 72 | 73 | # --epoch 74 | try: 75 | assert args.epoch >= 1 76 | except: 77 | print('number of epochs must be larger than or equal to one') 78 | 79 | # --batch_size 80 | try: 81 | assert args.batch_size >= 1 82 | except: 83 | print('batch size must be larger than or equal to one') 84 | return args 85 | 86 | """main""" 87 | def main(): 88 | # parse arguments 89 | args = parse_args() 90 | if args is None: 91 | exit() 92 | 93 | # open session 94 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: 95 | gan = SDIT(sess, args) 96 | 97 | # build graph 98 | gan.build_model() 99 | 100 | # show network architecture 101 | show_all_variables() 102 | 103 | if args.phase == 'train' : 104 | gan.train() 105 | print(" [*] Training finished!") 106 | 107 | if args.phase == 'test' : 108 | gan.test() 109 | print(" [*] Test finished!") 110 | 111 | if __name__ == '__main__': 112 | main() -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib as tf_contrib 3 | from utils import pytorch_xavier_weight_factor, pytorch_kaiming_weight_factor 4 | 5 | ################################################################################## 6 | # Initialization 7 | ################################################################################## 8 | 9 | factor, mode, uniform = pytorch_xavier_weight_factor(gain=0.02, uniform=False) 10 | weight_init = tf_contrib.layers.variance_scaling_initializer(factor=factor, mode=mode, uniform=uniform) 11 | weight_regularizer = None 12 | weight_regularizer_fully = None 13 | 14 | 15 | ################################################################################## 16 | # Layers 17 | ################################################################################## 18 | 19 | # padding='SAME' ======> pad = floor[ (kernel - stride) / 2 ] 20 | def conv(x, channels, kernel=4, stride=2, pad=0, pad_type='zero', use_bias=True, sn=False, scope='conv_0'): 21 | with tf.variable_scope(scope): 22 | if pad > 0: 23 | h = x.get_shape().as_list()[1] 24 | if h % stride == 0: 25 | pad = pad * 2 26 | else: 27 | pad = max(kernel - (h % stride), 0) 28 | 29 | pad_top = pad // 2 30 | pad_bottom = pad - pad_top 31 | pad_left = pad // 2 32 | pad_right = pad - pad_left 33 | 34 | if pad_type == 'zero': 35 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]) 36 | if pad_type == 'reflect': 37 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]], mode='REFLECT') 38 | 39 | if sn: 40 | w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], initializer=weight_init, 41 | regularizer=weight_regularizer) 42 | x = tf.nn.conv2d(input=x, filter=spectral_norm(w), 43 | strides=[1, stride, stride, 1], padding='VALID') 44 | if use_bias: 45 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0)) 46 | x = tf.nn.bias_add(x, bias) 47 | 48 | else: 49 | x = tf.layers.conv2d(inputs=x, filters=channels, 50 | kernel_size=kernel, kernel_initializer=weight_init, 51 | kernel_regularizer=weight_regularizer, 52 | strides=stride, use_bias=use_bias) 53 | 54 | return x 55 | 56 | 57 | def deconv(x, channels, kernel=4, stride=2, padding='SAME', use_bias=True, sn=False, scope='deconv_0'): 58 | with tf.variable_scope(scope): 59 | x_shape = x.get_shape().as_list() 60 | 61 | if padding == 'SAME': 62 | output_shape = [x_shape[0], x_shape[1] * stride, x_shape[2] * stride, channels] 63 | 64 | else: 65 | output_shape = [x_shape[0], x_shape[1] * stride + max(kernel - stride, 0), 66 | x_shape[2] * stride + max(kernel - stride, 0), channels] 67 | 68 | if sn: 69 | w = tf.get_variable("kernel", shape=[kernel, kernel, channels, x.get_shape()[-1]], initializer=weight_init, 70 | regularizer=weight_regularizer) 71 | x = tf.nn.conv2d_transpose(x, filter=spectral_norm(w), output_shape=output_shape, 72 | strides=[1, stride, stride, 1], padding=padding) 73 | 74 | if use_bias: 75 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0)) 76 | x = tf.nn.bias_add(x, bias) 77 | 78 | else: 79 | x = tf.layers.conv2d_transpose(inputs=x, filters=channels, 80 | kernel_size=kernel, kernel_initializer=weight_init, 81 | kernel_regularizer=weight_regularizer, 82 | strides=stride, padding=padding, use_bias=use_bias) 83 | 84 | return x 85 | 86 | def fully_connected(x, units, use_bias=True, sn=False, scope='linear'): 87 | with tf.variable_scope(scope): 88 | x = flatten(x) 89 | shape = x.get_shape().as_list() 90 | channels = shape[-1] 91 | 92 | if sn: 93 | w = tf.get_variable("kernel", [channels, units], tf.float32, 94 | initializer=weight_init, regularizer=weight_regularizer_fully) 95 | if use_bias: 96 | bias = tf.get_variable("bias", [units], 97 | initializer=tf.constant_initializer(0.0)) 98 | 99 | x = tf.matmul(x, spectral_norm(w)) + bias 100 | else: 101 | x = tf.matmul(x, spectral_norm(w)) 102 | 103 | else: 104 | x = tf.layers.dense(x, units=units, kernel_initializer=weight_init, 105 | kernel_regularizer=weight_regularizer_fully, 106 | use_bias=use_bias) 107 | 108 | return x 109 | 110 | def flatten(x) : 111 | return tf.layers.flatten(x) 112 | 113 | ################################################################################## 114 | # Residual-block 115 | ################################################################################## 116 | 117 | def resblock(x_init, channels, use_bias=True, sn=False, scope='resblock'): 118 | with tf.variable_scope(scope): 119 | with tf.variable_scope('res1'): 120 | x = conv(x_init, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias, sn=sn) 121 | x = instance_norm(x) 122 | x = relu(x) 123 | 124 | with tf.variable_scope('res2'): 125 | x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias, sn=sn) 126 | x = instance_norm(x) 127 | 128 | return x + x_init 129 | 130 | def adaptive_resblock(x_init, channels, gamma1, beta1, gamma2, beta2, use_bias=True, sn=False, scope='adaptive_resblock') : 131 | with tf.variable_scope(scope): 132 | with tf.variable_scope('res1'): 133 | x = conv(x_init, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias, sn=sn) 134 | x = adaptive_instance_norm(x, gamma1, beta1) 135 | x = relu(x) 136 | 137 | with tf.variable_scope('res2'): 138 | x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias, sn=sn) 139 | x = adaptive_instance_norm(x, gamma2, beta2) 140 | 141 | return x + x_init 142 | 143 | 144 | ################################################################################## 145 | # Activation function 146 | ################################################################################## 147 | 148 | def lrelu(x, alpha=0.2): 149 | return tf.nn.leaky_relu(x, alpha) 150 | 151 | 152 | def relu(x): 153 | return tf.nn.relu(x) 154 | 155 | 156 | def tanh(x): 157 | return tf.tanh(x) 158 | 159 | def sigmoid(x): 160 | return tf.sigmoid(x) 161 | 162 | ################################################################################## 163 | # Pooling & Resize 164 | ################################################################################## 165 | 166 | def up_sample(x, scale_factor=2): 167 | _, h, w, _ = x.get_shape().as_list() 168 | new_size = [h * scale_factor, w * scale_factor] 169 | return tf.image.resize_bilinear(x, size=new_size) 170 | 171 | 172 | ################################################################################## 173 | # Normalization function 174 | ################################################################################## 175 | 176 | def instance_norm(x, scope='instance_norm'): 177 | return tf_contrib.layers.instance_norm(x, 178 | epsilon=1e-05, 179 | center=True, scale=True, 180 | scope=scope) 181 | 182 | def adaptive_instance_norm(content, gamma, beta, epsilon=1e-5): 183 | 184 | c_mean, c_var = tf.nn.moments(content, axes=[1, 2], keep_dims=True) 185 | c_std = tf.sqrt(c_var + epsilon) 186 | 187 | return gamma * ((content - c_mean) / c_std) + beta 188 | 189 | def spectral_norm(w, iteration=1): 190 | w_shape = w.shape.as_list() 191 | w = tf.reshape(w, [-1, w_shape[-1]]) 192 | 193 | u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.random_normal_initializer(), trainable=False) 194 | 195 | u_hat = u 196 | v_hat = None 197 | for i in range(iteration): 198 | """ 199 | power iteration 200 | Usually iteration = 1 will be enough 201 | """ 202 | v_ = tf.matmul(u_hat, tf.transpose(w)) 203 | v_hat = tf.nn.l2_normalize(v_) 204 | 205 | u_ = tf.matmul(v_hat, w) 206 | u_hat = tf.nn.l2_normalize(u_) 207 | 208 | u_hat = tf.stop_gradient(u_hat) 209 | v_hat = tf.stop_gradient(v_hat) 210 | 211 | sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat)) 212 | 213 | with tf.control_dependencies([u.assign(u_hat)]): 214 | w_norm = w / sigma 215 | w_norm = tf.reshape(w_norm, w_shape) 216 | 217 | return w_norm 218 | 219 | ################################################################################## 220 | # Loss function 221 | ################################################################################## 222 | 223 | def discriminator_loss(loss_func, real_logit, fake_logit): 224 | real_loss = 0 225 | fake_loss = 0 226 | 227 | if loss_func.__contains__('wgan') : 228 | real_loss = -tf.reduce_mean(real_logit) 229 | fake_loss = tf.reduce_mean(fake_logit) 230 | 231 | if loss_func == 'lsgan' : 232 | real_loss = tf.reduce_mean(tf.squared_difference(real_logit, 1.0)) 233 | fake_loss = tf.reduce_mean(tf.square(fake_logit)) 234 | 235 | if loss_func == 'gan' or loss_func == 'dragan' : 236 | real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real_logit), logits=real_logit)) 237 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(fake_logit), logits=fake_logit)) 238 | 239 | if loss_func == 'hinge' : 240 | real_loss = tf.reduce_mean(relu(1.0 - real_logit)) 241 | fake_loss = tf.reduce_mean(relu(1.0 + fake_logit)) 242 | 243 | loss = real_loss + fake_loss 244 | 245 | return loss 246 | 247 | def generator_loss(loss_func, fake_logit): 248 | fake_loss = 0 249 | 250 | if loss_func.__contains__('wgan') : 251 | fake_loss = -tf.reduce_mean(fake_logit) 252 | 253 | if loss_func == 'lsgan' : 254 | fake_loss = tf.reduce_mean(tf.squared_difference(fake_logit, 1.0)) 255 | 256 | if loss_func == 'gan' or loss_func == 'dragan' : 257 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(fake_logit), logits=fake_logit)) 258 | 259 | if loss_func == 'hinge' : 260 | fake_loss = -tf.reduce_mean(fake_logit) 261 | 262 | loss = fake_loss 263 | 264 | return loss 265 | 266 | def classification_loss(logit, label) : 267 | loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=label, logits=logit)) 268 | 269 | return loss 270 | 271 | def L1_loss(x, y): 272 | loss = tf.reduce_mean(tf.abs(x - y)) 273 | 274 | return loss -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import cv2 4 | 5 | import tensorflow as tf 6 | import tensorflow.contrib.slim as slim 7 | import random 8 | from glob import glob 9 | from tqdm import tqdm 10 | 11 | class Image_data: 12 | 13 | def __init__(self, img_height, img_width, channels, dataset_path, label_list, augment_flag): 14 | self.img_height = img_height 15 | self.img_width = img_width 16 | self.channels = channels 17 | self.augment_flag = augment_flag 18 | 19 | self.label_list = label_list 20 | self.dataset_path = dataset_path 21 | 22 | self.label_onehot_list = [] 23 | self.image = [] 24 | self.label = [] 25 | 26 | 27 | def image_processing(self, filename, label): 28 | x = tf.read_file(filename) 29 | x_decode = tf.image.decode_jpeg(x, channels=self.channels, dct_method='INTEGER_ACCURATE') 30 | img = tf.image.resize_images(x_decode, [self.img_height, self.img_width]) 31 | img = tf.cast(img, tf.float32) / 127.5 - 1 32 | 33 | 34 | if self.augment_flag : 35 | augment_height_size = self.img_height + (30 if self.img_height == 256 else int(self.img_height * 0.1)) 36 | augment_width_size = self.img_width + (30 if self.img_width == 256 else int(self.img_width * 0.1)) 37 | 38 | img = tf.cond(pred=tf.greater_equal(tf.random_uniform(shape=[], minval=0.0, maxval=1.0), 0.5), 39 | true_fn=lambda : augmentation(img, augment_height_size, augment_width_size), 40 | false_fn=lambda : img) 41 | 42 | return img, label 43 | 44 | def preprocess(self): 45 | # self.label_list = ['tiger', 'cat', 'dog', 'lion'] 46 | 47 | v = 0 48 | 49 | for label in self.label_list : # fabric 50 | label_one_hot = list(get_one_hot(v, len(self.label_list))) # [1, 0, 0, 0, 0] 51 | self.label_onehot_list.append(label_one_hot) 52 | v = v+1 53 | 54 | image_list = glob(os.path.join(self.dataset_path, label) + '/*.png') + glob(os.path.join(self.dataset_path, label) + '/*.jpg') 55 | label_one_hot = [label_one_hot] * len(image_list) 56 | 57 | self.image.extend(image_list) 58 | self.label.extend(label_one_hot) 59 | 60 | class ImageData_celebA: 61 | 62 | def __init__(self, img_height, img_width, channels, dataset_path, label_list, augment_flag): 63 | self.img_height = img_height 64 | self.img_width = img_width 65 | self.channels = channels 66 | self.augment_flag = augment_flag 67 | self.label_list = label_list 68 | 69 | self.dataset_path = os.path.join(dataset_path, 'train') 70 | self.file_name_list = [os.path.basename(x) for x in glob(self.dataset_path + '/*.png')] 71 | self.lines = open(os.path.join(dataset_path, 'list_attr_celeba.txt'), 'r').readlines() 72 | 73 | self.image = [] 74 | self.label = [] 75 | 76 | self.test_image = [] 77 | self.test_label = [] 78 | 79 | self.attr2idx = {} 80 | self.idx2attr = {} 81 | 82 | self.train_label_onehot_list = [] 83 | self.test_label_onehot_list = [] 84 | 85 | def image_processing(self, filename, label, fix_label): 86 | x = tf.read_file(filename) 87 | x_decode = tf.image.decode_jpeg(x, channels=self.channels, dct_method='INTEGER_ACCURATE') 88 | img = tf.image.resize_images(x_decode, [self.img_height, self.img_width]) 89 | img = tf.cast(img, tf.float32) / 127.5 - 1 90 | 91 | if self.augment_flag : 92 | augment_height = self.img_height + (30 if self.img_height == 256 else int(self.img_height * 0.1)) 93 | augment_width = self.img_width + (30 if self.img_width == 256 else int(self.img_width * 0.1)) 94 | 95 | img = tf.cond(pred=tf.greater_equal(tf.random_uniform(shape=[], minval=0.0, maxval=1.0), 0.5), 96 | true_fn=lambda: augmentation(img, augment_height, augment_width), 97 | false_fn=lambda: img) 98 | 99 | 100 | return img, label, fix_label 101 | 102 | def preprocess(self, phase): 103 | 104 | all_attr_names = self.lines[1].split() 105 | for i, attr_name in enumerate(all_attr_names): 106 | self.attr2idx[attr_name] = i 107 | self.idx2attr[i] = attr_name 108 | 109 | lines = self.lines[2:] 110 | random.seed(1234) 111 | random.shuffle(lines) 112 | 113 | for i, line in enumerate(tqdm(lines)): 114 | split = line.split() 115 | if split[0] in self.file_name_list: 116 | filename = os.path.join(self.dataset_path, split[0]) 117 | values = split[1:] 118 | 119 | label = [] 120 | 121 | for attr_name in self.label_list: 122 | idx = self.attr2idx[attr_name] 123 | 124 | if values[idx] == '1': 125 | label.append(1.0) 126 | else: 127 | label.append(0.0) 128 | 129 | if i < 2000: 130 | self.test_image.append(filename) 131 | self.test_label.append(label) 132 | else: 133 | if phase == 'test' : 134 | break 135 | self.image.append(filename) 136 | self.label.append(label) 137 | # ['./dataset/celebA/train/019932.png', [1, 0, 0, 0, 1]] 138 | 139 | print() 140 | 141 | self.test_label_onehot_list = create_labels(self.test_label, self.label_list) 142 | if phase == 'train' : 143 | self.train_label_onehot_list = create_labels(self.label, self.label_list) 144 | 145 | print('\n Finished preprocessing the CelebA dataset...') 146 | 147 | def load_test_image(image_path, img_width, img_height, img_channel): 148 | 149 | if img_channel == 1 : 150 | img = cv2.imread(image_path, flags=cv2.IMREAD_GRAYSCALE) 151 | else : 152 | img = cv2.imread(image_path, flags=cv2.IMREAD_COLOR) 153 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 154 | 155 | img = cv2.resize(img, dsize=(img_width, img_height)) 156 | 157 | if img_channel == 1 : 158 | img = np.expand_dims(img, axis=0) 159 | img = np.expand_dims(img, axis=-1) 160 | else : 161 | img = np.expand_dims(img, axis=0) 162 | 163 | img = img/127.5 - 1 164 | 165 | return img 166 | 167 | def load_one_hot_vector(label_list, target_label) : 168 | label_onehot_dict = {} 169 | 170 | v = 0 171 | for label_name in label_list : 172 | label_one_hot = list(get_one_hot(v, len(label_list))) 173 | label_onehot_dict[label_name] = [label_one_hot] 174 | 175 | x = label_onehot_dict[target_label] 176 | 177 | 178 | return x 179 | 180 | def label2onehot(label_list) : 181 | v = 0 182 | label_onehot_list = [] 183 | for _ in label_list: # fabric 184 | label_one_hot = list(get_one_hot(v, len(label_list))) # [1, 0, 0, 0, 0] 185 | label_onehot_list.append(label_one_hot) 186 | v = v + 1 187 | 188 | return label_onehot_list 189 | 190 | def augmentation(image, augment_height, augment_width): 191 | seed = random.randint(0, 2 ** 31 - 1) 192 | ori_image_shape = tf.shape(image) 193 | image = tf.image.random_flip_left_right(image, seed=seed) 194 | image = tf.image.resize_images(image, [augment_height, augment_width]) 195 | image = tf.random_crop(image, ori_image_shape, seed=seed) 196 | return image 197 | 198 | 199 | def save_images(images, size, image_path): 200 | return imsave(inverse_transform(images), size, image_path) 201 | 202 | def inverse_transform(images): 203 | return ((images+1.) / 2) * 255.0 204 | 205 | 206 | def imsave(images, size, path): 207 | images = merge(images, size) 208 | images = cv2.cvtColor(images.astype('uint8'), cv2.COLOR_RGB2BGR) 209 | 210 | return cv2.imwrite(path, images) 211 | 212 | def merge(images, size): 213 | h, w = images.shape[1], images.shape[2] 214 | img = np.zeros((h * size[0], w * size[1], 3)) 215 | for idx, image in enumerate(images): 216 | i = idx % size[1] 217 | j = idx // size[1] 218 | img[h*j:h*(j+1), w*i:w*(i+1), :] = image 219 | 220 | return img 221 | 222 | def return_images(images, size) : 223 | x = merge(images, size) 224 | 225 | return x 226 | 227 | def check_folder(log_dir): 228 | if not os.path.exists(log_dir): 229 | os.makedirs(log_dir) 230 | return log_dir 231 | 232 | def show_all_variables(): 233 | model_vars = tf.trainable_variables() 234 | slim.model_analyzer.analyze_vars(model_vars, print_info=True) 235 | 236 | def str2bool(x): 237 | return x.lower() in ('true') 238 | 239 | def get_one_hot(targets, nb_classes): 240 | 241 | x = np.eye(nb_classes)[targets] 242 | 243 | return x 244 | 245 | def create_labels(c_org, selected_attrs=None): 246 | """Generate target domain labels for debugging and testing.""" 247 | # Get hair color indices. 248 | c_org = np.asarray(c_org) 249 | hair_color_indices = [] 250 | for i, attr_name in enumerate(selected_attrs): 251 | if attr_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']: 252 | hair_color_indices.append(i) 253 | 254 | c_trg_list = [] 255 | 256 | for i in range(len(selected_attrs)): 257 | c_trg = c_org.copy() 258 | 259 | if i in hair_color_indices: # Set one hair color to 1 and the rest to 0. 260 | c_trg[:, i] = 1.0 261 | for j in hair_color_indices: 262 | if j != i: 263 | c_trg[:, j] = 0.0 264 | else: 265 | c_trg[:, i] = (c_trg[:, i] == 0) # Reverse attribute value. 266 | 267 | c_trg_list.append(c_trg) 268 | 269 | c_trg_list = np.transpose(c_trg_list, axes=[1, 0, 2]) # [bs, c_dim, c_dim] 270 | 271 | return c_trg_list 272 | 273 | def pytorch_xavier_weight_factor(gain=0.02, uniform=False) : 274 | 275 | if uniform : 276 | factor = gain * gain 277 | mode = 'FAN_AVG' 278 | else : 279 | factor = (gain * gain) / 1.3 280 | mode = 'FAN_AVG' 281 | 282 | return factor, mode, uniform 283 | 284 | def pytorch_kaiming_weight_factor(a=0.0, activation_function='relu', uniform=False) : 285 | 286 | if activation_function == 'relu' : 287 | gain = np.sqrt(2.0) 288 | elif activation_function == 'leaky_relu' : 289 | gain = np.sqrt(2.0 / (1 + a ** 2)) 290 | elif activation_function =='tanh' : 291 | gain = 5.0 / 3 292 | else : 293 | gain = 1.0 294 | 295 | if uniform : 296 | factor = gain * gain 297 | mode = 'FAN_IN' 298 | else : 299 | factor = (gain * gain) / 1.3 300 | mode = 'FAN_IN' 301 | 302 | return factor, mode, uniform --------------------------------------------------------------------------------
nameinputoutput
%s
%s