├── .DS_Store
├── .gitignore
├── DatasetAPI
├── UNIT.py
├── UNIT_multi_gpu.py
├── main.py
├── main_multi_gpu.py
├── ops.py
└── utils.py
├── LICENSE
├── README.md
├── UNIT.py
├── UNIT_multi_gpu.py
├── assests
├── .DS_Store
├── architecture.png
├── cat_species.gif
├── cat_trans.png
├── compare.png
├── cycle.png
├── dog_breed.gif
├── dog_trans.png
├── faces.png
├── fail.png
├── framework.png
├── gan_model.png
├── slide
│ ├── compare.png
│ ├── cycle.png
│ ├── framework.png
│ ├── gan_model.png
│ ├── training_objective.png
│ └── vae_model.png
├── success.png
├── training_objective__.png
└── vae_model.png
├── main.py
├── main_multi_gpu.py
├── ops.py
└── utils.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/.DS_Store
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 |
--------------------------------------------------------------------------------
/DatasetAPI/UNIT.py:
--------------------------------------------------------------------------------
1 | from ops import *
2 | from utils import *
3 | from glob import glob
4 | import time
5 | from tensorflow.contrib.data import batch_and_drop_remainder
6 |
7 | class UNIT(object) :
8 | def __init__(self, sess, args):
9 | self.model_name = 'UNIT'
10 | self.sess = sess
11 | self.checkpoint_dir = args.checkpoint_dir
12 | self.result_dir = args.result_dir
13 | self.log_dir = args.log_dir
14 | self.sample_dir = args.sample_dir
15 | self.dataset_name = args.dataset
16 | self.augment_flag = args.augment_flag
17 |
18 | self.epoch = args.epoch
19 | self.iteration = args.iteration
20 | self.gan_type = args.gan_type
21 |
22 | self.batch_size = args.batch_size
23 | self.print_freq = args.print_freq
24 | self.save_freq = args.save_freq
25 |
26 | self.img_size = args.img_size
27 | self.img_ch = args.img_ch
28 |
29 | self.init_lr = args.lr
30 | self.ch = args.ch
31 |
32 | """ Weight about VAE """
33 | self.KL_weight = args.KL_weight # lambda 1
34 | self.L1_weight = args.L1_weight # lambda 2
35 |
36 | """ Weight about VAE Cycle"""
37 | self.KL_cycle_weight = args.KL_cycle_weight # lambda 3
38 | self.L1_cycle_weight = args.L1_cycle_weight # lambda 4
39 |
40 | """ Weight about GAN """
41 | self.GAN_weight = args.GAN_weight # lambda 0
42 |
43 | """ Encoder """
44 | self.n_encoder = args.n_encoder
45 | self.n_enc_resblock = args.n_enc_resblock
46 | self.n_enc_share = args.n_enc_share
47 |
48 | """ Generator """
49 | self.n_gen_share = args.n_gen_share
50 | self.n_gen_resblock = args.n_gen_resblock
51 | self.n_gen_decoder = args.n_gen_decoder
52 |
53 | """ Discriminator """
54 | self.n_dis = args.n_dis
55 |
56 | self.sample_dir = os.path.join(args.sample_dir, self.model_dir)
57 | check_folder(self.sample_dir)
58 |
59 | self.trainA_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainA'))
60 | self.trainB_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainB'))
61 | self.dataset_num = max(len(self.trainA_dataset), len(self.trainB_dataset))
62 |
63 | print("##### Information #####")
64 | print("# gan type : ", self.gan_type)
65 | print("# dataset : ", self.dataset_name)
66 | print("# max dataset number : ", self.dataset_num)
67 | print("# batch_size : ", self.batch_size)
68 | print("# epoch : ", self.epoch)
69 | print("# iteration per epoch : ", self.iteration)
70 |
71 | print()
72 |
73 | print("##### Encoder #####")
74 | print("# encoder blocks : ", self.n_encoder)
75 | print("# encoder resblock : ", self.n_enc_resblock)
76 | print("# encoder share : ", self.n_enc_share)
77 |
78 | print()
79 |
80 | print("##### Decoder #####")
81 | print("# decoder share : ", self.n_gen_share)
82 | print("# decoder resblock : ", self.n_gen_resblock)
83 | print("# decoder blocks : ", self.n_gen_decoder)
84 |
85 | print()
86 |
87 | print("##### Discriminator #####")
88 | print("# Discriminator layer : ", self.n_dis)
89 |
90 | ##############################################################################
91 | # BEGIN of ENCODERS
92 |
93 | def encoder(self, x, reuse=False, scope="encoder"):
94 | channel = self.ch
95 | with tf.variable_scope(scope, reuse=reuse):
96 | x = conv(x, channel, kernel=7, stride=1, pad=3, scope='conv_0')
97 | x = lrelu(x, 0.01)
98 |
99 | for i in range(1, self.n_encoder):
100 | x = conv(x, channel * 2, kernel=3, stride=2, pad=1, scope='conv_' + str(i))
101 | x = lrelu(x, 0.01)
102 | channel *= 2
103 |
104 | # channel = 256
105 | for i in range(0, self.n_enc_resblock):
106 | x = resblock(x, channel, scope='resblock_'+str(i))
107 |
108 | return x
109 | # END of ENCODERS
110 | ##############################################################################
111 |
112 | ##############################################################################
113 | # BEGIN of SHARED LAYERS
114 | # Shared residual-blocks
115 |
116 | def share_encoder(self, x, reuse=False, scope="share_encoder"):
117 | channel = self.ch * pow(2, self.n_encoder - 1)
118 | with tf.variable_scope(scope, reuse=reuse):
119 | for i in range(0, self.n_enc_share):
120 | x = resblock(x, channel, scope='resblock_' + str(i))
121 |
122 | x = gaussian_noise_layer(x)
123 |
124 | return x
125 |
126 | def share_generator(self, x, reuse=False, scope="share_generator"):
127 | channel = self.ch * pow(2, self.n_encoder - 1)
128 | with tf.variable_scope(scope, reuse=reuse):
129 | for i in range(0, self.n_gen_share):
130 | x = resblock(x, channel, scope='resblock_' + str(i))
131 |
132 | return x
133 | # END of SHARED LAYERS
134 | ##############################################################################
135 |
136 | ##############################################################################
137 | # BEGIN of DECODERS
138 |
139 | def generator(self, x, reuse=False, scope="generator"):
140 | channel = self.ch * pow(2, self.n_encoder - 1)
141 | with tf.variable_scope(scope, reuse=reuse):
142 | for i in range(0, self.n_gen_resblock):
143 | x = resblock(x, channel, scope='resblock_' + str(i))
144 |
145 | for i in range(0, self.n_gen_decoder - 1):
146 | x = deconv(x, channel // 2, kernel=3, stride=2, scope='deconv_' + str(i))
147 | x = lrelu(x, 0.01)
148 | channel = channel // 2
149 |
150 | x = deconv(x, channels=3, kernel=1, stride=1, scope='G_logit')
151 | x = tanh(x)
152 |
153 | return x
154 | # END of DECODERS
155 | ##############################################################################
156 |
157 | ##############################################################################
158 | # BEGIN of DISCRIMINATORS
159 |
160 | def discriminator(self, x, reuse=False, scope="discriminator"):
161 | channel = self.ch
162 | with tf.variable_scope(scope, reuse=reuse):
163 | x = conv(x, channel, kernel=3, stride=2, pad=1, scope='conv_0')
164 | x = lrelu(x, 0.01)
165 |
166 | for i in range(1, self.n_dis):
167 | x = conv(x, channel * 2, kernel=3, stride=2, pad=1, scope='conv_' + str(i))
168 | x = lrelu(x, 0.01)
169 | channel *= 2
170 |
171 | x = conv(x, channels=1, kernel=1, stride=1, scope='D_logit')
172 |
173 | return x
174 |
175 | # END of DISCRIMINATORS
176 | ##############################################################################
177 |
178 | def translation(self, x_A, x_B):
179 | out = tf.concat([self.encoder(x_A, scope="encoder_A"), self.encoder(x_B, scope="encoder_B")], axis=0)
180 | shared = self.share_encoder(out)
181 | out = self.share_generator(shared)
182 |
183 | out_A = self.generator(out, scope="generator_A")
184 | out_B = self.generator(out, scope="generator_B")
185 |
186 | x_Aa, x_Ba = tf.split(out_A, 2, axis=0)
187 | x_Ab, x_Bb = tf.split(out_B, 2, axis=0)
188 |
189 | return x_Aa, x_Ba, x_Ab, x_Bb, shared
190 |
191 | def generate_a2b(self, x_A):
192 | out = self.encoder(x_A, reuse=True, scope="encoder_A")
193 | shared = self.share_encoder(out, reuse=True)
194 | out = self.share_generator(shared, reuse=True)
195 | out = self.generator(out, reuse=True, scope="generator_B")
196 |
197 | return out, shared
198 |
199 | def generate_b2a(self, x_B):
200 | out = self.encoder(x_B, reuse=True, scope="encoder_B")
201 | shared = self.share_encoder(out, reuse=True)
202 | out = self.share_generator(shared, reuse=True)
203 | out = self.generator(out, reuse=True, scope="generator_A")
204 |
205 | return out, shared
206 |
207 | def discriminate_real(self, x_A, x_B):
208 | real_A_logit = self.discriminator(x_A, scope="discriminator_A")
209 | real_B_logit = self.discriminator(x_B, scope="discriminator_B")
210 |
211 | return real_A_logit, real_B_logit
212 |
213 | def discriminate_fake(self, x_ba, x_ab):
214 | fake_A_logit = self.discriminator(x_ba, reuse=True, scope="discriminator_A")
215 | fake_B_logit = self.discriminator(x_ab, reuse=True, scope="discriminator_B")
216 |
217 | return fake_A_logit, fake_B_logit
218 |
219 | def build_model(self):
220 | self.lr = tf.placeholder(tf.float32, name='learning_rate')
221 |
222 | """ Input Image"""
223 | Image_Data_Class = ImageData(self.img_size, self.img_ch, self.augment_flag)
224 |
225 | trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset)
226 | trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset)
227 |
228 | trainA = trainA.prefetch(self.batch_size).shuffle(self.dataset_num).map(Image_Data_Class.image_processing, num_parallel_calls=8).apply(batch_and_drop_remainder(self.batch_size)).repeat()
229 | trainB = trainB.prefetch(self.batch_size).shuffle(self.dataset_num).map(Image_Data_Class.image_processing, num_parallel_calls=8).apply(batch_and_drop_remainder(self.batch_size)).repeat()
230 |
231 | trainA_iterator = trainA.make_one_shot_iterator()
232 | trainB_iterator = trainB.make_one_shot_iterator()
233 |
234 |
235 | self.domain_A = trainA_iterator.get_next()
236 | self.domain_B = trainB_iterator.get_next()
237 |
238 |
239 | """ Define Encoder, Generator, Discriminator """
240 | x_aa, x_ba, x_ab, x_bb, shared = self.translation(self.domain_A, self.domain_B)
241 | x_bab, shared_bab = self.generate_a2b(x_ba)
242 | x_aba, shared_aba = self.generate_b2a(x_ab)
243 |
244 | real_A_logit, real_B_logit = self.discriminate_real(self.domain_A, self.domain_B)
245 |
246 |
247 | fake_A_logit, fake_B_logit = self.discriminate_fake(x_ba, x_ab)
248 |
249 | """ Define Loss """
250 | G_ad_loss_a = generator_loss(self.gan_type, fake_A_logit)
251 | G_ad_loss_b = generator_loss(self.gan_type, fake_B_logit)
252 |
253 | D_ad_loss_a = discriminator_loss(self.gan_type, real_A_logit, fake_A_logit)
254 | D_ad_loss_b = discriminator_loss(self.gan_type, real_B_logit, fake_B_logit)
255 |
256 | enc_loss = KL_divergence(shared)
257 | enc_bab_loss = KL_divergence(shared_bab)
258 | enc_aba_loss = KL_divergence(shared_aba)
259 |
260 | l1_loss_a = L1_loss(x_aa, self.domain_A) # identity
261 | l1_loss_b = L1_loss(x_bb, self.domain_B) # identity
262 | l1_loss_aba = L1_loss(x_aba, self.domain_A) # reconstruction
263 | l1_loss_bab = L1_loss(x_bab, self.domain_B) # reconstruction
264 |
265 | Generator_A_loss = self.GAN_weight * G_ad_loss_a + \
266 | self.L1_weight * l1_loss_a + \
267 | self.L1_cycle_weight * l1_loss_aba + \
268 | self.KL_weight * enc_loss + \
269 | self.KL_cycle_weight * enc_bab_loss
270 |
271 | Generator_B_loss = self.GAN_weight * G_ad_loss_b + \
272 | self.L1_weight * l1_loss_b + \
273 | self.L1_cycle_weight * l1_loss_bab + \
274 | self.KL_weight * enc_loss + \
275 | self.KL_cycle_weight * enc_aba_loss
276 |
277 | Discriminator_A_loss = self.GAN_weight * D_ad_loss_a
278 | Discriminator_B_loss = self.GAN_weight * D_ad_loss_b
279 |
280 | self.Generator_loss = Generator_A_loss + Generator_B_loss
281 | self.Discriminator_loss = Discriminator_A_loss + Discriminator_B_loss
282 |
283 | """ Training """
284 | t_vars = tf.trainable_variables()
285 | G_vars = [var for var in t_vars if 'generator' in var.name or 'encoder' in var.name]
286 | D_vars = [var for var in t_vars if 'discriminator' in var.name]
287 |
288 | self.G_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Generator_loss, var_list=G_vars)
289 | self.D_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Discriminator_loss, var_list=D_vars)
290 |
291 | """" Summary """
292 | self.all_G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss)
293 | self.all_D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss)
294 | self.G_A_loss = tf.summary.scalar("G_A_loss", Generator_A_loss)
295 | self.G_B_loss = tf.summary.scalar("G_B_loss", Generator_B_loss)
296 | self.D_A_loss = tf.summary.scalar("D_A_loss", Discriminator_A_loss)
297 | self.D_B_loss = tf.summary.scalar("D_B_loss", Discriminator_B_loss)
298 |
299 | self.G_loss = tf.summary.merge([self.G_A_loss, self.G_B_loss, self.all_G_loss])
300 | self.D_loss = tf.summary.merge([self.D_A_loss, self.D_B_loss, self.all_D_loss])
301 |
302 | """ Image """
303 | self.fake_A = x_ba
304 | self.fake_B = x_ab
305 |
306 | self.real_A = self.domain_A
307 | self.real_B = self.domain_B
308 |
309 | """ Test """
310 | self.test_image = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_image')
311 |
312 | self.test_fake_A, _ = self.generate_b2a(self.test_image)
313 | self.test_fake_B, _ = self.generate_a2b(self.test_image)
314 |
315 | def train(self):
316 | # initialize all variables
317 | tf.global_variables_initializer().run()
318 |
319 | # saver to save model
320 | self.saver = tf.train.Saver()
321 |
322 | # summary writer
323 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_dir, self.sess.graph)
324 |
325 | # restore check-point if it exits
326 | could_load, checkpoint_counter = self.load(self.checkpoint_dir)
327 | if could_load:
328 | start_epoch = (int)(checkpoint_counter / self.iteration)
329 | start_batch_id = checkpoint_counter - start_epoch * self.iteration
330 | counter = checkpoint_counter
331 | print(" [*] Load SUCCESS")
332 | else:
333 | start_epoch = 0
334 | start_batch_id = 0
335 | counter = 1
336 | print(" [!] Load failed...")
337 |
338 | # loop for epoch
339 | start_time = time.time()
340 | lr = self.init_lr
341 | for epoch in range(start_epoch, self.epoch):
342 | for idx in range(start_batch_id, self.iteration):
343 | train_feed_dict = {
344 | self.lr : lr
345 | }
346 |
347 | # Update D
348 | _, d_loss, summary_str = self.sess.run([self.D_optim, self.Discriminator_loss, self.D_loss], feed_dict = train_feed_dict)
349 | self.writer.add_summary(summary_str, counter)
350 |
351 | # Update G
352 | batch_A_images, batch_B_images, fake_A, fake_B, _, g_loss, summary_str = self.sess.run([self.real_A, self.real_B, self.fake_A, self.fake_B, self.G_optim, self.Generator_loss, self.G_loss], feed_dict = train_feed_dict)
353 | self.writer.add_summary(summary_str, counter)
354 |
355 | # display training status
356 | counter += 1
357 | print("Epoch: [%2d] [%6d/%6d] time: %4.4f d_loss: %.8f, g_loss: %.8f" \
358 | % (epoch, idx, self.iteration, time.time() - start_time, d_loss, g_loss))
359 |
360 | if np.mod(idx+1, self.print_freq) == 0 :
361 | save_images(batch_A_images, [self.batch_size, 1],
362 | './{}/real_A_{:02d}_{:06d}.jpg'.format(self.sample_dir, epoch, idx+1))
363 | # save_images(batch_B_images, [self.batch_size, 1],
364 | # './{}/real_B_{}_{:02d}_{:06d}.jpg'.format(self.sample_dir, gpu_id, epoch, idx+1))
365 |
366 | # save_images(fake_A, [self.batch_size, 1],
367 | # './{}/fake_A_{}_{:02d}_{:06d}.jpg'.format(self.sample_dir, gpu_id, epoch, idx+1))
368 | save_images(fake_B, [self.batch_size, 1],
369 | './{}/fake_B_{:02d}_{:06d}.jpg'.format(self.sample_dir, epoch, idx+1))
370 |
371 | if np.mod(idx+1, self.save_freq) == 0 :
372 | self.save(self.checkpoint_dir, counter)
373 |
374 | # After an epoch, start_batch_id is set to zero
375 | # non-zero value is only for the first epoch after loading pre-trained model
376 | start_batch_id = 0
377 |
378 | # save model for final step
379 | self.save(self.checkpoint_dir, counter)
380 |
381 | @property
382 | def model_dir(self):
383 | return "{}_{}_{}".format(self.model_name, self.dataset_name, self.gan_type)
384 |
385 | def save(self, checkpoint_dir, step):
386 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
387 |
388 | if not os.path.exists(checkpoint_dir):
389 | os.makedirs(checkpoint_dir)
390 |
391 | self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step)
392 |
393 | def load(self, checkpoint_dir):
394 | import re
395 | print(" [*] Reading checkpoints...")
396 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
397 |
398 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
399 | if ckpt and ckpt.model_checkpoint_path:
400 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
401 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
402 | counter = int(next(re.finditer("(\d+)(?!.*\d)", ckpt_name)).group(0))
403 | print(" [*] Success to read {}".format(ckpt_name))
404 | return True, counter
405 | else:
406 | print(" [*] Failed to find a checkpoint")
407 | return False, 0
408 |
409 | def test(self):
410 | tf.global_variables_initializer().run()
411 | test_A_files = glob('./dataset/{}/*.*'.format(self.dataset_name + '/testA'))
412 | test_B_files = glob('./dataset/{}/*.*'.format(self.dataset_name + '/testB'))
413 |
414 | self.saver = tf.train.Saver()
415 | could_load, checkpoint_counter = self.load(self.checkpoint_dir)
416 | self.result_dir = os.path.join(self.result_dir, self.model_dir)
417 | check_folder(self.result_dir)
418 |
419 | if could_load :
420 | print(" [*] Load SUCCESS")
421 | else :
422 | print(" [!] Load failed...")
423 |
424 | # write html for visual comparison
425 | index_path = os.path.join(self.result_dir, 'index.html')
426 | index = open(index_path, 'w')
427 | index.write("
")
428 | index.write("name | input | output |
")
429 |
430 | for sample_file in test_A_files : # A -> B
431 | print('Processing A image: ' + sample_file)
432 | sample_image = np.asarray(load_test_data(sample_file, size=self.img_size))
433 | image_path = os.path.join(self.result_dir, '{0}'.format(os.path.basename(sample_file)))
434 |
435 | fake_img = self.sess.run(self.test_fake_B, feed_dict={self.test_image: sample_image})
436 | save_images(fake_img, [1, 1], image_path)
437 |
438 | index.write("%s | " % os.path.basename(image_path))
439 | index.write(" | " % (sample_file if os.path.isabs(sample_file) else (
440 | '../..' + os.path.sep + sample_file), self.img_size, self.img_size))
441 | index.write(" | " % (image_path if os.path.isabs(image_path) else (
442 | '../..' + os.path.sep + image_path), self.img_size, self.img_size))
443 | index.write("")
444 |
445 | for sample_file in test_B_files : # B -> A
446 | print('Processing B image: ' + sample_file)
447 | sample_image = np.asarray(load_test_data(sample_file, size=self.img_size))
448 | image_path = os.path.join(self.result_dir, '{0}'.format(os.path.basename(sample_file)))
449 |
450 | fake_img = self.sess.run(self.test_fake_A, feed_dict={self.test_image: sample_image})
451 | save_images(fake_img, [1, 1], image_path)
452 |
453 | index.write("%s | " % os.path.basename(image_path))
454 | index.write(" | " % (sample_file if os.path.isabs(sample_file) else (
455 | '../..' + os.path.sep + sample_file), self.img_size, self.img_size))
456 | index.write(" | " % (image_path if os.path.isabs(image_path) else (
457 | '../..' + os.path.sep + image_path), self.img_size, self.img_size))
458 | index.write("")
459 |
460 | index.close()
461 |
--------------------------------------------------------------------------------
/DatasetAPI/UNIT_multi_gpu.py:
--------------------------------------------------------------------------------
1 | from ops import *
2 | from utils import *
3 | from glob import glob
4 | import time
5 | from tensorflow.contrib.data import batch_and_drop_remainder
6 |
7 | class UNIT(object) :
8 | def __init__(self, sess, args):
9 | self.model_name = 'UNIT'
10 | self.sess = sess
11 | self.checkpoint_dir = args.checkpoint_dir
12 | self.result_dir = args.result_dir
13 | self.log_dir = args.log_dir
14 | self.sample_dir = args.sample_dir
15 | self.dataset_name = args.dataset
16 | self.augment_flag = args.augment_flag
17 |
18 | self.epoch = args.epoch
19 | self.iteration = args.iteration
20 | self.gan_type = args.gan_type
21 |
22 | self.batch_size_per_gpu = args.batch_size
23 | self.batch_size = args.batch_size * args.gpu_num
24 | self.gpu_num = args.gpu_num
25 | self.print_freq = args.print_freq
26 | self.save_freq = args.save_freq
27 |
28 | self.img_size = args.img_size
29 | self.img_ch = args.img_ch
30 |
31 | self.init_lr = args.lr
32 | self.ch = args.ch
33 |
34 | """ Weight about VAE """
35 | self.KL_weight = args.KL_weight # lambda 1
36 | self.L1_weight = args.L1_weight # lambda 2
37 |
38 | """ Weight about VAE Cycle"""
39 | self.KL_cycle_weight = args.KL_cycle_weight # lambda 3
40 | self.L1_cycle_weight = args.L1_cycle_weight # lambda 4
41 |
42 | """ Weight about GAN """
43 | self.GAN_weight = args.GAN_weight # lambda 0
44 |
45 | """ Encoder """
46 | self.n_encoder = args.n_encoder
47 | self.n_enc_resblock = args.n_enc_resblock
48 | self.n_enc_share = args.n_enc_share
49 |
50 | """ Generator """
51 | self.n_gen_share = args.n_gen_share
52 | self.n_gen_resblock = args.n_gen_resblock
53 | self.n_gen_decoder = args.n_gen_decoder
54 |
55 | """ Discriminator """
56 | self.n_dis = args.n_dis
57 |
58 | self.sample_dir = os.path.join(args.sample_dir, self.model_dir)
59 | check_folder(self.sample_dir)
60 |
61 | self.trainA_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainA'))
62 | self.trainB_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainB'))
63 | self.dataset_num = max(len(self.trainA_dataset), len(self.trainB_dataset))
64 |
65 | print("##### Information #####")
66 | print("# gan type : ", self.gan_type)
67 | print("# dataset : ", self.dataset_name)
68 | print("# max dataset number : ", self.dataset_num)
69 | print("# batch_size : ", self.batch_size)
70 | print("# epoch : ", self.epoch)
71 | print("# iteration per epoch : ", self.iteration)
72 |
73 | print()
74 |
75 | print("##### Encoder #####")
76 | print("# encoder blocks : ", self.n_encoder)
77 | print("# encoder resblock : ", self.n_enc_resblock)
78 | print("# encoder share : ", self.n_enc_share)
79 |
80 | print()
81 |
82 | print("##### Decoder #####")
83 | print("# decoder share : ", self.n_gen_share)
84 | print("# decoder resblock : ", self.n_gen_resblock)
85 | print("# decoder blocks : ", self.n_gen_decoder)
86 |
87 | print()
88 |
89 | print("##### Discriminator #####")
90 | print("# Discriminator layer : ", self.n_dis)
91 |
92 | ##############################################################################
93 | # BEGIN of ENCODERS
94 |
95 | def encoder(self, x, reuse=False, scope="encoder"):
96 | channel = self.ch
97 | with tf.variable_scope(scope, reuse=reuse):
98 | x = conv(x, channel, kernel=7, stride=1, pad=3, scope='conv_0')
99 | x = lrelu(x, 0.01)
100 |
101 | for i in range(1, self.n_encoder):
102 | x = conv(x, channel * 2, kernel=3, stride=2, pad=1, scope='conv_' + str(i))
103 | x = lrelu(x, 0.01)
104 | channel *= 2
105 |
106 | # channel = 256
107 | for i in range(0, self.n_enc_resblock):
108 | x = resblock(x, channel, scope='resblock_'+str(i))
109 |
110 | return x
111 | # END of ENCODERS
112 | ##############################################################################
113 |
114 | ##############################################################################
115 | # BEGIN of SHARED LAYERS
116 | # Shared residual-blocks
117 |
118 | def share_encoder(self, x, reuse=False, scope="share_encoder"):
119 | channel = self.ch * pow(2, self.n_encoder - 1)
120 | with tf.variable_scope(scope, reuse=reuse):
121 | for i in range(0, self.n_enc_share):
122 | x = resblock(x, channel, scope='resblock_' + str(i))
123 |
124 | x = gaussian_noise_layer(x)
125 |
126 | return x
127 |
128 | def share_generator(self, x, reuse=False, scope="share_generator"):
129 | channel = self.ch * pow(2, self.n_encoder - 1)
130 | with tf.variable_scope(scope, reuse=reuse):
131 | for i in range(0, self.n_gen_share):
132 | x = resblock(x, channel, scope='resblock_' + str(i))
133 |
134 | return x
135 | # END of SHARED LAYERS
136 | ##############################################################################
137 |
138 | ##############################################################################
139 | # BEGIN of DECODERS
140 |
141 | def generator(self, x, reuse=False, scope="generator"):
142 | channel = self.ch * pow(2, self.n_encoder - 1)
143 | with tf.variable_scope(scope, reuse=reuse):
144 | for i in range(0, self.n_gen_resblock):
145 | x = resblock(x, channel, scope='resblock_' + str(i))
146 |
147 | for i in range(0, self.n_gen_decoder - 1):
148 | x = deconv(x, channel // 2, kernel=3, stride=2, scope='deconv_' + str(i))
149 | x = lrelu(x, 0.01)
150 | channel = channel // 2
151 |
152 | x = deconv(x, channels=3, kernel=1, stride=1, scope='G_logit')
153 | x = tanh(x)
154 |
155 | return x
156 | # END of DECODERS
157 | ##############################################################################
158 |
159 | ##############################################################################
160 | # BEGIN of DISCRIMINATORS
161 |
162 | def discriminator(self, x, reuse=False, scope="discriminator"):
163 | channel = self.ch
164 | with tf.variable_scope(scope, reuse=reuse):
165 | x = conv(x, channel, kernel=3, stride=2, pad=1, scope='conv_0')
166 | x = lrelu(x, 0.01)
167 |
168 | for i in range(1, self.n_dis):
169 | x = conv(x, channel * 2, kernel=3, stride=2, pad=1, scope='conv_' + str(i))
170 | x = lrelu(x, 0.01)
171 | channel *= 2
172 |
173 | x = conv(x, channels=1, kernel=1, stride=1, scope='D_logit')
174 |
175 | return x
176 |
177 | # END of DISCRIMINATORS
178 | ##############################################################################
179 |
180 | def translation(self, x_A, x_B):
181 | out = tf.concat([self.encoder(x_A, scope="encoder_A"), self.encoder(x_B, scope="encoder_B")], axis=0)
182 | shared = self.share_encoder(out)
183 | out = self.share_generator(shared)
184 |
185 | out_A = self.generator(out, scope="generator_A")
186 | out_B = self.generator(out, scope="generator_B")
187 |
188 | x_Aa, x_Ba = tf.split(out_A, 2, axis=0)
189 | x_Ab, x_Bb = tf.split(out_B, 2, axis=0)
190 |
191 | return x_Aa, x_Ba, x_Ab, x_Bb, shared
192 |
193 | def generate_a2b(self, x_A):
194 | out = self.encoder(x_A, reuse=True, scope="encoder_A")
195 | shared = self.share_encoder(out, reuse=True)
196 | out = self.share_generator(shared, reuse=True)
197 | out = self.generator(out, reuse=True, scope="generator_B")
198 |
199 | return out, shared
200 |
201 | def generate_b2a(self, x_B):
202 | out = self.encoder(x_B, reuse=True, scope="encoder_B")
203 | shared = self.share_encoder(out, reuse=True)
204 | out = self.share_generator(shared, reuse=True)
205 | out = self.generator(out, reuse=True, scope="generator_A")
206 |
207 | return out, shared
208 |
209 | def discriminate_real(self, x_A, x_B):
210 | real_A_logit = self.discriminator(x_A, scope="discriminator_A")
211 | real_B_logit = self.discriminator(x_B, scope="discriminator_B")
212 |
213 | return real_A_logit, real_B_logit
214 |
215 | def discriminate_fake(self, x_ba, x_ab):
216 | fake_A_logit = self.discriminator(x_ba, reuse=True, scope="discriminator_A")
217 | fake_B_logit = self.discriminator(x_ab, reuse=True, scope="discriminator_B")
218 |
219 | return fake_A_logit, fake_B_logit
220 |
221 | def build_model(self):
222 | self.lr = tf.placeholder(tf.float32, name='learning_rate')
223 |
224 | """ Input Image"""
225 | Image_Data_Class = ImageData(self.img_size, self.img_ch, self.augment_flag)
226 |
227 | trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset)
228 | trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset)
229 |
230 | trainA = trainA.prefetch(self.batch_size).shuffle(self.dataset_num).map(Image_Data_Class.image_processing, num_parallel_calls=8).apply(batch_and_drop_remainder(self.batch_size)).repeat()
231 | trainB = trainB.prefetch(self.batch_size).shuffle(self.dataset_num).map(Image_Data_Class.image_processing, num_parallel_calls=8).apply(batch_and_drop_remainder(self.batch_size)).repeat()
232 |
233 | trainA_iterator = trainA.make_one_shot_iterator()
234 | trainB_iterator = trainB.make_one_shot_iterator()
235 |
236 | self.domain_A = trainA_iterator.get_next()
237 | self.domain_B = trainB_iterator.get_next()
238 |
239 | domain_A = tf.split(self.domain_A, self.gpu_num)
240 | domain_B = tf.split(self.domain_B, self.gpu_num)
241 |
242 | G_A_losses = []
243 | G_B_losses = []
244 | D_A_losses = []
245 | D_B_losses = []
246 |
247 | G_losses = []
248 | D_losses = []
249 |
250 | self.fake_A = []
251 | self.fake_B = []
252 |
253 | self.real_A = []
254 | self.real_B = []
255 |
256 | for gpu_id in range(self.gpu_num):
257 | with tf.device(tf.DeviceSpec(device_type="GPU", device_index=gpu_id)):
258 | with tf.variable_scope(tf.get_variable_scope(), reuse=(gpu_id > 0)):
259 | """ Define Encoder, Generator, Discriminator """
260 | x_aa, x_ba, x_ab, x_bb, shared = self.translation(domain_A[gpu_id], domain_B[gpu_id])
261 | x_bab, shared_bab = self.generate_a2b(x_ba)
262 | x_aba, shared_aba = self.generate_b2a(x_ab)
263 |
264 | real_A_logit, real_B_logit = self.discriminate_real(domain_A[gpu_id], domain_B[gpu_id])
265 |
266 |
267 | fake_A_logit, fake_B_logit = self.discriminate_fake(x_ba, x_ab)
268 |
269 | """ Define Loss """
270 | G_ad_loss_a = generator_loss(self.gan_type, fake_A_logit)
271 | G_ad_loss_b = generator_loss(self.gan_type, fake_B_logit)
272 |
273 | D_ad_loss_a = discriminator_loss(self.gan_type, real_A_logit, fake_A_logit)
274 | D_ad_loss_b = discriminator_loss(self.gan_type, real_B_logit, fake_B_logit)
275 |
276 | enc_loss = KL_divergence(shared)
277 | enc_bab_loss = KL_divergence(shared_bab)
278 | enc_aba_loss = KL_divergence(shared_aba)
279 |
280 | l1_loss_a = L1_loss(x_aa, domain_A[gpu_id]) # identity
281 | l1_loss_b = L1_loss(x_bb, domain_B[gpu_id]) # identity
282 | l1_loss_aba = L1_loss(x_aba, domain_A[gpu_id]) # reconstruction
283 | l1_loss_bab = L1_loss(x_bab, domain_B[gpu_id]) # reconstruction
284 |
285 | Generator_A_loss_split = self.GAN_weight * G_ad_loss_a + \
286 | self.L1_weight * l1_loss_a + \
287 | self.L1_cycle_weight * l1_loss_aba + \
288 | self.KL_weight * enc_loss + \
289 | self.KL_cycle_weight * enc_bab_loss
290 |
291 | Generator_B_loss_split = self.GAN_weight * G_ad_loss_b + \
292 | self.L1_weight * l1_loss_b + \
293 | self.L1_cycle_weight * l1_loss_bab + \
294 | self.KL_weight * enc_loss + \
295 | self.KL_cycle_weight * enc_aba_loss
296 |
297 | Discriminator_A_loss_split = self.GAN_weight * D_ad_loss_a
298 | Discriminator_B_loss_split = self.GAN_weight * D_ad_loss_b
299 |
300 | Generator_loss_split = Generator_A_loss_split + Generator_B_loss_split
301 | Discriminator_loss_split = Discriminator_A_loss_split + Discriminator_B_loss_split
302 |
303 | G_A_losses.append(Generator_A_loss_split)
304 | G_B_losses.append(Generator_B_loss_split)
305 | D_A_losses.append(Discriminator_A_loss_split)
306 | D_B_losses.append(Discriminator_B_loss_split)
307 |
308 | G_losses.append(Generator_loss_split)
309 | D_losses.append(Discriminator_loss_split)
310 |
311 | self.fake_A.append(x_ba)
312 | self.fake_B.append(x_ab)
313 |
314 | self.real_A.append(domain_A[gpu_id])
315 | self.real_B.append(domain_B[gpu_id])
316 |
317 | Generator_A_loss = tf.reduce_mean(G_A_losses)
318 | Generator_B_loss = tf.reduce_mean(G_B_losses)
319 | Discriminator_A_loss = tf.reduce_mean(D_A_losses)
320 | Discriminator_B_loss = tf.reduce_mean(D_B_losses)
321 |
322 | self.Generator_loss = tf.reduce_mean(G_losses)
323 | self.Discriminator_loss = tf.reduce_mean(D_losses)
324 |
325 |
326 | """ Training """
327 | t_vars = tf.trainable_variables()
328 | G_vars = [var for var in t_vars if 'generator' in var.name or 'encoder' in var.name]
329 | D_vars = [var for var in t_vars if 'discriminator' in var.name]
330 |
331 | self.G_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Generator_loss, var_list=G_vars)
332 | self.D_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Discriminator_loss, var_list=D_vars)
333 |
334 | """" Summary """
335 | self.all_G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss)
336 | self.all_D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss)
337 | self.G_A_loss = tf.summary.scalar("G_A_loss", Generator_A_loss)
338 | self.G_B_loss = tf.summary.scalar("G_B_loss", Generator_B_loss)
339 | self.D_A_loss = tf.summary.scalar("D_A_loss", Discriminator_A_loss)
340 | self.D_B_loss = tf.summary.scalar("D_B_loss", Discriminator_B_loss)
341 |
342 | self.G_loss = tf.summary.merge([self.G_A_loss, self.G_B_loss, self.all_G_loss])
343 | self.D_loss = tf.summary.merge([self.D_A_loss, self.D_B_loss, self.all_D_loss])
344 |
345 | """ Image """
346 | self.fake_A = tf.squeeze(self.fake_A)
347 | self.fake_B = tf.squeeze(self.fake_B)
348 |
349 | self.real_A = tf.squeeze(self.real_A)
350 | self.real_B = tf.squeeze(self.real_B)
351 |
352 | """ Test """
353 | self.test_image = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_image')
354 |
355 | self.test_fake_A, _ = self.generate_b2a(self.test_image)
356 | self.test_fake_B, _ = self.generate_a2b(self.test_image)
357 |
358 | def train(self):
359 | # initialize all variables
360 | tf.global_variables_initializer().run()
361 |
362 | # saver to save model
363 | self.saver = tf.train.Saver()
364 |
365 | # summary writer
366 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_dir, self.sess.graph)
367 |
368 | # restore check-point if it exits
369 | could_load, checkpoint_counter = self.load(self.checkpoint_dir)
370 | if could_load:
371 | start_epoch = (int)(checkpoint_counter / self.iteration)
372 | start_batch_id = checkpoint_counter - start_epoch * self.iteration
373 | counter = checkpoint_counter
374 | print(" [*] Load SUCCESS")
375 | else:
376 | start_epoch = 0
377 | start_batch_id = 0
378 | counter = 1
379 | print(" [!] Load failed...")
380 |
381 | # loop for epoch
382 | start_time = time.time()
383 | lr = self.init_lr
384 | for epoch in range(start_epoch, self.epoch):
385 | for idx in range(start_batch_id, self.iteration):
386 | train_feed_dict = {
387 | self.lr : lr
388 | }
389 |
390 | # Update D
391 | _, d_loss, summary_str = self.sess.run([self.D_optim, self.Discriminator_loss, self.D_loss], feed_dict = train_feed_dict)
392 | self.writer.add_summary(summary_str, counter)
393 |
394 | # Update G
395 | batch_A_images, batch_B_images, fake_A, fake_B, _, g_loss, summary_str = self.sess.run([self.real_A, self.real_B, self.fake_A, self.fake_B, self.G_optim, self.Generator_loss, self.G_loss], feed_dict = train_feed_dict)
396 | self.writer.add_summary(summary_str, counter)
397 |
398 | # display training status
399 | counter += 1
400 | print("Epoch: [%2d] [%6d/%6d] time: %4.4f d_loss: %.8f, g_loss: %.8f" \
401 | % (epoch, idx, self.iteration, time.time() - start_time, d_loss, g_loss))
402 |
403 | if np.mod(idx+1, self.print_freq) == 0 :
404 | save_images(batch_A_images, [self.batch_size, 1],
405 | './{}/real_A_{:02d}_{:06d}.jpg'.format(self.sample_dir, epoch, idx+1))
406 | # save_images(batch_B_images, [self.batch_size, 1],
407 | # './{}/real_B_{}_{:02d}_{:06d}.jpg'.format(self.sample_dir, gpu_id, epoch, idx+1))
408 |
409 | # save_images(fake_A, [self.batch_size, 1],
410 | # './{}/fake_A_{}_{:02d}_{:06d}.jpg'.format(self.sample_dir, gpu_id, epoch, idx+1))
411 | save_images(fake_B, [self.batch_size, 1],
412 | './{}/fake_B_{:02d}_{:06d}.jpg'.format(self.sample_dir, epoch, idx+1))
413 |
414 | if np.mod(idx+1, self.save_freq) == 0 :
415 | self.save(self.checkpoint_dir, counter)
416 |
417 | # After an epoch, start_batch_id is set to zero
418 | # non-zero value is only for the first epoch after loading pre-trained model
419 | start_batch_id = 0
420 |
421 | # save model for final step
422 | self.save(self.checkpoint_dir, counter)
423 |
424 | @property
425 | def model_dir(self):
426 | return "{}_{}_{}".format(self.model_name, self.dataset_name, self.gan_type)
427 |
428 | def save(self, checkpoint_dir, step):
429 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
430 |
431 | if not os.path.exists(checkpoint_dir):
432 | os.makedirs(checkpoint_dir)
433 |
434 | self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step)
435 |
436 | def load(self, checkpoint_dir):
437 | import re
438 | print(" [*] Reading checkpoints...")
439 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
440 |
441 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
442 | if ckpt and ckpt.model_checkpoint_path:
443 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
444 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
445 | counter = int(next(re.finditer("(\d+)(?!.*\d)", ckpt_name)).group(0))
446 | print(" [*] Success to read {}".format(ckpt_name))
447 | return True, counter
448 | else:
449 | print(" [*] Failed to find a checkpoint")
450 | return False, 0
451 |
452 | def test(self):
453 | tf.global_variables_initializer().run()
454 | test_A_files = glob('./dataset/{}/*.*'.format(self.dataset_name + '/testA'))
455 | test_B_files = glob('./dataset/{}/*.*'.format(self.dataset_name + '/testB'))
456 |
457 | self.saver = tf.train.Saver()
458 | could_load, checkpoint_counter = self.load(self.checkpoint_dir)
459 | self.result_dir = os.path.join(self.result_dir, self.model_dir)
460 | check_folder(self.result_dir)
461 |
462 | if could_load :
463 | print(" [*] Load SUCCESS")
464 | else :
465 | print(" [!] Load failed...")
466 |
467 | # write html for visual comparison
468 | index_path = os.path.join(self.result_dir, 'index.html')
469 | index = open(index_path, 'w')
470 | index.write("")
471 | index.write("name | input | output |
")
472 |
473 | for sample_file in test_A_files : # A -> B
474 | print('Processing A image: ' + sample_file)
475 | sample_image = np.asarray(load_test_data(sample_file, size=self.img_size))
476 | image_path = os.path.join(self.result_dir, '{0}'.format(os.path.basename(sample_file)))
477 |
478 | fake_img = self.sess.run(self.test_fake_B, feed_dict={self.test_image: sample_image})
479 | save_images(fake_img, [1, 1], image_path)
480 |
481 | index.write("%s | " % os.path.basename(image_path))
482 | index.write(" | " % (sample_file if os.path.isabs(sample_file) else (
483 | '../..' + os.path.sep + sample_file), self.img_size, self.img_size))
484 | index.write(" | " % (image_path if os.path.isabs(image_path) else (
485 | '../..' + os.path.sep + image_path), self.img_size, self.img_size))
486 | index.write("")
487 |
488 | for sample_file in test_B_files : # B -> A
489 | print('Processing B image: ' + sample_file)
490 | sample_image = np.asarray(load_test_data(sample_file, size=self.img_size))
491 | image_path = os.path.join(self.result_dir, '{0}'.format(os.path.basename(sample_file)))
492 |
493 | fake_img = self.sess.run(self.test_fake_A, feed_dict={self.test_image: sample_image})
494 | save_images(fake_img, [1, 1], image_path)
495 |
496 | index.write("%s | " % os.path.basename(image_path))
497 | index.write(" | " % (sample_file if os.path.isabs(sample_file) else (
498 | '../..' + os.path.sep + sample_file), self.img_size, self.img_size))
499 | index.write(" | " % (image_path if os.path.isabs(image_path) else (
500 | '../..' + os.path.sep + image_path), self.img_size, self.img_size))
501 | index.write("")
502 |
503 | index.close()
504 |
--------------------------------------------------------------------------------
/DatasetAPI/main.py:
--------------------------------------------------------------------------------
1 | from UNIT import UNIT
2 | import argparse
3 | from utils import *
4 |
5 | """parsing and configuration"""
6 | def parse_args():
7 | desc = "Tensorflow implementation of UNIT"
8 | parser = argparse.ArgumentParser(description=desc)
9 | parser.add_argument('--phase', type=str, default='train', help='train or test ?')
10 | parser.add_argument('--dataset', type=str, default='summer2winter', help='dataset_name')
11 | parser.add_argument('--augment_flag', type=bool, default=False, help='Image augmentation use or not')
12 |
13 | parser.add_argument('--epoch', type=int, default=5, help='The number of epochs to run')
14 | parser.add_argument('--iteration', type=int, default=100000, help='The number of training iterations')
15 | parser.add_argument('--batch_size', type=int, default=1, help='The batch size')
16 | parser.add_argument('--print_freq', type=int, default=1000, help='The number of image_print_freq')
17 | parser.add_argument('--save_freq', type=int, default=1000, help='The number of ckpt_save_freq')
18 |
19 | parser.add_argument('--lr', type=float, default=0.0001, help='The learning rate')
20 | parser.add_argument('--GAN_weight', type=float, default=10.0, help='Weight about GAN, lambda0')
21 | parser.add_argument('--KL_weight', type=float, default=0.1, help='Weight about VAE, lambda1')
22 | parser.add_argument('--L1_weight', type=float, default=100.0, help='Weight about VAE, lambda2' )
23 | parser.add_argument('--KL_cycle_weight', type=float, default=0.1, help='Weight about VAE Cycle, lambda3')
24 | parser.add_argument('--L1_cycle_weight', type=float, default=100.0, help='Weight about VAE Cycle, lambda4')
25 |
26 | parser.add_argument('--gan_type', type=str, default='gan', help='GAN loss type [gan / lsgan]')
27 |
28 | parser.add_argument('--ch', type=int, default=64, help='base channel number per layer')
29 | parser.add_argument('--n_encoder', type=int, default=3, help='The number of encoder')
30 | parser.add_argument('--n_enc_resblock', type=int, default=3, help='The number of encoder_resblock')
31 | parser.add_argument('--n_enc_share', type=int, default=1, help='The number of share_encoder')
32 | parser.add_argument('--n_gen_share', type=int, default=1, help='The number of share_generator')
33 | parser.add_argument('--n_gen_resblock', type=int, default=3, help='The number of generator_resblock')
34 | parser.add_argument('--n_gen_decoder', type=int, default=3, help='The number of generator_decoder')
35 |
36 | parser.add_argument('--n_dis', type=int, default=6, help='The number of discriminator layer')
37 |
38 | parser.add_argument('--img_size', type=int, default=256, help='The size of image')
39 | parser.add_argument('--img_ch', type=int, default=3, help='The size of image channel')
40 |
41 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint',
42 | help='Directory name to save the checkpoints')
43 | parser.add_argument('--result_dir', type=str, default='results',
44 | help='Directory name to save the generated images')
45 | parser.add_argument('--log_dir', type=str, default='logs',
46 | help='Directory name to save training logs')
47 | parser.add_argument('--sample_dir', type=str, default='samples',
48 | help='Directory name to save the samples on training')
49 |
50 | return check_args(parser.parse_args())
51 |
52 | """checking arguments"""
53 | def check_args(args):
54 | # --checkpoint_dir
55 | check_folder(args.checkpoint_dir)
56 |
57 | # --result_dir
58 | check_folder(args.result_dir)
59 |
60 | # --result_dir
61 | check_folder(args.log_dir)
62 |
63 | # --sample_dir
64 | check_folder(args.sample_dir)
65 |
66 | # --epoch
67 | try:
68 | assert args.epoch >= 1
69 | except:
70 | print('number of epochs must be larger than or equal to one')
71 |
72 | # --batch_size
73 | try:
74 | assert args.batch_size >= 1
75 | except:
76 | print('batch size must be larger than or equal to one')
77 | return args
78 |
79 | """main"""
80 | def main():
81 | # parse arguments
82 | args = parse_args()
83 | if args is None:
84 | exit()
85 |
86 | # open session
87 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
88 | gan = UNIT(sess, args)
89 |
90 | # build graph
91 | gan.build_model()
92 |
93 | # show network architecture
94 | show_all_variables()
95 |
96 | if args.phase == 'train' :
97 | # launch the graph in a session
98 | gan.train()
99 | print(" [*] Training finished!")
100 |
101 | if args.phase == 'test' :
102 | gan.test()
103 | print(" [*] Test finished!")
104 |
105 | if __name__ == '__main__':
106 | main()
--------------------------------------------------------------------------------
/DatasetAPI/main_multi_gpu.py:
--------------------------------------------------------------------------------
1 | from UNIT import UNIT
2 | import argparse
3 | from utils import *
4 |
5 | """parsing and configuration"""
6 | def parse_args():
7 | desc = "Tensorflow implementation of UNIT"
8 | parser = argparse.ArgumentParser(description=desc)
9 | parser.add_argument('--phase', type=str, default='train', help='train or test ?')
10 | parser.add_argument('--dataset', type=str, default='summer2winter', help='dataset_name')
11 | parser.add_argument('--augment_flag', type=bool, default=False, help='Image augmentation use or not')
12 |
13 | parser.add_argument('--epoch', type=int, default=5, help='The number of epochs to run')
14 | parser.add_argument('--iteration', type=int, default=100000, help='The number of training iterations')
15 | parser.add_argument('--batch_size', type=int, default=1, help='The batch size')
16 | parser.add_argument('--gpu_num', type=int, default=8, help='The number of gpu')
17 | parser.add_argument('--print_freq', type=int, default=1000, help='The number of image_print_freq')
18 | parser.add_argument('--save_freq', type=int, default=1000, help='The number of ckpt_save_freq')
19 |
20 | parser.add_argument('--lr', type=float, default=0.0001, help='The learning rate')
21 | parser.add_argument('--GAN_weight', type=float, default=10.0, help='Weight about GAN, lambda0')
22 | parser.add_argument('--KL_weight', type=float, default=0.1, help='Weight about VAE, lambda1')
23 | parser.add_argument('--L1_weight', type=float, default=100.0, help='Weight about VAE, lambda2' )
24 | parser.add_argument('--KL_cycle_weight', type=float, default=0.1, help='Weight about VAE Cycle, lambda3')
25 | parser.add_argument('--L1_cycle_weight', type=float, default=100.0, help='Weight about VAE Cycle, lambda4')
26 |
27 | parser.add_argument('--gan_type', type=str, default='gan', help='GAN loss type [gan / lsgan]')
28 |
29 | parser.add_argument('--ch', type=int, default=64, help='base channel number per layer')
30 | parser.add_argument('--n_encoder', type=int, default=3, help='The number of encoder')
31 | parser.add_argument('--n_enc_resblock', type=int, default=3, help='The number of encoder_resblock')
32 | parser.add_argument('--n_enc_share', type=int, default=1, help='The number of share_encoder')
33 | parser.add_argument('--n_gen_share', type=int, default=1, help='The number of share_generator')
34 | parser.add_argument('--n_gen_resblock', type=int, default=3, help='The number of generator_resblock')
35 | parser.add_argument('--n_gen_decoder', type=int, default=3, help='The number of generator_decoder')
36 |
37 | parser.add_argument('--n_dis', type=int, default=6, help='The number of discriminator layer')
38 |
39 | parser.add_argument('--img_size', type=int, default=256, help='The size of image')
40 | parser.add_argument('--img_ch', type=int, default=3, help='The size of image channel')
41 |
42 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint',
43 | help='Directory name to save the checkpoints')
44 | parser.add_argument('--result_dir', type=str, default='results',
45 | help='Directory name to save the generated images')
46 | parser.add_argument('--log_dir', type=str, default='logs',
47 | help='Directory name to save training logs')
48 | parser.add_argument('--sample_dir', type=str, default='samples',
49 | help='Directory name to save the samples on training')
50 |
51 | return check_args(parser.parse_args())
52 |
53 | """checking arguments"""
54 | def check_args(args):
55 | # --checkpoint_dir
56 | check_folder(args.checkpoint_dir)
57 |
58 | # --result_dir
59 | check_folder(args.result_dir)
60 |
61 | # --result_dir
62 | check_folder(args.log_dir)
63 |
64 | # --sample_dir
65 | check_folder(args.sample_dir)
66 |
67 | # --epoch
68 | try:
69 | assert args.epoch >= 1
70 | except:
71 | print('number of epochs must be larger than or equal to one')
72 |
73 | # --batch_size
74 | try:
75 | assert args.batch_size >= 1
76 | except:
77 | print('batch size must be larger than or equal to one')
78 | return args
79 |
80 | """main"""
81 | def main():
82 | # parse arguments
83 | args = parse_args()
84 | if args is None:
85 | exit()
86 |
87 | # open session
88 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
89 | gan = UNIT(sess, args)
90 |
91 | # build graph
92 | gan.build_model()
93 |
94 | # show network architecture
95 | show_all_variables()
96 |
97 | if args.phase == 'train' :
98 | # launch the graph in a session
99 | gan.train()
100 | print(" [*] Training finished!")
101 |
102 | if args.phase == 'test' :
103 | gan.test()
104 | print(" [*] Test finished!")
105 |
106 | if __name__ == '__main__':
107 | main()
--------------------------------------------------------------------------------
/DatasetAPI/ops.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import tensorflow.contrib as tf_contrib
3 |
4 | weight_init = tf.random_normal_initializer(mean=0.0, stddev=0.02)
5 | weight_regularizer = tf_contrib.layers.l2_regularizer(scale=0.0001)
6 |
7 | ##################################################################################
8 | # Layer
9 | ##################################################################################
10 |
11 | def conv(x, channels, kernel=4, stride=2, pad=0, pad_type='zero', use_bias=True, scope='conv'):
12 | with tf.variable_scope(scope):
13 | if pad_type == 'zero' :
14 | x = tf.pad(x, [[0, 0], [pad, pad], [pad, pad], [0, 0]])
15 | if pad_type == 'reflect' :
16 | x = tf.pad(x, [[0, 0], [pad, pad], [pad, pad], [0, 0]], mode='REFLECT')
17 |
18 | x = tf.layers.conv2d(inputs=x, filters=channels,
19 | kernel_size=kernel, kernel_initializer=weight_init,
20 | kernel_regularizer=weight_regularizer,
21 | strides=stride, use_bias=use_bias)
22 |
23 | return x
24 |
25 | def deconv(x, channels, kernel=3, stride=2, use_bias=True, scope='deconv_0') :
26 | with tf.variable_scope(scope):
27 | x = tf.layers.conv2d_transpose(inputs=x, filters=channels,
28 | kernel_size=kernel, kernel_initializer=weight_init,
29 | kernel_regularizer=weight_regularizer,
30 | strides=stride, use_bias=use_bias, padding='SAME')
31 |
32 | return x
33 |
34 | def linear(x, units, use_bias=True, scope='linear'):
35 | with tf.variable_scope(scope):
36 | x = flatten(x)
37 | x = tf.layers.dense(x, units=units, kernel_initializer=weight_init, kernel_regularizer=weight_regularizer, use_bias=use_bias)
38 |
39 | return x
40 |
41 | def flatten(x) :
42 | return tf.layers.flatten(x)
43 |
44 | def gaussian_noise_layer(mu):
45 | sigma = 1.0
46 | gaussian_random_vector = tf.random_normal(shape=tf.shape(mu), mean=0.0, stddev=1.0, dtype=tf.float32)
47 | return mu + sigma * gaussian_random_vector
48 |
49 |
50 | ##################################################################################
51 | # Residual-block
52 | ##################################################################################
53 |
54 | def resblock(x_init, channels, use_bias=True, scope='resblock'):
55 | with tf.variable_scope(scope):
56 | with tf.variable_scope('res1'):
57 | x = conv(x_init, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias)
58 | x = instance_norm(x)
59 | x = relu(x)
60 |
61 | with tf.variable_scope('res2'):
62 | x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias)
63 | x = instance_norm(x)
64 |
65 | return x + x_init
66 |
67 | ##################################################################################
68 | # Activation function
69 | ##################################################################################
70 |
71 | def lrelu(x, alpha=0.01):
72 | # pytorch alpha is 0.01
73 | return tf.nn.leaky_relu(x, alpha)
74 |
75 |
76 | def relu(x):
77 | return tf.nn.relu(x)
78 |
79 |
80 | def tanh(x):
81 | return tf.tanh(x)
82 |
83 | ##################################################################################
84 | # Normalization function
85 | ##################################################################################
86 |
87 | def instance_norm(x, scope='instance_norm'):
88 | return tf_contrib.layers.instance_norm(x,
89 | epsilon=1e-05,
90 | center=True, scale=True,
91 | scope=scope)
92 |
93 | ##################################################################################
94 | # Loss function
95 | ##################################################################################
96 |
97 | def discriminator_loss(type, real, fake):
98 | real_loss = 0
99 | fake_loss = 0
100 |
101 | if type == 'lsgan' :
102 | real_loss = tf.reduce_mean(tf.squared_difference(real, 1.0))
103 | fake_loss = tf.reduce_mean(tf.square(fake))
104 |
105 | if type == 'gan' :
106 | real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real), logits=real))
107 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(fake), logits=fake))
108 |
109 | loss = real_loss + fake_loss
110 |
111 | return loss
112 |
113 |
114 | def generator_loss(type, fake):
115 | fake_loss = 0
116 |
117 | if type == 'lsgan' :
118 | fake_loss = tf.reduce_mean(tf.squared_difference(fake, 1.0))
119 |
120 | if type == 'gan' :
121 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(fake), logits=fake))
122 |
123 | loss = fake_loss
124 |
125 |
126 | return loss
127 |
128 |
129 | def L1_loss(x, y):
130 | loss = tf.reduce_mean(tf.abs(x - y))
131 |
132 | return loss
133 |
134 | def KL_divergence(mu) :
135 | # KL_divergence = 0.5 * tf.reduce_sum(tf.square(mu) + tf.square(sigma) - tf.log(1e-8 + tf.square(sigma)) - 1, axis = -1)
136 | # loss = tf.reduce_mean(KL_divergence)
137 | mu_2 = tf.square(mu)
138 | loss = tf.reduce_mean(mu_2)
139 |
140 | return loss
--------------------------------------------------------------------------------
/DatasetAPI/utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.contrib import slim
3 | from scipy import misc
4 | import os, random
5 | import numpy as np
6 |
7 | # https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/
8 | # https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/
9 |
10 | class ImageData:
11 |
12 | def __init__(self, load_size, channels, augment_flag=False):
13 | self.load_size = load_size
14 | self.channels = channels
15 | self.augment_flag = augment_flag
16 |
17 | def image_processing(self, filename):
18 | x = tf.read_file(filename)
19 | x_decode = tf.image.decode_jpeg(x, channels=self.channels)
20 | img = tf.image.resize_images(x_decode, [self.load_size, self.load_size])
21 | img = tf.cast(img, tf.float32) / 127.5 - 1
22 |
23 | if self.augment_flag :
24 | augment_size = self.load_size + (30 if self.load_size == 256 else 15)
25 | p = random.random()
26 | if p > 0.5:
27 | img = augmentation(img, augment_size)
28 |
29 | return img
30 |
31 |
32 | def load_test_data(image_path, size=256):
33 | img = misc.imread(image_path, mode='RGB')
34 | img = misc.imresize(img, [size, size])
35 | img = np.expand_dims(img, axis=0)
36 | img = preprocessing(img)
37 |
38 | return img
39 |
40 | def preprocessing(x):
41 | x = x/127.5 - 1 # -1 ~ 1
42 | return x
43 |
44 | def augmentation(image, augment_size):
45 | seed = random.randint(0, 2 ** 31 - 1)
46 | ori_image_shape = tf.shape(image)
47 | image = tf.image.random_flip_left_right(image, seed=seed)
48 | image = tf.image.resize_images(image, [augment_size, augment_size])
49 | image = tf.random_crop(image, ori_image_shape, seed=seed)
50 | return image
51 |
52 | def save_images(images, size, image_path):
53 | return imsave(inverse_transform(images), size, image_path)
54 |
55 | def inverse_transform(images):
56 | return (images+1.) / 2
57 |
58 | def imsave(images, size, path):
59 | return misc.imsave(path, merge(images, size))
60 |
61 | def merge(images, size):
62 | h, w = images.shape[1], images.shape[2]
63 | img = np.zeros((h * size[0], w * size[1], 3))
64 | for idx, image in enumerate(images):
65 | i = idx % size[1]
66 | j = idx // size[1]
67 | img[h*j:h*(j+1), w*i:w*(i+1), :] = image
68 |
69 | return img
70 |
71 | def show_all_variables():
72 | model_vars = tf.trainable_variables()
73 | slim.model_analyzer.analyze_vars(model_vars, print_info=True)
74 |
75 | def check_folder(log_dir):
76 | if not os.path.exists(log_dir):
77 | os.makedirs(log_dir)
78 | return log_dir
79 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Junho Kim (1993.01.12)
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # UNIT-Tensorflow
2 | Simple Tensorflow implementation of ["Unsupervised Image to Image Translation Networks"](https://arxiv.org/abs/1703.00848) (NIPS 2017 Spotlight)
3 |
4 | ## Requirements
5 | * Tensorflow 1.4
6 | * Python 3.6
7 |
8 | ## Usage
9 | ```bash
10 | ├── dataset
11 | └── YOUR_DATASET_NAME
12 | ├── trainA
13 | ├── xxx.jpg (name, format doesn't matter)
14 | ├── yyy.png
15 | └── ...
16 | ├── trainB
17 | ├── zzz.jpg
18 | ├── www.png
19 | └── ...
20 | ├── testA
21 | ├── aaa.jpg
22 | ├── bbb.png
23 | └── ...
24 | └── testB
25 | ├── ccc.jpg
26 | ├── ddd.png
27 | └── ...
28 | ```
29 |
30 | ```bash
31 | > python main.py --phase train --dataset cat2tiger
32 | ```
33 | * See `main.py` for other arguments
34 | * If you want to `multi_gpu_version`, then use `main_multi_gpu.py` (batch_size = The batch_size per gpu)
35 | * If you want to `faster_UNIT`, then use `DatasetAPI` (code is more simple !)
36 |
37 | ## Issue
38 | ### Too much Slow !!!
39 | * The slower reason is that it stores checkpoints
40 | * If you want to speed up, do not save checkpoints per iteration
41 |
42 | ## Arichitecture
43 | 
44 |
45 | ## Framework
46 | 
47 |
48 | ## Model
49 | 
50 |
51 | 
52 |
53 | 
54 |
55 | 
56 |
57 | ## Training Objective
58 | 
59 |
60 | ## Result
61 | ### Success
62 | 
63 |
64 | ### Fail
65 | 
66 |
67 | ## Related works
68 | * [CycleGAN-Tensorflow](https://github.com/taki0112/CycleGAN-Tensorflow)
69 | * [DiscoGAN-Tensorflow](https://github.com/taki0112/DiscoGAN-Tensorflow)
70 | * [MUNIT-Tensorflow](https://github.com/taki0112/MUNIT-Tensorflow)
71 | * [StarGAN-Tensorflow](https://github.com/taki0112/StarGAN-Tensorflow)
72 | * [DRIT-Tensorflow](https://github.com/taki0112/DRIT-Tensorflow)
73 |
74 | ## Reference
75 | * [UNIT-Pytorch](https://github.com/mingyuliutw/UNIT)
76 | * [Multi-GPU-Tensorflow](https://github.com/golbin/TensorFlow-Multi-GPUs)
77 | * [DatasetAPI-Tensorflow](https://github.com/taki0112/Tensorflow-DatasetAPI)
78 |
79 | ## Author
80 | Junho Kim
81 |
--------------------------------------------------------------------------------
/UNIT.py:
--------------------------------------------------------------------------------
1 | from ops import *
2 | from utils import *
3 | from glob import glob
4 | import time
5 |
6 | class UNIT(object):
7 | def __init__(self, sess, args):
8 | self.model_name = 'UNIT'
9 | self.sess = sess
10 | self.checkpoint_dir = args.checkpoint_dir
11 | self.result_dir = args.result_dir
12 | self.log_dir = args.log_dir
13 | self.sample_dir = args.sample_dir
14 | self.dataset_name = args.dataset
15 |
16 | self.epoch = args.epoch # 100000
17 | self.batch_size = args.batch_size # 1
18 |
19 | self.lr = args.lr # 0.0001
20 | """ Weight about VAE """
21 | self.KL_weight = args.KL_weight # lambda 1
22 | self.L1_weight = args.L1_weight # lambda 2
23 |
24 | """ Weight about VAE Cycle"""
25 | self.KL_cycle_weight = args.KL_cycle_weight # lambda 3
26 | self.L1_cycle_weight = args.L1_cycle_weight # lambda 4
27 |
28 | """ Weight about GAN """
29 | self.GAN_weight = args.GAN_weight # lambda 0
30 |
31 |
32 | """ Encoder """
33 | self.ch = args.ch # base channel number per layer
34 | self.n_encoder = args.n_encoder
35 | self.n_enc_resblock = args.n_enc_resblock
36 | self.n_enc_share = args.n_enc_share
37 |
38 | """ Generator """
39 | self.n_gen_share = args.n_gen_share
40 | self.n_gen_resblock = args.n_gen_resblock
41 | self.n_gen_decoder = args.n_gen_decoder
42 |
43 | """ Discriminator """
44 | self.n_dis = args.n_dis # + 2
45 |
46 | self.res_dropout = args.res_dropout
47 | self.smoothing = args.smoothing
48 | self.lsgan = args.lsgan
49 | self.norm = args.norm
50 | self.replay_memory = args.replay_memory
51 | self.pool_size = args.pool_size
52 | self.img_size = args.img_size
53 | self.channel = args.img_ch
54 | self.augment_flag = args.augment_flag
55 | self.augment_size = self.img_size + (30 if self.img_size == 256 else 15)
56 | self.normal_weight_init = args.normal_weight_init
57 |
58 | self.trainA, self.trainB = prepare_data(dataset_name=self.dataset_name, size=self.img_size)
59 | self.num_batches = max(len(self.trainA) // self.batch_size, len(self.trainB) // self.batch_size)
60 |
61 | ##############################################################################
62 | # BEGIN of ENCODERS
63 | def encoder(self, x, is_training=True, reuse=False, scope="encoder"):
64 | channel = self.ch
65 | with tf.variable_scope(scope, reuse=reuse) :
66 | x = conv(x, channel, kernel=7, stride=1, pad=3, normal_weight_init=self.normal_weight_init, activation_fn='leaky', scope='conv_0')
67 |
68 | for i in range(1, self.n_encoder) :
69 | x = conv(x, channel*2, kernel=3, stride=2, pad=1, normal_weight_init=self.normal_weight_init, activation_fn='leaky', scope='conv_'+str(i))
70 | channel *= 2
71 |
72 | # channel = 256
73 | for i in range(0, self.n_enc_resblock) :
74 | x = resblock(x, channel, kernel=3, stride=1, pad=1, dropout_ratio=self.res_dropout,
75 | normal_weight_init=self.normal_weight_init,
76 | is_training=is_training, norm_fn=self.norm, scope='resblock_'+str(i))
77 |
78 | return x
79 | # END of ENCODERS
80 | ##############################################################################
81 |
82 | ##############################################################################
83 | # BEGIN of SHARED LAYERS
84 | # Shared residual-blocks
85 | def share_encoder(self, x, is_training=True, reuse=False, scope="share_encoder"):
86 | channel = self.ch * pow(2, self.n_encoder-1)
87 | with tf.variable_scope(scope, reuse=reuse) :
88 | for i in range(0, self.n_enc_share) :
89 | x = resblock(x, channel, kernel=3, stride=1, pad=1, dropout_ratio=self.res_dropout,
90 | normal_weight_init=self.normal_weight_init,
91 | is_training=is_training, norm_fn=self.norm, scope='resblock_'+str(i))
92 |
93 | x = gaussian_noise_layer(x)
94 |
95 | return x
96 |
97 | def share_generator(self, x, is_training=True, reuse=False, scope="share_generator"):
98 | channel = self.ch * pow(2, self.n_encoder-1)
99 | with tf.variable_scope(scope, reuse=reuse) :
100 | for i in range(0, self.n_gen_share) :
101 | x = resblock(x, channel, kernel=3, stride=1, pad=1, dropout_ratio=self.res_dropout,
102 | normal_weight_init=self.normal_weight_init,
103 | is_training=is_training, norm_fn=self.norm, scope='resblock_'+str(i))
104 |
105 | return x
106 | # END of SHARED LAYERS
107 | ##############################################################################
108 |
109 | ##############################################################################
110 | # BEGIN of DECODERS
111 | def generator(self, x, is_training=True, reuse=False, scope="generator"):
112 | channel = self.ch * pow(2, self.n_encoder - 1)
113 | with tf.variable_scope(scope, reuse=reuse) :
114 | for i in range(0, self.n_gen_resblock) :
115 | x = resblock(x, channel, kernel=3, stride=1, pad=1, dropout_ratio=self.res_dropout,
116 | normal_weight_init=self.normal_weight_init,
117 | is_training=is_training, norm_fn=self.norm, scope='resblock_'+str(i))
118 |
119 | for i in range(0, self.n_gen_decoder-1) :
120 | x = deconv(x, channel//2, kernel=3, stride=2, normal_weight_init=self.normal_weight_init, activation_fn='leaky', scope='deconv_'+str(i))
121 | channel = channel // 2
122 |
123 | x = deconv(x, self.channel, kernel=1, stride=1, normal_weight_init=self.normal_weight_init, activation_fn='tanh', scope='deconv_tanh')
124 |
125 | return x
126 | # END of DECODERS
127 | ##############################################################################
128 |
129 | ##############################################################################
130 | # BEGIN of DISCRIMINATORS
131 | def discriminator(self, x, reuse=False, scope="discriminator"):
132 | channel = self.ch
133 | with tf.variable_scope(scope, reuse=reuse):
134 | x = conv(x, channel, kernel=3, stride=2, pad=1, normal_weight_init=self.normal_weight_init, activation_fn='leaky', scope='conv_0')
135 |
136 | for i in range(1, self.n_dis) :
137 | x = conv(x, channel*2, kernel=3, stride=2, pad=1, normal_weight_init=self.normal_weight_init, activation_fn='leaky', scope='conv_'+str(i))
138 | channel *= 2
139 |
140 | x = conv(x, channels=1, kernel=1, stride=1, pad=0, normal_weight_init=self.normal_weight_init, activation_fn=None, scope='dis_logit')
141 |
142 | return x
143 | # END of DISCRIMINATORS
144 | ##############################################################################
145 |
146 | def translation(self, x_A, x_B):
147 | out = tf.concat([self.encoder(x_A, self.is_training, scope="encoder_A"), self.encoder(x_B, self.is_training, scope="encoder_B")], axis=0)
148 | shared = self.share_encoder(out, self.is_training)
149 | out = self.share_generator(shared, self.is_training)
150 |
151 | out_A = self.generator(out, self.is_training, scope="generator_A")
152 | out_B = self.generator(out, self.is_training, scope="generator_B")
153 |
154 | x_Aa, x_Ba = tf.split(out_A, 2, axis=0)
155 | x_Ab, x_Bb = tf.split(out_B, 2, axis=0)
156 |
157 | return x_Aa, x_Ba, x_Ab, x_Bb, shared
158 |
159 | def generate_a2b(self, x_A):
160 | out = self.encoder(x_A, self.is_training, reuse=True, scope="encoder_A")
161 | shared = self.share_encoder(out, self.is_training, reuse=True)
162 | out = self.share_generator(shared, self.is_training, reuse=True)
163 | out = self.generator(out, self.is_training, reuse=True, scope="generator_B")
164 |
165 | return out, shared
166 |
167 | def generate_b2a(self, x_B):
168 | out = self.encoder(x_B, self.is_training, reuse=True, scope="encoder_B")
169 | shared = self.share_encoder(out, self.is_training, reuse=True)
170 | out = self.share_generator(shared, self.is_training, reuse=True)
171 | out = self.generator(out, self.is_training, reuse=True, scope="generator_A")
172 |
173 | return out, shared
174 |
175 | def discriminate_real(self, x_A, x_B):
176 | real_A_logit = self.discriminator(x_A, scope="discriminator_A")
177 | real_B_logit = self.discriminator(x_B, scope="discriminator_B")
178 |
179 | return real_A_logit, real_B_logit
180 |
181 | def discriminate_fake(self, x_ba, x_ab):
182 | fake_A_logit = self.discriminator(x_ba, reuse=True, scope="discriminator_A")
183 | fake_B_logit = self.discriminator(x_ab, reuse=True, scope="discriminator_B")
184 |
185 | return fake_A_logit, fake_B_logit
186 |
187 | def discriminate_fake_pool(self, x_ba, x_ab):
188 | fake_A_pool_logit = self.discriminator(self.fake_A_pool.query(x_ba), reuse=True, scope="discriminator_A") # replay memory
189 | fake_B_pool_logit = self.discriminator(self.fake_B_pool.query(x_ab), reuse=True, scope="discriminator_B") # replay memory
190 |
191 | return fake_A_pool_logit, fake_B_pool_logit
192 |
193 | def build_model(self):
194 | self.is_training = tf.placeholder(tf.bool)
195 | self.prob = tf.placeholder(tf.float32)
196 | self.condition = tf.logical_and(tf.greater(self.prob, tf.constant(0.5)), self.is_training)
197 |
198 | """ Input Image"""
199 | domain_A = self.domain_A = tf.placeholder(tf.float32, [self.batch_size, self.img_size, self.img_size, self.channel], name='domain_A') # real A
200 | domain_B = self.domain_B = tf.placeholder(tf.float32, [self.batch_size, self.img_size, self.img_size, self.channel], name='domain_B') # real B
201 |
202 | if self.augment_flag :
203 | """ Augmentation """
204 | domain_A = tf.cond(
205 | self.condition,
206 | lambda : augmentation(domain_A, self.augment_size),
207 | lambda : domain_A
208 | )
209 |
210 | domain_B = tf.cond(
211 | self.condition,
212 | lambda : augmentation(domain_B, self.augment_size),
213 | lambda : domain_B
214 | )
215 |
216 |
217 | """ Define Encoder, Generator, Discriminator """
218 | x_aa, x_ba, x_ab, x_bb, shared = self.translation(domain_A, domain_B)
219 | x_bab, shared_bab = self.generate_a2b(x_ba)
220 | x_aba, shared_aba = self.generate_b2a(x_ab)
221 |
222 | real_A_logit, real_B_logit = self.discriminate_real(domain_A, domain_B)
223 |
224 | if self.replay_memory :
225 | self.fake_A_pool = ImagePool(self.pool_size) # pool of generated A
226 | self.fake_B_pool = ImagePool(self.pool_size) # pool of generated B
227 | fake_A_logit, fake_B_logit = self.discriminate_fake_pool(x_ba, x_ab)
228 | else :
229 | fake_A_logit, fake_B_logit = self.discriminate_fake(x_ba, x_ab)
230 |
231 |
232 |
233 | """ Define Loss """
234 | G_ad_loss_a = generator_loss(fake_A_logit, smoothing=self.smoothing, use_lsgan=self.lsgan)
235 | G_ad_loss_b = generator_loss(fake_B_logit, smoothing=self.smoothing, use_lsgan=self.lsgan)
236 |
237 | D_ad_loss_a = discriminator_loss(real_A_logit, fake_A_logit, smoothing=self.smoothing, use_lasgan=self.lsgan)
238 | D_ad_loss_b = discriminator_loss(real_B_logit, fake_B_logit, smoothing=self.smoothing, use_lasgan=self.lsgan)
239 |
240 | enc_loss = KL_divergence(shared)
241 | enc_bab_loss = KL_divergence(shared_bab)
242 | enc_aba_loss = KL_divergence(shared_aba)
243 |
244 | l1_loss_a = L1_loss(x_aa, domain_A) # identity
245 | l1_loss_b = L1_loss(x_bb, domain_B) # identity
246 | l1_loss_aba = L1_loss(x_aba, domain_A) # reconstruction
247 | l1_loss_bab = L1_loss(x_bab, domain_B) # reconstruction
248 |
249 | Generator_A_loss = self.GAN_weight * G_ad_loss_a + \
250 | self.L1_weight * l1_loss_a + \
251 | self.L1_cycle_weight * l1_loss_aba + \
252 | self.KL_weight * enc_loss + \
253 | self.KL_cycle_weight * enc_bab_loss
254 |
255 | Generator_B_loss = self.GAN_weight * G_ad_loss_b + \
256 | self.L1_weight * l1_loss_b + \
257 | self.L1_cycle_weight * l1_loss_bab + \
258 | self.KL_weight * enc_loss + \
259 | self.KL_cycle_weight * enc_aba_loss
260 |
261 | Discriminator_A_loss = self.GAN_weight * D_ad_loss_a
262 | Discriminator_B_loss = self.GAN_weight * D_ad_loss_b
263 |
264 | self.Generator_loss = Generator_A_loss + Generator_B_loss
265 | self.Discriminator_loss = Discriminator_A_loss + Discriminator_B_loss
266 |
267 |
268 | """ Training """
269 | t_vars = tf.trainable_variables()
270 | G_vars = [var for var in t_vars if ('generator' in var.name) or ('encoder' in var.name)]
271 | D_vars = [var for var in t_vars if 'discriminator' in var.name]
272 |
273 |
274 | # with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
275 | self.G_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Generator_loss, var_list=G_vars)
276 | self.D_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Discriminator_loss, var_list=D_vars)
277 |
278 | """" Summary """
279 | self.all_G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss)
280 | self.all_D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss)
281 | self.G_A_loss = tf.summary.scalar("G_A_loss", Generator_A_loss)
282 | self.G_B_loss = tf.summary.scalar("G_B_loss", Generator_B_loss)
283 | self.D_A_loss = tf.summary.scalar("D_A_loss", Discriminator_A_loss)
284 | self.D_B_loss = tf.summary.scalar("D_B_loss", Discriminator_B_loss)
285 |
286 | self.G_loss = tf.summary.merge([self.G_A_loss, self.G_B_loss, self.all_G_loss])
287 | self.D_loss = tf.summary.merge([self.D_A_loss, self.D_B_loss, self.all_D_loss])
288 |
289 | """ Generated Image """
290 | self.fake_B, _ = self.generate_a2b(domain_A) # for test
291 | self.fake_A, _ = self.generate_b2a(domain_B) # for test
292 |
293 | def train(self):
294 | # initialize all variables
295 | tf.global_variables_initializer().run()
296 |
297 | # saver to save model
298 | self.saver = tf.train.Saver()
299 |
300 | # summary writer
301 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_name, self.sess.graph)
302 |
303 |
304 | # restore check-point if it exits
305 | could_load, checkpoint_counter = self.load(self.checkpoint_dir)
306 | if could_load:
307 | start_epoch = (int)(checkpoint_counter / self.num_batches)
308 | start_batch_id = checkpoint_counter - start_epoch * self.num_batches
309 | counter = checkpoint_counter
310 | print(" [*] Load SUCCESS")
311 | else:
312 | start_epoch = 0
313 | start_batch_id = 0
314 | counter = 1
315 | print(" [!] Load failed...")
316 |
317 | # loop for epoch
318 | start_time = time.time()
319 | for epoch in range(start_epoch, self.epoch):
320 | # get batch data
321 | for idx in range(start_batch_id, self.num_batches):
322 | random_index_A = np.random.choice(len(self.trainA), size=self.batch_size, replace=False)
323 | random_index_B = np.random.choice(len(self.trainB), size=self.batch_size, replace=False)
324 | batch_A_images = self.trainA[random_index_A]
325 | batch_B_images = self.trainB[random_index_B]
326 | p = np.random.uniform(low=0.0, high=1.0)
327 |
328 |
329 | train_feed_dict = {
330 | self.domain_A : batch_A_images,
331 | self.domain_B : batch_B_images,
332 | self.prob : p,
333 | self.is_training : True
334 | }
335 |
336 | # Update D
337 | _, d_loss, summary_str = self.sess.run([self.D_optim, self.Discriminator_loss, self.D_loss], feed_dict = train_feed_dict)
338 | self.writer.add_summary(summary_str, counter)
339 |
340 | # Update G
341 | fake_A, fake_B, _, g_loss, summary_str = self.sess.run([self.fake_A, self.fake_B, self.G_optim, self.Generator_loss, self.G_loss], feed_dict = train_feed_dict)
342 | self.writer.add_summary(summary_str, counter)
343 |
344 | # display training status
345 | counter += 1
346 | print("Epoch: [%2d] [%4d/%4d] time: %4.4f d_loss: %.8f, g_loss: %.8f" \
347 | % (epoch, idx, self.num_batches, time.time() - start_time, d_loss, g_loss))
348 |
349 | if np.mod(counter, 100) == 0 :
350 | save_images(batch_A_images, [self.batch_size, 1],
351 | './{}/real_A_{:02d}_{:04d}.jpg'.format(self.sample_dir, epoch, idx+2))
352 | save_images(batch_B_images, [self.batch_size, 1],
353 | './{}/real_B_{:02d}_{:04d}.jpg'.format(self.sample_dir, epoch, idx+2))
354 |
355 | save_images(fake_A, [self.batch_size, 1],
356 | './{}/fake_A_{:02d}_{:04d}.jpg'.format(self.sample_dir, epoch, idx+2))
357 | save_images(fake_B, [self.batch_size, 1],
358 | './{}/fake_B_{:02d}_{:04d}.jpg'.format(self.sample_dir, epoch, idx+2))
359 |
360 | # After an epoch, start_batch_id is set to zero
361 | # non-zero value is only for the first epoch after loading pre-trained model
362 | start_batch_id = 0
363 |
364 | # save model
365 | self.save(self.checkpoint_dir, counter)
366 |
367 | # save model for final step
368 | self.save(self.checkpoint_dir, counter)
369 |
370 |
371 | @property
372 | def model_dir(self):
373 | return "{}_{}_{}".format(
374 | self.model_name, self.dataset_name, self.norm)
375 |
376 | def save(self, checkpoint_dir, step):
377 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name)
378 |
379 | if not os.path.exists(checkpoint_dir):
380 | os.makedirs(checkpoint_dir)
381 |
382 | self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step)
383 |
384 | def load(self, checkpoint_dir):
385 | import re
386 | print(" [*] Reading checkpoints...")
387 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name)
388 |
389 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
390 | if ckpt and ckpt.model_checkpoint_path:
391 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
392 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
393 | counter = int(next(re.finditer("(\d+)(?!.*\d)", ckpt_name)).group(0))
394 | print(" [*] Success to read {}".format(ckpt_name))
395 | return True, counter
396 | else:
397 | print(" [*] Failed to find a checkpoint")
398 | return False, 0
399 |
400 | def test(self):
401 | tf.global_variables_initializer().run()
402 | test_A_files = glob('./dataset/{}/*.*'.format(self.dataset_name + '/testA'))
403 | test_B_files = glob('./dataset/{}/*.*'.format(self.dataset_name + '/testB'))
404 |
405 | """
406 | testA, testB = test_data(dataset_name=self.dataset_name, size=self.img_size)
407 | test_A_images = testA[:]
408 | test_B_images = testB[:]
409 | """
410 | self.saver = tf.train.Saver()
411 | could_load, checkpoint_counter = self.load(self.checkpoint_dir)
412 |
413 | if could_load :
414 | print(" [*] Load SUCCESS")
415 | else :
416 | print(" [!] Load failed...")
417 |
418 | # write html for visual comparison
419 | index_path = os.path.join(self.result_dir, 'index.html')
420 | index = open(index_path, 'w')
421 | index.write("")
422 | index.write("name | input | output |
")
423 |
424 | for sample_file in test_A_files : # A -> B
425 | print('Processing A image: ' + sample_file)
426 | sample_image = np.asarray(load_test_data(sample_file))
427 | image_path = os.path.join(self.result_dir,'{0}'.format(os.path.basename(sample_file)))
428 |
429 | fake_img = self.sess.run(self.fake_B, feed_dict = {self.domain_A : sample_image, self.prob : 0.0, self.is_training : False})
430 |
431 | save_images(fake_img, [1, 1], image_path)
432 | index.write("%s | " % os.path.basename(image_path))
433 | index.write(" | " % (sample_file if os.path.isabs(sample_file) else (
434 | '..' + os.path.sep + sample_file), self.img_size, self.img_size))
435 | index.write(" | " % (image_path if os.path.isabs(image_path) else (
436 | '..' + os.path.sep + image_path), self.img_size, self.img_size))
437 | index.write("")
438 |
439 | for sample_file in test_B_files : # B -> A
440 | print('Processing B image: ' + sample_file)
441 | sample_image = np.asarray(load_test_data(sample_file))
442 | image_path = os.path.join(self.result_dir,'{0}'.format(os.path.basename(sample_file)))
443 |
444 | fake_img = self.sess.run(self.fake_A, feed_dict = {self.domain_B : sample_image, self.prob : 0.0, self.is_training : False})
445 |
446 | save_images(fake_img, [1, 1], image_path)
447 | index.write("%s | " % os.path.basename(image_path))
448 | index.write(" | " % (sample_file if os.path.isabs(sample_file) else (
449 | '..' + os.path.sep + sample_file), self.img_size, self.img_size))
450 | index.write(" | " % (image_path if os.path.isabs(image_path) else (
451 | '..' + os.path.sep + image_path), self.img_size, self.img_size))
452 | index.write("")
453 | index.close()
--------------------------------------------------------------------------------
/UNIT_multi_gpu.py:
--------------------------------------------------------------------------------
1 | from ops import *
2 | from utils import *
3 | from glob import glob
4 | import time
5 |
6 | class UNIT(object):
7 | def __init__(self, sess, args):
8 | self.model_name = 'UNIT'
9 | self.sess = sess
10 | self.checkpoint_dir = args.checkpoint_dir
11 | self.result_dir = args.result_dir
12 | self.log_dir = args.log_dir
13 | self.sample_dir = args.sample_dir
14 | self.dataset_name = args.dataset
15 |
16 | self.epoch = args.epoch # 100000
17 | self.batch_size_per_gpu = args.batch_size
18 | self.batch_size = args.batch_size * args.gpu_num
19 | self.gpu_num = args.gpu_num
20 |
21 | self.lr = args.lr # 0.0001
22 | """ Weight about VAE """
23 | self.KL_weight = args.KL_weight # lambda 1
24 | self.L1_weight = args.L1_weight # lambda 2
25 |
26 | """ Weight about VAE Cycle"""
27 | self.KL_cycle_weight = args.KL_cycle_weight # lambda 3
28 | self.L1_cycle_weight = args.L1_cycle_weight # lambda 4
29 |
30 | """ Weight about GAN """
31 | self.GAN_weight = args.GAN_weight # lambda 0
32 |
33 |
34 | """ Encoder """
35 | self.ch = args.ch # base channel number per layer
36 | self.n_encoder = args.n_encoder
37 | self.n_enc_resblock = args.n_enc_resblock
38 | self.n_enc_share = args.n_enc_share
39 |
40 | """ Generator """
41 | self.n_gen_share = args.n_gen_share
42 | self.n_gen_resblock = args.n_gen_resblock
43 | self.n_gen_decoder = args.n_gen_decoder
44 |
45 | """ Discriminator """
46 | self.n_dis = args.n_dis # + 2
47 |
48 | self.res_dropout = args.res_dropout
49 | self.smoothing = args.smoothing
50 | self.lsgan = args.lsgan
51 | self.norm = args.norm
52 | self.replay_memory = args.replay_memory
53 | self.pool_size = args.pool_size
54 | self.img_size = args.img_size
55 | self.channel = args.img_ch
56 | self.augment_flag = args.augment_flag
57 | self.augment_size = self.img_size + (30 if self.img_size == 256 else 15)
58 | self.normal_weight_init = args.normal_weight_init
59 |
60 | self.trainA, self.trainB = prepare_data(dataset_name=self.dataset_name, size=self.img_size)
61 | self.num_batches = max(len(self.trainA) // self.batch_size, len(self.trainB) // self.batch_size)
62 |
63 | ##############################################################################
64 | # BEGIN of ENCODERS
65 | def encoder(self, x, is_training=True, reuse=False, scope="encoder"):
66 | channel = self.ch
67 | with tf.variable_scope(scope, reuse=reuse) :
68 | x = conv(x, channel, kernel=7, stride=1, pad=3, normal_weight_init=self.normal_weight_init, activation_fn='leaky', scope='conv_0')
69 |
70 | for i in range(1, self.n_encoder) :
71 | x = conv(x, channel*2, kernel=3, stride=2, pad=1, normal_weight_init=self.normal_weight_init, activation_fn='leaky', scope='conv_'+str(i))
72 | channel *= 2
73 |
74 | # channel = 256
75 | for i in range(0, self.n_enc_resblock) :
76 | x = resblock(x, channel, kernel=3, stride=1, pad=1, dropout_ratio=self.res_dropout,
77 | normal_weight_init=self.normal_weight_init,
78 | is_training=is_training, norm_fn=self.norm, scope='resblock_'+str(i))
79 |
80 | return x
81 | # END of ENCODERS
82 | ##############################################################################
83 |
84 | ##############################################################################
85 | # BEGIN of SHARED LAYERS
86 | # Shared residual-blocks
87 | def share_encoder(self, x, is_training=True, reuse=False, scope="share_encoder"):
88 | channel = self.ch * pow(2, self.n_encoder-1)
89 | with tf.variable_scope(scope, reuse=reuse) :
90 | for i in range(0, self.n_enc_share) :
91 | x = resblock(x, channel, kernel=3, stride=1, pad=1, dropout_ratio=self.res_dropout,
92 | normal_weight_init=self.normal_weight_init,
93 | is_training=is_training, norm_fn=self.norm, scope='resblock_'+str(i))
94 |
95 | x = gaussian_noise_layer(x)
96 |
97 | return x
98 |
99 | def share_generator(self, x, is_training=True, reuse=False, scope="share_generator"):
100 | channel = self.ch * pow(2, self.n_encoder-1)
101 | with tf.variable_scope(scope, reuse=reuse) :
102 | for i in range(0, self.n_gen_share) :
103 | x = resblock(x, channel, kernel=3, stride=1, pad=1, dropout_ratio=self.res_dropout,
104 | normal_weight_init=self.normal_weight_init,
105 | is_training=is_training, norm_fn=self.norm, scope='resblock_'+str(i))
106 |
107 | return x
108 | # END of SHARED LAYERS
109 | ##############################################################################
110 |
111 | ##############################################################################
112 | # BEGIN of DECODERS
113 | def generator(self, x, is_training=True, reuse=False, scope="generator"):
114 | channel = self.ch * pow(2, self.n_encoder - 1)
115 | with tf.variable_scope(scope, reuse=reuse) :
116 | for i in range(0, self.n_gen_resblock) :
117 | x = resblock(x, channel, kernel=3, stride=1, pad=1, dropout_ratio=self.res_dropout,
118 | normal_weight_init=self.normal_weight_init,
119 | is_training=is_training, norm_fn=self.norm, scope='resblock_'+str(i))
120 |
121 | for i in range(0, self.n_gen_decoder-1) :
122 | x = deconv(x, channel//2, kernel=3, stride=2, normal_weight_init=self.normal_weight_init, activation_fn='leaky', scope='deconv_'+str(i))
123 | channel = channel // 2
124 |
125 | x = deconv(x, self.channel, kernel=1, stride=1, normal_weight_init=self.normal_weight_init, activation_fn='tanh', scope='deconv_tanh')
126 |
127 | return x
128 | # END of DECODERS
129 | ##############################################################################
130 |
131 | ##############################################################################
132 | # BEGIN of DISCRIMINATORS
133 | def discriminator(self, x, reuse=False, scope="discriminator"):
134 | channel = self.ch
135 | with tf.variable_scope(scope, reuse=reuse):
136 | x = conv(x, channel, kernel=3, stride=2, pad=1, normal_weight_init=self.normal_weight_init, activation_fn='leaky', scope='conv_0')
137 |
138 | for i in range(1, self.n_dis) :
139 | x = conv(x, channel*2, kernel=3, stride=2, pad=1, normal_weight_init=self.normal_weight_init, activation_fn='leaky', scope='conv_'+str(i))
140 | channel *= 2
141 |
142 | x = conv(x, channels=1, kernel=1, stride=1, pad=0, normal_weight_init=self.normal_weight_init, activation_fn=None, scope='dis_logit')
143 |
144 | return x
145 | # END of DISCRIMINATORS
146 | ##############################################################################
147 |
148 | def translation(self, x_A, x_B):
149 | out = tf.concat([self.encoder(x_A, self.is_training, scope="encoder_A"), self.encoder(x_B, self.is_training, scope="encoder_B")], axis=0)
150 | shared = self.share_encoder(out, self.is_training)
151 | out = self.share_generator(shared, self.is_training)
152 |
153 | out_A = self.generator(out, self.is_training, scope="generator_A")
154 | out_B = self.generator(out, self.is_training, scope="generator_B")
155 |
156 | x_Aa, x_Ba = tf.split(out_A, 2, axis=0)
157 | x_Ab, x_Bb = tf.split(out_B, 2, axis=0)
158 |
159 | return x_Aa, x_Ba, x_Ab, x_Bb, shared
160 |
161 | def generate_a2b(self, x_A):
162 | out = self.encoder(x_A, self.is_training, reuse=True, scope="encoder_A")
163 | shared = self.share_encoder(out, self.is_training, reuse=True)
164 | out = self.share_generator(shared, self.is_training, reuse=True)
165 | out = self.generator(out, self.is_training, reuse=True, scope="generator_B")
166 |
167 | return out, shared
168 |
169 | def generate_b2a(self, x_B):
170 | out = self.encoder(x_B, self.is_training, reuse=True, scope="encoder_B")
171 | shared = self.share_encoder(out, self.is_training, reuse=True)
172 | out = self.share_generator(shared, self.is_training, reuse=True)
173 | out = self.generator(out, self.is_training, reuse=True, scope="generator_A")
174 |
175 | return out, shared
176 |
177 | def discriminate_real(self, x_A, x_B):
178 | real_A_logit = self.discriminator(x_A, scope="discriminator_A")
179 | real_B_logit = self.discriminator(x_B, scope="discriminator_B")
180 |
181 | return real_A_logit, real_B_logit
182 |
183 | def discriminate_fake(self, x_ba, x_ab):
184 | fake_A_logit = self.discriminator(x_ba, reuse=True, scope="discriminator_A")
185 | fake_B_logit = self.discriminator(x_ab, reuse=True, scope="discriminator_B")
186 |
187 | return fake_A_logit, fake_B_logit
188 |
189 | def discriminate_fake_pool(self, x_ba, x_ab):
190 | fake_A_pool_logit = self.discriminator(self.fake_A_pool.query(x_ba), reuse=True, scope="discriminator_A") # replay memory
191 | fake_B_pool_logit = self.discriminator(self.fake_B_pool.query(x_ab), reuse=True, scope="discriminator_B") # replay memory
192 |
193 | return fake_A_pool_logit, fake_B_pool_logit
194 |
195 | def build_model(self):
196 | self.is_training = tf.placeholder(tf.bool)
197 | self.prob = tf.placeholder(tf.float32)
198 | self.condition = tf.logical_and(tf.greater(self.prob, tf.constant(0.5)), self.is_training)
199 |
200 | """ Input Image"""
201 | domain_A = self.domain_A = tf.placeholder(tf.float32, [self.batch_size, self.img_size, self.img_size, self.channel], name='domain_A') # real A
202 | domain_B = self.domain_B = tf.placeholder(tf.float32, [self.batch_size, self.img_size, self.img_size, self.channel], name='domain_B') # real B
203 |
204 | self.test_domain_A = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.channel], name='test_domain_A')
205 | self.test_domain_B = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.channel], name='test_domain_B')
206 |
207 | if self.augment_flag :
208 | """ Augmentation """
209 | domain_A = tf.cond(
210 | self.condition,
211 | lambda : augmentation(domain_A, self.augment_size),
212 | lambda : domain_A
213 | )
214 |
215 | domain_B = tf.cond(
216 | self.condition,
217 | lambda : augmentation(domain_B, self.augment_size),
218 | lambda : domain_B
219 | )
220 |
221 | domain_A = tf.split(domain_A, self.gpu_num)
222 | domain_B = tf.split(domain_B, self.gpu_num)
223 |
224 | G_A_losses= []
225 | G_B_losses = []
226 | D_A_losses = []
227 | D_B_losses = []
228 |
229 | G_losses = []
230 | D_losses = []
231 |
232 | self.fake_A = []
233 | self.fake_B = []
234 | for gpu_id in range(self.gpu_num) :
235 | with tf.device(tf.DeviceSpec(device_type="GPU", device_index=gpu_id)) :
236 | with tf.variable_scope(tf.get_variable_scope(), reuse=(gpu_id > 0)) :
237 | """ Define Encoder, Generator, Discriminator """
238 | x_aa, x_ba, x_ab, x_bb, shared = self.translation(domain_A[gpu_id], domain_B[gpu_id])
239 | x_bab, shared_bab = self.generate_a2b(x_ba)
240 | x_aba, shared_aba = self.generate_b2a(x_ab)
241 |
242 | real_A_logit, real_B_logit = self.discriminate_real(domain_A[gpu_id], domain_B[gpu_id])
243 |
244 | if self.replay_memory :
245 | self.fake_A_pool = ImagePool(self.pool_size) # pool of generated A
246 | self.fake_B_pool = ImagePool(self.pool_size) # pool of generated B
247 | fake_A_logit, fake_B_logit = self.discriminate_fake_pool(x_ba, x_ab)
248 | else :
249 | fake_A_logit, fake_B_logit = self.discriminate_fake(x_ba, x_ab)
250 |
251 |
252 |
253 | """ Define Loss """
254 | G_ad_loss_a = generator_loss(fake_A_logit, smoothing=self.smoothing, use_lsgan=self.lsgan)
255 | G_ad_loss_b = generator_loss(fake_B_logit, smoothing=self.smoothing, use_lsgan=self.lsgan)
256 |
257 | D_ad_loss_a = discriminator_loss(real_A_logit, fake_A_logit, smoothing=self.smoothing, use_lasgan=self.lsgan)
258 | D_ad_loss_b = discriminator_loss(real_B_logit, fake_B_logit, smoothing=self.smoothing, use_lasgan=self.lsgan)
259 |
260 | enc_loss = KL_divergence(shared)
261 | enc_bab_loss = KL_divergence(shared_bab)
262 | enc_aba_loss = KL_divergence(shared_aba)
263 |
264 | l1_loss_a = L1_loss(x_aa, domain_A[gpu_id]) # identity
265 | l1_loss_b = L1_loss(x_bb, domain_B[gpu_id]) # identity
266 | l1_loss_aba = L1_loss(x_aba, domain_A[gpu_id]) # reconstruction
267 | l1_loss_bab = L1_loss(x_bab, domain_B[gpu_id]) # reconstruction
268 |
269 | Generator_A_loss_split = self.GAN_weight * G_ad_loss_a + \
270 | self.L1_weight * l1_loss_a + \
271 | self.L1_cycle_weight * l1_loss_aba + \
272 | self.KL_weight * enc_loss + \
273 | self.KL_cycle_weight * enc_bab_loss
274 |
275 | Generator_B_loss_split = self.GAN_weight * G_ad_loss_b + \
276 | self.L1_weight * l1_loss_b + \
277 | self.L1_cycle_weight * l1_loss_bab + \
278 | self.KL_weight * enc_loss + \
279 | self.KL_cycle_weight * enc_aba_loss
280 |
281 | Discriminator_A_loss_split = self.GAN_weight * D_ad_loss_a
282 | Discriminator_B_loss_split = self.GAN_weight * D_ad_loss_b
283 |
284 | Generator_loss_split = Generator_A_loss_split + Generator_B_loss_split
285 | Discriminator_loss_split = Discriminator_A_loss_split + Discriminator_B_loss_split
286 |
287 | """ Generated Image """
288 | fake_B, _ = self.generate_a2b(domain_A[gpu_id]) # for test
289 | fake_A, _ = self.generate_b2a(domain_B[gpu_id]) # for test
290 |
291 | G_A_losses.append(Generator_A_loss_split)
292 | G_B_losses.append(Generator_B_loss_split)
293 | D_A_losses.append(Discriminator_A_loss_split)
294 | D_B_losses.append(Discriminator_B_loss_split)
295 |
296 | G_losses.append(Generator_loss_split)
297 | D_losses.append(Discriminator_loss_split)
298 |
299 | self.fake_A.append(fake_A)
300 | self.fake_B.append(fake_B)
301 |
302 | Generator_A_loss = tf.reduce_mean(G_A_losses)
303 | Generator_B_loss = tf.reduce_mean(G_B_losses)
304 | Discriminator_A_loss = tf.reduce_mean(D_A_losses)
305 | Discriminator_B_loss = tf.reduce_mean(D_B_losses)
306 |
307 | self.Generator_loss = tf.reduce_mean(G_losses)
308 | self.Discriminator_loss = tf.reduce_mean(D_losses)
309 |
310 | self.fake_A = tf.concat(self.fake_A, axis=0)
311 | self.fake_B = tf.concat(self.fake_B, axis=0)
312 |
313 | self.test_fake_B, _ = self.generate_a2b(self.test_domain_A)
314 | self.test_fake_A, _ = self.generate_b2a(self.test_domain_B)
315 |
316 | """ Training """
317 | t_vars = tf.trainable_variables()
318 | G_vars = [var for var in t_vars if ('generator' in var.name) or ('encoder' in var.name)]
319 | D_vars = [var for var in t_vars if 'discriminator' in var.name]
320 |
321 |
322 | # with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
323 | self.G_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Generator_loss, colocate_gradients_with_ops=True, var_list=G_vars)
324 | self.D_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Discriminator_loss, colocate_gradients_with_ops=True, var_list=D_vars)
325 | """" Summary """
326 | self.all_G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss)
327 | self.all_D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss)
328 | self.G_A_loss = tf.summary.scalar("G_A_loss", Generator_A_loss)
329 | self.G_B_loss = tf.summary.scalar("G_B_loss", Generator_B_loss)
330 | self.D_A_loss = tf.summary.scalar("D_A_loss", Discriminator_A_loss)
331 | self.D_B_loss = tf.summary.scalar("D_B_loss", Discriminator_B_loss)
332 |
333 | self.G_loss = tf.summary.merge([self.G_A_loss, self.G_B_loss, self.all_G_loss])
334 | self.D_loss = tf.summary.merge([self.D_A_loss, self.D_B_loss, self.all_D_loss])
335 |
336 |
337 | def train(self):
338 | # initialize all variables
339 | tf.global_variables_initializer().run()
340 |
341 | # saver to save model
342 | self.saver = tf.train.Saver()
343 |
344 | # summary writer
345 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_name, self.sess.graph)
346 |
347 |
348 | # restore check-point if it exits
349 | could_load, checkpoint_counter = self.load(self.checkpoint_dir)
350 | if could_load:
351 | start_epoch = (int)(checkpoint_counter / self.num_batches)
352 | start_batch_id = checkpoint_counter - start_epoch * self.num_batches
353 | counter = checkpoint_counter
354 | print(" [*] Load SUCCESS")
355 | else:
356 | start_epoch = 0
357 | start_batch_id = 0
358 | counter = 1
359 | print(" [!] Load failed...")
360 |
361 | # loop for epoch
362 | start_time = time.time()
363 | for epoch in range(start_epoch, self.epoch):
364 | # get batch data
365 | for idx in range(start_batch_id, self.num_batches):
366 | random_index_A = np.random.choice(len(self.trainA), size=self.batch_size, replace=False)
367 | random_index_B = np.random.choice(len(self.trainB), size=self.batch_size, replace=False)
368 | batch_A_images = self.trainA[random_index_A]
369 | batch_B_images = self.trainB[random_index_B]
370 | p = np.random.uniform(low=0.0, high=1.0)
371 |
372 |
373 | train_feed_dict = {
374 | self.domain_A : batch_A_images,
375 | self.domain_B : batch_B_images,
376 | self.prob : p,
377 | self.is_training : True
378 | }
379 |
380 | # Update D
381 | _, d_loss, summary_str = self.sess.run([self.D_optim, self.Discriminator_loss, self.D_loss], feed_dict = train_feed_dict)
382 | self.writer.add_summary(summary_str, counter)
383 |
384 | # Update G
385 | fake_A, fake_B, _, g_loss, summary_str = self.sess.run([self.fake_A, self.fake_B, self.G_optim, self.Generator_loss, self.G_loss], feed_dict = train_feed_dict)
386 | self.writer.add_summary(summary_str, counter)
387 |
388 | # display training status
389 | counter += 1
390 | print("Epoch: [%2d] [%4d/%4d] time: %4.4f d_loss: %.8f, g_loss: %.8f" \
391 | % (epoch, idx, self.num_batches, time.time() - start_time, d_loss, g_loss))
392 |
393 | if np.mod(counter, 10) == 0 :
394 | batch_A_images = np.split(batch_A_images, self.gpu_num)
395 | batch_B_images = np.split(batch_B_images, self.gpu_num)
396 | fake_A = np.split(fake_A, self.gpu_num)
397 | fake_B = np.split(fake_B, self.gpu_num)
398 |
399 | for gpu_id in range(self.gpu_num) :
400 | save_images(batch_A_images[gpu_id], [self.batch_size_per_gpu, 1],
401 | './{}/real_A_{}_{:02d}_{:04d}.jpg'.format(self.sample_dir, gpu_id, epoch, idx+2))
402 | save_images(batch_B_images[gpu_id], [self.batch_size_per_gpu, 1],
403 | './{}/real_B_{}_{:02d}_{:04d}.jpg'.format(self.sample_dir, gpu_id, epoch, idx+2))
404 |
405 | save_images(fake_A[gpu_id], [self.batch_size_per_gpu, 1],
406 | './{}/fake_A_{}_{:02d}_{:04d}.jpg'.format(self.sample_dir, gpu_id, epoch, idx+2))
407 | save_images(fake_B[gpu_id], [self.batch_size_per_gpu, 1],
408 | './{}/fake_B_{}_{:02d}_{:04d}.jpg'.format(self.sample_dir, gpu_id, epoch, idx+2))
409 |
410 | # After an epoch, start_batch_id is set to zero
411 | # non-zero value is only for the first epoch after loading pre-trained model
412 | start_batch_id = 0
413 |
414 | # save model
415 | self.save(self.checkpoint_dir, counter)
416 |
417 | # save model for final step
418 | self.save(self.checkpoint_dir, counter)
419 |
420 |
421 | @property
422 | def model_dir(self):
423 | return "{}_{}_{}".format(
424 | self.model_name, self.dataset_name, self.norm)
425 |
426 | def save(self, checkpoint_dir, step):
427 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name)
428 |
429 | if not os.path.exists(checkpoint_dir):
430 | os.makedirs(checkpoint_dir)
431 |
432 | self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step)
433 |
434 | def load(self, checkpoint_dir):
435 | import re
436 | print(" [*] Reading checkpoints...")
437 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name)
438 |
439 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
440 | if ckpt and ckpt.model_checkpoint_path:
441 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
442 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
443 | counter = int(next(re.finditer("(\d+)(?!.*\d)", ckpt_name)).group(0))
444 | print(" [*] Success to read {}".format(ckpt_name))
445 | return True, counter
446 | else:
447 | print(" [*] Failed to find a checkpoint")
448 | return False, 0
449 |
450 | def test(self):
451 | tf.global_variables_initializer().run()
452 | test_A_files = glob('./dataset/{}/*.*'.format(self.dataset_name + '/testA'))
453 | test_B_files = glob('./dataset/{}/*.*'.format(self.dataset_name + '/testB'))
454 |
455 | """
456 | testA, testB = test_data(dataset_name=self.dataset_name, size=self.img_size)
457 | test_A_images = testA[:]
458 | test_B_images = testB[:]
459 | """
460 | self.saver = tf.train.Saver()
461 | could_load, checkpoint_counter = self.load(self.checkpoint_dir)
462 |
463 | if could_load :
464 | print(" [*] Load SUCCESS")
465 | else :
466 | print(" [!] Load failed...")
467 |
468 | # write html for visual comparison
469 | index_path = os.path.join(self.result_dir, 'index.html')
470 | index = open(index_path, 'w')
471 | index.write("")
472 | index.write("name | input | output |
")
473 |
474 | for sample_file in test_A_files : # A -> B
475 | print('Processing A image: ' + sample_file)
476 | sample_image = np.asarray(load_test_data(sample_file))
477 | image_path = os.path.join(self.result_dir,'{0}'.format(os.path.basename(sample_file)))
478 |
479 | fake_img = self.sess.run(self.test_fake_B, feed_dict = {self.test_domain_A : sample_image, self.is_training : False})
480 |
481 | save_images(fake_img, [1, 1], image_path)
482 | index.write("%s | " % os.path.basename(image_path))
483 | index.write(" | " % (sample_file if os.path.isabs(sample_file) else (
484 | '..' + os.path.sep + sample_file), self.img_size, self.img_size))
485 | index.write(" | " % (image_path if os.path.isabs(image_path) else (
486 | '..' + os.path.sep + image_path), self.img_size, self.img_size))
487 | index.write("")
488 |
489 | for sample_file in test_B_files : # B -> A
490 | print('Processing B image: ' + sample_file)
491 | sample_image = np.asarray(load_test_data(sample_file))
492 | image_path = os.path.join(self.result_dir,'{0}'.format(os.path.basename(sample_file)))
493 |
494 | fake_img = self.sess.run(self.test_fake_A, feed_dict = {self.test_domain_B : sample_image, self.is_training : False})
495 |
496 | save_images(fake_img, [1, 1], image_path)
497 | index.write("%s | " % os.path.basename(image_path))
498 | index.write(" | " % (sample_file if os.path.isabs(sample_file) else (
499 | '..' + os.path.sep + sample_file), self.img_size, self.img_size))
500 | index.write(" | " % (image_path if os.path.isabs(image_path) else (
501 | '..' + os.path.sep + image_path), self.img_size, self.img_size))
502 | index.write("")
503 | index.close()
--------------------------------------------------------------------------------
/assests/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/.DS_Store
--------------------------------------------------------------------------------
/assests/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/architecture.png
--------------------------------------------------------------------------------
/assests/cat_species.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/cat_species.gif
--------------------------------------------------------------------------------
/assests/cat_trans.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/cat_trans.png
--------------------------------------------------------------------------------
/assests/compare.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/compare.png
--------------------------------------------------------------------------------
/assests/cycle.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/cycle.png
--------------------------------------------------------------------------------
/assests/dog_breed.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/dog_breed.gif
--------------------------------------------------------------------------------
/assests/dog_trans.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/dog_trans.png
--------------------------------------------------------------------------------
/assests/faces.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/faces.png
--------------------------------------------------------------------------------
/assests/fail.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/fail.png
--------------------------------------------------------------------------------
/assests/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/framework.png
--------------------------------------------------------------------------------
/assests/gan_model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/gan_model.png
--------------------------------------------------------------------------------
/assests/slide/compare.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/slide/compare.png
--------------------------------------------------------------------------------
/assests/slide/cycle.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/slide/cycle.png
--------------------------------------------------------------------------------
/assests/slide/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/slide/framework.png
--------------------------------------------------------------------------------
/assests/slide/gan_model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/slide/gan_model.png
--------------------------------------------------------------------------------
/assests/slide/training_objective.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/slide/training_objective.png
--------------------------------------------------------------------------------
/assests/slide/vae_model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/slide/vae_model.png
--------------------------------------------------------------------------------
/assests/success.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/success.png
--------------------------------------------------------------------------------
/assests/training_objective__.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/training_objective__.png
--------------------------------------------------------------------------------
/assests/vae_model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/UNIT-Tensorflow/4d7430a6f0bd3bea72d821e14db6e6442c02ed32/assests/vae_model.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from UNIT import UNIT
2 | import argparse
3 | from utils import *
4 |
5 | """parsing and configuration"""
6 | def parse_args():
7 | desc = "Tensorflow implementation of UNIT"
8 | parser = argparse.ArgumentParser(description=desc)
9 | parser.add_argument('--phase', type=str, default='train', help='train or test ?')
10 | parser.add_argument('--dataset', type=str, default='cat2dog', help='dataset_name')
11 |
12 | parser.add_argument('--epoch', type=int, default=200, help='The number of epochs to run')
13 | parser.add_argument('--batch_size', type=int, default=1, help='The size of batch')
14 |
15 | parser.add_argument('--lr', type=float, default=0.0001, help='The learning rate')
16 | parser.add_argument('--GAN_weight', type=float, default=10.0, help='Weight about GAN, lambda0')
17 | parser.add_argument('--KL_weight', type=float, default=0.1, help='Weight about VAE, lambda1')
18 | parser.add_argument('--L1_weight', type=float, default=100.0, help='Weight about VAE, lambda2' )
19 | parser.add_argument('--KL_cycle_weight', type=float, default=0.1, help='Weight about VAE Cycle, lambda3')
20 | parser.add_argument('--L1_cycle_weight', type=float, default=100.0, help='Weight about VAE Cycle, lambda4')
21 |
22 | parser.add_argument('--ch', type=int, default=64, help='base channel number per layer')
23 | parser.add_argument('--n_encoder', type=int, default=3, help='The number of encoder')
24 | parser.add_argument('--n_enc_resblock', type=int, default=3, help='The number of encoder_resblock')
25 | parser.add_argument('--n_enc_share', type=int, default=1, help='The number of share_encoder')
26 | parser.add_argument('--n_gen_share', type=int, default=1, help='The number of share_generator')
27 | parser.add_argument('--n_gen_resblock', type=int, default=3, help='The number of generator_resblock')
28 | parser.add_argument('--n_gen_decoder', type=int, default=3, help='The number of generator_decoder')
29 | parser.add_argument('--n_dis', type=int, default=6, help='The number of discriminator layer')
30 |
31 | parser.add_argument('--res_dropout', type=float, default=0.0, help='The dropout ration of Resblock')
32 | parser.add_argument('--smoothing', type=bool, default=False, help='smoothing loss use or not')
33 | parser.add_argument('--lsgan', type=bool, default=False, help='lsgan loss use or not')
34 | parser.add_argument('--norm', type=str, default='instance', help='The norm type')
35 | parser.add_argument('--replay_memory', type=bool, default=False, help='discriminator pool use or not')
36 | parser.add_argument('--pool_size', type=int, default=50, help='The size of image buffer that stores previously generated images')
37 | parser.add_argument('--img_size', type=int, default=256, help='The size of image')
38 | parser.add_argument('--img_ch', type=int, default=3, help='The size of image channel')
39 | parser.add_argument('--augment_flag', type=bool, default=True, help='Image augmentation use or not')
40 | parser.add_argument('--normal_weight_init', type=bool, default=True, help='normal initialization use or not')
41 |
42 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint',
43 | help='Directory name to save the checkpoints')
44 | parser.add_argument('--result_dir', type=str, default='results',
45 | help='Directory name to save the generated images')
46 | parser.add_argument('--log_dir', type=str, default='logs',
47 | help='Directory name to save training logs')
48 | parser.add_argument('--sample_dir', type=str, default='samples',
49 | help='Directory name to save the samples on training')
50 |
51 | return check_args(parser.parse_args())
52 |
53 | """checking arguments"""
54 | def check_args(args):
55 | # --checkpoint_dir
56 | check_folder(args.checkpoint_dir)
57 |
58 | # --result_dir
59 | check_folder(args.result_dir)
60 |
61 | # --result_dir
62 | check_folder(args.log_dir)
63 |
64 | # --sample_dir
65 | check_folder(args.sample_dir)
66 |
67 | # --epoch
68 | try:
69 | assert args.epoch >= 1
70 | except:
71 | print('number of epochs must be larger than or equal to one')
72 |
73 | # --batch_size
74 | try:
75 | assert args.batch_size >= 1
76 | except:
77 | print('batch size must be larger than or equal to one')
78 | return args
79 |
80 | """main"""
81 | def main():
82 | # parse arguments
83 | args = parse_args()
84 | if args is None:
85 | exit()
86 |
87 | # open session
88 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
89 | gan = UNIT(sess, args)
90 |
91 | # build graph
92 | gan.build_model()
93 |
94 | # show network architecture
95 | show_all_variables()
96 |
97 | if args.phase == 'train' :
98 | # launch the graph in a session
99 | gan.train()
100 | print(" [*] Training finished!")
101 |
102 | if args.phase == 'test' :
103 | gan.test()
104 | print(" [*] Test finished!")
105 |
106 | if __name__ == '__main__':
107 | main()
--------------------------------------------------------------------------------
/main_multi_gpu.py:
--------------------------------------------------------------------------------
1 | from UNIT_multi_gpu import UNIT
2 | import argparse
3 | from utils import *
4 |
5 | """parsing and configuration"""
6 | def parse_args():
7 | desc = "Tensorflow implementation of UNIT"
8 | parser = argparse.ArgumentParser(description=desc)
9 | parser.add_argument('--phase', type=str, default='train', help='train or test ?')
10 | parser.add_argument('--dataset', type=str, default='cat2dog', help='dataset_name')
11 |
12 | parser.add_argument('--epoch', type=int, default=200, help='The number of epochs to run')
13 | parser.add_argument('--batch_size', type=int, default=1, help='The size of batch per gpu')
14 | parser.add_argument('--gpu_num', type=int, default=8, help='The number of gpu')
15 |
16 | parser.add_argument('--lr', type=float, default=0.0001, help='The learning rate')
17 | parser.add_argument('--GAN_weight', type=float, default=10.0, help='Weight about GAN, lambda0')
18 | parser.add_argument('--KL_weight', type=float, default=0.1, help='Weight about VAE, lambda1')
19 | parser.add_argument('--L1_weight', type=float, default=100.0, help='Weight about VAE, lambda2' )
20 | parser.add_argument('--KL_cycle_weight', type=float, default=0.1, help='Weight about VAE Cycle, lambda3')
21 | parser.add_argument('--L1_cycle_weight', type=float, default=100.0, help='Weight about VAE Cycle, lambda4')
22 |
23 | parser.add_argument('--ch', type=int, default=64, help='base channel number per layer')
24 | parser.add_argument('--n_encoder', type=int, default=3, help='The number of encoder')
25 | parser.add_argument('--n_enc_resblock', type=int, default=3, help='The number of encoder_resblock')
26 | parser.add_argument('--n_enc_share', type=int, default=1, help='The number of share_encoder')
27 | parser.add_argument('--n_gen_share', type=int, default=1, help='The number of share_generator')
28 | parser.add_argument('--n_gen_resblock', type=int, default=3, help='The number of generator_resblock')
29 | parser.add_argument('--n_gen_decoder', type=int, default=3, help='The number of generator_decoder')
30 | parser.add_argument('--n_dis', type=int, default=6, help='The number of discriminator layer')
31 |
32 | parser.add_argument('--res_dropout', type=float, default=0.0, help='The dropout ration of Resblock')
33 | parser.add_argument('--smoothing', type=bool, default=False, help='smoothing loss use or not')
34 | parser.add_argument('--lsgan', type=bool, default=False, help='lsgan loss use or not')
35 | parser.add_argument('--norm', type=str, default='instance', help='The norm type')
36 | parser.add_argument('--replay_memory', type=bool, default=False, help='discriminator pool use or not')
37 | parser.add_argument('--pool_size', type=int, default=50, help='The size of image buffer that stores previously generated images')
38 | parser.add_argument('--img_size', type=int, default=256, help='The size of image')
39 | parser.add_argument('--img_ch', type=int, default=3, help='The size of image channel')
40 | parser.add_argument('--augment_flag', type=bool, default=True, help='Image augmentation use or not')
41 | parser.add_argument('--normal_weight_init', type=bool, default=True, help='normal initialization use or not')
42 |
43 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint',
44 | help='Directory name to save the checkpoints')
45 | parser.add_argument('--result_dir', type=str, default='results',
46 | help='Directory name to save the generated images')
47 | parser.add_argument('--log_dir', type=str, default='logs',
48 | help='Directory name to save training logs')
49 | parser.add_argument('--sample_dir', type=str, default='samples',
50 | help='Directory name to save the samples on training')
51 |
52 | return check_args(parser.parse_args())
53 |
54 | """checking arguments"""
55 | def check_args(args):
56 | # --checkpoint_dir
57 | check_folder(args.checkpoint_dir)
58 |
59 | # --result_dir
60 | check_folder(args.result_dir)
61 |
62 | # --result_dir
63 | check_folder(args.log_dir)
64 |
65 | # --sample_dir
66 | check_folder(args.sample_dir)
67 |
68 | # --epoch
69 | try:
70 | assert args.epoch >= 1
71 | except:
72 | print('number of epochs must be larger than or equal to one')
73 |
74 | # --batch_size
75 | try:
76 | assert args.batch_size >= 1
77 | except:
78 | print('batch size must be larger than or equal to one')
79 | return args
80 |
81 | """main"""
82 | def main():
83 | # parse arguments
84 | args = parse_args()
85 | if args is None:
86 | exit()
87 |
88 | # open session
89 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
90 | gan = UNIT(sess, args)
91 |
92 | # build graph
93 | gan.build_model()
94 |
95 | # show network architecture
96 | show_all_variables()
97 |
98 | if args.phase == 'train' :
99 | # launch the graph in a session
100 | gan.train()
101 | print(" [*] Training finished!")
102 |
103 | if args.phase == 'test' :
104 | gan.test()
105 | print(" [*] Test finished!")
106 |
107 | if __name__ == '__main__':
108 | main()
--------------------------------------------------------------------------------
/ops.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import tensorflow.contrib as tf_contrib
3 | from tensorflow.contrib.layers import variance_scaling_initializer as he_init
4 |
5 | def conv(x, channels, kernel=3, stride=2, pad=0, normal_weight_init=False, activation_fn='leaky', scope='conv_0') :
6 | with tf.variable_scope(scope) :
7 | x = tf.pad(x, [[0,0], [pad, pad], [pad, pad], [0,0]])
8 |
9 | if normal_weight_init :
10 | x = tf.layers.conv2d(inputs=x, filters=channels, kernel_size=kernel, kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
11 | strides=stride, kernel_regularizer=tf_contrib.layers.l2_regularizer(scale=0.0001))
12 |
13 | else :
14 | if activation_fn == 'relu' :
15 | x = tf.layers.conv2d(inputs=x, filters=channels, kernel_size=kernel, kernel_initializer=he_init(), strides=stride,
16 | kernel_regularizer=tf_contrib.layers.l2_regularizer(scale=0.0001))
17 | else :
18 | x = tf.layers.conv2d(inputs=x, filters=channels, kernel_size=kernel, strides=stride,
19 | kernel_regularizer=tf_contrib.layers.l2_regularizer(scale=0.0001))
20 |
21 |
22 | x = activation(x, activation_fn)
23 |
24 | return x
25 |
26 | def deconv(x, channels, kernel=3, stride=2, normal_weight_init=False, activation_fn='leaky', scope='deconv_0') :
27 | with tf.variable_scope(scope):
28 | if normal_weight_init:
29 | x = tf.layers.conv2d_transpose(inputs=x, filters=channels, kernel_size=kernel,
30 | kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
31 | strides=stride, padding='SAME', kernel_regularizer=tf_contrib.layers.l2_regularizer(scale=0.0001))
32 |
33 | else:
34 | if activation_fn == 'relu' :
35 | x = tf.layers.conv2d_transpose(inputs=x, filters=channels, kernel_size=kernel, kernel_initializer=he_init(), strides=stride, padding='SAME',
36 | kernel_regularizer=tf_contrib.layers.l2_regularizer(scale=0.0001))
37 | else :
38 | x = tf.layers.conv2d_transpose(inputs=x, filters=channels, kernel_size=kernel, strides=stride, padding='SAME',
39 | kernel_regularizer=tf_contrib.layers.l2_regularizer(scale=0.0001))
40 |
41 | x = activation(x, activation_fn)
42 |
43 | return x
44 |
45 | def resblock(x_init, channels, kernel=3, stride=1, pad=1, dropout_ratio=0.0, normal_weight_init=False, is_training=True, norm_fn='instance', scope='resblock_0') :
46 | assert norm_fn in ['instance', 'batch', 'weight', 'spectral', None]
47 | with tf.variable_scope(scope) :
48 | with tf.variable_scope('res1') :
49 | x = tf.pad(x_init, [[0, 0], [pad, pad], [pad, pad], [0, 0]])
50 |
51 | if normal_weight_init :
52 | x = tf.layers.conv2d(inputs=x, filters=channels, kernel_size=kernel,
53 | kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
54 | strides=stride, kernel_regularizer=tf_contrib.layers.l2_regularizer(scale=0.0001))
55 | else :
56 | x = tf.layers.conv2d(inputs=x, filters=channels, kernel_size=kernel, kernel_initializer=he_init(),
57 | strides=stride, kernel_regularizer=tf_contrib.layers.l2_regularizer(scale=0.0001))
58 |
59 | if norm_fn == 'instance' :
60 | x = instance_norm(x, 'res1_instance')
61 | if norm_fn == 'batch' :
62 | x = batch_norm(x, is_training, 'res1_batch')
63 |
64 | x = relu(x)
65 | with tf.variable_scope('res2') :
66 | x = tf.pad(x, [[0, 0], [pad, pad], [pad, pad], [0, 0]])
67 |
68 | if normal_weight_init :
69 | x = tf.layers.conv2d(inputs=x, filters=channels, kernel_size=kernel,
70 | kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
71 | strides=stride, kernel_regularizer=tf_contrib.layers.l2_regularizer(scale=0.0001))
72 | else :
73 | x = tf.layers.conv2d(inputs=x, filters=channels, kernel_size=kernel, strides=stride,
74 | kernel_regularizer=tf_contrib.layers.l2_regularizer(scale=0.0001))
75 |
76 | if norm_fn == 'instance' :
77 | x = instance_norm(x, 'res2_instance')
78 | if norm_fn == 'batch' :
79 | x = batch_norm(x, is_training, 'res2_batch')
80 |
81 | if dropout_ratio > 0.0 :
82 | x = tf.layers.dropout(x, rate=dropout_ratio, training=is_training)
83 |
84 | return x + x_init
85 |
86 | def activation(x, activation_fn='leaky') :
87 | assert activation_fn in ['relu', 'leaky', 'tanh', 'sigmoid', 'swish', None]
88 | if activation_fn == 'leaky':
89 | x = lrelu(x)
90 |
91 | if activation_fn == 'relu':
92 | x = relu(x)
93 |
94 | if activation_fn == 'sigmoid':
95 | x = sigmoid(x)
96 |
97 | if activation_fn == 'tanh' :
98 | x = tanh(x)
99 |
100 | if activation_fn == 'swish' :
101 | x = swish(x)
102 |
103 | return x
104 |
105 | def lrelu(x, alpha=0.01) :
106 | # pytorch alpha is 0.01
107 | return tf.nn.leaky_relu(x, alpha)
108 |
109 | def relu(x) :
110 | return tf.nn.relu(x)
111 |
112 | def sigmoid(x) :
113 | return tf.sigmoid(x)
114 |
115 | def tanh(x) :
116 | return tf.tanh(x)
117 |
118 | def swish(x) :
119 | return x * sigmoid(x)
120 |
121 | def batch_norm(x, is_training=False, scope='batch_nom') :
122 | return tf_contrib.layers.batch_norm(x,
123 | decay=0.9, epsilon=1e-05,
124 | center=True, scale=True, updates_collections=None,
125 | is_training=is_training, scope=scope)
126 |
127 | def instance_norm(x, scope='instance') :
128 | return tf_contrib.layers.instance_norm(x,
129 | epsilon=1e-05,
130 | center=True, scale=True,
131 | scope=scope)
132 |
133 | def gaussian_noise_layer(mu):
134 | sigma = 1.0
135 | gaussian_random_vector = tf.random_normal(shape=tf.shape(mu), mean=0.0, stddev=1.0, dtype=tf.float32)
136 | return mu + sigma * gaussian_random_vector
137 |
138 | def KL_divergence(mu) :
139 | # KL_divergence = 0.5 * tf.reduce_sum(tf.square(mu) + tf.square(sigma) - tf.log(1e-8 + tf.square(sigma)) - 1, axis = -1)
140 | # loss = tf.reduce_mean(KL_divergence)
141 | mu_2 = tf.square(mu)
142 | loss = tf.reduce_mean(mu_2)
143 |
144 | return loss
145 |
146 | def L1_loss(x, y) :
147 | loss = tf.reduce_mean(tf.abs(x - y))
148 | return loss
149 |
150 | def discriminator_loss(real, fake, smoothing=False, use_lasgan=False) :
151 | if use_lasgan :
152 | if smoothing :
153 | real_loss = tf.reduce_mean(tf.squared_difference(real, 0.9)) * 0.5
154 | else :
155 | real_loss = tf.reduce_mean(tf.squared_difference(real, 1.0)) * 0.5
156 |
157 | fake_loss = tf.reduce_mean(tf.square(fake)) * 0.5
158 | else :
159 | if smoothing :
160 | real_labels = tf.fill(tf.shape(real), 0.9)
161 | else :
162 | real_labels = tf.ones_like(real)
163 |
164 | fake_labels = tf.zeros_like(fake)
165 |
166 | real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=real_labels, logits=real))
167 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=fake_labels, logits=fake))
168 |
169 | loss = real_loss + fake_loss
170 |
171 | return loss
172 |
173 | def generator_loss(fake, smoothing=False, use_lsgan=False) :
174 | if use_lsgan :
175 | if smoothing :
176 | loss = tf.reduce_mean(tf.squared_difference(fake, 0.9)) * 0.5
177 | else :
178 | loss = tf.reduce_mean(tf.squared_difference(fake, 1.0)) * 0.5
179 | else :
180 | if smoothing :
181 | fake_labels = tf.fill(tf.shape(fake), 0.9)
182 | else :
183 | fake_labels = tf.ones_like(fake)
184 |
185 | loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=fake_labels, logits=fake))
186 |
187 | return loss
188 |
189 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.contrib import slim
3 | from scipy import misc
4 | import os, random
5 | import numpy as np
6 |
7 |
8 | class ImagePool:
9 | """ History of generated images
10 | Same logic as https://github.com/junyanz/CycleGAN/blob/master/util/image_pool.lua
11 | """
12 |
13 | def __init__(self, pool_size):
14 | self.pool_size = pool_size
15 | self.images = []
16 |
17 | def query(self, image):
18 | if self.pool_size == 0:
19 | return image
20 |
21 | if len(self.images) < self.pool_size:
22 | self.images.append(image)
23 | return image
24 | else:
25 | p = random.random()
26 | if p > 0.5:
27 | # use old image
28 | random_id = random.randrange(0, self.pool_size)
29 | tmp = self.images[random_id].copy()
30 | self.images[random_id] = image.copy()
31 | return tmp
32 | else:
33 | return image
34 |
35 |
36 | def prepare_data(dataset_name, size):
37 | data_path = os.path.join("./dataset", dataset_name)
38 |
39 | trainA = []
40 | trainB = []
41 | for path, dir, files in os.walk(data_path):
42 | for file in files:
43 | image = os.path.join(path, file)
44 | if path.__contains__('trainA') :
45 | trainA.append(misc.imresize(misc.imread(image, mode='RGB'), [size, size]))
46 | if path.__contains__('trainB') :
47 | trainB.append(misc.imresize(misc.imread(image, mode='RGB'), [size, size]))
48 |
49 |
50 | trainA = preprocessing(np.asarray(trainA))
51 | trainB = preprocessing(np.asarray(trainB))
52 |
53 | np.random.shuffle(trainA)
54 | np.random.shuffle(trainB)
55 |
56 | return trainA, trainB
57 |
58 | def test_data(dataset_name, size) :
59 | data_path = os.path.join("./dataset", dataset_name)
60 | testA = []
61 | testB = []
62 | for path, dir, files in os.walk(data_path) :
63 | for file in files :
64 | image = os.path.join(path, file)
65 | if path.__contains__('testA') :
66 | testA.append(misc.imresize(misc.imread(image, mode='RGB'), [size, size]))
67 | if path.__contains__('testB') :
68 | testB.append(misc.imresize(misc.imread(image, mode='RGB'), [size, size]))
69 |
70 | testA = preprocessing(np.asarray(testA))
71 | testB = preprocessing(np.asarray(testB))
72 |
73 | return testA, testB
74 |
75 | def load_test_data(image_path, size=256):
76 | img = misc.imread(image_path, mode='RGB')
77 | img = misc.imresize(img, [size, size])
78 | img = np.expand_dims(img, axis=0)
79 | img = preprocessing(img)
80 |
81 | return img
82 |
83 | def preprocessing(x):
84 | """
85 | # Create Normal distribution
86 | x = x.astype('float32')
87 | x[:, :, :, 0] = (x[:, :, :, 0] - np.mean(x[:, :, :, 0])) / np.std(x[:, :, :, 0])
88 | x[:, :, :, 1] = (x[:, :, :, 1] - np.mean(x[:, :, :, 1])) / np.std(x[:, :, :, 1])
89 | x[:, :, :, 2] = (x[:, :, :, 2] - np.mean(x[:, :, :, 2])) / np.std(x[:, :, :, 2])
90 | """
91 | x = x/127.5 - 1 # -1 ~ 1
92 | return x
93 |
94 | def augmentation(image, augment_size):
95 | seed = random.randint(0, 2 ** 31 - 1)
96 | ori_image_shape = tf.shape(image)
97 | image = tf.image.resize_images(image, [augment_size, augment_size])
98 | image = tf.random_crop(image, ori_image_shape, seed=seed)
99 | image = tf.map_fn(lambda x: tf.image.random_flip_left_right(x, seed), image)
100 | return image
101 |
102 | def save_images(images, size, image_path):
103 | return imsave(inverse_transform(images), size, image_path)
104 |
105 | def inverse_transform(images):
106 | return (images+1.) / 2
107 |
108 | def imsave(images, size, path):
109 | return misc.imsave(path, merge(images, size))
110 |
111 | def merge(images, size):
112 | h, w = images.shape[1], images.shape[2]
113 | img = np.zeros((h * size[0], w * size[1], 3))
114 | for idx, image in enumerate(images):
115 | i = idx % size[1]
116 | j = idx // size[1]
117 | img[h*j:h*(j+1), w*i:w*(i+1), :] = image
118 |
119 | return img
120 |
121 | def show_all_variables():
122 | model_vars = tf.trainable_variables()
123 | slim.model_analyzer.analyze_vars(model_vars, print_info=True)
124 |
125 | def check_folder(log_dir):
126 | if not os.path.exists(log_dir):
127 | os.makedirs(log_dir)
128 | return log_dir
--------------------------------------------------------------------------------