├── .DS_Store ├── .gitignore ├── LICENSE ├── README.md ├── assets ├── compare.png ├── gcb.png ├── pixel_shuffle.png ├── rdb.png ├── rdn.png ├── relativistic.png ├── relativistic_s.png ├── srm.png ├── tf-cook.png └── tf-cookbook.png ├── ops.py └── utils.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/Tensorflow-Cookbook/580e5fc26e8f24023d1da6d095452fb6d5a121c7/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Junho Kim (1993.01.12) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |
4 | 5 | 6 | # [Web page](http://bit.ly/jhkim_tf_cookbook) 7 | # [Tensorflow 2 Cookbook](https://github.com/taki0112/Tensorflow2-Cookbook) 8 | 9 | ## Contributions 10 | In now, this repo contains general architectures and functions that are useful for the GAN and classificstion. 11 | 12 | I will continue to add useful things to other areas. 13 | 14 | Also, your pull requests and issues are always welcome. 15 | 16 | And write what you want to implement on the issue. I'll implement it. 17 | 18 | # How to use 19 | ## Import 20 | * `ops.py` 21 | * **operations** 22 | * from ops import * 23 | * `utils.py` 24 | * **image processing** 25 | * from utils import * 26 | 27 | ## Network template 28 | ```python 29 | def network(x, is_training=True, reuse=False, scope="network"): 30 | with tf.variable_scope(scope, reuse=reuse): 31 | x = conv(...) 32 | 33 | ... 34 | 35 | return logit 36 | ``` 37 | 38 | ## Insert data to network using DatasetAPI 39 | ```python 40 | Image_Data_Class = ImageData(img_size, img_ch, augment_flag) 41 | 42 | trainA_dataset = ['./dataset/cat/trainA/a.jpg', 43 | './dataset/cat/trainA/b.png', 44 | './dataset/cat/trainA/c.jpeg', 45 | ...] 46 | trainA = tf.data.Dataset.from_tensor_slices(trainA_dataset) 47 | trainA = trainA.map(Image_Data_Class.image_processing, num_parallel_calls=16) 48 | trainA = trainA.shuffle(buffer_size=10000).prefetch(buffer_size=batch_size).batch(batch_size).repeat() 49 | 50 | trainA_iterator = trainA.make_one_shot_iterator() 51 | data_A = trainA_iterator.get_next() 52 | 53 | logit = network(data_A) 54 | ``` 55 | * See [this](https://github.com/taki0112/Tensorflow-DatasetAPI) for more information. 56 | 57 | ## Option 58 | * `padding='SAME'` 59 | * pad = ceil[ (kernel - stride) / 2 ] 60 | * `pad_type` 61 | * 'zero' or 'reflect' 62 | * `sn` 63 | * use [spectral_normalization](https://arxiv.org/pdf/1802.05957.pdf) or not 64 | 65 | ## Caution 66 | * If you don't want to share variable, **set all scope names differently.** 67 | 68 | --- 69 | ## Weight 70 | ```python 71 | weight_init = tf.truncated_normal_initializer(mean=0.0, stddev=0.02) 72 | weight_regularizer = tf.contrib.layers.l2_regularizer(0.0001) 73 | weight_regularizer_fully = tf.contrib.layers.l2_regularizer(0.0001) 74 | ``` 75 | ### Initialization 76 | * `Xavier` : tf.contrib.layers.xavier_initializer() 77 | ```python 78 | 79 | USE """tf.contrib.layers.variance_scaling_initializer()""" 80 | 81 | if uniform : 82 | factor = gain * gain 83 | mode = 'FAN_AVG' 84 | else : 85 | factor = (gain * gain) / 1.3 86 | mode = 'FAN_AVG' 87 | ``` 88 | * `He` : tf.contrib.layers.variance_scaling_initializer() 89 | ```python 90 | if uniform : 91 | factor = gain * gain 92 | mode = 'FAN_IN' 93 | else : 94 | factor = (gain * gain) / 1.3 95 | mode = 'FAN_OUT' 96 | ``` 97 | * `Normal` : tf.random_normal_initializer(mean=0.0, stddev=0.02) 98 | * `Truncated_normal` : tf.truncated_normal_initializer(mean=0.0, stddev=0.02) 99 | * `Orthogonal` : tf.orthogonal_initializer(1.0) / # if relu = sqrt(2), the others = 1.0 100 | 101 | ### Regularization 102 | * `l2_decay` : tf.contrib.layers.l2_regularizer(0.0001) 103 | * `orthogonal_regularizer` : orthogonal_regularizer(0.0001) & orthogonal_regularizer_fully(0.0001) 104 | 105 | ## Convolution 106 | ### basic conv 107 | ```python 108 | x = conv(x, channels=64, kernel=3, stride=2, pad=1, pad_type='reflect', use_bias=True, sn=True, scope='conv') 109 | ``` 110 |
111 | 112 |
113 | 114 | ### partial conv (NVIDIA [Partial Convolution](https://github.com/NVIDIA/partialconv)) 115 | ```python 116 | x = partial_conv(x, channels=64, kernel=3, stride=2, use_bias=True, padding='SAME', sn=True, scope='partial_conv') 117 | ``` 118 | 119 | ![p_conv](https://github.com/taki0112/partial_conv-Tensorflow/raw/master/assets/partial_conv.png) 120 | ![p_result](https://github.com/taki0112/partial_conv-Tensorflow/raw/master/assets/classification.png) 121 | 122 | ### dilated conv 123 | ```python 124 | x = dilate_conv(x, channels=64, kernel=3, rate=2, use_bias=True, padding='VALID', sn=True, scope='dilate_conv') 125 | ``` 126 |
127 | 128 |
129 | 130 | --- 131 | 132 | ## Deconvolution 133 | ### basic deconv 134 | ```python 135 | x = deconv(x, channels=64, kernel=3, stride=1, padding='SAME', use_bias=True, sn=True, scope='deconv') 136 | ``` 137 |
138 | 139 |
140 | 141 | --- 142 | 143 | ## Fully-connected 144 | ```python 145 | x = fully_connected(x, units=64, use_bias=True, sn=True, scope='fully_connected') 146 | ``` 147 | 148 |
149 | 150 |
151 | 152 | --- 153 | 154 | ## Pixel shuffle 155 | ```python 156 | x = conv_pixel_shuffle_down(x, scale_factor=2, use_bias=True, sn=True, scope='pixel_shuffle_down') 157 | x = conv_pixel_shuffle_up(x, scale_factor=2, use_bias=True, sn=True, scope='pixel_shuffle_up') 158 | ``` 159 | * `down` ===> [height, width] -> [**height // scale_factor, width // scale_factor**] 160 | * `up` ===> [height, width] -> [**height \* scale_factor, width \* scale_factor**] 161 | 162 | ![pixel_shuffle](./assets/pixel_shuffle.png) 163 | 164 | 165 | --- 166 | 167 | ## Block 168 | ### residual block 169 | ```python 170 | x = resblock(x, channels=64, is_training=is_training, use_bias=True, sn=True, scope='residual_block') 171 | x = resblock_down(x, channels=64, is_training=is_training, use_bias=True, sn=True, scope='residual_block_down') 172 | x = resblock_up(x, channels=64, is_training=is_training, use_bias=True, sn=True, scope='residual_block_up') 173 | ``` 174 | * `down` ===> [height, width] -> [**height // 2, width // 2**] 175 | * `up` ===> [height, width] -> [**height \* 2, width \* 2**] 176 |
177 | 178 |
179 | 180 | ### dense block 181 | ```python 182 | x = denseblock(x, channels=64, n_db=6, is_training=is_training, use_bias=True, sn=True, scope='denseblock') 183 | ``` 184 | * `n_db` ===> The number of dense-block 185 |
186 | 187 |
188 | 189 | ### residual-dense block 190 | ```python 191 | x = res_denseblock(x, channels=64, n_rdb=20, n_rdb_conv=6, is_training=is_training, use_bias=True, sn=True, scope='res_denseblock') 192 | ``` 193 | * `n_rdb` ===> The number of RDB 194 | * `n_rdb_conv` ===> per RDB conv layer 195 | 196 |
197 | 198 | 199 | 200 |
201 | 202 | ### attention block 203 | ```python 204 | x = self_attention(x, use_bias=True, sn=True, scope='self_attention') 205 | x = self_attention_with_pooling(x, use_bias=True, sn=True, scope='self_attention_version_2') 206 | 207 | x = squeeze_excitation(x, ratio=16, use_bias=True, sn=True, scope='squeeze_excitation') 208 | 209 | x = convolution_block_attention(x, ratio=16, use_bias=True, sn=True, scope='convolution_block_attention') 210 | 211 | x = global_context_block(x, use_bias=True, sn=True, scope='gc_block') 212 | 213 | x = srm_block(x, use_bias=False, is_training=is_training, scope='srm_block') 214 | ``` 215 | 216 |
217 | 218 |
219 | 220 | --- 221 | 222 |
223 | 224 | 225 |
226 | 227 | --- 228 | 229 | 230 |
231 | 232 | 233 |
234 | 235 | 236 | --- 237 | 238 |
239 | 240 |
241 | 242 | --- 243 | 244 |
245 | 246 |
247 | 248 | --- 249 | 250 | ## Normalization 251 | ```python 252 | x = batch_norm(x, is_training=is_training, scope='batch_norm') 253 | x = layer_norm(x, scope='layer_norm') 254 | x = instance_norm(x, scope='instance_norm') 255 | x = group_norm(x, groups=32, scope='group_norm') 256 | 257 | x = pixel_norm(x) 258 | 259 | x = batch_instance_norm(x, scope='batch_instance_norm') 260 | x = layer_instance_norm(x, scope='layer_instance_norm') 261 | x = switch_norm(x, scope='switch_norm') 262 | 263 | x = condition_batch_norm(x, z, is_training=is_training, scope='condition_batch_norm'): 264 | 265 | x = adaptive_instance_norm(x, gamma, beta) 266 | x = adaptive_layer_instance_norm(x, gamma, beta, smoothing=True, scope='adaLIN') 267 | 268 | ``` 269 | * See [this](https://github.com/taki0112/BigGAN-Tensorflow) for how to use `condition_batch_norm` 270 | * See [this](https://github.com/taki0112/MUNIT-Tensorflow) for how to use `adaptive_instance_norm` 271 | * See [this](https://github.com/taki0112/UGATIT) for how to use `adaptive_layer_instance_norm` & `layer_instance_norm` 272 | 273 |
274 | 275 |
276 | 277 | 278 |
279 | 280 |
281 | 282 | --- 283 | 284 | ## Activation 285 | ```python 286 | x = relu(x) 287 | x = lrelu(x, alpha=0.2) 288 | x = tanh(x) 289 | x = sigmoid(x) 290 | x = swish(x) 291 | x = elu(x) 292 | ``` 293 | 294 | --- 295 | 296 | ## Pooling & Resize 297 | ```python 298 | x = nearest_up_sample(x, scale_factor=2) 299 | x = bilinear_up_sample(x, scale_factor=2) 300 | x = nearest_down_sample(x, scale_factor=2) 301 | x = bilinear_down_sample(x, scale_factor=2) 302 | 303 | x = max_pooling(x, pool_size=2) 304 | x = avg_pooling(x, pool_size=2) 305 | 306 | x = global_max_pooling(x) 307 | x = global_avg_pooling(x) 308 | 309 | x = flatten(x) 310 | x = hw_flatten(x) 311 | ``` 312 | 313 | --- 314 | 315 | ## Loss 316 | ### classification loss 317 | ```python 318 | loss, accuracy = classification_loss(logit, label) 319 | 320 | loss = dice_loss(n_classes=10, logit, label) 321 | ``` 322 | 323 | ### regularization loss 324 | ```python 325 | g_reg_loss = regularization_loss('generator') 326 | d_reg_loss = regularization_loss('discriminator') 327 | ``` 328 | 329 | * If you want to use `regularizer`, then you should write it 330 | 331 | ### pixel loss 332 | ```python 333 | loss = L1_loss(x, y) 334 | loss = L2_loss(x, y) 335 | loss = huber_loss(x, y) 336 | loss = histogram_loss(x, y) 337 | 338 | loss = gram_style_loss(x, y) 339 | 340 | loss = color_consistency_loss(x, y) 341 | ``` 342 | * `histogram_loss` means the difference in the color distribution of the image pixel values. 343 | * `gram_style_loss` means the difference between the styles using gram matrix. 344 | * `color_consistency_loss` means the color difference between the generated image and the input image. 345 | 346 | ### gan loss 347 | ```python 348 | d_loss = discriminator_loss(Ra=True, loss_func='wgan-gp', real=real_logit, fake=fake_logit) 349 | g_loss = generator_loss(Ra=True, loss_func='wgan-gp', real=real_logit, fake=fake_logit) 350 | ``` 351 | * `Ra` 352 | * use [relativistic gan](https://arxiv.org/pdf/1807.00734.pdf) or not 353 | * `loss_func` 354 | * gan 355 | * lsgan 356 | * hinge 357 | * wgan-gp 358 | * dragan 359 | * See [this](https://github.com/taki0112/BigGAN-Tensorflow/blob/master/BigGAN_512.py#L180) for how to use `gradient_penalty` 360 | 361 |
362 | 363 |
364 | 365 | ### [vdb loss](https://arxiv.org/abs/1810.00821) 366 | ```python 367 | d_bottleneck_loss = vdb_loss(real_mu, real_logvar, i_c) + vdb_loss(fake_mu, fake_logvar, i_c) 368 | ``` 369 | 370 | ### kl-divergence (z ~ N(0, 1)) 371 | ```python 372 | loss = kl_loss(mean, logvar) 373 | ``` 374 | 375 | --- 376 | 377 | ## Author 378 | [Junho Kim](http://bit.ly/jhkim_ai) 379 | -------------------------------------------------------------------------------- /assets/compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/Tensorflow-Cookbook/580e5fc26e8f24023d1da6d095452fb6d5a121c7/assets/compare.png -------------------------------------------------------------------------------- /assets/gcb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/Tensorflow-Cookbook/580e5fc26e8f24023d1da6d095452fb6d5a121c7/assets/gcb.png -------------------------------------------------------------------------------- /assets/pixel_shuffle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/Tensorflow-Cookbook/580e5fc26e8f24023d1da6d095452fb6d5a121c7/assets/pixel_shuffle.png -------------------------------------------------------------------------------- /assets/rdb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/Tensorflow-Cookbook/580e5fc26e8f24023d1da6d095452fb6d5a121c7/assets/rdb.png -------------------------------------------------------------------------------- /assets/rdn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/Tensorflow-Cookbook/580e5fc26e8f24023d1da6d095452fb6d5a121c7/assets/rdn.png -------------------------------------------------------------------------------- /assets/relativistic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/Tensorflow-Cookbook/580e5fc26e8f24023d1da6d095452fb6d5a121c7/assets/relativistic.png -------------------------------------------------------------------------------- /assets/relativistic_s.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/Tensorflow-Cookbook/580e5fc26e8f24023d1da6d095452fb6d5a121c7/assets/relativistic_s.png -------------------------------------------------------------------------------- /assets/srm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/Tensorflow-Cookbook/580e5fc26e8f24023d1da6d095452fb6d5a121c7/assets/srm.png -------------------------------------------------------------------------------- /assets/tf-cook.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/Tensorflow-Cookbook/580e5fc26e8f24023d1da6d095452fb6d5a121c7/assets/tf-cook.png -------------------------------------------------------------------------------- /assets/tf-cookbook.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/Tensorflow-Cookbook/580e5fc26e8f24023d1da6d095452fb6d5a121c7/assets/tf-cookbook.png -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from utils import pytorch_xavier_weight_factor, pytorch_kaiming_weight_factor 4 | 5 | ################################################################################## 6 | # Initialization 7 | ################################################################################## 8 | 9 | """ 10 | 11 | pytorch xavier (gain) 12 | https://pytorch.org/docs/stable/_modules/torch/nn/init.html 13 | 14 | USE < tf.contrib.layers.variance_scaling_initializer() > 15 | if uniform : 16 | factor = gain * gain 17 | mode = 'FAN_AVG' 18 | else : 19 | factor = (gain * gain) / 1.3 20 | mode = 'FAN_AVG' 21 | 22 | pytorch : trunc_stddev = gain * sqrt(2 / (fan_in + fan_out)) 23 | tensorflow : trunc_stddev = sqrt(1.3 * factor * 2 / (fan_in + fan_out)) 24 | 25 | """ 26 | 27 | """ 28 | pytorch kaiming (a=0) 29 | https://pytorch.org/docs/stable/_modules/torch/nn/init.html 30 | 31 | if uniform : 32 | a = 0 -> gain = sqrt(2) 33 | factor = gain * gain 34 | mode='FAN_IN' 35 | else : 36 | a = 0 -> gain = sqrt(2) 37 | factor = (gain * gain) / 1.3 38 | mode = 'FAN_OUT', # FAN_OUT is correct, but more use 'FAN_IN 39 | 40 | pytorch : trunc_stddev = gain * sqrt(2 / fan_in) 41 | tensorflow : trunc_stddev = sqrt(1.3 * factor * 2 / fan_in) 42 | 43 | """ 44 | 45 | # Xavier : tf.contrib.layers.xavier_initializer() 46 | # He : tf.contrib.layers.variance_scaling_initializer() 47 | # Normal : tf.random_normal_initializer(mean=0.0, stddev=0.02) 48 | # Truncated_normal : tf.truncated_normal_initializer(mean=0.0, stddev=0.02) 49 | # Orthogonal : tf.orthogonal_initializer(0.02) 50 | 51 | ################################################################################## 52 | # Regularization 53 | ################################################################################## 54 | 55 | # l2_decay : tf.contrib.layers.l2_regularizer(0.0001) 56 | # orthogonal_regularizer : orthogonal_regularizer(0.0001) # orthogonal_regularizer_fully(0.0001) 57 | 58 | # factor, mode, uniform = pytorch_xavier_weight_factor(gain=0.02, uniform=False) 59 | # weight_init = tf_contrib.layers.variance_scaling_initializer(factor=factor, mode=mode, uniform=uniform) 60 | 61 | weight_init = tf.truncated_normal_initializer(mean=0.0, stddev=0.02) 62 | weight_regularizer = tf.contrib.layers.l2_regularizer(0.0001) 63 | weight_regularizer_fully = tf.contrib.layers.l2_regularizer(0.0001) 64 | 65 | 66 | ################################################################################## 67 | # Layers 68 | ################################################################################## 69 | 70 | # padding='SAME' ======> pad = floor[ (kernel - stride) / 2 ] 71 | def conv(x, channels, kernel=4, stride=2, pad=0, pad_type='zero', use_bias=True, sn=False, scope='conv_0'): 72 | with tf.variable_scope(scope): 73 | if pad > 0: 74 | h = x.get_shape().as_list()[1] 75 | if h % stride == 0: 76 | pad = pad * 2 77 | else: 78 | pad = max(kernel - (h % stride), 0) 79 | 80 | pad_top = pad // 2 81 | pad_bottom = pad - pad_top 82 | pad_left = pad // 2 83 | pad_right = pad - pad_left 84 | 85 | if pad_type == 'zero': 86 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]) 87 | if pad_type == 'reflect': 88 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]], mode='REFLECT') 89 | 90 | if sn: 91 | w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], initializer=weight_init, 92 | regularizer=weight_regularizer) 93 | x = tf.nn.conv2d(input=x, filter=spectral_norm(w), 94 | strides=[1, stride, stride, 1], padding='VALID') 95 | if use_bias: 96 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0)) 97 | x = tf.nn.bias_add(x, bias) 98 | 99 | else: 100 | x = tf.layers.conv2d(inputs=x, filters=channels, 101 | kernel_size=kernel, kernel_initializer=weight_init, 102 | kernel_regularizer=weight_regularizer, 103 | strides=stride, use_bias=use_bias) 104 | 105 | return x 106 | 107 | 108 | def partial_conv(x, channels, kernel=3, stride=2, use_bias=True, padding='SAME', sn=False, scope='conv_0'): 109 | with tf.variable_scope(scope): 110 | if padding.lower() == 'SAME'.lower(): 111 | with tf.variable_scope('mask'): 112 | _, h, w, _ = x.get_shape().as_list() 113 | 114 | slide_window = kernel * kernel 115 | mask = tf.ones(shape=[1, h, w, 1]) 116 | 117 | update_mask = tf.layers.conv2d(mask, filters=1, 118 | kernel_size=kernel, kernel_initializer=tf.constant_initializer(1.0), 119 | strides=stride, padding=padding, use_bias=False, trainable=False) 120 | 121 | mask_ratio = slide_window / (update_mask + 1e-8) 122 | update_mask = tf.clip_by_value(update_mask, 0.0, 1.0) 123 | mask_ratio = mask_ratio * update_mask 124 | 125 | with tf.variable_scope('x'): 126 | if sn: 127 | w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], 128 | initializer=weight_init, regularizer=weight_regularizer) 129 | x = tf.nn.conv2d(input=x, filter=spectral_norm(w), strides=[1, stride, stride, 1], padding=padding) 130 | else: 131 | x = tf.layers.conv2d(x, filters=channels, 132 | kernel_size=kernel, kernel_initializer=weight_init, 133 | kernel_regularizer=weight_regularizer, 134 | strides=stride, padding=padding, use_bias=False) 135 | x = x * mask_ratio 136 | 137 | if use_bias: 138 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0)) 139 | 140 | x = tf.nn.bias_add(x, bias) 141 | x = x * update_mask 142 | else: 143 | if sn: 144 | w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], 145 | initializer=weight_init, regularizer=weight_regularizer) 146 | x = tf.nn.conv2d(input=x, filter=spectral_norm(w), strides=[1, stride, stride, 1], padding=padding) 147 | if use_bias: 148 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0)) 149 | 150 | x = tf.nn.bias_add(x, bias) 151 | else: 152 | x = tf.layers.conv2d(x, filters=channels, 153 | kernel_size=kernel, kernel_initializer=weight_init, 154 | kernel_regularizer=weight_regularizer, 155 | strides=stride, padding=padding, use_bias=use_bias) 156 | 157 | return x 158 | 159 | 160 | def dilate_conv(x, channels, kernel=3, rate=2, use_bias=True, padding='SAME', sn=False, scope='conv_0'): 161 | with tf.variable_scope(scope): 162 | w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], initializer=weight_init, 163 | regularizer=weight_regularizer) 164 | if sn: 165 | x = tf.nn.atrous_conv2d(x, spectral_norm(w), rate=rate, padding=padding) 166 | else: 167 | x = tf.nn.atrous_conv2d(x, w, rate=rate, padding=padding) 168 | 169 | if use_bias: 170 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0)) 171 | x = tf.nn.bias_add(x, bias) 172 | 173 | return x 174 | 175 | 176 | def deconv(x, channels, kernel=4, stride=2, padding='SAME', use_bias=True, sn=False, scope='deconv_0'): 177 | with tf.variable_scope(scope): 178 | x_shape = x.get_shape().as_list() 179 | 180 | if padding == 'SAME': 181 | output_shape = [x_shape[0], x_shape[1] * stride, x_shape[2] * stride, channels] 182 | 183 | else: 184 | output_shape = [x_shape[0], x_shape[1] * stride + max(kernel - stride, 0), 185 | x_shape[2] * stride + max(kernel - stride, 0), channels] 186 | 187 | if sn: 188 | w = tf.get_variable("kernel", shape=[kernel, kernel, channels, x.get_shape()[-1]], initializer=weight_init, 189 | regularizer=weight_regularizer) 190 | x = tf.nn.conv2d_transpose(x, filter=spectral_norm(w), output_shape=output_shape, 191 | strides=[1, stride, stride, 1], padding=padding) 192 | 193 | if use_bias: 194 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0)) 195 | x = tf.nn.bias_add(x, bias) 196 | 197 | else: 198 | x = tf.layers.conv2d_transpose(inputs=x, filters=channels, 199 | kernel_size=kernel, kernel_initializer=weight_init, 200 | kernel_regularizer=weight_regularizer, 201 | strides=stride, padding=padding, use_bias=use_bias) 202 | 203 | return x 204 | 205 | 206 | def conv_pixel_shuffle_up(x, scale_factor=2, use_bias=True, sn=False, scope='pixel_shuffle'): 207 | channel = x.get_shape()[-1] * (scale_factor ** 2) 208 | x = conv(x, channel, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope=scope) 209 | x = tf.depth_to_space(x, block_size=scale_factor) 210 | 211 | return x 212 | 213 | 214 | def conv_pixel_shuffle_down(x, scale_factor=2, use_bias=True, sn=False, scope='pixel_shuffle'): 215 | channel = x.get_shape()[-1] // (scale_factor ** 2) 216 | x = conv(x, channel, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope=scope) 217 | x = tf.space_to_depth(x, block_size=scale_factor) 218 | 219 | return x 220 | 221 | 222 | def fully_connected(x, units, use_bias=True, sn=False, scope='linear'): 223 | with tf.variable_scope(scope): 224 | x = flatten(x) 225 | shape = x.get_shape().as_list() 226 | channels = shape[-1] 227 | 228 | if sn: 229 | w = tf.get_variable("kernel", [channels, units], tf.float32, 230 | initializer=weight_init, regularizer=weight_regularizer_fully) 231 | if use_bias: 232 | bias = tf.get_variable("bias", [units], 233 | initializer=tf.constant_initializer(0.0)) 234 | 235 | x = tf.matmul(x, spectral_norm(w)) + bias 236 | else: 237 | x = tf.matmul(x, spectral_norm(w)) 238 | 239 | else: 240 | x = tf.layers.dense(x, units=units, kernel_initializer=weight_init, 241 | kernel_regularizer=weight_regularizer_fully, 242 | use_bias=use_bias) 243 | 244 | return x 245 | 246 | 247 | ################################################################################## 248 | # Blocks 249 | ################################################################################## 250 | 251 | def resblock(x_init, channels, use_bias=True, is_training=True, sn=False, scope='resblock'): 252 | with tf.variable_scope(scope): 253 | with tf.variable_scope('res1'): 254 | x = conv(x_init, channels, kernel=3, stride=1, pad=1, use_bias=use_bias, sn=sn) 255 | x = batch_norm(x, is_training) 256 | x = relu(x) 257 | 258 | with tf.variable_scope('res2'): 259 | x = conv(x, channels, kernel=3, stride=1, pad=1, use_bias=use_bias, sn=sn) 260 | x = batch_norm(x, is_training) 261 | 262 | if channels != x_init.shape[-1]: 263 | with tf.variable_scope('skip'): 264 | x_init = conv(x_init, channels, kernel=1, stride=1, use_bias=use_bias, sn=sn) 265 | return relu(x + x_init) 266 | 267 | return x + x_init 268 | 269 | 270 | def resblock_up(x_init, channels, use_bias=True, is_training=True, sn=False, scope='resblock_up'): 271 | with tf.variable_scope(scope): 272 | with tf.variable_scope('res1'): 273 | x = deconv(x_init, channels, kernel=3, stride=2, use_bias=use_bias, sn=sn) 274 | x = batch_norm(x, is_training) 275 | x = relu(x) 276 | 277 | with tf.variable_scope('res2'): 278 | x = deconv(x, channels, kernel=3, stride=1, use_bias=use_bias, sn=sn) 279 | x = batch_norm(x, is_training) 280 | 281 | with tf.variable_scope('skip'): 282 | x_init = deconv(x_init, channels, kernel=3, stride=2, use_bias=use_bias, sn=sn) 283 | 284 | return relu(x + x_init) 285 | 286 | 287 | def resblock_up_condition(x_init, z, channels, use_bias=True, is_training=True, sn=False, scope='resblock_up'): 288 | # See https://github.com/taki0112/BigGAN-Tensorflow 289 | with tf.variable_scope(scope): 290 | with tf.variable_scope('res1'): 291 | x = deconv(x_init, channels, kernel=3, stride=2, use_bias=use_bias, sn=sn) 292 | x = condition_batch_norm(x, z, is_training) 293 | x = relu(x) 294 | 295 | with tf.variable_scope('res2'): 296 | x = deconv(x, channels, kernel=3, stride=1, use_bias=use_bias, sn=sn) 297 | x = condition_batch_norm(x, z, is_training) 298 | 299 | with tf.variable_scope('skip'): 300 | x_init = deconv(x_init, channels, kernel=3, stride=2, use_bias=use_bias, sn=sn) 301 | 302 | return relu(x + x_init) 303 | 304 | 305 | def resblock_down(x_init, channels, use_bias=True, is_training=True, sn=False, scope='resblock_down'): 306 | with tf.variable_scope(scope): 307 | with tf.variable_scope('res1'): 308 | x = conv(x_init, channels, kernel=3, stride=2, pad=1, use_bias=use_bias, sn=sn) 309 | x = batch_norm(x, is_training) 310 | x = relu(x) 311 | 312 | with tf.variable_scope('res2'): 313 | x = conv(x, channels, kernel=3, stride=1, pad=1, use_bias=use_bias, sn=sn) 314 | x = batch_norm(x, is_training) 315 | 316 | with tf.variable_scope('skip'): 317 | x_init = conv(x_init, channels, kernel=3, stride=2, pad=1, use_bias=use_bias, sn=sn) 318 | 319 | return relu(x + x_init) 320 | 321 | 322 | def denseblock(x_init, channels, n_db=6, use_bias=True, is_training=True, sn=False, scope='denseblock'): 323 | with tf.variable_scope(scope): 324 | layers = [] 325 | layers.append(x_init) 326 | 327 | with tf.variable_scope('bottle_neck_0'): 328 | x = conv(x_init, 4 * channels, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv_0') 329 | x = batch_norm(x, is_training, scope='batch_norm_0') 330 | x = relu(x) 331 | 332 | x = conv(x, channels, kernel=3, stride=1, pad=1, use_bias=use_bias, sn=sn, scope='conv_1') 333 | x = batch_norm(x, is_training, scope='batch_norm_1') 334 | x = relu(x) 335 | 336 | layers.append(x) 337 | 338 | for i in range(1, n_db): 339 | with tf.variable_scope('bottle_neck_' + str(i)): 340 | x = tf.concat(layers, axis=-1) 341 | 342 | x = conv(x, 4 * channels, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv_0') 343 | x = batch_norm(x, is_training, scope='batch_norm_0') 344 | x = relu(x) 345 | 346 | x = conv(x, channels, kernel=3, stride=1, pad=1, use_bias=use_bias, sn=sn, scope='conv_1') 347 | x = batch_norm(x, is_training, scope='batch_norm_1') 348 | x = relu(x) 349 | 350 | layers.append(x) 351 | 352 | x = tf.concat(layers, axis=-1) 353 | 354 | return x 355 | 356 | 357 | def res_denseblock(x_init, channels, n_rdb=20, n_rdb_conv=6, use_bias=True, is_training=True, sn=False, 358 | scope='res_denseblock'): 359 | with tf.variable_scope(scope): 360 | RDBs = [] 361 | x_input = x_init 362 | 363 | """ 364 | n_rdb = 20 ( RDB number ) 365 | n_rdb_conv = 6 ( per RDB conv layer ) 366 | """ 367 | 368 | for k in range(n_rdb): 369 | with tf.variable_scope('RDB_' + str(k)): 370 | layers = [] 371 | layers.append(x_init) 372 | 373 | x = conv(x_init, channels, kernel=3, stride=1, pad=1, use_bias=use_bias, sn=sn, scope='conv_0') 374 | x = batch_norm(x, is_training, scope='batch_norm_0') 375 | x = relu(x) 376 | 377 | layers.append(x) 378 | 379 | for i in range(1, n_rdb_conv): 380 | x = tf.concat(layers, axis=-1) 381 | 382 | x = conv(x, channels, kernel=3, stride=1, pad=1, use_bias=use_bias, sn=sn, scope='conv_' + str(i)) 383 | x = batch_norm(x, is_training, scope='batch_norm_' + str(i)) 384 | x = relu(x) 385 | 386 | layers.append(x) 387 | 388 | # Local feature fusion 389 | x = tf.concat(layers, axis=-1) 390 | x = conv(x, channels, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv_last') 391 | 392 | # Local residual learning 393 | if channels != x_init.shape[-1] : 394 | x_init = conv(x_init, channels, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='local_skip_conv') 395 | x = relu(x + x_init) 396 | else : 397 | x = x_init + x 398 | 399 | RDBs.append(x) 400 | x_init = x 401 | 402 | with tf.variable_scope('GFF_1x1'): 403 | x = tf.concat(RDBs, axis=-1) 404 | x = conv(x, channels, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv') 405 | 406 | with tf.variable_scope('GFF_3x3'): 407 | x = conv(x, channels, kernel=3, stride=1, pad=1, use_bias=use_bias, sn=sn, scope='conv') 408 | 409 | # Global residual learning 410 | if channels != x_input.shape[-1]: 411 | x_input = conv(x_input, channels, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='global_skip_conv') 412 | x = relu(x + x_input) 413 | else : 414 | x = x_input + x 415 | 416 | return x 417 | 418 | 419 | def self_attention(x, use_bias=True, sn=False, scope='self_attention'): 420 | with tf.variable_scope(scope): 421 | channels = x.shape[-1] 422 | f = conv(x, channels // 8, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='f_conv') # [bs, h, w, c'] 423 | g = conv(x, channels // 8, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='g_conv') # [bs, h, w, c'] 424 | h = conv(x, channels, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='h_conv') # [bs, h, w, c] 425 | 426 | # N = h * w 427 | s = tf.matmul(hw_flatten(g), hw_flatten(f), transpose_b=True) # # [bs, N, N] 428 | 429 | beta = tf.nn.softmax(s) # attention map 430 | 431 | o = tf.matmul(beta, hw_flatten(h)) # [bs, N, C] 432 | gamma = tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0)) 433 | 434 | o = tf.reshape(o, shape=x.shape) # [bs, h, w, C] 435 | x = gamma * o + x 436 | 437 | return x 438 | 439 | 440 | def self_attention_with_pooling(x, use_bias=True, sn=False, scope='self_attention'): 441 | with tf.variable_scope(scope): 442 | channels = x.shape[-1] 443 | f = conv(x, channels // 8, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='f_conv') # [bs, h, w, c'] 444 | f = max_pooling(f) 445 | 446 | g = conv(x, channels // 8, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='g_conv') # [bs, h, w, c'] 447 | 448 | h = conv(x, channels // 2, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='h_conv') # [bs, h, w, c] 449 | h = max_pooling(h) 450 | 451 | # N = h * w 452 | s = tf.matmul(hw_flatten(g), hw_flatten(f), transpose_b=True) # # [bs, N, N] 453 | 454 | beta = tf.nn.softmax(s) # attention map 455 | 456 | o = tf.matmul(beta, hw_flatten(h)) # [bs, N, C] 457 | gamma = tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0)) 458 | 459 | o = tf.reshape(o, shape=[x.shape[0], x.shape[1], x.shape[2], channels // 2]) # [bs, h, w, C] 460 | o = conv(o, channels, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='attn_conv') 461 | x = gamma * o + x 462 | 463 | return x 464 | 465 | 466 | def squeeze_excitation(x, ratio=16, use_bias=True, sn=False, scope='senet'): 467 | with tf.variable_scope(scope): 468 | channels = x.shape[-1] 469 | squeeze = global_avg_pooling(x) 470 | 471 | excitation = fully_connected(squeeze, units=channels // ratio, use_bias=use_bias, sn=sn, scope='fc1') 472 | excitation = relu(excitation) 473 | excitation = fully_connected(excitation, units=channels, use_bias=use_bias, sn=sn, scope='fc2') 474 | excitation = sigmoid(excitation) 475 | 476 | excitation = tf.reshape(excitation, [-1, 1, 1, channels]) 477 | 478 | scale = x * excitation 479 | 480 | return scale 481 | 482 | 483 | def convolution_block_attention(x, ratio=16, use_bias=True, sn=False, scope='cbam'): 484 | with tf.variable_scope(scope): 485 | channels = x.shape[-1] 486 | with tf.variable_scope('channel_attention'): 487 | x_gap = global_avg_pooling(x) 488 | x_gap = fully_connected(x_gap, units=channels // ratio, use_bias=use_bias, sn=sn, scope='fc1') 489 | x_gap = relu(x_gap) 490 | x_gap = fully_connected(x_gap, units=channels, use_bias=use_bias, sn=sn, scope='fc2') 491 | 492 | with tf.variable_scope('channel_attention', reuse=True): 493 | x_gmp = global_max_pooling(x) 494 | x_gmp = fully_connected(x_gmp, units=channels // ratio, use_bias=use_bias, sn=sn, scope='fc1') 495 | x_gmp = relu(x_gmp) 496 | x_gmp = fully_connected(x_gmp, units=channels, use_bias=use_bias, sn=sn, scope='fc2') 497 | 498 | scale = tf.reshape(x_gap + x_gmp, [-1, 1, 1, channels]) 499 | scale = sigmoid(scale) 500 | 501 | x = x * scale 502 | 503 | with tf.variable_scope('spatial_attention'): 504 | x_channel_avg_pooling = tf.reduce_mean(x, axis=-1, keepdims=True) 505 | x_channel_max_pooling = tf.reduce_max(x, axis=-1, keepdims=True) 506 | scale = tf.concat([x_channel_avg_pooling, x_channel_max_pooling], axis=-1) 507 | 508 | scale = conv(scale, channels=1, kernel=7, stride=1, pad=3, pad_type='reflect', use_bias=False, sn=sn, scope='conv') 509 | scale = sigmoid(scale) 510 | 511 | x = x * scale 512 | 513 | return x 514 | 515 | 516 | def global_context_block(x, use_bias=True, sn=False, scope='gc_block'): 517 | with tf.variable_scope(scope): 518 | channels = x.shape[-1] 519 | with tf.variable_scope('context_modeling'): 520 | bs, h, w, c = x.get_shape().as_list() 521 | input_x = x 522 | input_x = hw_flatten(input_x) # [N, H*W, C] 523 | input_x = tf.transpose(input_x, perm=[0, 2, 1]) 524 | input_x = tf.expand_dims(input_x, axis=1) 525 | 526 | context_mask = conv(x, channels=1, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv') 527 | context_mask = hw_flatten(context_mask) 528 | context_mask = tf.nn.softmax(context_mask, axis=1) # [N, H*W, 1] 529 | context_mask = tf.transpose(context_mask, perm=[0, 2, 1]) 530 | context_mask = tf.expand_dims(context_mask, axis=-1) 531 | 532 | context = tf.matmul(input_x, context_mask) 533 | context = tf.reshape(context, shape=[bs, 1, 1, c]) 534 | 535 | with tf.variable_scope('transform_0'): 536 | context_transform = conv(context, channels, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv_0') 537 | context_transform = layer_norm(context_transform) 538 | context_transform = relu(context_transform) 539 | context_transform = conv(context_transform, channels=c, kernel=1, stride=1, use_bias=use_bias, sn=sn, 540 | scope='conv_1') 541 | context_transform = sigmoid(context_transform) 542 | 543 | x = x * context_transform 544 | 545 | with tf.variable_scope('transform_1'): 546 | context_transform = conv(context, channels, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv_0') 547 | context_transform = layer_norm(context_transform) 548 | context_transform = relu(context_transform) 549 | context_transform = conv(context_transform, channels=c, kernel=1, stride=1, use_bias=use_bias, sn=sn, 550 | scope='conv_1') 551 | 552 | x = x + context_transform 553 | 554 | return x 555 | 556 | 557 | def srm_block(x, use_bias=False, is_training=True, scope='srm_block'): 558 | with tf.variable_scope(scope): 559 | bs, h, w, channels = x.get_shape().as_list() # c = channels 560 | 561 | x = tf.reshape(x, shape=[bs, -1, channels]) # [bs, h*w, c] 562 | 563 | x_mean, x_var = tf.nn.moments(x, axes=1, keep_dims=True) # [bs, 1, c] 564 | x_std = tf.sqrt(x_var + 1e-5) 565 | 566 | t = tf.concat([x_mean, x_std], axis=1) # [bs, 2, c] 567 | 568 | z = tf.layers.conv1d(t, channels, kernel_size=2, strides=1, use_bias=use_bias) 569 | z = batch_norm(z, is_training=is_training) 570 | 571 | g = tf.sigmoid(z) 572 | 573 | x = tf.reshape(x * g, shape=[bs, h, w, channels]) 574 | 575 | return x 576 | 577 | 578 | ################################################################################## 579 | # Normalization 580 | ################################################################################## 581 | 582 | def batch_norm(x, is_training=False, scope='batch_norm'): 583 | """ 584 | if x_norm = tf.layers.batch_normalization 585 | 586 | # ... 587 | 588 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): 589 | train_op = optimizer.minimize(loss) 590 | """ 591 | 592 | return tf.contrib.layers.batch_norm(x, 593 | decay=0.9, epsilon=1e-05, 594 | center=True, scale=True, updates_collections=None, 595 | is_training=is_training, scope=scope) 596 | 597 | # return tf.layers.batch_normalization(x, momentum=0.9, epsilon=1e-05, center=True, scale=True, training=is_training, name=scope) 598 | 599 | 600 | def instance_norm(x, scope='instance_norm'): 601 | return tf.contrib.layers.instance_norm(x, 602 | epsilon=1e-05, 603 | center=True, scale=True, 604 | scope=scope) 605 | 606 | 607 | def layer_norm(x, scope='layer_norm'): 608 | return tf.contrib.layers.layer_norm(x, 609 | center=True, scale=True, 610 | scope=scope) 611 | 612 | 613 | def group_norm(x, groups=32, scope='group_norm'): 614 | return tf.contrib.layers.group_norm(x, groups=groups, epsilon=1e-05, 615 | center=True, scale=True, 616 | scope=scope) 617 | 618 | 619 | def adaptive_instance_norm(content, gamma, beta, epsilon=1e-5): 620 | # gamma, beta = style_mean, style_std from MLP 621 | # See https://github.com/taki0112/MUNIT-Tensorflow 622 | 623 | c_mean, c_var = tf.nn.moments(content, axes=[1, 2], keep_dims=True) 624 | c_std = tf.sqrt(c_var + epsilon) 625 | 626 | return gamma * ((content - c_mean) / c_std) + beta 627 | 628 | def adaptive_layer_instance_norm(x, gamma, beta, smoothing=True, scope='ada_layer_instance_norm') : 629 | # proposed by UGATIT 630 | # https://github.com/taki0112/UGATIT 631 | with tf.variable_scope(scope): 632 | ch = x.shape[-1] 633 | eps = 1e-5 634 | 635 | ins_mean, ins_sigma = tf.nn.moments(x, axes=[1, 2], keep_dims=True) 636 | x_ins = (x - ins_mean) / (tf.sqrt(ins_sigma + eps)) 637 | 638 | ln_mean, ln_sigma = tf.nn.moments(x, axes=[1, 2, 3], keep_dims=True) 639 | x_ln = (x - ln_mean) / (tf.sqrt(ln_sigma + eps)) 640 | 641 | rho = tf.get_variable("rho", [ch], initializer=tf.constant_initializer(1.0), constraint=lambda x: tf.clip_by_value(x, clip_value_min=0.0, clip_value_max=1.0)) 642 | 643 | if smoothing : 644 | rho = tf.clip_by_value(rho - tf.constant(0.1), 0.0, 1.0) 645 | 646 | x_hat = rho * x_ins + (1 - rho) * x_ln 647 | 648 | 649 | x_hat = x_hat * gamma + beta 650 | 651 | return x_hat 652 | 653 | 654 | def condition_batch_norm(x, z, is_training=True, scope='batch_norm'): 655 | # See https://github.com/taki0112/BigGAN-Tensorflow 656 | with tf.variable_scope(scope): 657 | _, _, _, c = x.get_shape().as_list() 658 | decay = 0.9 659 | epsilon = 1e-05 660 | 661 | test_mean = tf.get_variable("pop_mean", shape=[c], dtype=tf.float32, 662 | initializer=tf.constant_initializer(0.0), trainable=False) 663 | test_var = tf.get_variable("pop_var", shape=[c], dtype=tf.float32, initializer=tf.constant_initializer(1.0), 664 | trainable=False) 665 | 666 | beta = fully_connected(z, units=c, scope='beta') 667 | gamma = fully_connected(z, units=c, scope='gamma') 668 | 669 | beta = tf.reshape(beta, shape=[-1, 1, 1, c]) 670 | gamma = tf.reshape(gamma, shape=[-1, 1, 1, c]) 671 | 672 | if is_training: 673 | batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2]) 674 | ema_mean = tf.assign(test_mean, test_mean * decay + batch_mean * (1 - decay)) 675 | ema_var = tf.assign(test_var, test_var * decay + batch_var * (1 - decay)) 676 | 677 | with tf.control_dependencies([ema_mean, ema_var]): 678 | return tf.nn.batch_normalization(x, batch_mean, batch_var, beta, gamma, epsilon) 679 | else: 680 | return tf.nn.batch_normalization(x, test_mean, test_var, beta, gamma, epsilon) 681 | 682 | 683 | def batch_instance_norm(x, scope='batch_instance_norm'): 684 | with tf.variable_scope(scope): 685 | ch = x.shape[-1] 686 | eps = 1e-5 687 | 688 | batch_mean, batch_sigma = tf.nn.moments(x, axes=[0, 1, 2], keep_dims=True) 689 | x_batch = (x - batch_mean) / (tf.sqrt(batch_sigma + eps)) 690 | 691 | ins_mean, ins_sigma = tf.nn.moments(x, axes=[1, 2], keep_dims=True) 692 | x_ins = (x - ins_mean) / (tf.sqrt(ins_sigma + eps)) 693 | 694 | rho = tf.get_variable("rho", [ch], initializer=tf.constant_initializer(1.0), constraint=lambda x: tf.clip_by_value(x, clip_value_min=0.0, clip_value_max=1.0)) 695 | gamma = tf.get_variable("gamma", [ch], initializer=tf.constant_initializer(1.0)) 696 | beta = tf.get_variable("beta", [ch], initializer=tf.constant_initializer(0.0)) 697 | 698 | x_hat = rho * x_batch + (1 - rho) * x_ins 699 | x_hat = x_hat * gamma + beta 700 | 701 | return x_hat 702 | 703 | def layer_instance_norm(x, scope='layer_instance_norm') : 704 | # proposed by UGATIT 705 | # https://github.com/taki0112/UGATIT 706 | with tf.variable_scope(scope): 707 | ch = x.shape[-1] 708 | eps = 1e-5 709 | 710 | ins_mean, ins_sigma = tf.nn.moments(x, axes=[1, 2], keep_dims=True) 711 | x_ins = (x - ins_mean) / (tf.sqrt(ins_sigma + eps)) 712 | 713 | ln_mean, ln_sigma = tf.nn.moments(x, axes=[1, 2, 3], keep_dims=True) 714 | x_ln = (x - ln_mean) / (tf.sqrt(ln_sigma + eps)) 715 | 716 | rho = tf.get_variable("rho", [ch], initializer=tf.constant_initializer(0.0), constraint=lambda x: tf.clip_by_value(x, clip_value_min=0.0, clip_value_max=1.0)) 717 | 718 | gamma = tf.get_variable("gamma", [ch], initializer=tf.constant_initializer(1.0)) 719 | beta = tf.get_variable("beta", [ch], initializer=tf.constant_initializer(0.0)) 720 | 721 | x_hat = rho * x_ins + (1 - rho) * x_ln 722 | 723 | x_hat = x_hat * gamma + beta 724 | 725 | return x_hat 726 | 727 | def pixel_norm(x, epsilon=1e-8): 728 | return x * tf.rsqrt(tf.reduce_mean(tf.square(x), axis=-1, keepdims=True) + epsilon) 729 | 730 | def switch_norm(x, scope='switch_norm'): 731 | with tf.variable_scope(scope): 732 | ch = x.shape[-1] 733 | eps = 1e-5 734 | 735 | batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], keep_dims=True) 736 | ins_mean, ins_var = tf.nn.moments(x, [1, 2], keep_dims=True) 737 | layer_mean, layer_var = tf.nn.moments(x, [1, 2, 3], keep_dims=True) 738 | 739 | gamma = tf.get_variable("gamma", [ch], initializer=tf.constant_initializer(1.0)) 740 | beta = tf.get_variable("beta", [ch], initializer=tf.constant_initializer(0.0)) 741 | 742 | mean_weight = tf.nn.softmax(tf.get_variable("mean_weight", [3], initializer=tf.constant_initializer(1.0))) 743 | var_wegiht = tf.nn.softmax(tf.get_variable("var_weight", [3], initializer=tf.constant_initializer(1.0))) 744 | 745 | mean = mean_weight[0] * batch_mean + mean_weight[1] * ins_mean + mean_weight[2] * layer_mean 746 | var = var_wegiht[0] * batch_var + var_wegiht[1] * ins_var + var_wegiht[2] * layer_var 747 | 748 | x = (x - mean) / (tf.sqrt(var + eps)) 749 | x = x * gamma + beta 750 | 751 | return x 752 | 753 | def spectral_norm(w, iteration=1): 754 | w_shape = w.shape.as_list() 755 | w = tf.reshape(w, [-1, w_shape[-1]]) 756 | 757 | u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.random_normal_initializer(), trainable=False) 758 | 759 | u_hat = u 760 | v_hat = None 761 | for i in range(iteration): 762 | """ 763 | power iteration 764 | Usually iteration = 1 will be enough 765 | """ 766 | v_ = tf.matmul(u_hat, tf.transpose(w)) 767 | v_hat = tf.nn.l2_normalize(v_) 768 | 769 | u_ = tf.matmul(v_hat, w) 770 | u_hat = tf.nn.l2_normalize(u_) 771 | 772 | u_hat = tf.stop_gradient(u_hat) 773 | v_hat = tf.stop_gradient(v_hat) 774 | 775 | sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat)) 776 | 777 | with tf.control_dependencies([u.assign(u_hat)]): 778 | w_norm = w / sigma 779 | w_norm = tf.reshape(w_norm, w_shape) 780 | 781 | return w_norm 782 | 783 | ################################################################################## 784 | # Activation Function 785 | ################################################################################## 786 | 787 | def lrelu(x, alpha=0.01): 788 | # pytorch alpha is 0.01 789 | return tf.nn.leaky_relu(x, alpha) 790 | 791 | 792 | def relu(x): 793 | return tf.nn.relu(x) 794 | 795 | 796 | def tanh(x): 797 | return tf.tanh(x) 798 | 799 | 800 | def sigmoid(x): 801 | return tf.sigmoid(x) 802 | 803 | 804 | def swish(x): 805 | return x * tf.sigmoid(x) 806 | 807 | 808 | def elu(x): 809 | return tf.nn.elu(x) 810 | 811 | 812 | ################################################################################## 813 | # Pooling & Resize 814 | ################################################################################## 815 | 816 | def nearest_up_sample(x, scale_factor=2): 817 | _, h, w, _ = x.get_shape().as_list() 818 | new_size = [h * scale_factor, w * scale_factor] 819 | return tf.image.resize_nearest_neighbor(x, size=new_size) 820 | 821 | def bilinear_up_sample(x, scale_factor=2): 822 | _, h, w, _ = x.get_shape().as_list() 823 | new_size = [h * scale_factor, w * scale_factor] 824 | return tf.image.resize_bilinear(x, size=new_size) 825 | 826 | def nearest_down_sample(x, scale_factor=2): 827 | _, h, w, _ = x.get_shape().as_list() 828 | new_size = [h // scale_factor, w // scale_factor] 829 | return tf.image.resize_nearest_neighbor(x, size=new_size) 830 | 831 | def bilinear_down_sample(x, scale_factor=2): 832 | _, h, w, _ = x.get_shape().as_list() 833 | new_size = [h // scale_factor, w // scale_factor] 834 | return tf.image.resize_bilinear(x, size=new_size) 835 | 836 | def global_avg_pooling(x): 837 | gap = tf.reduce_mean(x, axis=[1, 2], keepdims=True) 838 | return gap 839 | 840 | 841 | def global_max_pooling(x): 842 | gmp = tf.reduce_max(x, axis=[1, 2], keepdims=True) 843 | return gmp 844 | 845 | 846 | def max_pooling(x, pool_size=2): 847 | x = tf.layers.max_pooling2d(x, pool_size=pool_size, strides=pool_size, padding='SAME') 848 | return x 849 | 850 | 851 | def avg_pooling(x, pool_size=2): 852 | x = tf.layers.average_pooling2d(x, pool_size=pool_size, strides=pool_size, padding='SAME') 853 | return x 854 | 855 | 856 | def flatten(x): 857 | return tf.layers.flatten(x) 858 | 859 | 860 | def hw_flatten(x): 861 | return tf.reshape(x, shape=[x.shape[0], -1, x.shape[-1]]) 862 | 863 | 864 | ################################################################################## 865 | # Loss Function 866 | ################################################################################## 867 | 868 | def classification_loss(logit, label): 869 | loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=label, logits=logit)) 870 | prediction = tf.equal(tf.argmax(logit, -1), tf.argmax(label, -1)) 871 | accuracy = tf.reduce_mean(tf.cast(prediction, tf.float32)) 872 | 873 | return loss, accuracy 874 | 875 | 876 | def L1_loss(x, y): 877 | loss = tf.reduce_mean(tf.abs(x - y)) 878 | 879 | return loss 880 | 881 | 882 | def L2_loss(x, y): 883 | loss = tf.reduce_mean(tf.square(x - y)) 884 | 885 | return loss 886 | 887 | 888 | def huber_loss(x, y): 889 | return tf.losses.huber_loss(x, y) 890 | 891 | 892 | def regularization_loss(scope_name): 893 | """ 894 | If you want to use "Regularization" 895 | g_loss += regularization_loss('generator') 896 | d_loss += regularization_loss('discriminator') 897 | """ 898 | collection_regularization = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) 899 | 900 | loss = [] 901 | for item in collection_regularization: 902 | if scope_name in item.name: 903 | loss.append(item) 904 | 905 | return tf.reduce_sum(loss) 906 | 907 | 908 | def histogram_loss(x, y): 909 | histogram_x = get_histogram(x) 910 | histogram_y = get_histogram(y) 911 | 912 | hist_loss = L1_loss(histogram_x, histogram_y) 913 | 914 | return hist_loss 915 | 916 | 917 | def get_histogram(img, bin_size=0.2): 918 | hist_entries = [] 919 | 920 | img_r, img_g, img_b = tf.split(img, num_or_size_splits=3, axis=-1) 921 | 922 | for img_chan in [img_r, img_g, img_b]: 923 | for i in np.arange(-1, 1, bin_size): 924 | gt = tf.greater(img_chan, i) 925 | leq = tf.less_equal(img_chan, i + bin_size) 926 | 927 | condition = tf.cast(tf.logical_and(gt, leq), tf.float32) 928 | hist_entries.append(tf.reduce_sum(condition)) 929 | 930 | hist = normalization(hist_entries) 931 | 932 | return hist 933 | 934 | 935 | def normalization(x): 936 | x = (x - tf.reduce_min(x)) / (tf.reduce_max(x) - tf.reduce_min(x)) 937 | return x 938 | 939 | 940 | def gram_matrix(x): 941 | b, h, w, c = x.get_shape().as_list() 942 | 943 | x = tf.reshape(x, shape=[b, -1, c]) 944 | 945 | x = tf.matmul(tf.transpose(x, perm=[0, 2, 1]), x) 946 | x = x / (h * w * c) 947 | 948 | return x 949 | 950 | 951 | def gram_style_loss(x, y): 952 | _, height, width, channels = x.get_shape().as_list() 953 | 954 | x = gram_matrix(x) 955 | y = gram_matrix(y) 956 | 957 | loss = L2_loss(x, y) # simple version 958 | 959 | # Original eqn as a constant to divide i.e 1/(4. * (channels ** 2) * (width * height) ** 2) 960 | # loss = tf.reduce_mean(tf.square(x - y)) / (channels ** 2 * width * height) # (4.0 * (channels ** 2) * (width * height) ** 2) 961 | 962 | return loss 963 | 964 | 965 | def color_consistency_loss(x, y): 966 | x_mu, x_var = tf.nn.moments(x, axes=[1, 2], keep_dims=True) 967 | y_mu, y_var = tf.nn.moments(y, axes=[1, 2], keep_dims=True) 968 | 969 | loss = L2_loss(x_mu, y_mu) + 5.0 * L2_loss(x_var, y_var) 970 | 971 | return loss 972 | 973 | 974 | def dice_loss(n_classes, logits, labels): 975 | """ 976 | :param n_classes: number of classes 977 | :param logits: [batch_size, m, n, n_classes] float32, output logits 978 | :param labels: [batch_size, m, n, 1] int32, class label 979 | :return: 980 | """ 981 | 982 | # https://github.com/keras-team/keras/issues/9395 983 | 984 | smooth = 1e-7 985 | dtype = tf.float32 986 | 987 | # alpha=beta=0.5 : dice coefficient 988 | # alpha=beta=1 : tanimoto coefficient (also known as jaccard) 989 | # alpha+beta=1 : produces set of F*-scores 990 | alpha, beta = 0.5, 0.5 991 | 992 | # make onehot label [batch_size, m, n, n_classes] 993 | # tf.one_hot() will ignore (creates zero vector) labels larger than n_class and less then 0 994 | onehot_labels = tf.one_hot(tf.squeeze(labels, axis=-1), depth=n_classes, dtype=dtype) 995 | 996 | ones = tf.ones_like(onehot_labels, dtype=dtype) 997 | predicted = tf.nn.softmax(logits) 998 | p0 = predicted 999 | p1 = ones - predicted 1000 | g0 = onehot_labels 1001 | g1 = ones - onehot_labels 1002 | 1003 | num = tf.reduce_sum(p0 * g0, axis=[0, 1, 2]) 1004 | den = num + alpha * tf.reduce_sum(p0 * g1, axis=[0, 1, 2]) + beta * tf.reduce_sum(p1 * g0, axis=[0, 1, 2]) 1005 | 1006 | loss = tf.cast(n_classes, dtype=dtype) - tf.reduce_sum((num + smooth) / (den + smooth)) 1007 | return loss 1008 | 1009 | 1010 | ################################################################################## 1011 | # GAN Loss Function 1012 | ################################################################################## 1013 | 1014 | def discriminator_loss(Ra, gan_type, real, fake): 1015 | # Ra = Relativistic 1016 | real_loss = 0 1017 | fake_loss = 0 1018 | 1019 | if Ra and (gan_type.__contains__('wgan') or gan_type == 'sphere'): 1020 | print("No exist [Ra + WGAN or Ra + Sphere], so use the {} loss function".format(gan_type)) 1021 | Ra = False 1022 | 1023 | if Ra: 1024 | real_logit = (real - tf.reduce_mean(fake)) 1025 | fake_logit = (fake - tf.reduce_mean(real)) 1026 | 1027 | if gan_type == 'lsgan': 1028 | real_loss = tf.reduce_mean(tf.square(real_logit - 1.0)) 1029 | fake_loss = tf.reduce_mean(tf.square(fake_logit + 1.0)) 1030 | 1031 | if gan_type == 'gan' or gan_type == 'gan-gp' or gan_type == 'dragan': 1032 | real_loss = tf.reduce_mean( 1033 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real), logits=real_logit)) 1034 | fake_loss = tf.reduce_mean( 1035 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(fake), logits=fake_logit)) 1036 | 1037 | if gan_type == 'hinge': 1038 | real_loss = tf.reduce_mean(relu(1.0 - real_logit)) 1039 | fake_loss = tf.reduce_mean(relu(1.0 + fake_logit)) 1040 | 1041 | else: 1042 | if gan_type.__contains__('wgan'): 1043 | real_loss = -tf.reduce_mean(real) 1044 | fake_loss = tf.reduce_mean(fake) 1045 | 1046 | if gan_type == 'lsgan': 1047 | real_loss = tf.reduce_mean(tf.square(real - 1.0)) 1048 | fake_loss = tf.reduce_mean(tf.square(fake)) 1049 | 1050 | if gan_type == 'gan' or gan_type == 'gan-gp' or gan_type == 'dragan': 1051 | real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real), logits=real)) 1052 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(fake), logits=fake)) 1053 | 1054 | if gan_type == 'hinge': 1055 | real_loss = tf.reduce_mean(relu(1.0 - real)) 1056 | fake_loss = tf.reduce_mean(relu(1.0 + fake)) 1057 | 1058 | if gan_type == 'sphere': 1059 | bs, c = real.get_shape().as_list() 1060 | moment = 3 1061 | north_pole = tf.one_hot(tf.tile([c], multiples=[bs]), depth=c + 1) # [bs, c+1] -> [0, 0, 0, ... , 1] 1062 | 1063 | real_projection = inverse_stereographic_projection(real) 1064 | fake_projection = inverse_stereographic_projection(fake) 1065 | 1066 | for i in range(1, moment + 1): 1067 | real_loss += -tf.reduce_mean(tf.pow(sphere_loss(real_projection, north_pole), i)) 1068 | fake_loss += tf.reduce_mean(tf.pow(sphere_loss(fake_projection, north_pole), i)) 1069 | 1070 | 1071 | loss = real_loss + fake_loss 1072 | 1073 | return loss 1074 | 1075 | 1076 | def generator_loss(Ra, gan_type, real, fake): 1077 | # Ra = Relativistic 1078 | fake_loss = 0 1079 | real_loss = 0 1080 | 1081 | if Ra and (gan_type.__contains__('wgan') or gan_type == 'sphere'): 1082 | print("No exist [Ra + WGAN or Ra + Sphere], so use the {} loss function".format(gan_type)) 1083 | Ra = False 1084 | 1085 | if Ra: 1086 | fake_logit = (fake - tf.reduce_mean(real)) 1087 | real_logit = (real - tf.reduce_mean(fake)) 1088 | 1089 | if gan_type == 'lsgan': 1090 | fake_loss = tf.reduce_mean(tf.square(fake_logit - 1.0)) 1091 | real_loss = tf.reduce_mean(tf.square(real_logit + 1.0)) 1092 | 1093 | if gan_type == 'gan' or gan_type == 'gan-gp' or gan_type == 'dragan': 1094 | fake_loss = tf.reduce_mean( 1095 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(fake), logits=fake_logit)) 1096 | real_loss = tf.reduce_mean( 1097 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(real), logits=real_logit)) 1098 | 1099 | if gan_type == 'hinge': 1100 | fake_loss = tf.reduce_mean(relu(1.0 - fake_logit)) 1101 | real_loss = tf.reduce_mean(relu(1.0 + real_logit)) 1102 | 1103 | else: 1104 | if gan_type.__contains__('wgan'): 1105 | fake_loss = -tf.reduce_mean(fake) 1106 | 1107 | if gan_type == 'lsgan': 1108 | fake_loss = tf.reduce_mean(tf.square(fake - 1.0)) 1109 | 1110 | if gan_type == 'gan' or gan_type == 'gan-gp' or gan_type == 'dragan': 1111 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(fake), logits=fake)) 1112 | 1113 | if gan_type == 'hinge': 1114 | fake_loss = -tf.reduce_mean(fake) 1115 | 1116 | if gan_type == 'sphere': 1117 | bs, c = real.get_shape().as_list() 1118 | moment = 3 1119 | north_pole = tf.one_hot(tf.tile([c], multiples=[bs]), depth=c + 1) # [bs, c+1] -> [0, 0, 0, ... , 1] 1120 | 1121 | fake_projection = inverse_stereographic_projection(fake) 1122 | 1123 | for i in range(1, moment + 1): 1124 | fake_loss += -tf.reduce_mean(tf.pow(sphere_loss(fake_projection, north_pole), i)) 1125 | 1126 | loss = fake_loss + real_loss 1127 | 1128 | return loss 1129 | 1130 | 1131 | def vdb_loss(mu, logvar, i_c=0.1): 1132 | # variational discriminator bottleneck loss 1133 | kl_divergence = 0.5 * tf.reduce_sum(tf.square(mu) + tf.exp(logvar) - 1 - logvar, axis=-1) 1134 | 1135 | loss = tf.reduce_mean(kl_divergence - i_c) 1136 | 1137 | return loss 1138 | 1139 | 1140 | def simple_gp(real_logit, fake_logit, real_images, fake_images, r1_gamma=10, r2_gamma=0): 1141 | # Used in StyleGAN 1142 | 1143 | r1_penalty = 0 1144 | r2_penalty = 0 1145 | 1146 | if r1_gamma != 0: 1147 | real_loss = tf.reduce_sum(real_logit) # In some cases, you may use reduce_mean 1148 | real_grads = tf.gradients(real_loss, real_images)[0] 1149 | 1150 | r1_penalty = 0.5 * r1_gamma * tf.reduce_mean(tf.reduce_sum(tf.square(real_grads), axis=[1, 2, 3])) 1151 | 1152 | if r2_gamma != 0: 1153 | fake_loss = tf.reduce_sum(fake_logit) # In some cases, you may use reduce_mean 1154 | fake_grads = tf.gradients(fake_loss, fake_images)[0] 1155 | 1156 | r2_penalty = 0.5 * r2_gamma * tf.reduce_mean(tf.reduce_sum(tf.square(fake_grads), axis=[1, 2, 3])) 1157 | 1158 | return r1_penalty + r2_penalty 1159 | 1160 | def inverse_stereographic_projection(x) : 1161 | 1162 | x_u = tf.transpose(2 * x) / (tf.pow(tf.norm(x, axis=-1), 2) + 1.0) 1163 | x_v = (tf.pow(tf.norm(x, axis=-1), 2) - 1.0) / (tf.pow(tf.norm(x, axis=-1), 2) + 1.0) 1164 | 1165 | x_projection = tf.transpose(tf.concat([x_u, [x_v]], axis=0)) 1166 | 1167 | return x_projection 1168 | 1169 | def sphere_loss(x, y) : 1170 | 1171 | loss = tf.math.acos(tf.matmul(x, tf.transpose(y))) 1172 | 1173 | return loss 1174 | 1175 | ################################################################################## 1176 | # KL-Divergence Loss Function 1177 | ################################################################################## 1178 | 1179 | # typical version 1180 | def z_sample(mean, logvar): 1181 | eps = tf.random_normal(tf.shape(mean), mean=0.0, stddev=1.0, dtype=tf.float32) 1182 | 1183 | return mean + tf.exp(logvar * 0.5) * eps 1184 | 1185 | 1186 | def kl_loss(mean, logvar): 1187 | # shape : [batch_size, channel] 1188 | loss = 0.5 * tf.reduce_sum(tf.square(mean) + tf.exp(logvar) - 1 - logvar, axis=-1) 1189 | loss = tf.reduce_mean(loss) 1190 | 1191 | return loss 1192 | 1193 | 1194 | # version 2 1195 | def z_sample_2(mean, sigma): 1196 | eps = tf.random_normal(tf.shape(mean), mean=0.0, stddev=1.0, dtype=tf.float32) 1197 | 1198 | return mean + sigma * eps 1199 | 1200 | 1201 | def kl_loss_2(mean, sigma): 1202 | # shape : [batch_size, channel] 1203 | loss = 0.5 * tf.reduce_sum(tf.square(mean) + tf.square(sigma) - tf.log(1e-8 + tf.square(sigma)) - 1, axis=-1) 1204 | loss = tf.reduce_mean(loss) 1205 | 1206 | return loss 1207 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import random, os 4 | from tensorflow.contrib import slim 5 | import cv2 6 | 7 | class ImageData: 8 | 9 | def __init__(self, img_height, img_width, channels, augment_flag): 10 | self.img_height = img_height 11 | self.img_width = img_width 12 | self.channels = channels 13 | self.augment_flag = augment_flag 14 | 15 | def image_processing(self, filename): 16 | x = tf.read_file(filename) 17 | x_decode = tf.image.decode_jpeg(x, channels=self.channels, dct_method='INTEGER_ACCURATE') 18 | img = tf.image.resize_images(x_decode, [self.img_height, self.img_width]) 19 | img = tf.cast(img, tf.float32) / 127.5 - 1 20 | 21 | if self.augment_flag : 22 | augment_height = self.img_height + (30 if self.img_height == 256 else int(self.img_height * 0.1)) 23 | augment_width = self.img_width + (30 if self.img_width == 256 else int(self.img_width * 0.1)) 24 | 25 | img = tf.cond(pred=tf.greater_equal(tf.random_uniform(shape=[], minval=0.0, maxval=1.0), 0.5), 26 | true_fn=lambda: augmentation(img, augment_height, augment_width), 27 | false_fn=lambda: img) 28 | 29 | return img 30 | 31 | def load_test_image(image_path, img_width, img_height, img_channel): 32 | 33 | if img_channel == 1 : 34 | img = cv2.imread(image_path, flags=cv2.IMREAD_GRAYSCALE) 35 | else : 36 | img = cv2.imread(image_path, flags=cv2.IMREAD_COLOR) 37 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 38 | 39 | img = cv2.resize(img, dsize=(img_width, img_height)) 40 | 41 | if img_channel == 1 : 42 | img = np.expand_dims(img, axis=0) 43 | img = np.expand_dims(img, axis=-1) 44 | else : 45 | img = np.expand_dims(img, axis=0) 46 | 47 | img = img/127.5 - 1 48 | 49 | return img 50 | 51 | def augmentation(image, augment_height, augment_width): 52 | seed = random.randint(0, 2 ** 31 - 1) 53 | ori_image_shape = tf.shape(image) 54 | image = tf.image.random_flip_left_right(image, seed=seed) 55 | image = tf.image.resize_images(image, [augment_height, augment_width]) 56 | image = tf.random_crop(image, ori_image_shape, seed=seed) 57 | return image 58 | 59 | def save_images(images, size, image_path): 60 | return imsave(inverse_transform(images), size, image_path) 61 | 62 | def inverse_transform(images): 63 | return ((images+1.) / 2) * 255.0 64 | 65 | 66 | def imsave(images, size, path): 67 | images = merge(images, size) 68 | images = cv2.cvtColor(images.astype('uint8'), cv2.COLOR_RGB2BGR) 69 | 70 | return cv2.imwrite(path, images) 71 | 72 | def merge(images, size): 73 | h, w = images.shape[1], images.shape[2] 74 | img = np.zeros((h * size[0], w * size[1], 3)) 75 | for idx, image in enumerate(images): 76 | i = idx % size[1] 77 | j = idx // size[1] 78 | img[h*j:h*(j+1), w*i:w*(i+1), :] = image 79 | 80 | return img 81 | 82 | def orthogonal_regularizer(scale) : 83 | """ Defining the Orthogonal regularizer and return the function at last to be used in Conv layer as kernel regularizer""" 84 | 85 | def ortho_reg(w) : 86 | """ Reshaping the matrxi in to 2D tensor for enforcing orthogonality""" 87 | _, _, _, c = w.get_shape().as_list() 88 | 89 | w = tf.reshape(w, [-1, c]) 90 | 91 | """ Declaring a Identity Tensor of appropriate size""" 92 | identity = tf.eye(c) 93 | 94 | """ Regularizer Wt*W - I """ 95 | w_transpose = tf.transpose(w) 96 | w_mul = tf.matmul(w_transpose, w) 97 | reg = tf.subtract(w_mul, identity) 98 | 99 | """Calculating the Loss Obtained""" 100 | ortho_loss = tf.nn.l2_loss(reg) 101 | 102 | return scale * ortho_loss 103 | 104 | return ortho_reg 105 | 106 | def orthogonal_regularizer_fully(scale) : 107 | """ Defining the Orthogonal regularizer and return the function at last to be used in Fully Connected Layer """ 108 | 109 | def ortho_reg_fully(w) : 110 | """ Reshaping the matrix in to 2D tensor for enforcing orthogonality""" 111 | _, c = w.get_shape().as_list() 112 | 113 | """Declaring a Identity Tensor of appropriate size""" 114 | identity = tf.eye(c) 115 | w_transpose = tf.transpose(w) 116 | w_mul = tf.matmul(w_transpose, w) 117 | reg = tf.subtract(w_mul, identity) 118 | 119 | """ Calculating the Loss """ 120 | ortho_loss = tf.nn.l2_loss(reg) 121 | 122 | return scale * ortho_loss 123 | 124 | return ortho_reg_fully 125 | 126 | def tf_rgb_to_gray(x) : 127 | x = (x + 1.0) * 0.5 128 | x = tf.image.rgb_to_grayscale(x) 129 | 130 | x = (x * 2) - 1.0 131 | 132 | return x 133 | 134 | def RGB2LAB(srgb): 135 | srgb = inverse_transform(srgb) 136 | 137 | lab = rgb_to_lab(srgb) 138 | l, a, b = preprocess_lab(lab) 139 | 140 | l = tf.expand_dims(l, axis=-1) 141 | a = tf.expand_dims(a, axis=-1) 142 | b = tf.expand_dims(b, axis=-1) 143 | 144 | x = tf.concat([l, a, b], axis=-1) 145 | 146 | return x 147 | 148 | def LAB2RGB(lab) : 149 | lab = inverse_transform(lab) 150 | 151 | rgb = lab_to_rgb(lab) 152 | rgb = tf.clip_by_value(rgb, 0, 1) 153 | 154 | # r, g, b = tf.unstack(rgb, axis=-1) 155 | # rgb = tf.concat([r,g,b], axis=-1) 156 | 157 | x = (rgb * 2) - 1.0 158 | 159 | return x 160 | 161 | def rgb_to_lab(srgb): 162 | with tf.name_scope('rgb_to_lab'): 163 | srgb_pixels = tf.reshape(srgb, [-1, 3]) 164 | with tf.name_scope('srgb_to_xyz'): 165 | linear_mask = tf.cast(srgb_pixels <= 0.04045, dtype=tf.float32) 166 | exponential_mask = tf.cast(srgb_pixels > 0.04045, dtype=tf.float32) 167 | rgb_pixels = (srgb_pixels / 12.92 * linear_mask) + (((srgb_pixels + 0.055) / 1.055) ** 2.4) * exponential_mask 168 | rgb_to_xyz = tf.constant([ 169 | # X Y Z 170 | [0.412453, 0.212671, 0.019334], # R 171 | [0.357580, 0.715160, 0.119193], # G 172 | [0.180423, 0.072169, 0.950227], # B 173 | ]) 174 | xyz_pixels = tf.matmul(rgb_pixels, rgb_to_xyz) 175 | 176 | with tf.name_scope('xyz_to_cielab'): 177 | # convert to fx = f(X/Xn), fy = f(Y/Yn), fz = f(Z/Zn) 178 | 179 | # normalize for D65 white point 180 | xyz_normalized_pixels = tf.multiply(xyz_pixels, [1/0.950456, 1.0, 1/1.088754]) 181 | 182 | epsilon = 6/29 183 | linear_mask = tf.cast(xyz_normalized_pixels <= (epsilon**3), dtype=tf.float32) 184 | exponential_mask = tf.cast(xyz_normalized_pixels > (epsilon**3), dtype=tf.float32) 185 | fxfyfz_pixels = (xyz_normalized_pixels / (3 * epsilon**2) + 4/29) * linear_mask + (xyz_normalized_pixels ** (1/3)) * exponential_mask 186 | 187 | # convert to lab 188 | fxfyfz_to_lab = tf.constant([ 189 | # l a b 190 | [ 0.0, 500.0, 0.0], # fx 191 | [116.0, -500.0, 200.0], # fy 192 | [ 0.0, 0.0, -200.0], # fz 193 | ]) 194 | lab_pixels = tf.matmul(fxfyfz_pixels, fxfyfz_to_lab) + tf.constant([-16.0, 0.0, 0.0]) 195 | 196 | return tf.reshape(lab_pixels, tf.shape(srgb)) 197 | 198 | 199 | def lab_to_rgb(lab): 200 | with tf.name_scope('lab_to_rgb'): 201 | lab_pixels = tf.reshape(lab, [-1, 3]) 202 | with tf.name_scope('cielab_to_xyz'): 203 | # convert to fxfyfz 204 | lab_to_fxfyfz = tf.constant([ 205 | # fx fy fz 206 | [1/116.0, 1/116.0, 1/116.0], # l 207 | [1/500.0, 0.0, 0.0], # a 208 | [ 0.0, 0.0, -1/200.0], # b 209 | ]) 210 | fxfyfz_pixels = tf.matmul(lab_pixels + tf.constant([16.0, 0.0, 0.0]), lab_to_fxfyfz) 211 | 212 | # convert to xyz 213 | epsilon = 6/29 214 | linear_mask = tf.cast(fxfyfz_pixels <= epsilon, dtype=tf.float32) 215 | exponential_mask = tf.cast(fxfyfz_pixels > epsilon, dtype=tf.float32) 216 | xyz_pixels = (3 * epsilon**2 * (fxfyfz_pixels - 4/29)) * linear_mask + (fxfyfz_pixels ** 3) * exponential_mask 217 | 218 | # denormalize for D65 white point 219 | xyz_pixels = tf.multiply(xyz_pixels, [0.950456, 1.0, 1.088754]) 220 | 221 | with tf.name_scope('xyz_to_srgb'): 222 | xyz_to_rgb = tf.constant([ 223 | # r g b 224 | [ 3.2404542, -0.9692660, 0.0556434], # x 225 | [-1.5371385, 1.8760108, -0.2040259], # y 226 | [-0.4985314, 0.0415560, 1.0572252], # z 227 | ]) 228 | rgb_pixels = tf.matmul(xyz_pixels, xyz_to_rgb) 229 | # avoid a slightly negative number messing up the conversion 230 | rgb_pixels = tf.clip_by_value(rgb_pixels, 0.0, 1.0) 231 | linear_mask = tf.cast(rgb_pixels <= 0.0031308, dtype=tf.float32) 232 | exponential_mask = tf.cast(rgb_pixels > 0.0031308, dtype=tf.float32) 233 | srgb_pixels = (rgb_pixels * 12.92 * linear_mask) + ((rgb_pixels ** (1/2.4) * 1.055) - 0.055) * exponential_mask 234 | 235 | return tf.reshape(srgb_pixels, tf.shape(lab)) 236 | 237 | def preprocess_lab(lab): 238 | with tf.name_scope('preprocess_lab'): 239 | L_chan, a_chan, b_chan = tf.unstack(lab, axis=-1) 240 | # L_chan: black and white with input range [0, 100] 241 | # a_chan/b_chan: color channels with input range [-128, 127] 242 | # [0, 100] => [-1, 1], ~[-128, 127] => [-1, 1] 243 | 244 | L_chan = L_chan * 255.0 / 100.0 245 | a_chan = a_chan + 128 246 | b_chan = b_chan + 128 247 | 248 | L_chan /= 255.0 249 | a_chan /= 255.0 250 | b_chan /= 255.0 251 | 252 | L_chan = (L_chan - 0.5) / 0.5 253 | a_chan = (a_chan - 0.5) / 0.5 254 | b_chan = (b_chan - 0.5) / 0.5 255 | 256 | return [L_chan, a_chan, b_chan] 257 | 258 | def show_all_variables(): 259 | model_vars = tf.trainable_variables() 260 | slim.model_analyzer.analyze_vars(model_vars, print_info=True) 261 | 262 | def check_folder(log_dir): 263 | if not os.path.exists(log_dir): 264 | os.makedirs(log_dir) 265 | return log_dir 266 | 267 | def str2bool(x): 268 | return x.lower() in ('true') 269 | 270 | def pytorch_xavier_weight_factor(gain=0.02, uniform=False) : 271 | 272 | if uniform : 273 | factor = gain * gain 274 | mode = 'FAN_AVG' 275 | else : 276 | factor = (gain * gain) / 1.3 277 | mode = 'FAN_AVG' 278 | 279 | return factor, mode, uniform 280 | 281 | def pytorch_kaiming_weight_factor(a=0.0, activation_function='relu', uniform=False) : 282 | 283 | if activation_function == 'relu' : 284 | gain = np.sqrt(2.0) 285 | elif activation_function == 'leaky_relu' : 286 | gain = np.sqrt(2.0 / (1 + a ** 2)) 287 | elif activation_function =='tanh' : 288 | gain = 5.0 / 3 289 | else : 290 | gain = 1.0 291 | 292 | if uniform : 293 | factor = gain * gain 294 | mode = 'FAN_IN' 295 | else : 296 | factor = (gain * gain) / 1.3 297 | mode = 'FAN_IN' 298 | 299 | return factor, mode, uniform --------------------------------------------------------------------------------