├── .DS_Store
├── .gitignore
├── LICENSE
├── README.md
├── SDIT.py
├── assets
├── framework.png
├── result.png
└── teaser.png
├── main.py
├── ops.py
└── utils.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/SDIT-Tensorflow/71f9917325d647d3d51e691e85f5a3079e068da0/.DS_Store
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Junho Kim
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SDIT-Tensorflow
2 | ## : Scalable and Diverse Cross-domain Image Translation (ACM-MM 2019)
3 |
4 |
5 |

6 |
7 |
8 | ### [Paper](https://arxiv.org/abs/1908.06881) | [Official Pytorch code](https://github.com/yaxingwang/SDIT)
9 |
10 | ## Usage
11 | ```
12 | ├── dataset
13 | └── YOUR_DATASET_NAME
14 | ├── train
15 | ├── class1 (class folder)
16 | ├── xxx.jpg (class1 image)
17 | ├── yyy.png
18 | ├── ...
19 | ├── class2
20 | ├── aaa.jpg (class2 image)
21 | ├── bbb.png
22 | ├── ...
23 | ├── class3
24 | ├── ...
25 | ├── test
26 | ├── zzz.jpg (any content image)
27 | ├── www.png
28 | ├── ...
29 |
30 | └── celebA
31 | ├── train
32 | ├── 000001.png
33 | ├── 000002.png
34 | └── ...
35 | ├── test
36 | ├── a.jpg (The test image that you wanted)
37 | ├── b.png
38 | └── ...
39 | ├── list_attr_celeba.txt (For attribute information)
40 | ```
41 | ### Train
42 | * python main.py --dataset celebA --phase train
43 |
44 | ### Test
45 | * python main.py --dataset celebA --phase test
46 | * The celebA test image and the image you wanted run simultaneously
47 |
48 |
49 | ## Comparison
50 |
51 |

52 |
53 |
54 | ## Paper results
55 |
56 |

