├── .gitignore
├── DRIT.py
├── LICENSE
├── README.md
├── assets
├── comparison.png
├── false.png
├── final.gif
├── result1.png
├── result2.png
├── test.png
├── test_1.png
├── test_2.png
├── train_1.png
├── train_2.png
└── true.png
├── main.py
├── ops.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/DRIT.py:
--------------------------------------------------------------------------------
1 | from ops import *
2 | from utils import *
3 | from glob import glob
4 | import time
5 | from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch
6 |
7 | class DRIT(object) :
8 | def __init__(self, sess, args):
9 | self.model_name = 'DRIT'
10 | self.sess = sess
11 | self.checkpoint_dir = args.checkpoint_dir
12 | self.result_dir = args.result_dir
13 | self.log_dir = args.log_dir
14 | self.sample_dir = args.sample_dir
15 | self.dataset_name = args.dataset
16 | self.augment_flag = args.augment_flag
17 |
18 | self.epoch = args.epoch
19 | self.iteration = args.iteration
20 | self.decay_flag = args.decay_flag
21 | self.decay_epoch = args.decay_epoch
22 |
23 | self.gan_type = args.gan_type
24 |
25 | self.batch_size = args.batch_size
26 | self.print_freq = args.print_freq
27 | self.save_freq = args.save_freq
28 |
29 | self.num_attribute = args.num_attribute # for test
30 | self.guide_img = args.guide_img
31 | self.direction = args.direction
32 |
33 | self.img_size = args.img_size
34 | self.img_ch = args.img_ch
35 |
36 | self.init_lr = args.lr
37 | self.content_init_lr = args.lr / 2.5
38 | self.ch = args.ch
39 | self.concat = args.concat
40 |
41 | """ Weight """
42 | self.content_adv_w = args.content_adv_w
43 | self.domain_adv_w = args.domain_adv_w
44 | self.cycle_w = args.cycle_w
45 | self.recon_w = args.recon_w
46 | self.latent_w = args.latent_w
47 | self.kl_w = args.kl_w
48 |
49 | """ Generator """
50 | self.n_layer = args.n_layer
51 | self.n_z = args.n_z
52 |
53 | """ Discriminator """
54 | self.n_dis = args.n_dis
55 | self.n_scale = args.n_scale
56 | self.n_d_con = args.n_d_con
57 | self.multi = True if args.n_scale > 1 else False
58 | self.sn = args.sn
59 |
60 | self.sample_dir = os.path.join(args.sample_dir, self.model_dir)
61 | check_folder(self.sample_dir)
62 |
63 | self.trainA_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainA'))
64 | self.trainB_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainB'))
65 | self.dataset_num = max(len(self.trainA_dataset), len(self.trainB_dataset))
66 |
67 | print("##### Information #####")
68 | print("# gan type : ", self.gan_type)
69 | print("# dataset : ", self.dataset_name)
70 | print("# max dataset number : ", self.dataset_num)
71 | print("# batch_size : ", self.batch_size)
72 | print("# decay_flag : ", self.decay_flag)
73 | print("# epoch : ", self.epoch)
74 | print("# decay_epoch : ", self.decay_epoch)
75 | print("# iteration per epoch : ", self.iteration)
76 | print("# attribute in test phase : ", self.num_attribute)
77 |
78 | print()
79 |
80 | print("##### Generator #####")
81 | print("# layer : ", self.n_layer)
82 | print("# z dimension : ", self.n_z)
83 | print("# concat : ", self.concat)
84 |
85 | print()
86 |
87 | print("##### Discriminator #####")
88 | print("# discriminator layer : ", self.n_dis)
89 | print("# multi-scale Dis : ", self.n_scale)
90 | print("# updating iteration of con_dis : ", self.n_d_con)
91 | print("# spectral_norm : ", self.sn)
92 |
93 | print()
94 |
95 | print("##### Weight #####")
96 | print("# domain_adv_weight : ", self.domain_adv_w)
97 | print("# content_adv_weight : ", self.content_adv_w)
98 | print("# cycle_weight : ", self.cycle_w)
99 | print("# recon_weight : ", self.recon_w)
100 | print("# latent_weight : ", self.latent_w)
101 | print("# kl_weight : ", self.kl_w)
102 |
103 | ##################################################################################
104 | # Encoder and Decoders
105 | ##################################################################################
106 |
107 | def content_encoder(self, x, is_training=True, reuse=False, scope='content_encoder'):
108 | channel = self.ch
109 | with tf.variable_scope(scope, reuse=reuse) :
110 | x = conv(x, channel, kernel=7, stride=1, pad=3, pad_type='reflect', scope='conv')
111 | x = lrelu(x, 0.01)
112 |
113 | for i in range(2) :
114 | x = conv(x, channel * 2, kernel=3, stride=2, pad=1, pad_type='reflect', scope='conv_' + str(i))
115 | x = instance_norm(x, scope='ins_norm_' + str(i))
116 | x = relu(x)
117 |
118 | channel = channel * 2
119 |
120 |
121 | for i in range(1, self.n_layer) :
122 | x = resblock(x, channel, scope='resblock_'+str(i))
123 |
124 | with tf.variable_scope('content_encoder_share', reuse=tf.AUTO_REUSE) :
125 | x = resblock(x, channel, scope='resblock_share')
126 | x = gaussian_noise_layer(x, is_training)
127 |
128 | return x
129 |
130 | def attribute_encoder(self, x, reuse=False, scope='attribute_encoder'):
131 | channel = self.ch
132 | with tf.variable_scope(scope, reuse=reuse) :
133 | x = conv(x, channel, kernel=7, stride=1, pad=3, pad_type='reflect', scope='conv')
134 | x = relu(x)
135 | channel = channel * 2
136 |
137 | x = conv(x, channel, kernel=4, stride=2, pad=1, pad_type='reflect', scope='conv_0')
138 | x = relu(x)
139 | channel = channel * 2
140 |
141 |
142 | for i in range(1, self.n_layer) :
143 | x = conv(x, channel, kernel=4, stride=2, pad=1, pad_type='reflect', scope='conv_' + str(i))
144 | x = relu(x)
145 |
146 | x = global_avg_pooling(x)
147 | x = conv(x, channels=self.n_z, kernel=1, stride=1, scope='attribute_logit')
148 |
149 | return x
150 |
151 | def attribute_encoder_concat(self, x, reuse=False, scope='attribute_encoder_concat'):
152 | channel = self.ch
153 | with tf.variable_scope(scope, reuse=reuse) :
154 | x = conv(x, channel, kernel=4, stride=2, pad=1, pad_type='reflect', scope='conv')
155 |
156 | for i in range(1, self.n_layer) :
157 | channel = channel * (i+1)
158 | x = basic_block(x, channel, scope='basic_block_' + str(i))
159 |
160 | x = lrelu(x, 0.2)
161 | x = global_avg_pooling(x)
162 |
163 | mean = fully_conneted(x, channels=self.n_z, scope='z_mean')
164 | logvar = fully_conneted(x, channels=self.n_z, scope='z_logvar')
165 |
166 | return mean, logvar
167 |
168 | def MLP(self, z, reuse=False, scope='MLP'):
169 | channel = self.ch * self.n_layer
170 | with tf.variable_scope(scope, reuse=reuse) :
171 |
172 | for i in range(2) :
173 | z = fully_conneted(z, channel, scope='fully_' + str(i))
174 | z = relu(z)
175 |
176 | z = fully_conneted(z, channel*self.n_layer, scope='fully_logit')
177 |
178 | return z
179 |
180 | def generator(self, x, z, reuse=False, scope="generator"):
181 | channel = self.ch * self.n_layer
182 | with tf.variable_scope(scope, reuse=reuse) :
183 | z = self.MLP(z, reuse=reuse)
184 | z = tf.split(z, num_or_size_splits=self.n_layer, axis=-1)
185 |
186 | for i in range(self.n_layer) :
187 | x = mis_resblock(x, z[i], channel, scope='mis_resblock_' + str(i))
188 |
189 | for i in range(2) :
190 | x = deconv(x, channel // 2, kernel=3, stride=2, scope='deconv_' + str(i))
191 | x = layer_norm(x, scope='layer_norm_' + str(i))
192 | x = relu(x)
193 |
194 | channel = channel // 2
195 |
196 | x = deconv(x, channels=self.img_ch, kernel=1, stride=1, scope='G_logit')
197 | x = tanh(x)
198 |
199 | return x
200 |
201 | def generator_concat(self, x, z, reuse=False, scope='generator_concat'):
202 | channel = self.ch * self.n_layer
203 | with tf.variable_scope('generator_concat_share', reuse=tf.AUTO_REUSE) :
204 | x = resblock(x, channel, scope='resblock')
205 |
206 | with tf.variable_scope(scope, reuse=reuse) :
207 | channel = channel + self.n_z
208 | x = expand_concat(x, z)
209 |
210 | for i in range(1, self.n_layer) :
211 | x = resblock(x, channel, scope='resblock_' + str(i))
212 |
213 | for i in range(2) :
214 | channel = channel + self.n_z
215 | x = expand_concat(x, z)
216 |
217 | x = deconv(x, channel // 2, kernel=3, stride=2, scope='deconv_' + str(i))
218 | x = layer_norm(x, scope='layer_norm_' + str(i))
219 | x = relu(x)
220 |
221 | channel = channel // 2
222 |
223 | x = expand_concat(x, z)
224 | x = deconv(x, channels=self.img_ch, kernel=1, stride=1, scope='G_logit')
225 | x = tanh(x)
226 |
227 | return x
228 |
229 |
230 |
231 | ##################################################################################
232 | # Discriminator
233 | ##################################################################################
234 |
235 | def content_discriminator(self, x, reuse=False, scope='content_discriminator'):
236 | D_logit = []
237 | with tf.variable_scope(scope, reuse=reuse) :
238 | channel = self.ch * self.n_layer
239 | for i in range(3) :
240 | x = conv(x, channel, kernel=7, stride=2, pad=1, pad_type='reflect', scope='conv_' + str(i))
241 | x = instance_norm(x, scope='ins_norm_' + str(i))
242 | x = lrelu(x, 0.01)
243 |
244 | x = conv(x, channel, kernel=4, stride=1, scope='conv_3')
245 | x = lrelu(x, 0.01)
246 |
247 | x = conv(x, channels=1, kernel=1, stride=1, scope='D_content_logit')
248 | D_logit.append(x)
249 |
250 | return D_logit
251 |
252 | def multi_discriminator(self, x_init, reuse=False, scope="multi_discriminator"):
253 | D_logit = []
254 | with tf.variable_scope(scope, reuse=reuse) :
255 | for scale in range(self.n_scale) :
256 | channel = self.ch
257 | x = conv(x_init, channel, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='ms_' + str(scale) + 'conv_0')
258 | x = lrelu(x, 0.01)
259 |
260 | for i in range(1, self.n_dis):
261 | x = conv(x, channel * 2, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='ms_' + str(scale) +'conv_' + str(i))
262 | x = lrelu(x, 0.01)
263 |
264 | channel = channel * 2
265 |
266 | x = conv(x, channels=1, kernel=1, stride=1, sn=self.sn, scope='ms_' + str(scale) + 'D_logit')
267 | D_logit.append(x)
268 |
269 | x_init = down_sample(x_init)
270 |
271 | return D_logit
272 |
273 | def discriminator(self, x, reuse=False, scope="discriminator"):
274 | D_logit = []
275 | with tf.variable_scope(scope, reuse=reuse) :
276 | channel = self.ch
277 | x = conv(x, channel, kernel=3, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv')
278 | x = lrelu(x, 0.01)
279 |
280 | for i in range(1, self.n_dis) :
281 | x = conv(x, channel * 2, kernel=3, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_' + str(i))
282 | x = lrelu(x, 0.01)
283 |
284 | channel = channel * 2
285 |
286 | x = conv(x, channels=1, kernel=1, stride=1, sn=self.sn, scope='D_logit')
287 | D_logit.append(x)
288 |
289 | return D_logit
290 |
291 | ##################################################################################
292 | # Model
293 | ##################################################################################
294 |
295 | def Encoder_A(self, x_A, is_training=True, random_fake=False, reuse=False):
296 | mean = None
297 | logvar = None
298 |
299 | content_A = self.content_encoder(x_A, is_training=is_training, reuse=reuse, scope='content_encoder_A')
300 |
301 | if self.concat :
302 | mean, logvar = self.attribute_encoder_concat(x_A, reuse=reuse, scope='attribute_encoder_concat_A')
303 | if random_fake :
304 | attribute_A = mean
305 | else :
306 | attribute_A = z_sample(mean, logvar)
307 | else :
308 | attribute_A = self.attribute_encoder(x_A, reuse=reuse, scope='attribute_encoder_A')
309 |
310 | return content_A, attribute_A, mean, logvar
311 |
312 | def Encoder_B(self, x_B, is_training=True, random_fake=False, reuse=False):
313 | mean = None
314 | logvar = None
315 |
316 | content_B = self.content_encoder(x_B, is_training=is_training, reuse=reuse, scope='content_encoder_B')
317 |
318 | if self.concat:
319 | mean, logvar = self.attribute_encoder_concat(x_B, reuse=reuse, scope='attribute_encoder_concat_B')
320 | if random_fake :
321 | attribute_B = mean
322 |
323 | else :
324 | attribute_B = z_sample(mean, logvar)
325 | else:
326 | attribute_B = self.attribute_encoder(x_B, reuse=reuse, scope='attribute_encoder_B')
327 |
328 | return content_B, attribute_B, mean, logvar
329 |
330 | def Decoder_A(self, content_B, attribute_A, reuse=False):
331 | # x = fake_A, identity_A, random_fake_A
332 | # x = (B, A), (A, A), (B, z)
333 | if self.concat :
334 | x = self.generator_concat(x=content_B, z=attribute_A, reuse=reuse, scope='generator_concat_A')
335 | else :
336 | x = self.generator(x=content_B, z=attribute_A, reuse=reuse, scope='generator_A')
337 |
338 | return x
339 |
340 | def Decoder_B(self, content_A, attribute_B, reuse=False):
341 | # x = fake_B, identity_B, random_fake_B
342 | # x = (A, B), (B, B), (A, z)
343 | if self.concat :
344 | x = self.generator_concat(x=content_A, z=attribute_B, reuse=reuse, scope='generator_concat_B')
345 | else :
346 | x = self.generator(x=content_A, z=attribute_B, reuse=reuse, scope='generator_B')
347 |
348 | return x
349 |
350 | def discriminate_real(self, x_A, x_B):
351 | if self.multi :
352 | real_A_logit = self.multi_discriminator(x_A, scope='multi_discriminator_A')
353 | real_B_logit = self.multi_discriminator(x_B, scope='multi_discriminator_B')
354 |
355 | else :
356 | real_A_logit = self.discriminator(x_A, scope="discriminator_A")
357 | real_B_logit = self.discriminator(x_B, scope="discriminator_B")
358 |
359 | return real_A_logit, real_B_logit
360 |
361 | def discriminate_fake(self, x_ba, x_ab):
362 | if self.multi :
363 | fake_A_logit = self.multi_discriminator(x_ba, reuse=True, scope='multi_discriminator_A')
364 | fake_B_logit = self.multi_discriminator(x_ab, reuse=True, scope='multi_discriminator_B')
365 |
366 | else :
367 | fake_A_logit = self.discriminator(x_ba, reuse=True, scope="discriminator_A")
368 | fake_B_logit = self.discriminator(x_ab, reuse=True, scope="discriminator_B")
369 |
370 | return fake_A_logit, fake_B_logit
371 |
372 | def discriminate_content(self, content_A, content_B, reuse=False):
373 | content_A_logit = self.content_discriminator(content_A, reuse=reuse, scope='content_discriminator')
374 | content_B_logit = self.content_discriminator(content_B, reuse=True, scope='content_discriminator')
375 |
376 | return content_A_logit, content_B_logit
377 |
378 |
379 | def build_model(self):
380 | self.lr = tf.placeholder(tf.float32, name='lr')
381 | self.content_lr = tf.placeholder(tf.float32, name='content_lr')
382 |
383 | """ Input Image"""
384 | Image_Data_Class = ImageData(self.img_size, self.img_ch, self.augment_flag)
385 |
386 | trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset)
387 | trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset)
388 |
389 | gpu_device = '/gpu:0'
390 | trainA = trainA.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size))
391 | trainB = trainB.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size))
392 |
393 |
394 | trainA_iterator = trainA.make_one_shot_iterator()
395 | trainB_iterator = trainB.make_one_shot_iterator()
396 |
397 |
398 | self.domain_A = trainA_iterator.get_next()
399 | self.domain_B = trainB_iterator.get_next()
400 |
401 |
402 | """ Define Encoder, Generator, Discriminator """
403 | random_z = tf.random_normal(shape=[self.batch_size, self.n_z], mean=0.0, stddev=1.0, dtype=tf.float32)
404 |
405 | # encode
406 | content_a, attribute_a, mean_a, logvar_a = self.Encoder_A(self.domain_A)
407 | content_b, attribute_b, mean_b, logvar_b = self.Encoder_B(self.domain_B)
408 |
409 | # decode (fake, identity, random)
410 | fake_a = self.Decoder_A(content_B=content_b, attribute_A=attribute_a)
411 | fake_b = self.Decoder_B(content_A=content_a, attribute_B=attribute_b)
412 |
413 | recon_a = self.Decoder_A(content_B=content_a, attribute_A=attribute_a, reuse=True)
414 | recon_b = self.Decoder_B(content_A=content_b, attribute_B=attribute_b, reuse=True)
415 |
416 | random_fake_a = self.Decoder_A(content_B=content_b, attribute_A=random_z, reuse=True)
417 | random_fake_b = self.Decoder_B(content_A=content_a, attribute_B=random_z, reuse=True)
418 |
419 | # encode & decode again for cycle-consistency
420 | content_fake_a, attribute_fake_a, _, _ = self.Encoder_A(fake_a, reuse=True)
421 | content_fake_b, attribute_fake_b, _, _ = self.Encoder_B(fake_b, reuse=True)
422 |
423 | cycle_a = self.Decoder_A(content_B=content_fake_b, attribute_A=attribute_fake_a, reuse=True)
424 | cycle_b = self.Decoder_B(content_A=content_fake_a, attribute_B=attribute_fake_b, reuse=True)
425 |
426 | # for latent regression
427 | _, attribute_fake_random_a, _, _ = self.Encoder_A(random_fake_a, random_fake=True, reuse=True)
428 | _, attribute_fake_random_b, _, _ = self.Encoder_B(random_fake_b, random_fake=True, reuse=True)
429 |
430 |
431 | # discriminate
432 | real_A_logit, real_B_logit = self.discriminate_real(self.domain_A, self.domain_B)
433 | fake_A_logit, fake_B_logit = self.discriminate_fake(fake_a, fake_b)
434 | random_fake_A_logit, random_fake_B_logit = self.discriminate_fake(random_fake_a, random_fake_b)
435 | content_A_logit, content_B_logit = self.discriminate_content(content_a, content_b)
436 |
437 |
438 | """ Define Loss """
439 | g_adv_loss_a = generator_loss(self.gan_type, fake_A_logit) + generator_loss(self.gan_type, random_fake_A_logit)
440 | g_adv_loss_b = generator_loss(self.gan_type, fake_B_logit) + generator_loss(self.gan_type, random_fake_B_logit)
441 |
442 | g_con_loss_a = generator_loss(self.gan_type, content_A_logit, content=True)
443 | g_con_loss_b = generator_loss(self.gan_type, content_B_logit, content=True)
444 |
445 | g_cyc_loss_a = L1_loss(cycle_a, self.domain_A)
446 | g_cyc_loss_b = L1_loss(cycle_b, self.domain_B)
447 |
448 | g_rec_loss_a = L1_loss(recon_a, self.domain_A)
449 | g_rec_loss_b = L1_loss(recon_b, self.domain_B)
450 |
451 | g_latent_loss_a = L1_loss(attribute_fake_random_a, random_z)
452 | g_latent_loss_b = L1_loss(attribute_fake_random_b, random_z)
453 |
454 | if self.concat :
455 | g_kl_loss_a = kl_loss(mean_a, logvar_a) + l2_regularize(content_a)
456 | g_kl_loss_b = kl_loss(mean_b, logvar_b) + l2_regularize(content_b)
457 | else :
458 | g_kl_loss_a = l2_regularize(attribute_a) + l2_regularize(content_a)
459 | g_kl_loss_b = l2_regularize(attribute_b) + l2_regularize(content_b)
460 |
461 |
462 | d_adv_loss_a = discriminator_loss(self.gan_type, real_A_logit, fake_A_logit, random_fake_A_logit)
463 | d_adv_loss_b = discriminator_loss(self.gan_type, real_B_logit, fake_B_logit, random_fake_B_logit)
464 |
465 | d_con_loss = discriminator_loss(self.gan_type, content_A_logit, content_B_logit, content=True)
466 |
467 | Generator_A_domain_loss = self.domain_adv_w * g_adv_loss_a
468 | Generator_A_content_loss = self.content_adv_w * g_con_loss_a
469 | Generator_A_cycle_loss = self.cycle_w * g_cyc_loss_b
470 | Generator_A_recon_loss = self.recon_w * g_rec_loss_a
471 | Generator_A_latent_loss = self.latent_w * g_latent_loss_a
472 | Generator_A_kl_loss = self.kl_w * g_kl_loss_a
473 |
474 | Generator_A_loss = Generator_A_domain_loss + \
475 | Generator_A_content_loss + \
476 | Generator_A_cycle_loss + \
477 | Generator_A_recon_loss + \
478 | Generator_A_latent_loss + \
479 | Generator_A_kl_loss
480 |
481 | Generator_B_domain_loss = self.domain_adv_w * g_adv_loss_b
482 | Generator_B_content_loss = self.content_adv_w * g_con_loss_b
483 | Generator_B_cycle_loss = self.cycle_w * g_cyc_loss_a
484 | Generator_B_recon_loss = self.recon_w * g_rec_loss_b
485 | Generator_B_latent_loss = self.latent_w * g_latent_loss_b
486 | Generator_B_kl_loss = self.kl_w * g_kl_loss_b
487 |
488 | Generator_B_loss = Generator_B_domain_loss + \
489 | Generator_B_content_loss + \
490 | Generator_B_cycle_loss + \
491 | Generator_B_recon_loss + \
492 | Generator_B_latent_loss + \
493 | Generator_B_kl_loss
494 |
495 | Discriminator_A_loss = self.domain_adv_w * d_adv_loss_a
496 | Discriminator_B_loss = self.domain_adv_w * d_adv_loss_b
497 | Discriminator_content_loss = self.content_adv_w * d_con_loss
498 |
499 | self.Generator_loss = Generator_A_loss + Generator_B_loss
500 | self.Discriminator_loss = Discriminator_A_loss + Discriminator_B_loss
501 | self.Discriminator_content_loss = Discriminator_content_loss
502 |
503 | """ Training """
504 | t_vars = tf.trainable_variables()
505 | G_vars = [var for var in t_vars if 'encoder' in var.name or 'generator' in var.name]
506 | D_vars = [var for var in t_vars if 'discriminator' in var.name and 'content' not in var.name]
507 | D_content_vars = [var for var in t_vars if 'content_discriminator' in var.name]
508 |
509 | grads, _ = tf.clip_by_global_norm(tf.gradients(self.Discriminator_content_loss, D_content_vars), clip_norm=5)
510 |
511 | self.G_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Generator_loss, var_list=G_vars)
512 | self.D_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Discriminator_loss, var_list=D_vars)
513 | self.D_content_optim = tf.train.AdamOptimizer(self.content_lr, beta1=0.5, beta2=0.999).apply_gradients(zip(grads, D_content_vars))
514 |
515 |
516 | """" Summary """
517 | self.lr_write = tf.summary.scalar("learning_rate", self.lr)
518 |
519 | self.all_G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss)
520 | self.all_D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss)
521 |
522 | self.G_A_loss = tf.summary.scalar("G_A_loss", Generator_A_loss)
523 | self.G_A_domain_loss = tf.summary.scalar("G_A_domain_loss", Generator_A_domain_loss)
524 | self.G_A_content_loss = tf.summary.scalar("G_A_content_loss", Generator_A_content_loss)
525 | self.G_A_cycle_loss = tf.summary.scalar("G_A_cycle_loss", Generator_A_cycle_loss)
526 | self.G_A_recon_loss = tf.summary.scalar("G_A_recon_loss", Generator_A_recon_loss)
527 | self.G_A_latent_loss = tf.summary.scalar("G_A_latent_loss", Generator_A_latent_loss)
528 | self.G_A_kl_loss = tf.summary.scalar("G_A_kl_loss", Generator_A_kl_loss)
529 |
530 |
531 | self.G_B_loss = tf.summary.scalar("G_B_loss", Generator_B_loss)
532 | self.G_B_domain_loss = tf.summary.scalar("G_B_domain_loss", Generator_B_domain_loss)
533 | self.G_B_content_loss = tf.summary.scalar("G_B_content_loss", Generator_B_content_loss)
534 | self.G_B_cycle_loss = tf.summary.scalar("G_B_cycle_loss", Generator_B_cycle_loss)
535 | self.G_B_recon_loss = tf.summary.scalar("G_B_recon_loss", Generator_B_recon_loss)
536 | self.G_B_latent_loss = tf.summary.scalar("G_B_latent_loss", Generator_B_latent_loss)
537 | self.G_B_kl_loss = tf.summary.scalar("G_B_kl_loss", Generator_B_kl_loss)
538 |
539 | self.D_A_loss = tf.summary.scalar("D_A_loss", Discriminator_A_loss)
540 | self.D_B_loss = tf.summary.scalar("D_B_loss", Discriminator_B_loss)
541 |
542 | self.G_loss = tf.summary.merge([self.G_A_loss,
543 | self.G_A_domain_loss, self.G_A_content_loss,
544 | self.G_A_cycle_loss, self.G_A_recon_loss,
545 | self.G_A_latent_loss, self.G_A_kl_loss,
546 |
547 | self.G_B_loss,
548 | self.G_B_domain_loss, self.G_B_content_loss,
549 | self.G_B_cycle_loss, self.G_B_recon_loss,
550 | self.G_B_latent_loss, self.G_B_kl_loss,
551 |
552 | self.all_G_loss])
553 |
554 | self.D_loss = tf.summary.merge([self.D_A_loss,
555 | self.D_B_loss,
556 | self.all_D_loss])
557 |
558 | self.D_content_loss = tf.summary.scalar("Discriminator_content_loss", self.Discriminator_content_loss)
559 |
560 |
561 |
562 | """ Image """
563 | self.fake_A = random_fake_a
564 | self.fake_B = random_fake_b
565 |
566 | self.real_A = self.domain_A
567 | self.real_B = self.domain_B
568 |
569 |
570 | """ Test """
571 | self.test_image = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_image')
572 | self.test_random_z = tf.random_normal(shape=[1, self.n_z], mean=0.0, stddev=1.0, dtype=tf.float32)
573 |
574 | test_content_a, _, _, _ = self.Encoder_A(self.test_image, is_training=False, reuse=True)
575 | test_content_b, _, _, _ = self.Encoder_B(self.test_image, is_training=False, reuse=True)
576 |
577 | self.test_fake_A = self.Decoder_A(content_B=test_content_b, attribute_A=self.test_random_z, reuse=True)
578 | self.test_fake_B = self.Decoder_B(content_A=test_content_a, attribute_B=self.test_random_z, reuse=True)
579 |
580 | """ Guided Image Translation """
581 | self.content_image = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='content_image')
582 | self.attribute_image = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='guide_attribute_image')
583 |
584 | if self.direction == 'a2b' :
585 | guide_content_A, _, _, _ = self.Encoder_A(self.content_image, is_training=False, reuse=True)
586 | _, guide_attribute_B, _, _ = self.Encoder_B(self.attribute_image, is_training=False, reuse=True)
587 | self.guide_fake_B = self.Decoder_B(content_A=guide_content_A, attribute_B=guide_attribute_B, reuse=True)
588 |
589 | else :
590 | guide_content_B, _, _, _ = self.Encoder_B(self.content_image, is_training=False, reuse=True)
591 | _, guide_attribute_A, _, _ = self.Encoder_A(self.attribute_image, is_training=False, reuse=True)
592 | self.guide_fake_A = self.Decoder_A(content_B=guide_content_B, attribute_A=guide_attribute_A, reuse=True)
593 |
594 | def train(self):
595 | # initialize all variables
596 | tf.global_variables_initializer().run()
597 |
598 | # saver to save model
599 | self.saver = tf.train.Saver()
600 |
601 | # summary writer
602 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_dir, self.sess.graph)
603 |
604 | # restore check-point if it exits
605 | could_load, checkpoint_counter = self.load(self.checkpoint_dir)
606 | if could_load:
607 | start_epoch = (int)(checkpoint_counter / self.iteration)
608 | start_batch_id = checkpoint_counter - start_epoch * self.iteration
609 | counter = checkpoint_counter
610 | print(" [*] Load SUCCESS")
611 | else:
612 | start_epoch = 0
613 | start_batch_id = 0
614 | counter = 1
615 | print(" [!] Load failed...")
616 |
617 | # loop for epoch
618 | start_time = time.time()
619 | lr = self.init_lr
620 | content_lr = self.content_init_lr
621 | for epoch in range(start_epoch, self.epoch):
622 | if self.decay_flag:
623 | lr = self.init_lr if epoch < self.decay_epoch else self.init_lr * (self.epoch - epoch) / (self.epoch - self.decay_epoch) # linear decay
624 | content_lr = self.content_init_lr if epoch < self.decay_epoch else self.content_init_lr * (self.epoch - epoch) / (self.epoch - self.decay_epoch) # linear decay
625 |
626 | for idx in range(start_batch_id, self.iteration):
627 | train_feed_dict = {
628 | self.lr : lr,
629 | self.content_lr : content_lr
630 | }
631 |
632 | summary_str = self.sess.run(self.lr_write, feed_dict=train_feed_dict)
633 | self.writer.add_summary(summary_str, counter)
634 |
635 | # Update content D
636 | _, d_con_loss, summary_str = self.sess.run([self.D_content_optim, self.Discriminator_content_loss, self.D_content_loss], feed_dict=train_feed_dict)
637 | self.writer.add_summary(summary_str, counter)
638 |
639 | if (counter - 1) % self.n_d_con == 0 :
640 | # Update D
641 | _, d_loss, summary_str = self.sess.run([self.D_optim, self.Discriminator_loss, self.D_loss], feed_dict = train_feed_dict)
642 | self.writer.add_summary(summary_str, counter)
643 |
644 | # Update G
645 | batch_A_images, batch_B_images, fake_A, fake_B, _, g_loss, summary_str = self.sess.run([self.real_A, self.real_B, self.fake_A, self.fake_B, self.G_optim, self.Generator_loss, self.G_loss], feed_dict = train_feed_dict)
646 | self.writer.add_summary(summary_str, counter)
647 |
648 | print("Epoch: [%2d] [%6d/%6d] time: %4.4f d_con_loss: %.8f, d_loss: %.8f, g_loss: %.8f" \
649 | % (epoch, idx, self.iteration, time.time() - start_time, d_con_loss, d_loss, g_loss))
650 |
651 | else :
652 | print("Epoch: [%2d] [%6d/%6d] time: %4.4f d_con_loss: %.8f" % (epoch, idx, self.iteration, time.time() - start_time, d_con_loss))
653 |
654 | if np.mod(idx + 1, self.print_freq) == 0:
655 | save_images(batch_A_images, [self.batch_size, 1],
656 | './{}/real_A_{:03d}_{:05d}.jpg'.format(self.sample_dir, epoch, idx + 1))
657 | # save_images(batch_B_images, [self.batch_size, 1],
658 | # './{}/real_B_{}_{:03d}_{:05d}.jpg'.format(self.sample_dir, gpu_id, epoch, idx+1))
659 |
660 | # save_images(fake_A, [self.batch_size, 1],
661 | # './{}/fake_A_{}_{:03d}_{:05d}.jpg'.format(self.sample_dir, gpu_id, epoch, idx+1))
662 | save_images(fake_B, [self.batch_size, 1],
663 | './{}/fake_B_{:03d}_{:05d}.jpg'.format(self.sample_dir, epoch, idx + 1))
664 |
665 | # display training status
666 | counter += 1
667 |
668 | if np.mod(idx+1, self.save_freq) == 0 :
669 | self.save(self.checkpoint_dir, counter)
670 |
671 | # After an epoch, start_batch_id is set to zero
672 | # non-zero value is only for the first epoch after loading pre-trained model
673 | start_batch_id = 0
674 |
675 | # save model for final step
676 | self.save(self.checkpoint_dir, counter)
677 |
678 | @property
679 | def model_dir(self):
680 | if self.concat :
681 | concat = "_concat"
682 | else :
683 | concat = ""
684 |
685 | if self.sn :
686 | sn = "_sn"
687 | else :
688 | sn = ""
689 |
690 | return "{}{}_{}_{}_{}layer_{}dis_{}scale_{}con{}".format(self.model_name, concat, self.dataset_name, self.gan_type,
691 | self.n_layer, self.n_dis, self.n_scale, self.n_d_con, sn)
692 |
693 | def save(self, checkpoint_dir, step):
694 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
695 |
696 | if not os.path.exists(checkpoint_dir):
697 | os.makedirs(checkpoint_dir)
698 |
699 | self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step)
700 |
701 | def load(self, checkpoint_dir):
702 | print(" [*] Reading checkpoints...")
703 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
704 |
705 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
706 | if ckpt and ckpt.model_checkpoint_path:
707 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
708 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
709 | counter = int(ckpt_name.split('-')[-1])
710 | print(" [*] Success to read {}".format(ckpt_name))
711 | return True, counter
712 | else:
713 | print(" [*] Failed to find a checkpoint")
714 | return False, 0
715 |
716 | def test(self):
717 | tf.global_variables_initializer().run()
718 | test_A_files = glob('./dataset/{}/*.*'.format(self.dataset_name + '/testA'))
719 | test_B_files = glob('./dataset/{}/*.*'.format(self.dataset_name + '/testB'))
720 |
721 | self.saver = tf.train.Saver()
722 | could_load, checkpoint_counter = self.load(self.checkpoint_dir)
723 | self.result_dir = os.path.join(self.result_dir, self.model_dir)
724 | check_folder(self.result_dir)
725 |
726 | if could_load :
727 | print(" [*] Load SUCCESS")
728 | else :
729 | print(" [!] Load failed...")
730 |
731 | # write html for visual comparison
732 | index_path = os.path.join(self.result_dir, 'index.html')
733 | index = open(index_path, 'w')
734 | index.write("
")
735 | index.write("name | input | output |
")
736 |
737 | for sample_file in test_A_files : # A -> B
738 | print('Processing A image: ' + sample_file)
739 | sample_image = np.asarray(load_test_data(sample_file, size=self.img_size))
740 | file_name = os.path.basename(sample_file).split(".")[0]
741 | file_extension = os.path.basename(sample_file).split(".")[1]
742 |
743 | for i in range(self.num_attribute) :
744 | image_path = os.path.join(self.result_dir, '{}_attribute{}.{}'.format(file_name, i, file_extension))
745 |
746 | fake_img = self.sess.run(self.test_fake_B, feed_dict = {self.test_image : sample_image})
747 | save_images(fake_img, [1, 1], image_path)
748 |
749 | index.write("%s | " % os.path.basename(image_path))
750 | index.write(" | " % (sample_file if os.path.isabs(sample_file) else (
751 | '../..' + os.path.sep + sample_file), self.img_size, self.img_size))
752 | index.write(" | " % (image_path if os.path.isabs(image_path) else (
753 | '../..' + os.path.sep + image_path), self.img_size, self.img_size))
754 | index.write("")
755 |
756 | for sample_file in test_B_files : # B -> A
757 | print('Processing B image: ' + sample_file)
758 | sample_image = np.asarray(load_test_data(sample_file, size=self.img_size))
759 | file_name = os.path.basename(sample_file).split(".")[0]
760 | file_extension = os.path.basename(sample_file).split(".")[1]
761 |
762 | for i in range(self.num_attribute):
763 | image_path = os.path.join(self.result_dir, '{}_attribute{}.{}'.format(file_name, i, file_extension))
764 |
765 | fake_img = self.sess.run(self.test_fake_A, feed_dict={self.test_image: sample_image})
766 | save_images(fake_img, [1, 1], image_path)
767 |
768 | index.write("%s | " % os.path.basename(image_path))
769 | index.write(" | " % (sample_file if os.path.isabs(sample_file) else (
770 | '../..' + os.path.sep + sample_file), self.img_size, self.img_size))
771 | index.write(" | " % (image_path if os.path.isabs(image_path) else (
772 | '../..' + os.path.sep + image_path), self.img_size, self.img_size))
773 | index.write("")
774 | index.close()
775 |
776 | def guide_test(self):
777 | tf.global_variables_initializer().run()
778 | test_A_files = glob('./dataset/{}/*.*'.format(self.dataset_name + '/testA'))
779 | test_B_files = glob('./dataset/{}/*.*'.format(self.dataset_name + '/testB'))
780 |
781 | attribute_file = np.asarray(load_test_data(self.guide_img, size=self.img_size))
782 |
783 | self.saver = tf.train.Saver()
784 | could_load, checkpoint_counter = self.load(self.checkpoint_dir)
785 | self.result_dir = os.path.join(self.result_dir, self.model_dir, 'guide')
786 | check_folder(self.result_dir)
787 |
788 | if could_load:
789 | print(" [*] Load SUCCESS")
790 | else:
791 | print(" [!] Load failed...")
792 |
793 | # write html for visual comparison
794 | index_path = os.path.join(self.result_dir, 'index.html')
795 | index = open(index_path, 'w')
796 | index.write("")
797 | index.write("name | input | output |
")
798 |
799 | if self.direction == 'a2b' :
800 | for sample_file in test_A_files: # A -> B
801 | print('Processing A image: ' + sample_file)
802 | sample_image = np.asarray(load_test_data(sample_file, size=self.img_size))
803 | image_path = os.path.join(self.result_dir, '{}'.format(os.path.basename(sample_file)))
804 |
805 | fake_img = self.sess.run(self.guide_fake_B, feed_dict={self.content_image: sample_image, self.attribute_image : attribute_file})
806 | save_images(fake_img, [1, 1], image_path)
807 |
808 | index.write("%s | " % os.path.basename(image_path))
809 | index.write(" | " % (sample_file if os.path.isabs(sample_file) else (
810 | '../../..' + os.path.sep + sample_file), self.img_size, self.img_size))
811 | index.write(" | " % (image_path if os.path.isabs(image_path) else (
812 | '../../..' + os.path.sep + image_path), self.img_size, self.img_size))
813 | index.write("")
814 |
815 | else :
816 | for sample_file in test_B_files: # B -> A
817 | print('Processing B image: ' + sample_file)
818 | sample_image = np.asarray(load_test_data(sample_file, size=self.img_size))
819 | image_path = os.path.join(self.result_dir, '{}'.format(os.path.basename(sample_file)))
820 |
821 | fake_img = self.sess.run(self.guide_fake_A, feed_dict={self.content_image: sample_image, self.attribute_image : attribute_file})
822 | save_images(fake_img, [1, 1], image_path)
823 |
824 | index.write("%s | " % os.path.basename(image_path))
825 | index.write(" | " % (sample_file if os.path.isabs(sample_file) else (
826 | '../../..' + os.path.sep + sample_file), self.img_size, self.img_size))
827 | index.write(" | " % (image_path if os.path.isabs(image_path) else (
828 | '../../..' + os.path.sep + image_path), self.img_size, self.img_size))
829 | index.write("")
830 | index.close()
831 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Junho Kim (1993.01.12)
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DRIT-Tensorflow
2 | Simple Tensorflow implementation of [Diverse Image-to-Image Translation via Disentangled Representations](https://arxiv.org/abs/1808.00948) (ECCV 2018 Oral)
3 |
4 |
5 |
6 | ## Pytorch version
7 | * [Author_pytorch_code](https://github.com/HsinYingLee/DRIT)
8 |
9 | ## Requirements
10 | * Tensorflow 1.8
11 | * python 3.6
12 |
13 | ## Usage
14 | ### Download Dataset
15 | * [cat2dog](http://vllab.ucmerced.edu/hylee/DRIT/datasets/cat2dog)
16 | * [portrait](http://vllab.ucmerced.edu/hylee/DRIT/datasets/portrait)
17 | * [CycleGAN](https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/)
18 |
19 | ```
20 | ├── dataset
21 | └── YOUR_DATASET_NAME
22 | ├── trainA
23 | ├── xxx.jpg (name, format doesn't matter)
24 | ├── yyy.png
25 | └── ...
26 | ├── trainB
27 | ├── zzz.jpg
28 | ├── www.png
29 | └── ...
30 | ├── testA
31 | ├── aaa.jpg
32 | ├── bbb.png
33 | └── ...
34 | └── testB
35 | ├── ccc.jpg
36 | ├── ddd.png
37 | └── ...
38 |
39 | ├── guide.jpg (example for guided image translation task)
40 | ```
41 |
42 | ### Train
43 | ```
44 | python main.py --phase train --dataset summer2winter --concat True
45 | ```
46 |
47 | ### Test
48 | ```
49 | python main.py --phase test --dataset summer2winter --concat True --num_attribute 3
50 | ```
51 |
52 | ### Guide
53 | ```
54 | python main.py --phase guide --dataset summer2winter --concat True --direction a2b --guide_img ./guide.jpg
55 | ```
56 |
57 | ### Tips
58 | * --concat
59 | * `True` : for the **shape preserving translation** (summer <-> winter) **(default)**
60 | * `False` : for the **shape variation translation** (cat <-> dog)
61 |
62 | * --n_scale
63 | * Recommend `n_scale = 3` **(default)**
64 | * Using the `n_scale > 1`, a.k.a. `multiscale discriminator` often gets better results
65 |
66 | * --n_dis
67 | * If you use the multi-discriminator, then recommend `n_dis = 4` **(default)**
68 | * If you don't the use multi-discriminator, then recommend `n_dis = 6`
69 |
70 | * --n_d_con
71 | * Author use `n_d_con = 3` **(default)**
72 | * Model can still generate diverse results with `n_d_con = 1`
73 |
74 | * --num_attribute **(only for the test phase)**
75 | * If you use the `num_attribute > 1`, then output images are variously generated
76 |
77 | ## Summary
78 | ### Comparison
79 | 
80 |
81 | ### Architecture
82 | 
83 | 
84 |
85 | ### Train phase
86 | 
87 | 
88 |
89 | ### Test & Guide phase
90 | 
91 |
92 | ## Results
93 | 
94 | 
95 |
96 | ## Related works
97 | * [UNIT-Tensorflow](https://github.com/taki0112/UNIT-Tensorflow)
98 | * [MUNIT-Tensorflow](https://github.com/taki0112/MUNIT-Tensorflow)
99 |
100 | ## Author
101 | Junho Kim
102 |
--------------------------------------------------------------------------------
/assets/comparison.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/DRIT-Tensorflow/384f6aac3e91898ee400c57418d2e7b3d6df1916/assets/comparison.png
--------------------------------------------------------------------------------
/assets/false.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/DRIT-Tensorflow/384f6aac3e91898ee400c57418d2e7b3d6df1916/assets/false.png
--------------------------------------------------------------------------------
/assets/final.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/DRIT-Tensorflow/384f6aac3e91898ee400c57418d2e7b3d6df1916/assets/final.gif
--------------------------------------------------------------------------------
/assets/result1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/DRIT-Tensorflow/384f6aac3e91898ee400c57418d2e7b3d6df1916/assets/result1.png
--------------------------------------------------------------------------------
/assets/result2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/DRIT-Tensorflow/384f6aac3e91898ee400c57418d2e7b3d6df1916/assets/result2.png
--------------------------------------------------------------------------------
/assets/test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/DRIT-Tensorflow/384f6aac3e91898ee400c57418d2e7b3d6df1916/assets/test.png
--------------------------------------------------------------------------------
/assets/test_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/DRIT-Tensorflow/384f6aac3e91898ee400c57418d2e7b3d6df1916/assets/test_1.png
--------------------------------------------------------------------------------
/assets/test_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/DRIT-Tensorflow/384f6aac3e91898ee400c57418d2e7b3d6df1916/assets/test_2.png
--------------------------------------------------------------------------------
/assets/train_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/DRIT-Tensorflow/384f6aac3e91898ee400c57418d2e7b3d6df1916/assets/train_1.png
--------------------------------------------------------------------------------
/assets/train_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/DRIT-Tensorflow/384f6aac3e91898ee400c57418d2e7b3d6df1916/assets/train_2.png
--------------------------------------------------------------------------------
/assets/true.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/DRIT-Tensorflow/384f6aac3e91898ee400c57418d2e7b3d6df1916/assets/true.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from DRIT import DRIT
2 | import argparse
3 | from utils import *
4 |
5 | """parsing and configuration"""
6 | def parse_args():
7 | desc = "Tensorflow implementation of DRIT"
8 |
9 | parser = argparse.ArgumentParser(description=desc)
10 | parser.add_argument('--phase', type=str, default='train', help='[train, test, guide]')
11 | parser.add_argument('--dataset', type=str, default='cat2dog', help='dataset_name')
12 | parser.add_argument('--augment_flag', type=str2bool, default=True, help='Image augmentation use or not')
13 | parser.add_argument('--decay_flag', type=str2bool, default=True, help='using learning rate decay')
14 |
15 | parser.add_argument('--epoch', type=int, default=50, help='The number of epochs to run')
16 | parser.add_argument('--decay_epoch', type=int, default=25, help='The number of decay epochs to run')
17 | parser.add_argument('--iteration', type=int, default=10000, help='The number of training iterations')
18 | parser.add_argument('--batch_size', type=int, default=1, help='The batch size')
19 | parser.add_argument('--print_freq', type=int, default=1000, help='The number of image_print_freq')
20 | parser.add_argument('--save_freq', type=int, default=1000, help='The number of ckpt_save_freq')
21 |
22 | parser.add_argument('--num_attribute', type=int, default=3, help='number of attributes to sample')
23 | parser.add_argument('--direction', type=str, default='a2b', help='direction of guided image translation')
24 | parser.add_argument('--guide_img', type=str, default='guide.jpg', help='Style guided image translation')
25 |
26 | parser.add_argument('--gan_type', type=str, default='gan', help='GAN loss type [gan / lsgan]')
27 |
28 | parser.add_argument('--lr', type=float, default=0.0001, help='The learning rate')
29 | parser.add_argument('--content_adv_w', type=int, default=1, help='weight of content adversarial loss')
30 | parser.add_argument('--domain_adv_w', type=int, default=1, help='weight of domain adversarial loss')
31 | parser.add_argument('--cycle_w', type=int, default=10, help='weight of cross-cycle reconstruction loss')
32 | parser.add_argument('--recon_w', type=int, default=10, help='weight of self-reconstruction loss')
33 | parser.add_argument('--latent_w', type=int, default=10, help='wight of latent regression loss')
34 | parser.add_argument('--kl_w', type=float, default=0.01, help='weight of kl-divergence loss')
35 |
36 | parser.add_argument('--ch', type=int, default=64, help='base channel number per layer')
37 | parser.add_argument('--concat', type=str2bool, default=False, help='using concat networks')
38 |
39 | # concat = False : for the shape variation translation (cat <-> dog)
40 | # concat = True : for the shape preserving translation (winter <-> summer)
41 |
42 | parser.add_argument('--n_z', type=int, default=8, help='length of z')
43 | parser.add_argument('--n_layer', type=int, default=4, help='number of layers in G, D')
44 |
45 | parser.add_argument('--n_dis', type=int, default=4, help='number of discriminator layer')
46 |
47 | # If you don't use multi-discriminator, then recommend n_dis = 6
48 |
49 | parser.add_argument('--n_scale', type=int, default=3, help='number of scales for discriminator')
50 |
51 | # using the multiscale discriminator often gets better results
52 |
53 | parser.add_argument('--n_d_con', type=int, default=3, help='# of iterations for updating content discrimnator')
54 |
55 | # model can still generate diverse results with n_d_con = 1
56 |
57 | parser.add_argument('--sn', type=str2bool, default=False, help='using spectral normalization')
58 |
59 | parser.add_argument('--img_size', type=int, default=256, help='The size of image')
60 | parser.add_argument('--img_ch', type=int, default=3, help='The size of image channel')
61 |
62 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint',
63 | help='Directory name to save the checkpoints')
64 | parser.add_argument('--result_dir', type=str, default='results',
65 | help='Directory name to save the generated images')
66 | parser.add_argument('--log_dir', type=str, default='logs',
67 | help='Directory name to save training logs')
68 | parser.add_argument('--sample_dir', type=str, default='samples',
69 | help='Directory name to save the samples on training')
70 |
71 | return check_args(parser.parse_args())
72 |
73 | """checking arguments"""
74 | def check_args(args):
75 | # --checkpoint_dir
76 | check_folder(args.checkpoint_dir)
77 |
78 | # --result_dir
79 | check_folder(args.result_dir)
80 |
81 | # --result_dir
82 | check_folder(args.log_dir)
83 |
84 | # --sample_dir
85 | check_folder(args.sample_dir)
86 |
87 | # --epoch
88 | try:
89 | assert args.epoch >= 1
90 | except:
91 | print('number of epochs must be larger than or equal to one')
92 |
93 | # --batch_size
94 | try:
95 | assert args.batch_size >= 1
96 | except:
97 | print('batch size must be larger than or equal to one')
98 | return args
99 |
100 | """main"""
101 | def main():
102 | # parse arguments
103 | args = parse_args()
104 | if args is None:
105 | exit()
106 |
107 | # open session
108 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
109 | gan = DRIT(sess, args)
110 |
111 | # build graph
112 | gan.build_model()
113 |
114 | # show network architecture
115 | show_all_variables()
116 |
117 | if args.phase == 'train' :
118 | # launch the graph in a session
119 | gan.train()
120 | print(" [*] Training finished!")
121 |
122 | if args.phase == 'test' :
123 | gan.test()
124 | print(" [*] Test finished!")
125 |
126 | if args.phase == 'guide' :
127 | gan.guide_test()
128 | print(" [*] Guide finished!")
129 |
130 | if __name__ == '__main__':
131 | main()
132 |
--------------------------------------------------------------------------------
/ops.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import tensorflow.contrib as tf_contrib
3 |
4 | # Xavier : tf_contrib.layers.xavier_initializer()
5 | # He : tf_contrib.layers.variance_scaling_initializer()
6 | # Normal : tf.random_normal_initializer(mean=0.0, stddev=0.02)
7 | # l2_decay : tf_contrib.layers.l2_regularizer(0.0001)
8 |
9 | weight_init = tf.random_normal_initializer(mean=0.0, stddev=0.02)
10 | weight_regularizer = None
11 |
12 | ##################################################################################
13 | # Layer
14 | ##################################################################################
15 |
16 | def conv(x, channels, kernel=4, stride=2, pad=0, pad_type='zero', use_bias=True, sn=False, scope='conv'):
17 | with tf.variable_scope(scope):
18 | if pad > 0 :
19 | if (kernel - stride) % 2 == 0:
20 | pad_top = pad
21 | pad_bottom = pad
22 | pad_left = pad
23 | pad_right = pad
24 |
25 | else:
26 | pad_top = pad
27 | pad_bottom = kernel - stride - pad_top
28 | pad_left = pad
29 | pad_right = kernel - stride - pad_left
30 |
31 | if pad_type == 'zero':
32 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]])
33 | if pad_type == 'reflect':
34 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]], mode='REFLECT')
35 |
36 | if sn :
37 | w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], initializer=weight_init, regularizer=weight_regularizer)
38 | x = tf.nn.conv2d(input=x, filter=spectral_norm(w), strides=[1, stride, stride, 1], padding='VALID')
39 | if use_bias :
40 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0))
41 | x = tf.nn.bias_add(x, bias)
42 |
43 | else :
44 | x = tf.layers.conv2d(inputs=x, filters=channels,
45 | kernel_size=kernel, kernel_initializer=weight_init,
46 | kernel_regularizer=weight_regularizer,
47 | strides=stride, use_bias=use_bias)
48 |
49 |
50 | return x
51 |
52 | def deconv(x, channels, kernel=4, stride=2, padding='SAME', use_bias=True, sn=False, scope='deconv'):
53 | with tf.variable_scope(scope):
54 | x_shape = x.get_shape().as_list()
55 |
56 | if padding == 'SAME':
57 | output_shape = [x_shape[0], x_shape[1] * stride, x_shape[2] * stride, channels]
58 |
59 | else:
60 | output_shape =[x_shape[0], x_shape[1] * stride + max(kernel - stride, 0), x_shape[2] * stride + max(kernel - stride, 0), channels]
61 |
62 | if sn :
63 | w = tf.get_variable("kernel", shape=[kernel, kernel, channels, x.get_shape()[-1]], initializer=weight_init, regularizer=weight_regularizer)
64 | x = tf.nn.conv2d_transpose(x, filter=spectral_norm(w), output_shape=output_shape, strides=[1, stride, stride, 1], padding=padding)
65 |
66 | if use_bias :
67 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0))
68 | x = tf.nn.bias_add(x, bias)
69 |
70 | else :
71 | x = tf.layers.conv2d_transpose(inputs=x, filters=channels,
72 | kernel_size=kernel, kernel_initializer=weight_init, kernel_regularizer=weight_regularizer,
73 | strides=stride, padding=padding, use_bias=use_bias)
74 |
75 | return x
76 |
77 | def fully_conneted(x, channels, use_bias=True, sn=False, scope='fully'):
78 | with tf.variable_scope(scope):
79 | x = tf.layers.flatten(x)
80 | shape = x.get_shape().as_list()
81 | x_channel = shape[-1]
82 |
83 | if sn :
84 | w = tf.get_variable("kernel", [x_channel, channels], tf.float32, initializer=weight_init, regularizer=weight_regularizer)
85 | if use_bias :
86 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0))
87 |
88 | x = tf.matmul(x, spectral_norm(w)) + bias
89 | else :
90 | x = tf.matmul(x, spectral_norm(w))
91 |
92 | else :
93 | x = tf.layers.dense(x, units=channels, kernel_initializer=weight_init, kernel_regularizer=weight_regularizer, use_bias=use_bias)
94 |
95 | return x
96 |
97 | def gaussian_noise_layer(x, is_training=False):
98 | if is_training :
99 | noise = tf.random_normal(shape=tf.shape(x), mean=0.0, stddev=1.0, dtype=tf.float32)
100 | return x + noise
101 |
102 | else :
103 | return x
104 |
105 | ##################################################################################
106 | # Block
107 | ##################################################################################
108 |
109 | def resblock(x_init, channels, use_bias=True, sn=False, scope='resblock'):
110 | with tf.variable_scope(scope):
111 | with tf.variable_scope('res1'):
112 | x = conv(x_init, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias, sn=sn)
113 | x = instance_norm(x)
114 | x = relu(x)
115 |
116 | with tf.variable_scope('res2'):
117 | x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias, sn=sn)
118 | x = instance_norm(x)
119 |
120 | return x + x_init
121 |
122 | def basic_block(x_init, channels, use_bias=True, sn=False, scope='basic_block') :
123 | with tf.variable_scope(scope) :
124 | x = lrelu(x_init, 0.2)
125 | x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias, sn=sn)
126 |
127 | x = lrelu(x, 0.2)
128 | x = conv_avg(x, channels, use_bias=use_bias, sn=sn)
129 |
130 | shortcut = avg_conv(x_init, channels, use_bias=use_bias, sn=sn)
131 |
132 | return x + shortcut
133 |
134 | def mis_resblock(x_init, z, channels, use_bias=True, sn=False, scope='mis_resblock') :
135 | with tf.variable_scope(scope) :
136 | z = tf.reshape(z, shape=[-1, 1, 1, z.shape[-1]])
137 | z = tf.tile(z, multiples=[1, x_init.shape[1], x_init.shape[2], 1]) # expand
138 |
139 | with tf.variable_scope('mis1') :
140 | x = conv(x_init, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias, sn=sn, scope='conv3x3')
141 | x = instance_norm(x)
142 |
143 | x = tf.concat([x, z], axis=-1)
144 | x = conv(x, channels * 2, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv1x1_0')
145 | x = relu(x)
146 |
147 | x = conv(x, channels, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv1x1_1')
148 | x = relu(x)
149 |
150 | with tf.variable_scope('mis2') :
151 | x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias, sn=sn, scope='conv3x3')
152 | x = instance_norm(x)
153 |
154 | x = tf.concat([x, z], axis=-1)
155 | x = conv(x, channels * 2, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv1x1_0')
156 | x = relu(x)
157 |
158 | x = conv(x, channels, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv1x1_1')
159 | x = relu(x)
160 |
161 | return x + x_init
162 |
163 | def avg_conv(x, channels, use_bias=True, sn=False, scope='avg_conv') :
164 | with tf.variable_scope(scope) :
165 | x = avg_pooling(x, kernel=2, stride=2)
166 | x = conv(x, channels, kernel=1, stride=1, use_bias=use_bias, sn=sn)
167 |
168 | return x
169 |
170 | def conv_avg(x, channels, use_bias=True, sn=False, scope='conv_avg') :
171 | with tf.variable_scope(scope) :
172 | x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias, sn=sn)
173 | x = avg_pooling(x, kernel=2, stride=2)
174 |
175 | return x
176 |
177 | def expand_concat(x, z) :
178 | z = tf.reshape(z, shape=[z.shape[0], 1, 1, -1])
179 | z = tf.tile(z, multiples=[1, x.shape[1], x.shape[2], 1]) # expand
180 | x = tf.concat([x, z], axis=-1)
181 |
182 | return x
183 |
184 | ##################################################################################
185 | # Sampling
186 | ##################################################################################
187 |
188 | def down_sample(x) :
189 | return avg_pooling(x, kernel=3, stride=2, pad=1)
190 |
191 | def avg_pooling(x, kernel=2, stride=2, pad=0) :
192 | if pad > 0 :
193 | if (kernel - stride) % 2 == 0:
194 | pad_top = pad
195 | pad_bottom = pad
196 | pad_left = pad
197 | pad_right = pad
198 |
199 | else:
200 | pad_top = pad
201 | pad_bottom = kernel - stride - pad_top
202 | pad_left = pad
203 | pad_right = kernel - stride - pad_left
204 |
205 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]])
206 |
207 | return tf.layers.average_pooling2d(x, pool_size=kernel, strides=stride, padding='VALID')
208 |
209 | def global_avg_pooling(x):
210 | gap = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
211 |
212 | return gap
213 |
214 | def z_sample(mean, logvar) :
215 | eps = tf.random_normal(shape=tf.shape(mean), mean=0.0, stddev=1.0, dtype=tf.float32)
216 |
217 | return mean + tf.exp(logvar * 0.5) * eps
218 |
219 | ##################################################################################
220 | # Activation function
221 | ##################################################################################
222 |
223 | def lrelu(x, alpha=0.01):
224 | # pytorch alpha is 0.01
225 | return tf.nn.leaky_relu(x, alpha)
226 |
227 |
228 | def relu(x):
229 | return tf.nn.relu(x)
230 |
231 |
232 | def tanh(x):
233 | return tf.tanh(x)
234 |
235 | ##################################################################################
236 | # Normalization function
237 | ##################################################################################
238 |
239 | def instance_norm(x, scope='instance_norm'):
240 | return tf_contrib.layers.instance_norm(x,
241 | epsilon=1e-05,
242 | center=True, scale=True,
243 | scope=scope)
244 |
245 | def layer_norm(x, scope='layer_norm') :
246 | return tf_contrib.layers.layer_norm(x,
247 | center=True, scale=True,
248 | scope=scope)
249 |
250 | def spectral_norm(w, iteration=1):
251 | w_shape = w.shape.as_list()
252 | w = tf.reshape(w, [-1, w_shape[-1]])
253 |
254 | u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.random_normal_initializer(), trainable=False)
255 |
256 | u_hat = u
257 | v_hat = None
258 | for i in range(iteration):
259 | """
260 | power iteration
261 | Usually iteration = 1 will be enough
262 | """
263 | v_ = tf.matmul(u_hat, tf.transpose(w))
264 | v_hat = tf.nn.l2_normalize(v_)
265 |
266 | u_ = tf.matmul(v_hat, w)
267 | u_hat = tf.nn.l2_normalize(u_)
268 |
269 | u_hat = tf.stop_gradient(u_hat)
270 | v_hat = tf.stop_gradient(v_hat)
271 |
272 | sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat))
273 |
274 | with tf.control_dependencies([u.assign(u_hat)]):
275 | w_norm = w / sigma
276 | w_norm = tf.reshape(w_norm, w_shape)
277 |
278 |
279 | return w_norm
280 |
281 | ##################################################################################
282 | # Loss function
283 | ##################################################################################
284 |
285 | def discriminator_loss(type, real, fake, fake_random=None, content=False):
286 | n_scale = len(real)
287 | loss = []
288 |
289 | real_loss = 0
290 | fake_loss = 0
291 | fake_random_loss = 0
292 |
293 | if content :
294 | for i in range(n_scale):
295 | if type == 'lsgan' :
296 | real_loss = tf.reduce_mean(tf.squared_difference(real[i], 1.0))
297 | fake_loss = tf.reduce_mean(tf.square(fake[i]))
298 |
299 | if type =='gan' :
300 | real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real[i]), logits=real[i]))
301 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(fake[i]), logits=fake[i]))
302 |
303 | loss.append(real_loss + fake_loss)
304 |
305 | else :
306 | for i in range(n_scale) :
307 | if type == 'lsgan' :
308 | real_loss = tf.reduce_mean(tf.squared_difference(real[i], 1.0))
309 | fake_loss = tf.reduce_mean(tf.square(fake[i]))
310 | fake_random_loss = tf.reduce_mean(tf.square(fake_random[i]))
311 |
312 | if type == 'gan' :
313 | real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real[i]), logits=real[i]))
314 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(fake[i]), logits=fake[i]))
315 | fake_random_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(fake_random[i]), logits=fake_random[i]))
316 |
317 | loss.append(real_loss * 2 + fake_loss + fake_random_loss)
318 |
319 | return sum(loss)
320 |
321 |
322 | def generator_loss(type, fake, content=False):
323 | n_scale = len(fake)
324 | loss = []
325 |
326 | fake_loss = 0
327 |
328 | if content :
329 | for i in range(n_scale):
330 | if type =='lsgan' :
331 | fake_loss = tf.reduce_mean(tf.squared_difference(fake[i], 0.5))
332 |
333 | if type == 'gan' :
334 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=0.5 * tf.ones_like(fake[i]), logits=fake[i]))
335 |
336 | loss.append(fake_loss)
337 | else :
338 | for i in range(n_scale) :
339 | if type == 'lsgan' :
340 | fake_loss = tf.reduce_mean(tf.squared_difference(fake[i], 1.0))
341 |
342 | if type == 'gan' :
343 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(fake[i]), logits=fake[i]))
344 |
345 | loss.append(fake_loss)
346 |
347 |
348 | return sum(loss)
349 |
350 |
351 | def l2_regularize(x) :
352 | loss = tf.reduce_mean(tf.square(x))
353 |
354 | return loss
355 |
356 | def kl_loss(mu, logvar) :
357 | loss = 0.5 * tf.reduce_sum(tf.square(mu) + tf.exp(logvar) - 1 - logvar, axis=-1)
358 | loss = tf.reduce_mean(loss)
359 |
360 |
361 | return loss
362 |
363 | def L1_loss(x, y):
364 | loss = tf.reduce_mean(tf.abs(x - y))
365 |
366 | return loss
367 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.contrib import slim
3 | from scipy import misc
4 | import os, random
5 | import numpy as np
6 |
7 | class ImageData:
8 |
9 | def __init__(self, img_size, channels, augment_flag=False):
10 | self.img_size = img_size
11 | self.channels = channels
12 | self.augment_flag = augment_flag
13 |
14 | def image_processing(self, filename):
15 | x = tf.read_file(filename)
16 | x_decode = tf.image.decode_jpeg(x, channels=self.channels)
17 | img = tf.image.resize_images(x_decode, [self.img_size, self.img_size])
18 | img = tf.cast(img, tf.float32) / 127.5 - 1
19 |
20 | if self.augment_flag :
21 | if self.img_size < 256 :
22 | augment_size = 256
23 | else :
24 | augment_size = self.img_size + 30
25 | p = random.random()
26 | if p > 0.5:
27 | img = augmentation(img, augment_size)
28 |
29 | return img
30 |
31 |
32 | def load_test_data(image_path, size=256):
33 | img = misc.imread(image_path, mode='RGB')
34 | img = misc.imresize(img, [size, size])
35 | img = np.expand_dims(img, axis=0)
36 | img = preprocessing(img)
37 |
38 | return img
39 |
40 | def preprocessing(x):
41 | x = x/127.5 - 1 # -1 ~ 1
42 | return x
43 |
44 | def augmentation(image, aug_img_size):
45 | seed = random.randint(0, 2 ** 31 - 1)
46 | ori_image_shape = tf.shape(image)
47 | image = tf.image.random_flip_left_right(image, seed=seed)
48 | image = tf.image.resize_images(image, [aug_img_size, aug_img_size])
49 | image = tf.random_crop(image, ori_image_shape, seed=seed)
50 | return image
51 |
52 | def save_images(images, size, image_path):
53 | return imsave(inverse_transform(images), size, image_path)
54 |
55 | def inverse_transform(images):
56 | return (images+1.) / 2
57 |
58 | def imsave(images, size, path):
59 | return misc.imsave(path, merge(images, size))
60 |
61 | def merge(images, size):
62 | h, w = images.shape[1], images.shape[2]
63 | img = np.zeros((h * size[0], w * size[1], 3))
64 | for idx, image in enumerate(images):
65 | i = idx % size[1]
66 | j = idx // size[1]
67 | img[h*j:h*(j+1), w*i:w*(i+1), :] = image
68 |
69 | return img
70 |
71 | def show_all_variables():
72 | model_vars = tf.trainable_variables()
73 | slim.model_analyzer.analyze_vars(model_vars, print_info=True)
74 |
75 | def check_folder(log_dir):
76 | if not os.path.exists(log_dir):
77 | os.makedirs(log_dir)
78 | return log_dir
79 |
80 | def str2bool(x):
81 | return x.lower() in ('true')
--------------------------------------------------------------------------------