├── NormalizedCrossCorrelation.py └── README.md /NormalizedCrossCorrelation.py: -------------------------------------------------------------------------------- 1 | from keras.engine import Layer 2 | from keras import backend as K 3 | import tensorflow as tf 4 | from keras.layers import Input 5 | from keras.models import Model 6 | 7 | import numpy as np 8 | 9 | class Normlized_Cross_Correlation(Layer): 10 | # try to speed up 11 | def __init__(self, k=0, d=1, s1=1, s2=2, **kwargs): 12 | super(Normlized_Cross_Correlation, self).__init__(**kwargs) 13 | 14 | self.D = 2 * d + 1 # output size 15 | self.K = 2 * k + 1 16 | self.d = d 17 | self.k = k 18 | self.s1 = s1 19 | self.s2 = s2 20 | self.D_stride = self.s2*(self.D-1)+self.K # make sure the output is DxD 21 | 22 | def build(self, input_shape): 23 | super(Normlized_Cross_Correlation, self).build(input_shape) 24 | 25 | def compute_output_shape(self, input_shape): 26 | return (input_shape[0][0], input_shape[0][1] / self.s1, input_shape[0][2] / self.s1, self.D ** 2) 27 | 28 | def call(self, x, mask=None): 29 | input_1, input_2 = x 30 | input_shape = input_1._keras_shape 31 | 32 | assert input_shape == input_2._keras_shape 33 | 34 | self.H = input_shape[1] 35 | self.W = input_shape[2] 36 | self.C = input_shape[3] 37 | 38 | padding1 = K.spatial_2d_padding(input_1, 39 | padding=((self.k, self.k+1), (self.k, self.k+1)), 40 | data_format='channels_last') 41 | padding2 = K.spatial_2d_padding(input_2, 42 | padding=( 43 | ((self.D_stride-1)/2, (self.D_stride-1)/2+1), 44 | ((self.D_stride-1)/2, (self.D_stride-1)/2+1)), 45 | data_format='channels_last') 46 | # padding1&2: [nS, w, h, c] 47 | 48 | out = tf.scan(self.single_sample_corr, 49 | elems=[padding2, padding1], 50 | initializer=(K.zeros((self.H / self.s1, self.W / self.s1, self.D ** 2)))) 51 | 52 | return out 53 | 54 | def single_sample_corr(self, previous, features): 55 | fea1, fea2 = features # fea1: the displacement, fea2: the kernel 56 | 57 | displaces = [] 58 | kernels = [] 59 | for i in range(self.H / self.s1): 60 | for j in range(self.W / self.s1): 61 | slice_h_ker = slice(i * self.s1, i * self.s1 + self.K) 62 | slice_w_ker = slice(j * self.s1, j * self.s1 + self.K) 63 | 64 | slice_h_dis = slice(i * self.s1, i * self.s1 + self.D_stride) 65 | slice_w_dis = slice(j * self.s1, j * self.s1 + self.D_stride) 66 | 67 | kernels.append(fea2[slice_h_ker, slice_w_ker, :]) 68 | displaces.append(fea1[slice_h_dis, slice_w_dis, :]) 69 | 70 | displaces = K.stack(displaces, axis=0) # [WH/s1s1, D_stride, D_stride, C] 71 | kernels = K.stack(kernels, axis=0) 72 | kernels = K.permute_dimensions(kernels, (1, 2, 3, 0)) # [K, K, C, WH/s1s1] 73 | corr = self.correlation(kernels, displaces, s2=self.s2) # [WH/s1s1, D, D, WH/s1s1] 74 | 75 | # get diag 76 | b = [] 77 | for i in range(self.H*self.W/self.s1**2): 78 | b.append(corr[i, :, :, i]) 79 | a = K.stack(b, axis=0) # [WH/s1s1, D, D] 80 | 81 | out = K.reshape(a, (self.H / self.s1, self.W / self.s1, self.D ** 2)) 82 | 83 | return out 84 | 85 | def correlation(self, kernel, displace, s2=1): 86 | dis_std = K.std(displace, axis=[3], keepdims=True) 87 | ker_std = K.std(kernel, axis=[2], keepdims=True) 88 | displace = (displace - K.mean(displace, axis=[3], keepdims=True))/(dis_std+0.000001) 89 | kernel = (kernel - K.mean(kernel, axis=[2], keepdims=True))/(ker_std+0.0000001) 90 | # print kernel, displace 91 | return K.conv2d(displace, kernel, strides=(s2, s2), padding='valid', data_format='channels_last') 92 | 93 | 94 | if __name__ == '__main__': 95 | from skimage.io import imread 96 | 97 | img = imread('image.bmp') 98 | print img.shape 99 | 100 | img = img[np.newaxis, :, :, :] 101 | img2 = np.zeros((32, 64, 64, 3)) 102 | for i in range(32): 103 | img2[i, :, :, :] = img[:, 164:228, 164:228, :]*(np.random.random()*0.3+0.7) 104 | img2 /= 255. 105 | 106 | input1 = Input(shape=(64, 64, 3)) 107 | input2 = Input(shape=(64, 64, 3)) 108 | crop = Normlized_Cross_Correlation(k=2, d=5, s1=2, s2=1)([input1, input2]) 109 | 110 | model = Model([input1, input2], crop) 111 | 112 | corr = model.predict([img2, img2]) 113 | 114 | print(corr.shape) 115 | 116 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | NormalizedCrossCorrelationLayer 2 | ================================ 3 | 4 | The implementation of Normalized Cross Correlation Layer, which is proposed by Dosovitskiy et al.[1], on Keras with tensorflow backend. 5 | 6 | "NormalizedCrossCorrelation.py" contains the code of the layer and an simple example of how to use it. 7 | 8 | 9 | # Reference 10 | [1] Dosovitskiy, A., Fischer, P., Ilg, E., Hausser, P., Hazirbas, C., Golkov, V., ... & Brox, T. (2015). Flownet: Learning optical flow with convolutional networks. In Proceedings of the IEEE International Conference on Computer Vision (pp. 2758-2766). 11 | --------------------------------------------------------------------------------