├── .DS_Store
├── .gitignore
├── LICENSE
├── README.md
├── assets
├── compare.png
├── gcb.png
├── pixel_shuffle.png
├── rdb.png
├── rdn.png
├── relativistic.png
├── relativistic_s.png
├── srm.png
├── tf-cook.png
└── tf-cookbook.png
├── ops.py
└── utils.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/Tensorflow-Cookbook/580e5fc26e8f24023d1da6d095452fb6d5a121c7/.DS_Store
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Junho Kim (1993.01.12)
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |

3 |
4 |
5 |
6 | # [Web page](http://bit.ly/jhkim_tf_cookbook)
7 | # [Tensorflow 2 Cookbook](https://github.com/taki0112/Tensorflow2-Cookbook)
8 |
9 | ## Contributions
10 | In now, this repo contains general architectures and functions that are useful for the GAN and classificstion.
11 |
12 | I will continue to add useful things to other areas.
13 |
14 | Also, your pull requests and issues are always welcome.
15 |
16 | And write what you want to implement on the issue. I'll implement it.
17 |
18 | # How to use
19 | ## Import
20 | * `ops.py`
21 | * **operations**
22 | * from ops import *
23 | * `utils.py`
24 | * **image processing**
25 | * from utils import *
26 |
27 | ## Network template
28 | ```python
29 | def network(x, is_training=True, reuse=False, scope="network"):
30 | with tf.variable_scope(scope, reuse=reuse):
31 | x = conv(...)
32 |
33 | ...
34 |
35 | return logit
36 | ```
37 |
38 | ## Insert data to network using DatasetAPI
39 | ```python
40 | Image_Data_Class = ImageData(img_size, img_ch, augment_flag)
41 |
42 | trainA_dataset = ['./dataset/cat/trainA/a.jpg',
43 | './dataset/cat/trainA/b.png',
44 | './dataset/cat/trainA/c.jpeg',
45 | ...]
46 | trainA = tf.data.Dataset.from_tensor_slices(trainA_dataset)
47 | trainA = trainA.map(Image_Data_Class.image_processing, num_parallel_calls=16)
48 | trainA = trainA.shuffle(buffer_size=10000).prefetch(buffer_size=batch_size).batch(batch_size).repeat()
49 |
50 | trainA_iterator = trainA.make_one_shot_iterator()
51 | data_A = trainA_iterator.get_next()
52 |
53 | logit = network(data_A)
54 | ```
55 | * See [this](https://github.com/taki0112/Tensorflow-DatasetAPI) for more information.
56 |
57 | ## Option
58 | * `padding='SAME'`
59 | * pad = ceil[ (kernel - stride) / 2 ]
60 | * `pad_type`
61 | * 'zero' or 'reflect'
62 | * `sn`
63 | * use [spectral_normalization](https://arxiv.org/pdf/1802.05957.pdf) or not
64 |
65 | ## Caution
66 | * If you don't want to share variable, **set all scope names differently.**
67 |
68 | ---
69 | ## Weight
70 | ```python
71 | weight_init = tf.truncated_normal_initializer(mean=0.0, stddev=0.02)
72 | weight_regularizer = tf.contrib.layers.l2_regularizer(0.0001)
73 | weight_regularizer_fully = tf.contrib.layers.l2_regularizer(0.0001)
74 | ```
75 | ### Initialization
76 | * `Xavier` : tf.contrib.layers.xavier_initializer()
77 | ```python
78 |
79 | USE """tf.contrib.layers.variance_scaling_initializer()"""
80 |
81 | if uniform :
82 | factor = gain * gain
83 | mode = 'FAN_AVG'
84 | else :
85 | factor = (gain * gain) / 1.3
86 | mode = 'FAN_AVG'
87 | ```
88 | * `He` : tf.contrib.layers.variance_scaling_initializer()
89 | ```python
90 | if uniform :
91 | factor = gain * gain
92 | mode = 'FAN_IN'
93 | else :
94 | factor = (gain * gain) / 1.3
95 | mode = 'FAN_OUT'
96 | ```
97 | * `Normal` : tf.random_normal_initializer(mean=0.0, stddev=0.02)
98 | * `Truncated_normal` : tf.truncated_normal_initializer(mean=0.0, stddev=0.02)
99 | * `Orthogonal` : tf.orthogonal_initializer(1.0) / # if relu = sqrt(2), the others = 1.0
100 |
101 | ### Regularization
102 | * `l2_decay` : tf.contrib.layers.l2_regularizer(0.0001)
103 | * `orthogonal_regularizer` : orthogonal_regularizer(0.0001) & orthogonal_regularizer_fully(0.0001)
104 |
105 | ## Convolution
106 | ### basic conv
107 | ```python
108 | x = conv(x, channels=64, kernel=3, stride=2, pad=1, pad_type='reflect', use_bias=True, sn=True, scope='conv')
109 | ```
110 |
111 |

112 |
113 |
114 | ### partial conv (NVIDIA [Partial Convolution](https://github.com/NVIDIA/partialconv))
115 | ```python
116 | x = partial_conv(x, channels=64, kernel=3, stride=2, use_bias=True, padding='SAME', sn=True, scope='partial_conv')
117 | ```
118 |
119 | 
120 | 
121 |
122 | ### dilated conv
123 | ```python
124 | x = dilate_conv(x, channels=64, kernel=3, rate=2, use_bias=True, padding='VALID', sn=True, scope='dilate_conv')
125 | ```
126 |
127 |

128 |
129 |
130 | ---
131 |
132 | ## Deconvolution
133 | ### basic deconv
134 | ```python
135 | x = deconv(x, channels=64, kernel=3, stride=1, padding='SAME', use_bias=True, sn=True, scope='deconv')
136 | ```
137 |
138 |

139 |
140 |
141 | ---
142 |
143 | ## Fully-connected
144 | ```python
145 | x = fully_connected(x, units=64, use_bias=True, sn=True, scope='fully_connected')
146 | ```
147 |
148 |
149 |

150 |
151 |
152 | ---
153 |
154 | ## Pixel shuffle
155 | ```python
156 | x = conv_pixel_shuffle_down(x, scale_factor=2, use_bias=True, sn=True, scope='pixel_shuffle_down')
157 | x = conv_pixel_shuffle_up(x, scale_factor=2, use_bias=True, sn=True, scope='pixel_shuffle_up')
158 | ```
159 | * `down` ===> [height, width] -> [**height // scale_factor, width // scale_factor**]
160 | * `up` ===> [height, width] -> [**height \* scale_factor, width \* scale_factor**]
161 |
162 | 
163 |
164 |
165 | ---
166 |
167 | ## Block
168 | ### residual block
169 | ```python
170 | x = resblock(x, channels=64, is_training=is_training, use_bias=True, sn=True, scope='residual_block')
171 | x = resblock_down(x, channels=64, is_training=is_training, use_bias=True, sn=True, scope='residual_block_down')
172 | x = resblock_up(x, channels=64, is_training=is_training, use_bias=True, sn=True, scope='residual_block_up')
173 | ```
174 | * `down` ===> [height, width] -> [**height // 2, width // 2**]
175 | * `up` ===> [height, width] -> [**height \* 2, width \* 2**]
176 |
177 |

178 |
179 |
180 | ### dense block
181 | ```python
182 | x = denseblock(x, channels=64, n_db=6, is_training=is_training, use_bias=True, sn=True, scope='denseblock')
183 | ```
184 | * `n_db` ===> The number of dense-block
185 |
186 |

187 |
188 |
189 | ### residual-dense block
190 | ```python
191 | x = res_denseblock(x, channels=64, n_rdb=20, n_rdb_conv=6, is_training=is_training, use_bias=True, sn=True, scope='res_denseblock')
192 | ```
193 | * `n_rdb` ===> The number of RDB
194 | * `n_rdb_conv` ===> per RDB conv layer
195 |
196 |
201 |
202 | ### attention block
203 | ```python
204 | x = self_attention(x, use_bias=True, sn=True, scope='self_attention')
205 | x = self_attention_with_pooling(x, use_bias=True, sn=True, scope='self_attention_version_2')
206 |
207 | x = squeeze_excitation(x, ratio=16, use_bias=True, sn=True, scope='squeeze_excitation')
208 |
209 | x = convolution_block_attention(x, ratio=16, use_bias=True, sn=True, scope='convolution_block_attention')
210 |
211 | x = global_context_block(x, use_bias=True, sn=True, scope='gc_block')
212 |
213 | x = srm_block(x, use_bias=False, is_training=is_training, scope='srm_block')
214 | ```
215 |
216 |
217 |

218 |
219 |
220 | ---
221 |
222 |
223 |

224 |

225 |
226 |
227 | ---
228 |
229 |
230 |
231 |

232 |

233 |
234 |
235 |
236 | ---
237 |
238 |
239 |

240 |
241 |
242 | ---
243 |
244 |
245 |

246 |
247 |
248 | ---
249 |
250 | ## Normalization
251 | ```python
252 | x = batch_norm(x, is_training=is_training, scope='batch_norm')
253 | x = layer_norm(x, scope='layer_norm')
254 | x = instance_norm(x, scope='instance_norm')
255 | x = group_norm(x, groups=32, scope='group_norm')
256 |
257 | x = pixel_norm(x)
258 |
259 | x = batch_instance_norm(x, scope='batch_instance_norm')
260 | x = layer_instance_norm(x, scope='layer_instance_norm')
261 | x = switch_norm(x, scope='switch_norm')
262 |
263 | x = condition_batch_norm(x, z, is_training=is_training, scope='condition_batch_norm'):
264 |
265 | x = adaptive_instance_norm(x, gamma, beta)
266 | x = adaptive_layer_instance_norm(x, gamma, beta, smoothing=True, scope='adaLIN')
267 |
268 | ```
269 | * See [this](https://github.com/taki0112/BigGAN-Tensorflow) for how to use `condition_batch_norm`
270 | * See [this](https://github.com/taki0112/MUNIT-Tensorflow) for how to use `adaptive_instance_norm`
271 | * See [this](https://github.com/taki0112/UGATIT) for how to use `adaptive_layer_instance_norm` & `layer_instance_norm`
272 |
273 |
274 |

275 |
276 |
277 |
278 |
279 |

280 |
281 |
282 | ---
283 |
284 | ## Activation
285 | ```python
286 | x = relu(x)
287 | x = lrelu(x, alpha=0.2)
288 | x = tanh(x)
289 | x = sigmoid(x)
290 | x = swish(x)
291 | x = elu(x)
292 | ```
293 |
294 | ---
295 |
296 | ## Pooling & Resize
297 | ```python
298 | x = nearest_up_sample(x, scale_factor=2)
299 | x = bilinear_up_sample(x, scale_factor=2)
300 | x = nearest_down_sample(x, scale_factor=2)
301 | x = bilinear_down_sample(x, scale_factor=2)
302 |
303 | x = max_pooling(x, pool_size=2)
304 | x = avg_pooling(x, pool_size=2)
305 |
306 | x = global_max_pooling(x)
307 | x = global_avg_pooling(x)
308 |
309 | x = flatten(x)
310 | x = hw_flatten(x)
311 | ```
312 |
313 | ---
314 |
315 | ## Loss
316 | ### classification loss
317 | ```python
318 | loss, accuracy = classification_loss(logit, label)
319 |
320 | loss = dice_loss(n_classes=10, logit, label)
321 | ```
322 |
323 | ### regularization loss
324 | ```python
325 | g_reg_loss = regularization_loss('generator')
326 | d_reg_loss = regularization_loss('discriminator')
327 | ```
328 |
329 | * If you want to use `regularizer`, then you should write it
330 |
331 | ### pixel loss
332 | ```python
333 | loss = L1_loss(x, y)
334 | loss = L2_loss(x, y)
335 | loss = huber_loss(x, y)
336 | loss = histogram_loss(x, y)
337 |
338 | loss = gram_style_loss(x, y)
339 |
340 | loss = color_consistency_loss(x, y)
341 | ```
342 | * `histogram_loss` means the difference in the color distribution of the image pixel values.
343 | * `gram_style_loss` means the difference between the styles using gram matrix.
344 | * `color_consistency_loss` means the color difference between the generated image and the input image.
345 |
346 | ### gan loss
347 | ```python
348 | d_loss = discriminator_loss(Ra=True, loss_func='wgan-gp', real=real_logit, fake=fake_logit)
349 | g_loss = generator_loss(Ra=True, loss_func='wgan-gp', real=real_logit, fake=fake_logit)
350 | ```
351 | * `Ra`
352 | * use [relativistic gan](https://arxiv.org/pdf/1807.00734.pdf) or not
353 | * `loss_func`
354 | * gan
355 | * lsgan
356 | * hinge
357 | * wgan-gp
358 | * dragan
359 | * See [this](https://github.com/taki0112/BigGAN-Tensorflow/blob/master/BigGAN_512.py#L180) for how to use `gradient_penalty`
360 |
361 |
362 |

363 |
364 |
365 | ### [vdb loss](https://arxiv.org/abs/1810.00821)
366 | ```python
367 | d_bottleneck_loss = vdb_loss(real_mu, real_logvar, i_c) + vdb_loss(fake_mu, fake_logvar, i_c)
368 | ```
369 |
370 | ### kl-divergence (z ~ N(0, 1))
371 | ```python
372 | loss = kl_loss(mean, logvar)
373 | ```
374 |
375 | ---
376 |
377 | ## Author
378 | [Junho Kim](http://bit.ly/jhkim_ai)
379 |
--------------------------------------------------------------------------------
/assets/compare.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/Tensorflow-Cookbook/580e5fc26e8f24023d1da6d095452fb6d5a121c7/assets/compare.png
--------------------------------------------------------------------------------
/assets/gcb.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/Tensorflow-Cookbook/580e5fc26e8f24023d1da6d095452fb6d5a121c7/assets/gcb.png
--------------------------------------------------------------------------------
/assets/pixel_shuffle.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/Tensorflow-Cookbook/580e5fc26e8f24023d1da6d095452fb6d5a121c7/assets/pixel_shuffle.png
--------------------------------------------------------------------------------
/assets/rdb.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/Tensorflow-Cookbook/580e5fc26e8f24023d1da6d095452fb6d5a121c7/assets/rdb.png
--------------------------------------------------------------------------------
/assets/rdn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/Tensorflow-Cookbook/580e5fc26e8f24023d1da6d095452fb6d5a121c7/assets/rdn.png
--------------------------------------------------------------------------------
/assets/relativistic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/Tensorflow-Cookbook/580e5fc26e8f24023d1da6d095452fb6d5a121c7/assets/relativistic.png
--------------------------------------------------------------------------------
/assets/relativistic_s.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/Tensorflow-Cookbook/580e5fc26e8f24023d1da6d095452fb6d5a121c7/assets/relativistic_s.png
--------------------------------------------------------------------------------
/assets/srm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/Tensorflow-Cookbook/580e5fc26e8f24023d1da6d095452fb6d5a121c7/assets/srm.png
--------------------------------------------------------------------------------
/assets/tf-cook.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/Tensorflow-Cookbook/580e5fc26e8f24023d1da6d095452fb6d5a121c7/assets/tf-cook.png
--------------------------------------------------------------------------------
/assets/tf-cookbook.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/Tensorflow-Cookbook/580e5fc26e8f24023d1da6d095452fb6d5a121c7/assets/tf-cookbook.png
--------------------------------------------------------------------------------
/ops.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | from utils import pytorch_xavier_weight_factor, pytorch_kaiming_weight_factor
4 |
5 | ##################################################################################
6 | # Initialization
7 | ##################################################################################
8 |
9 | """
10 |
11 | pytorch xavier (gain)
12 | https://pytorch.org/docs/stable/_modules/torch/nn/init.html
13 |
14 | USE < tf.contrib.layers.variance_scaling_initializer() >
15 | if uniform :
16 | factor = gain * gain
17 | mode = 'FAN_AVG'
18 | else :
19 | factor = (gain * gain) / 1.3
20 | mode = 'FAN_AVG'
21 |
22 | pytorch : trunc_stddev = gain * sqrt(2 / (fan_in + fan_out))
23 | tensorflow : trunc_stddev = sqrt(1.3 * factor * 2 / (fan_in + fan_out))
24 |
25 | """
26 |
27 | """
28 | pytorch kaiming (a=0)
29 | https://pytorch.org/docs/stable/_modules/torch/nn/init.html
30 |
31 | if uniform :
32 | a = 0 -> gain = sqrt(2)
33 | factor = gain * gain
34 | mode='FAN_IN'
35 | else :
36 | a = 0 -> gain = sqrt(2)
37 | factor = (gain * gain) / 1.3
38 | mode = 'FAN_OUT', # FAN_OUT is correct, but more use 'FAN_IN
39 |
40 | pytorch : trunc_stddev = gain * sqrt(2 / fan_in)
41 | tensorflow : trunc_stddev = sqrt(1.3 * factor * 2 / fan_in)
42 |
43 | """
44 |
45 | # Xavier : tf.contrib.layers.xavier_initializer()
46 | # He : tf.contrib.layers.variance_scaling_initializer()
47 | # Normal : tf.random_normal_initializer(mean=0.0, stddev=0.02)
48 | # Truncated_normal : tf.truncated_normal_initializer(mean=0.0, stddev=0.02)
49 | # Orthogonal : tf.orthogonal_initializer(0.02)
50 |
51 | ##################################################################################
52 | # Regularization
53 | ##################################################################################
54 |
55 | # l2_decay : tf.contrib.layers.l2_regularizer(0.0001)
56 | # orthogonal_regularizer : orthogonal_regularizer(0.0001) # orthogonal_regularizer_fully(0.0001)
57 |
58 | # factor, mode, uniform = pytorch_xavier_weight_factor(gain=0.02, uniform=False)
59 | # weight_init = tf_contrib.layers.variance_scaling_initializer(factor=factor, mode=mode, uniform=uniform)
60 |
61 | weight_init = tf.truncated_normal_initializer(mean=0.0, stddev=0.02)
62 | weight_regularizer = tf.contrib.layers.l2_regularizer(0.0001)
63 | weight_regularizer_fully = tf.contrib.layers.l2_regularizer(0.0001)
64 |
65 |
66 | ##################################################################################
67 | # Layers
68 | ##################################################################################
69 |
70 | # padding='SAME' ======> pad = floor[ (kernel - stride) / 2 ]
71 | def conv(x, channels, kernel=4, stride=2, pad=0, pad_type='zero', use_bias=True, sn=False, scope='conv_0'):
72 | with tf.variable_scope(scope):
73 | if pad > 0:
74 | h = x.get_shape().as_list()[1]
75 | if h % stride == 0:
76 | pad = pad * 2
77 | else:
78 | pad = max(kernel - (h % stride), 0)
79 |
80 | pad_top = pad // 2
81 | pad_bottom = pad - pad_top
82 | pad_left = pad // 2
83 | pad_right = pad - pad_left
84 |
85 | if pad_type == 'zero':
86 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]])
87 | if pad_type == 'reflect':
88 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]], mode='REFLECT')
89 |
90 | if sn:
91 | w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], initializer=weight_init,
92 | regularizer=weight_regularizer)
93 | x = tf.nn.conv2d(input=x, filter=spectral_norm(w),
94 | strides=[1, stride, stride, 1], padding='VALID')
95 | if use_bias:
96 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0))
97 | x = tf.nn.bias_add(x, bias)
98 |
99 | else:
100 | x = tf.layers.conv2d(inputs=x, filters=channels,
101 | kernel_size=kernel, kernel_initializer=weight_init,
102 | kernel_regularizer=weight_regularizer,
103 | strides=stride, use_bias=use_bias)
104 |
105 | return x
106 |
107 |
108 | def partial_conv(x, channels, kernel=3, stride=2, use_bias=True, padding='SAME', sn=False, scope='conv_0'):
109 | with tf.variable_scope(scope):
110 | if padding.lower() == 'SAME'.lower():
111 | with tf.variable_scope('mask'):
112 | _, h, w, _ = x.get_shape().as_list()
113 |
114 | slide_window = kernel * kernel
115 | mask = tf.ones(shape=[1, h, w, 1])
116 |
117 | update_mask = tf.layers.conv2d(mask, filters=1,
118 | kernel_size=kernel, kernel_initializer=tf.constant_initializer(1.0),
119 | strides=stride, padding=padding, use_bias=False, trainable=False)
120 |
121 | mask_ratio = slide_window / (update_mask + 1e-8)
122 | update_mask = tf.clip_by_value(update_mask, 0.0, 1.0)
123 | mask_ratio = mask_ratio * update_mask
124 |
125 | with tf.variable_scope('x'):
126 | if sn:
127 | w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels],
128 | initializer=weight_init, regularizer=weight_regularizer)
129 | x = tf.nn.conv2d(input=x, filter=spectral_norm(w), strides=[1, stride, stride, 1], padding=padding)
130 | else:
131 | x = tf.layers.conv2d(x, filters=channels,
132 | kernel_size=kernel, kernel_initializer=weight_init,
133 | kernel_regularizer=weight_regularizer,
134 | strides=stride, padding=padding, use_bias=False)
135 | x = x * mask_ratio
136 |
137 | if use_bias:
138 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0))
139 |
140 | x = tf.nn.bias_add(x, bias)
141 | x = x * update_mask
142 | else:
143 | if sn:
144 | w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels],
145 | initializer=weight_init, regularizer=weight_regularizer)
146 | x = tf.nn.conv2d(input=x, filter=spectral_norm(w), strides=[1, stride, stride, 1], padding=padding)
147 | if use_bias:
148 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0))
149 |
150 | x = tf.nn.bias_add(x, bias)
151 | else:
152 | x = tf.layers.conv2d(x, filters=channels,
153 | kernel_size=kernel, kernel_initializer=weight_init,
154 | kernel_regularizer=weight_regularizer,
155 | strides=stride, padding=padding, use_bias=use_bias)
156 |
157 | return x
158 |
159 |
160 | def dilate_conv(x, channels, kernel=3, rate=2, use_bias=True, padding='SAME', sn=False, scope='conv_0'):
161 | with tf.variable_scope(scope):
162 | w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], initializer=weight_init,
163 | regularizer=weight_regularizer)
164 | if sn:
165 | x = tf.nn.atrous_conv2d(x, spectral_norm(w), rate=rate, padding=padding)
166 | else:
167 | x = tf.nn.atrous_conv2d(x, w, rate=rate, padding=padding)
168 |
169 | if use_bias:
170 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0))
171 | x = tf.nn.bias_add(x, bias)
172 |
173 | return x
174 |
175 |
176 | def deconv(x, channels, kernel=4, stride=2, padding='SAME', use_bias=True, sn=False, scope='deconv_0'):
177 | with tf.variable_scope(scope):
178 | x_shape = x.get_shape().as_list()
179 |
180 | if padding == 'SAME':
181 | output_shape = [x_shape[0], x_shape[1] * stride, x_shape[2] * stride, channels]
182 |
183 | else:
184 | output_shape = [x_shape[0], x_shape[1] * stride + max(kernel - stride, 0),
185 | x_shape[2] * stride + max(kernel - stride, 0), channels]
186 |
187 | if sn:
188 | w = tf.get_variable("kernel", shape=[kernel, kernel, channels, x.get_shape()[-1]], initializer=weight_init,
189 | regularizer=weight_regularizer)
190 | x = tf.nn.conv2d_transpose(x, filter=spectral_norm(w), output_shape=output_shape,
191 | strides=[1, stride, stride, 1], padding=padding)
192 |
193 | if use_bias:
194 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0))
195 | x = tf.nn.bias_add(x, bias)
196 |
197 | else:
198 | x = tf.layers.conv2d_transpose(inputs=x, filters=channels,
199 | kernel_size=kernel, kernel_initializer=weight_init,
200 | kernel_regularizer=weight_regularizer,
201 | strides=stride, padding=padding, use_bias=use_bias)
202 |
203 | return x
204 |
205 |
206 | def conv_pixel_shuffle_up(x, scale_factor=2, use_bias=True, sn=False, scope='pixel_shuffle'):
207 | channel = x.get_shape()[-1] * (scale_factor ** 2)
208 | x = conv(x, channel, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope=scope)
209 | x = tf.depth_to_space(x, block_size=scale_factor)
210 |
211 | return x
212 |
213 |
214 | def conv_pixel_shuffle_down(x, scale_factor=2, use_bias=True, sn=False, scope='pixel_shuffle'):
215 | channel = x.get_shape()[-1] // (scale_factor ** 2)
216 | x = conv(x, channel, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope=scope)
217 | x = tf.space_to_depth(x, block_size=scale_factor)
218 |
219 | return x
220 |
221 |
222 | def fully_connected(x, units, use_bias=True, sn=False, scope='linear'):
223 | with tf.variable_scope(scope):
224 | x = flatten(x)
225 | shape = x.get_shape().as_list()
226 | channels = shape[-1]
227 |
228 | if sn:
229 | w = tf.get_variable("kernel", [channels, units], tf.float32,
230 | initializer=weight_init, regularizer=weight_regularizer_fully)
231 | if use_bias:
232 | bias = tf.get_variable("bias", [units],
233 | initializer=tf.constant_initializer(0.0))
234 |
235 | x = tf.matmul(x, spectral_norm(w)) + bias
236 | else:
237 | x = tf.matmul(x, spectral_norm(w))
238 |
239 | else:
240 | x = tf.layers.dense(x, units=units, kernel_initializer=weight_init,
241 | kernel_regularizer=weight_regularizer_fully,
242 | use_bias=use_bias)
243 |
244 | return x
245 |
246 |
247 | ##################################################################################
248 | # Blocks
249 | ##################################################################################
250 |
251 | def resblock(x_init, channels, use_bias=True, is_training=True, sn=False, scope='resblock'):
252 | with tf.variable_scope(scope):
253 | with tf.variable_scope('res1'):
254 | x = conv(x_init, channels, kernel=3, stride=1, pad=1, use_bias=use_bias, sn=sn)
255 | x = batch_norm(x, is_training)
256 | x = relu(x)
257 |
258 | with tf.variable_scope('res2'):
259 | x = conv(x, channels, kernel=3, stride=1, pad=1, use_bias=use_bias, sn=sn)
260 | x = batch_norm(x, is_training)
261 |
262 | if channels != x_init.shape[-1]:
263 | with tf.variable_scope('skip'):
264 | x_init = conv(x_init, channels, kernel=1, stride=1, use_bias=use_bias, sn=sn)
265 | return relu(x + x_init)
266 |
267 | return x + x_init
268 |
269 |
270 | def resblock_up(x_init, channels, use_bias=True, is_training=True, sn=False, scope='resblock_up'):
271 | with tf.variable_scope(scope):
272 | with tf.variable_scope('res1'):
273 | x = deconv(x_init, channels, kernel=3, stride=2, use_bias=use_bias, sn=sn)
274 | x = batch_norm(x, is_training)
275 | x = relu(x)
276 |
277 | with tf.variable_scope('res2'):
278 | x = deconv(x, channels, kernel=3, stride=1, use_bias=use_bias, sn=sn)
279 | x = batch_norm(x, is_training)
280 |
281 | with tf.variable_scope('skip'):
282 | x_init = deconv(x_init, channels, kernel=3, stride=2, use_bias=use_bias, sn=sn)
283 |
284 | return relu(x + x_init)
285 |
286 |
287 | def resblock_up_condition(x_init, z, channels, use_bias=True, is_training=True, sn=False, scope='resblock_up'):
288 | # See https://github.com/taki0112/BigGAN-Tensorflow
289 | with tf.variable_scope(scope):
290 | with tf.variable_scope('res1'):
291 | x = deconv(x_init, channels, kernel=3, stride=2, use_bias=use_bias, sn=sn)
292 | x = condition_batch_norm(x, z, is_training)
293 | x = relu(x)
294 |
295 | with tf.variable_scope('res2'):
296 | x = deconv(x, channels, kernel=3, stride=1, use_bias=use_bias, sn=sn)
297 | x = condition_batch_norm(x, z, is_training)
298 |
299 | with tf.variable_scope('skip'):
300 | x_init = deconv(x_init, channels, kernel=3, stride=2, use_bias=use_bias, sn=sn)
301 |
302 | return relu(x + x_init)
303 |
304 |
305 | def resblock_down(x_init, channels, use_bias=True, is_training=True, sn=False, scope='resblock_down'):
306 | with tf.variable_scope(scope):
307 | with tf.variable_scope('res1'):
308 | x = conv(x_init, channels, kernel=3, stride=2, pad=1, use_bias=use_bias, sn=sn)
309 | x = batch_norm(x, is_training)
310 | x = relu(x)
311 |
312 | with tf.variable_scope('res2'):
313 | x = conv(x, channels, kernel=3, stride=1, pad=1, use_bias=use_bias, sn=sn)
314 | x = batch_norm(x, is_training)
315 |
316 | with tf.variable_scope('skip'):
317 | x_init = conv(x_init, channels, kernel=3, stride=2, pad=1, use_bias=use_bias, sn=sn)
318 |
319 | return relu(x + x_init)
320 |
321 |
322 | def denseblock(x_init, channels, n_db=6, use_bias=True, is_training=True, sn=False, scope='denseblock'):
323 | with tf.variable_scope(scope):
324 | layers = []
325 | layers.append(x_init)
326 |
327 | with tf.variable_scope('bottle_neck_0'):
328 | x = conv(x_init, 4 * channels, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv_0')
329 | x = batch_norm(x, is_training, scope='batch_norm_0')
330 | x = relu(x)
331 |
332 | x = conv(x, channels, kernel=3, stride=1, pad=1, use_bias=use_bias, sn=sn, scope='conv_1')
333 | x = batch_norm(x, is_training, scope='batch_norm_1')
334 | x = relu(x)
335 |
336 | layers.append(x)
337 |
338 | for i in range(1, n_db):
339 | with tf.variable_scope('bottle_neck_' + str(i)):
340 | x = tf.concat(layers, axis=-1)
341 |
342 | x = conv(x, 4 * channels, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv_0')
343 | x = batch_norm(x, is_training, scope='batch_norm_0')
344 | x = relu(x)
345 |
346 | x = conv(x, channels, kernel=3, stride=1, pad=1, use_bias=use_bias, sn=sn, scope='conv_1')
347 | x = batch_norm(x, is_training, scope='batch_norm_1')
348 | x = relu(x)
349 |
350 | layers.append(x)
351 |
352 | x = tf.concat(layers, axis=-1)
353 |
354 | return x
355 |
356 |
357 | def res_denseblock(x_init, channels, n_rdb=20, n_rdb_conv=6, use_bias=True, is_training=True, sn=False,
358 | scope='res_denseblock'):
359 | with tf.variable_scope(scope):
360 | RDBs = []
361 | x_input = x_init
362 |
363 | """
364 | n_rdb = 20 ( RDB number )
365 | n_rdb_conv = 6 ( per RDB conv layer )
366 | """
367 |
368 | for k in range(n_rdb):
369 | with tf.variable_scope('RDB_' + str(k)):
370 | layers = []
371 | layers.append(x_init)
372 |
373 | x = conv(x_init, channels, kernel=3, stride=1, pad=1, use_bias=use_bias, sn=sn, scope='conv_0')
374 | x = batch_norm(x, is_training, scope='batch_norm_0')
375 | x = relu(x)
376 |
377 | layers.append(x)
378 |
379 | for i in range(1, n_rdb_conv):
380 | x = tf.concat(layers, axis=-1)
381 |
382 | x = conv(x, channels, kernel=3, stride=1, pad=1, use_bias=use_bias, sn=sn, scope='conv_' + str(i))
383 | x = batch_norm(x, is_training, scope='batch_norm_' + str(i))
384 | x = relu(x)
385 |
386 | layers.append(x)
387 |
388 | # Local feature fusion
389 | x = tf.concat(layers, axis=-1)
390 | x = conv(x, channels, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv_last')
391 |
392 | # Local residual learning
393 | if channels != x_init.shape[-1] :
394 | x_init = conv(x_init, channels, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='local_skip_conv')
395 | x = relu(x + x_init)
396 | else :
397 | x = x_init + x
398 |
399 | RDBs.append(x)
400 | x_init = x
401 |
402 | with tf.variable_scope('GFF_1x1'):
403 | x = tf.concat(RDBs, axis=-1)
404 | x = conv(x, channels, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv')
405 |
406 | with tf.variable_scope('GFF_3x3'):
407 | x = conv(x, channels, kernel=3, stride=1, pad=1, use_bias=use_bias, sn=sn, scope='conv')
408 |
409 | # Global residual learning
410 | if channels != x_input.shape[-1]:
411 | x_input = conv(x_input, channels, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='global_skip_conv')
412 | x = relu(x + x_input)
413 | else :
414 | x = x_input + x
415 |
416 | return x
417 |
418 |
419 | def self_attention(x, use_bias=True, sn=False, scope='self_attention'):
420 | with tf.variable_scope(scope):
421 | channels = x.shape[-1]
422 | f = conv(x, channels // 8, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='f_conv') # [bs, h, w, c']
423 | g = conv(x, channels // 8, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='g_conv') # [bs, h, w, c']
424 | h = conv(x, channels, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='h_conv') # [bs, h, w, c]
425 |
426 | # N = h * w
427 | s = tf.matmul(hw_flatten(g), hw_flatten(f), transpose_b=True) # # [bs, N, N]
428 |
429 | beta = tf.nn.softmax(s) # attention map
430 |
431 | o = tf.matmul(beta, hw_flatten(h)) # [bs, N, C]
432 | gamma = tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0))
433 |
434 | o = tf.reshape(o, shape=x.shape) # [bs, h, w, C]
435 | x = gamma * o + x
436 |
437 | return x
438 |
439 |
440 | def self_attention_with_pooling(x, use_bias=True, sn=False, scope='self_attention'):
441 | with tf.variable_scope(scope):
442 | channels = x.shape[-1]
443 | f = conv(x, channels // 8, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='f_conv') # [bs, h, w, c']
444 | f = max_pooling(f)
445 |
446 | g = conv(x, channels // 8, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='g_conv') # [bs, h, w, c']
447 |
448 | h = conv(x, channels // 2, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='h_conv') # [bs, h, w, c]
449 | h = max_pooling(h)
450 |
451 | # N = h * w
452 | s = tf.matmul(hw_flatten(g), hw_flatten(f), transpose_b=True) # # [bs, N, N]
453 |
454 | beta = tf.nn.softmax(s) # attention map
455 |
456 | o = tf.matmul(beta, hw_flatten(h)) # [bs, N, C]
457 | gamma = tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0))
458 |
459 | o = tf.reshape(o, shape=[x.shape[0], x.shape[1], x.shape[2], channels // 2]) # [bs, h, w, C]
460 | o = conv(o, channels, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='attn_conv')
461 | x = gamma * o + x
462 |
463 | return x
464 |
465 |
466 | def squeeze_excitation(x, ratio=16, use_bias=True, sn=False, scope='senet'):
467 | with tf.variable_scope(scope):
468 | channels = x.shape[-1]
469 | squeeze = global_avg_pooling(x)
470 |
471 | excitation = fully_connected(squeeze, units=channels // ratio, use_bias=use_bias, sn=sn, scope='fc1')
472 | excitation = relu(excitation)
473 | excitation = fully_connected(excitation, units=channels, use_bias=use_bias, sn=sn, scope='fc2')
474 | excitation = sigmoid(excitation)
475 |
476 | excitation = tf.reshape(excitation, [-1, 1, 1, channels])
477 |
478 | scale = x * excitation
479 |
480 | return scale
481 |
482 |
483 | def convolution_block_attention(x, ratio=16, use_bias=True, sn=False, scope='cbam'):
484 | with tf.variable_scope(scope):
485 | channels = x.shape[-1]
486 | with tf.variable_scope('channel_attention'):
487 | x_gap = global_avg_pooling(x)
488 | x_gap = fully_connected(x_gap, units=channels // ratio, use_bias=use_bias, sn=sn, scope='fc1')
489 | x_gap = relu(x_gap)
490 | x_gap = fully_connected(x_gap, units=channels, use_bias=use_bias, sn=sn, scope='fc2')
491 |
492 | with tf.variable_scope('channel_attention', reuse=True):
493 | x_gmp = global_max_pooling(x)
494 | x_gmp = fully_connected(x_gmp, units=channels // ratio, use_bias=use_bias, sn=sn, scope='fc1')
495 | x_gmp = relu(x_gmp)
496 | x_gmp = fully_connected(x_gmp, units=channels, use_bias=use_bias, sn=sn, scope='fc2')
497 |
498 | scale = tf.reshape(x_gap + x_gmp, [-1, 1, 1, channels])
499 | scale = sigmoid(scale)
500 |
501 | x = x * scale
502 |
503 | with tf.variable_scope('spatial_attention'):
504 | x_channel_avg_pooling = tf.reduce_mean(x, axis=-1, keepdims=True)
505 | x_channel_max_pooling = tf.reduce_max(x, axis=-1, keepdims=True)
506 | scale = tf.concat([x_channel_avg_pooling, x_channel_max_pooling], axis=-1)
507 |
508 | scale = conv(scale, channels=1, kernel=7, stride=1, pad=3, pad_type='reflect', use_bias=False, sn=sn, scope='conv')
509 | scale = sigmoid(scale)
510 |
511 | x = x * scale
512 |
513 | return x
514 |
515 |
516 | def global_context_block(x, use_bias=True, sn=False, scope='gc_block'):
517 | with tf.variable_scope(scope):
518 | channels = x.shape[-1]
519 | with tf.variable_scope('context_modeling'):
520 | bs, h, w, c = x.get_shape().as_list()
521 | input_x = x
522 | input_x = hw_flatten(input_x) # [N, H*W, C]
523 | input_x = tf.transpose(input_x, perm=[0, 2, 1])
524 | input_x = tf.expand_dims(input_x, axis=1)
525 |
526 | context_mask = conv(x, channels=1, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv')
527 | context_mask = hw_flatten(context_mask)
528 | context_mask = tf.nn.softmax(context_mask, axis=1) # [N, H*W, 1]
529 | context_mask = tf.transpose(context_mask, perm=[0, 2, 1])
530 | context_mask = tf.expand_dims(context_mask, axis=-1)
531 |
532 | context = tf.matmul(input_x, context_mask)
533 | context = tf.reshape(context, shape=[bs, 1, 1, c])
534 |
535 | with tf.variable_scope('transform_0'):
536 | context_transform = conv(context, channels, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv_0')
537 | context_transform = layer_norm(context_transform)
538 | context_transform = relu(context_transform)
539 | context_transform = conv(context_transform, channels=c, kernel=1, stride=1, use_bias=use_bias, sn=sn,
540 | scope='conv_1')
541 | context_transform = sigmoid(context_transform)
542 |
543 | x = x * context_transform
544 |
545 | with tf.variable_scope('transform_1'):
546 | context_transform = conv(context, channels, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv_0')
547 | context_transform = layer_norm(context_transform)
548 | context_transform = relu(context_transform)
549 | context_transform = conv(context_transform, channels=c, kernel=1, stride=1, use_bias=use_bias, sn=sn,
550 | scope='conv_1')
551 |
552 | x = x + context_transform
553 |
554 | return x
555 |
556 |
557 | def srm_block(x, use_bias=False, is_training=True, scope='srm_block'):
558 | with tf.variable_scope(scope):
559 | bs, h, w, channels = x.get_shape().as_list() # c = channels
560 |
561 | x = tf.reshape(x, shape=[bs, -1, channels]) # [bs, h*w, c]
562 |
563 | x_mean, x_var = tf.nn.moments(x, axes=1, keep_dims=True) # [bs, 1, c]
564 | x_std = tf.sqrt(x_var + 1e-5)
565 |
566 | t = tf.concat([x_mean, x_std], axis=1) # [bs, 2, c]
567 |
568 | z = tf.layers.conv1d(t, channels, kernel_size=2, strides=1, use_bias=use_bias)
569 | z = batch_norm(z, is_training=is_training)
570 |
571 | g = tf.sigmoid(z)
572 |
573 | x = tf.reshape(x * g, shape=[bs, h, w, channels])
574 |
575 | return x
576 |
577 |
578 | ##################################################################################
579 | # Normalization
580 | ##################################################################################
581 |
582 | def batch_norm(x, is_training=False, scope='batch_norm'):
583 | """
584 | if x_norm = tf.layers.batch_normalization
585 |
586 | # ...
587 |
588 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
589 | train_op = optimizer.minimize(loss)
590 | """
591 |
592 | return tf.contrib.layers.batch_norm(x,
593 | decay=0.9, epsilon=1e-05,
594 | center=True, scale=True, updates_collections=None,
595 | is_training=is_training, scope=scope)
596 |
597 | # return tf.layers.batch_normalization(x, momentum=0.9, epsilon=1e-05, center=True, scale=True, training=is_training, name=scope)
598 |
599 |
600 | def instance_norm(x, scope='instance_norm'):
601 | return tf.contrib.layers.instance_norm(x,
602 | epsilon=1e-05,
603 | center=True, scale=True,
604 | scope=scope)
605 |
606 |
607 | def layer_norm(x, scope='layer_norm'):
608 | return tf.contrib.layers.layer_norm(x,
609 | center=True, scale=True,
610 | scope=scope)
611 |
612 |
613 | def group_norm(x, groups=32, scope='group_norm'):
614 | return tf.contrib.layers.group_norm(x, groups=groups, epsilon=1e-05,
615 | center=True, scale=True,
616 | scope=scope)
617 |
618 |
619 | def adaptive_instance_norm(content, gamma, beta, epsilon=1e-5):
620 | # gamma, beta = style_mean, style_std from MLP
621 | # See https://github.com/taki0112/MUNIT-Tensorflow
622 |
623 | c_mean, c_var = tf.nn.moments(content, axes=[1, 2], keep_dims=True)
624 | c_std = tf.sqrt(c_var + epsilon)
625 |
626 | return gamma * ((content - c_mean) / c_std) + beta
627 |
628 | def adaptive_layer_instance_norm(x, gamma, beta, smoothing=True, scope='ada_layer_instance_norm') :
629 | # proposed by UGATIT
630 | # https://github.com/taki0112/UGATIT
631 | with tf.variable_scope(scope):
632 | ch = x.shape[-1]
633 | eps = 1e-5
634 |
635 | ins_mean, ins_sigma = tf.nn.moments(x, axes=[1, 2], keep_dims=True)
636 | x_ins = (x - ins_mean) / (tf.sqrt(ins_sigma + eps))
637 |
638 | ln_mean, ln_sigma = tf.nn.moments(x, axes=[1, 2, 3], keep_dims=True)
639 | x_ln = (x - ln_mean) / (tf.sqrt(ln_sigma + eps))
640 |
641 | rho = tf.get_variable("rho", [ch], initializer=tf.constant_initializer(1.0), constraint=lambda x: tf.clip_by_value(x, clip_value_min=0.0, clip_value_max=1.0))
642 |
643 | if smoothing :
644 | rho = tf.clip_by_value(rho - tf.constant(0.1), 0.0, 1.0)
645 |
646 | x_hat = rho * x_ins + (1 - rho) * x_ln
647 |
648 |
649 | x_hat = x_hat * gamma + beta
650 |
651 | return x_hat
652 |
653 |
654 | def condition_batch_norm(x, z, is_training=True, scope='batch_norm'):
655 | # See https://github.com/taki0112/BigGAN-Tensorflow
656 | with tf.variable_scope(scope):
657 | _, _, _, c = x.get_shape().as_list()
658 | decay = 0.9
659 | epsilon = 1e-05
660 |
661 | test_mean = tf.get_variable("pop_mean", shape=[c], dtype=tf.float32,
662 | initializer=tf.constant_initializer(0.0), trainable=False)
663 | test_var = tf.get_variable("pop_var", shape=[c], dtype=tf.float32, initializer=tf.constant_initializer(1.0),
664 | trainable=False)
665 |
666 | beta = fully_connected(z, units=c, scope='beta')
667 | gamma = fully_connected(z, units=c, scope='gamma')
668 |
669 | beta = tf.reshape(beta, shape=[-1, 1, 1, c])
670 | gamma = tf.reshape(gamma, shape=[-1, 1, 1, c])
671 |
672 | if is_training:
673 | batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2])
674 | ema_mean = tf.assign(test_mean, test_mean * decay + batch_mean * (1 - decay))
675 | ema_var = tf.assign(test_var, test_var * decay + batch_var * (1 - decay))
676 |
677 | with tf.control_dependencies([ema_mean, ema_var]):
678 | return tf.nn.batch_normalization(x, batch_mean, batch_var, beta, gamma, epsilon)
679 | else:
680 | return tf.nn.batch_normalization(x, test_mean, test_var, beta, gamma, epsilon)
681 |
682 |
683 | def batch_instance_norm(x, scope='batch_instance_norm'):
684 | with tf.variable_scope(scope):
685 | ch = x.shape[-1]
686 | eps = 1e-5
687 |
688 | batch_mean, batch_sigma = tf.nn.moments(x, axes=[0, 1, 2], keep_dims=True)
689 | x_batch = (x - batch_mean) / (tf.sqrt(batch_sigma + eps))
690 |
691 | ins_mean, ins_sigma = tf.nn.moments(x, axes=[1, 2], keep_dims=True)
692 | x_ins = (x - ins_mean) / (tf.sqrt(ins_sigma + eps))
693 |
694 | rho = tf.get_variable("rho", [ch], initializer=tf.constant_initializer(1.0), constraint=lambda x: tf.clip_by_value(x, clip_value_min=0.0, clip_value_max=1.0))
695 | gamma = tf.get_variable("gamma", [ch], initializer=tf.constant_initializer(1.0))
696 | beta = tf.get_variable("beta", [ch], initializer=tf.constant_initializer(0.0))
697 |
698 | x_hat = rho * x_batch + (1 - rho) * x_ins
699 | x_hat = x_hat * gamma + beta
700 |
701 | return x_hat
702 |
703 | def layer_instance_norm(x, scope='layer_instance_norm') :
704 | # proposed by UGATIT
705 | # https://github.com/taki0112/UGATIT
706 | with tf.variable_scope(scope):
707 | ch = x.shape[-1]
708 | eps = 1e-5
709 |
710 | ins_mean, ins_sigma = tf.nn.moments(x, axes=[1, 2], keep_dims=True)
711 | x_ins = (x - ins_mean) / (tf.sqrt(ins_sigma + eps))
712 |
713 | ln_mean, ln_sigma = tf.nn.moments(x, axes=[1, 2, 3], keep_dims=True)
714 | x_ln = (x - ln_mean) / (tf.sqrt(ln_sigma + eps))
715 |
716 | rho = tf.get_variable("rho", [ch], initializer=tf.constant_initializer(0.0), constraint=lambda x: tf.clip_by_value(x, clip_value_min=0.0, clip_value_max=1.0))
717 |
718 | gamma = tf.get_variable("gamma", [ch], initializer=tf.constant_initializer(1.0))
719 | beta = tf.get_variable("beta", [ch], initializer=tf.constant_initializer(0.0))
720 |
721 | x_hat = rho * x_ins + (1 - rho) * x_ln
722 |
723 | x_hat = x_hat * gamma + beta
724 |
725 | return x_hat
726 |
727 | def pixel_norm(x, epsilon=1e-8):
728 | return x * tf.rsqrt(tf.reduce_mean(tf.square(x), axis=-1, keepdims=True) + epsilon)
729 |
730 | def switch_norm(x, scope='switch_norm'):
731 | with tf.variable_scope(scope):
732 | ch = x.shape[-1]
733 | eps = 1e-5
734 |
735 | batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], keep_dims=True)
736 | ins_mean, ins_var = tf.nn.moments(x, [1, 2], keep_dims=True)
737 | layer_mean, layer_var = tf.nn.moments(x, [1, 2, 3], keep_dims=True)
738 |
739 | gamma = tf.get_variable("gamma", [ch], initializer=tf.constant_initializer(1.0))
740 | beta = tf.get_variable("beta", [ch], initializer=tf.constant_initializer(0.0))
741 |
742 | mean_weight = tf.nn.softmax(tf.get_variable("mean_weight", [3], initializer=tf.constant_initializer(1.0)))
743 | var_wegiht = tf.nn.softmax(tf.get_variable("var_weight", [3], initializer=tf.constant_initializer(1.0)))
744 |
745 | mean = mean_weight[0] * batch_mean + mean_weight[1] * ins_mean + mean_weight[2] * layer_mean
746 | var = var_wegiht[0] * batch_var + var_wegiht[1] * ins_var + var_wegiht[2] * layer_var
747 |
748 | x = (x - mean) / (tf.sqrt(var + eps))
749 | x = x * gamma + beta
750 |
751 | return x
752 |
753 | def spectral_norm(w, iteration=1):
754 | w_shape = w.shape.as_list()
755 | w = tf.reshape(w, [-1, w_shape[-1]])
756 |
757 | u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.random_normal_initializer(), trainable=False)
758 |
759 | u_hat = u
760 | v_hat = None
761 | for i in range(iteration):
762 | """
763 | power iteration
764 | Usually iteration = 1 will be enough
765 | """
766 | v_ = tf.matmul(u_hat, tf.transpose(w))
767 | v_hat = tf.nn.l2_normalize(v_)
768 |
769 | u_ = tf.matmul(v_hat, w)
770 | u_hat = tf.nn.l2_normalize(u_)
771 |
772 | u_hat = tf.stop_gradient(u_hat)
773 | v_hat = tf.stop_gradient(v_hat)
774 |
775 | sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat))
776 |
777 | with tf.control_dependencies([u.assign(u_hat)]):
778 | w_norm = w / sigma
779 | w_norm = tf.reshape(w_norm, w_shape)
780 |
781 | return w_norm
782 |
783 | ##################################################################################
784 | # Activation Function
785 | ##################################################################################
786 |
787 | def lrelu(x, alpha=0.01):
788 | # pytorch alpha is 0.01
789 | return tf.nn.leaky_relu(x, alpha)
790 |
791 |
792 | def relu(x):
793 | return tf.nn.relu(x)
794 |
795 |
796 | def tanh(x):
797 | return tf.tanh(x)
798 |
799 |
800 | def sigmoid(x):
801 | return tf.sigmoid(x)
802 |
803 |
804 | def swish(x):
805 | return x * tf.sigmoid(x)
806 |
807 |
808 | def elu(x):
809 | return tf.nn.elu(x)
810 |
811 |
812 | ##################################################################################
813 | # Pooling & Resize
814 | ##################################################################################
815 |
816 | def nearest_up_sample(x, scale_factor=2):
817 | _, h, w, _ = x.get_shape().as_list()
818 | new_size = [h * scale_factor, w * scale_factor]
819 | return tf.image.resize_nearest_neighbor(x, size=new_size)
820 |
821 | def bilinear_up_sample(x, scale_factor=2):
822 | _, h, w, _ = x.get_shape().as_list()
823 | new_size = [h * scale_factor, w * scale_factor]
824 | return tf.image.resize_bilinear(x, size=new_size)
825 |
826 | def nearest_down_sample(x, scale_factor=2):
827 | _, h, w, _ = x.get_shape().as_list()
828 | new_size = [h // scale_factor, w // scale_factor]
829 | return tf.image.resize_nearest_neighbor(x, size=new_size)
830 |
831 | def bilinear_down_sample(x, scale_factor=2):
832 | _, h, w, _ = x.get_shape().as_list()
833 | new_size = [h // scale_factor, w // scale_factor]
834 | return tf.image.resize_bilinear(x, size=new_size)
835 |
836 | def global_avg_pooling(x):
837 | gap = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
838 | return gap
839 |
840 |
841 | def global_max_pooling(x):
842 | gmp = tf.reduce_max(x, axis=[1, 2], keepdims=True)
843 | return gmp
844 |
845 |
846 | def max_pooling(x, pool_size=2):
847 | x = tf.layers.max_pooling2d(x, pool_size=pool_size, strides=pool_size, padding='SAME')
848 | return x
849 |
850 |
851 | def avg_pooling(x, pool_size=2):
852 | x = tf.layers.average_pooling2d(x, pool_size=pool_size, strides=pool_size, padding='SAME')
853 | return x
854 |
855 |
856 | def flatten(x):
857 | return tf.layers.flatten(x)
858 |
859 |
860 | def hw_flatten(x):
861 | return tf.reshape(x, shape=[x.shape[0], -1, x.shape[-1]])
862 |
863 |
864 | ##################################################################################
865 | # Loss Function
866 | ##################################################################################
867 |
868 | def classification_loss(logit, label):
869 | loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=label, logits=logit))
870 | prediction = tf.equal(tf.argmax(logit, -1), tf.argmax(label, -1))
871 | accuracy = tf.reduce_mean(tf.cast(prediction, tf.float32))
872 |
873 | return loss, accuracy
874 |
875 |
876 | def L1_loss(x, y):
877 | loss = tf.reduce_mean(tf.abs(x - y))
878 |
879 | return loss
880 |
881 |
882 | def L2_loss(x, y):
883 | loss = tf.reduce_mean(tf.square(x - y))
884 |
885 | return loss
886 |
887 |
888 | def huber_loss(x, y):
889 | return tf.losses.huber_loss(x, y)
890 |
891 |
892 | def regularization_loss(scope_name):
893 | """
894 | If you want to use "Regularization"
895 | g_loss += regularization_loss('generator')
896 | d_loss += regularization_loss('discriminator')
897 | """
898 | collection_regularization = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
899 |
900 | loss = []
901 | for item in collection_regularization:
902 | if scope_name in item.name:
903 | loss.append(item)
904 |
905 | return tf.reduce_sum(loss)
906 |
907 |
908 | def histogram_loss(x, y):
909 | histogram_x = get_histogram(x)
910 | histogram_y = get_histogram(y)
911 |
912 | hist_loss = L1_loss(histogram_x, histogram_y)
913 |
914 | return hist_loss
915 |
916 |
917 | def get_histogram(img, bin_size=0.2):
918 | hist_entries = []
919 |
920 | img_r, img_g, img_b = tf.split(img, num_or_size_splits=3, axis=-1)
921 |
922 | for img_chan in [img_r, img_g, img_b]:
923 | for i in np.arange(-1, 1, bin_size):
924 | gt = tf.greater(img_chan, i)
925 | leq = tf.less_equal(img_chan, i + bin_size)
926 |
927 | condition = tf.cast(tf.logical_and(gt, leq), tf.float32)
928 | hist_entries.append(tf.reduce_sum(condition))
929 |
930 | hist = normalization(hist_entries)
931 |
932 | return hist
933 |
934 |
935 | def normalization(x):
936 | x = (x - tf.reduce_min(x)) / (tf.reduce_max(x) - tf.reduce_min(x))
937 | return x
938 |
939 |
940 | def gram_matrix(x):
941 | b, h, w, c = x.get_shape().as_list()
942 |
943 | x = tf.reshape(x, shape=[b, -1, c])
944 |
945 | x = tf.matmul(tf.transpose(x, perm=[0, 2, 1]), x)
946 | x = x / (h * w * c)
947 |
948 | return x
949 |
950 |
951 | def gram_style_loss(x, y):
952 | _, height, width, channels = x.get_shape().as_list()
953 |
954 | x = gram_matrix(x)
955 | y = gram_matrix(y)
956 |
957 | loss = L2_loss(x, y) # simple version
958 |
959 | # Original eqn as a constant to divide i.e 1/(4. * (channels ** 2) * (width * height) ** 2)
960 | # loss = tf.reduce_mean(tf.square(x - y)) / (channels ** 2 * width * height) # (4.0 * (channels ** 2) * (width * height) ** 2)
961 |
962 | return loss
963 |
964 |
965 | def color_consistency_loss(x, y):
966 | x_mu, x_var = tf.nn.moments(x, axes=[1, 2], keep_dims=True)
967 | y_mu, y_var = tf.nn.moments(y, axes=[1, 2], keep_dims=True)
968 |
969 | loss = L2_loss(x_mu, y_mu) + 5.0 * L2_loss(x_var, y_var)
970 |
971 | return loss
972 |
973 |
974 | def dice_loss(n_classes, logits, labels):
975 | """
976 | :param n_classes: number of classes
977 | :param logits: [batch_size, m, n, n_classes] float32, output logits
978 | :param labels: [batch_size, m, n, 1] int32, class label
979 | :return:
980 | """
981 |
982 | # https://github.com/keras-team/keras/issues/9395
983 |
984 | smooth = 1e-7
985 | dtype = tf.float32
986 |
987 | # alpha=beta=0.5 : dice coefficient
988 | # alpha=beta=1 : tanimoto coefficient (also known as jaccard)
989 | # alpha+beta=1 : produces set of F*-scores
990 | alpha, beta = 0.5, 0.5
991 |
992 | # make onehot label [batch_size, m, n, n_classes]
993 | # tf.one_hot() will ignore (creates zero vector) labels larger than n_class and less then 0
994 | onehot_labels = tf.one_hot(tf.squeeze(labels, axis=-1), depth=n_classes, dtype=dtype)
995 |
996 | ones = tf.ones_like(onehot_labels, dtype=dtype)
997 | predicted = tf.nn.softmax(logits)
998 | p0 = predicted
999 | p1 = ones - predicted
1000 | g0 = onehot_labels
1001 | g1 = ones - onehot_labels
1002 |
1003 | num = tf.reduce_sum(p0 * g0, axis=[0, 1, 2])
1004 | den = num + alpha * tf.reduce_sum(p0 * g1, axis=[0, 1, 2]) + beta * tf.reduce_sum(p1 * g0, axis=[0, 1, 2])
1005 |
1006 | loss = tf.cast(n_classes, dtype=dtype) - tf.reduce_sum((num + smooth) / (den + smooth))
1007 | return loss
1008 |
1009 |
1010 | ##################################################################################
1011 | # GAN Loss Function
1012 | ##################################################################################
1013 |
1014 | def discriminator_loss(Ra, gan_type, real, fake):
1015 | # Ra = Relativistic
1016 | real_loss = 0
1017 | fake_loss = 0
1018 |
1019 | if Ra and (gan_type.__contains__('wgan') or gan_type == 'sphere'):
1020 | print("No exist [Ra + WGAN or Ra + Sphere], so use the {} loss function".format(gan_type))
1021 | Ra = False
1022 |
1023 | if Ra:
1024 | real_logit = (real - tf.reduce_mean(fake))
1025 | fake_logit = (fake - tf.reduce_mean(real))
1026 |
1027 | if gan_type == 'lsgan':
1028 | real_loss = tf.reduce_mean(tf.square(real_logit - 1.0))
1029 | fake_loss = tf.reduce_mean(tf.square(fake_logit + 1.0))
1030 |
1031 | if gan_type == 'gan' or gan_type == 'gan-gp' or gan_type == 'dragan':
1032 | real_loss = tf.reduce_mean(
1033 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real), logits=real_logit))
1034 | fake_loss = tf.reduce_mean(
1035 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(fake), logits=fake_logit))
1036 |
1037 | if gan_type == 'hinge':
1038 | real_loss = tf.reduce_mean(relu(1.0 - real_logit))
1039 | fake_loss = tf.reduce_mean(relu(1.0 + fake_logit))
1040 |
1041 | else:
1042 | if gan_type.__contains__('wgan'):
1043 | real_loss = -tf.reduce_mean(real)
1044 | fake_loss = tf.reduce_mean(fake)
1045 |
1046 | if gan_type == 'lsgan':
1047 | real_loss = tf.reduce_mean(tf.square(real - 1.0))
1048 | fake_loss = tf.reduce_mean(tf.square(fake))
1049 |
1050 | if gan_type == 'gan' or gan_type == 'gan-gp' or gan_type == 'dragan':
1051 | real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real), logits=real))
1052 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(fake), logits=fake))
1053 |
1054 | if gan_type == 'hinge':
1055 | real_loss = tf.reduce_mean(relu(1.0 - real))
1056 | fake_loss = tf.reduce_mean(relu(1.0 + fake))
1057 |
1058 | if gan_type == 'sphere':
1059 | bs, c = real.get_shape().as_list()
1060 | moment = 3
1061 | north_pole = tf.one_hot(tf.tile([c], multiples=[bs]), depth=c + 1) # [bs, c+1] -> [0, 0, 0, ... , 1]
1062 |
1063 | real_projection = inverse_stereographic_projection(real)
1064 | fake_projection = inverse_stereographic_projection(fake)
1065 |
1066 | for i in range(1, moment + 1):
1067 | real_loss += -tf.reduce_mean(tf.pow(sphere_loss(real_projection, north_pole), i))
1068 | fake_loss += tf.reduce_mean(tf.pow(sphere_loss(fake_projection, north_pole), i))
1069 |
1070 |
1071 | loss = real_loss + fake_loss
1072 |
1073 | return loss
1074 |
1075 |
1076 | def generator_loss(Ra, gan_type, real, fake):
1077 | # Ra = Relativistic
1078 | fake_loss = 0
1079 | real_loss = 0
1080 |
1081 | if Ra and (gan_type.__contains__('wgan') or gan_type == 'sphere'):
1082 | print("No exist [Ra + WGAN or Ra + Sphere], so use the {} loss function".format(gan_type))
1083 | Ra = False
1084 |
1085 | if Ra:
1086 | fake_logit = (fake - tf.reduce_mean(real))
1087 | real_logit = (real - tf.reduce_mean(fake))
1088 |
1089 | if gan_type == 'lsgan':
1090 | fake_loss = tf.reduce_mean(tf.square(fake_logit - 1.0))
1091 | real_loss = tf.reduce_mean(tf.square(real_logit + 1.0))
1092 |
1093 | if gan_type == 'gan' or gan_type == 'gan-gp' or gan_type == 'dragan':
1094 | fake_loss = tf.reduce_mean(
1095 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(fake), logits=fake_logit))
1096 | real_loss = tf.reduce_mean(
1097 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(real), logits=real_logit))
1098 |
1099 | if gan_type == 'hinge':
1100 | fake_loss = tf.reduce_mean(relu(1.0 - fake_logit))
1101 | real_loss = tf.reduce_mean(relu(1.0 + real_logit))
1102 |
1103 | else:
1104 | if gan_type.__contains__('wgan'):
1105 | fake_loss = -tf.reduce_mean(fake)
1106 |
1107 | if gan_type == 'lsgan':
1108 | fake_loss = tf.reduce_mean(tf.square(fake - 1.0))
1109 |
1110 | if gan_type == 'gan' or gan_type == 'gan-gp' or gan_type == 'dragan':
1111 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(fake), logits=fake))
1112 |
1113 | if gan_type == 'hinge':
1114 | fake_loss = -tf.reduce_mean(fake)
1115 |
1116 | if gan_type == 'sphere':
1117 | bs, c = real.get_shape().as_list()
1118 | moment = 3
1119 | north_pole = tf.one_hot(tf.tile([c], multiples=[bs]), depth=c + 1) # [bs, c+1] -> [0, 0, 0, ... , 1]
1120 |
1121 | fake_projection = inverse_stereographic_projection(fake)
1122 |
1123 | for i in range(1, moment + 1):
1124 | fake_loss += -tf.reduce_mean(tf.pow(sphere_loss(fake_projection, north_pole), i))
1125 |
1126 | loss = fake_loss + real_loss
1127 |
1128 | return loss
1129 |
1130 |
1131 | def vdb_loss(mu, logvar, i_c=0.1):
1132 | # variational discriminator bottleneck loss
1133 | kl_divergence = 0.5 * tf.reduce_sum(tf.square(mu) + tf.exp(logvar) - 1 - logvar, axis=-1)
1134 |
1135 | loss = tf.reduce_mean(kl_divergence - i_c)
1136 |
1137 | return loss
1138 |
1139 |
1140 | def simple_gp(real_logit, fake_logit, real_images, fake_images, r1_gamma=10, r2_gamma=0):
1141 | # Used in StyleGAN
1142 |
1143 | r1_penalty = 0
1144 | r2_penalty = 0
1145 |
1146 | if r1_gamma != 0:
1147 | real_loss = tf.reduce_sum(real_logit) # In some cases, you may use reduce_mean
1148 | real_grads = tf.gradients(real_loss, real_images)[0]
1149 |
1150 | r1_penalty = 0.5 * r1_gamma * tf.reduce_mean(tf.reduce_sum(tf.square(real_grads), axis=[1, 2, 3]))
1151 |
1152 | if r2_gamma != 0:
1153 | fake_loss = tf.reduce_sum(fake_logit) # In some cases, you may use reduce_mean
1154 | fake_grads = tf.gradients(fake_loss, fake_images)[0]
1155 |
1156 | r2_penalty = 0.5 * r2_gamma * tf.reduce_mean(tf.reduce_sum(tf.square(fake_grads), axis=[1, 2, 3]))
1157 |
1158 | return r1_penalty + r2_penalty
1159 |
1160 | def inverse_stereographic_projection(x) :
1161 |
1162 | x_u = tf.transpose(2 * x) / (tf.pow(tf.norm(x, axis=-1), 2) + 1.0)
1163 | x_v = (tf.pow(tf.norm(x, axis=-1), 2) - 1.0) / (tf.pow(tf.norm(x, axis=-1), 2) + 1.0)
1164 |
1165 | x_projection = tf.transpose(tf.concat([x_u, [x_v]], axis=0))
1166 |
1167 | return x_projection
1168 |
1169 | def sphere_loss(x, y) :
1170 |
1171 | loss = tf.math.acos(tf.matmul(x, tf.transpose(y)))
1172 |
1173 | return loss
1174 |
1175 | ##################################################################################
1176 | # KL-Divergence Loss Function
1177 | ##################################################################################
1178 |
1179 | # typical version
1180 | def z_sample(mean, logvar):
1181 | eps = tf.random_normal(tf.shape(mean), mean=0.0, stddev=1.0, dtype=tf.float32)
1182 |
1183 | return mean + tf.exp(logvar * 0.5) * eps
1184 |
1185 |
1186 | def kl_loss(mean, logvar):
1187 | # shape : [batch_size, channel]
1188 | loss = 0.5 * tf.reduce_sum(tf.square(mean) + tf.exp(logvar) - 1 - logvar, axis=-1)
1189 | loss = tf.reduce_mean(loss)
1190 |
1191 | return loss
1192 |
1193 |
1194 | # version 2
1195 | def z_sample_2(mean, sigma):
1196 | eps = tf.random_normal(tf.shape(mean), mean=0.0, stddev=1.0, dtype=tf.float32)
1197 |
1198 | return mean + sigma * eps
1199 |
1200 |
1201 | def kl_loss_2(mean, sigma):
1202 | # shape : [batch_size, channel]
1203 | loss = 0.5 * tf.reduce_sum(tf.square(mean) + tf.square(sigma) - tf.log(1e-8 + tf.square(sigma)) - 1, axis=-1)
1204 | loss = tf.reduce_mean(loss)
1205 |
1206 | return loss
1207 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import random, os
4 | from tensorflow.contrib import slim
5 | import cv2
6 |
7 | class ImageData:
8 |
9 | def __init__(self, img_height, img_width, channels, augment_flag):
10 | self.img_height = img_height
11 | self.img_width = img_width
12 | self.channels = channels
13 | self.augment_flag = augment_flag
14 |
15 | def image_processing(self, filename):
16 | x = tf.read_file(filename)
17 | x_decode = tf.image.decode_jpeg(x, channels=self.channels, dct_method='INTEGER_ACCURATE')
18 | img = tf.image.resize_images(x_decode, [self.img_height, self.img_width])
19 | img = tf.cast(img, tf.float32) / 127.5 - 1
20 |
21 | if self.augment_flag :
22 | augment_height = self.img_height + (30 if self.img_height == 256 else int(self.img_height * 0.1))
23 | augment_width = self.img_width + (30 if self.img_width == 256 else int(self.img_width * 0.1))
24 |
25 | img = tf.cond(pred=tf.greater_equal(tf.random_uniform(shape=[], minval=0.0, maxval=1.0), 0.5),
26 | true_fn=lambda: augmentation(img, augment_height, augment_width),
27 | false_fn=lambda: img)
28 |
29 | return img
30 |
31 | def load_test_image(image_path, img_width, img_height, img_channel):
32 |
33 | if img_channel == 1 :
34 | img = cv2.imread(image_path, flags=cv2.IMREAD_GRAYSCALE)
35 | else :
36 | img = cv2.imread(image_path, flags=cv2.IMREAD_COLOR)
37 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
38 |
39 | img = cv2.resize(img, dsize=(img_width, img_height))
40 |
41 | if img_channel == 1 :
42 | img = np.expand_dims(img, axis=0)
43 | img = np.expand_dims(img, axis=-1)
44 | else :
45 | img = np.expand_dims(img, axis=0)
46 |
47 | img = img/127.5 - 1
48 |
49 | return img
50 |
51 | def augmentation(image, augment_height, augment_width):
52 | seed = random.randint(0, 2 ** 31 - 1)
53 | ori_image_shape = tf.shape(image)
54 | image = tf.image.random_flip_left_right(image, seed=seed)
55 | image = tf.image.resize_images(image, [augment_height, augment_width])
56 | image = tf.random_crop(image, ori_image_shape, seed=seed)
57 | return image
58 |
59 | def save_images(images, size, image_path):
60 | return imsave(inverse_transform(images), size, image_path)
61 |
62 | def inverse_transform(images):
63 | return ((images+1.) / 2) * 255.0
64 |
65 |
66 | def imsave(images, size, path):
67 | images = merge(images, size)
68 | images = cv2.cvtColor(images.astype('uint8'), cv2.COLOR_RGB2BGR)
69 |
70 | return cv2.imwrite(path, images)
71 |
72 | def merge(images, size):
73 | h, w = images.shape[1], images.shape[2]
74 | img = np.zeros((h * size[0], w * size[1], 3))
75 | for idx, image in enumerate(images):
76 | i = idx % size[1]
77 | j = idx // size[1]
78 | img[h*j:h*(j+1), w*i:w*(i+1), :] = image
79 |
80 | return img
81 |
82 | def orthogonal_regularizer(scale) :
83 | """ Defining the Orthogonal regularizer and return the function at last to be used in Conv layer as kernel regularizer"""
84 |
85 | def ortho_reg(w) :
86 | """ Reshaping the matrxi in to 2D tensor for enforcing orthogonality"""
87 | _, _, _, c = w.get_shape().as_list()
88 |
89 | w = tf.reshape(w, [-1, c])
90 |
91 | """ Declaring a Identity Tensor of appropriate size"""
92 | identity = tf.eye(c)
93 |
94 | """ Regularizer Wt*W - I """
95 | w_transpose = tf.transpose(w)
96 | w_mul = tf.matmul(w_transpose, w)
97 | reg = tf.subtract(w_mul, identity)
98 |
99 | """Calculating the Loss Obtained"""
100 | ortho_loss = tf.nn.l2_loss(reg)
101 |
102 | return scale * ortho_loss
103 |
104 | return ortho_reg
105 |
106 | def orthogonal_regularizer_fully(scale) :
107 | """ Defining the Orthogonal regularizer and return the function at last to be used in Fully Connected Layer """
108 |
109 | def ortho_reg_fully(w) :
110 | """ Reshaping the matrix in to 2D tensor for enforcing orthogonality"""
111 | _, c = w.get_shape().as_list()
112 |
113 | """Declaring a Identity Tensor of appropriate size"""
114 | identity = tf.eye(c)
115 | w_transpose = tf.transpose(w)
116 | w_mul = tf.matmul(w_transpose, w)
117 | reg = tf.subtract(w_mul, identity)
118 |
119 | """ Calculating the Loss """
120 | ortho_loss = tf.nn.l2_loss(reg)
121 |
122 | return scale * ortho_loss
123 |
124 | return ortho_reg_fully
125 |
126 | def tf_rgb_to_gray(x) :
127 | x = (x + 1.0) * 0.5
128 | x = tf.image.rgb_to_grayscale(x)
129 |
130 | x = (x * 2) - 1.0
131 |
132 | return x
133 |
134 | def RGB2LAB(srgb):
135 | srgb = inverse_transform(srgb)
136 |
137 | lab = rgb_to_lab(srgb)
138 | l, a, b = preprocess_lab(lab)
139 |
140 | l = tf.expand_dims(l, axis=-1)
141 | a = tf.expand_dims(a, axis=-1)
142 | b = tf.expand_dims(b, axis=-1)
143 |
144 | x = tf.concat([l, a, b], axis=-1)
145 |
146 | return x
147 |
148 | def LAB2RGB(lab) :
149 | lab = inverse_transform(lab)
150 |
151 | rgb = lab_to_rgb(lab)
152 | rgb = tf.clip_by_value(rgb, 0, 1)
153 |
154 | # r, g, b = tf.unstack(rgb, axis=-1)
155 | # rgb = tf.concat([r,g,b], axis=-1)
156 |
157 | x = (rgb * 2) - 1.0
158 |
159 | return x
160 |
161 | def rgb_to_lab(srgb):
162 | with tf.name_scope('rgb_to_lab'):
163 | srgb_pixels = tf.reshape(srgb, [-1, 3])
164 | with tf.name_scope('srgb_to_xyz'):
165 | linear_mask = tf.cast(srgb_pixels <= 0.04045, dtype=tf.float32)
166 | exponential_mask = tf.cast(srgb_pixels > 0.04045, dtype=tf.float32)
167 | rgb_pixels = (srgb_pixels / 12.92 * linear_mask) + (((srgb_pixels + 0.055) / 1.055) ** 2.4) * exponential_mask
168 | rgb_to_xyz = tf.constant([
169 | # X Y Z
170 | [0.412453, 0.212671, 0.019334], # R
171 | [0.357580, 0.715160, 0.119193], # G
172 | [0.180423, 0.072169, 0.950227], # B
173 | ])
174 | xyz_pixels = tf.matmul(rgb_pixels, rgb_to_xyz)
175 |
176 | with tf.name_scope('xyz_to_cielab'):
177 | # convert to fx = f(X/Xn), fy = f(Y/Yn), fz = f(Z/Zn)
178 |
179 | # normalize for D65 white point
180 | xyz_normalized_pixels = tf.multiply(xyz_pixels, [1/0.950456, 1.0, 1/1.088754])
181 |
182 | epsilon = 6/29
183 | linear_mask = tf.cast(xyz_normalized_pixels <= (epsilon**3), dtype=tf.float32)
184 | exponential_mask = tf.cast(xyz_normalized_pixels > (epsilon**3), dtype=tf.float32)
185 | fxfyfz_pixels = (xyz_normalized_pixels / (3 * epsilon**2) + 4/29) * linear_mask + (xyz_normalized_pixels ** (1/3)) * exponential_mask
186 |
187 | # convert to lab
188 | fxfyfz_to_lab = tf.constant([
189 | # l a b
190 | [ 0.0, 500.0, 0.0], # fx
191 | [116.0, -500.0, 200.0], # fy
192 | [ 0.0, 0.0, -200.0], # fz
193 | ])
194 | lab_pixels = tf.matmul(fxfyfz_pixels, fxfyfz_to_lab) + tf.constant([-16.0, 0.0, 0.0])
195 |
196 | return tf.reshape(lab_pixels, tf.shape(srgb))
197 |
198 |
199 | def lab_to_rgb(lab):
200 | with tf.name_scope('lab_to_rgb'):
201 | lab_pixels = tf.reshape(lab, [-1, 3])
202 | with tf.name_scope('cielab_to_xyz'):
203 | # convert to fxfyfz
204 | lab_to_fxfyfz = tf.constant([
205 | # fx fy fz
206 | [1/116.0, 1/116.0, 1/116.0], # l
207 | [1/500.0, 0.0, 0.0], # a
208 | [ 0.0, 0.0, -1/200.0], # b
209 | ])
210 | fxfyfz_pixels = tf.matmul(lab_pixels + tf.constant([16.0, 0.0, 0.0]), lab_to_fxfyfz)
211 |
212 | # convert to xyz
213 | epsilon = 6/29
214 | linear_mask = tf.cast(fxfyfz_pixels <= epsilon, dtype=tf.float32)
215 | exponential_mask = tf.cast(fxfyfz_pixels > epsilon, dtype=tf.float32)
216 | xyz_pixels = (3 * epsilon**2 * (fxfyfz_pixels - 4/29)) * linear_mask + (fxfyfz_pixels ** 3) * exponential_mask
217 |
218 | # denormalize for D65 white point
219 | xyz_pixels = tf.multiply(xyz_pixels, [0.950456, 1.0, 1.088754])
220 |
221 | with tf.name_scope('xyz_to_srgb'):
222 | xyz_to_rgb = tf.constant([
223 | # r g b
224 | [ 3.2404542, -0.9692660, 0.0556434], # x
225 | [-1.5371385, 1.8760108, -0.2040259], # y
226 | [-0.4985314, 0.0415560, 1.0572252], # z
227 | ])
228 | rgb_pixels = tf.matmul(xyz_pixels, xyz_to_rgb)
229 | # avoid a slightly negative number messing up the conversion
230 | rgb_pixels = tf.clip_by_value(rgb_pixels, 0.0, 1.0)
231 | linear_mask = tf.cast(rgb_pixels <= 0.0031308, dtype=tf.float32)
232 | exponential_mask = tf.cast(rgb_pixels > 0.0031308, dtype=tf.float32)
233 | srgb_pixels = (rgb_pixels * 12.92 * linear_mask) + ((rgb_pixels ** (1/2.4) * 1.055) - 0.055) * exponential_mask
234 |
235 | return tf.reshape(srgb_pixels, tf.shape(lab))
236 |
237 | def preprocess_lab(lab):
238 | with tf.name_scope('preprocess_lab'):
239 | L_chan, a_chan, b_chan = tf.unstack(lab, axis=-1)
240 | # L_chan: black and white with input range [0, 100]
241 | # a_chan/b_chan: color channels with input range [-128, 127]
242 | # [0, 100] => [-1, 1], ~[-128, 127] => [-1, 1]
243 |
244 | L_chan = L_chan * 255.0 / 100.0
245 | a_chan = a_chan + 128
246 | b_chan = b_chan + 128
247 |
248 | L_chan /= 255.0
249 | a_chan /= 255.0
250 | b_chan /= 255.0
251 |
252 | L_chan = (L_chan - 0.5) / 0.5
253 | a_chan = (a_chan - 0.5) / 0.5
254 | b_chan = (b_chan - 0.5) / 0.5
255 |
256 | return [L_chan, a_chan, b_chan]
257 |
258 | def show_all_variables():
259 | model_vars = tf.trainable_variables()
260 | slim.model_analyzer.analyze_vars(model_vars, print_info=True)
261 |
262 | def check_folder(log_dir):
263 | if not os.path.exists(log_dir):
264 | os.makedirs(log_dir)
265 | return log_dir
266 |
267 | def str2bool(x):
268 | return x.lower() in ('true')
269 |
270 | def pytorch_xavier_weight_factor(gain=0.02, uniform=False) :
271 |
272 | if uniform :
273 | factor = gain * gain
274 | mode = 'FAN_AVG'
275 | else :
276 | factor = (gain * gain) / 1.3
277 | mode = 'FAN_AVG'
278 |
279 | return factor, mode, uniform
280 |
281 | def pytorch_kaiming_weight_factor(a=0.0, activation_function='relu', uniform=False) :
282 |
283 | if activation_function == 'relu' :
284 | gain = np.sqrt(2.0)
285 | elif activation_function == 'leaky_relu' :
286 | gain = np.sqrt(2.0 / (1 + a ** 2))
287 | elif activation_function =='tanh' :
288 | gain = 5.0 / 3
289 | else :
290 | gain = 1.0
291 |
292 | if uniform :
293 | factor = gain * gain
294 | mode = 'FAN_IN'
295 | else :
296 | factor = (gain * gain) / 1.3
297 | mode = 'FAN_IN'
298 |
299 | return factor, mode, uniform
--------------------------------------------------------------------------------