├── .DS_Store
├── .gitignore
├── LICENSE
├── README.md
├── UGATIT.py
├── assets
├── .DS_Store
├── ablation.png
├── discriminator_fix.png
├── generator_fix.png
├── kid_fix2.png
├── teaser.png
└── user_study.png
├── main.py
├── ops.py
└── utils.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/UGATIT/d508e8f5188e47000d79d8aecada0cc9119e0d56/.DS_Store
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Junho Kim
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## U-GAT-IT — Official TensorFlow Implementation (ICLR 2020)
2 | ### : Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization for Image-to-Image Translation
3 |
4 |
5 |

6 |
7 |
8 | ### [Paper](https://arxiv.org/abs/1907.10830) | [Official Pytorch code](https://github.com/znxlwm/UGATIT-pytorch)
9 | This repository provides the **official Tensorflow implementation** of the following paper:
10 |
11 | > **U-GAT-IT: Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization for Image-to-Image Translation**
12 | > **Junho Kim (NCSOFT)**, Minjae Kim (NCSOFT), Hyeonwoo Kang (NCSOFT), Kwanghee Lee (Boeing Korea)
13 | >
14 | > **Abstract** *We propose a novel method for unsupervised image-to-image translation, which incorporates a new attention module and a new learnable normalization function in an end-to-end manner. The attention module guides our model to focus on more important regions distinguishing between source and target domains based on the attention map obtained by the auxiliary classifier. Unlike previous attention-based methods which cannot handle the geometric changes between domains, our model can translate both images requiring holistic changes and images requiring large shape changes. Moreover, our new AdaLIN (Adaptive Layer-Instance Normalization) function helps our attention-guided model to flexibly control the amount of change in shape and texture by learned parameters depending on datasets. Experimental results show the superiority of the proposed method compared to the existing state-of-the-art models with a fixed network architecture and hyper-parameters.*
15 |
16 | ## Requirements
17 | * python == 3.6
18 | * tensorflow == 1.14
19 |
20 | ## Pretrained model
21 | > We released 50 epoch and 100 epoch checkpoints so that people could test more widely.
22 | * [selfie2anime checkpoint (50 epoch)](https://drive.google.com/file/d/1V6GbSItG3HZKv3quYs7AP0rr1kOCT3QO/view?usp=sharing)
23 | * [selfie2anime checkpoint (100 epoch)](https://drive.google.com/file/d/19xQK2onIy-3S5W5K-XIh85pAg_RNvBVf/view?usp=sharing)
24 |
25 | ## Dataset
26 | * [selfie2anime dataset](https://drive.google.com/file/d/1xOWj1UVgp6NKMT3HbPhBbtq2A4EDkghF/view?usp=sharing)
27 |
28 | ## Web page
29 | * [Selfie2Anime](https://selfie2anime.com) by [Nathan Glover](https://github.com/t04glovern)
30 | * [Selfie2Waifu](https://waifu.lofiu.com) by [creke](https://github.com/creke)
31 |
32 | ## Telegram Bot
33 | * [Selfie2AnimeBot](https://t.me/selfie2animebot) by [Alex Spirin](https://github.com/sxela)
34 |
35 | ## Usage
36 | ```
37 | ├── dataset
38 | └── YOUR_DATASET_NAME
39 | ├── trainA
40 | ├── xxx.jpg (name, format doesn't matter)
41 | ├── yyy.png
42 | └── ...
43 | ├── trainB
44 | ├── zzz.jpg
45 | ├── www.png
46 | └── ...
47 | ├── testA
48 | ├── aaa.jpg
49 | ├── bbb.png
50 | └── ...
51 | └── testB
52 | ├── ccc.jpg
53 | ├── ddd.png
54 | └── ...
55 | ```
56 |
57 | ### Train
58 | ```
59 | > python main.py --dataset selfie2anime
60 | ```
61 | * If the memory of gpu is **not sufficient**, set `--light` to **True**
62 | * But it may **not** perform well
63 | * paper version is `--light` to **False**
64 |
65 | ### Test
66 | ```
67 | > python main.py --dataset selfie2anime --phase test
68 | ```
69 |
70 | ## Architecture
71 |
72 |

73 |
74 |
75 | ---
76 |
77 |
78 |

79 |
80 |
81 | ## Results
82 | ### Ablation study
83 |
84 |

85 |
86 |
87 | ### User study
88 |
89 |

90 |
91 |
92 | ### Kernel Inception Distance (KID)
93 |
94 |

95 |
96 |
97 | ## Citation
98 | If you find this code useful for your research, please cite our paper:
99 |
100 | ```
101 | @inproceedings{
102 | Kim2020U-GAT-IT:,
103 | title={U-GAT-IT: Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization for Image-to-Image Translation},
104 | author={Junho Kim and Minjae Kim and Hyeonwoo Kang and Kwang Hee Lee},
105 | booktitle={International Conference on Learning Representations},
106 | year={2020},
107 | url={https://openreview.net/forum?id=BJlZ5ySKPH}
108 | }
109 | ```
110 |
111 | ## Author
112 | [Junho Kim](http://bit.ly/jhkim_ai), Minjae Kim, Hyeonwoo Kang, Kwanghee Lee
113 |
--------------------------------------------------------------------------------
/UGATIT.py:
--------------------------------------------------------------------------------
1 | from ops import *
2 | from utils import *
3 | from glob import glob
4 | import time
5 | from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch
6 | import numpy as np
7 |
8 | class UGATIT(object) :
9 | def __init__(self, sess, args):
10 | self.light = args.light
11 |
12 | if self.light :
13 | self.model_name = 'UGATIT_light'
14 | else :
15 | self.model_name = 'UGATIT'
16 |
17 | self.sess = sess
18 | self.phase = args.phase
19 | self.checkpoint_dir = args.checkpoint_dir
20 | self.result_dir = args.result_dir
21 | self.log_dir = args.log_dir
22 | self.dataset_name = args.dataset
23 | self.augment_flag = args.augment_flag
24 |
25 | self.epoch = args.epoch
26 | self.iteration = args.iteration
27 | self.decay_flag = args.decay_flag
28 | self.decay_epoch = args.decay_epoch
29 |
30 | self.gan_type = args.gan_type
31 |
32 | self.batch_size = args.batch_size
33 | self.print_freq = args.print_freq
34 | self.save_freq = args.save_freq
35 |
36 | self.init_lr = args.lr
37 | self.ch = args.ch
38 |
39 | """ Weight """
40 | self.adv_weight = args.adv_weight
41 | self.cycle_weight = args.cycle_weight
42 | self.identity_weight = args.identity_weight
43 | self.cam_weight = args.cam_weight
44 | self.ld = args.GP_ld
45 | self.smoothing = args.smoothing
46 |
47 | """ Generator """
48 | self.n_res = args.n_res
49 |
50 | """ Discriminator """
51 | self.n_dis = args.n_dis
52 | self.n_critic = args.n_critic
53 | self.sn = args.sn
54 |
55 | self.img_size = args.img_size
56 | self.img_ch = args.img_ch
57 |
58 |
59 | self.sample_dir = os.path.join(args.sample_dir, self.model_dir)
60 | check_folder(self.sample_dir)
61 |
62 | # self.trainA, self.trainB = prepare_data(dataset_name=self.dataset_name, size=self.img_size
63 | self.trainA_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainA'))
64 | self.trainB_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainB'))
65 | self.dataset_num = max(len(self.trainA_dataset), len(self.trainB_dataset))
66 |
67 | print()
68 |
69 | print("##### Information #####")
70 | print("# light : ", self.light)
71 | print("# gan type : ", self.gan_type)
72 | print("# dataset : ", self.dataset_name)
73 | print("# max dataset number : ", self.dataset_num)
74 | print("# batch_size : ", self.batch_size)
75 | print("# epoch : ", self.epoch)
76 | print("# iteration per epoch : ", self.iteration)
77 | print("# smoothing : ", self.smoothing)
78 |
79 | print()
80 |
81 | print("##### Generator #####")
82 | print("# residual blocks : ", self.n_res)
83 |
84 | print()
85 |
86 | print("##### Discriminator #####")
87 | print("# discriminator layer : ", self.n_dis)
88 | print("# the number of critic : ", self.n_critic)
89 | print("# spectral normalization : ", self.sn)
90 |
91 | print()
92 |
93 | print("##### Weight #####")
94 | print("# adv_weight : ", self.adv_weight)
95 | print("# cycle_weight : ", self.cycle_weight)
96 | print("# identity_weight : ", self.identity_weight)
97 | print("# cam_weight : ", self.cam_weight)
98 |
99 | ##################################################################################
100 | # Generator
101 | ##################################################################################
102 |
103 | def generator(self, x_init, reuse=False, scope="generator"):
104 | channel = self.ch
105 | with tf.variable_scope(scope, reuse=reuse) :
106 | x = conv(x_init, channel, kernel=7, stride=1, pad=3, pad_type='reflect', scope='conv')
107 | x = instance_norm(x, scope='ins_norm')
108 | x = relu(x)
109 |
110 | # Down-Sampling
111 | for i in range(2) :
112 | x = conv(x, channel*2, kernel=3, stride=2, pad=1, pad_type='reflect', scope='conv_'+str(i))
113 | x = instance_norm(x, scope='ins_norm_'+str(i))
114 | x = relu(x)
115 |
116 | channel = channel * 2
117 |
118 | # Down-Sampling Bottleneck
119 | for i in range(self.n_res):
120 | x = resblock(x, channel, scope='resblock_' + str(i))
121 |
122 |
123 | # Class Activation Map
124 | cam_x = global_avg_pooling(x)
125 | cam_gap_logit, cam_x_weight = fully_connected_with_w(cam_x, scope='CAM_logit')
126 | x_gap = tf.multiply(x, cam_x_weight)
127 |
128 | cam_x = global_max_pooling(x)
129 | cam_gmp_logit, cam_x_weight = fully_connected_with_w(cam_x, reuse=True, scope='CAM_logit')
130 | x_gmp = tf.multiply(x, cam_x_weight)
131 |
132 |
133 | cam_logit = tf.concat([cam_gap_logit, cam_gmp_logit], axis=-1)
134 | x = tf.concat([x_gap, x_gmp], axis=-1)
135 |
136 | x = conv(x, channel, kernel=1, stride=1, scope='conv_1x1')
137 | x = relu(x)
138 |
139 | heatmap = tf.squeeze(tf.reduce_sum(x, axis=-1))
140 |
141 | # Gamma, Beta block
142 | gamma, beta = self.MLP(x, reuse=reuse)
143 |
144 | # Up-Sampling Bottleneck
145 | for i in range(self.n_res):
146 | x = adaptive_ins_layer_resblock(x, channel, gamma, beta, smoothing=self.smoothing, scope='adaptive_resblock' + str(i))
147 |
148 | # Up-Sampling
149 | for i in range(2) :
150 | x = up_sample(x, scale_factor=2)
151 | x = conv(x, channel//2, kernel=3, stride=1, pad=1, pad_type='reflect', scope='up_conv_'+str(i))
152 | x = layer_instance_norm(x, scope='layer_ins_norm_'+str(i))
153 | x = relu(x)
154 |
155 | channel = channel // 2
156 |
157 |
158 | x = conv(x, channels=3, kernel=7, stride=1, pad=3, pad_type='reflect', scope='G_logit')
159 | x = tanh(x)
160 |
161 | return x, cam_logit, heatmap
162 |
163 | def MLP(self, x, use_bias=True, reuse=False, scope='MLP'):
164 | channel = self.ch * self.n_res
165 |
166 | if self.light :
167 | x = global_avg_pooling(x)
168 |
169 | with tf.variable_scope(scope, reuse=reuse):
170 | for i in range(2) :
171 | x = fully_connected(x, channel, use_bias, scope='linear_' + str(i))
172 | x = relu(x)
173 |
174 |
175 | gamma = fully_connected(x, channel, use_bias, scope='gamma')
176 | beta = fully_connected(x, channel, use_bias, scope='beta')
177 |
178 | gamma = tf.reshape(gamma, shape=[self.batch_size, 1, 1, channel])
179 | beta = tf.reshape(beta, shape=[self.batch_size, 1, 1, channel])
180 |
181 | return gamma, beta
182 |
183 | ##################################################################################
184 | # Discriminator
185 | ##################################################################################
186 |
187 | def discriminator(self, x_init, reuse=False, scope="discriminator"):
188 | D_logit = []
189 | D_CAM_logit = []
190 | with tf.variable_scope(scope, reuse=reuse) :
191 | local_x, local_cam, local_heatmap = self.discriminator_local(x_init, reuse=reuse, scope='local')
192 | global_x, global_cam, global_heatmap = self.discriminator_global(x_init, reuse=reuse, scope='global')
193 |
194 | D_logit.extend([local_x, global_x])
195 | D_CAM_logit.extend([local_cam, global_cam])
196 |
197 | return D_logit, D_CAM_logit, local_heatmap, global_heatmap
198 |
199 | def discriminator_global(self, x_init, reuse=False, scope='discriminator_global'):
200 | with tf.variable_scope(scope, reuse=reuse):
201 | channel = self.ch
202 | x = conv(x_init, channel, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_0')
203 | x = lrelu(x, 0.2)
204 |
205 | for i in range(1, self.n_dis - 1):
206 | x = conv(x, channel * 2, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_' + str(i))
207 | x = lrelu(x, 0.2)
208 |
209 | channel = channel * 2
210 |
211 | x = conv(x, channel * 2, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='conv_last')
212 | x = lrelu(x, 0.2)
213 |
214 | channel = channel * 2
215 |
216 | cam_x = global_avg_pooling(x)
217 | cam_gap_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, scope='CAM_logit')
218 | x_gap = tf.multiply(x, cam_x_weight)
219 |
220 | cam_x = global_max_pooling(x)
221 | cam_gmp_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, reuse=True, scope='CAM_logit')
222 | x_gmp = tf.multiply(x, cam_x_weight)
223 |
224 | cam_logit = tf.concat([cam_gap_logit, cam_gmp_logit], axis=-1)
225 | x = tf.concat([x_gap, x_gmp], axis=-1)
226 |
227 | x = conv(x, channel, kernel=1, stride=1, scope='conv_1x1')
228 | x = lrelu(x, 0.2)
229 |
230 | heatmap = tf.squeeze(tf.reduce_sum(x, axis=-1))
231 |
232 |
233 | x = conv(x, channels=1, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='D_logit')
234 |
235 | return x, cam_logit, heatmap
236 |
237 | def discriminator_local(self, x_init, reuse=False, scope='discriminator_local'):
238 | with tf.variable_scope(scope, reuse=reuse) :
239 | channel = self.ch
240 | x = conv(x_init, channel, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_0')
241 | x = lrelu(x, 0.2)
242 |
243 | for i in range(1, self.n_dis - 2 - 1):
244 | x = conv(x, channel * 2, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_' + str(i))
245 | x = lrelu(x, 0.2)
246 |
247 | channel = channel * 2
248 |
249 | x = conv(x, channel * 2, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='conv_last')
250 | x = lrelu(x, 0.2)
251 |
252 | channel = channel * 2
253 |
254 | cam_x = global_avg_pooling(x)
255 | cam_gap_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, scope='CAM_logit')
256 | x_gap = tf.multiply(x, cam_x_weight)
257 |
258 | cam_x = global_max_pooling(x)
259 | cam_gmp_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, reuse=True, scope='CAM_logit')
260 | x_gmp = tf.multiply(x, cam_x_weight)
261 |
262 | cam_logit = tf.concat([cam_gap_logit, cam_gmp_logit], axis=-1)
263 | x = tf.concat([x_gap, x_gmp], axis=-1)
264 |
265 | x = conv(x, channel, kernel=1, stride=1, scope='conv_1x1')
266 | x = lrelu(x, 0.2)
267 |
268 | heatmap = tf.squeeze(tf.reduce_sum(x, axis=-1))
269 |
270 | x = conv(x, channels=1, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='D_logit')
271 |
272 | return x, cam_logit, heatmap
273 |
274 | ##################################################################################
275 | # Model
276 | ##################################################################################
277 |
278 | def generate_a2b(self, x_A, reuse=False):
279 | out, cam, _ = self.generator(x_A, reuse=reuse, scope="generator_B")
280 |
281 | return out, cam
282 |
283 | def generate_b2a(self, x_B, reuse=False):
284 | out, cam, _ = self.generator(x_B, reuse=reuse, scope="generator_A")
285 |
286 | return out, cam
287 |
288 | def discriminate_real(self, x_A, x_B):
289 | real_A_logit, real_A_cam_logit, _, _ = self.discriminator(x_A, scope="discriminator_A")
290 | real_B_logit, real_B_cam_logit, _, _ = self.discriminator(x_B, scope="discriminator_B")
291 |
292 | return real_A_logit, real_A_cam_logit, real_B_logit, real_B_cam_logit
293 |
294 | def discriminate_fake(self, x_ba, x_ab):
295 | fake_A_logit, fake_A_cam_logit, _, _ = self.discriminator(x_ba, reuse=True, scope="discriminator_A")
296 | fake_B_logit, fake_B_cam_logit, _, _ = self.discriminator(x_ab, reuse=True, scope="discriminator_B")
297 |
298 | return fake_A_logit, fake_A_cam_logit, fake_B_logit, fake_B_cam_logit
299 |
300 | def gradient_panalty(self, real, fake, scope="discriminator_A"):
301 | if self.gan_type.__contains__('dragan'):
302 | eps = tf.random_uniform(shape=tf.shape(real), minval=0., maxval=1.)
303 | _, x_var = tf.nn.moments(real, axes=[0, 1, 2, 3])
304 | x_std = tf.sqrt(x_var) # magnitude of noise decides the size of local region
305 |
306 | fake = real + 0.5 * x_std * eps
307 |
308 | alpha = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0., maxval=1.)
309 | interpolated = real + alpha * (fake - real)
310 |
311 | logit, cam_logit, _, _ = self.discriminator(interpolated, reuse=True, scope=scope)
312 |
313 |
314 | GP = []
315 | cam_GP = []
316 |
317 | for i in range(2) :
318 | grad = tf.gradients(logit[i], interpolated)[0] # gradient of D(interpolated)
319 | grad_norm = tf.norm(flatten(grad), axis=1) # l2 norm
320 |
321 | # WGAN - LP
322 | if self.gan_type == 'wgan-lp' :
323 | GP.append(self.ld * tf.reduce_mean(tf.square(tf.maximum(0.0, grad_norm - 1.))))
324 |
325 | elif self.gan_type == 'wgan-gp' or self.gan_type == 'dragan':
326 | GP.append(self.ld * tf.reduce_mean(tf.square(grad_norm - 1.)))
327 |
328 | for i in range(2) :
329 | grad = tf.gradients(cam_logit[i], interpolated)[0] # gradient of D(interpolated)
330 | grad_norm = tf.norm(flatten(grad), axis=1) # l2 norm
331 |
332 | # WGAN - LP
333 | if self.gan_type == 'wgan-lp' :
334 | cam_GP.append(self.ld * tf.reduce_mean(tf.square(tf.maximum(0.0, grad_norm - 1.))))
335 |
336 | elif self.gan_type == 'wgan-gp' or self.gan_type == 'dragan':
337 | cam_GP.append(self.ld * tf.reduce_mean(tf.square(grad_norm - 1.)))
338 |
339 |
340 | return sum(GP), sum(cam_GP)
341 |
342 | def build_model(self):
343 | if self.phase == 'train' :
344 | self.lr = tf.placeholder(tf.float32, name='learning_rate')
345 |
346 |
347 | """ Input Image"""
348 | Image_Data_Class = ImageData(self.img_size, self.img_ch, self.augment_flag)
349 |
350 | trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset)
351 | trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset)
352 |
353 |
354 | gpu_device = '/gpu:0'
355 | trainA = trainA.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, None))
356 | trainB = trainB.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, None))
357 |
358 |
359 | trainA_iterator = trainA.make_one_shot_iterator()
360 | trainB_iterator = trainB.make_one_shot_iterator()
361 |
362 | self.domain_A = trainA_iterator.get_next()
363 | self.domain_B = trainB_iterator.get_next()
364 |
365 | """ Define Generator, Discriminator """
366 | x_ab, cam_ab = self.generate_a2b(self.domain_A) # real a
367 | x_ba, cam_ba = self.generate_b2a(self.domain_B) # real b
368 |
369 | x_aba, _ = self.generate_b2a(x_ab, reuse=True) # real b
370 | x_bab, _ = self.generate_a2b(x_ba, reuse=True) # real a
371 |
372 | x_aa, cam_aa = self.generate_b2a(self.domain_A, reuse=True) # fake b
373 | x_bb, cam_bb = self.generate_a2b(self.domain_B, reuse=True) # fake a
374 |
375 | real_A_logit, real_A_cam_logit, real_B_logit, real_B_cam_logit = self.discriminate_real(self.domain_A, self.domain_B)
376 | fake_A_logit, fake_A_cam_logit, fake_B_logit, fake_B_cam_logit = self.discriminate_fake(x_ba, x_ab)
377 |
378 |
379 | """ Define Loss """
380 | if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan' :
381 | GP_A, GP_CAM_A = self.gradient_panalty(real=self.domain_A, fake=x_ba, scope="discriminator_A")
382 | GP_B, GP_CAM_B = self.gradient_panalty(real=self.domain_B, fake=x_ab, scope="discriminator_B")
383 | else :
384 | GP_A, GP_CAM_A = 0, 0
385 | GP_B, GP_CAM_B = 0, 0
386 |
387 | G_ad_loss_A = (generator_loss(self.gan_type, fake_A_logit) + generator_loss(self.gan_type, fake_A_cam_logit))
388 | G_ad_loss_B = (generator_loss(self.gan_type, fake_B_logit) + generator_loss(self.gan_type, fake_B_cam_logit))
389 |
390 | D_ad_loss_A = (discriminator_loss(self.gan_type, real_A_logit, fake_A_logit) + discriminator_loss(self.gan_type, real_A_cam_logit, fake_A_cam_logit) + GP_A + GP_CAM_A)
391 | D_ad_loss_B = (discriminator_loss(self.gan_type, real_B_logit, fake_B_logit) + discriminator_loss(self.gan_type, real_B_cam_logit, fake_B_cam_logit) + GP_B + GP_CAM_B)
392 |
393 | reconstruction_A = L1_loss(x_aba, self.domain_A) # reconstruction
394 | reconstruction_B = L1_loss(x_bab, self.domain_B) # reconstruction
395 |
396 | identity_A = L1_loss(x_aa, self.domain_A)
397 | identity_B = L1_loss(x_bb, self.domain_B)
398 |
399 | cam_A = cam_loss(source=cam_ba, non_source=cam_aa)
400 | cam_B = cam_loss(source=cam_ab, non_source=cam_bb)
401 |
402 | Generator_A_gan = self.adv_weight * G_ad_loss_A
403 | Generator_A_cycle = self.cycle_weight * reconstruction_B
404 | Generator_A_identity = self.identity_weight * identity_A
405 | Generator_A_cam = self.cam_weight * cam_A
406 |
407 |
408 | Generator_B_gan = self.adv_weight * G_ad_loss_B
409 | Generator_B_cycle = self.cycle_weight * reconstruction_A
410 | Generator_B_identity = self.identity_weight * identity_B
411 | Generator_B_cam = self.cam_weight * cam_B
412 |
413 |
414 | Generator_A_loss = Generator_A_gan + Generator_A_cycle + Generator_A_identity + Generator_A_cam
415 | Generator_B_loss = Generator_B_gan + Generator_B_cycle + Generator_B_identity + Generator_B_cam
416 |
417 |
418 | Discriminator_A_loss = self.adv_weight * D_ad_loss_A
419 | Discriminator_B_loss = self.adv_weight * D_ad_loss_B
420 |
421 | self.Generator_loss = Generator_A_loss + Generator_B_loss + regularization_loss('generator')
422 | self.Discriminator_loss = Discriminator_A_loss + Discriminator_B_loss + regularization_loss('discriminator')
423 |
424 |
425 | """ Result Image """
426 | self.fake_A = x_ba
427 | self.fake_B = x_ab
428 |
429 | self.real_A = self.domain_A
430 | self.real_B = self.domain_B
431 |
432 |
433 | """ Training """
434 | t_vars = tf.trainable_variables()
435 | G_vars = [var for var in t_vars if 'generator' in var.name]
436 | D_vars = [var for var in t_vars if 'discriminator' in var.name]
437 |
438 | self.G_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Generator_loss, var_list=G_vars)
439 | self.D_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Discriminator_loss, var_list=D_vars)
440 |
441 |
442 | """" Summary """
443 | self.all_G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss)
444 | self.all_D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss)
445 |
446 | self.G_A_loss = tf.summary.scalar("G_A_loss", Generator_A_loss)
447 | self.G_A_gan = tf.summary.scalar("G_A_gan", Generator_A_gan)
448 | self.G_A_cycle = tf.summary.scalar("G_A_cycle", Generator_A_cycle)
449 | self.G_A_identity = tf.summary.scalar("G_A_identity", Generator_A_identity)
450 | self.G_A_cam = tf.summary.scalar("G_A_cam", Generator_A_cam)
451 |
452 | self.G_B_loss = tf.summary.scalar("G_B_loss", Generator_B_loss)
453 | self.G_B_gan = tf.summary.scalar("G_B_gan", Generator_B_gan)
454 | self.G_B_cycle = tf.summary.scalar("G_B_cycle", Generator_B_cycle)
455 | self.G_B_identity = tf.summary.scalar("G_B_identity", Generator_B_identity)
456 | self.G_B_cam = tf.summary.scalar("G_B_cam", Generator_B_cam)
457 |
458 | self.D_A_loss = tf.summary.scalar("D_A_loss", Discriminator_A_loss)
459 | self.D_B_loss = tf.summary.scalar("D_B_loss", Discriminator_B_loss)
460 |
461 | self.rho_var = []
462 | for var in tf.trainable_variables():
463 | if 'rho' in var.name:
464 | self.rho_var.append(tf.summary.histogram(var.name, var))
465 | self.rho_var.append(tf.summary.scalar(var.name + "_min", tf.reduce_min(var)))
466 | self.rho_var.append(tf.summary.scalar(var.name + "_max", tf.reduce_max(var)))
467 | self.rho_var.append(tf.summary.scalar(var.name + "_mean", tf.reduce_mean(var)))
468 |
469 | g_summary_list = [self.G_A_loss, self.G_A_gan, self.G_A_cycle, self.G_A_identity, self.G_A_cam,
470 | self.G_B_loss, self.G_B_gan, self.G_B_cycle, self.G_B_identity, self.G_B_cam,
471 | self.all_G_loss]
472 |
473 | g_summary_list.extend(self.rho_var)
474 | d_summary_list = [self.D_A_loss, self.D_B_loss, self.all_D_loss]
475 |
476 | self.G_loss = tf.summary.merge(g_summary_list)
477 | self.D_loss = tf.summary.merge(d_summary_list)
478 |
479 | else :
480 | """ Test """
481 | self.test_domain_A = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_domain_A')
482 | self.test_domain_B = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_domain_B')
483 |
484 |
485 | self.test_fake_B, _ = self.generate_a2b(self.test_domain_A)
486 | self.test_fake_A, _ = self.generate_b2a(self.test_domain_B)
487 |
488 |
489 | def train(self):
490 | # initialize all variables
491 | tf.global_variables_initializer().run()
492 |
493 | # saver to save model
494 | self.saver = tf.train.Saver()
495 |
496 | # summary writer
497 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_dir, self.sess.graph)
498 |
499 |
500 | # restore check-point if it exits
501 | could_load, checkpoint_counter = self.load(self.checkpoint_dir)
502 | if could_load:
503 | start_epoch = (int)(checkpoint_counter / self.iteration)
504 | start_batch_id = checkpoint_counter - start_epoch * self.iteration
505 | counter = checkpoint_counter
506 | print(" [*] Load SUCCESS")
507 | else:
508 | start_epoch = 0
509 | start_batch_id = 0
510 | counter = 1
511 | print(" [!] Load failed...")
512 |
513 | # loop for epoch
514 | start_time = time.time()
515 | past_g_loss = -1.
516 | lr = self.init_lr
517 | for epoch in range(start_epoch, self.epoch):
518 | # lr = self.init_lr if epoch < self.decay_epoch else self.init_lr * (self.epoch - epoch) / (self.epoch - self.decay_epoch)
519 | if self.decay_flag :
520 | #lr = self.init_lr * pow(0.5, epoch // self.decay_epoch)
521 | lr = self.init_lr if epoch < self.decay_epoch else self.init_lr * (self.epoch - epoch) / (self.epoch - self.decay_epoch)
522 | for idx in range(start_batch_id, self.iteration):
523 | train_feed_dict = {
524 | self.lr : lr
525 | }
526 |
527 | # Update D
528 | _, d_loss, summary_str = self.sess.run([self.D_optim,
529 | self.Discriminator_loss, self.D_loss], feed_dict = train_feed_dict)
530 | self.writer.add_summary(summary_str, counter)
531 |
532 | # Update G
533 | g_loss = None
534 | if (counter - 1) % self.n_critic == 0 :
535 | batch_A_images, batch_B_images, fake_A, fake_B, _, g_loss, summary_str = self.sess.run([self.real_A, self.real_B,
536 | self.fake_A, self.fake_B,
537 | self.G_optim,
538 | self.Generator_loss, self.G_loss], feed_dict = train_feed_dict)
539 | self.writer.add_summary(summary_str, counter)
540 | past_g_loss = g_loss
541 |
542 | # display training status
543 | counter += 1
544 | if g_loss == None :
545 | g_loss = past_g_loss
546 | print("Epoch: [%2d] [%5d/%5d] time: %4.4f d_loss: %.8f, g_loss: %.8f" % (epoch, idx, self.iteration, time.time() - start_time, d_loss, g_loss))
547 |
548 | if np.mod(idx+1, self.print_freq) == 0 :
549 | save_images(batch_A_images, [self.batch_size, 1],
550 | './{}/real_A_{:03d}_{:05d}.png'.format(self.sample_dir, epoch, idx+1))
551 | # save_images(batch_B_images, [self.batch_size, 1],
552 | # './{}/real_B_{:03d}_{:05d}.png'.format(self.sample_dir, epoch, idx+1))
553 |
554 | # save_images(fake_A, [self.batch_size, 1],
555 | # './{}/fake_A_{:03d}_{:05d}.png'.format(self.sample_dir, epoch, idx+1))
556 | save_images(fake_B, [self.batch_size, 1],
557 | './{}/fake_B_{:03d}_{:05d}.png'.format(self.sample_dir, epoch, idx+1))
558 |
559 | if np.mod(idx + 1, self.save_freq) == 0:
560 | self.save(self.checkpoint_dir, counter)
561 |
562 |
563 |
564 | # After an epoch, start_batch_id is set to zero
565 | # non-zero value is only for the first epoch after loading pre-trained model
566 | start_batch_id = 0
567 |
568 | # save model for final step
569 | self.save(self.checkpoint_dir, counter)
570 |
571 | @property
572 | def model_dir(self):
573 | n_res = str(self.n_res) + 'resblock'
574 | n_dis = str(self.n_dis) + 'dis'
575 |
576 | if self.smoothing :
577 | smoothing = '_smoothing'
578 | else :
579 | smoothing = ''
580 |
581 | if self.sn :
582 | sn = '_sn'
583 | else :
584 | sn = ''
585 |
586 | return "{}_{}_{}_{}_{}_{}_{}_{}_{}_{}{}{}".format(self.model_name, self.dataset_name,
587 | self.gan_type, n_res, n_dis,
588 | self.n_critic,
589 | self.adv_weight, self.cycle_weight, self.identity_weight, self.cam_weight, sn, smoothing)
590 |
591 | def save(self, checkpoint_dir, step):
592 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
593 |
594 | if not os.path.exists(checkpoint_dir):
595 | os.makedirs(checkpoint_dir)
596 |
597 | self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step)
598 |
599 | def load(self, checkpoint_dir):
600 | print(" [*] Reading checkpoints...")
601 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
602 |
603 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
604 | if ckpt and ckpt.model_checkpoint_path:
605 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
606 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
607 | counter = int(ckpt_name.split('-')[-1])
608 | print(" [*] Success to read {}".format(ckpt_name))
609 | return True, counter
610 | else:
611 | print(" [*] Failed to find a checkpoint")
612 | return False, 0
613 |
614 | def test(self):
615 | tf.global_variables_initializer().run()
616 | test_A_files = glob('./dataset/{}/*.*'.format(self.dataset_name + '/testA'))
617 | test_B_files = glob('./dataset/{}/*.*'.format(self.dataset_name + '/testB'))
618 |
619 | self.saver = tf.train.Saver()
620 | could_load, checkpoint_counter = self.load(self.checkpoint_dir)
621 | self.result_dir = os.path.join(self.result_dir, self.model_dir)
622 | check_folder(self.result_dir)
623 |
624 | if could_load :
625 | print(" [*] Load SUCCESS")
626 | else :
627 | print(" [!] Load failed...")
628 |
629 | # write html for visual comparison
630 | index_path = os.path.join(self.result_dir, 'index.html')
631 | index = open(index_path, 'w')
632 | index.write("")
633 | index.write("name | input | output |
")
634 |
635 | for sample_file in test_A_files : # A -> B
636 | print('Processing A image: ' + sample_file)
637 | sample_image = np.asarray(load_test_data(sample_file, size=self.img_size))
638 | image_path = os.path.join(self.result_dir,'{0}'.format(os.path.basename(sample_file)))
639 |
640 | fake_img = self.sess.run(self.test_fake_B, feed_dict = {self.test_domain_A : sample_image})
641 | save_images(fake_img, [1, 1], image_path)
642 |
643 | index.write("%s | " % os.path.basename(image_path))
644 |
645 | index.write(" | " % (sample_file if os.path.isabs(sample_file) else (
646 | '../..' + os.path.sep + sample_file), self.img_size, self.img_size))
647 | index.write(" | " % (image_path if os.path.isabs(image_path) else (
648 | '../..' + os.path.sep + image_path), self.img_size, self.img_size))
649 | index.write("")
650 |
651 | for sample_file in test_B_files : # B -> A
652 | print('Processing B image: ' + sample_file)
653 | sample_image = np.asarray(load_test_data(sample_file, size=self.img_size))
654 | image_path = os.path.join(self.result_dir,'{0}'.format(os.path.basename(sample_file)))
655 |
656 | fake_img = self.sess.run(self.test_fake_A, feed_dict = {self.test_domain_B : sample_image})
657 |
658 | save_images(fake_img, [1, 1], image_path)
659 | index.write("%s | " % os.path.basename(image_path))
660 | index.write(" | " % (sample_file if os.path.isabs(sample_file) else (
661 | '../..' + os.path.sep + sample_file), self.img_size, self.img_size))
662 | index.write(" | " % (image_path if os.path.isabs(image_path) else (
663 | '../..' + os.path.sep + image_path), self.img_size, self.img_size))
664 | index.write("")
665 | index.close()
666 |
--------------------------------------------------------------------------------
/assets/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/UGATIT/d508e8f5188e47000d79d8aecada0cc9119e0d56/assets/.DS_Store
--------------------------------------------------------------------------------
/assets/ablation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/UGATIT/d508e8f5188e47000d79d8aecada0cc9119e0d56/assets/ablation.png
--------------------------------------------------------------------------------
/assets/discriminator_fix.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/UGATIT/d508e8f5188e47000d79d8aecada0cc9119e0d56/assets/discriminator_fix.png
--------------------------------------------------------------------------------
/assets/generator_fix.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/UGATIT/d508e8f5188e47000d79d8aecada0cc9119e0d56/assets/generator_fix.png
--------------------------------------------------------------------------------
/assets/kid_fix2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/UGATIT/d508e8f5188e47000d79d8aecada0cc9119e0d56/assets/kid_fix2.png
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/UGATIT/d508e8f5188e47000d79d8aecada0cc9119e0d56/assets/teaser.png
--------------------------------------------------------------------------------
/assets/user_study.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/UGATIT/d508e8f5188e47000d79d8aecada0cc9119e0d56/assets/user_study.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from UGATIT import UGATIT
2 | import argparse
3 | from utils import *
4 |
5 | """parsing and configuration"""
6 |
7 | def parse_args():
8 | desc = "Tensorflow implementation of U-GAT-IT"
9 | parser = argparse.ArgumentParser(description=desc)
10 | parser.add_argument('--phase', type=str, default='train', help='[train / test]')
11 | parser.add_argument('--light', type=str2bool, default=False, help='[U-GAT-IT full version / U-GAT-IT light version]')
12 | parser.add_argument('--dataset', type=str, default='selfie2anime', help='dataset_name')
13 |
14 | parser.add_argument('--epoch', type=int, default=100, help='The number of epochs to run')
15 | parser.add_argument('--iteration', type=int, default=10000, help='The number of training iterations')
16 | parser.add_argument('--batch_size', type=int, default=1, help='The size of batch size')
17 | parser.add_argument('--print_freq', type=int, default=1000, help='The number of image_print_freq')
18 | parser.add_argument('--save_freq', type=int, default=1000, help='The number of ckpt_save_freq')
19 | parser.add_argument('--decay_flag', type=str2bool, default=True, help='The decay_flag')
20 | parser.add_argument('--decay_epoch', type=int, default=50, help='decay epoch')
21 |
22 | parser.add_argument('--lr', type=float, default=0.0001, help='The learning rate')
23 | parser.add_argument('--GP_ld', type=int, default=10, help='The gradient penalty lambda')
24 | parser.add_argument('--adv_weight', type=int, default=1, help='Weight about GAN')
25 | parser.add_argument('--cycle_weight', type=int, default=10, help='Weight about Cycle')
26 | parser.add_argument('--identity_weight', type=int, default=10, help='Weight about Identity')
27 | parser.add_argument('--cam_weight', type=int, default=1000, help='Weight about CAM')
28 | parser.add_argument('--gan_type', type=str, default='lsgan', help='[gan / lsgan / wgan-gp / wgan-lp / dragan / hinge]')
29 |
30 | parser.add_argument('--smoothing', type=str2bool, default=True, help='AdaLIN smoothing effect')
31 |
32 | parser.add_argument('--ch', type=int, default=64, help='base channel number per layer')
33 | parser.add_argument('--n_res', type=int, default=4, help='The number of resblock')
34 | parser.add_argument('--n_dis', type=int, default=6, help='The number of discriminator layer')
35 | parser.add_argument('--n_critic', type=int, default=1, help='The number of critic')
36 | parser.add_argument('--sn', type=str2bool, default=True, help='using spectral norm')
37 |
38 | parser.add_argument('--img_size', type=int, default=256, help='The size of image')
39 | parser.add_argument('--img_ch', type=int, default=3, help='The size of image channel')
40 | parser.add_argument('--augment_flag', type=str2bool, default=True, help='Image augmentation use or not')
41 |
42 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint',
43 | help='Directory name to save the checkpoints')
44 | parser.add_argument('--result_dir', type=str, default='results',
45 | help='Directory name to save the generated images')
46 | parser.add_argument('--log_dir', type=str, default='logs',
47 | help='Directory name to save training logs')
48 | parser.add_argument('--sample_dir', type=str, default='samples',
49 | help='Directory name to save the samples on training')
50 |
51 | return check_args(parser.parse_args())
52 |
53 | """checking arguments"""
54 | def check_args(args):
55 | # --checkpoint_dir
56 | check_folder(args.checkpoint_dir)
57 |
58 | # --result_dir
59 | check_folder(args.result_dir)
60 |
61 | # --result_dir
62 | check_folder(args.log_dir)
63 |
64 | # --sample_dir
65 | check_folder(args.sample_dir)
66 |
67 | # --epoch
68 | try:
69 | assert args.epoch >= 1
70 | except:
71 | print('number of epochs must be larger than or equal to one')
72 |
73 | # --batch_size
74 | try:
75 | assert args.batch_size >= 1
76 | except:
77 | print('batch size must be larger than or equal to one')
78 | return args
79 |
80 | """main"""
81 | def main():
82 | # parse arguments
83 | args = parse_args()
84 | if args is None:
85 | exit()
86 |
87 | # open session
88 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
89 | gan = UGATIT(sess, args)
90 |
91 | # build graph
92 | gan.build_model()
93 |
94 | # show network architecture
95 | show_all_variables()
96 |
97 | if args.phase == 'train' :
98 | gan.train()
99 | print(" [*] Training finished!")
100 |
101 | if args.phase == 'test' :
102 | gan.test()
103 | print(" [*] Test finished!")
104 |
105 | if __name__ == '__main__':
106 | main()
107 |
--------------------------------------------------------------------------------
/ops.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import tensorflow.contrib as tf_contrib
3 |
4 | # Xavier : tf_contrib.layers.xavier_initializer()
5 | # He : tf_contrib.layers.variance_scaling_initializer()
6 | # Normal : tf.random_normal_initializer(mean=0.0, stddev=0.02)
7 | # l2_decay : tf_contrib.layers.l2_regularizer(0.0001)
8 |
9 | weight_init = tf.random_normal_initializer(mean=0.0, stddev=0.02)
10 | weight_regularizer = tf_contrib.layers.l2_regularizer(scale=0.0001)
11 |
12 | ##################################################################################
13 | # Layer
14 | ##################################################################################
15 |
16 | def conv(x, channels, kernel=4, stride=2, pad=0, pad_type='zero', use_bias=True, sn=False, scope='conv_0'):
17 | with tf.variable_scope(scope):
18 | if pad > 0 :
19 | if (kernel - stride) % 2 == 0:
20 | pad_top = pad
21 | pad_bottom = pad
22 | pad_left = pad
23 | pad_right = pad
24 |
25 | else:
26 | pad_top = pad
27 | pad_bottom = kernel - stride - pad_top
28 | pad_left = pad
29 | pad_right = kernel - stride - pad_left
30 |
31 | if pad_type == 'zero':
32 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]])
33 | if pad_type == 'reflect':
34 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]], mode='REFLECT')
35 |
36 | if sn :
37 | w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], initializer=weight_init,
38 | regularizer=weight_regularizer)
39 | x = tf.nn.conv2d(input=x, filter=spectral_norm(w),
40 | strides=[1, stride, stride, 1], padding='VALID')
41 | if use_bias :
42 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0))
43 | x = tf.nn.bias_add(x, bias)
44 |
45 | else :
46 | x = tf.layers.conv2d(inputs=x, filters=channels,
47 | kernel_size=kernel, kernel_initializer=weight_init,
48 | kernel_regularizer=weight_regularizer,
49 | strides=stride, use_bias=use_bias)
50 |
51 |
52 | return x
53 |
54 | def fully_connected_with_w(x, use_bias=True, sn=False, reuse=False, scope='linear'):
55 | with tf.variable_scope(scope, reuse=reuse):
56 | x = flatten(x)
57 | bias = 0.0
58 | shape = x.get_shape().as_list()
59 | channels = shape[-1]
60 |
61 | w = tf.get_variable("kernel", [channels, 1], tf.float32,
62 | initializer=weight_init, regularizer=weight_regularizer)
63 |
64 | if sn :
65 | w = spectral_norm(w)
66 |
67 | if use_bias :
68 | bias = tf.get_variable("bias", [1],
69 | initializer=tf.constant_initializer(0.0))
70 |
71 | x = tf.matmul(x, w) + bias
72 | else :
73 | x = tf.matmul(x, w)
74 |
75 | if use_bias :
76 | weights = tf.gather(tf.transpose(tf.nn.bias_add(w, bias)), 0)
77 | else :
78 | weights = tf.gather(tf.transpose(w), 0)
79 |
80 | return x, weights
81 |
82 | def fully_connected(x, units, use_bias=True, sn=False, scope='linear'):
83 | with tf.variable_scope(scope):
84 | x = flatten(x)
85 | shape = x.get_shape().as_list()
86 | channels = shape[-1]
87 |
88 | if sn:
89 | w = tf.get_variable("kernel", [channels, units], tf.float32,
90 | initializer=weight_init, regularizer=weight_regularizer)
91 | if use_bias:
92 | bias = tf.get_variable("bias", [units],
93 | initializer=tf.constant_initializer(0.0))
94 |
95 | x = tf.matmul(x, spectral_norm(w)) + bias
96 | else:
97 | x = tf.matmul(x, spectral_norm(w))
98 |
99 | else :
100 | x = tf.layers.dense(x, units=units, kernel_initializer=weight_init, kernel_regularizer=weight_regularizer, use_bias=use_bias)
101 |
102 | return x
103 |
104 | def flatten(x) :
105 | return tf.layers.flatten(x)
106 |
107 | ##################################################################################
108 | # Residual-block
109 | ##################################################################################
110 |
111 | def resblock(x_init, channels, use_bias=True, scope='resblock_0'):
112 | with tf.variable_scope(scope):
113 | with tf.variable_scope('res1'):
114 | x = conv(x_init, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias)
115 | x = instance_norm(x)
116 | x = relu(x)
117 |
118 | with tf.variable_scope('res2'):
119 | x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias)
120 | x = instance_norm(x)
121 |
122 | return x + x_init
123 |
124 | def adaptive_ins_layer_resblock(x_init, channels, gamma, beta, use_bias=True, smoothing=True, scope='adaptive_resblock') :
125 | with tf.variable_scope(scope):
126 | with tf.variable_scope('res1'):
127 | x = conv(x_init, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias)
128 | x = adaptive_instance_layer_norm(x, gamma, beta, smoothing)
129 | x = relu(x)
130 |
131 | with tf.variable_scope('res2'):
132 | x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias)
133 | x = adaptive_instance_layer_norm(x, gamma, beta, smoothing)
134 |
135 | return x + x_init
136 |
137 |
138 | ##################################################################################
139 | # Sampling
140 | ##################################################################################
141 |
142 | def up_sample(x, scale_factor=2):
143 | _, h, w, _ = x.get_shape().as_list()
144 | new_size = [h * scale_factor, w * scale_factor]
145 | return tf.image.resize_nearest_neighbor(x, size=new_size)
146 |
147 |
148 | def global_avg_pooling(x):
149 | gap = tf.reduce_mean(x, axis=[1, 2])
150 | return gap
151 |
152 | def global_max_pooling(x):
153 | gmp = tf.reduce_max(x, axis=[1, 2])
154 | return gmp
155 |
156 | ##################################################################################
157 | # Activation function
158 | ##################################################################################
159 |
160 | def lrelu(x, alpha=0.01):
161 | # pytorch alpha is 0.01
162 | return tf.nn.leaky_relu(x, alpha)
163 |
164 |
165 | def relu(x):
166 | return tf.nn.relu(x)
167 |
168 |
169 | def tanh(x):
170 | return tf.tanh(x)
171 |
172 | def sigmoid(x) :
173 | return tf.sigmoid(x)
174 |
175 | ##################################################################################
176 | # Normalization function
177 | ##################################################################################
178 |
179 | def adaptive_instance_layer_norm(x, gamma, beta, smoothing=True, scope='instance_layer_norm') :
180 | with tf.variable_scope(scope):
181 | ch = x.shape[-1]
182 | eps = 1e-5
183 |
184 | ins_mean, ins_sigma = tf.nn.moments(x, axes=[1, 2], keep_dims=True)
185 | x_ins = (x - ins_mean) / (tf.sqrt(ins_sigma + eps))
186 |
187 | ln_mean, ln_sigma = tf.nn.moments(x, axes=[1, 2, 3], keep_dims=True)
188 | x_ln = (x - ln_mean) / (tf.sqrt(ln_sigma + eps))
189 |
190 | rho = tf.get_variable("rho", [ch], initializer=tf.constant_initializer(1.0), constraint=lambda x: tf.clip_by_value(x, clip_value_min=0.0, clip_value_max=1.0))
191 |
192 | if smoothing :
193 | rho = tf.clip_by_value(rho - tf.constant(0.1), 0.0, 1.0)
194 |
195 | x_hat = rho * x_ins + (1 - rho) * x_ln
196 |
197 |
198 | x_hat = x_hat * gamma + beta
199 |
200 | return x_hat
201 |
202 | def instance_norm(x, scope='instance_norm'):
203 | return tf_contrib.layers.instance_norm(x,
204 | epsilon=1e-05,
205 | center=True, scale=True,
206 | scope=scope)
207 |
208 | def layer_norm(x, scope='layer_norm') :
209 | return tf_contrib.layers.layer_norm(x,
210 | center=True, scale=True,
211 | scope=scope)
212 |
213 | def layer_instance_norm(x, scope='layer_instance_norm') :
214 | with tf.variable_scope(scope):
215 | ch = x.shape[-1]
216 | eps = 1e-5
217 |
218 | ins_mean, ins_sigma = tf.nn.moments(x, axes=[1, 2], keep_dims=True)
219 | x_ins = (x - ins_mean) / (tf.sqrt(ins_sigma + eps))
220 |
221 | ln_mean, ln_sigma = tf.nn.moments(x, axes=[1, 2, 3], keep_dims=True)
222 | x_ln = (x - ln_mean) / (tf.sqrt(ln_sigma + eps))
223 |
224 | rho = tf.get_variable("rho", [ch], initializer=tf.constant_initializer(0.0), constraint=lambda x: tf.clip_by_value(x, clip_value_min=0.0, clip_value_max=1.0))
225 |
226 | gamma = tf.get_variable("gamma", [ch], initializer=tf.constant_initializer(1.0))
227 | beta = tf.get_variable("beta", [ch], initializer=tf.constant_initializer(0.0))
228 |
229 | x_hat = rho * x_ins + (1 - rho) * x_ln
230 |
231 | x_hat = x_hat * gamma + beta
232 |
233 | return x_hat
234 |
235 | def spectral_norm(w, iteration=1):
236 | w_shape = w.shape.as_list()
237 | w = tf.reshape(w, [-1, w_shape[-1]])
238 |
239 | u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.random_normal_initializer(), trainable=False)
240 |
241 | u_hat = u
242 | v_hat = None
243 | for i in range(iteration):
244 | """
245 | power iteration
246 | Usually iteration = 1 will be enough
247 | """
248 | v_ = tf.matmul(u_hat, tf.transpose(w))
249 | v_hat = tf.nn.l2_normalize(v_)
250 |
251 | u_ = tf.matmul(v_hat, w)
252 | u_hat = tf.nn.l2_normalize(u_)
253 |
254 | u_hat = tf.stop_gradient(u_hat)
255 | v_hat = tf.stop_gradient(v_hat)
256 |
257 | sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat))
258 |
259 | with tf.control_dependencies([u.assign(u_hat)]):
260 | w_norm = w / sigma
261 | w_norm = tf.reshape(w_norm, w_shape)
262 |
263 |
264 | return w_norm
265 |
266 | ##################################################################################
267 | # Loss function
268 | ##################################################################################
269 |
270 | def L1_loss(x, y):
271 | loss = tf.reduce_mean(tf.abs(x - y))
272 |
273 | return loss
274 |
275 | def cam_loss(source, non_source) :
276 |
277 | identity_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(source), logits=source))
278 | non_identity_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(non_source), logits=non_source))
279 |
280 | loss = identity_loss + non_identity_loss
281 |
282 | return loss
283 |
284 | def regularization_loss(scope_name) :
285 | """
286 | If you want to use "Regularization"
287 | g_loss += regularization_loss('generator')
288 | d_loss += regularization_loss('discriminator')
289 | """
290 | collection_regularization = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
291 |
292 | loss = []
293 | for item in collection_regularization :
294 | if scope_name in item.name :
295 | loss.append(item)
296 |
297 | return tf.reduce_sum(loss)
298 |
299 |
300 | def discriminator_loss(loss_func, real, fake):
301 | loss = []
302 | real_loss = 0
303 | fake_loss = 0
304 |
305 | for i in range(2) :
306 | if loss_func.__contains__('wgan') :
307 | real_loss = -tf.reduce_mean(real[i])
308 | fake_loss = tf.reduce_mean(fake[i])
309 |
310 | if loss_func == 'lsgan' :
311 | real_loss = tf.reduce_mean(tf.squared_difference(real[i], 1.0))
312 | fake_loss = tf.reduce_mean(tf.square(fake[i]))
313 |
314 | if loss_func == 'gan' or loss_func == 'dragan' :
315 | real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real[i]), logits=real[i]))
316 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(fake[i]), logits=fake[i]))
317 |
318 | if loss_func == 'hinge' :
319 | real_loss = tf.reduce_mean(relu(1.0 - real[i]))
320 | fake_loss = tf.reduce_mean(relu(1.0 + fake[i]))
321 |
322 | loss.append(real_loss + fake_loss)
323 |
324 | return sum(loss)
325 |
326 | def generator_loss(loss_func, fake):
327 | loss = []
328 | fake_loss = 0
329 |
330 | for i in range(2) :
331 | if loss_func.__contains__('wgan') :
332 | fake_loss = -tf.reduce_mean(fake[i])
333 |
334 | if loss_func == 'lsgan' :
335 | fake_loss = tf.reduce_mean(tf.squared_difference(fake[i], 1.0))
336 |
337 | if loss_func == 'gan' or loss_func == 'dragan' :
338 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(fake[i]), logits=fake[i]))
339 |
340 | if loss_func == 'hinge' :
341 | fake_loss = -tf.reduce_mean(fake[i])
342 |
343 | loss.append(fake_loss)
344 |
345 | return sum(loss)
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.contrib import slim
3 | import cv2
4 | import os, random
5 | import numpy as np
6 |
7 | class ImageData:
8 |
9 | def __init__(self, load_size, channels, augment_flag):
10 | self.load_size = load_size
11 | self.channels = channels
12 | self.augment_flag = augment_flag
13 |
14 | def image_processing(self, filename):
15 | x = tf.read_file(filename)
16 | x_decode = tf.image.decode_jpeg(x, channels=self.channels)
17 | img = tf.image.resize_images(x_decode, [self.load_size, self.load_size])
18 | img = tf.cast(img, tf.float32) / 127.5 - 1
19 |
20 | if self.augment_flag :
21 | augment_size = self.load_size + (30 if self.load_size == 256 else 15)
22 | p = random.random()
23 | if p > 0.5:
24 | img = augmentation(img, augment_size)
25 |
26 | return img
27 |
28 | def load_test_data(image_path, size=256):
29 | img = cv2.imread(image_path, flags=cv2.IMREAD_COLOR)
30 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
31 |
32 | img = cv2.resize(img, dsize=(size, size))
33 |
34 | img = np.expand_dims(img, axis=0)
35 | img = img/127.5 - 1
36 |
37 | return img
38 |
39 | def augmentation(image, augment_size):
40 | seed = random.randint(0, 2 ** 31 - 1)
41 | ori_image_shape = tf.shape(image)
42 | image = tf.image.random_flip_left_right(image, seed=seed)
43 | image = tf.image.resize_images(image, [augment_size, augment_size])
44 | image = tf.random_crop(image, ori_image_shape, seed=seed)
45 | return image
46 |
47 | def save_images(images, size, image_path):
48 | return imsave(inverse_transform(images), size, image_path)
49 |
50 | def inverse_transform(images):
51 | return ((images+1.) / 2) * 255.0
52 |
53 |
54 | def imsave(images, size, path):
55 | images = merge(images, size)
56 | images = cv2.cvtColor(images.astype('uint8'), cv2.COLOR_RGB2BGR)
57 |
58 | return cv2.imwrite(path, images)
59 |
60 | def merge(images, size):
61 | h, w = images.shape[1], images.shape[2]
62 | img = np.zeros((h * size[0], w * size[1], 3))
63 | for idx, image in enumerate(images):
64 | i = idx % size[1]
65 | j = idx // size[1]
66 | img[h*j:h*(j+1), w*i:w*(i+1), :] = image
67 |
68 | return img
69 |
70 | def show_all_variables():
71 | model_vars = tf.trainable_variables()
72 | slim.model_analyzer.analyze_vars(model_vars, print_info=True)
73 |
74 | def check_folder(log_dir):
75 | if not os.path.exists(log_dir):
76 | os.makedirs(log_dir)
77 | return log_dir
78 |
79 | def str2bool(x):
80 | return x.lower() in ('true')
81 |
--------------------------------------------------------------------------------