├── LPM.py ├── README.md ├── church.mat ├── city.mat └── demo.py /LPM.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.neighbors import KDTree 3 | 4 | def LPM_cosF(neighborX, neighborY, lbd, vec, d2, tau, K): 5 | L = neighborX.shape[0] 6 | C = 0 7 | Km = np.array([K+2, K, K-2]) 8 | M = len(Km) 9 | 10 | for KK in Km: 11 | neighborX = neighborX[:,1:KK+1] 12 | neighborY = neighborY[:,1:KK+1] 13 | 14 | ## This is a loop implementation for computing c1 and c2, much slower but more readable 15 | # ni = np.zeros((L,1)) 16 | # c1 = np.zeros((L,1)) 17 | # c2 = np.zeros((L,1)) 18 | # for i in range(L): 19 | # inters = np.intersect1d(neighborX[i,:], neighborY[i,:]) 20 | # ni[i] = len(inters) 21 | # c1[i] = KK - ni[i] 22 | # cos_sita = np.sum(vec[inters, :]*vec[i,:],axis=1)/np.sqrt(d2[inters]*d2[i]).reshape(ni[i].astype('int').item(), 1) 23 | # ratio = np.minimum(d2[inters], d2[i])/np.maximum(d2[inters], d2[i]) 24 | # ratio = ratio.reshape(-1,1) 25 | # label = cos_sita*ratio < tau 26 | # c2[i] = np.sum(label.astype('float64')) 27 | 28 | neighborIndex = np.hstack((neighborX,neighborY)) 29 | index = np.sort(neighborIndex,axis=1) 30 | temp1 = np.hstack((np.diff(index,axis = 1),np.ones((L,1)))) 31 | temp2 = (temp1==0).astype('int') 32 | ni = np.sum(temp2,axis=1) 33 | c1 = KK - ni 34 | temp3 = np.tile(vec.reshape((vec.shape[0],1,vec.shape[1])),(1,index.shape[1],1))*vec[index, :] 35 | temp4 = np.tile(d2.reshape((d2.shape[0],1)),(1,index.shape[1])) 36 | temp5 = d2[index]*temp4 37 | cos_sita = np.sum(temp3,axis=2).reshape((temp3.shape[0],temp3.shape[1]))/np.sqrt(temp5) 38 | ratio = np.minimum(d2[index], temp4)/np.maximum(d2[index], temp4) 39 | label = cos_sita*ratio < tau 40 | label = label.astype('int') 41 | c2 = np.sum(label*temp2,axis=1) 42 | 43 | C = C + (c1 + c2)/KK 44 | 45 | 46 | idx = np.where((C/M) <= lbd) 47 | return idx[0], C 48 | 49 | def LPM_filter(X, Y): 50 | lambda1 = 0.8 51 | lambda2 = 0.5 52 | numNeigh1 = 6 53 | numNeigh2 = 6 54 | tau1 = 0.2 55 | tau2 = 0.2 56 | 57 | vec = Y - X 58 | d2 = np.sum(vec**2,axis=1) 59 | 60 | treeX = KDTree(X) 61 | _, neighborX = treeX.query(X, k=numNeigh1+3) 62 | treeY = KDTree(Y) 63 | _, neighborY = treeY.query(Y, k=numNeigh1+3) 64 | 65 | idx, C = LPM_cosF(neighborX, neighborY, lambda1, vec, d2, tau1, numNeigh1) 66 | 67 | if len(idx) >= numNeigh2 + 4: 68 | treeX2 = KDTree(X[idx,:]) 69 | _, neighborX2 = treeX2.query(X, k=numNeigh2+3) 70 | treeY2 = KDTree(Y[idx,:]) 71 | _, neighborY2 = treeY2.query(Y, k=numNeigh2+3) 72 | neighborX2 = idx[neighborX2] 73 | neighborY2 = idx[neighborY2] 74 | idx, C = LPM_cosF(neighborX2, neighborY2, lambda2, vec, d2, tau2, numNeigh2) 75 | 76 | mask = np.zeros((X.shape[0],1)) 77 | mask[idx] = 1 78 | 79 | return mask.flatten().astype('bool') 80 | 81 | 82 | 83 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LPM_Python 2 | 3 | A Python implementation of the Locality Preserving Matching (LPM) method for pruning outliers in image matching. 4 | 5 | The code is established according to the MATLAB version https://github.com/jiayi-ma/LPM and supposed to have the same output and similar time cost. The parameters are tunable inside the function LPM_filter in LPM.py. 6 | 7 | If you find this code useful for your research, plese cite the paper: 8 | 9 | ``` 10 | @article{ma2019locality, 11 | title={Locality preserving matching}, 12 | author={Ma, Jiayi and Zhao, Ji and Jiang, Junjun and Zhou, Huabing and Guo, Xiaojie}, 13 | journal={International Journal of Computer Vision}, 14 | volume={127}, 15 | number={5}, 16 | pages={512--531}, 17 | year={2019}, 18 | publisher={Springer} 19 | } 20 | ``` 21 | 22 | # USAGE 23 | 24 | Dependencies: numpy and sklearn packages are required for the core function LPM_filter, 25 | 26 | opencv-python and scipy are additionally required to run the demo. 27 | 28 | After installing dependencies, just run 29 | ``` 30 | python demo.py 31 | ``` 32 | for a simple example. 33 | -------------------------------------------------------------------------------- /church.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiayi-ma/LPM_Python/15c13ffb5fcec0482d31d18402ecce5d4b5bceb3/church.mat -------------------------------------------------------------------------------- /city.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiayi-ma/LPM_Python/15c13ffb5fcec0482d31d18402ecce5d4b5bceb3/city.mat -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import scipy.io 2 | import time 3 | import cv2 4 | 5 | from LPM import LPM_filter 6 | 7 | def draw_match(img1, img2, corr1, corr2): 8 | 9 | corr1 = [cv2.KeyPoint(corr1[i, 0], corr1[i, 1], 1) for i in range(corr1.shape[0])] 10 | corr2 = [cv2.KeyPoint(corr2[i, 0], corr2[i, 1], 1) for i in range(corr2.shape[0])] 11 | 12 | assert len(corr1) == len(corr2) 13 | 14 | draw_matches = [cv2.DMatch(i, i, 0) for i in range(len(corr1))] 15 | 16 | display = cv2.drawMatches(img1, corr1, img2, corr2, draw_matches, None, 17 | matchColor=(0, 255, 0), 18 | singlePointColor=(0, 0, 255), 19 | flags=4 20 | ) 21 | return display 22 | 23 | if __name__ == "__main__": 24 | data = scipy.io.loadmat('church.mat') 25 | # data = scipy.io.loadmat('city.mat') 26 | X = data['X'] 27 | Y = data['Y'] 28 | I1 = data['I1'] 29 | I2 = data['I2'] 30 | 31 | start = time.clock() 32 | mask = LPM_filter(X, Y) 33 | end = time.clock() 34 | print("Time cost: {} seconds".format(end-start)) 35 | 36 | 37 | 38 | print("Correspondences before LPM filter, please press ESC to terminate image window") 39 | 40 | display = draw_match(I1, I2, X, Y) 41 | cv2.imshow("before", display) 42 | # press ESC to terminate imshow 43 | k = cv2.waitKey(0) 44 | if k == 27: 45 | cv2.destroyAllWindows() 46 | 47 | print("Correspondences after LPM filter, please press ESC to terminate image window") 48 | 49 | display2 = draw_match(I1, I2, X[mask,:], Y[mask,:]) 50 | cv2.imshow("after", display2) 51 | # press ESC to terminate imshow 52 | k = cv2.waitKey(0) 53 | if k == 27: 54 | cv2.destroyAllWindows() --------------------------------------------------------------------------------