├── .gitattributes ├── .gitignore ├── README.md ├── bayesian_matting.py ├── gandalf.png ├── gandalfAlpha.png ├── gandalfTrimap.png └── orchard_bouman_clust.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | 4 | # Custom for Visual Studio 5 | *.cs diff=csharp 6 | 7 | # Standard to msysgit 8 | *.doc diff=astextplain 9 | *.DOC diff=astextplain 10 | *.docx diff=astextplain 11 | *.DOCX diff=astextplain 12 | *.dot diff=astextplain 13 | *.DOT diff=astextplain 14 | *.pdf diff=astextplain 15 | *.PDF diff=astextplain 16 | *.rtf diff=astextplain 17 | *.RTF diff=astextplain 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask instance folder 57 | instance/ 58 | 59 | # Scrapy stuff: 60 | .scrapy 61 | 62 | # Sphinx documentation 63 | docs/_build/ 64 | 65 | # PyBuilder 66 | target/ 67 | 68 | # IPython Notebook 69 | .ipynb_checkpoints 70 | 71 | # pyenv 72 | .python-version 73 | 74 | # celery beat schedule file 75 | celerybeat-schedule 76 | 77 | # dotenv 78 | .env 79 | 80 | # virtualenv 81 | venv/ 82 | ENV/ 83 | 84 | # Spyder project settings 85 | .spyderproject 86 | 87 | # Rope project settings 88 | .ropeproject 89 | 90 | # ========================= 91 | # Operating System Files 92 | # ========================= 93 | 94 | # OSX 95 | # ========================= 96 | 97 | .DS_Store 98 | .AppleDouble 99 | .LSOverride 100 | 101 | # Thumbnails 102 | ._* 103 | 104 | # Files that might appear in the root of a volume 105 | .DocumentRevisions-V100 106 | .fseventsd 107 | .Spotlight-V100 108 | .TemporaryItems 109 | .Trashes 110 | .VolumeIcon.icns 111 | 112 | # Directories potentially created on remote AFP share 113 | .AppleDB 114 | .AppleDesktop 115 | Network Trash Folder 116 | Temporary Items 117 | .apdisk 118 | 119 | # Windows 120 | # ========================= 121 | 122 | # Windows image file caches 123 | Thumbs.db 124 | ehthumbs.db 125 | 126 | # Folder config file 127 | Desktop.ini 128 | 129 | # Recycle Bin used on file shares 130 | $RECYCLE.BIN/ 131 | 132 | # Windows Installer files 133 | *.cab 134 | *.msi 135 | *.msm 136 | *.msp 137 | 138 | # Windows shortcuts 139 | *.lnk 140 | 141 | # vscode 142 | .vscode 143 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # bayesian-matting 2 | Python implementation of Yung-Yu Chuang, Brian Curless, David H. Salesin, and Richard Szeliski. A Bayesian Approach to Digital Matting. In Proceedings of IEEE Computer Vision and Pattern Recognition (CVPR 2001), Vol. II, 264-271, December 2001 3 | 4 | 5 | ### Requirements 6 | - python 3.5+ (Though it should run on 2.7 with some minor tweaks) 7 | - scipy 8 | - numpy 9 | - numba > 0.30.1 (Not neccesary, but does give a 5x speedup) 10 | - matplotlib 11 | - opencv 12 | - sys 13 | - pathlib 14 | - argparse 15 | 16 | ### Running the demo 17 | - 'python bayesian_matting.py gandalf.png gandalfTrimap.png' 18 | - sigma (σ) fall off of gaussian weighting to local window 19 | - N size of window to construct local fg/bg clusters from 20 | - minN minimum number of known pixels in local window to proceed 21 | - minN_reduction to reduce N by in event of infinite loop. May reduce accuracy 22 | 23 | 24 | ### Results 25 | Original image 26 | Trimap image 27 | Result 28 | 29 | 30 | 31 | ### More Information 32 | 33 | For more information see the orginal project website http://grail.cs.washington.edu/projects/digital-matting/image-matting/ 34 | This implementation was mostly adapted from Michael Rubinsteins matlab code here, 35 | http://www1.idc.ac.il/toky/CompPhoto-09/Projects/Stud_projects/Miki/index.html 36 | http://people.csail.mit.edu/mrub/code/bayesmat.zip 37 | 38 | ### Disclaimer 39 | 40 | The code is free for academic/research purpose. Use at your own risk and we are not responsible for any loss resulting from this code. Feel free to submit pull request for bug fixes. 41 | 42 | ### Contact 43 | [Marco Forte](https://marcoforte.github.io/) (fortem@tcd.ie) 44 | 45 | #### Original authors: 46 | [Yung-Yu Chuang](http://www.cs.washington.edu/homes/cyy) 47 | [Brian Curless](http://www.cs.washington.edu/homes/curless) 48 | [David Salesin](http://www.cs.washington.edu/homes/salesin) 49 | [Richard Szeliski](http://www.research.microsoft.com/~szeliski) 50 | -------------------------------------------------------------------------------- /bayesian_matting.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | import numpy as np 4 | import cv2 5 | from numba import jit 6 | from argparse import ArgumentParser 7 | import warnings 8 | 9 | from orchard_bouman_clust import clustFunc 10 | 11 | 12 | def matlab_style_gauss2d(shape=(3, 3), sigma=0.5): 13 | """ 14 | 2D gaussian mask - should give the same result as MATLAB's 15 | fspecial('gaussian',[shape],[sigma]) 16 | """ 17 | m, n = [(ss - 1.) / 2. for ss in shape] 18 | y, x = np.ogrid[-m:m + 1, -n:n + 1] 19 | h = np.exp(-(x * x + y * y) / (2. * sigma * sigma)) 20 | h[h < np.finfo(h.dtype).eps * h.max()] = 0 21 | sumh = h.sum() 22 | if sumh != 0: 23 | h /= sumh 24 | return h 25 | 26 | 27 | # returns the surrounding N-rectangular neighborhood of matrix m, centered 28 | # at pixel (x,y), (odd valued N) 29 | @jit(nopython=True, cache=True) 30 | def get_window(m, x, y, N): 31 | h, w, c = m.shape 32 | halfN = N // 2 33 | r = np.zeros((N, N, c)) 34 | xmin = max(0, x - halfN); 35 | xmax = min(w, x + (halfN + 1)) 36 | ymin = max(0, y - halfN); 37 | ymax = min(h, y + (halfN + 1)) 38 | pxmin = halfN - (x - xmin); 39 | pxmax = halfN + (xmax - x) 40 | pymin = halfN - (y - ymin); 41 | pymax = halfN + (ymax - y) 42 | 43 | r[pymin:pymax, pxmin:pxmax] = m[ymin:ymax, xmin:xmax] 44 | return r 45 | 46 | 47 | @jit(nopython=True, cache=True) 48 | def solve(mu_F, Sigma_F, mu_B, Sigma_B, C, sigma_C, alpha_init, maxIter, minLike): 49 | ''' 50 | Solves for F,B and alpha that maximize the sum of log 51 | likelihoods at the given pixel C. 52 | input: 53 | mu_F - means of foreground clusters (for RGB, of size 3x#Fclusters) 54 | Sigma_F - covariances of foreground clusters (for RGB, of size 55 | 3x3x#Fclusters) 56 | mu_B,Sigma_B - same for background clusters 57 | C - observed pixel 58 | alpha_init - initial value for alpha 59 | maxIter - maximal number of iterations 60 | minLike - minimal change in likelihood between consecutive iterations 61 | 62 | returns: 63 | F,B,alpha - estimate of foreground, background and alpha 64 | channel (for RGB, each of size 3x1) 65 | ''' 66 | I = np.eye(3) 67 | FMax = np.zeros(3) 68 | BMax = np.zeros(3) 69 | alphaMax = 0 70 | maxlike = - np.inf 71 | invsgma2 = 1 / sigma_C ** 2 72 | for i in range(mu_F.shape[0]): 73 | mu_Fi = mu_F[i] 74 | invSigma_Fi = np.linalg.inv(Sigma_F[i]) 75 | for j in range(mu_B.shape[0]): 76 | mu_Bj = mu_B[j] 77 | invSigma_Bj = np.linalg.inv(Sigma_B[j]) 78 | 79 | alpha = alpha_init 80 | myiter = 1 81 | lastLike = -1.7977e+308 82 | while True: 83 | # solve for F,B 84 | A11 = invSigma_Fi + I * alpha ** 2 * invsgma2 85 | A12 = I * alpha * (1 - alpha) * invsgma2 86 | A22 = invSigma_Bj + I * (1 - alpha) ** 2 * invsgma2 87 | A = np.vstack((np.hstack((A11, A12)), np.hstack((A12, A22)))) 88 | b1 = invSigma_Fi @ mu_Fi + C * (alpha) * invsgma2 89 | b2 = invSigma_Bj @ mu_Bj + C * (1 - alpha) * invsgma2 90 | b = np.atleast_2d(np.concatenate((b1, b2))).T 91 | 92 | X = np.linalg.solve(A, b) 93 | F = np.maximum(0, np.minimum(1, X[0:3])) 94 | B = np.maximum(0, np.minimum(1, X[3:6])) 95 | # solve for alpha 96 | 97 | alpha = np.maximum(0, np.minimum(1, ((np.atleast_2d(C).T - B).T @ (F - B)) / np.sum((F - B) ** 2)))[ 98 | 0, 0] 99 | # # calculate likelihood 100 | L_C = - np.sum((np.atleast_2d(C).T - alpha * F - (1 - alpha) * B) ** 2) * invsgma2 101 | L_F = (- ((F - np.atleast_2d(mu_Fi).T).T @ invSigma_Fi @ (F - np.atleast_2d(mu_Fi).T)) / 2)[0, 0] 102 | L_B = (- ((B - np.atleast_2d(mu_Bj).T).T @ invSigma_Bj @ (B - np.atleast_2d(mu_Bj).T)) / 2)[0, 0] 103 | like = (L_C + L_F + L_B) 104 | # like = 0 105 | 106 | if like > maxlike: 107 | alphaMax = alpha 108 | maxLike = like 109 | FMax = F.ravel() 110 | BMax = B.ravel() 111 | 112 | if myiter >= maxIter or abs(like - lastLike) <= minLike: 113 | break 114 | 115 | lastLike = like 116 | myiter += 1 117 | return FMax, BMax, alphaMax 118 | 119 | 120 | def bayesian_matte(img, trimap, sigma=8, N=25, minN=10, minN_reduction=0): 121 | # check minN_reduction parameter 122 | if minN_reduction >= minN: 123 | raise ValueError("minN_reduction parameter must be less than minN") 124 | 125 | img = img / 255 126 | 127 | h, w, c = img.shape 128 | alpha = np.zeros((h, w)) 129 | 130 | fg_mask = trimap == 255 131 | bg_mask = trimap == 0 132 | unknown_mask = True ^ np.logical_or(fg_mask, bg_mask) 133 | foreground = img * np.repeat(fg_mask[:, :, np.newaxis], 3, axis=2) 134 | background = img * np.repeat(bg_mask[:, :, np.newaxis], 3, axis=2) 135 | 136 | gaussian_weights = matlab_style_gauss2d((N, N), sigma) 137 | gaussian_weights = gaussian_weights / np.max(gaussian_weights) 138 | 139 | alpha[fg_mask] = 1 140 | F = np.zeros(img.shape) 141 | B = np.zeros(img.shape) 142 | alphaRes = np.zeros(trimap.shape) 143 | 144 | n = 1 145 | alpha[unknown_mask] = np.nan 146 | nUnknown = np.sum(unknown_mask) 147 | unkreg = unknown_mask 148 | 149 | kernel = np.ones((3, 3)) 150 | while n < nUnknown: 151 | unkreg = cv2.erode(unkreg.astype(np.uint8), kernel, iterations=1) 152 | unkpixels = np.logical_and(np.logical_not(unkreg), unknown_mask) 153 | 154 | Y, X = np.nonzero(unkpixels) 155 | 156 | for i in range(Y.shape[0]): 157 | if n % 100 == 0: 158 | print(n, nUnknown) 159 | y, x = Y[i], X[i] 160 | p = img[y, x] 161 | # Try cluster Fg, Bg in p's known neighborhood 162 | 163 | # take surrounding alpha values 164 | a = get_window(alpha[:, :, np.newaxis], x, y, N)[:, :, 0] 165 | 166 | # Take surrounding foreground pixels 167 | f_pixels = get_window(foreground, x, y, N) 168 | f_weights = (a ** 2 * gaussian_weights).ravel() 169 | 170 | f_pixels = np.reshape(f_pixels, (N * N, 3)) 171 | posInds = np.nan_to_num(f_weights) > 0 172 | f_pixels = f_pixels[posInds, :] 173 | f_weights = f_weights[posInds] 174 | 175 | # Take surrounding background pixels 176 | b_pixels = get_window(background, x, y, N) 177 | b_weights = ((1 - a) ** 2 * gaussian_weights).ravel() 178 | 179 | b_pixels = np.reshape(b_pixels, (N * N, 3)) 180 | posInds = np.nan_to_num(b_weights) > 0 181 | b_pixels = b_pixels[posInds, :] 182 | b_weights = b_weights[posInds] 183 | 184 | # if not enough data, return to it later... 185 | if len(f_weights) < minN or len(b_weights) < minN: 186 | # if end of loop has been reached and n is still < nUnknown, infinite loop will occur 187 | if i == Y.shape[0] and n < nUnknown: 188 | # adjust minN, break loop, and retry. If that still fails, terminate the program 189 | if minN > (minN - minN_reduction): 190 | minN -= 1 191 | n = 1 192 | warnings.warn(message="Infinte loop encountered. Reducing minN by 1 and retrying.", 193 | category=RuntimeWarning) 194 | break 195 | else: 196 | raise RuntimeError("Terminating infinite loop. Adjust input parameters and retry.") 197 | continue 198 | 199 | # Partition foreground and background pixels to clusters (in a weighted manner) 200 | mu_f, sigma_f = clustFunc(f_pixels, f_weights) 201 | mu_b, sigma_b = clustFunc(b_pixels, b_weights) 202 | 203 | alpha_init = np.nanmean(a.ravel()) 204 | # Solve for F,B for all cluster pairs 205 | f, b, alphaT = solve(mu_f, sigma_f, mu_b, sigma_b, p, 0.01, alpha_init, 50, 1e-6) 206 | foreground[y, x] = f.ravel() 207 | background[y, x] = b.ravel() 208 | alpha[y, x] = alphaT 209 | unknown_mask[y, x] = 0 210 | n += 1 211 | 212 | return alpha 213 | 214 | 215 | def main(img, trimap, sigma, N, minN, minN_reduction): 216 | img = cv2.imread(str(Path(img)))[:, :, :3] 217 | trimap = cv2.imread(str(Path(trimap)), cv2.IMREAD_GRAYSCALE) 218 | alpha = bayesian_matte(img, trimap, sigma, N, minN, minN_reduction) 219 | # scipy.misc.imsave('gandalfAlpha.png', alpha) 220 | plt.title("Alpha matte") 221 | plt.imshow(alpha, cmap='gray') 222 | plt.show() 223 | 224 | 225 | if __name__ == '__main__': 226 | import matplotlib.pyplot as plt 227 | 228 | # start parser 229 | parser = ArgumentParser() 230 | 231 | # add args 232 | parser.add_argument('image', help="path to image to be segmented") 233 | parser.add_argument('trimap', help="path to trimap of image") 234 | parser.add_argument('-s', '--sigma', default=8, help="variance of gaussian for spatial weighting") 235 | parser.add_argument('-n', '--N', default=25, help="pixel neighborhood size") 236 | parser.add_argument('-mn', '--minN', default=10, help="minimum required foreground and background neighbors for " 237 | "optimization") 238 | parser.add_argument('-red', '--minN_reduction', default=0, help="number of times to reduce minN if an infinite " 239 | "loop is encountered") 240 | 241 | args = parser.parse_args() 242 | # call main with all args 243 | main(args.image, args.trimap, args.sigma, args.N, args.minN, args.minN_reduction) 244 | -------------------------------------------------------------------------------- /gandalf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarcoForte/bayesian-matting/8b9164c901157e65f94ecfddb9e902f6d080d288/gandalf.png -------------------------------------------------------------------------------- /gandalfAlpha.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarcoForte/bayesian-matting/8b9164c901157e65f94ecfddb9e902f6d080d288/gandalfAlpha.png -------------------------------------------------------------------------------- /gandalfTrimap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarcoForte/bayesian-matting/8b9164c901157e65f94ecfddb9e902f6d080d288/gandalfTrimap.png -------------------------------------------------------------------------------- /orchard_bouman_clust.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # Implementation of orchard bouman clustering 4 | # Not the cleanest code 5 | # Might be faster if rewritten in all numpy, node object -> array 6 | # May be able to be spedup with numba. 7 | 8 | class Node(object): 9 | 10 | def __init__(self, matrix, w): 11 | W = np.sum(w) 12 | self.w = w 13 | self.X = matrix 14 | self.left = None 15 | self.right = None 16 | self.mu = np.einsum('ij,i->j', self.X, w)/W 17 | diff = self.X - np.tile(self.mu, [np.shape(self.X)[0], 1]) 18 | t = np.einsum('ij,i->ij', diff, np.sqrt(w)) 19 | self.cov = (t.T @ t)/W + 1e-5*np.eye(3) 20 | self.N = self.X.shape[0] 21 | V, D = np.linalg.eig(self.cov) 22 | self.lmbda = np.max(np.abs(V)) 23 | self.e = D[np.argmax(np.abs(V))] 24 | 25 | 26 | # S is measurements vector - dim nxd 27 | # w is weights vector - dim n 28 | def clustFunc(S, w, minVar=0.05): 29 | mu, sigma = [], [] 30 | nodes = [] 31 | nodes.append(Node(S, w)) 32 | 33 | while max(nodes, key=lambda x: x.lmbda).lmbda > minVar: 34 | nodes = split(nodes) 35 | 36 | for i, node in enumerate(nodes): 37 | mu.append(node.mu) 38 | sigma.append(node.cov) 39 | 40 | return np.array(mu), np.array(sigma) 41 | 42 | 43 | def split(nodes): 44 | idx_max = max(enumerate(nodes), key=lambda x: x[1].lmbda)[0] 45 | C_i = nodes[idx_max] 46 | idx = C_i.X @ C_i.e <= np.dot(C_i.mu, C_i.e) 47 | C_a = Node(C_i.X[idx], C_i.w[idx]) 48 | C_b = Node(C_i.X[np.logical_not(idx)], C_i.w[np.logical_not(idx)]) 49 | nodes.pop(idx_max) 50 | nodes.append(C_a) 51 | nodes.append(C_b) 52 | return nodes 53 | --------------------------------------------------------------------------------