├── README.md ├── alpha.py ├── closed_form_matting.py └── lle.py /README.md: -------------------------------------------------------------------------------- 1 | # AlphaMatting-Information-Flow 2 | Implementation of "Designing Effective Inter-Pixel Information Flow for Natural Image Matting" 3 | 4 | 5 | References - 6 | Implementation of Closed form matting - https://github.com/MarcoForte/closed-form-matting/blob/master/closed_form_matting.py 7 | 8 | lle implementation in scipy python - https://github.com/scikit-learn/scikit-learn/blob/7b136e9/sklearn/manifold/locally_linear.py#L521 9 | -------------------------------------------------------------------------------- /alpha.py: -------------------------------------------------------------------------------- 1 | from math import sqrt as sqrt 2 | import heapq 3 | import numpy as np 4 | import cv2 5 | import matplotlib 6 | import matplotlib.pyplot as plt 7 | from sklearn.neighbors import KDTree 8 | 9 | from lle import locally_linear_embedding as lle 10 | from lle import barycenter_kneighbors_graph as bkg 11 | from lle import barycenter_kneighbors_graph_ku as bkgku 12 | 13 | from closed_form_matting import compute_weight 14 | 15 | from sklearn.neighbors import NearestNeighbors 16 | import numpy as np 17 | 18 | from scipy.sparse import csr_matrix as csr 19 | from scipy.sparse import diags 20 | from scipy.sparse.linalg import cg 21 | from scipy.sparse.linalg import spsolve 22 | 23 | from scipy.misc import imsave 24 | import os 25 | 26 | 27 | 28 | 29 | def cm(img, X): 30 | 31 | kcm = 20 32 | w = bkg(X, n_neighbors=kcm) 33 | return w 34 | # print(w) 35 | # X = np.array([[-1, -1, 2,1], [-2, -1, 4,7], [-3, -2,4,5], [1, 1,8,9], [2, 1,5,4], [3, 2,3,8]]) 36 | # nbrs = NearestNeighbors(n_neighbors=3, algorithm='ball_tree').fit(X) 37 | # y = nbrs._fit_X 38 | # print(nbrs.kneighbors(X, return_distance=False)) 39 | # exit() 40 | # print(X == y) 41 | # pp = NearestNeighbors(n_neighbors=4, algorithm='ball_tree').fit(nbrs) 42 | # print(nbrs) 43 | # exit() 44 | # distances, indices = pp.kneighbors(X) 45 | # w = bkg(X, n_neighbors=2) 46 | # print(w) 47 | # print(type(w)) 48 | # l = w[0] 49 | # print((l[s0])) 50 | # print(l[1]) 51 | # print(l[2]) 52 | # print(l[3]) 53 | # print(indices) 54 | # print(nbrs._fit_X) 55 | 56 | 57 | def ku(img, tmap, X): 58 | kku = 7 59 | 60 | N = X.shape[0] 61 | X[:,3:5] = 10*X[:,3:5] 62 | alpha = tmap.ravel() 63 | ind = np.arange(X.shape[0]) 64 | 65 | fore = X[alpha>0.9] 66 | find = ind[alpha>0.9] 67 | 68 | back = X[alpha<0.1] 69 | bind = ind[alpha<0.1] 70 | 71 | unk = X[(alpha>0.1)&(alpha<0.9)] 72 | unkind = ind[(alpha>0.1)&(alpha<0.9)] 73 | 74 | #nearest foreground pixel to unknown 75 | kdt = KDTree(fore, leaf_size=30, metric='euclidean') 76 | nf = kdt.query(unk, k=kku, return_distance=False) 77 | ind1 = find[nf] 78 | 79 | #nearest background pixel to unknown 80 | kdt = KDTree(back, leaf_size=30, metric='euclidean') 81 | nb = kdt.query(unk, k=kku, return_distance=False) 82 | ind2 = bind[nb] 83 | 84 | z_ind = np.concatenate((ind1, ind2), axis = 1) 85 | z = X[:,:-2][z_ind] 86 | x_inp = unk[:,:-2] 87 | 88 | W = bkgku(x_inp, z, z_ind, n_neighbors=2*kku) 89 | W = W.reshape((W.shape[0], W.shape[1], 1)) 90 | weighted_colours = W*z 91 | print(W.shape) 92 | print(z.shape) 93 | 94 | cpf = np.sum(weighted_colours[:,0:kku,:], axis = 1) 95 | cc = cpf.copy() 96 | cpb = np.sum(weighted_colours[:,kku:,:], axis = 1) 97 | cpf /= np.sum(W[:,0:kku],axis=1) 98 | cpb /= np.sum(W[:,kku:],axis=1) 99 | cpf = np.abs(cpf) 100 | cpb = np.abs(cpb) 101 | 102 | 103 | wf = np.zeros((N)) 104 | wf[unkind] = np.sum(W[:,0:kku,:],axis=1)[:,0] 105 | wf[find] = 1 106 | # print(np.sum(alpha>0.1)) 107 | # print(np.sum(wf!=0)) 108 | 109 | H = np.sum((cpf-cpb)*(cpf-cpb),axis=1)/(3*255*255) #2norm of cpf-cpb 110 | # print(H.shape) 111 | # print(H[0:20]) 112 | # q = H.toarray() 113 | # print(z[H>1]) 114 | # # print(W[H>1]) 115 | # print(W[H>1][:1]) 116 | # print(z[H>1][:1]) 117 | # m = W[H>1][:1] * z[H>1][:1] 118 | # ww = W[H>1][:1] 119 | # print(W[H>1][:1] * z[H>1][:1]) 120 | # print(np.sum(m[:,0:kku,:], axis =1)) 121 | # print(np.sum(ww[:,0:kku,:], axis =1)) 122 | # print("---------") 123 | 124 | # print(cc[H>1][:1]) 125 | # print(H[H>1][:1]) 126 | # print(cpf[H>1][:1]) 127 | # print(cpb[H>1][:1]) 128 | # print(np.sum(H>1)) 129 | H = csr((H,(unkind,unkind)),shape=(N,N)) 130 | return wf,H 131 | # print(csr.count_nonzero(nu)) 132 | # print(nu.shape)s 133 | 134 | # weighted_colours[0:kku,:,:] /= 135 | 136 | # print(unk.shape) 137 | 138 | # print(ind1.shape) 139 | # print(ind2.shape) 140 | 141 | 142 | # knn = NearestNeighbors(n_neighbors + 1, n_jobs=n_jobs).fit(X) 143 | # X = knn._fit_X 144 | # n_samples = X.shape[0] 145 | # ind = knn.kneighbors(X, return_distance=False)[:, 1:] 146 | 147 | # w = bkg(X, n_neighbors=kku) 148 | 149 | 150 | def intra_u(img, tmap, X): 151 | 152 | ##yet to add symmetricity 153 | N = X.shape[0] 154 | kuu = 5 155 | X[:,3:5] = X[:,3:5]/20 156 | alpha = tmap.ravel() 157 | 158 | ind = np.arange(X.shape[0]) 159 | 160 | unk = X[(alpha>0.1)&(alpha<0.9)] 161 | unkind = ind[(alpha>0.1)&(alpha<0.9)] 162 | 163 | #nearest unknown pixels to unknown 164 | kdt = KDTree(unk, leaf_size=30, metric='euclidean') 165 | nu = kdt.query(unk, k=kuu, return_distance=False) 166 | unk_nbr_true_ind = unkind[nu] 167 | unk_nu_ind = np.asarray([int(i/kuu) for i in range(nu.shape[0]*nu.shape[1])]) 168 | unk_nu_true_ind = unkind[unk_nu_ind] 169 | 170 | nbr = unk[nu] 171 | nbr = np.swapaxes(nbr,1,2) 172 | unk = unk.reshape((unk.shape[0], unk.shape[1], 1)) 173 | 174 | x = nbr-unk 175 | x = np.abs(x) 176 | print(x.shape) 177 | y = 1-np.sum(x, axis = 1) 178 | y[y<0] = 0 179 | # print(y.shape) 180 | 181 | row = unk_nu_true_ind 182 | col = unk_nbr_true_ind.ravel() 183 | data = y.ravel() 184 | z = csr((data,(col,row)),shape=(N,N)) 185 | w = csr((data,(row,col)),shape=(N,N)) 186 | # z = csr((data,(col,row)),shape=(h*w,h*w)) 187 | w = w+z 188 | return w 189 | # print(csr.count_nonzero(z)) 190 | # print(csr.count_nonzero(w)) 191 | 192 | 193 | def local(img,tmap): 194 | umask = (tmap>0.1) & (tmap<0.9) 195 | W = compute_weight(img, mask=umask, eps=10**(-7), win_rad=1).tocsr() 196 | return W 197 | # print(type(W)) 198 | # X = csr.sum(W,axis=1) #numpy matrix 199 | # D = diags(X.A.ravel()).tocsr() 200 | # a = csr([[1,0,0],[2,3,0],[0,0,4]]) 201 | # b = csr([[1,0,0],[1,1,0],[0,0,1]]) 202 | # c = a-b 203 | # print(c.toarray()) 204 | # print(csr.count_nonzero(X)) 205 | 206 | 207 | def eq1(Wcm,Wuu,Wl,H,T,ak,wf): 208 | # sku = 0.05 209 | # suu = 0.01 210 | # sl = 1 211 | # lamd = 100 212 | 213 | sku = 0.05 214 | suu = 0.01 215 | sl = 1 216 | lamd = 100 217 | 218 | X = csr.sum(Wcm,axis=1) #numpy matrix 219 | Dcm = diags(X.A.ravel()).tocsr() 220 | 221 | X = csr.sum(Wuu,axis=1) #numpy matrix 222 | Duu = diags(X.A.ravel()).tocsr() 223 | 224 | X = csr.sum(Wl,axis=1) #numpy matrix 225 | Dl = diags(X.A.ravel()).tocsr() 226 | 227 | Lifm = csr.transpose(Dcm-Wcm).dot(Dcm-Wcm) + suu*(Duu-Wuu) + sl*(Dl-Wl) 228 | # Lifm = suu*(Duu-Wuu) + sl*(Dl-Wl) 229 | 230 | 231 | A = Lifm + lamd*T + sku*H 232 | b = (lamd*T + sku*H).dot(wf) 233 | # print(csr.sum(b)) 234 | M = diags(A.diagonal()) 235 | # print(A.shape) 236 | # print(b.shape) 237 | alpha = cg(A, b, x0=wf, tol=1e-05, maxiter=100, M=None, callback=None, atol=None) 238 | # alpha = spsolve(A, b) 239 | # print(alpha) 240 | # print(type(alpha[0])) 241 | return alpha[0]*255 242 | ###solve 243 | 244 | # A = Lifm + lamd*T 245 | # b = (lamd*T).dot(ak) 246 | ###solve 247 | 248 | 249 | def main(img_path, tri_map, save_path): 250 | 251 | # c = np.asarray([[1,0,0],[2,3,0],[0,0,4]]) 252 | # d = np.asarray([[1],[1],[1]]) 253 | # e = c.dot(d) 254 | # a = csr([[1,0,0],[2,3,0],[0,0,4]]) 255 | # b = csr(e) 256 | # b = b.T 257 | # b = csr(b) 258 | # print(a.shape) 259 | # print(b.shape) 260 | # x = cg(a, b, x0=None, tol=1e-05, maxiter=10, callback=None, atol=None) 261 | # print(x) 262 | # exit() 263 | 264 | # img_path = img_path 265 | # tri_map = tri_map 266 | 267 | img = cv2.imread(img_path) 268 | tri_map = cv2.imread(tri_map) 269 | tmap = tri_map[:,:,0].copy() 270 | tmap = tmap/255 271 | print(np.sum(tmap==1)) 272 | 273 | X = [] 274 | [h, w, c] = img.shape 275 | for i in range(h): 276 | for j in range(w): 277 | [r, g, b] = list(img[i, j, :]) 278 | X.append([r, g, b, i/h, j/w]) 279 | X = np.asarray(X) 280 | alpha = tmap.ravel() 281 | 282 | 283 | known = alpha.copy() 284 | known[(alpha>0.9)|(alpha<0.1)] = 1 285 | known[(alpha<0.9)&(alpha>0.1)] = 0 286 | T = diags(known).tocsr() 287 | print(T.count_nonzero()) 288 | # print(np.sum(known==0)) 289 | # print(X[(alpha>0.1)&(alpha<0.9)].shape[0]) 290 | # exit() 291 | 292 | Wcm = cm(img,X) 293 | # Wcm = csr((h*w,h*w)) 294 | wk,H = ku(img,tmap,X) 295 | # H = csr((h*w,h*w)) 296 | # wk = csr((wk.shape)) 297 | Wuu = intra_u(img,tmap,X) 298 | # Wuu = csr((h*w,h*w)) 299 | Wl = local(img,tmap) 300 | 301 | ak = alpha.copy() 302 | ak[ak<0.9] = 0 #set all non foreground pixels to 0 303 | ak[ak>=0.9] = 1 #set all foreground pixels to 1 304 | 305 | calc_alpha = eq1(Wcm,Wuu,Wl,H,T,ak,wk) 306 | calc_alpha[alpha==1] = 1 307 | calc_alpha[alpha==0] = 0 308 | calc_alpha[calc_alpha>0.2] = 1 309 | calc_alpha[calc_alpha<0.07] = 0 310 | imsave(save_path, calc_alpha.reshape((h,w))) 311 | 312 | 313 | def other(): 314 | dira = "./data/trimap_lowres/Trimap1" 315 | dirb = "./data/trimap_lowres/Trimap2" 316 | dirc = "./data/trimap_lowres/Trimap3" 317 | dir1 = "./data/input_lowres" 318 | 319 | # dir1 = "out2/Trimap" 320 | l = os.listdir(dir1) 321 | 322 | for file in l: 323 | f = str(file) 324 | f = "troll.png" 325 | img_path = dir1+'/'+f 326 | 327 | save_path = "out2/Trimap1/"+f 328 | tri_map = dira+'/'+f 329 | main(img_path, tri_map, save_path) 330 | exit() 331 | 332 | save_path = "out2/Trimap2/"+f 333 | tri_map = dirb+'/'+f 334 | main(img_path, tri_map, save_path) 335 | 336 | save_path = "out2/Trimap3/"+f 337 | tri_map = dirc+'/'+f 338 | main(img_path, tri_map, save_path) 339 | 340 | 341 | if __name__ == "__main__": 342 | # main() 343 | other() 344 | # cm() -------------------------------------------------------------------------------- /closed_form_matting.py: -------------------------------------------------------------------------------- 1 | """Implementation of Closed-Form Matting. 2 | 3 | This module implements natural image matting method described in: 4 | Levin, Anat, Dani Lischinski, and Yair Weiss. "A closed-form solution to natural image matting." 5 | IEEE Transactions on Pattern Analysis and Machine Intelligence 30.2 (2008): 228-242. 6 | 7 | The code can be used in two ways: 8 | 1. By importing solve_foregound_background in your code: 9 | ``` 10 | import closed_form_matting 11 | ... 12 | # For scribles input 13 | alpha = closed_form_matting.closed_form_matting_with_scribbles(image, scribbles) 14 | 15 | # For trimap input 16 | alpha = closed_form_matting.closed_form_matting_with_trimap(image, trimap) 17 | 18 | # For prior with confidence 19 | alpha = closed_form_matting.closed_form_matting_with_prior( 20 | image, prior, prior_confidence, optional_const_mask) 21 | 22 | # To get Matting Laplacian for image 23 | laplacian = compute_laplacian(image, optional_const_mask) 24 | ``` 25 | 2. From command line: 26 | ``` 27 | # Scribbles input 28 | ./closed_form_matting.py input_image.png -s scribbles_image.png -o output_alpha.png 29 | 30 | # Trimap input 31 | ./closed_form_matting.py input_image.png -t scribbles_image.png -o output_alpha.png 32 | 33 | # Add flag --solve-fg to compute foreground color and output RGBA image instead 34 | # of alpha. 35 | ``` 36 | """ 37 | 38 | from __future__ import division 39 | 40 | import logging 41 | 42 | import cv2 43 | import numpy as np 44 | from numpy.lib.stride_tricks import as_strided 45 | import scipy.sparse 46 | import scipy.sparse.linalg 47 | 48 | 49 | def _rolling_block(A, block=(3, 3)): 50 | """Applies sliding window to given matrix.""" 51 | shape = (A.shape[0] - block[0] + 1, A.shape[1] - block[1] + 1) + block 52 | strides = (A.strides[0], A.strides[1]) + A.strides 53 | return as_strided(A, shape=shape, strides=strides) 54 | 55 | 56 | 57 | def compute_weight(img, mask=None, eps=10**(-7), win_rad=1): 58 | """Computes Matting Laplacian for a given image. 59 | 60 | Args: 61 | img: 3-dim numpy matrix with input image 62 | mask: mask of pixels for which Laplacian will be computed. 63 | If not set Laplacian will be computed for all pixels. 64 | eps: regularization parameter controlling alpha smoothness 65 | from Eq. 12 of the original paper. Defaults to 1e-7. 66 | win_rad: radius of window used to build Matting Laplacian (i.e. 67 | radius of omega_k in Eq. 12). 68 | Returns: sparse matrix holding Matting Laplacian. 69 | """ 70 | 71 | win_size = (win_rad * 2 + 1) ** 2 72 | h, w, d = img.shape 73 | # Number of window centre indices in h, w axes 74 | c_h, c_w = h - 2 * win_rad, w - 2 * win_rad 75 | win_diam = win_rad * 2 + 1 76 | 77 | indsM = np.arange(h * w).reshape((h, w)) 78 | ravelImg = img.reshape(h * w, d) 79 | win_inds = _rolling_block(indsM, block=(win_diam, win_diam)) 80 | 81 | win_inds = win_inds.reshape(c_h, c_w, win_size) 82 | if mask is not None: 83 | mask = cv2.dilate( 84 | mask.astype(np.uint8), 85 | np.ones((win_diam, win_diam), np.uint8) 86 | ).astype(np.bool) 87 | win_mask = np.sum(mask.ravel()[win_inds], axis=2) 88 | win_inds = win_inds[win_mask > 0, :] 89 | else: 90 | win_inds = win_inds.reshape(-1, win_size) 91 | 92 | 93 | winI = ravelImg[win_inds] 94 | 95 | win_mu = np.mean(winI, axis=1, keepdims=True) 96 | win_var = np.einsum('...ji,...jk ->...ik', winI, winI) / win_size - np.einsum('...ji,...jk ->...ik', win_mu, win_mu) 97 | 98 | inv = np.linalg.inv(win_var + (eps/win_size)*np.eye(3)) 99 | 100 | X = np.einsum('...ij,...jk->...ik', winI - win_mu, inv) 101 | vals = (1.0/win_size)*(1 + np.einsum('...ij,...kj->...ik', X, winI - win_mu)) 102 | 103 | nz_indsCol = np.tile(win_inds, win_size).ravel() 104 | nz_indsRow = np.repeat(win_inds, win_size).ravel() 105 | nz_indsVal = vals.ravel() 106 | # L = np.array((nz_indsVal, (nz_indsRow, nz_indsCol)), shape=(h*w, h*w)) 107 | L = scipy.sparse.coo_matrix((nz_indsVal, (nz_indsRow, nz_indsCol)), shape=(h*w, h*w)) 108 | return L 109 | 110 | 111 | 112 | 113 | def compute_laplacian(img, mask=None, eps=10**(-7), win_rad=1): 114 | """Computes Matting Laplacian for a given image. 115 | 116 | Args: 117 | img: 3-dim numpy matrix with input image 118 | mask: mask of pixels for which Laplacian will be computed. 119 | If not set Laplacian will be computed for all pixels. 120 | eps: regularization parameter controlling alpha smoothness 121 | from Eq. 12 of the original paper. Defaults to 1e-7. 122 | win_rad: radius of window used to build Matting Laplacian (i.e. 123 | radius of omega_k in Eq. 12). 124 | Returns: sparse matrix holding Matting Laplacian. 125 | """ 126 | 127 | win_size = (win_rad * 2 + 1) ** 2 128 | h, w, d = img.shape 129 | # Number of window centre indices in h, w axes 130 | c_h, c_w = h - 2 * win_rad, w - 2 * win_rad 131 | win_diam = win_rad * 2 + 1 132 | 133 | indsM = np.arange(h * w).reshape((h, w)) 134 | ravelImg = img.reshape(h * w, d) 135 | win_inds = _rolling_block(indsM, block=(win_diam, win_diam)) 136 | 137 | win_inds = win_inds.reshape(c_h, c_w, win_size) 138 | if mask is not None: 139 | mask = cv2.dilate( 140 | mask.astype(np.uint8), 141 | np.ones((win_diam, win_diam), np.uint8) 142 | ).astype(np.bool) 143 | win_mask = np.sum(mask.ravel()[win_inds], axis=2) 144 | win_inds = win_inds[win_mask > 0, :] 145 | else: 146 | win_inds = win_inds.reshape(-1, win_size) 147 | 148 | 149 | winI = ravelImg[win_inds] 150 | 151 | win_mu = np.mean(winI, axis=1, keepdims=True) 152 | win_var = np.einsum('...ji,...jk ->...ik', winI, winI) / win_size - np.einsum('...ji,...jk ->...ik', win_mu, win_mu) 153 | 154 | inv = np.linalg.inv(win_var + (eps/win_size)*np.eye(3)) 155 | 156 | X = np.einsum('...ij,...jk->...ik', winI - win_mu, inv) 157 | vals = np.eye(win_size) - (1.0/win_size)*(1 + np.einsum('...ij,...kj->...ik', X, winI - win_mu)) 158 | 159 | nz_indsCol = np.tile(win_inds, win_size).ravel() 160 | nz_indsRow = np.repeat(win_inds, win_size).ravel() 161 | nz_indsVal = vals.ravel() 162 | L = scipy.sparse.coo_matrix((nz_indsVal, (nz_indsRow, nz_indsCol)), shape=(h*w, h*w)) 163 | return L 164 | 165 | 166 | def closed_form_matting_with_prior(image, prior, prior_confidence, consts_map=None): 167 | """Applies closed form matting with prior alpha map to image. 168 | 169 | Args: 170 | image: 3-dim numpy matrix with input image. 171 | prior: matrix of same width and height as input image holding apriori alpha map. 172 | prior_confidence: matrix of the same shape as prior hodling confidence of prior alpha. 173 | consts_map: binary mask of pixels that aren't expected to change due to high 174 | prior confidence. 175 | 176 | Returns: 2-dim matrix holding computed alpha map. 177 | """ 178 | 179 | assert image.shape[:2] == prior.shape, ('prior must be 2D matrix with height and width equal ' 180 | 'to image.') 181 | assert image.shape[:2] == prior_confidence.shape, ('prior_confidence must be 2D matrix with ' 182 | 'height and width equal to image.') 183 | assert (consts_map is None) or image.shape[:2] == consts_map.shape, ( 184 | 'consts_map must be 2D matrix with height and width equal to image.') 185 | 186 | logging.info('Computing Matting Laplacian.') 187 | laplacian = compute_laplacian(image, ~consts_map if consts_map is not None else None) 188 | 189 | confidence = scipy.sparse.diags(prior_confidence.ravel()) 190 | logging.info('Solving for alpha.') 191 | solution = scipy.sparse.linalg.spsolve( 192 | laplacian + confidence, 193 | prior.ravel() * prior_confidence.ravel() 194 | ) 195 | alpha = np.minimum(np.maximum(solution.reshape(prior.shape), 0), 1) 196 | return alpha 197 | 198 | 199 | def closed_form_matting_with_trimap(image, trimap, trimap_confidence=100.0): 200 | """Apply Closed-Form matting to given image using trimap.""" 201 | 202 | assert image.shape[:2] == trimap.shape, ('trimap must be 2D matrix with height and width equal ' 203 | 'to image.') 204 | consts_map = (trimap < 0.1) | (trimap > 0.9) 205 | return closed_form_matting_with_prior(image, trimap, trimap_confidence * consts_map, consts_map) 206 | 207 | 208 | def closed_form_matting_with_scribbles(image, scribbles, scribbles_confidence=100.0): 209 | """Apply Closed-Form matting to given image using scribbles image.""" 210 | 211 | assert image.shape == scribbles.shape, 'scribbles must have exactly same shape as image.' 212 | prior = np.sign(np.sum(scribbles - image, axis=2)) / 2 + 0.5 213 | consts_map = prior != 0.5 214 | return closed_form_matting_with_prior( 215 | image, 216 | prior, 217 | scribbles_confidence * consts_map, 218 | consts_map 219 | ) 220 | 221 | 222 | closed_form_matting = closed_form_matting_with_trimap 223 | 224 | def main(): 225 | import argparse 226 | 227 | logging.basicConfig(level=logging.INFO) 228 | arg_parser = argparse.ArgumentParser(description=__doc__) 229 | arg_parser.add_argument('image', type=str, help='input image') 230 | 231 | arg_parser.add_argument('-t', '--trimap', type=str, help='input trimap') 232 | arg_parser.add_argument('-s', '--scribbles', type=str, help='input scribbles') 233 | arg_parser.add_argument('-o', '--output', type=str, required=True, help='output image') 234 | arg_parser.add_argument( 235 | '--solve-fg', dest='solve_fg', action='store_true', 236 | help='compute foreground color and output RGBA image' 237 | ) 238 | args = arg_parser.parse_args() 239 | 240 | image = cv2.imread(args.image, cv2.IMREAD_COLOR) / 255.0 241 | 242 | if args.scribbles: 243 | scribbles = cv2.imread(args.scribbles, cv2.IMREAD_COLOR) / 255.0 244 | alpha = closed_form_matting_with_scribbles(image, scribbles) 245 | elif args.trimap: 246 | trimap = cv2.imread(args.trimap, cv2.IMREAD_GRAYSCALE) / 255.0 247 | alpha = closed_form_matting_with_trimap(image, trimap) 248 | else: 249 | logging.error('Either trimap or scribbles must be specified.') 250 | arg_parser.print_help() 251 | exit(-1) 252 | 253 | if args.solve_fg: 254 | from solve_foreground_background import solve_foreground_background 255 | foreground, _ = solve_foreground_background(image, alpha) 256 | output = np.concatenate((foreground, alpha[:, :, np.newaxis]), axis=2) 257 | else: 258 | output = alpha 259 | 260 | cv2.imwrite(args.output, output * 255.0) 261 | 262 | 263 | if __name__ == "__main__": 264 | main() 265 | -------------------------------------------------------------------------------- /lle.py: -------------------------------------------------------------------------------- 1 | """Locally Linear Embedding""" 2 | 3 | # Author: Fabian Pedregosa -- 4 | # Jake Vanderplas -- 5 | # License: BSD 3 clause (C) INRIA 2011 6 | 7 | import numpy as np 8 | from scipy.linalg import eigh, svd, qr, solve 9 | from scipy.sparse import eye, csr_matrix 10 | from scipy.sparse.linalg import eigsh 11 | 12 | from sklearn.base import BaseEstimator, TransformerMixin 13 | from sklearn.utils import check_random_state, check_array 14 | from sklearn.utils.extmath import stable_cumsum 15 | from sklearn.utils.validation import check_is_fitted 16 | from sklearn.utils.validation import FLOAT_DTYPES 17 | from sklearn.neighbors import NearestNeighbors 18 | 19 | 20 | def barycenter_weights(X, Z, reg=1e-3): 21 | """Compute barycenter weights of X from Y along the first axis 22 | 23 | We estimate the weights to assign to each point in Y[i] to recover 24 | the point X[i]. The barycenter weights sum to 1. 25 | 26 | Parameters 27 | ---------- 28 | X : array-like, shape (n_samples, n_dim) 29 | 30 | Z : array-like, shape (n_samples, n_neighbors, n_dim) 31 | 32 | reg : float, optional 33 | amount of regularization to add for the problem to be 34 | well-posed in the case of n_neighbors > n_dim 35 | 36 | Returns 37 | ------- 38 | B : array-like, shape (n_samples, n_neighbors) 39 | 40 | Notes 41 | ----- 42 | See developers note for more information. 43 | """ 44 | X = check_array(X, dtype=FLOAT_DTYPES) 45 | Z = check_array(Z, dtype=FLOAT_DTYPES, allow_nd=True) 46 | 47 | n_samples, n_neighbors = X.shape[0], Z.shape[1] 48 | B = np.empty((n_samples, n_neighbors), dtype=X.dtype) 49 | v = np.ones(n_neighbors, dtype=X.dtype) 50 | 51 | # this might raise a LinalgError if G is singular and has trace 52 | # zero 53 | for i, A in enumerate(Z.transpose(0, 2, 1)): 54 | C = A.T - X[i] # broadcasting 55 | G = np.dot(C, C.T) 56 | trace = np.trace(G) 57 | if trace > 0: 58 | R = reg * trace 59 | else: 60 | R = reg 61 | G.flat[::Z.shape[1] + 1] += R 62 | w = solve(G, v, sym_pos=True) 63 | B[i, :] = w / np.sum(w) 64 | return B 65 | 66 | 67 | def barycenter_kneighbors_graph(X, n_neighbors, reg=1e-3, n_jobs=None): 68 | """Computes the barycenter weighted graph of k-Neighbors for points in X 69 | 70 | Parameters 71 | ---------- 72 | X : {array-like, NearestNeighbors} 73 | Sample data, shape = (n_samples, n_features), in the form of a 74 | numpy array or a NearestNeighbors object. 75 | 76 | n_neighbors : int 77 | Number of neighbors for each sample. 78 | 79 | reg : float, optional 80 | Amount of regularization when solving the least-squares 81 | problem. Only relevant if mode='barycenter'. If None, use the 82 | default. 83 | 84 | n_jobs : int or None, optional (default=None) 85 | The number of parallel jobs to run for neighbors search. 86 | ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context. 87 | ``-1`` means using all processors. See :term:`Glossary ` 88 | for more details. 89 | 90 | Returns 91 | ------- 92 | A : sparse matrix in CSR format, shape = [n_samples, n_samples] 93 | A[i, j] is assigned the weight of edge that connects i to j. 94 | 95 | See also 96 | -------- 97 | sklearn.neighbors.kneighbors_graph 98 | sklearn.neighbors.radius_neighbors_graph 99 | """ 100 | # print("in bkg:") 101 | # print(X) 102 | knn = NearestNeighbors(n_neighbors + 1, n_jobs=n_jobs).fit(X) 103 | X = knn._fit_X 104 | n_samples = X.shape[0] 105 | ind = knn.kneighbors(X, return_distance=False)[:, 1:] 106 | # ind = np.asarray([[1,2],[1,2],[1,2],[1,2],[1,2],[1,2]]) 107 | data = barycenter_weights(X[:,:-2], X[:,:-2][ind], reg=reg) 108 | indptr = np.arange(0, n_samples * n_neighbors + 1, n_neighbors) 109 | return csr_matrix((data.ravel(), ind.ravel(), indptr), 110 | shape=(n_samples, n_samples)) 111 | 112 | 113 | 114 | def barycenter_kneighbors_graph_ku(X, z, z_ind, n_neighbors, reg=1e-3, n_jobs=None): 115 | # knn = NearestNeighbors(n_neighbors + 1, n_jobs=n_jobs).fit(X) 116 | # X = knn._fit_X 117 | n_samples = X.shape[0] 118 | # ind = knn.kneighbors(X, return_distance=False)[:, 1:] 119 | # ind = np.asarray([[1,2],[1,2],[1,2],[1,2],[1,2],[1,2]]) 120 | data = barycenter_weights(X, z, reg=reg) 121 | return data 122 | indptr = np.arange(0, n_samples * n_neighbors + 1, n_neighbors) 123 | return csr_matrix((data.ravel(), z_ind.ravel(), indptr), 124 | shape=(n_samples, n_samples)) 125 | 126 | 127 | 128 | def null_space(M, k, k_skip=1, eigen_solver='arpack', tol=1E-6, max_iter=100, 129 | random_state=None): 130 | """ 131 | Find the null space of a matrix M. 132 | 133 | Parameters 134 | ---------- 135 | M : {array, matrix, sparse matrix, LinearOperator} 136 | Input covariance matrix: should be symmetric positive semi-definite 137 | 138 | k : integer 139 | Number of eigenvalues/vectors to return 140 | 141 | k_skip : integer, optional 142 | Number of low eigenvalues to skip. 143 | 144 | eigen_solver : string, {'auto', 'arpack', 'dense'} 145 | auto : algorithm will attempt to choose the best method for input data 146 | arpack : use arnoldi iteration in shift-invert mode. 147 | For this method, M may be a dense matrix, sparse matrix, 148 | or general linear operator. 149 | Warning: ARPACK can be unstable for some problems. It is 150 | best to try several random seeds in order to check results. 151 | dense : use standard dense matrix operations for the eigenvalue 152 | decomposition. For this method, M must be an array 153 | or matrix type. This method should be avoided for 154 | large problems. 155 | 156 | tol : float, optional 157 | Tolerance for 'arpack' method. 158 | Not used if eigen_solver=='dense'. 159 | 160 | max_iter : int 161 | Maximum number of iterations for 'arpack' method. 162 | Not used if eigen_solver=='dense' 163 | 164 | random_state : int, RandomState instance or None, optional (default=None) 165 | If int, random_state is the seed used by the random number generator; 166 | If RandomState instance, random_state is the random number generator; 167 | If None, the random number generator is the RandomState instance used 168 | by `np.random`. Used when ``solver`` == 'arpack'. 169 | 170 | """ 171 | if eigen_solver == 'auto': 172 | if M.shape[0] > 200 and k + k_skip < 10: 173 | eigen_solver = 'arpack' 174 | else: 175 | eigen_solver = 'dense' 176 | 177 | if eigen_solver == 'arpack': 178 | random_state = check_random_state(random_state) 179 | # initialize with [-1,1] as in ARPACK 180 | v0 = random_state.uniform(-1, 1, M.shape[0]) 181 | try: 182 | eigen_values, eigen_vectors = eigsh(M, k + k_skip, sigma=0.0, 183 | tol=tol, maxiter=max_iter, 184 | v0=v0) 185 | except RuntimeError as msg: 186 | raise ValueError("Error in determining null-space with ARPACK. " 187 | "Error message: '%s'. " 188 | "Note that method='arpack' can fail when the " 189 | "weight matrix is singular or otherwise " 190 | "ill-behaved. method='dense' is recommended. " 191 | "See online documentation for more information." 192 | % msg) 193 | 194 | return eigen_vectors[:, k_skip:], np.sum(eigen_values[k_skip:]) 195 | elif eigen_solver == 'dense': 196 | if hasattr(M, 'toarray'): 197 | M = M.toarray() 198 | eigen_values, eigen_vectors = eigh( 199 | M, eigvals=(k_skip, k + k_skip - 1), overwrite_a=True) 200 | index = np.argsort(np.abs(eigen_values)) 201 | return eigen_vectors[:, index], np.sum(eigen_values) 202 | else: 203 | raise ValueError("Unrecognized eigen_solver '%s'" % eigen_solver) 204 | 205 | 206 | def locally_linear_embedding( 207 | X, n_neighbors, n_components, reg=1e-3, eigen_solver='auto', tol=1e-6, 208 | max_iter=100, method='standard', hessian_tol=1E-4, modified_tol=1E-12, 209 | random_state=None, n_jobs=None): 210 | """Perform a Locally Linear Embedding analysis on the data. 211 | 212 | Read more in the :ref:`User Guide `. 213 | 214 | Parameters 215 | ---------- 216 | X : {array-like, NearestNeighbors} 217 | Sample data, shape = (n_samples, n_features), in the form of a 218 | numpy array or a NearestNeighbors object. 219 | 220 | n_neighbors : integer 221 | number of neighbors to consider for each point. 222 | 223 | n_components : integer 224 | number of coordinates for the manifold. 225 | 226 | reg : float 227 | regularization constant, multiplies the trace of the local covariance 228 | matrix of the distances. 229 | 230 | eigen_solver : string, {'auto', 'arpack', 'dense'} 231 | auto : algorithm will attempt to choose the best method for input data 232 | 233 | arpack : use arnoldi iteration in shift-invert mode. 234 | For this method, M may be a dense matrix, sparse matrix, 235 | or general linear operator. 236 | Warning: ARPACK can be unstable for some problems. It is 237 | best to try several random seeds in order to check results. 238 | 239 | dense : use standard dense matrix operations for the eigenvalue 240 | decomposition. For this method, M must be an array 241 | or matrix type. This method should be avoided for 242 | large problems. 243 | 244 | tol : float, optional 245 | Tolerance for 'arpack' method 246 | Not used if eigen_solver=='dense'. 247 | 248 | max_iter : integer 249 | maximum number of iterations for the arpack solver. 250 | 251 | method : {'standard', 'hessian', 'modified', 'ltsa'} 252 | standard : use the standard locally linear embedding algorithm. 253 | see reference [1]_ 254 | hessian : use the Hessian eigenmap method. This method requires 255 | n_neighbors > n_components * (1 + (n_components + 1) / 2. 256 | see reference [2]_ 257 | modified : use the modified locally linear embedding algorithm. 258 | see reference [3]_ 259 | ltsa : use local tangent space alignment algorithm 260 | see reference [4]_ 261 | 262 | hessian_tol : float, optional 263 | Tolerance for Hessian eigenmapping method. 264 | Only used if method == 'hessian' 265 | 266 | modified_tol : float, optional 267 | Tolerance for modified LLE method. 268 | Only used if method == 'modified' 269 | 270 | random_state : int, RandomState instance or None, optional (default=None) 271 | If int, random_state is the seed used by the random number generator; 272 | If RandomState instance, random_state is the random number generator; 273 | If None, the random number generator is the RandomState instance used 274 | by `np.random`. Used when ``solver`` == 'arpack'. 275 | 276 | n_jobs : int or None, optional (default=None) 277 | The number of parallel jobs to run for neighbors search. 278 | ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context. 279 | ``-1`` means using all processors. See :term:`Glossary ` 280 | for more details. 281 | 282 | Returns 283 | ------- 284 | Y : array-like, shape [n_samples, n_components] 285 | Embedding vectors. 286 | 287 | squared_error : float 288 | Reconstruction error for the embedding vectors. Equivalent to 289 | ``norm(Y - W Y, 'fro')**2``, where W are the reconstruction weights. 290 | 291 | References 292 | ---------- 293 | 294 | .. [1] `Roweis, S. & Saul, L. Nonlinear dimensionality reduction 295 | by locally linear embedding. Science 290:2323 (2000).` 296 | .. [2] `Donoho, D. & Grimes, C. Hessian eigenmaps: Locally 297 | linear embedding techniques for high-dimensional data. 298 | Proc Natl Acad Sci U S A. 100:5591 (2003).` 299 | .. [3] `Zhang, Z. & Wang, J. MLLE: Modified Locally Linear 300 | Embedding Using Multiple Weights.` 301 | http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.70.382 302 | .. [4] `Zhang, Z. & Zha, H. Principal manifolds and nonlinear 303 | dimensionality reduction via tangent space alignment. 304 | Journal of Shanghai Univ. 8:406 (2004)` 305 | """ 306 | if eigen_solver not in ('auto', 'arpack', 'dense'): 307 | raise ValueError("unrecognized eigen_solver '%s'" % eigen_solver) 308 | 309 | if method not in ('standard', 'hessian', 'modified', 'ltsa'): 310 | raise ValueError("unrecognized method '%s'" % method) 311 | 312 | nbrs = NearestNeighbors(n_neighbors=n_neighbors + 1, n_jobs=n_jobs) 313 | nbrs.fit(X) 314 | X = nbrs._fit_X 315 | 316 | N, d_in = X.shape 317 | 318 | if n_components > d_in: 319 | raise ValueError("output dimension must be less than or equal " 320 | "to input dimension") 321 | if n_neighbors >= N: 322 | raise ValueError( 323 | "Expected n_neighbors <= n_samples, " 324 | " but n_samples = %d, n_neighbors = %d" % 325 | (N, n_neighbors) 326 | ) 327 | 328 | if n_neighbors <= 0: 329 | raise ValueError("n_neighbors must be positive") 330 | 331 | M_sparse = (eigen_solver != 'dense') 332 | 333 | if method == 'standard': 334 | W = barycenter_kneighbors_graph( 335 | nbrs, n_neighbors=n_neighbors, reg=reg, n_jobs=n_jobs) 336 | 337 | # we'll compute M = (I-W)'(I-W) 338 | # depending on the solver, we'll do this differently 339 | if M_sparse: 340 | M = eye(*W.shape, format=W.format) - W 341 | M = (M.T * M).tocsr() 342 | else: 343 | M = (W.T * W - W.T - W).toarray() 344 | M.flat[::M.shape[0] + 1] += 1 # W = W - I = W - I 345 | 346 | elif method == 'hessian': 347 | dp = n_components * (n_components + 1) // 2 348 | 349 | if n_neighbors <= n_components + dp: 350 | raise ValueError("for method='hessian', n_neighbors must be " 351 | "greater than " 352 | "[n_components * (n_components + 3) / 2]") 353 | 354 | neighbors = nbrs.kneighbors(X, n_neighbors=n_neighbors + 1, 355 | return_distance=False) 356 | neighbors = neighbors[:, 1:] 357 | 358 | Yi = np.empty((n_neighbors, 1 + n_components + dp), dtype=np.float64) 359 | Yi[:, 0] = 1 360 | 361 | M = np.zeros((N, N), dtype=np.float64) 362 | 363 | use_svd = (n_neighbors > d_in) 364 | 365 | for i in range(N): 366 | Gi = X[neighbors[i]] 367 | Gi -= Gi.mean(0) 368 | 369 | # build Hessian estimator 370 | if use_svd: 371 | U = svd(Gi, full_matrices=0)[0] 372 | else: 373 | Ci = np.dot(Gi, Gi.T) 374 | U = eigh(Ci)[1][:, ::-1] 375 | 376 | Yi[:, 1:1 + n_components] = U[:, :n_components] 377 | 378 | j = 1 + n_components 379 | for k in range(n_components): 380 | Yi[:, j:j + n_components - k] = (U[:, k:k + 1] * 381 | U[:, k:n_components]) 382 | j += n_components - k 383 | 384 | Q, R = qr(Yi) 385 | 386 | w = Q[:, n_components + 1:] 387 | S = w.sum(0) 388 | 389 | S[np.where(abs(S) < hessian_tol)] = 1 390 | w /= S 391 | 392 | nbrs_x, nbrs_y = np.meshgrid(neighbors[i], neighbors[i]) 393 | M[nbrs_x, nbrs_y] += np.dot(w, w.T) 394 | 395 | if M_sparse: 396 | M = csr_matrix(M) 397 | 398 | elif method == 'modified': 399 | if n_neighbors < n_components: 400 | raise ValueError("modified LLE requires " 401 | "n_neighbors >= n_components") 402 | 403 | neighbors = nbrs.kneighbors(X, n_neighbors=n_neighbors + 1, 404 | return_distance=False) 405 | neighbors = neighbors[:, 1:] 406 | 407 | # find the eigenvectors and eigenvalues of each local covariance 408 | # matrix. We want V[i] to be a [n_neighbors x n_neighbors] matrix, 409 | # where the columns are eigenvectors 410 | V = np.zeros((N, n_neighbors, n_neighbors)) 411 | nev = min(d_in, n_neighbors) 412 | evals = np.zeros([N, nev]) 413 | 414 | # choose the most efficient way to find the eigenvectors 415 | use_svd = (n_neighbors > d_in) 416 | 417 | if use_svd: 418 | for i in range(N): 419 | X_nbrs = X[neighbors[i]] - X[i] 420 | V[i], evals[i], _ = svd(X_nbrs, 421 | full_matrices=True) 422 | evals **= 2 423 | else: 424 | for i in range(N): 425 | X_nbrs = X[neighbors[i]] - X[i] 426 | C_nbrs = np.dot(X_nbrs, X_nbrs.T) 427 | evi, vi = eigh(C_nbrs) 428 | evals[i] = evi[::-1] 429 | V[i] = vi[:, ::-1] 430 | 431 | # find regularized weights: this is like normal LLE. 432 | # because we've already computed the SVD of each covariance matrix, 433 | # it's faster to use this rather than np.linalg.solve 434 | reg = 1E-3 * evals.sum(1) 435 | 436 | tmp = np.dot(V.transpose(0, 2, 1), np.ones(n_neighbors)) 437 | tmp[:, :nev] /= evals + reg[:, None] 438 | tmp[:, nev:] /= reg[:, None] 439 | 440 | w_reg = np.zeros((N, n_neighbors)) 441 | for i in range(N): 442 | w_reg[i] = np.dot(V[i], tmp[i]) 443 | w_reg /= w_reg.sum(1)[:, None] 444 | 445 | # calculate eta: the median of the ratio of small to large eigenvalues 446 | # across the points. This is used to determine s_i, below 447 | rho = evals[:, n_components:].sum(1) / evals[:, :n_components].sum(1) 448 | eta = np.median(rho) 449 | 450 | # find s_i, the size of the "almost null space" for each point: 451 | # this is the size of the largest set of eigenvalues 452 | # such that Sum[v; v in set]/Sum[v; v not in set] < eta 453 | s_range = np.zeros(N, dtype=int) 454 | evals_cumsum = stable_cumsum(evals, 1) 455 | eta_range = evals_cumsum[:, -1:] / evals_cumsum[:, :-1] - 1 456 | for i in range(N): 457 | s_range[i] = np.searchsorted(eta_range[i, ::-1], eta) 458 | s_range += n_neighbors - nev # number of zero eigenvalues 459 | 460 | # Now calculate M. 461 | # This is the [N x N] matrix whose null space is the desired embedding 462 | M = np.zeros((N, N), dtype=np.float64) 463 | for i in range(N): 464 | s_i = s_range[i] 465 | 466 | # select bottom s_i eigenvectors and calculate alpha 467 | Vi = V[i, :, n_neighbors - s_i:] 468 | alpha_i = np.linalg.norm(Vi.sum(0)) / np.sqrt(s_i) 469 | 470 | # compute Householder matrix which satisfies 471 | # Hi*Vi.T*ones(n_neighbors) = alpha_i*ones(s) 472 | # using prescription from paper 473 | h = np.full(s_i, alpha_i) - np.dot(Vi.T, np.ones(n_neighbors)) 474 | 475 | norm_h = np.linalg.norm(h) 476 | if norm_h < modified_tol: 477 | h *= 0 478 | else: 479 | h /= norm_h 480 | 481 | # Householder matrix is 482 | # >> Hi = np.identity(s_i) - 2*np.outer(h,h) 483 | # Then the weight matrix is 484 | # >> Wi = np.dot(Vi,Hi) + (1-alpha_i) * w_reg[i,:,None] 485 | # We do this much more efficiently: 486 | Wi = (Vi - 2 * np.outer(np.dot(Vi, h), h) + 487 | (1 - alpha_i) * w_reg[i, :, None]) 488 | 489 | # Update M as follows: 490 | # >> W_hat = np.zeros( (N,s_i) ) 491 | # >> W_hat[neighbors[i],:] = Wi 492 | # >> W_hat[i] -= 1 493 | # >> M += np.dot(W_hat,W_hat.T) 494 | # We can do this much more efficiently: 495 | nbrs_x, nbrs_y = np.meshgrid(neighbors[i], neighbors[i]) 496 | M[nbrs_x, nbrs_y] += np.dot(Wi, Wi.T) 497 | Wi_sum1 = Wi.sum(1) 498 | M[i, neighbors[i]] -= Wi_sum1 499 | M[neighbors[i], i] -= Wi_sum1 500 | M[i, i] += s_i 501 | 502 | if M_sparse: 503 | M = csr_matrix(M) 504 | 505 | elif method == 'ltsa': 506 | neighbors = nbrs.kneighbors(X, n_neighbors=n_neighbors + 1, 507 | return_distance=False) 508 | neighbors = neighbors[:, 1:] 509 | 510 | M = np.zeros((N, N)) 511 | 512 | use_svd = (n_neighbors > d_in) 513 | 514 | for i in range(N): 515 | Xi = X[neighbors[i]] 516 | Xi -= Xi.mean(0) 517 | 518 | # compute n_components largest eigenvalues of Xi * Xi^T 519 | if use_svd: 520 | v = svd(Xi, full_matrices=True)[0] 521 | else: 522 | Ci = np.dot(Xi, Xi.T) 523 | v = eigh(Ci)[1][:, ::-1] 524 | 525 | Gi = np.zeros((n_neighbors, n_components + 1)) 526 | Gi[:, 1:] = v[:, :n_components] 527 | Gi[:, 0] = 1. / np.sqrt(n_neighbors) 528 | 529 | GiGiT = np.dot(Gi, Gi.T) 530 | 531 | nbrs_x, nbrs_y = np.meshgrid(neighbors[i], neighbors[i]) 532 | M[nbrs_x, nbrs_y] -= GiGiT 533 | M[neighbors[i], neighbors[i]] += 1 534 | 535 | return null_space(M, n_components, k_skip=1, eigen_solver=eigen_solver, 536 | tol=tol, max_iter=max_iter, random_state=random_state) 537 | 538 | '''' 539 | class LocallyLinearEmbedding(BaseEstimator, TransformerMixin, 540 | _UnstableArchMixin): 541 | """Locally Linear Embedding 542 | 543 | Read more in the :ref:`User Guide `. 544 | 545 | Parameters 546 | ---------- 547 | n_neighbors : integer 548 | number of neighbors to consider for each point. 549 | 550 | n_components : integer 551 | number of coordinates for the manifold 552 | 553 | reg : float 554 | regularization constant, multiplies the trace of the local covariance 555 | matrix of the distances. 556 | 557 | eigen_solver : string, {'auto', 'arpack', 'dense'} 558 | auto : algorithm will attempt to choose the best method for input data 559 | 560 | arpack : use arnoldi iteration in shift-invert mode. 561 | For this method, M may be a dense matrix, sparse matrix, 562 | or general linear operator. 563 | Warning: ARPACK can be unstable for some problems. It is 564 | best to try several random seeds in order to check results. 565 | 566 | dense : use standard dense matrix operations for the eigenvalue 567 | decomposition. For this method, M must be an array 568 | or matrix type. This method should be avoided for 569 | large problems. 570 | 571 | tol : float, optional 572 | Tolerance for 'arpack' method 573 | Not used if eigen_solver=='dense'. 574 | 575 | max_iter : integer 576 | maximum number of iterations for the arpack solver. 577 | Not used if eigen_solver=='dense'. 578 | 579 | method : string ('standard', 'hessian', 'modified' or 'ltsa') 580 | standard : use the standard locally linear embedding algorithm. see 581 | reference [1] 582 | hessian : use the Hessian eigenmap method. This method requires 583 | ``n_neighbors > n_components * (1 + (n_components + 1) / 2`` 584 | see reference [2] 585 | modified : use the modified locally linear embedding algorithm. 586 | see reference [3] 587 | ltsa : use local tangent space alignment algorithm 588 | see reference [4] 589 | 590 | hessian_tol : float, optional 591 | Tolerance for Hessian eigenmapping method. 592 | Only used if ``method == 'hessian'`` 593 | 594 | modified_tol : float, optional 595 | Tolerance for modified LLE method. 596 | Only used if ``method == 'modified'`` 597 | 598 | neighbors_algorithm : string ['auto'|'brute'|'kd_tree'|'ball_tree'] 599 | algorithm to use for nearest neighbors search, 600 | passed to neighbors.NearestNeighbors instance 601 | 602 | random_state : int, RandomState instance or None, optional (default=None) 603 | If int, random_state is the seed used by the random number generator; 604 | If RandomState instance, random_state is the random number generator; 605 | If None, the random number generator is the RandomState instance used 606 | by `np.random`. Used when ``eigen_solver`` == 'arpack'. 607 | 608 | n_jobs : int or None, optional (default=None) 609 | The number of parallel jobs to run. 610 | ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context. 611 | ``-1`` means using all processors. See :term:`Glossary ` 612 | for more details. 613 | 614 | Attributes 615 | ---------- 616 | embedding_ : array-like, shape [n_samples, n_components] 617 | Stores the embedding vectors 618 | 619 | reconstruction_error_ : float 620 | Reconstruction error associated with `embedding_` 621 | 622 | nbrs_ : NearestNeighbors object 623 | Stores nearest neighbors instance, including BallTree or KDtree 624 | if applicable. 625 | 626 | Examples 627 | -------- 628 | >>> from sklearn.datasets import load_digits 629 | >>> from sklearn.manifold import LocallyLinearEmbedding 630 | >>> X, _ = load_digits(return_X_y=True) 631 | >>> X.shape 632 | (1797, 64) 633 | >>> embedding = LocallyLinearEmbedding(n_components=2) 634 | >>> X_transformed = embedding.fit_transform(X[:100]) 635 | >>> X_transformed.shape 636 | (100, 2) 637 | 638 | References 639 | ---------- 640 | 641 | .. [1] `Roweis, S. & Saul, L. Nonlinear dimensionality reduction 642 | by locally linear embedding. Science 290:2323 (2000).` 643 | .. [2] `Donoho, D. & Grimes, C. Hessian eigenmaps: Locally 644 | linear embedding techniques for high-dimensional data. 645 | Proc Natl Acad Sci U S A. 100:5591 (2003).` 646 | .. [3] `Zhang, Z. & Wang, J. MLLE: Modified Locally Linear 647 | Embedding Using Multiple Weights.` 648 | http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.70.382 649 | .. [4] `Zhang, Z. & Zha, H. Principal manifolds and nonlinear 650 | dimensionality reduction via tangent space alignment. 651 | Journal of Shanghai Univ. 8:406 (2004)` 652 | """ 653 | 654 | def __init__(self, n_neighbors=5, n_components=2, reg=1E-3, 655 | eigen_solver='auto', tol=1E-6, max_iter=100, 656 | method='standard', hessian_tol=1E-4, modified_tol=1E-12, 657 | neighbors_algorithm='auto', random_state=None, n_jobs=None): 658 | self.n_neighbors = n_neighbors 659 | self.n_components = n_components 660 | self.reg = reg 661 | self.eigen_solver = eigen_solver 662 | self.tol = tol 663 | self.max_iter = max_iter 664 | self.method = method 665 | self.hessian_tol = hessian_tol 666 | self.modified_tol = modified_tol 667 | self.random_state = random_state 668 | self.neighbors_algorithm = neighbors_algorithm 669 | self.n_jobs = n_jobs 670 | 671 | def _fit_transform(self, X): 672 | self.nbrs_ = NearestNeighbors(self.n_neighbors, 673 | algorithm=self.neighbors_algorithm, 674 | n_jobs=self.n_jobs) 675 | 676 | random_state = check_random_state(self.random_state) 677 | X = check_array(X, dtype=float) 678 | self.nbrs_.fit(X) 679 | self.embedding_, self.reconstruction_error_ = \ 680 | locally_linear_embedding( 681 | self.nbrs_, self.n_neighbors, self.n_components, 682 | eigen_solver=self.eigen_solver, tol=self.tol, 683 | max_iter=self.max_iter, method=self.method, 684 | hessian_tol=self.hessian_tol, modified_tol=self.modified_tol, 685 | random_state=random_state, reg=self.reg, n_jobs=self.n_jobs) 686 | 687 | def fit(self, X, y=None): 688 | """Compute the embedding vectors for data X 689 | 690 | Parameters 691 | ---------- 692 | X : array-like of shape [n_samples, n_features] 693 | training set. 694 | 695 | y : Ignored 696 | 697 | Returns 698 | ------- 699 | self : returns an instance of self. 700 | """ 701 | self._fit_transform(X) 702 | return self 703 | 704 | def fit_transform(self, X, y=None): 705 | """Compute the embedding vectors for data X and transform X. 706 | 707 | Parameters 708 | ---------- 709 | X : array-like of shape [n_samples, n_features] 710 | training set. 711 | 712 | y : Ignored 713 | 714 | Returns 715 | ------- 716 | X_new : array-like, shape (n_samples, n_components) 717 | """ 718 | self._fit_transform(X) 719 | return self.embedding_ 720 | 721 | def transform(self, X): 722 | """ 723 | Transform new points into embedding space. 724 | 725 | Parameters 726 | ---------- 727 | X : array-like, shape = [n_samples, n_features] 728 | 729 | Returns 730 | ------- 731 | X_new : array, shape = [n_samples, n_components] 732 | 733 | Notes 734 | ----- 735 | Because of scaling performed by this method, it is discouraged to use 736 | it together with methods that are not scale-invariant (like SVMs) 737 | """ 738 | check_is_fitted(self, "nbrs_") 739 | 740 | X = check_array(X) 741 | ind = self.nbrs_.kneighbors(X, n_neighbors=self.n_neighbors, 742 | return_distance=False) 743 | weights = barycenter_weights(X, self.nbrs_._fit_X[ind], 744 | reg=self.reg) 745 | X_new = np.empty((X.shape[0], self.n_components)) 746 | for i in range(X.shape[0]): 747 | X_new[i] = np.dot(self.embedding_[ind[i]].T, weights[i]) 748 | return X_new 749 | ''' --------------------------------------------------------------------------------