├── .DS_Store
├── assets
├── gcb.png
└── teaser.png
├── LICENSE
├── .gitignore
├── README.md
└── GCNet.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/GCNet-Tensorflow/HEAD/.DS_Store
--------------------------------------------------------------------------------
/assets/gcb.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/GCNet-Tensorflow/HEAD/assets/gcb.png
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/GCNet-Tensorflow/HEAD/assets/teaser.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Junho Kim (1993.01.12)
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GCNet-Tensorflow
2 |
3 | Simple Tensorflow implementation of ["GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond"](https://arxiv.org/abs/1904.11492)
4 |
5 |
6 |
7 |

8 |
9 |
10 | ## Summary
11 | ### Architecture
12 | 
13 |
14 | ### Code
15 | ```python
16 | def global_context_block(x, channels, use_bias=True, sn=False, scope='gc_block'):
17 | with tf.variable_scope(scope):
18 | with tf.variable_scope('context_modeling'):
19 | bs, h, w, c = x.get_shape().as_list()
20 | input_x = x
21 | input_x = hw_flatten(input_x) # [N, H*W, C]
22 | input_x = tf.transpose(input_x, perm=[0, 2, 1])
23 | input_x = tf.expand_dims(input_x, axis=1)
24 |
25 | context_mask = conv(x, channels=1, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv')
26 | context_mask = hw_flatten(context_mask)
27 | context_mask = tf.nn.softmax(context_mask, axis=1) # [N, H*W, 1]
28 | context_mask = tf.transpose(context_mask, perm=[0, 2, 1])
29 | context_mask = tf.expand_dims(context_mask, axis=-1)
30 |
31 | context = tf.matmul(input_x, context_mask)
32 | context = tf.reshape(context, shape=[bs, 1, 1, c])
33 |
34 | with tf.variable_scope('transform_0'):
35 | context_transform = conv(context, channels, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv_0')
36 | context_transform = layer_norm(context_transform)
37 | context_transform = relu(context_transform)
38 | context_transform = conv(context_transform, channels=c, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv_1')
39 | context_transform = sigmoid(context_transform)
40 |
41 | x = x * context_transform
42 |
43 | with tf.variable_scope('transform_1'):
44 | context_transform = conv(context, channels, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv_0')
45 | context_transform = layer_norm(context_transform)
46 | context_transform = relu(context_transform)
47 | context_transform = conv(context_transform, channels=c, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv_1')
48 |
49 | x = x + context_transform
50 |
51 | return x
52 | ```
53 |
54 | ### Usage
55 | ```python
56 | from GCNet import *
57 |
58 | x = global_context_block(x, channels=64, use_bias=True, sn=True, scope='gc_block')
59 | ```
60 |
61 | ## Related works
62 | * [Tensorflow cookbook](https://github.com/taki0112/Tensorflow-Cookbook)
63 | ## Author
64 | Junho Kim
65 |
--------------------------------------------------------------------------------
/GCNet.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | weight_init = tf.truncated_normal_initializer(mean=0.0, stddev=0.02)
4 | weight_regularizer = tf.contrib.layers.l2_regularizer(0.0001)
5 |
6 | def conv(x, channels, kernel=4, stride=2, pad=0, pad_type='zero', use_bias=True, sn=False, scope='conv_0'):
7 | with tf.variable_scope(scope):
8 | if pad > 0:
9 | h = x.get_shape().as_list()[1]
10 | if h % stride == 0:
11 | pad = pad * 2
12 | else:
13 | pad = max(kernel - (h % stride), 0)
14 |
15 | pad_top = pad // 2
16 | pad_bottom = pad - pad_top
17 | pad_left = pad // 2
18 | pad_right = pad - pad_left
19 |
20 | if pad_type == 'zero':
21 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]])
22 | if pad_type == 'reflect':
23 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]], mode='REFLECT')
24 |
25 | if sn:
26 | w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], initializer=weight_init,
27 | regularizer=weight_regularizer)
28 | x = tf.nn.conv2d(input=x, filter=spectral_norm(w),
29 | strides=[1, stride, stride, 1], padding='VALID')
30 | if use_bias:
31 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0))
32 | x = tf.nn.bias_add(x, bias)
33 |
34 | else:
35 | x = tf.layers.conv2d(inputs=x, filters=channels,
36 | kernel_size=kernel, kernel_initializer=weight_init,
37 | kernel_regularizer=weight_regularizer,
38 | strides=stride, use_bias=use_bias)
39 |
40 | return x
41 |
42 | def global_context_block(x, channels, use_bias=True, sn=False, scope='gc_block'):
43 | with tf.variable_scope(scope):
44 | with tf.variable_scope('context_modeling'):
45 | bs, h, w, c = x.get_shape().as_list()
46 | input_x = x
47 | input_x = hw_flatten(input_x) # [N, H*W, C]
48 | input_x = tf.transpose(input_x, perm=[0, 2, 1])
49 | input_x = tf.expand_dims(input_x, axis=1)
50 |
51 | context_mask = conv(x, channels=1, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv')
52 | context_mask = hw_flatten(context_mask)
53 | context_mask = tf.nn.softmax(context_mask, axis=1) # [N, H*W, 1]
54 | context_mask = tf.transpose(context_mask, perm=[0, 2, 1])
55 | context_mask = tf.expand_dims(context_mask, axis=-1)
56 |
57 | context = tf.matmul(input_x, context_mask)
58 | context = tf.reshape(context, shape=[bs, 1, 1, c])
59 |
60 | with tf.variable_scope('transform_0'):
61 | context_transform = conv(context, channels, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv_0')
62 | context_transform = layer_norm(context_transform)
63 | context_transform = relu(context_transform)
64 | context_transform = conv(context_transform, channels=c, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv_1')
65 | context_transform = sigmoid(context_transform)
66 |
67 | x = x * context_transform
68 |
69 | with tf.variable_scope('transform_1'):
70 | context_transform = conv(context, channels, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv_0')
71 | context_transform = layer_norm(context_transform)
72 | context_transform = relu(context_transform)
73 | context_transform = conv(context_transform, channels=c, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv_1')
74 |
75 | x = x + context_transform
76 |
77 | return x
78 |
79 | def layer_norm(x, scope='layer_norm'):
80 | return tf.contrib.layers.layer_norm(x,
81 | center=True, scale=True,
82 | scope=scope)
83 | def relu(x):
84 | return tf.nn.relu(x)
85 |
86 | def sigmoid(x):
87 | return tf.sigmoid(x)
88 |
89 | def hw_flatten(x):
90 | return tf.reshape(x, shape=[x.shape[0], -1, x.shape[-1]])
91 |
92 | def spectral_norm(w, iteration=1):
93 | w_shape = w.shape.as_list()
94 | w = tf.reshape(w, [-1, w_shape[-1]])
95 |
96 | u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.random_normal_initializer(), trainable=False)
97 |
98 | u_hat = u
99 | v_hat = None
100 | for i in range(iteration):
101 | """
102 | power iteration
103 | Usually iteration = 1 will be enough
104 | """
105 | v_ = tf.matmul(u_hat, tf.transpose(w))
106 | v_hat = tf.nn.l2_normalize(v_)
107 |
108 | u_ = tf.matmul(v_hat, w)
109 | u_hat = tf.nn.l2_normalize(u_)
110 |
111 | u_hat = tf.stop_gradient(u_hat)
112 | v_hat = tf.stop_gradient(v_hat)
113 |
114 | sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat))
115 |
116 | with tf.control_dependencies([u.assign(u_hat)]):
117 | w_norm = w / sigma
118 | w_norm = tf.reshape(w_norm, w_shape)
119 |
120 | return w_norm
--------------------------------------------------------------------------------