├── .DS_Store
├── .gitignore
├── FUNIT.py
├── LICENSE
├── README.md
├── assets
├── .DS_Store
├── animal.gif
├── architecture.png
├── funit_example.jpg
├── our_result.png
└── process.png
├── main.py
├── ops.py
└── utils.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/FUNIT-Tensorflow/6a160e5690544359133fc0860cedf2a61dbdcaf9/.DS_Store
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/FUNIT.py:
--------------------------------------------------------------------------------
1 | from ops import *
2 | from utils import *
3 | import time
4 | from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch
5 | import numpy as np
6 | from tqdm import tqdm
7 |
8 |
9 | class FUNIT(object):
10 | def __init__(self, sess, args):
11 |
12 | self.phase = args.phase
13 | self.model_name = 'FUNIT'
14 |
15 | self.sess = sess
16 | self.checkpoint_dir = args.checkpoint_dir
17 | self.result_dir = args.result_dir
18 | self.log_dir = args.log_dir
19 | self.dataset_name = args.dataset
20 | self.augment_flag = args.augment_flag
21 |
22 | self.gpu_num = args.gpu_num
23 |
24 | self.iteration = args.iteration // args.gpu_num
25 |
26 | self.batch_size = args.batch_size
27 | self.print_freq = args.print_freq
28 | self.save_freq = args.save_freq
29 |
30 | self.lr = args.lr
31 | self.ch = args.ch
32 | self.ema_decay = args.ema_decay
33 |
34 | self.K = args.K
35 |
36 | self.gan_type = args.gan_type
37 |
38 |
39 | """ Weight """
40 | self.adv_weight = args.adv_weight
41 | self.recon_weight = args.recon_weight
42 | self.feature_weight = args.feature_weight
43 |
44 |
45 | """ Generator """
46 | self.latent_dim = args.latent_dim
47 |
48 | """ Discriminator """
49 | self.sn = args.sn
50 |
51 | self.img_height = args.img_height
52 | self.img_width = args.img_width
53 |
54 | self.img_ch = args.img_ch
55 |
56 | self.sample_dir = os.path.join(args.sample_dir, self.model_dir)
57 | check_folder(self.sample_dir)
58 |
59 | self.dataset_path = os.path.join('./dataset', self.dataset_name, 'train')
60 | self.class_dim = len(glob(self.dataset_path + '/*'))
61 |
62 | print()
63 |
64 | print("##### Information #####")
65 | print("# dataset : ", self.dataset_name)
66 | print("# batch_size : ", self.batch_size)
67 | print("# max iteration : ", self.iteration)
68 | print("# gpu num : ", self.gpu_num)
69 |
70 | print()
71 |
72 | print("##### Generator #####")
73 | print("# latent_dim : ", self.latent_dim)
74 |
75 | print()
76 |
77 | print("##### Discriminator #####")
78 | print("# spectral normalization : ", self.sn)
79 |
80 | print()
81 |
82 | print("##### Weight #####")
83 | print("# adv_weight : ", self.adv_weight)
84 | print("# feature_weight : ", self.feature_weight)
85 | print("# recon_weight : ", self.recon_weight)
86 |
87 | print()
88 |
89 | ##################################################################################
90 | # Generator
91 | ##################################################################################
92 |
93 | def content_encoder(self, x_init, reuse=tf.AUTO_REUSE, scope='content_encoder'):
94 | channel = self.ch
95 | with tf.variable_scope(scope, reuse=reuse) :
96 | x = conv(x_init, channel, kernel=7, stride=1, pad=3, pad_type='reflect',scope='conv')
97 | x = instance_norm(x, scope='ins_norm')
98 | x = relu(x)
99 |
100 | for i in range(3) :
101 | x = conv(x, channel * 2, kernel=4, stride=2, pad=1, pad_type='reflect', scope='conv_' + str(i))
102 | x = instance_norm(x, scope='ins_norm_' + str(i))
103 | x = relu(x)
104 |
105 | channel = channel * 2
106 |
107 | for i in range(2) :
108 | x = resblock(x, channel, scope='resblock_' + str(i))
109 |
110 | return x
111 |
112 | def class_encoder(self, x_init, reuse=tf.AUTO_REUSE, scope='class_encoder'):
113 | channel = self.ch
114 | with tf.variable_scope(scope, reuse=reuse) :
115 | x = conv(x_init, channel, kernel=7, stride=1, pad=3, pad_type='reflect', scope='conv')
116 | x = relu(x)
117 |
118 | for i in range(2) :
119 | x = conv(x, channel * 2, kernel=4, stride=2, pad=1, pad_type='reflect', scope='conv_' + str(i))
120 | x = relu(x)
121 |
122 | channel = channel * 2
123 |
124 | for i in range(2) :
125 | x = conv(x, channel, kernel=4, stride=2, pad=1, pad_type='reflect', scope='fix_conv_' + str(i))
126 | x = relu(x)
127 |
128 | x = global_avg_pooling(x)
129 | x = conv(x, channels=self.latent_dim, kernel=1, stride=1, scope='style_logit')
130 |
131 | return x
132 |
133 | def generator(self, content, style, reuse=tf.AUTO_REUSE, scope="generator"):
134 | channel = self.ch * 8 # 512
135 | with tf.variable_scope(scope, reuse=reuse):
136 | x = content
137 |
138 | mu, var = self.MLP(style, channel // 2, scope='MLP')
139 |
140 | for i in range(2) :
141 | idx = 2 * i
142 | x = adaptive_resblock(x, channel, mu[idx], var[idx], mu[idx + 1], var[idx + 1], scope='ada_resbloack_' + str(i))
143 |
144 | for i in range(3) :
145 |
146 | x = up_sample(x, scale_factor=2)
147 | x = conv(x, channel//2, kernel=5, stride=1, pad=2, pad_type='reflect', scope='up_conv_' + str(i))
148 | x = instance_norm(x, scope='ins_norm_' + str(i))
149 | x = relu(x)
150 |
151 | channel = channel // 2
152 |
153 | x = conv(x, channels=self.img_ch, kernel=7, stride=1, pad=3, pad_type='reflect', scope='g_logit')
154 | x = tanh(x)
155 |
156 | return x
157 |
158 | def MLP(self, style, channel, scope='MLP'):
159 | with tf.variable_scope(scope):
160 | x = style
161 |
162 | for i in range(2) :
163 | x = fully_connected(x, channel, scope='FC_' + str(i))
164 | x = relu(x)
165 |
166 | mu_list = []
167 | var_list = []
168 |
169 | for i in range(4) :
170 | mu = fully_connected(x, channel * 2, scope='FC_mu_' + str(i))
171 | var = fully_connected(x, channel * 2, scope='FC_var_' + str(i))
172 |
173 | mu = tf.reshape(mu, shape=[-1, 1, 1, channel * 2])
174 | var = tf.reshape(var, shape=[-1, 1, 1, channel * 2])
175 |
176 | mu_list.append(mu)
177 | var_list.append(var)
178 |
179 |
180 | return mu_list, var_list
181 |
182 |
183 | ##################################################################################
184 | # Discriminator
185 | ##################################################################################
186 |
187 | def discriminator(self, x_init, class_onehot, reuse=tf.AUTO_REUSE, scope="discriminator"):
188 | channel = self.ch
189 | class_onehot = tf.reshape(class_onehot, shape=[self.batch_size, 1, 1, -1])
190 |
191 | with tf.variable_scope(scope, reuse=reuse):
192 | x = conv(x_init, channel, kernel=7, stride=1, pad=3, pad_type='reflect', sn=self.sn, scope='conv')
193 |
194 | for i in range(4) :
195 | x = pre_resblock(x, channel * 2, sn=self.sn, scope='front_resblock_0_' + str(i))
196 | x = pre_resblock(x, channel * 2, sn=self.sn, scope='front_resblock_1_' + str(i))
197 | x = down_sample_avg(x, scale_factor=2)
198 |
199 | channel = channel * 2
200 |
201 | for i in range(2) :
202 | x = pre_resblock(x, channel, sn=self.sn, scope='back_resblock_' + str(i))
203 |
204 | x_feature = x
205 | x = lrelu(x, 0.2)
206 |
207 | x = conv(x, channels=self.class_dim, kernel=1, stride=1, sn=self.sn, scope='d_logit')
208 | x = tf.reduce_sum(x * class_onehot, axis=-1, keepdims=True) # [1, 0, 0, 0, 0]
209 |
210 | return x, x_feature
211 |
212 | ##################################################################################
213 | # Model
214 | ##################################################################################
215 |
216 |
217 | def build_model(self):
218 |
219 | if self.phase == 'train' :
220 | """ Input Image"""
221 | img_data_class = Image_data(self.img_height, self.img_width, self.img_ch, self.dataset_path, self.augment_flag)
222 | img_data_class.preprocess()
223 |
224 | self.dataset_num = len(img_data_class.image_list)
225 |
226 |
227 | img_and_class = tf.data.Dataset.from_tensor_slices((img_data_class.image_list, img_data_class.class_list))
228 |
229 | gpu_device = '/gpu:0'
230 | img_and_class = img_and_class.apply(shuffle_and_repeat(self.dataset_num)).apply(
231 | map_and_batch(img_data_class.image_processing, batch_size=self.batch_size * self.gpu_num, num_parallel_batches=16,
232 | drop_remainder=True)).apply(prefetch_to_device(gpu_device, None))
233 |
234 |
235 | img_and_class_iterator = img_and_class.make_one_shot_iterator()
236 |
237 | self.content_img, self.content_class = img_and_class_iterator.get_next()
238 | self.style_img, self.style_class = img_and_class_iterator.get_next()
239 |
240 | self.content_img = tf.split(self.content_img, num_or_size_splits=self.gpu_num)
241 | self.content_class = tf.split(self.content_class, num_or_size_splits=self.gpu_num)
242 | self.style_img = tf.split(self.style_img, num_or_size_splits=self.gpu_num)
243 | self.style_class = tf.split(self.style_class, num_or_size_splits=self.gpu_num)
244 |
245 | self.fake_img = []
246 |
247 | d_adv_losses = []
248 | g_adv_losses = []
249 | g_recon_losses = []
250 | g_feature_losses = []
251 |
252 |
253 | for gpu_id in range(self.gpu_num):
254 | with tf.device(tf.DeviceSpec(device_type="GPU", device_index=gpu_id)):
255 | with tf.variable_scope(tf.get_variable_scope(), reuse=(gpu_id > 0)):
256 | """ Define Generator, Discriminator """
257 | content_code = self.content_encoder(self.content_img[gpu_id])
258 | style_class_code = self.class_encoder(self.style_img[gpu_id])
259 | content_class_code = self.class_encoder(self.content_img[gpu_id])
260 |
261 | fake_img = self.generator(content_code, style_class_code)
262 | recon_img = self.generator(content_code, content_class_code)
263 |
264 | real_logit, style_feature_map = self.discriminator(self.style_img[gpu_id], self.style_class[gpu_id])
265 | fake_logit, fake_feature_map = self.discriminator(fake_img, self.style_class[gpu_id])
266 |
267 | recon_logit, recon_feature_map = self.discriminator(recon_img, self.content_class[gpu_id])
268 | _, content_feature_map = self.discriminator(self.content_img[gpu_id], self.content_class[gpu_id])
269 |
270 | """ Define Loss """
271 | d_adv_loss = self.adv_weight * discriminator_loss(self.gan_type, real_logit, fake_logit, self.style_img[gpu_id])
272 | g_adv_loss = 0.5 * self.adv_weight * (generator_loss(self.gan_type, fake_logit) + generator_loss(self.gan_type, recon_logit))
273 |
274 | g_recon_loss = self.recon_weight * L1_loss(self.content_img[gpu_id], recon_img)
275 |
276 | content_feature_map = tf.reduce_mean(tf.reduce_mean(content_feature_map, axis=2), axis=1)
277 | recon_feature_map = tf.reduce_mean(tf.reduce_mean(recon_feature_map, axis=2), axis=1)
278 | fake_feature_map = tf.reduce_mean(tf.reduce_mean(fake_feature_map, axis=2), axis=1)
279 | style_feature_map = tf.reduce_mean(tf.reduce_mean(style_feature_map, axis=2), axis=1)
280 |
281 | g_feature_loss = self.feature_weight * (L1_loss(recon_feature_map, content_feature_map) + L1_loss(fake_feature_map, style_feature_map))
282 |
283 | d_adv_losses.append(d_adv_loss)
284 | g_adv_losses.append(g_adv_loss)
285 | g_recon_losses.append(g_recon_loss)
286 | g_feature_losses.append(g_feature_loss)
287 |
288 | self.fake_img.append(fake_img)
289 |
290 | self.g_loss = tf.reduce_mean(g_adv_losses) + \
291 | tf.reduce_mean(g_recon_losses) + \
292 | tf.reduce_mean(g_feature_losses) + regularization_loss('encoder') + regularization_loss('generator')
293 |
294 | self.d_loss = tf.reduce_mean(d_adv_losses) + regularization_loss('discriminator')
295 |
296 |
297 | """ Training """
298 | t_vars = tf.trainable_variables()
299 | G_vars = [var for var in t_vars if 'encoder' in var.name or 'generator' in var.name]
300 | D_vars = [var for var in t_vars if 'discriminator' in var.name]
301 |
302 | if self.gpu_num == 1 :
303 | prev_G_optim = tf.train.RMSPropOptimizer(self.lr, decay=0.99, epsilon=1e-8).minimize(self.g_loss, var_list=G_vars)
304 | self.D_optim = tf.train.RMSPropOptimizer(self.lr, decay=0.99, epsilon=1e-8).minimize(self.d_loss, var_list=D_vars)
305 | # Pytorch : decay=0.99, epsilon=1e-8
306 |
307 | else :
308 | prev_G_optim = tf.train.RMSPropOptimizer(self.lr, decay=0.99, epsilon=1e-8).minimize(self.g_loss, var_list=G_vars, colocate_gradients_with_ops=True)
309 | self.D_optim = tf.train.RMSPropOptimizer(self.lr, decay=0.99, epsilon=1e-8).minimize(self.d_loss, var_list=D_vars, colocate_gradients_with_ops=True)
310 | # Pytorch : decay=0.99, epsilon=1e-8
311 |
312 | self.ema = tf.train.ExponentialMovingAverage(decay=self.ema_decay)
313 | with tf.control_dependencies([prev_G_optim]):
314 | self.G_optim = self.ema.apply(G_vars)
315 |
316 |
317 | """" Summary """
318 | self.summary_g_loss = tf.summary.scalar("g_loss", self.g_loss)
319 | self.summary_d_loss = tf.summary.scalar("d_loss", self.d_loss)
320 |
321 | self.summary_g_adv_loss = tf.summary.scalar("g_adv_loss", tf.reduce_mean(g_adv_losses))
322 | self.summary_g_recon_loss = tf.summary.scalar("g_recon_loss", tf.reduce_mean(g_recon_losses))
323 | self.summary_g_feature_loss = tf.summary.scalar("g_feature_loss", tf.reduce_mean(g_feature_losses))
324 |
325 |
326 | g_summary_list = [self.summary_g_loss,
327 | self.summary_g_adv_loss,
328 | self.summary_g_recon_loss, self.summary_g_feature_loss
329 | ]
330 |
331 | d_summary_list = [self.summary_d_loss]
332 |
333 | self.summary_merge_g_loss = tf.summary.merge(g_summary_list)
334 | self.summary_merge_d_loss = tf.summary.merge(d_summary_list)
335 |
336 | else :
337 | """ Test """
338 | self.ema = tf.train.ExponentialMovingAverage(decay=self.ema_decay)
339 | self.test_content_img = tf.placeholder(tf.float32, [1, self.img_height, self.img_width, self.img_ch])
340 | self.test_class_img = tf.placeholder(tf.float32, [self.K, self.img_height, self.img_width, self.img_ch])
341 |
342 | test_content_code = self.content_encoder(self.test_content_img)
343 | test_style_class_code = tf.reduce_mean(self.class_encoder(self.test_class_img), axis=0, keepdims=True)
344 |
345 | self.test_fake_img = self.generator(test_content_code, test_style_class_code)
346 |
347 | def train(self):
348 | # initialize all variables
349 | tf.global_variables_initializer().run()
350 |
351 | # saver to save model
352 | self.saver = tf.train.Saver(max_to_keep=20)
353 |
354 | # summary writer
355 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_dir, self.sess.graph)
356 |
357 | # restore check-point if it exits
358 | could_load, checkpoint_counter = self.load(self.checkpoint_dir)
359 | if could_load:
360 | start_batch_id = checkpoint_counter
361 | counter = checkpoint_counter
362 | print(" [*] Load SUCCESS")
363 |
364 | else:
365 | start_batch_id = 0
366 | counter = 1
367 | print(" [!] Load failed...")
368 |
369 | # loop for epoch
370 | start_time = time.time()
371 | for idx in range(start_batch_id, self.iteration):
372 |
373 | # Update D
374 | _, d_loss, summary_str = self.sess.run([self.D_optim, self.d_loss, self.summary_merge_d_loss])
375 | self.writer.add_summary(summary_str, counter)
376 |
377 | # Update G
378 | content_images, style_images, fake_x_images, _, g_loss, summary_str = self.sess.run(
379 | [self.content_img[0], self.style_img[0], self.fake_img[0],
380 | self.G_optim,
381 | self.g_loss, self.summary_merge_g_loss])
382 |
383 | self.writer.add_summary(summary_str, counter)
384 |
385 |
386 | # display training status
387 | counter += 1
388 | print("iter: [%6d/%6d] time: %4.4f d_loss: %.8f, g_loss: %.8f" % (idx, self.iteration, time.time() - start_time, d_loss, g_loss))
389 |
390 | if np.mod(idx + 1, self.print_freq) == 0:
391 | content_images = np.expand_dims(content_images[0], axis=0)
392 | style_images = np.expand_dims(style_images[0], axis=0)
393 | fake_x_images = np.expand_dims(fake_x_images[0], axis=0)
394 |
395 | merge_images = np.concatenate([content_images, style_images, fake_x_images], axis=0)
396 |
397 | save_images(merge_images, [1, 3],
398 | './{}/merge_{:07d}.jpg'.format(self.sample_dir, idx + 1))
399 |
400 | # save_images(content_images, [1, 1],
401 | # './{}/content_{:07d}.jpg'.format(self.sample_dir, idx + 1))
402 | #
403 | # save_images(style_images, [1, 1],
404 | # './{}/style_{:07d}.jpg'.format(self.sample_dir, idx + 1))
405 | #
406 | # save_images(fake_x_images, [1, 1],
407 | # './{}/fake_{:07d}.jpg'.format(self.sample_dir, idx + 1))
408 |
409 |
410 | if np.mod(counter - 1, self.save_freq) == 0:
411 | self.save(self.checkpoint_dir, counter)
412 |
413 | # save model for final step
414 | self.save(self.checkpoint_dir, counter)
415 |
416 | @property
417 | def model_dir(self):
418 | if self.sn:
419 | sn = '_sn'
420 | else:
421 | sn = ''
422 |
423 | return "{}_{}_{}_{}adv_{}feature_{}recon{}".format(self.model_name, self.dataset_name, self.gan_type,
424 | self.adv_weight, self.feature_weight, self.recon_weight,
425 | sn)
426 |
427 | def save(self, checkpoint_dir, step):
428 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
429 |
430 | if not os.path.exists(checkpoint_dir):
431 | os.makedirs(checkpoint_dir)
432 |
433 | self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step)
434 |
435 | def load(self, checkpoint_dir):
436 | print(" [*] Reading checkpoints...")
437 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
438 |
439 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
440 | if ckpt and ckpt.model_checkpoint_path:
441 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
442 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
443 | counter = int(ckpt_name.split('-')[-1])
444 | print(" [*] Success to read {}".format(ckpt_name))
445 | return True, counter
446 | else:
447 | print(" [*] Failed to find a checkpoint")
448 | return False, 0
449 |
450 | def test(self):
451 | tf.global_variables_initializer().run()
452 |
453 | content_images = glob('./dataset/{}/{}/{}/*.*'.format(self.dataset_name, 'test', 'content'))
454 | class_images = glob('./dataset/{}/{}/{}/*.*'.format(self.dataset_name, 'test', 'class'))
455 |
456 | t_vars = tf.trainable_variables()
457 | G_vars = [var for var in t_vars if 'encoder' in var.name or 'generator' in var.name]
458 |
459 | shadow_G_vars_dict = {}
460 |
461 | for g_var in G_vars :
462 | shadow_G_vars_dict[self.ema.average_name(g_var)] = g_var
463 |
464 | self.saver = tf.train.Saver(shadow_G_vars_dict)
465 | could_load, checkpoint_counter = self.load(self.checkpoint_dir)
466 | self.result_dir = os.path.join(self.result_dir, self.model_dir)
467 | check_folder(self.result_dir)
468 |
469 | if could_load:
470 | print(" [*] Load SUCCESS")
471 | else:
472 | print(" [!] Load failed...")
473 |
474 | # write html for visual comparison
475 | index_path = os.path.join(self.result_dir, 'index.html')
476 | index = open(index_path, 'w')
477 | index.write("
")
478 | index.write("name | content | style | output |
")
479 |
480 | for sample_content_image in tqdm(content_images):
481 | sample_image = load_test_image(sample_content_image, self.img_width, self.img_height)
482 |
483 | random_class_images = np.random.choice(class_images, size=self.K, replace=False)
484 | sample_class_image = np.concatenate([load_test_image(x, self.img_width, self.img_height) for x in random_class_images])
485 |
486 | fake_path = os.path.join(self.result_dir, '{}'.format(os.path.basename(sample_content_image)))
487 | class_path = os.path.join(self.result_dir, 'style_{}'.format(os.path.basename(sample_content_image)))
488 |
489 | fake_img = self.sess.run(self.test_fake_img, feed_dict={self.test_content_img : sample_image, self.test_class_img : sample_class_image})
490 |
491 | save_images(fake_img, [1, 1], fake_path)
492 | save_images(sample_class_image, [1, self.K], class_path)
493 |
494 | index.write("%s | " % os.path.basename(sample_content_image))
495 | index.write(
496 | " | " % (sample_content_image if os.path.isabs(sample_content_image) else (
497 | '../..' + os.path.sep + sample_content_image), self.img_width, self.img_height))
498 |
499 | index.write(
500 | " | " % (class_path if os.path.isabs(class_path) else (
501 | '../..' + os.path.sep + class_path), self.img_width * self.K, self.img_height))
502 | index.write(
503 | " | " % (fake_path if os.path.isabs(fake_path) else (
504 | '../..' + os.path.sep + fake_path), self.img_width, self.img_height))
505 | index.write("")
506 |
507 | index.close()
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Junho Kim
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FUNIT-Tensorflow
2 | ## : Few-Shot Unsupervised Image-to-Image Translation (ICCV 2019)
3 |
4 |
5 |

6 |

7 |
8 |
9 | ### [Paper](https://arxiv.org/abs/1905.01723) | [Official Pytorch code](https://github.com/NVlabs/FUNIT)
10 |
11 | ### [Other Pytorch Implementation](https://github.com/znxlwm/FUNIT-pytorch)
12 |
13 | ## Usage
14 | ```
15 | ├── dataset
16 | └── YOUR_DATASET_NAME
17 | ├── train
18 | ├── class1 (class folder)
19 | ├── xxx.jpg (class1 image)
20 | ├── yyy.png
21 | ├── ...
22 | ├── class2
23 | ├── aaa.jpg (class2 image)
24 | ├── bbb.png
25 | ├── ...
26 | ├── class3
27 | ├── ...
28 | ├── test
29 | ├── content (content folder)
30 | ├── zzz.jpg (any content image)
31 | ├── www.png
32 | ├── ...
33 | ├── class (class folder)
34 | ├── ccc.jpg (unseen target class image)
35 | ├── ddd.jpg
36 | ├── ...
37 | ```
38 |
39 | ### Train
40 | ```
41 | > python main.py --dataset flower
42 | ```
43 |
44 | ### Test
45 | ```
46 | > python main.py --dataset flower --phase test
47 | ```
48 |
49 | ## Architecture
50 | 
51 |
52 | ## Our result
53 | 
54 |
55 | ## Paper result
56 | 
57 |
58 | ## Author
59 | [Junho Kim](http://bit.ly/jhkim_ai)
60 |
--------------------------------------------------------------------------------
/assets/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/FUNIT-Tensorflow/6a160e5690544359133fc0860cedf2a61dbdcaf9/assets/.DS_Store
--------------------------------------------------------------------------------
/assets/animal.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/FUNIT-Tensorflow/6a160e5690544359133fc0860cedf2a61dbdcaf9/assets/animal.gif
--------------------------------------------------------------------------------
/assets/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/FUNIT-Tensorflow/6a160e5690544359133fc0860cedf2a61dbdcaf9/assets/architecture.png
--------------------------------------------------------------------------------
/assets/funit_example.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/FUNIT-Tensorflow/6a160e5690544359133fc0860cedf2a61dbdcaf9/assets/funit_example.jpg
--------------------------------------------------------------------------------
/assets/our_result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/FUNIT-Tensorflow/6a160e5690544359133fc0860cedf2a61dbdcaf9/assets/our_result.png
--------------------------------------------------------------------------------
/assets/process.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/FUNIT-Tensorflow/6a160e5690544359133fc0860cedf2a61dbdcaf9/assets/process.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from FUNIT import FUNIT
2 | import argparse
3 | from utils import *
4 |
5 | """parsing and configuration"""
6 | def parse_args():
7 | desc = "Tensorflow implementation of FUNIT"
8 | parser = argparse.ArgumentParser(description=desc)
9 | parser.add_argument('--phase', type=str, default='train', choices=('train', 'test'), help='phase name')
10 | parser.add_argument('--dataset', type=str, default='flower', help='dataset_name')
11 |
12 | parser.add_argument('--iteration', type=int, default=800000, help='The number of training iterations')
13 | parser.add_argument('--batch_size', type=int, default=64, help='The size of batch size for each gpu')
14 | parser.add_argument('--print_freq', type=int, default=1000, help='The number of image_print_freq')
15 | parser.add_argument('--save_freq', type=int, default=10000, help='The number of ckpt_save_freq')
16 | parser.add_argument('--gpu_num', type=int, default=1, help='The number of gpu')
17 |
18 | parser.add_argument('--lr', type=float, default=0.0001, help='The learning rate')
19 | parser.add_argument('--ema_decay', type=float, default=0.999, help='ema decay value')
20 | parser.add_argument('--K', type=int, default=5, help='Test K')
21 |
22 | parser.add_argument('--gan_type', type=str, default='hinge', help='[gan / lsgan / hinge]')
23 |
24 | parser.add_argument('--adv_weight', type=int, default=1, help='Weight about GAN')
25 | parser.add_argument('--feature_weight', type=int, default=1, help='Weight about feature-matching')
26 | parser.add_argument('--recon_weight', type=int, default=0.1, help='Weight about reconstruction')
27 |
28 | parser.add_argument('--latent_dim', type=int, default=64, help='The dimension of class code')
29 | parser.add_argument('--ch', type=int, default=64, help='base channel number per layer')
30 |
31 | parser.add_argument('--sn', type=str2bool, default=False, help='using spectral norm')
32 |
33 | parser.add_argument('--img_height', type=int, default=128, help='The height size of image')
34 | parser.add_argument('--img_width', type=int, default=128, help='The width size of image ')
35 | parser.add_argument('--img_ch', type=int, default=3, help='The size of image channel')
36 | parser.add_argument('--augment_flag', type=str2bool, default=True, help='Image augmentation use or not')
37 |
38 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint',
39 | help='Directory name to save the checkpoints')
40 | parser.add_argument('--result_dir', type=str, default='results',
41 | help='Directory name to save the generated images')
42 | parser.add_argument('--log_dir', type=str, default='logs',
43 | help='Directory name to save training logs')
44 | parser.add_argument('--sample_dir', type=str, default='samples',
45 | help='Directory name to save the samples on training')
46 |
47 | return check_args(parser.parse_args())
48 |
49 | """checking arguments"""
50 | def check_args(args):
51 | # --checkpoint_dir
52 | check_folder(args.checkpoint_dir)
53 |
54 | # --result_dir
55 | check_folder(args.result_dir)
56 |
57 | # --log_dir
58 | check_folder(args.log_dir)
59 |
60 | # --sample_dir
61 | check_folder(args.sample_dir)
62 |
63 | # --batch_size
64 | try:
65 | assert args.batch_size >= 1
66 | except:
67 | print('batch size must be larger than or equal to one')
68 | return args
69 |
70 | """main"""
71 | def main():
72 | # parse arguments
73 | args = parse_args()
74 | if args is None:
75 | exit()
76 |
77 | # open session
78 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
79 | gan = FUNIT(sess, args)
80 |
81 | # build graph
82 | gan.build_model()
83 |
84 | # show network architecture
85 | show_all_variables()
86 |
87 | if args.phase == 'train' :
88 | gan.train()
89 | print(" [*] Training finished!")
90 |
91 | if args.phase == 'test' :
92 | gan.test()
93 | print(" [*] Test finished!")
94 |
95 |
96 |
97 | if __name__ == '__main__':
98 | main()
99 |
--------------------------------------------------------------------------------
/ops.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import tensorflow.contrib as tf_contrib
3 | from utils import pytorch_kaiming_weight_factor
4 |
5 | ##################################################################################
6 | # Initialization
7 | ##################################################################################
8 |
9 | factor, mode, uniform = pytorch_kaiming_weight_factor(a=0.0, uniform=False)
10 | weight_init = tf_contrib.layers.variance_scaling_initializer(factor=factor, mode=mode, uniform=uniform)
11 | # weight_init = tf.random_normal_initializer(mean=0.0, stddev=0.02)
12 |
13 | weight_regularizer = tf.contrib.layers.l2_regularizer(0.0001)
14 | weight_regularizer_fully = tf.contrib.layers.l2_regularizer(0.0001)
15 |
16 | ##################################################################################
17 | # Layer
18 | ##################################################################################
19 |
20 | def conv(x, channels, kernel=4, stride=2, pad=0, pad_type='zero', use_bias=True, sn=False, scope='conv_0'):
21 | with tf.variable_scope(scope):
22 | if pad > 0:
23 | h = x.get_shape().as_list()[1]
24 | if h % stride == 0:
25 | pad = pad * 2
26 | else:
27 | pad = max(kernel - (h % stride), 0)
28 |
29 | pad_top = pad // 2
30 | pad_bottom = pad - pad_top
31 | pad_left = pad // 2
32 | pad_right = pad - pad_left
33 |
34 | if pad_type == 'zero':
35 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]])
36 | if pad_type == 'reflect':
37 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]], mode='REFLECT')
38 |
39 | if sn:
40 | w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], initializer=weight_init,
41 | regularizer=weight_regularizer)
42 | x = tf.nn.conv2d(input=x, filter=spectral_norm(w),
43 | strides=[1, stride, stride, 1], padding='VALID')
44 | if use_bias:
45 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0))
46 | x = tf.nn.bias_add(x, bias)
47 |
48 | else:
49 | x = tf.layers.conv2d(inputs=x, filters=channels,
50 | kernel_size=kernel, kernel_initializer=weight_init,
51 | kernel_regularizer=weight_regularizer,
52 | strides=stride, use_bias=use_bias)
53 |
54 | return x
55 |
56 |
57 | def fully_connected(x, units, use_bias=True, sn=False, scope='linear'):
58 | with tf.variable_scope(scope):
59 | x = flatten(x)
60 | shape = x.get_shape().as_list()
61 | channels = shape[-1]
62 |
63 | if sn:
64 | w = tf.get_variable("kernel", [channels, units], tf.float32,
65 | initializer=weight_init, regularizer=weight_regularizer_fully)
66 | if use_bias:
67 | bias = tf.get_variable("bias", [units],
68 | initializer=tf.constant_initializer(0.0))
69 |
70 | x = tf.matmul(x, spectral_norm(w)) + bias
71 | else:
72 | x = tf.matmul(x, spectral_norm(w))
73 |
74 | else:
75 | x = tf.layers.dense(x, units=units, kernel_initializer=weight_init,
76 | kernel_regularizer=weight_regularizer_fully,
77 | use_bias=use_bias)
78 |
79 | return x
80 |
81 |
82 | def flatten(x):
83 | return tf.layers.flatten(x)
84 |
85 |
86 | ##################################################################################
87 | # Residual-block
88 | ##################################################################################
89 |
90 |
91 | def resblock(x_init, channels, use_bias=True, sn=False, scope='resblock'):
92 | with tf.variable_scope(scope):
93 | with tf.variable_scope('res1'):
94 | x = conv(x_init, channels, kernel=3, stride=1, pad=1, use_bias=use_bias, sn=sn)
95 | x = instance_norm(x)
96 | x = relu(x)
97 |
98 | with tf.variable_scope('res2'):
99 | x = conv(x, channels, kernel=3, stride=1, pad=1, use_bias=use_bias, sn=sn)
100 | x = instance_norm(x)
101 |
102 | return x + x_init
103 |
104 | def pre_resblock(x_init, channels, use_bias=True, sn=False, scope='resblock'):
105 | with tf.variable_scope(scope):
106 | _, _, _, init_channel = x_init.get_shape().as_list()
107 |
108 | with tf.variable_scope('res1'):
109 | x = lrelu(x_init, 0.2)
110 | x = conv(x, channels, kernel=3, stride=1, pad=1, use_bias=use_bias, sn=sn)
111 |
112 | with tf.variable_scope('res2'):
113 | x = lrelu(x, 0.2)
114 | x = conv(x, channels, kernel=3, stride=1, pad=1, use_bias=use_bias, sn=sn)
115 |
116 | if init_channel != channels :
117 | with tf.variable_scope('shortcut'):
118 | x_init = conv(x_init, channels, kernel=1, stride=1, use_bias=False, sn=sn)
119 |
120 | return x + x_init
121 |
122 | def adaptive_resblock(x_init, channels, gamma1, beta1, gamma2, beta2, use_bias=True, sn=False, scope='adaptive_resblock') :
123 | with tf.variable_scope(scope):
124 | with tf.variable_scope('res1'):
125 | x = conv(x_init, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias, sn=sn)
126 | x = adaptive_instance_norm(x, gamma1, beta1)
127 | x = relu(x)
128 |
129 | with tf.variable_scope('res2'):
130 | x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias, sn=sn)
131 | x = adaptive_instance_norm(x, gamma2, beta2)
132 |
133 | return x + x_init
134 |
135 |
136 | ##################################################################################
137 | # Sampling
138 | ##################################################################################
139 |
140 | def up_sample(x, scale_factor=2):
141 | _, h, w, _ = x.get_shape().as_list()
142 | new_size = [h * scale_factor, w * scale_factor]
143 | return tf.image.resize_nearest_neighbor(x, size=new_size)
144 |
145 | def down_sample_avg(x, scale_factor=2):
146 | return tf.layers.average_pooling2d(x, pool_size=3, strides=scale_factor, padding='SAME')
147 |
148 | def global_avg_pooling(x):
149 | gap = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
150 | return gap
151 |
152 |
153 | ##################################################################################
154 | # Activation function
155 | ##################################################################################
156 |
157 | def lrelu(x, alpha=0.01):
158 | # pytorch alpha is 0.01
159 | return tf.nn.leaky_relu(x, alpha)
160 |
161 |
162 | def relu(x):
163 | return tf.nn.relu(x)
164 |
165 |
166 | def tanh(x):
167 | return tf.tanh(x)
168 |
169 |
170 | ##################################################################################
171 | # Normalization function
172 | ##################################################################################
173 |
174 | def instance_norm(x, scope='instance_norm'):
175 | return tf_contrib.layers.instance_norm(x,
176 | epsilon=1e-05,
177 | center=True, scale=True,
178 | scope=scope)
179 |
180 | def param_free_norm(x, epsilon=1e-5):
181 | x_mean, x_var = tf.nn.moments(x, axes=[1, 2], keep_dims=True)
182 | x_std = tf.sqrt(x_var + epsilon)
183 |
184 | return (x - x_mean) / x_std
185 |
186 | def adaptive_instance_norm(content, gamma, beta, epsilon=1e-5):
187 | # gamma, beta = style_mean, style_std from MLP
188 |
189 | x = param_free_norm(content, epsilon)
190 |
191 | return gamma * x + beta
192 |
193 | def spectral_norm(w, iteration=1):
194 | w_shape = w.shape.as_list()
195 | w = tf.reshape(w, [-1, w_shape[-1]])
196 |
197 | u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.random_normal_initializer(), trainable=False)
198 |
199 | u_hat = u
200 | v_hat = None
201 | for i in range(iteration):
202 | """
203 | power iteration
204 | Usually iteration = 1 will be enough
205 | """
206 | v_ = tf.matmul(u_hat, tf.transpose(w))
207 | v_hat = tf.nn.l2_normalize(v_)
208 |
209 | u_ = tf.matmul(v_hat, w)
210 | u_hat = tf.nn.l2_normalize(u_)
211 |
212 | u_hat = tf.stop_gradient(u_hat)
213 | v_hat = tf.stop_gradient(v_hat)
214 |
215 | sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat))
216 |
217 | with tf.control_dependencies([u.assign(u_hat)]):
218 | w_norm = w / sigma
219 | w_norm = tf.reshape(w_norm, w_shape)
220 |
221 | return w_norm
222 |
223 |
224 | ##################################################################################
225 | # Loss function
226 | ##################################################################################
227 |
228 | def L1_loss(x, y):
229 | loss = tf.reduce_mean(tf.abs(x - y)) # [64, h, w, c]
230 |
231 | return loss
232 |
233 | def discriminator_loss(gan_type, real_logit, fake_logit, real_images):
234 | real_loss = 0
235 | fake_loss = 0
236 |
237 | if gan_type == 'lsgan':
238 | real_loss = tf.reduce_mean(tf.squared_difference(real_logit, 1.0))
239 | fake_loss = tf.reduce_mean(tf.square(fake_logit))
240 |
241 | if gan_type == 'gan':
242 | real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real_logit), logits=real_logit))
243 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(fake_logit), logits=fake_logit))
244 |
245 | if gan_type == 'hinge':
246 |
247 | real_loss = tf.reduce_mean(relu(1 - real_logit))
248 | fake_loss = tf.reduce_mean(relu(1 + fake_logit))
249 |
250 | return real_loss + fake_loss + real_gp(real_images, real_logit)
251 |
252 | def real_gp(real_images, real_logit) :
253 | grad_out = tf.gradients(tf.reduce_mean(real_logit), [real_images])[0]
254 | grad_out2 = tf.square(grad_out)
255 |
256 | r1_penalty = 10 * tf.reduce_mean(tf.reduce_sum(grad_out2, axis=[1, 2, 3]))
257 |
258 | return r1_penalty
259 |
260 | def generator_loss(gan_type, fake_logit):
261 | fake_loss = 0
262 |
263 | if gan_type == 'lsgan':
264 | fake_loss = tf.reduce_mean(tf.squared_difference(fake_logit, 1.0))
265 |
266 | if gan_type == 'gan':
267 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(fake_logit), logits=fake_logit))
268 |
269 | if gan_type == 'hinge':
270 | fake_loss = -tf.reduce_mean(fake_logit)
271 |
272 | return fake_loss
273 |
274 |
275 | def regularization_loss(scope_name):
276 | """
277 | If you want to use "Regularization"
278 | g_loss += regularization_loss('generator')
279 | d_loss += regularization_loss('discriminator')
280 | """
281 | collection_regularization = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
282 |
283 | loss = []
284 | for item in collection_regularization:
285 | if scope_name in item.name:
286 | loss.append(item)
287 |
288 | return tf.reduce_sum(loss)
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.contrib import slim
3 | import os
4 | import numpy as np
5 | from glob import glob
6 | import cv2
7 |
8 | class Image_data:
9 |
10 | def __init__(self, img_height, img_width, channels, dataset_path, augment_flag):
11 | self.img_height = img_height
12 | self.img_width = img_width
13 | self.channels = channels
14 | self.augment_flag = augment_flag
15 |
16 | self.dataset_path = dataset_path
17 |
18 |
19 | self.image_list = []
20 | self.class_list = []
21 |
22 |
23 | def image_processing(self, filename, label):
24 | x = tf.read_file(filename)
25 | x_decode = tf.image.decode_jpeg(x, channels=self.channels, dct_method='INTEGER_ACCURATE')
26 | img = preprocess_fit_train_image(x_decode, self.img_height, self.img_width)
27 |
28 |
29 | if self.augment_flag :
30 | augment_height_size = self.img_height + (30 if self.img_height == 256 else int(self.img_height * 0.1))
31 | augment_width_size = self.img_width + (30 if self.img_width == 256 else int(self.img_width * 0.1))
32 |
33 | img = tf.cond(pred=tf.greater_equal(tf.random_uniform(shape=[], minval=0.0, maxval=1.0), 0.5),
34 | true_fn=lambda : augmentation(img, augment_height_size, augment_width_size),
35 | false_fn=lambda : img)
36 |
37 | return img, label
38 |
39 | def preprocess(self):
40 | self.class_label = [os.path.basename(x) for x in glob(self.dataset_path + '/*')]
41 |
42 | v = 0
43 |
44 | for class_label in self.class_label :
45 | class_one_hot = list(get_one_hot(v, len(self.class_label))) # [1, 0, 0, 0, 0]
46 | v = v+1
47 |
48 | image_list = glob(os.path.join(self.dataset_path, class_label) + '/*.png') + glob(os.path.join(self.dataset_path, class_label) + '/*.jpg')
49 | class_one_hot = [class_one_hot] * len(image_list)
50 |
51 | self.image_list.extend(image_list)
52 | self.class_list.extend(class_one_hot)
53 |
54 | def load_test_image(image_path, img_width, img_height):
55 |
56 | img = cv2.imread(image_path, flags=cv2.IMREAD_COLOR)
57 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
58 |
59 | img = cv2.resize(img, dsize=(img_width, img_height))
60 | img = np.expand_dims(img, axis=0)
61 |
62 | img = adjust_dynamic_range(img)
63 |
64 | return img
65 |
66 |
67 | def preprocessing(x):
68 | x = x/127.5 - 1 # -1 ~ 1
69 | return x
70 |
71 | def preprocess_fit_train_image(images, height, width):
72 | images = tf.image.resize(images, size=[height, width], method=tf.image.ResizeMethod.BILINEAR)
73 | images = adjust_dynamic_range(images)
74 |
75 | return images
76 |
77 | def adjust_dynamic_range(images):
78 | drange_in = [0.0, 255.0]
79 | drange_out = [-1.0, 1.0]
80 | scale = (drange_out[1] - drange_out[0]) / (drange_in[1] - drange_in[0])
81 | bias = drange_out[0] - drange_in[0] * scale
82 | images = images * scale + bias
83 | return images
84 |
85 | def augmentation(image, augment_height, augment_width):
86 | seed = np.random.randint(0, 2 ** 31 - 1)
87 |
88 | ori_image_shape = tf.shape(image)
89 | image = tf.image.random_flip_left_right(image, seed=seed)
90 | image = tf.image.resize(image, size=[augment_height, augment_width], method=tf.image.ResizeMethod.BILINEAR)
91 | image = tf.random_crop(image, ori_image_shape, seed=seed)
92 |
93 |
94 | return image
95 |
96 | def save_images(images, size, image_path):
97 | return imsave(images, size, image_path)
98 |
99 | def imsave(images, size, path):
100 | images = merge(images, size)
101 | images = post_process_generator_output(images)
102 | images = cv2.cvtColor(images.astype('uint8'), cv2.COLOR_RGB2BGR)
103 | cv2.imwrite(path, images)
104 |
105 | def post_process_generator_output(generator_output):
106 |
107 | drange_min, drange_max = -1.0, 1.0
108 | scale = 255.0 / (drange_max - drange_min)
109 |
110 | scaled_image = generator_output * scale + (0.5 - drange_min * scale)
111 | scaled_image = np.clip(scaled_image, 0, 255)
112 |
113 | return scaled_image
114 |
115 | def merge(images, size):
116 | h, w = images.shape[1], images.shape[2]
117 | c = images.shape[3]
118 | img = np.zeros((h * size[0], w * size[1], c))
119 | for idx, image in enumerate(images):
120 | i = idx % size[1]
121 | j = idx // size[1]
122 | img[h*j:h*(j+1), w*i:w*(i+1), :] = image
123 |
124 | return img
125 |
126 | def show_all_variables():
127 | model_vars = tf.trainable_variables()
128 | slim.model_analyzer.analyze_vars(model_vars, print_info=True)
129 |
130 | def check_folder(log_dir):
131 | if not os.path.exists(log_dir):
132 | os.makedirs(log_dir)
133 | return log_dir
134 |
135 | def str2bool(x):
136 | return x.lower() in ('true')
137 |
138 | def get_one_hot(targets, nb_classes):
139 |
140 | x = np.eye(nb_classes)[targets]
141 |
142 | return x
143 |
144 | def pytorch_xavier_weight_factor(gain=0.02, uniform=False) :
145 |
146 | if uniform :
147 | factor = gain * gain
148 | mode = 'FAN_AVG'
149 | else :
150 | factor = (gain * gain) / 1.3
151 | mode = 'FAN_AVG'
152 |
153 | return factor, mode, uniform
154 |
155 | def pytorch_kaiming_weight_factor(a=0.0, activation_function='leaky_relu', uniform=False) :
156 |
157 | if activation_function == 'relu' :
158 | gain = np.sqrt(2.0)
159 | elif activation_function == 'leaky_relu' :
160 | gain = np.sqrt(2.0 / (1 + a ** 2))
161 | elif activation_function == 'tanh' :
162 | gain = 5.0 / 3
163 | else :
164 | gain = 1.0
165 |
166 | if uniform :
167 | factor = gain * gain
168 | mode = 'FAN_IN'
169 | else :
170 | factor = (gain * gain) / 1.3
171 | mode = 'FAN_IN'
172 |
173 | return factor, mode, uniform
174 |
--------------------------------------------------------------------------------