├── .DS_Store ├── assets ├── gcb.png └── teaser.png ├── LICENSE ├── .gitignore ├── README.md └── GCNet.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/GCNet-Tensorflow/HEAD/.DS_Store -------------------------------------------------------------------------------- /assets/gcb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/GCNet-Tensorflow/HEAD/assets/gcb.png -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/GCNet-Tensorflow/HEAD/assets/teaser.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Junho Kim (1993.01.12) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GCNet-Tensorflow 2 | 3 | Simple Tensorflow implementation of ["GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond"](https://arxiv.org/abs/1904.11492) 4 | 5 | 6 |
7 | 8 |
9 | 10 | ## Summary 11 | ### Architecture 12 | ![arc](./assets/gcb.png) 13 | 14 | ### Code 15 | ```python 16 | def global_context_block(x, channels, use_bias=True, sn=False, scope='gc_block'): 17 | with tf.variable_scope(scope): 18 | with tf.variable_scope('context_modeling'): 19 | bs, h, w, c = x.get_shape().as_list() 20 | input_x = x 21 | input_x = hw_flatten(input_x) # [N, H*W, C] 22 | input_x = tf.transpose(input_x, perm=[0, 2, 1]) 23 | input_x = tf.expand_dims(input_x, axis=1) 24 | 25 | context_mask = conv(x, channels=1, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv') 26 | context_mask = hw_flatten(context_mask) 27 | context_mask = tf.nn.softmax(context_mask, axis=1) # [N, H*W, 1] 28 | context_mask = tf.transpose(context_mask, perm=[0, 2, 1]) 29 | context_mask = tf.expand_dims(context_mask, axis=-1) 30 | 31 | context = tf.matmul(input_x, context_mask) 32 | context = tf.reshape(context, shape=[bs, 1, 1, c]) 33 | 34 | with tf.variable_scope('transform_0'): 35 | context_transform = conv(context, channels, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv_0') 36 | context_transform = layer_norm(context_transform) 37 | context_transform = relu(context_transform) 38 | context_transform = conv(context_transform, channels=c, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv_1') 39 | context_transform = sigmoid(context_transform) 40 | 41 | x = x * context_transform 42 | 43 | with tf.variable_scope('transform_1'): 44 | context_transform = conv(context, channels, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv_0') 45 | context_transform = layer_norm(context_transform) 46 | context_transform = relu(context_transform) 47 | context_transform = conv(context_transform, channels=c, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv_1') 48 | 49 | x = x + context_transform 50 | 51 | return x 52 | ``` 53 | 54 | ### Usage 55 | ```python 56 | from GCNet import * 57 | 58 | x = global_context_block(x, channels=64, use_bias=True, sn=True, scope='gc_block') 59 | ``` 60 | 61 | ## Related works 62 | * [Tensorflow cookbook](https://github.com/taki0112/Tensorflow-Cookbook) 63 | ## Author 64 | Junho Kim 65 | -------------------------------------------------------------------------------- /GCNet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | weight_init = tf.truncated_normal_initializer(mean=0.0, stddev=0.02) 4 | weight_regularizer = tf.contrib.layers.l2_regularizer(0.0001) 5 | 6 | def conv(x, channels, kernel=4, stride=2, pad=0, pad_type='zero', use_bias=True, sn=False, scope='conv_0'): 7 | with tf.variable_scope(scope): 8 | if pad > 0: 9 | h = x.get_shape().as_list()[1] 10 | if h % stride == 0: 11 | pad = pad * 2 12 | else: 13 | pad = max(kernel - (h % stride), 0) 14 | 15 | pad_top = pad // 2 16 | pad_bottom = pad - pad_top 17 | pad_left = pad // 2 18 | pad_right = pad - pad_left 19 | 20 | if pad_type == 'zero': 21 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]) 22 | if pad_type == 'reflect': 23 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]], mode='REFLECT') 24 | 25 | if sn: 26 | w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], initializer=weight_init, 27 | regularizer=weight_regularizer) 28 | x = tf.nn.conv2d(input=x, filter=spectral_norm(w), 29 | strides=[1, stride, stride, 1], padding='VALID') 30 | if use_bias: 31 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0)) 32 | x = tf.nn.bias_add(x, bias) 33 | 34 | else: 35 | x = tf.layers.conv2d(inputs=x, filters=channels, 36 | kernel_size=kernel, kernel_initializer=weight_init, 37 | kernel_regularizer=weight_regularizer, 38 | strides=stride, use_bias=use_bias) 39 | 40 | return x 41 | 42 | def global_context_block(x, channels, use_bias=True, sn=False, scope='gc_block'): 43 | with tf.variable_scope(scope): 44 | with tf.variable_scope('context_modeling'): 45 | bs, h, w, c = x.get_shape().as_list() 46 | input_x = x 47 | input_x = hw_flatten(input_x) # [N, H*W, C] 48 | input_x = tf.transpose(input_x, perm=[0, 2, 1]) 49 | input_x = tf.expand_dims(input_x, axis=1) 50 | 51 | context_mask = conv(x, channels=1, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv') 52 | context_mask = hw_flatten(context_mask) 53 | context_mask = tf.nn.softmax(context_mask, axis=1) # [N, H*W, 1] 54 | context_mask = tf.transpose(context_mask, perm=[0, 2, 1]) 55 | context_mask = tf.expand_dims(context_mask, axis=-1) 56 | 57 | context = tf.matmul(input_x, context_mask) 58 | context = tf.reshape(context, shape=[bs, 1, 1, c]) 59 | 60 | with tf.variable_scope('transform_0'): 61 | context_transform = conv(context, channels, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv_0') 62 | context_transform = layer_norm(context_transform) 63 | context_transform = relu(context_transform) 64 | context_transform = conv(context_transform, channels=c, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv_1') 65 | context_transform = sigmoid(context_transform) 66 | 67 | x = x * context_transform 68 | 69 | with tf.variable_scope('transform_1'): 70 | context_transform = conv(context, channels, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv_0') 71 | context_transform = layer_norm(context_transform) 72 | context_transform = relu(context_transform) 73 | context_transform = conv(context_transform, channels=c, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv_1') 74 | 75 | x = x + context_transform 76 | 77 | return x 78 | 79 | def layer_norm(x, scope='layer_norm'): 80 | return tf.contrib.layers.layer_norm(x, 81 | center=True, scale=True, 82 | scope=scope) 83 | def relu(x): 84 | return tf.nn.relu(x) 85 | 86 | def sigmoid(x): 87 | return tf.sigmoid(x) 88 | 89 | def hw_flatten(x): 90 | return tf.reshape(x, shape=[x.shape[0], -1, x.shape[-1]]) 91 | 92 | def spectral_norm(w, iteration=1): 93 | w_shape = w.shape.as_list() 94 | w = tf.reshape(w, [-1, w_shape[-1]]) 95 | 96 | u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.random_normal_initializer(), trainable=False) 97 | 98 | u_hat = u 99 | v_hat = None 100 | for i in range(iteration): 101 | """ 102 | power iteration 103 | Usually iteration = 1 will be enough 104 | """ 105 | v_ = tf.matmul(u_hat, tf.transpose(w)) 106 | v_hat = tf.nn.l2_normalize(v_) 107 | 108 | u_ = tf.matmul(v_hat, w) 109 | u_hat = tf.nn.l2_normalize(u_) 110 | 111 | u_hat = tf.stop_gradient(u_hat) 112 | v_hat = tf.stop_gradient(v_hat) 113 | 114 | sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat)) 115 | 116 | with tf.control_dependencies([u.assign(u_hat)]): 117 | w_norm = w / sigma 118 | w_norm = tf.reshape(w_norm, w_shape) 119 | 120 | return w_norm --------------------------------------------------------------------------------