57 |
58 |
59 |
60 | ## Author
61 | [Junho Kim](http://bit.ly/jhkim_ai)
62 |
--------------------------------------------------------------------------------
/SDIT.py:
--------------------------------------------------------------------------------
1 | from ops import *
2 | from utils import *
3 | import time
4 | from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch
5 | import numpy as np
6 | from glob import glob
7 | from tqdm import tqdm
8 |
9 | class SDIT() :
10 | def __init__(self, sess, args):
11 | self.model_name = 'SDIT'
12 | self.sess = sess
13 | self.phase = args.phase
14 | self.checkpoint_dir = args.checkpoint_dir
15 | self.sample_dir = args.sample_dir
16 | self.result_dir = args.result_dir
17 | self.log_dir = args.log_dir
18 | self.dataset_name = args.dataset
19 | self.dataset_path = os.path.join('./dataset', self.dataset_name)
20 | self.augment_flag = args.augment_flag
21 |
22 | self.epoch = args.epoch
23 | self.iteration = args.iteration
24 | self.decay_flag = args.decay_flag
25 | self.decay_epoch = args.decay_epoch
26 |
27 | self.gan_type = args.gan_type
28 | self.attention = args.attention
29 |
30 | self.batch_size = args.batch_size
31 | self.print_freq = args.print_freq
32 | self.save_freq = args.save_freq
33 |
34 | self.init_lr = args.lr
35 | self.ch = args.ch
36 |
37 | if self.dataset_name == 'celebA-HQ' or self.dataset_name == 'celebA':
38 | self.label_list = args.label_list
39 | else :
40 | self.dataset_path = os.path.join(self.dataset_path, 'train')
41 | self.label_list = [os.path.basename(x) for x in glob(self.dataset_path + '/*')]
42 |
43 |
44 | self.c_dim = len(self.label_list)
45 |
46 | """ Weight """
47 | self.adv_weight = args.adv_weight
48 | self.rec_weight = args.rec_weight
49 | self.cls_weight = args.cls_weight
50 | self.noise_weight = args.noise_weight
51 | self.gp_weight = args.gp_weight
52 |
53 | self.sn = args.sn
54 |
55 | """ Generator """
56 | self.n_res = args.n_res
57 | self.style_dim = args.style_dim
58 | self.num_style = args.num_style
59 |
60 | """ Discriminator """
61 | self.n_dis = args.n_dis
62 | self.n_critic = args.n_critic
63 |
64 | self.img_height = args.img_height
65 | self.img_width = args.img_width
66 | self.img_ch = args.img_ch
67 |
68 | print()
69 |
70 | print("##### Information #####")
71 | print("# gan type : ", self.gan_type)
72 | print("# selected_attrs : ", self.label_list)
73 | print("# dataset : ", self.dataset_name)
74 | print("# batch_size : ", self.batch_size)
75 | print("# epoch : ", self.epoch)
76 | print("# iteration per epoch : ", self.iteration)
77 | print("# spectral normalization : ", self.sn)
78 |
79 | print()
80 |
81 | print("##### Generator #####")
82 | print("# residual blocks : ", self.n_res)
83 | print("# attention : ", self.attention)
84 |
85 | print()
86 |
87 | print("##### Discriminator #####")
88 | print("# discriminator layer : ", self.n_dis)
89 | print("# the number of critic : ", self.n_critic)
90 |
91 | ##################################################################################
92 | # Generator
93 | ##################################################################################
94 |
95 | def generator(self, x_init, c, style, reuse=False, scope="generator"):
96 | channel = self.ch
97 | c = tf.cast(tf.reshape(c, shape=[-1, 1, 1, c.shape[-1]]), tf.float32)
98 | c = tf.tile(c, [1, x_init.shape[1], x_init.shape[2], 1])
99 | x = tf.concat([x_init, c], axis=-1)
100 |
101 | with tf.variable_scope(scope, reuse=reuse) :
102 | """ Encoder """
103 | x = conv(x, channel, kernel=7, stride=1, pad=3, pad_type='reflect', use_bias=False, sn=self.sn, scope='conv')
104 | x = instance_norm(x, scope='ins_norm')
105 | x = relu(x)
106 |
107 | # Down-Sampling
108 | for i in range(2) :
109 | x = conv(x, channel*2, kernel=4, stride=2, pad=1, pad_type='reflect', use_bias=False, sn=self.sn, scope='conv_'+str(i))
110 | x = instance_norm(x, scope='down_ins_norm_'+str(i))
111 | x = relu(x)
112 |
113 | channel = channel * 2
114 |
115 | """ Bottleneck """
116 | # Encoder Bottleneck
117 | for i in range(self.n_res) :
118 | x = resblock(x, channel, use_bias=False, sn=self.sn, scope='encoder_resblock_' + str(i))
119 |
120 | attention = x
121 | adaptive = x
122 |
123 | # Adaptive Bottleneck
124 | mu, var = self.MLP(style, channel)
125 | for i in range(self.n_res - 2) :
126 | idx = 2 * i
127 | adaptive = adaptive_resblock(adaptive, channel, mu[idx], var[idx], mu[idx + 1], var[idx + 1], use_bias=True, sn=self.sn, scope='ada_resbloack_' + str(i))
128 |
129 | if self.attention :
130 | # Attention Bottleneck
131 | for i in range(self.n_res - 1) :
132 | attention = resblock(attention, channel, use_bias=False, sn=self.sn, scope='attention_resblock_' + str(i))
133 |
134 | attention = conv(attention, 1, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=False, sn=self.sn, scope='attention_conv')
135 | attention = instance_norm(attention, scope='attention_ins_norm')
136 | attention = sigmoid(attention)
137 |
138 | x = attention * adaptive
139 |
140 | # attention_map = tf.concat([attention, attention, attention], axis=-1) * 2 - 1
141 | # attention_map = up_sample(attention_map, scale_factor=4)
142 |
143 | else :
144 | x = adaptive
145 |
146 | """ Decoder """
147 | # Up-Sampling
148 | for i in range(2):
149 | x = deconv(x, channel // 2, kernel=4, stride=2, use_bias=False, sn=self.sn, scope='deconv_' + str(i))
150 | x = instance_norm(x, scope='up_ins_norm' + str(i))
151 | x = relu(x)
152 |
153 | channel = channel // 2
154 |
155 | x = conv(x, channels=self.img_ch, kernel=7, stride=1, pad=3, pad_type='reflect', use_bias=False, sn=self.sn, scope='G_logit')
156 | x = tanh(x)
157 |
158 | return x
159 |
160 | def MLP(self, style, channel, scope='MLP'):
161 | with tf.variable_scope(scope):
162 | x = style
163 |
164 | for i in range(2):
165 | x = fully_connected(x, channel, sn=self.sn, scope='FC_' + str(i))
166 | x = relu(x)
167 |
168 | mu_list = []
169 | var_list = []
170 |
171 | for i in range(8):
172 | mu = fully_connected(x, channel, sn=self.sn, scope='FC_mu_' + str(i))
173 | var = fully_connected(x, channel, sn=self.sn, scope='FC_var_' + str(i))
174 |
175 | mu = tf.reshape(mu, shape=[-1, 1, 1, channel])
176 | var = tf.reshape(var, shape=[-1, 1, 1, channel])
177 |
178 | mu_list.append(mu)
179 | var_list.append(var)
180 |
181 | return mu_list, var_list
182 |
183 | ##################################################################################
184 | # Discriminator
185 | ##################################################################################
186 |
187 | def discriminator(self, x_init, reuse=False, scope="discriminator"):
188 | with tf.variable_scope(scope, reuse=reuse) :
189 | channel = self.ch
190 | x = conv(x_init, channel, kernel=4, stride=2, pad=1, pad_type='reflect', use_bias=True, sn=self.sn, scope='conv_0')
191 | x = lrelu(x, 0.01)
192 |
193 | for i in range(1, self.n_dis):
194 | x = conv(x, channel * 2, kernel=4, stride=2, pad=1, pad_type='reflect', use_bias=True, sn=self.sn, scope='conv_' + str(i))
195 | x = lrelu(x, 0.01)
196 |
197 | channel = channel * 2
198 |
199 | c_kernel = int(self.img_height / np.power(2, self.n_dis))
200 |
201 | logit = conv(x, channels=1, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=False, sn=self.sn, scope='D_logit')
202 |
203 | c = conv(x, channels=self.c_dim, kernel=c_kernel, stride=1, pad_type='reflect', use_bias=False, sn=self.sn, scope='D_label')
204 | c = tf.reshape(c, shape=[-1, self.c_dim])
205 |
206 | noise = conv(x, channels=self.style_dim, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=False, sn=self.sn, scope='D_noise')
207 | noise = fully_connected(noise, units=self.style_dim, use_bias=True, sn=self.sn, scope='fc_0')
208 | noise = relu(noise)
209 | noise = fully_connected(noise, units=self.style_dim, use_bias=True, sn=self.sn, scope='fc_1')
210 |
211 | return logit, c, noise
212 |
213 | ##################################################################################
214 | # Model
215 | ##################################################################################
216 |
217 | def gradient_panalty(self, real, fake, scope="discriminator"):
218 | if self.gan_type.__contains__('dragan'):
219 | eps = tf.random_uniform(shape=tf.shape(real), minval=0., maxval=1.)
220 | _, x_var = tf.nn.moments(real, axes=[0, 1, 2, 3])
221 | x_std = tf.sqrt(x_var) # magnitude of noise decides the size of local region
222 |
223 | fake = real + 0.5 * x_std * eps
224 |
225 | alpha = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0., maxval=1.)
226 | interpolated = real + alpha * (fake - real)
227 |
228 | logit, _, _ = self.discriminator(interpolated, reuse=True, scope=scope)
229 |
230 |
231 | GP = 0
232 |
233 | grad = tf.gradients(logit, interpolated)[0] # gradient of D(interpolated)
234 | grad_norm = tf.norm(flatten(grad), axis=-1) # l2 norm
235 |
236 | # WGAN - LP
237 | if self.gan_type == 'wgan-lp' :
238 | GP = self.gp_weight * tf.reduce_mean(tf.square(tf.maximum(0.0, grad_norm - 1.)))
239 |
240 | elif self.gan_type == 'wgan-gp' or self.gan_type == 'dragan':
241 | GP = self.gp_weight * tf.reduce_mean(tf.square(grad_norm - 1.))
242 |
243 | return GP
244 |
245 | def build_model(self):
246 | label_fix_onehot_list = []
247 |
248 | """ Input Image"""
249 | if self.dataset_name == 'celebA-HQ' or self.dataset_name == 'celebA':
250 | img_class = ImageData_celebA(self.img_height, self.img_width, self.img_ch, self.dataset_path,
251 | self.label_list, self.augment_flag)
252 | img_class.preprocess(self.phase)
253 |
254 | else:
255 | img_class = Image_data(self.img_height, self.img_width, self.img_ch, self.dataset_path, self.label_list,
256 | self.augment_flag)
257 | img_class.preprocess()
258 |
259 | label_fix_onehot_list = img_class.label_onehot_list
260 | label_fix_onehot_list = tf.tile(tf.expand_dims(label_fix_onehot_list, axis=1), [1, self.batch_size, 1])
261 |
262 | dataset_num = len(img_class.image)
263 | print("Dataset number : ", dataset_num)
264 |
265 | if self.phase == 'train' :
266 | self.lr = tf.placeholder(tf.float32, name='learning_rate')
267 |
268 | if self.dataset_name == 'celebA-HQ' or self.dataset_name == 'celebA':
269 | img_and_label = tf.data.Dataset.from_tensor_slices(
270 | (img_class.image, img_class.label, img_class.train_label_onehot_list))
271 | else:
272 | img_and_label = tf.data.Dataset.from_tensor_slices((img_class.image, img_class.label))
273 |
274 | gpu_device = '/gpu:0'
275 | img_and_label = img_and_label.apply(shuffle_and_repeat(dataset_num)).apply(
276 | map_and_batch(img_class.image_processing, self.batch_size, num_parallel_batches=16,
277 | drop_remainder=True)).apply(prefetch_to_device(gpu_device, None))
278 |
279 | img_and_label_iterator = img_and_label.make_one_shot_iterator()
280 |
281 | if self.dataset_name == 'celebA-HQ' or self.dataset_name == 'celebA':
282 | self.x_real, label_org, label_fix_onehot_list = img_and_label_iterator.get_next()
283 | label_trg = tf.random_shuffle(label_org) # Target domain labels
284 | label_fix_onehot_list = tf.transpose(label_fix_onehot_list, perm=[1, 0, 2])
285 | else:
286 | self.x_real, label_org = img_and_label_iterator.get_next()
287 | label_trg = tf.random_shuffle(label_org) # Target domain labels
288 |
289 |
290 | """ Define Generator, Discriminator """
291 | fake_style_code = tf.random_normal(shape=[self.batch_size, self.style_dim])
292 | x_fake = self.generator(self.x_real, label_trg, fake_style_code) # real a
293 |
294 | recon_style_code = tf.random_normal(shape=[self.batch_size, self.style_dim])
295 | x_recon = self.generator(x_fake, label_org, recon_style_code, reuse=True) # real b
296 |
297 | real_logit, real_cls, _ = self.discriminator(self.x_real)
298 | fake_logit, fake_cls, fake_noise = self.discriminator(x_fake, reuse=True)
299 |
300 |
301 | """ Define Loss """
302 | if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan' :
303 | GP = self.gradient_panalty(real=self.x_real, fake=x_fake)
304 | else :
305 | GP = 0
306 |
307 | g_adv_loss = self.adv_weight * generator_loss(self.gan_type, fake_logit)
308 | g_cls_loss = self.cls_weight * classification_loss(logit=fake_cls, label=label_trg)
309 | g_rec_loss = self.rec_weight * L1_loss(self.x_real, x_recon)
310 | g_noise_loss = self.noise_weight * L1_loss(fake_style_code, fake_noise)
311 |
312 | d_adv_loss = self.adv_weight * discriminator_loss(self.gan_type, real_logit, fake_logit) + GP
313 | d_cls_loss = self.cls_weight * classification_loss(logit=real_cls, label=label_org)
314 | d_noise_loss = self.noise_weight * L1_loss(fake_style_code, fake_noise)
315 |
316 | self.d_loss = d_adv_loss + d_cls_loss + d_noise_loss
317 | self.g_loss = g_adv_loss + g_cls_loss + g_rec_loss + g_noise_loss
318 |
319 |
320 | """ Result Image """
321 | if self.dataset_name == 'celebA-HQ' or self.dataset_name == 'celebA':
322 | self.x_fake_list = []
323 |
324 | for _ in range(self.num_style):
325 | random_style_code = tf.random_normal(shape=[self.batch_size, self.style_dim])
326 | self.x_fake_list.append(tf.map_fn(lambda c : self.generator(self.x_real, c, random_style_code, reuse=True), label_fix_onehot_list, dtype=tf.float32))
327 |
328 | else :
329 | self.x_fake_list = []
330 |
331 | for _ in range(self.num_style) :
332 | random_style_code = tf.random_normal(shape=[self.batch_size, self.style_dim])
333 | self.x_fake_list.append(tf.map_fn(lambda c : self.generator(self.x_real, c, random_style_code, reuse=True), label_fix_onehot_list, dtype=tf.float32))
334 |
335 |
336 |
337 | """ Training """
338 | t_vars = tf.trainable_variables()
339 | G_vars = [var for var in t_vars if 'generator' in var.name]
340 | D_vars = [var for var in t_vars if 'discriminator' in var.name]
341 |
342 | self.g_optimizer = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.g_loss, var_list=G_vars)
343 | self.d_optimizer = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.d_loss, var_list=D_vars)
344 |
345 |
346 | """" Summary """
347 | self.Generator_loss = tf.summary.scalar("g_loss", self.g_loss)
348 | self.Discriminator_loss = tf.summary.scalar("d_loss", self.d_loss)
349 |
350 | self.g_adv_loss = tf.summary.scalar("g_adv_loss", g_adv_loss)
351 | self.g_cls_loss = tf.summary.scalar("g_cls_loss", g_cls_loss)
352 | self.g_rec_loss = tf.summary.scalar("g_rec_loss", g_rec_loss)
353 | self.g_noise_loss = tf.summary.scalar("g_noise_loss", g_noise_loss)
354 |
355 | self.d_adv_loss = tf.summary.scalar("d_adv_loss", d_adv_loss)
356 | self.d_cls_loss = tf.summary.scalar("d_cls_loss", d_cls_loss)
357 | self.d_noise_loss = tf.summary.scalar("d_noise_loss", d_noise_loss)
358 |
359 | self.g_summary_loss = tf.summary.merge([self.Generator_loss, self.g_adv_loss, self.g_cls_loss, self.g_rec_loss, self.g_noise_loss])
360 | self.d_summary_loss = tf.summary.merge([self.Discriminator_loss, self.d_adv_loss, self.d_cls_loss, self.d_noise_loss])
361 |
362 | else :
363 | """ Test """
364 | if self.dataset_name == 'celebA-HQ' or self.dataset_name == 'celebA':
365 | img_and_label = tf.data.Dataset.from_tensor_slices(
366 | (img_class.test_image, img_class.test_label, img_class.test_label_onehot_list))
367 | dataset_num = len(img_class.test_image)
368 |
369 | gpu_device = '/gpu:0'
370 | img_and_label = img_and_label.apply(shuffle_and_repeat(dataset_num)).apply(
371 | map_and_batch(img_class.image_processing, batch_size=self.batch_size, num_parallel_batches=16,
372 | drop_remainder=True)).apply(prefetch_to_device(gpu_device, None))
373 |
374 | img_and_label_iterator = img_and_label.make_one_shot_iterator()
375 |
376 | self.x_test, _, self.test_label_fix_onehot_list = img_and_label_iterator.get_next()
377 | self.test_img_placeholder = tf.placeholder(tf.float32, [1, self.img_height, self.img_width, self.img_ch])
378 | self.test_label_fix_placeholder = tf.placeholder(tf.float32, [self.c_dim, 1, self.c_dim])
379 |
380 | self.custom_image = tf.placeholder(tf.float32, [1, self.img_height, self.img_width, self.img_ch], name='custom_image') # Custom Image
381 | custom_label_fix_onehot_list = tf.transpose(np.expand_dims(label2onehot(self.label_list), axis=0), perm=[1, 0, 2]) # [c_dim, bs, c_dim]
382 |
383 | """ Test Image """
384 | test_random_style_code = tf.random_normal(shape=[1, self.style_dim])
385 |
386 | self.x_test_fake_list = tf.map_fn(lambda c : self.generator(self.test_img_placeholder, c, test_random_style_code), self.test_label_fix_placeholder, dtype=tf.float32)
387 | self.custom_fake_image = tf.map_fn(lambda c : self.generator(self.custom_image, c, test_random_style_code, reuse=True), custom_label_fix_onehot_list, dtype=tf.float32)
388 |
389 | else :
390 | self.custom_image = tf.placeholder(tf.float32, [1, self.img_height, self.img_width, self.img_ch], name='custom_image') # Custom Image
391 | custom_label_fix_onehot_list = tf.transpose(np.expand_dims(label2onehot(self.label_list), axis=0), perm=[1, 0, 2]) # [c_dim, bs, c_dim]
392 |
393 | test_random_style_code = tf.random_normal(shape=[1, self.style_dim])
394 | self.custom_fake_image = tf.map_fn(lambda c : self.generator(self.custom_image, c, test_random_style_code), custom_label_fix_onehot_list, dtype=tf.float32)
395 |
396 |
397 |
398 | def train(self):
399 | # initialize all variables
400 | tf.global_variables_initializer().run()
401 |
402 | # saver to save model
403 | self.saver = tf.train.Saver(max_to_keep=10)
404 |
405 | # summary writer
406 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_dir, self.sess.graph)
407 |
408 | # restore check-point if it exits
409 | could_load, checkpoint_counter = self.load(self.checkpoint_dir)
410 | if could_load:
411 | start_epoch = (int)(checkpoint_counter / self.iteration)
412 | start_batch_id = checkpoint_counter - start_epoch * self.iteration
413 | counter = checkpoint_counter
414 | print(" [*] Load SUCCESS")
415 | else:
416 | start_epoch = 0
417 | start_batch_id = 0
418 | counter = 1
419 | print(" [!] Load failed...")
420 |
421 | self.sample_dir = os.path.join(self.sample_dir, self.model_dir)
422 | check_folder(self.sample_dir)
423 |
424 | # loop for epoch
425 | start_time = time.time()
426 | past_g_loss = -1.
427 | lr = self.init_lr
428 | for epoch in range(start_epoch, self.epoch):
429 | if self.decay_flag :
430 | lr = self.init_lr if epoch < self.decay_epoch else self.init_lr * (self.epoch - epoch) / (self.epoch - self.decay_epoch) # linear decay
431 |
432 | for idx in range(start_batch_id, self.iteration):
433 | train_feed_dict = {
434 | self.lr : lr
435 | }
436 |
437 | # Update D
438 | _, d_loss, summary_str = self.sess.run([self.d_optimizer, self.d_loss, self.d_summary_loss], feed_dict = train_feed_dict)
439 | self.writer.add_summary(summary_str, counter)
440 |
441 | # Update G
442 | g_loss = None
443 | if (counter - 1) % self.n_critic == 0 :
444 | real_images, fake_images, _, g_loss, summary_str = self.sess.run([self.x_real, self.x_fake_list, self.g_optimizer, self.g_loss, self.g_summary_loss], feed_dict = train_feed_dict)
445 | self.writer.add_summary(summary_str, counter)
446 | past_g_loss = g_loss
447 |
448 | # display training status
449 | counter += 1
450 | if g_loss == None :
451 | g_loss = past_g_loss
452 |
453 | print("Epoch: [%2d] [%5d/%5d] time: %4.4f d_loss: %.8f, g_loss: %.8f" % (epoch, idx, self.iteration, time.time() - start_time, d_loss, g_loss))
454 |
455 | if np.mod(idx+1, self.print_freq) == 0 :
456 | real_image = np.expand_dims(real_images[0], axis=0)
457 | save_images(real_image, [1, 1],
458 | './{}/real_{:03d}_{:05d}.jpg'.format(self.sample_dir, epoch, idx+1))
459 |
460 | merge_fake_x = None
461 |
462 | for ns in range(self.num_style) :
463 | fake_img = np.transpose(fake_images[ns], axes=[1, 0, 2, 3, 4])[0]
464 |
465 | if ns == 0 :
466 | merge_fake_x = return_images(fake_img, [1, self.c_dim]) # [self.img_height, self.img_width * self.c_dim, self.img_ch]
467 | else :
468 | x = return_images(fake_img, [1, self.c_dim])
469 | merge_fake_x = np.concatenate([merge_fake_x, x], axis=0)
470 |
471 | merge_fake_x = np.expand_dims(merge_fake_x, axis=0)
472 | save_images(merge_fake_x, [1, 1],
473 | './{}/fake_{:03d}_{:05d}.jpg'.format(self.sample_dir, epoch, idx+1))
474 |
475 | if np.mod(counter - 1, self.save_freq) == 0:
476 | self.save(self.checkpoint_dir, counter)
477 |
478 | # After an epoch, start_batch_id is set to zero
479 | # non-zero value is only for the first epoch after loading pre-trained model
480 | start_batch_id = 0
481 |
482 | # save model for final step
483 | self.save(self.checkpoint_dir, counter)
484 |
485 | @property
486 | def model_dir(self):
487 |
488 | if self.sn:
489 | sn = '_sn'
490 | else:
491 | sn = ''
492 |
493 | if self.attention:
494 | attention = '_attention'
495 | else:
496 | attention = ''
497 |
498 | return "{}_{}_{}_{}adv_{}rec_{}cls_{}noise{}{}".format(self.model_name, self.dataset_name, self.gan_type,
499 | self.adv_weight, self.rec_weight, self.cls_weight, self.noise_weight,
500 | sn, attention)
501 |
502 | def save(self, checkpoint_dir, step):
503 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
504 |
505 | if not os.path.exists(checkpoint_dir):
506 | os.makedirs(checkpoint_dir)
507 |
508 | self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step)
509 |
510 | def load(self, checkpoint_dir):
511 | print(" [*] Reading checkpoints...")
512 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
513 |
514 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
515 | if ckpt and ckpt.model_checkpoint_path:
516 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
517 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
518 | counter = int(ckpt_name.split('-')[-1])
519 | print(" [*] Success to read {}".format(ckpt_name))
520 | return True, counter
521 | else:
522 | print(" [*] Failed to find a checkpoint")
523 | return False, 0
524 |
525 | def test(self):
526 | tf.global_variables_initializer().run()
527 | test_files = glob('./dataset/{}/{}/*.jpg'.format(self.dataset_name, 'test')) + glob('./dataset/{}/{}/*.png'.format(self.dataset_name, 'test'))
528 |
529 | self.saver = tf.train.Saver()
530 | could_load, checkpoint_counter = self.load(self.checkpoint_dir)
531 | self.result_dir = os.path.join(self.result_dir, self.model_dir)
532 | check_folder(self.result_dir)
533 |
534 | custom_image_folder = os.path.join(self.result_dir, 'custom_fake_images')
535 | check_folder(custom_image_folder)
536 |
537 | if could_load :
538 | print(" [*] Load SUCCESS")
539 | else :
540 | print(" [!] Load failed...")
541 |
542 | # write html for visual comparison
543 | index_path = os.path.join(self.result_dir, 'index.html')
544 | index = open(index_path, 'w')
545 | index.write("")
546 | index.write("name | input | output |
")
547 |
548 | # Custom Image
549 | for sample_file in tqdm(test_files):
550 | print("Processing image: " + sample_file)
551 | sample_image = load_test_image(sample_file, self.img_width, self.img_height, self.img_ch)
552 | image_path = os.path.join(custom_image_folder, '{}'.format(os.path.basename(sample_file)))
553 |
554 | merge_x = None
555 |
556 | for i in range(self.num_style) :
557 | fake_img = self.sess.run(self.custom_fake_image, feed_dict={self.custom_image: sample_image})
558 | fake_img = np.transpose(fake_img, axes=[1, 0, 2, 3, 4])[0]
559 |
560 | if i == 0:
561 | merge_x = return_images(fake_img, [1, self.c_dim]) # [self.img_height, self.img_width * self.c_dim, self.img_ch]
562 | else :
563 | x = return_images(fake_img, [1, self.c_dim])
564 | merge_x = np.concatenate([merge_x, x], axis=0)
565 |
566 | merge_x = np.expand_dims(merge_x, axis=0)
567 |
568 | save_images(merge_x, [1, 1], image_path)
569 |
570 | index.write("%s | " % os.path.basename(image_path))
571 | index.write(" | " % (sample_file if os.path.isabs(sample_file) else (
572 | '../..' + os.path.sep + sample_file), self.img_width, self.img_height))
573 |
574 | index.write(" | " % (image_path if os.path.isabs(image_path) else (
575 | '../..' + os.path.sep + image_path), self.img_width * self.c_dim, self.img_height * self.num_style))
576 | index.write("")
577 |
578 | if self.dataset_name == 'celebA-HQ' or self.dataset_name == 'celebA':
579 | # CelebA
580 | celebA_image_folder = os.path.join(self.result_dir, 'celebA_real_fake_images')
581 | check_folder(celebA_image_folder)
582 | real_images, real_label_fixes = self.sess.run([self.x_test, self.test_label_fix_onehot_list])
583 |
584 | for i in tqdm(range(len(real_images))) :
585 |
586 | real_path = os.path.join(celebA_image_folder, 'real_{}.png'.format(i))
587 | fake_path = os.path.join(celebA_image_folder, 'fake_{}.png'.format(i))
588 |
589 | real_img = np.expand_dims(real_images[i], axis=0)
590 | real_label_fix = np.expand_dims(real_label_fixes[i], axis=1)
591 |
592 | merge_x = None
593 |
594 | for ns in range(self.num_style) :
595 | fake_img = self.sess.run(self.x_test_fake_list, feed_dict={self.test_img_placeholder: real_img, self.test_label_fix_placeholder:real_label_fix})
596 | fake_img = np.transpose(fake_img, axes=[1, 0, 2, 3, 4])[0]
597 |
598 | if ns == 0:
599 | merge_x = return_images(fake_img, [1, self.c_dim]) # [self.img_height, self.img_width * self.c_dim, self.img_ch]
600 | else:
601 | x = return_images(fake_img, [1, self.c_dim])
602 | merge_x = np.concatenate([merge_x, x], axis=0)
603 |
604 | merge_x = np.expand_dims(merge_x, axis=0)
605 |
606 | save_images(real_img, [1, 1], real_path)
607 | save_images(merge_x, [1, 1], fake_path)
608 |
609 | index.write("%s | " % os.path.basename(real_path))
610 | index.write(" | " % (real_path if os.path.isabs(real_path) else (
611 | '../..' + os.path.sep + real_path), self.img_width, self.img_height))
612 |
613 | index.write(" | " % (fake_path if os.path.isabs(fake_path) else (
614 | '../..' + os.path.sep + fake_path), self.img_width * self.c_dim, self.img_height * self.num_style))
615 | index.write("")
616 |
617 | index.close()
--------------------------------------------------------------------------------
/assets/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/SDIT-Tensorflow/71f9917325d647d3d51e691e85f5a3079e068da0/assets/framework.png
--------------------------------------------------------------------------------
/assets/result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/SDIT-Tensorflow/71f9917325d647d3d51e691e85f5a3079e068da0/assets/result.png
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/SDIT-Tensorflow/71f9917325d647d3d51e691e85f5a3079e068da0/assets/teaser.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from SDIT import SDIT
2 | import argparse
3 | from utils import *
4 |
5 | """parsing and configuration"""
6 | def parse_args():
7 | desc = "Tensorflow implementation of SDIT"
8 | parser = argparse.ArgumentParser(description=desc)
9 | parser.add_argument('--phase', type=str, default='train', help='train or test ?')
10 | parser.add_argument('--attention', type=str2bool, default=True, choices=[True, False])
11 | parser.add_argument('--dataset', type=str, default='celebA', help='dataset_name')
12 |
13 | parser.add_argument('--epoch', type=int, default=20, help='The number of epochs to run')
14 | parser.add_argument('--iteration', type=int, default=10000, help='The number of training iterations')
15 | # The total number of iterations is [epoch * iteration]
16 |
17 | parser.add_argument('--batch_size', type=int, default=16, help='The size of batch size')
18 | parser.add_argument('--print_freq', type=int, default=1000, help='The number of image_print_freq')
19 | parser.add_argument('--save_freq', type=int, default=10000, help='The number of ckpt_save_freq')
20 | parser.add_argument('--decay_flag', type=str2bool, default=True, help='The decay_flag')
21 | parser.add_argument('--decay_epoch', type=int, default=10, help='decay epoch')
22 |
23 | parser.add_argument('--lr', type=float, default=0.0001, help='The learning rate')
24 | parser.add_argument('--adv_weight', type=float, default=1, help='Weight about GAN')
25 | parser.add_argument('--rec_weight', type=float, default=10, help='Weight about Reconstruction')
26 | parser.add_argument('--cls_weight', type=float, default=10, help='Weight about Classification')
27 | parser.add_argument('--gp_weight', type=float, default=10, help='The gradient penalty lambda')
28 | parser.add_argument('--noise_weight', type=float, default=800, help='weight of noise for reconstruction loss')
29 |
30 | parser.add_argument('--gan_type', type=str, default='wgan-gp', help='gan / lsgan / wgan-gp / wgan-lp / dragan / hinge')
31 | parser.add_argument('--sn', type=str2bool, default=False, help='using spectral norm')
32 | parser.add_argument('--label_list', type=str, nargs='+', help='selected attributes for the CelebA dataset',
33 | default=['Blond_Hair', 'Brown_Hair', 'Male', 'Eyeglasses', 'Bangs'])
34 |
35 | parser.add_argument('--ch', type=int, default=64, help='base channel number per layer')
36 | parser.add_argument('--n_res', type=int, default=6, help='The number of resblock')
37 | parser.add_argument('--n_dis', type=int, default=6, help='The number of discriminator layer')
38 | parser.add_argument('--n_critic', type=int, default=5, help='number of D updates per each G update')
39 | parser.add_argument('--style_dim', type=int, default=8, help='length of style code')
40 |
41 | parser.add_argument('--num_style', type=int, default=5, help='number of styles to sample')
42 |
43 | parser.add_argument('--img_height', type=int, default=128, help='The height size of image')
44 | parser.add_argument('--img_width', type=int, default=128, help='The width size of image ')
45 | parser.add_argument('--img_ch', type=int, default=3, help='The size of image channel')
46 | parser.add_argument('--augment_flag', type=str2bool, default=True, help='Image augmentation use or not')
47 |
48 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint',
49 | help='Directory name to save the checkpoints')
50 | parser.add_argument('--result_dir', type=str, default='results',
51 | help='Directory name to save the generated images')
52 | parser.add_argument('--log_dir', type=str, default='logs',
53 | help='Directory name to save training logs')
54 | parser.add_argument('--sample_dir', type=str, default='samples',
55 | help='Directory name to save the samples on training')
56 |
57 | return check_args(parser.parse_args())
58 |
59 | """checking arguments"""
60 | def check_args(args):
61 | # --checkpoint_dir
62 | check_folder(args.checkpoint_dir)
63 |
64 | # --result_dir
65 | check_folder(args.result_dir)
66 |
67 | # --result_dir
68 | check_folder(args.log_dir)
69 |
70 | # --sample_dir
71 | check_folder(args.sample_dir)
72 |
73 | # --epoch
74 | try:
75 | assert args.epoch >= 1
76 | except:
77 | print('number of epochs must be larger than or equal to one')
78 |
79 | # --batch_size
80 | try:
81 | assert args.batch_size >= 1
82 | except:
83 | print('batch size must be larger than or equal to one')
84 | return args
85 |
86 | """main"""
87 | def main():
88 | # parse arguments
89 | args = parse_args()
90 | if args is None:
91 | exit()
92 |
93 | # open session
94 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
95 | gan = SDIT(sess, args)
96 |
97 | # build graph
98 | gan.build_model()
99 |
100 | # show network architecture
101 | show_all_variables()
102 |
103 | if args.phase == 'train' :
104 | gan.train()
105 | print(" [*] Training finished!")
106 |
107 | if args.phase == 'test' :
108 | gan.test()
109 | print(" [*] Test finished!")
110 |
111 | if __name__ == '__main__':
112 | main()
--------------------------------------------------------------------------------
/ops.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import tensorflow.contrib as tf_contrib
3 | from utils import pytorch_xavier_weight_factor, pytorch_kaiming_weight_factor
4 |
5 | ##################################################################################
6 | # Initialization
7 | ##################################################################################
8 |
9 | factor, mode, uniform = pytorch_xavier_weight_factor(gain=0.02, uniform=False)
10 | weight_init = tf_contrib.layers.variance_scaling_initializer(factor=factor, mode=mode, uniform=uniform)
11 | weight_regularizer = None
12 | weight_regularizer_fully = None
13 |
14 |
15 | ##################################################################################
16 | # Layers
17 | ##################################################################################
18 |
19 | # padding='SAME' ======> pad = floor[ (kernel - stride) / 2 ]
20 | def conv(x, channels, kernel=4, stride=2, pad=0, pad_type='zero', use_bias=True, sn=False, scope='conv_0'):
21 | with tf.variable_scope(scope):
22 | if pad > 0:
23 | h = x.get_shape().as_list()[1]
24 | if h % stride == 0:
25 | pad = pad * 2
26 | else:
27 | pad = max(kernel - (h % stride), 0)
28 |
29 | pad_top = pad // 2
30 | pad_bottom = pad - pad_top
31 | pad_left = pad // 2
32 | pad_right = pad - pad_left
33 |
34 | if pad_type == 'zero':
35 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]])
36 | if pad_type == 'reflect':
37 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]], mode='REFLECT')
38 |
39 | if sn:
40 | w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], initializer=weight_init,
41 | regularizer=weight_regularizer)
42 | x = tf.nn.conv2d(input=x, filter=spectral_norm(w),
43 | strides=[1, stride, stride, 1], padding='VALID')
44 | if use_bias:
45 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0))
46 | x = tf.nn.bias_add(x, bias)
47 |
48 | else:
49 | x = tf.layers.conv2d(inputs=x, filters=channels,
50 | kernel_size=kernel, kernel_initializer=weight_init,
51 | kernel_regularizer=weight_regularizer,
52 | strides=stride, use_bias=use_bias)
53 |
54 | return x
55 |
56 |
57 | def deconv(x, channels, kernel=4, stride=2, padding='SAME', use_bias=True, sn=False, scope='deconv_0'):
58 | with tf.variable_scope(scope):
59 | x_shape = x.get_shape().as_list()
60 |
61 | if padding == 'SAME':
62 | output_shape = [x_shape[0], x_shape[1] * stride, x_shape[2] * stride, channels]
63 |
64 | else:
65 | output_shape = [x_shape[0], x_shape[1] * stride + max(kernel - stride, 0),
66 | x_shape[2] * stride + max(kernel - stride, 0), channels]
67 |
68 | if sn:
69 | w = tf.get_variable("kernel", shape=[kernel, kernel, channels, x.get_shape()[-1]], initializer=weight_init,
70 | regularizer=weight_regularizer)
71 | x = tf.nn.conv2d_transpose(x, filter=spectral_norm(w), output_shape=output_shape,
72 | strides=[1, stride, stride, 1], padding=padding)
73 |
74 | if use_bias:
75 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0))
76 | x = tf.nn.bias_add(x, bias)
77 |
78 | else:
79 | x = tf.layers.conv2d_transpose(inputs=x, filters=channels,
80 | kernel_size=kernel, kernel_initializer=weight_init,
81 | kernel_regularizer=weight_regularizer,
82 | strides=stride, padding=padding, use_bias=use_bias)
83 |
84 | return x
85 |
86 | def fully_connected(x, units, use_bias=True, sn=False, scope='linear'):
87 | with tf.variable_scope(scope):
88 | x = flatten(x)
89 | shape = x.get_shape().as_list()
90 | channels = shape[-1]
91 |
92 | if sn:
93 | w = tf.get_variable("kernel", [channels, units], tf.float32,
94 | initializer=weight_init, regularizer=weight_regularizer_fully)
95 | if use_bias:
96 | bias = tf.get_variable("bias", [units],
97 | initializer=tf.constant_initializer(0.0))
98 |
99 | x = tf.matmul(x, spectral_norm(w)) + bias
100 | else:
101 | x = tf.matmul(x, spectral_norm(w))
102 |
103 | else:
104 | x = tf.layers.dense(x, units=units, kernel_initializer=weight_init,
105 | kernel_regularizer=weight_regularizer_fully,
106 | use_bias=use_bias)
107 |
108 | return x
109 |
110 | def flatten(x) :
111 | return tf.layers.flatten(x)
112 |
113 | ##################################################################################
114 | # Residual-block
115 | ##################################################################################
116 |
117 | def resblock(x_init, channels, use_bias=True, sn=False, scope='resblock'):
118 | with tf.variable_scope(scope):
119 | with tf.variable_scope('res1'):
120 | x = conv(x_init, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias, sn=sn)
121 | x = instance_norm(x)
122 | x = relu(x)
123 |
124 | with tf.variable_scope('res2'):
125 | x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias, sn=sn)
126 | x = instance_norm(x)
127 |
128 | return x + x_init
129 |
130 | def adaptive_resblock(x_init, channels, gamma1, beta1, gamma2, beta2, use_bias=True, sn=False, scope='adaptive_resblock') :
131 | with tf.variable_scope(scope):
132 | with tf.variable_scope('res1'):
133 | x = conv(x_init, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias, sn=sn)
134 | x = adaptive_instance_norm(x, gamma1, beta1)
135 | x = relu(x)
136 |
137 | with tf.variable_scope('res2'):
138 | x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias, sn=sn)
139 | x = adaptive_instance_norm(x, gamma2, beta2)
140 |
141 | return x + x_init
142 |
143 |
144 | ##################################################################################
145 | # Activation function
146 | ##################################################################################
147 |
148 | def lrelu(x, alpha=0.2):
149 | return tf.nn.leaky_relu(x, alpha)
150 |
151 |
152 | def relu(x):
153 | return tf.nn.relu(x)
154 |
155 |
156 | def tanh(x):
157 | return tf.tanh(x)
158 |
159 | def sigmoid(x):
160 | return tf.sigmoid(x)
161 |
162 | ##################################################################################
163 | # Pooling & Resize
164 | ##################################################################################
165 |
166 | def up_sample(x, scale_factor=2):
167 | _, h, w, _ = x.get_shape().as_list()
168 | new_size = [h * scale_factor, w * scale_factor]
169 | return tf.image.resize_bilinear(x, size=new_size)
170 |
171 |
172 | ##################################################################################
173 | # Normalization function
174 | ##################################################################################
175 |
176 | def instance_norm(x, scope='instance_norm'):
177 | return tf_contrib.layers.instance_norm(x,
178 | epsilon=1e-05,
179 | center=True, scale=True,
180 | scope=scope)
181 |
182 | def adaptive_instance_norm(content, gamma, beta, epsilon=1e-5):
183 |
184 | c_mean, c_var = tf.nn.moments(content, axes=[1, 2], keep_dims=True)
185 | c_std = tf.sqrt(c_var + epsilon)
186 |
187 | return gamma * ((content - c_mean) / c_std) + beta
188 |
189 | def spectral_norm(w, iteration=1):
190 | w_shape = w.shape.as_list()
191 | w = tf.reshape(w, [-1, w_shape[-1]])
192 |
193 | u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.random_normal_initializer(), trainable=False)
194 |
195 | u_hat = u
196 | v_hat = None
197 | for i in range(iteration):
198 | """
199 | power iteration
200 | Usually iteration = 1 will be enough
201 | """
202 | v_ = tf.matmul(u_hat, tf.transpose(w))
203 | v_hat = tf.nn.l2_normalize(v_)
204 |
205 | u_ = tf.matmul(v_hat, w)
206 | u_hat = tf.nn.l2_normalize(u_)
207 |
208 | u_hat = tf.stop_gradient(u_hat)
209 | v_hat = tf.stop_gradient(v_hat)
210 |
211 | sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat))
212 |
213 | with tf.control_dependencies([u.assign(u_hat)]):
214 | w_norm = w / sigma
215 | w_norm = tf.reshape(w_norm, w_shape)
216 |
217 | return w_norm
218 |
219 | ##################################################################################
220 | # Loss function
221 | ##################################################################################
222 |
223 | def discriminator_loss(loss_func, real_logit, fake_logit):
224 | real_loss = 0
225 | fake_loss = 0
226 |
227 | if loss_func.__contains__('wgan') :
228 | real_loss = -tf.reduce_mean(real_logit)
229 | fake_loss = tf.reduce_mean(fake_logit)
230 |
231 | if loss_func == 'lsgan' :
232 | real_loss = tf.reduce_mean(tf.squared_difference(real_logit, 1.0))
233 | fake_loss = tf.reduce_mean(tf.square(fake_logit))
234 |
235 | if loss_func == 'gan' or loss_func == 'dragan' :
236 | real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real_logit), logits=real_logit))
237 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(fake_logit), logits=fake_logit))
238 |
239 | if loss_func == 'hinge' :
240 | real_loss = tf.reduce_mean(relu(1.0 - real_logit))
241 | fake_loss = tf.reduce_mean(relu(1.0 + fake_logit))
242 |
243 | loss = real_loss + fake_loss
244 |
245 | return loss
246 |
247 | def generator_loss(loss_func, fake_logit):
248 | fake_loss = 0
249 |
250 | if loss_func.__contains__('wgan') :
251 | fake_loss = -tf.reduce_mean(fake_logit)
252 |
253 | if loss_func == 'lsgan' :
254 | fake_loss = tf.reduce_mean(tf.squared_difference(fake_logit, 1.0))
255 |
256 | if loss_func == 'gan' or loss_func == 'dragan' :
257 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(fake_logit), logits=fake_logit))
258 |
259 | if loss_func == 'hinge' :
260 | fake_loss = -tf.reduce_mean(fake_logit)
261 |
262 | loss = fake_loss
263 |
264 | return loss
265 |
266 | def classification_loss(logit, label) :
267 | loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=label, logits=logit))
268 |
269 | return loss
270 |
271 | def L1_loss(x, y):
272 | loss = tf.reduce_mean(tf.abs(x - y))
273 |
274 | return loss
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import cv2
4 |
5 | import tensorflow as tf
6 | import tensorflow.contrib.slim as slim
7 | import random
8 | from glob import glob
9 | from tqdm import tqdm
10 |
11 | class Image_data:
12 |
13 | def __init__(self, img_height, img_width, channels, dataset_path, label_list, augment_flag):
14 | self.img_height = img_height
15 | self.img_width = img_width
16 | self.channels = channels
17 | self.augment_flag = augment_flag
18 |
19 | self.label_list = label_list
20 | self.dataset_path = dataset_path
21 |
22 | self.label_onehot_list = []
23 | self.image = []
24 | self.label = []
25 |
26 |
27 | def image_processing(self, filename, label):
28 | x = tf.read_file(filename)
29 | x_decode = tf.image.decode_jpeg(x, channels=self.channels, dct_method='INTEGER_ACCURATE')
30 | img = tf.image.resize_images(x_decode, [self.img_height, self.img_width])
31 | img = tf.cast(img, tf.float32) / 127.5 - 1
32 |
33 |
34 | if self.augment_flag :
35 | augment_height_size = self.img_height + (30 if self.img_height == 256 else int(self.img_height * 0.1))
36 | augment_width_size = self.img_width + (30 if self.img_width == 256 else int(self.img_width * 0.1))
37 |
38 | img = tf.cond(pred=tf.greater_equal(tf.random_uniform(shape=[], minval=0.0, maxval=1.0), 0.5),
39 | true_fn=lambda : augmentation(img, augment_height_size, augment_width_size),
40 | false_fn=lambda : img)
41 |
42 | return img, label
43 |
44 | def preprocess(self):
45 | # self.label_list = ['tiger', 'cat', 'dog', 'lion']
46 |
47 | v = 0
48 |
49 | for label in self.label_list : # fabric
50 | label_one_hot = list(get_one_hot(v, len(self.label_list))) # [1, 0, 0, 0, 0]
51 | self.label_onehot_list.append(label_one_hot)
52 | v = v+1
53 |
54 | image_list = glob(os.path.join(self.dataset_path, label) + '/*.png') + glob(os.path.join(self.dataset_path, label) + '/*.jpg')
55 | label_one_hot = [label_one_hot] * len(image_list)
56 |
57 | self.image.extend(image_list)
58 | self.label.extend(label_one_hot)
59 |
60 | class ImageData_celebA:
61 |
62 | def __init__(self, img_height, img_width, channels, dataset_path, label_list, augment_flag):
63 | self.img_height = img_height
64 | self.img_width = img_width
65 | self.channels = channels
66 | self.augment_flag = augment_flag
67 | self.label_list = label_list
68 |
69 | self.dataset_path = os.path.join(dataset_path, 'train')
70 | self.file_name_list = [os.path.basename(x) for x in glob(self.dataset_path + '/*.png')]
71 | self.lines = open(os.path.join(dataset_path, 'list_attr_celeba.txt'), 'r').readlines()
72 |
73 | self.image = []
74 | self.label = []
75 |
76 | self.test_image = []
77 | self.test_label = []
78 |
79 | self.attr2idx = {}
80 | self.idx2attr = {}
81 |
82 | self.train_label_onehot_list = []
83 | self.test_label_onehot_list = []
84 |
85 | def image_processing(self, filename, label, fix_label):
86 | x = tf.read_file(filename)
87 | x_decode = tf.image.decode_jpeg(x, channels=self.channels, dct_method='INTEGER_ACCURATE')
88 | img = tf.image.resize_images(x_decode, [self.img_height, self.img_width])
89 | img = tf.cast(img, tf.float32) / 127.5 - 1
90 |
91 | if self.augment_flag :
92 | augment_height = self.img_height + (30 if self.img_height == 256 else int(self.img_height * 0.1))
93 | augment_width = self.img_width + (30 if self.img_width == 256 else int(self.img_width * 0.1))
94 |
95 | img = tf.cond(pred=tf.greater_equal(tf.random_uniform(shape=[], minval=0.0, maxval=1.0), 0.5),
96 | true_fn=lambda: augmentation(img, augment_height, augment_width),
97 | false_fn=lambda: img)
98 |
99 |
100 | return img, label, fix_label
101 |
102 | def preprocess(self, phase):
103 |
104 | all_attr_names = self.lines[1].split()
105 | for i, attr_name in enumerate(all_attr_names):
106 | self.attr2idx[attr_name] = i
107 | self.idx2attr[i] = attr_name
108 |
109 | lines = self.lines[2:]
110 | random.seed(1234)
111 | random.shuffle(lines)
112 |
113 | for i, line in enumerate(tqdm(lines)):
114 | split = line.split()
115 | if split[0] in self.file_name_list:
116 | filename = os.path.join(self.dataset_path, split[0])
117 | values = split[1:]
118 |
119 | label = []
120 |
121 | for attr_name in self.label_list:
122 | idx = self.attr2idx[attr_name]
123 |
124 | if values[idx] == '1':
125 | label.append(1.0)
126 | else:
127 | label.append(0.0)
128 |
129 | if i < 2000:
130 | self.test_image.append(filename)
131 | self.test_label.append(label)
132 | else:
133 | if phase == 'test' :
134 | break
135 | self.image.append(filename)
136 | self.label.append(label)
137 | # ['./dataset/celebA/train/019932.png', [1, 0, 0, 0, 1]]
138 |
139 | print()
140 |
141 | self.test_label_onehot_list = create_labels(self.test_label, self.label_list)
142 | if phase == 'train' :
143 | self.train_label_onehot_list = create_labels(self.label, self.label_list)
144 |
145 | print('\n Finished preprocessing the CelebA dataset...')
146 |
147 | def load_test_image(image_path, img_width, img_height, img_channel):
148 |
149 | if img_channel == 1 :
150 | img = cv2.imread(image_path, flags=cv2.IMREAD_GRAYSCALE)
151 | else :
152 | img = cv2.imread(image_path, flags=cv2.IMREAD_COLOR)
153 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
154 |
155 | img = cv2.resize(img, dsize=(img_width, img_height))
156 |
157 | if img_channel == 1 :
158 | img = np.expand_dims(img, axis=0)
159 | img = np.expand_dims(img, axis=-1)
160 | else :
161 | img = np.expand_dims(img, axis=0)
162 |
163 | img = img/127.5 - 1
164 |
165 | return img
166 |
167 | def load_one_hot_vector(label_list, target_label) :
168 | label_onehot_dict = {}
169 |
170 | v = 0
171 | for label_name in label_list :
172 | label_one_hot = list(get_one_hot(v, len(label_list)))
173 | label_onehot_dict[label_name] = [label_one_hot]
174 |
175 | x = label_onehot_dict[target_label]
176 |
177 |
178 | return x
179 |
180 | def label2onehot(label_list) :
181 | v = 0
182 | label_onehot_list = []
183 | for _ in label_list: # fabric
184 | label_one_hot = list(get_one_hot(v, len(label_list))) # [1, 0, 0, 0, 0]
185 | label_onehot_list.append(label_one_hot)
186 | v = v + 1
187 |
188 | return label_onehot_list
189 |
190 | def augmentation(image, augment_height, augment_width):
191 | seed = random.randint(0, 2 ** 31 - 1)
192 | ori_image_shape = tf.shape(image)
193 | image = tf.image.random_flip_left_right(image, seed=seed)
194 | image = tf.image.resize_images(image, [augment_height, augment_width])
195 | image = tf.random_crop(image, ori_image_shape, seed=seed)
196 | return image
197 |
198 |
199 | def save_images(images, size, image_path):
200 | return imsave(inverse_transform(images), size, image_path)
201 |
202 | def inverse_transform(images):
203 | return ((images+1.) / 2) * 255.0
204 |
205 |
206 | def imsave(images, size, path):
207 | images = merge(images, size)
208 | images = cv2.cvtColor(images.astype('uint8'), cv2.COLOR_RGB2BGR)
209 |
210 | return cv2.imwrite(path, images)
211 |
212 | def merge(images, size):
213 | h, w = images.shape[1], images.shape[2]
214 | img = np.zeros((h * size[0], w * size[1], 3))
215 | for idx, image in enumerate(images):
216 | i = idx % size[1]
217 | j = idx // size[1]
218 | img[h*j:h*(j+1), w*i:w*(i+1), :] = image
219 |
220 | return img
221 |
222 | def return_images(images, size) :
223 | x = merge(images, size)
224 |
225 | return x
226 |
227 | def check_folder(log_dir):
228 | if not os.path.exists(log_dir):
229 | os.makedirs(log_dir)
230 | return log_dir
231 |
232 | def show_all_variables():
233 | model_vars = tf.trainable_variables()
234 | slim.model_analyzer.analyze_vars(model_vars, print_info=True)
235 |
236 | def str2bool(x):
237 | return x.lower() in ('true')
238 |
239 | def get_one_hot(targets, nb_classes):
240 |
241 | x = np.eye(nb_classes)[targets]
242 |
243 | return x
244 |
245 | def create_labels(c_org, selected_attrs=None):
246 | """Generate target domain labels for debugging and testing."""
247 | # Get hair color indices.
248 | c_org = np.asarray(c_org)
249 | hair_color_indices = []
250 | for i, attr_name in enumerate(selected_attrs):
251 | if attr_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']:
252 | hair_color_indices.append(i)
253 |
254 | c_trg_list = []
255 |
256 | for i in range(len(selected_attrs)):
257 | c_trg = c_org.copy()
258 |
259 | if i in hair_color_indices: # Set one hair color to 1 and the rest to 0.
260 | c_trg[:, i] = 1.0
261 | for j in hair_color_indices:
262 | if j != i:
263 | c_trg[:, j] = 0.0
264 | else:
265 | c_trg[:, i] = (c_trg[:, i] == 0) # Reverse attribute value.
266 |
267 | c_trg_list.append(c_trg)
268 |
269 | c_trg_list = np.transpose(c_trg_list, axes=[1, 0, 2]) # [bs, c_dim, c_dim]
270 |
271 | return c_trg_list
272 |
273 | def pytorch_xavier_weight_factor(gain=0.02, uniform=False) :
274 |
275 | if uniform :
276 | factor = gain * gain
277 | mode = 'FAN_AVG'
278 | else :
279 | factor = (gain * gain) / 1.3
280 | mode = 'FAN_AVG'
281 |
282 | return factor, mode, uniform
283 |
284 | def pytorch_kaiming_weight_factor(a=0.0, activation_function='relu', uniform=False) :
285 |
286 | if activation_function == 'relu' :
287 | gain = np.sqrt(2.0)
288 | elif activation_function == 'leaky_relu' :
289 | gain = np.sqrt(2.0 / (1 + a ** 2))
290 | elif activation_function =='tanh' :
291 | gain = 5.0 / 3
292 | else :
293 | gain = 1.0
294 |
295 | if uniform :
296 | factor = gain * gain
297 | mode = 'FAN_IN'
298 | else :
299 | factor = (gain * gain) / 1.3
300 | mode = 'FAN_IN'
301 |
302 | return factor, mode, uniform
--------------------------------------------------------------------------------