├── README.md ├── inflation_foreground_forgit.py ├── test1-d.png └── test2-d.png /README.md: -------------------------------------------------------------------------------- 1 | # Post-Processing-for-Matting 2 | We often change background after matting. But due to the matting precision, the edge of body, especially hair, often leak pixels contain background, so we need some post-processing like this. 3 | 4 | ![result1](https://github.com/quziyan/Post-Processing-for-Matting/blob/main/test1-d.png) 5 | ![result2](https://github.com/quziyan/Post-Processing-for-Matting/blob/main/test2-d.png) 6 | -------------------------------------------------------------------------------- /inflation_foreground_forgit.py: -------------------------------------------------------------------------------- 1 | import cv2,os 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import glob 5 | import numpy as np 6 | import scipy as sp 7 | import scipy.optimize as opt 8 | import time 9 | 10 | def cv2_imread(file_path, toRGB = False, max_border_len=None, shape=None): 11 | 12 | #cv_img = cv2.imdecode(np.fromfile(file_path, dtype=np.uint8), -1) 13 | cv_img = cv2.imread(file_path) 14 | if toRGB: 15 | cv_img = cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB) 16 | assert not (max_border_len is not None and shape is not None) 17 | if max_border_len is not None: 18 | h, w = cv_img.shape[0:2] 19 | ratio = max_border_len / max(h, w) 20 | new_h = int(h * ratio) 21 | new_w = int(w * ratio) 22 | cv_img = cv2.resize(cv_img, (new_w, new_h)) 23 | if shape is not None: 24 | h, w = shape 25 | cv_img = cv2.resize(cv_img, (w, h)) 26 | 27 | return cv_img 28 | def cv2_imwrite( path, img, toBGR=False): 29 | suffix = os.path.splitext(path)[-1] 30 | if toBGR: 31 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 32 | cv2.imencode(suffix, img)[1].tofile(path) 33 | def dilation(mask, ksize=20): 34 | kernel = np.ones((ksize, ksize),np.float) 35 | mask_di = cv2.dilate(mask, kernel, ) 36 | mask_di = mask_di[:,:,np.newaxis] 37 | return mask_di 38 | 39 | def get_similarity(vec1, vec2, similarity_type, space ): 40 | #similarity_type l2/cosine 41 | #space rgb/hsv/ 42 | if space == 'hsv': 43 | vec1 = cv2.cvtColor(vec1, cv2.COLOR_RGB2HSV)[:,:, 2] 44 | vec2 = cv2.cvtColor(vec2, cv2.COLOR_RGB2HSV)[:,:, 2] 45 | vec1 = vec1 / 255.0 46 | vec2 = vec2 / 255.0 47 | 48 | similarity = 0 49 | if similarity_type == 'cosine': 50 | vec1 = np.linalg.norm(vec1) 51 | vec2 = np.linalg.norm(vec2) 52 | similarity = np.dot(vec1, vec2.T)/(vec1 * vec2) 53 | elif similarity_type == 'l2': 54 | dist = ((vec1 - vec2)**2).mean() ** 0.5 55 | similarity = 1 - dist 56 | elif similarity_type == 'l1': 57 | dist = (np.abs(vec1 - vec2)).mean() 58 | similarity = 1 - dist 59 | return similarity 60 | 61 | def get_similarity_vec(vec1, vec2, similarity_type, space ): 62 | #similarity_type l2/cosine 63 | #space rgb/hsv/ 64 | if space == 'hsv': 65 | vec1 = cv2.cvtColor(vec1, cv2.COLOR_RGB2HSV)[:,:, 2:] 66 | vec2 = cv2.cvtColor(vec2, cv2.COLOR_RGB2HSV)[:,:, 2:] 67 | vec1 = vec1 / 255.0 68 | vec2 = vec2 / 255.0 69 | 70 | similarity = 0 71 | eps = 1e-6 72 | if similarity_type == 'cosine': 73 | vec1_norm = np.linalg.norm(vec1, axis=2, keepdims=False) 74 | vec2_norm = np.linalg.norm(vec2, axis=2, keepdims=False) 75 | similarity = (vec1 * vec2).sum(axis=2, keepdims=False) / (vec1_norm * vec2_norm + eps) 76 | elif similarity_type == 'l2': 77 | dist = ((vec1 - vec2)**2).mean(axis=2, keepdims=False) ** 0.5 78 | similarity = 1 - dist 79 | elif similarity_type == 'l1': 80 | dist = (np.abs(vec1 - vec2)).mean(axis=2, keepdims=False) 81 | similarity = 1 - dist 82 | return similarity 83 | 84 | def get_win_centerdist_similarity(mask_win, center_h, center_w): 85 | #计算window与中心点的位置归一化相似度(越近,相似度越高) 86 | ksize_h, ksize_w= mask_win.shape 87 | #r = max(ksize_h - center_i, ksize_w - ) // 2 88 | r_h = max(center_h, ksize_h - 1 - center_h) 89 | r_w = max(center_w, ksize_w - 1 - center_w) 90 | r= max(r_h, r_w) 91 | maxdist = (r ** 2 * 2) ** 0.5 92 | similarity = np.zeros_like(mask_win) 93 | for i in range(ksize_h): 94 | for j in range(ksize_w): 95 | ijdist = ((i - center_h) ** 2 + (j - center_w) ** 2) ** 0.5 96 | similarity[i , j] = (1-ijdist) / maxdist 97 | return similarity 98 | 99 | 100 | def get_max(image_ori, image, mask, ksize, similarity_type, space, evaluate_eps = 0.00): 101 | #image_ori 原始图 102 | #image 增强后的图 103 | image_re = np.copy(image) 104 | def get_mask_max(mask, i, j, ksize): 105 | h, w = mask.shape 106 | r = ksize // 2 107 | h_start = max(i - r, 0) 108 | h_end = min(h-1, i+r) 109 | w_start = max(j - r, 0) 110 | w_end = min(w-1, j+r) 111 | max_val = -1 112 | h_max = -1 113 | w_max = -1 114 | 115 | image_win = image[h_start:h_end, w_start:w_end] 116 | mask_win = mask[h_start:h_end, w_start:w_end] 117 | similarity = mask_win + \ 118 | 0.01*get_similarity_vec(image_win, image_ori[i:i+1,j:j+1], similarity_type = similarity_type, space=space) + \ 119 | 0 #0.01 * get_win_centerdist_similarity(mask_win=mask_win, center_h=i-h_start, center_w=j-w_start) 120 | pos = np.unravel_index(np.argmax(similarity),similarity.shape) 121 | 122 | h_max = h_start + pos[0] 123 | w_max = w_start + pos[1] 124 | return h_max, w_max 125 | ''' 126 | for h_ in range(h_start, h_end): 127 | for w_ in range(w_start, w_end): 128 | similarity = mask[h_, w_] + \ 129 | get_similarity(image[h_:h_+1, w_:w_+1], 130 | image[i:i+1,j:j+1], 131 | similarity_type = similarity_type, 132 | space=space) 133 | if similarity > max_val: 134 | max_val = similarity 135 | h_max = h_ 136 | w_max = w_ 137 | return h_max, w_max 138 | ''' 139 | 140 | h, w = mask.shape 141 | for i in range(h): 142 | for j in range(w): 143 | alpha_ij = mask[i, j] 144 | if alpha_ij <= evaluate_eps and alpha_ij > 0.0: # alpha_ij > 0 + evaluate_eps and alpha_ij < 1 - evaluate_eps: 145 | i_max, j_max = get_mask_max(mask, i, j, ksize) 146 | image_re[i, j] = image[i_max, j_max] 147 | #else: 148 | return image_re 149 | 150 | 151 | 152 | 153 | def get_image_bg(image_ori, mask, ksize, evaluate_eps=0.00, calculate_method='lsq'): 154 | #估计alpha在(0,1)范围内的背景 155 | def linear_strech(array, start, end): 156 | #这个函数不好用 157 | to_min_val = max(min(array), start) 158 | to_max_val = min(end, max(array)) 159 | result = (array - min(array)) / (max(array) - min(array)) * (to_max_val - to_min_val) + to_min_val 160 | return result 161 | 162 | 163 | image_bg = np.copy(image_ori) 164 | image_fg = np.copy(image_ori) 165 | h, w = mask.shape 166 | r = ksize // 2 167 | #dist_weight = get_dist_weight(ksize).reshape(-1, 1) 168 | for i in range(h): 169 | #print(i) 170 | for j in range(w): 171 | if mask[i,j] > 0 + evaluate_eps and mask[i,j] < 1 - evaluate_eps: 172 | h_start = max(i - r, 0) 173 | h_end = min(h-1, i+r) 174 | w_start = max(j - r, 0) 175 | w_end = min(w-1, j+r) 176 | mask_win = mask[h_start:h_end, w_start: w_end] 177 | image_ori_win = image_ori[h_start:h_end, w_start: w_end] 178 | ''' #Navie实现背景估计 179 | bg_win = image_ori_win[mask_win ==0] 180 | if bg_win.size == 0: 181 | pos = np.unravel_index(np.argmin(mask_win), mask_win.shape) 182 | image_bg[i,j] = image_ori[pos[0], pos[1]] 183 | #image_bg[i,j] = image_ori[i,j] 184 | else: 185 | image_bg[i,j] = bg_win.mean(axis=0, keepdims=True) 186 | ''' 187 | dist_weight = get_win_centerdist_similarity(mask_win, center_h=i-h_start, center_w=j-w_start).reshape(-1, 1) 188 | #最小二乘实现背景估计 189 | mask_win_vec_raw = mask_win.reshape(-1, 1) 190 | mask_win_vec = np.concatenate((mask_win_vec_raw, 1-mask_win_vec_raw), axis=1) #alpha, 1-alpha N*2 191 | image_ori_win_vec = image_ori_win.reshape(-1, 3) # N*3 192 | if calculate_method == 'lsq': 193 | #这部分假设过强,假设前景一致,背景一致 194 | #print(mask_win_vec.shape, dist_weight.shape) 195 | mask_win_vec = mask_win_vec * dist_weight #mask_win_vec_raw * dist_weight 196 | image_ori_win_vec = image_ori_win_vec * dist_weight #image_ori_win_vec * dist_weight 197 | #print(mask_win_vec.shape, image_ori_win_vec.shape) 198 | #fg_bg 2*3 199 | fg_bg, _res, rank, singular = np.linalg.lstsq(mask_win_vec, image_ori_win_vec) 200 | #print(fg_bg) 201 | #print(fg_bg) 202 | image_fg[i,j] = fg_bg[0].clip(0,255) #linear_strech(fg_bg[0], 0, 255) #fg_bg[0].clip(0,255) 203 | image_bg[i,j] = fg_bg[1].clip(0,255) #linear_strech(fg_bg[1], 0, 255) #fg_bg[1].clip(0,255) 204 | elif calculate_method == 'naive': 205 | bg_win = image_ori_win[mask_win ==0] 206 | if bg_win.size == 0: 207 | pos = np.unravel_index(np.argmin(mask_win), mask_win.shape) 208 | image_bg[i,j] = image_ori[pos[0], pos[1]] 209 | #image_bg[i,j] = image_ori[i,j] 210 | else: 211 | image_bg[i,j] = bg_win.mean(axis=0, keepdims=True) 212 | mask_ij = mask[i, j] if mask[i,j] >= 0.01 else 1.0 213 | image_fg[i, j] = ( image_ori[i, j] - ( 1 - mask_ij ) ) / mask_ij 214 | else: 215 | raise 216 | 217 | 218 | else: 219 | image_fg[i,j] = image_ori[i,j] 220 | image_bg[i,j] = image_ori[i,j] 221 | return image_bg, image_fg 222 | 223 | 224 | if __name__ == '__main__': 225 | DEBUG=True 226 | def get_blue_bg(refimg): 227 | bg = np.zeros_like(refimg) 228 | bg[:,:, 2] = 255 229 | return bg 230 | def text_on_img_bottom(img, text): 231 | h, w, c = img.shape 232 | font = cv2.FONT_HERSHEY_SIMPLEX 233 | re = cv2.putText(img, text, (50,50), font, 0.8, (255, 0, 0), 2) 234 | return re 235 | 236 | def process(img_ori, mask, blue_bg, img_pixel, imgoutput_path): 237 | #subplot = lambda i, j: plt.subplot(rows, cols, (i)* cols + j + 1) 238 | img_ori_bg, img_ori_fg = get_image_bg(img_ori, mask, ksize=10, evaluate_eps=0.05, calculate_method='lsq') 239 | matting = lambda foreground: (foreground * mask[:,:, np.newaxis] + blue_bg * (1-mask[:,:, np.newaxis])).astype(np.uint8) 240 | image_ori_fgenhanced = img_ori_fg # get_enhanced_image(img_ori, img_ori_bg, mask, mask_gamma_coeff=1.0) 241 | 242 | image_ori_fgenhanced_final = \ 243 | get_max(img_ori, image_ori_fgenhanced, mask, ksize=35, similarity_type='l1', space='hsv', evaluate_eps=0.05) 244 | 245 | matting_img_ori = matting(img_ori) 246 | matting_img_pixel = matting(img_pixel) 247 | matting_img_fgenhanced = matting(image_ori_fgenhanced_final) 248 | 249 | if DEBUG: 250 | output = np.concatenate(( 251 | np.repeat((mask[:,:,np.newaxis] * 255).astype(np.uint8), 3, axis=2), 252 | text_on_img_bottom(img_ori, 'ORI_IMG'), 253 | text_on_img_bottom(img_ori_bg, 'MID_BG'), 254 | text_on_img_bottom(image_ori_fgenhanced, 'MID_FG'), 255 | text_on_img_bottom(img_pixel, 'ORI_Pixel'), 256 | text_on_img_bottom(image_ori_fgenhanced_final, 'MID_FG_ENC'), 257 | text_on_img_bottom(matting_img_ori, 'MATTING_ORI_IMAGE'), 258 | text_on_img_bottom(matting_img_pixel, 'MATTING_ORI_PIXEL'), 259 | text_on_img_bottom(matting_img_fgenhanced, 'MATTING_FG_ENC'), 260 | ),axis=1) 261 | cv2_imwrite(imgoutput_path, output, toBGR=True) 262 | else: 263 | output = np.concatenate(( 264 | np.repeat((mask[:,:,np.newaxis] * 255).astype(np.uint8), 3, axis=2), 265 | text_on_img_bottom(img_ori, 'ORI_IMG'), 266 | text_on_img_bottom(img_pixel, 'ORI_Pixel'), 267 | text_on_img_bottom(image_ori_fgenhanced_final, 'MID_PIXEL'), 268 | text_on_img_bottom(matting_img_pixel, 'MATTING_ORI_IMAGE'), 269 | text_on_img_bottom(matting_img_pixel, 'MATTING_ORI_PIXEL'), 270 | text_on_img_bottom(matting_img_fgenhanced, 'MATTING_MID_PIXEL'), 271 | ),axis=1) 272 | cv2_imwrite(imgoutput_path, output, toBGR=True) 273 | 274 | def get_path(pathbase, pathtype): 275 | if pathtype == 'mask': 276 | re0 =pathbase.replace('-a', '-b') 277 | if pathtype == 'pixel': 278 | re0 =pathbase.replace('-a', '-c') 279 | if pathtype == 'output': 280 | re0 = pathbase.replace('-a', '-d') 281 | rawpath, ext = os.path.splitext(re0) 282 | re1 = rawpath + '.png' 283 | return re1 284 | raise 285 | 286 | 287 | image_ORI_paths = glob.glob(r'D:\Data\Download\new_data\*-a.*') 288 | #构建一个文件夹,原图为xxx-a.jpg, matting-mask为xxx-b.png, 输出为xxx-d.png 289 | maxBorderLen = 800 290 | for imgoripath in image_ORI_paths: 291 | maskpath = get_path(imgoripath, 'mask') 292 | imgpixelpath = get_path(imgoripath, 'pixel') 293 | imgoutput_path = get_path(imgoripath, 'output') 294 | 295 | imgori = cv2_imread(imgoripath, toRGB=True, max_border_len=maxBorderLen) 296 | mask = cv2_imread(maskpath, shape=imgori.shape[0:2])[:,:,0] / 255. 297 | 298 | #imgpixel设置成原图,姑且认为该变量没有用 299 | imgpixel = imgori # cv2_imread(imgpixelpath, toRGB=True, shape=imgori.shape[0:2]) 300 | print(imgoripath, imgori.shape, mask.shape, imgpixel.shape) 301 | blue_bg = get_blue_bg(imgori) 302 | process(imgori, mask, blue_bg, imgpixel, imgoutput_path) 303 | -------------------------------------------------------------------------------- /test1-d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quziyan/Post-Processing-for-Matting/1160a6e5f2952b77ffc80ecd6db39a9dc7dc08b2/test1-d.png -------------------------------------------------------------------------------- /test2-d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quziyan/Post-Processing-for-Matting/1160a6e5f2952b77ffc80ecd6db39a9dc7dc08b2/test2-d.png --------------------------------------------------------------------------------