├── .gitattributes
├── .gitignore
├── README.md
├── bayesian_matting.py
├── gandalf.png
├── gandalfAlpha.png
├── gandalfTrimap.png
└── orchard_bouman_clust.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
4 | # Custom for Visual Studio
5 | *.cs diff=csharp
6 |
7 | # Standard to msysgit
8 | *.doc diff=astextplain
9 | *.DOC diff=astextplain
10 | *.docx diff=astextplain
11 | *.DOCX diff=astextplain
12 | *.dot diff=astextplain
13 | *.DOT diff=astextplain
14 | *.pdf diff=astextplain
15 | *.PDF diff=astextplain
16 | *.rtf diff=astextplain
17 | *.RTF diff=astextplain
18 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 |
27 | # PyInstaller
28 | # Usually these files are written by a python script from a template
29 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
30 | *.manifest
31 | *.spec
32 |
33 | # Installer logs
34 | pip-log.txt
35 | pip-delete-this-directory.txt
36 |
37 | # Unit test / coverage reports
38 | htmlcov/
39 | .tox/
40 | .coverage
41 | .coverage.*
42 | .cache
43 | nosetests.xml
44 | coverage.xml
45 | *,cover
46 | .hypothesis/
47 |
48 | # Translations
49 | *.mo
50 | *.pot
51 |
52 | # Django stuff:
53 | *.log
54 | local_settings.py
55 |
56 | # Flask instance folder
57 | instance/
58 |
59 | # Scrapy stuff:
60 | .scrapy
61 |
62 | # Sphinx documentation
63 | docs/_build/
64 |
65 | # PyBuilder
66 | target/
67 |
68 | # IPython Notebook
69 | .ipynb_checkpoints
70 |
71 | # pyenv
72 | .python-version
73 |
74 | # celery beat schedule file
75 | celerybeat-schedule
76 |
77 | # dotenv
78 | .env
79 |
80 | # virtualenv
81 | venv/
82 | ENV/
83 |
84 | # Spyder project settings
85 | .spyderproject
86 |
87 | # Rope project settings
88 | .ropeproject
89 |
90 | # =========================
91 | # Operating System Files
92 | # =========================
93 |
94 | # OSX
95 | # =========================
96 |
97 | .DS_Store
98 | .AppleDouble
99 | .LSOverride
100 |
101 | # Thumbnails
102 | ._*
103 |
104 | # Files that might appear in the root of a volume
105 | .DocumentRevisions-V100
106 | .fseventsd
107 | .Spotlight-V100
108 | .TemporaryItems
109 | .Trashes
110 | .VolumeIcon.icns
111 |
112 | # Directories potentially created on remote AFP share
113 | .AppleDB
114 | .AppleDesktop
115 | Network Trash Folder
116 | Temporary Items
117 | .apdisk
118 |
119 | # Windows
120 | # =========================
121 |
122 | # Windows image file caches
123 | Thumbs.db
124 | ehthumbs.db
125 |
126 | # Folder config file
127 | Desktop.ini
128 |
129 | # Recycle Bin used on file shares
130 | $RECYCLE.BIN/
131 |
132 | # Windows Installer files
133 | *.cab
134 | *.msi
135 | *.msm
136 | *.msp
137 |
138 | # Windows shortcuts
139 | *.lnk
140 |
141 | # vscode
142 | .vscode
143 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # bayesian-matting
2 | Python implementation of Yung-Yu Chuang, Brian Curless, David H. Salesin, and Richard Szeliski. A Bayesian Approach to Digital Matting. In Proceedings of IEEE Computer Vision and Pattern Recognition (CVPR 2001), Vol. II, 264-271, December 2001
3 |
4 |
5 | ### Requirements
6 | - python 3.5+ (Though it should run on 2.7 with some minor tweaks)
7 | - scipy
8 | - numpy
9 | - numba > 0.30.1 (Not neccesary, but does give a 5x speedup)
10 | - matplotlib
11 | - opencv
12 | - sys
13 | - pathlib
14 | - argparse
15 |
16 | ### Running the demo
17 | - 'python bayesian_matting.py gandalf.png gandalfTrimap.png'
18 | - sigma (σ) fall off of gaussian weighting to local window
19 | - N size of window to construct local fg/bg clusters from
20 | - minN minimum number of known pixels in local window to proceed
21 | - minN_reduction to reduce N by in event of infinite loop. May reduce accuracy
22 |
23 |
24 | ### Results
25 |
26 |
27 |
28 |
29 |
30 |
31 | ### More Information
32 |
33 | For more information see the orginal project website http://grail.cs.washington.edu/projects/digital-matting/image-matting/
34 | This implementation was mostly adapted from Michael Rubinsteins matlab code here,
35 | http://www1.idc.ac.il/toky/CompPhoto-09/Projects/Stud_projects/Miki/index.html
36 | http://people.csail.mit.edu/mrub/code/bayesmat.zip
37 |
38 | ### Disclaimer
39 |
40 | The code is free for academic/research purpose. Use at your own risk and we are not responsible for any loss resulting from this code. Feel free to submit pull request for bug fixes.
41 |
42 | ### Contact
43 | [Marco Forte](https://marcoforte.github.io/) (fortem@tcd.ie)
44 |
45 | #### Original authors:
46 | [Yung-Yu Chuang](http://www.cs.washington.edu/homes/cyy)
47 | [Brian Curless](http://www.cs.washington.edu/homes/curless)
48 | [David Salesin](http://www.cs.washington.edu/homes/salesin)
49 | [Richard Szeliski](http://www.research.microsoft.com/~szeliski)
50 |
--------------------------------------------------------------------------------
/bayesian_matting.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from pathlib import Path
3 | import numpy as np
4 | import cv2
5 | from numba import jit
6 | from argparse import ArgumentParser
7 | import warnings
8 |
9 | from orchard_bouman_clust import clustFunc
10 |
11 |
12 | def matlab_style_gauss2d(shape=(3, 3), sigma=0.5):
13 | """
14 | 2D gaussian mask - should give the same result as MATLAB's
15 | fspecial('gaussian',[shape],[sigma])
16 | """
17 | m, n = [(ss - 1.) / 2. for ss in shape]
18 | y, x = np.ogrid[-m:m + 1, -n:n + 1]
19 | h = np.exp(-(x * x + y * y) / (2. * sigma * sigma))
20 | h[h < np.finfo(h.dtype).eps * h.max()] = 0
21 | sumh = h.sum()
22 | if sumh != 0:
23 | h /= sumh
24 | return h
25 |
26 |
27 | # returns the surrounding N-rectangular neighborhood of matrix m, centered
28 | # at pixel (x,y), (odd valued N)
29 | @jit(nopython=True, cache=True)
30 | def get_window(m, x, y, N):
31 | h, w, c = m.shape
32 | halfN = N // 2
33 | r = np.zeros((N, N, c))
34 | xmin = max(0, x - halfN);
35 | xmax = min(w, x + (halfN + 1))
36 | ymin = max(0, y - halfN);
37 | ymax = min(h, y + (halfN + 1))
38 | pxmin = halfN - (x - xmin);
39 | pxmax = halfN + (xmax - x)
40 | pymin = halfN - (y - ymin);
41 | pymax = halfN + (ymax - y)
42 |
43 | r[pymin:pymax, pxmin:pxmax] = m[ymin:ymax, xmin:xmax]
44 | return r
45 |
46 |
47 | @jit(nopython=True, cache=True)
48 | def solve(mu_F, Sigma_F, mu_B, Sigma_B, C, sigma_C, alpha_init, maxIter, minLike):
49 | '''
50 | Solves for F,B and alpha that maximize the sum of log
51 | likelihoods at the given pixel C.
52 | input:
53 | mu_F - means of foreground clusters (for RGB, of size 3x#Fclusters)
54 | Sigma_F - covariances of foreground clusters (for RGB, of size
55 | 3x3x#Fclusters)
56 | mu_B,Sigma_B - same for background clusters
57 | C - observed pixel
58 | alpha_init - initial value for alpha
59 | maxIter - maximal number of iterations
60 | minLike - minimal change in likelihood between consecutive iterations
61 |
62 | returns:
63 | F,B,alpha - estimate of foreground, background and alpha
64 | channel (for RGB, each of size 3x1)
65 | '''
66 | I = np.eye(3)
67 | FMax = np.zeros(3)
68 | BMax = np.zeros(3)
69 | alphaMax = 0
70 | maxlike = - np.inf
71 | invsgma2 = 1 / sigma_C ** 2
72 | for i in range(mu_F.shape[0]):
73 | mu_Fi = mu_F[i]
74 | invSigma_Fi = np.linalg.inv(Sigma_F[i])
75 | for j in range(mu_B.shape[0]):
76 | mu_Bj = mu_B[j]
77 | invSigma_Bj = np.linalg.inv(Sigma_B[j])
78 |
79 | alpha = alpha_init
80 | myiter = 1
81 | lastLike = -1.7977e+308
82 | while True:
83 | # solve for F,B
84 | A11 = invSigma_Fi + I * alpha ** 2 * invsgma2
85 | A12 = I * alpha * (1 - alpha) * invsgma2
86 | A22 = invSigma_Bj + I * (1 - alpha) ** 2 * invsgma2
87 | A = np.vstack((np.hstack((A11, A12)), np.hstack((A12, A22))))
88 | b1 = invSigma_Fi @ mu_Fi + C * (alpha) * invsgma2
89 | b2 = invSigma_Bj @ mu_Bj + C * (1 - alpha) * invsgma2
90 | b = np.atleast_2d(np.concatenate((b1, b2))).T
91 |
92 | X = np.linalg.solve(A, b)
93 | F = np.maximum(0, np.minimum(1, X[0:3]))
94 | B = np.maximum(0, np.minimum(1, X[3:6]))
95 | # solve for alpha
96 |
97 | alpha = np.maximum(0, np.minimum(1, ((np.atleast_2d(C).T - B).T @ (F - B)) / np.sum((F - B) ** 2)))[
98 | 0, 0]
99 | # # calculate likelihood
100 | L_C = - np.sum((np.atleast_2d(C).T - alpha * F - (1 - alpha) * B) ** 2) * invsgma2
101 | L_F = (- ((F - np.atleast_2d(mu_Fi).T).T @ invSigma_Fi @ (F - np.atleast_2d(mu_Fi).T)) / 2)[0, 0]
102 | L_B = (- ((B - np.atleast_2d(mu_Bj).T).T @ invSigma_Bj @ (B - np.atleast_2d(mu_Bj).T)) / 2)[0, 0]
103 | like = (L_C + L_F + L_B)
104 | # like = 0
105 |
106 | if like > maxlike:
107 | alphaMax = alpha
108 | maxLike = like
109 | FMax = F.ravel()
110 | BMax = B.ravel()
111 |
112 | if myiter >= maxIter or abs(like - lastLike) <= minLike:
113 | break
114 |
115 | lastLike = like
116 | myiter += 1
117 | return FMax, BMax, alphaMax
118 |
119 |
120 | def bayesian_matte(img, trimap, sigma=8, N=25, minN=10, minN_reduction=0):
121 | # check minN_reduction parameter
122 | if minN_reduction >= minN:
123 | raise ValueError("minN_reduction parameter must be less than minN")
124 |
125 | img = img / 255
126 |
127 | h, w, c = img.shape
128 | alpha = np.zeros((h, w))
129 |
130 | fg_mask = trimap == 255
131 | bg_mask = trimap == 0
132 | unknown_mask = True ^ np.logical_or(fg_mask, bg_mask)
133 | foreground = img * np.repeat(fg_mask[:, :, np.newaxis], 3, axis=2)
134 | background = img * np.repeat(bg_mask[:, :, np.newaxis], 3, axis=2)
135 |
136 | gaussian_weights = matlab_style_gauss2d((N, N), sigma)
137 | gaussian_weights = gaussian_weights / np.max(gaussian_weights)
138 |
139 | alpha[fg_mask] = 1
140 | F = np.zeros(img.shape)
141 | B = np.zeros(img.shape)
142 | alphaRes = np.zeros(trimap.shape)
143 |
144 | n = 1
145 | alpha[unknown_mask] = np.nan
146 | nUnknown = np.sum(unknown_mask)
147 | unkreg = unknown_mask
148 |
149 | kernel = np.ones((3, 3))
150 | while n < nUnknown:
151 | unkreg = cv2.erode(unkreg.astype(np.uint8), kernel, iterations=1)
152 | unkpixels = np.logical_and(np.logical_not(unkreg), unknown_mask)
153 |
154 | Y, X = np.nonzero(unkpixels)
155 |
156 | for i in range(Y.shape[0]):
157 | if n % 100 == 0:
158 | print(n, nUnknown)
159 | y, x = Y[i], X[i]
160 | p = img[y, x]
161 | # Try cluster Fg, Bg in p's known neighborhood
162 |
163 | # take surrounding alpha values
164 | a = get_window(alpha[:, :, np.newaxis], x, y, N)[:, :, 0]
165 |
166 | # Take surrounding foreground pixels
167 | f_pixels = get_window(foreground, x, y, N)
168 | f_weights = (a ** 2 * gaussian_weights).ravel()
169 |
170 | f_pixels = np.reshape(f_pixels, (N * N, 3))
171 | posInds = np.nan_to_num(f_weights) > 0
172 | f_pixels = f_pixels[posInds, :]
173 | f_weights = f_weights[posInds]
174 |
175 | # Take surrounding background pixels
176 | b_pixels = get_window(background, x, y, N)
177 | b_weights = ((1 - a) ** 2 * gaussian_weights).ravel()
178 |
179 | b_pixels = np.reshape(b_pixels, (N * N, 3))
180 | posInds = np.nan_to_num(b_weights) > 0
181 | b_pixels = b_pixels[posInds, :]
182 | b_weights = b_weights[posInds]
183 |
184 | # if not enough data, return to it later...
185 | if len(f_weights) < minN or len(b_weights) < minN:
186 | # if end of loop has been reached and n is still < nUnknown, infinite loop will occur
187 | if i == Y.shape[0] and n < nUnknown:
188 | # adjust minN, break loop, and retry. If that still fails, terminate the program
189 | if minN > (minN - minN_reduction):
190 | minN -= 1
191 | n = 1
192 | warnings.warn(message="Infinte loop encountered. Reducing minN by 1 and retrying.",
193 | category=RuntimeWarning)
194 | break
195 | else:
196 | raise RuntimeError("Terminating infinite loop. Adjust input parameters and retry.")
197 | continue
198 |
199 | # Partition foreground and background pixels to clusters (in a weighted manner)
200 | mu_f, sigma_f = clustFunc(f_pixels, f_weights)
201 | mu_b, sigma_b = clustFunc(b_pixels, b_weights)
202 |
203 | alpha_init = np.nanmean(a.ravel())
204 | # Solve for F,B for all cluster pairs
205 | f, b, alphaT = solve(mu_f, sigma_f, mu_b, sigma_b, p, 0.01, alpha_init, 50, 1e-6)
206 | foreground[y, x] = f.ravel()
207 | background[y, x] = b.ravel()
208 | alpha[y, x] = alphaT
209 | unknown_mask[y, x] = 0
210 | n += 1
211 |
212 | return alpha
213 |
214 |
215 | def main(img, trimap, sigma, N, minN, minN_reduction):
216 | img = cv2.imread(str(Path(img)))[:, :, :3]
217 | trimap = cv2.imread(str(Path(trimap)), cv2.IMREAD_GRAYSCALE)
218 | alpha = bayesian_matte(img, trimap, sigma, N, minN, minN_reduction)
219 | # scipy.misc.imsave('gandalfAlpha.png', alpha)
220 | plt.title("Alpha matte")
221 | plt.imshow(alpha, cmap='gray')
222 | plt.show()
223 |
224 |
225 | if __name__ == '__main__':
226 | import matplotlib.pyplot as plt
227 |
228 | # start parser
229 | parser = ArgumentParser()
230 |
231 | # add args
232 | parser.add_argument('image', help="path to image to be segmented")
233 | parser.add_argument('trimap', help="path to trimap of image")
234 | parser.add_argument('-s', '--sigma', default=8, help="variance of gaussian for spatial weighting")
235 | parser.add_argument('-n', '--N', default=25, help="pixel neighborhood size")
236 | parser.add_argument('-mn', '--minN', default=10, help="minimum required foreground and background neighbors for "
237 | "optimization")
238 | parser.add_argument('-red', '--minN_reduction', default=0, help="number of times to reduce minN if an infinite "
239 | "loop is encountered")
240 |
241 | args = parser.parse_args()
242 | # call main with all args
243 | main(args.image, args.trimap, args.sigma, args.N, args.minN, args.minN_reduction)
244 |
--------------------------------------------------------------------------------
/gandalf.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarcoForte/bayesian-matting/8b9164c901157e65f94ecfddb9e902f6d080d288/gandalf.png
--------------------------------------------------------------------------------
/gandalfAlpha.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarcoForte/bayesian-matting/8b9164c901157e65f94ecfddb9e902f6d080d288/gandalfAlpha.png
--------------------------------------------------------------------------------
/gandalfTrimap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarcoForte/bayesian-matting/8b9164c901157e65f94ecfddb9e902f6d080d288/gandalfTrimap.png
--------------------------------------------------------------------------------
/orchard_bouman_clust.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | # Implementation of orchard bouman clustering
4 | # Not the cleanest code
5 | # Might be faster if rewritten in all numpy, node object -> array
6 | # May be able to be spedup with numba.
7 |
8 | class Node(object):
9 |
10 | def __init__(self, matrix, w):
11 | W = np.sum(w)
12 | self.w = w
13 | self.X = matrix
14 | self.left = None
15 | self.right = None
16 | self.mu = np.einsum('ij,i->j', self.X, w)/W
17 | diff = self.X - np.tile(self.mu, [np.shape(self.X)[0], 1])
18 | t = np.einsum('ij,i->ij', diff, np.sqrt(w))
19 | self.cov = (t.T @ t)/W + 1e-5*np.eye(3)
20 | self.N = self.X.shape[0]
21 | V, D = np.linalg.eig(self.cov)
22 | self.lmbda = np.max(np.abs(V))
23 | self.e = D[np.argmax(np.abs(V))]
24 |
25 |
26 | # S is measurements vector - dim nxd
27 | # w is weights vector - dim n
28 | def clustFunc(S, w, minVar=0.05):
29 | mu, sigma = [], []
30 | nodes = []
31 | nodes.append(Node(S, w))
32 |
33 | while max(nodes, key=lambda x: x.lmbda).lmbda > minVar:
34 | nodes = split(nodes)
35 |
36 | for i, node in enumerate(nodes):
37 | mu.append(node.mu)
38 | sigma.append(node.cov)
39 |
40 | return np.array(mu), np.array(sigma)
41 |
42 |
43 | def split(nodes):
44 | idx_max = max(enumerate(nodes), key=lambda x: x[1].lmbda)[0]
45 | C_i = nodes[idx_max]
46 | idx = C_i.X @ C_i.e <= np.dot(C_i.mu, C_i.e)
47 | C_a = Node(C_i.X[idx], C_i.w[idx])
48 | C_b = Node(C_i.X[np.logical_not(idx)], C_i.w[np.logical_not(idx)])
49 | nodes.pop(idx_max)
50 | nodes.append(C_a)
51 | nodes.append(C_b)
52 | return nodes
53 |
--------------------------------------------------------------------------------