├── README.md └── GloRe.py /README.md: -------------------------------------------------------------------------------- 1 | # GloRe 2 | Tensorflow implementation of Global Reasoning unit (GloRe) from Graph-Based Global Reasoning Networks. 3 | -------------------------------------------------------------------------------- /GloRe.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | slim = tf.contrib.slim 3 | 4 | def GloRe(X, C, N, activation_fn=None): 5 | imput_chancel = X.get_shape().as_list()[-1] 6 | inputs_shape = tf.shape(X) 7 | 8 | B = slim.conv2d(X, N, [1, 1]) 9 | B = tf.reshape(B, [inputs_shape[0], -1, N]) # [B, H*W, N] 10 | 11 | x_reduced = slim.conv2d(X, C, [1, 1]) 12 | x_reduced = tf.reshape(x_reduced, [inputs_shape[0], -1, C]) # [B, H*W, C] 13 | x_reduced = tf.transpose(x_reduced, perm=[0, 2, 1]) # [B, C, H*W] 14 | 15 | # [B, C, H * W] * [B, H*W, N] —>#[B, C, N] 16 | v = tf.matmul(x_reduced, B) # [B, C, N] 17 | v = tf.expand_dims(v, axis=1) # [B, 1, C, N] 18 | 19 | def GCN(Vnode, nodeN, mid_chancel): 20 | 21 | net = slim.conv2d(Vnode, nodeN, [1, 1], ) # [B, 1, C, N] 22 | 23 | net = Vnode -net #(I-Ag)V 24 | 25 | net = tf.transpose(net, perm=[0, 3, 1, 2]) # [B, N, 1, C] 26 | 27 | net = slim.conv2d(net, mid_chancel, [1, 1]) # [B, N, 1, C] 28 | 29 | return net 30 | 31 | z = GCN(v, N, C) # [B, N, 1, C] 32 | z = tf.reshape(z, [inputs_shape[0], N, C]) # [B, N, C] 33 | 34 | # [B, H*W, N] * [B, N, C] => [B, H*W, C] 35 | y = tf.matmul(B, z) # [B, H*W, C] 36 | y = tf.expand_dims(y, axis=1) #[B, 1, H*W, C] 37 | y = tf.reshape(y, [inputs_shape[0], inputs_shape[1], inputs_shape[2], C]) # [B, H, W, C] 38 | x_res = slim.conv2d(y, imput_chancel, [1, 1]) 39 | 40 | return X + x_res 41 | --------------------------------------------------------------------------------