├── README.md ├── data ├── __init__.py ├── retrieve.py └── utils.py ├── data_dim ├── composition.py ├── retrieve.py └── test │ ├── alpha │ ├── 16452523375_08591714cf_o.png │ └── girl-1467820_1280.png │ ├── bg │ ├── 000000239187.jpg │ └── 000000393412.jpg │ ├── fg │ ├── 16452523375_08591714cf_o.png │ └── girl-1467820_1280.png │ └── trimap │ ├── 16452523375_08591714cf_o_0.png │ └── girl-1467820_1280_0.png ├── dataset.py ├── model.py ├── model_deconv.py ├── model_paper.py ├── predict.py ├── predict_trimap.py ├── test ├── 1803151818-00000003.jpg ├── 1803151818-00000003.png ├── 1803151818-00000004.jpg ├── 1803151818-00000004.png ├── 1803250719-00000103.jpg └── 1803250719-00000103.png └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # deep_image_matting_pytorch 2 | A Pytorch implementation of Deep Image Matting. 3 | 4 | ### Usage 5 | 6 | See details at: [Pytorch 抠图算法 Deep Image Matting 模型实现](https://www.jianshu.com/p/91fc778cf4ed) 7 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /data/retrieve.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Oct 18 10:47:28 2019 4 | 5 | @author: shirhe-lyh 6 | """ 7 | 8 | import cv2 9 | import numpy as np 10 | import os 11 | 12 | import utils 13 | 14 | 15 | def get_image_paths(root_dir): 16 | image_paths_dict = {} 17 | matting_paths_dict = {} 18 | for root, dirs, files in os.walk(root_dir): 19 | if not files: 20 | continue 21 | 22 | for file in files: 23 | file_path = os.path.join(root, file) 24 | file_path = file_path.replace('\\', '/') 25 | file_name = file.split('.')[0] 26 | dir_name = file_path.split('/')[-2] 27 | if dir_name.startswith('clip'): 28 | image_paths_dict[file_name] = file_path 29 | if dir_name.startswith('matting'): 30 | matting_paths_dict[file_name] = file_path 31 | 32 | image_corresponding_paths = [] 33 | for image_name, path in image_paths_dict.items(): 34 | matting_path = matting_paths_dict.get(image_name, None) 35 | if matting_path is not None: 36 | image_corresponding_paths.append([path, matting_path]) 37 | else: 38 | print(path) 39 | print('Number of valid images: ', len(image_corresponding_paths)) 40 | if len(image_corresponding_paths) < 1: 41 | raise ValueError('`root_dir` is error. Please reset it correctly.') 42 | return image_corresponding_paths 43 | 44 | 45 | def split(image_paths, num_val_samples=100): 46 | if image_paths is None: 47 | return None 48 | 49 | np.random.shuffle(image_paths) 50 | val_image_paths = image_paths[:num_val_samples] 51 | train_image_paths = image_paths[num_val_samples:] 52 | return train_image_paths, val_image_paths 53 | 54 | 55 | def write_to_txt(image_paths, txt_path, delimiter='@'): 56 | if image_paths is None: 57 | return 58 | 59 | with open(txt_path, 'w') as writer: 60 | for element in image_paths: 61 | line = delimiter.join(element) 62 | writer.write(line + '\n') 63 | print('Write successfully to: ', txt_path) 64 | 65 | 66 | def write_masks(image_paths, root_dir, add_mask_paths=True): 67 | if not image_paths: 68 | return image_paths 69 | 70 | alpha_dir = os.path.join(root_dir, 'alphas') 71 | mask_dir = os.path.join(root_dir, 'masks') 72 | if not os.path.exists(alpha_dir): 73 | os.mkdir(alpha_dir) 74 | if not os.path.exists(mask_dir): 75 | os.mkdir(mask_dir) 76 | 77 | new_image_paths = [] 78 | for i, [image_path, matting_path] in enumerate(image_paths): 79 | if (i + 1) % 1000 == 0: 80 | print('On image: {}/{}'.format(i + 1, len(image_paths))) 81 | 82 | matting_image = cv2.imread(matting_path, -1) 83 | if matting_image is None: 84 | print('Image does not exist: ', matting_path) 85 | continue 86 | alpha = utils.get_alpha(matting_image) 87 | mask = utils.to_mask(alpha) 88 | image_name = matting_path.split('/')[-1] 89 | alpha_path = os.path.join(alpha_dir, image_name) 90 | alpha_path = alpha_path.replace('\\', '/') 91 | mask_path = os.path.join(mask_dir, image_name) 92 | mask_path = mask_path.replace('\\', '/') 93 | cv2.imwrite(alpha_path, alpha) 94 | cv2.imwrite(mask_path, mask) 95 | 96 | if add_mask_paths: 97 | new_image_paths.append([image_path, matting_path, 98 | alpha_path, mask_path]) 99 | else: 100 | new_image_paths.append([image_path, matting_path]) 101 | 102 | print('Write successfully to: {} and {}'.format(alpha_dir, mask_dir)) 103 | print('Number of valid samples: ', len(new_image_paths)) 104 | return new_image_paths 105 | 106 | 107 | if __name__ == '__main__': 108 | root_dir = 'E://datasets/matting/Matting_Human_Half' 109 | train_txt_path = './train.txt' 110 | val_txt_path = './val.txt' 111 | image_paths = get_image_paths(root_dir=root_dir) 112 | image_paths = write_masks(image_paths, root_dir) 113 | train_image_paths, val_image_paths = split(image_paths) 114 | write_to_txt(train_image_paths, train_txt_path) 115 | write_to_txt(val_image_paths, val_txt_path) -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Nov 9 14:44:07 2018 4 | 5 | @author: shirhe-lyh 6 | """ 7 | 8 | import cv2 9 | import numpy as np 10 | import os 11 | 12 | 13 | def get_alpha(image): 14 | """Returns the alpha channel of a given image.""" 15 | if image.shape[2] > 3: 16 | alpha = image[:, :, 3] 17 | #alpha = remove_noise(alpha) 18 | else: 19 | reduced_image = np.sum(np.abs(255 - image), axis=2) 20 | alpha = np.where(reduced_image > 100, 255, 0) 21 | alpha = alpha.astype(np.uint8) 22 | return alpha 23 | 24 | 25 | def remove_noise(gray, area_threshold=5000): 26 | gray = gray.astype(np.uint8) 27 | ret, binary = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY) 28 | 29 | contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, 30 | cv2.CHAIN_APPROX_SIMPLE) 31 | 32 | remove_contours = [] 33 | for contour in contours: 34 | area = cv2.contourArea(contour) 35 | if area < area_threshold: 36 | remove_contours.append(contour) 37 | 38 | cv2.fillPoly(gray, remove_contours, 0) 39 | return gray 40 | 41 | 42 | def to_mask(alpha, threshold=50): 43 | mask = np.where(alpha > threshold, 1, 0) 44 | return mask.astype(np.uint8) 45 | 46 | 47 | def provide(txt_path, delimiter='@'): 48 | """Returns the paths of images. 49 | 50 | Args: 51 | txt_path: A .txt file with format: 52 | [image_path_11, image_path_12, ..., image_path_1n, 53 | image_path_21, image_path_22, ..., image_path_2n, 54 | ...]. 55 | 56 | Returns: 57 | The paths of images. 58 | 59 | Raises: 60 | ValueError: If txt_path does not exist. 61 | """ 62 | if not os.path.exists(txt_path): 63 | raise ValueError('`txt_path` does not exist.') 64 | 65 | with open(txt_path, 'r') as reader: 66 | txt_content = np.loadtxt(reader, str, delimiter=delimiter) 67 | np.random.shuffle(txt_content) 68 | image_paths = [] 69 | for line in txt_content: 70 | paths = [x for x in line] 71 | image_paths.append(paths) 72 | return image_paths 73 | -------------------------------------------------------------------------------- /data_dim/composition.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Nov 12 11:43:38 2019 4 | 5 | @author: shirhe-lyh 6 | """ 7 | 8 | import cv2 9 | import glob 10 | import numpy as np 11 | import os 12 | import uuid 13 | 14 | import retrieve 15 | 16 | 17 | def compose(fg, bg, alpha): 18 | if fg is None or bg is None or alpha is None: 19 | return None 20 | 21 | height, width, _ = fg.shape 22 | height_bg, width_bg, _ = bg.shape 23 | alpha_exp = np.expand_dims(alpha, axis=2) / 255. 24 | if min(height_bg, width_bg) >= max(height, width): 25 | bg_resized = bg[:height, :width] 26 | else: 27 | bg_resized = cv2.resize(bg, (width, height)) 28 | image = alpha_exp * fg + (1 - alpha_exp) * bg_resized 29 | return image.astype(np.uint8) 30 | 31 | 32 | if __name__ == '__main__': 33 | root_dir = 'xxx/Combined_Dataset' 34 | bg_image_root_dir = '/data/COCO/train2017' 35 | output_dir = '/data/matting/dim_composite_images' 36 | train_txt_path = './train.txt' 37 | val_txt_path = './val.txt' 38 | num_bg_images_per_fg = 50 39 | 40 | train_fg_alpha_paths = retrieve.get_image_paths(root_dir) 41 | test_fg_alpha_paths = retrieve.get_image_paths(root_dir, dataset='Test_set') 42 | 43 | bg_image_paths = glob.glob(os.path.join(bg_image_root_dir, '*.*')) 44 | np.random.shuffle(bg_image_paths) 45 | 46 | if not os.path.exists(output_dir): 47 | os.makedirs(output_dir) 48 | 49 | # Generate training data 50 | index = 0 51 | fg_bg_alpha_paths = [] 52 | iterator = iter(bg_image_paths) 53 | for fg_path, alpha_path in train_fg_alpha_paths: 54 | for i in range(num_bg_images_per_fg): 55 | fg = cv2.imread(fg_path) 56 | alpha = cv2.imread(alpha_path, 0) 57 | bg_path = next(iterator) 58 | bg = cv2.imread(bg_path) 59 | image = compose(fg, bg, alpha) 60 | if image is None: 61 | print(fg_path) 62 | continue 63 | output_path = os.path.join(output_dir, str(uuid.uuid4()) + '.jpg') 64 | cv2.imwrite(output_path, image) 65 | 66 | bg_path = bg_path.replace('\\', '/') 67 | fg_bg_alpha_paths.append([output_path, fg_path, alpha_path, bg_path]) 68 | index += 1 69 | if index % 50 == 0: 70 | print('On image: {}/{}'.format(index, len(train_fg_alpha_paths))) 71 | 72 | retrieve.write_to_txt(fg_bg_alpha_paths, train_txt_path) 73 | 74 | # Generate validation data 75 | fg_bg_alpha_paths = [] 76 | num_bg_images_per_fg = 10 77 | iterator = iter(bg_image_paths) 78 | for fg_path, alpha_path in test_fg_alpha_paths: 79 | for i in range(num_bg_images_per_fg): 80 | bg_path = next(iterator) 81 | bg_path = bg_path.replace('\\', '/') 82 | fg_bg_alpha_paths.append([fg_path, alpha_path, bg_path]) 83 | 84 | retrieve.write_to_txt(fg_bg_alpha_paths, val_txt_path) 85 | 86 | 87 | 88 | 89 | -------------------------------------------------------------------------------- /data_dim/retrieve.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Nov 12 11:22:00 2019 4 | 5 | @author: shirhe-lyh 6 | """ 7 | 8 | import os 9 | 10 | 11 | def get_image_paths(root_dir, dataset='Training_set'): 12 | image_paths_dict = {} 13 | matting_paths_dict = {} 14 | sub_root_dir = os.path.join(root_dir, dataset) 15 | for root, dirs, files in os.walk(sub_root_dir): 16 | if not files: 17 | continue 18 | 19 | for file in files: 20 | file_path = os.path.join(root, file) 21 | file_path = file_path.replace('\\', '/') 22 | file_name = file.split('.')[0] 23 | dir_name = file_path.split('/')[-2] 24 | if dir_name.startswith('fg'): 25 | image_paths_dict[file_name] = file_path 26 | if dir_name.startswith('alpha'): 27 | matting_paths_dict[file_name] = file_path 28 | 29 | image_corresponding_paths = [] 30 | for image_name, path in image_paths_dict.items(): 31 | matting_path = matting_paths_dict.get(image_name, None) 32 | if matting_path is not None: 33 | image_corresponding_paths.append([path, matting_path]) 34 | else: 35 | print(path) 36 | print('Number of valid images: ', len(image_corresponding_paths)) 37 | if len(image_corresponding_paths) < 1: 38 | raise ValueError('`root_dir` is error. Please reset it correctly.') 39 | return image_corresponding_paths 40 | 41 | 42 | def write_to_txt(image_paths, txt_path, delimiter='@'): 43 | if image_paths is None: 44 | return 45 | 46 | with open(txt_path, 'w') as writer: 47 | for element in image_paths: 48 | line = delimiter.join(element) 49 | writer.write(line + '\n') 50 | print('Write successfully to: ', txt_path) -------------------------------------------------------------------------------- /data_dim/test/alpha/16452523375_08591714cf_o.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shirhe-Lyh/deep_image_matting_pytorch/643e52975f153919eaed62de6dd87d31a65d6a0e/data_dim/test/alpha/16452523375_08591714cf_o.png -------------------------------------------------------------------------------- /data_dim/test/alpha/girl-1467820_1280.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shirhe-Lyh/deep_image_matting_pytorch/643e52975f153919eaed62de6dd87d31a65d6a0e/data_dim/test/alpha/girl-1467820_1280.png -------------------------------------------------------------------------------- /data_dim/test/bg/000000239187.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shirhe-Lyh/deep_image_matting_pytorch/643e52975f153919eaed62de6dd87d31a65d6a0e/data_dim/test/bg/000000239187.jpg -------------------------------------------------------------------------------- /data_dim/test/bg/000000393412.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shirhe-Lyh/deep_image_matting_pytorch/643e52975f153919eaed62de6dd87d31a65d6a0e/data_dim/test/bg/000000393412.jpg -------------------------------------------------------------------------------- /data_dim/test/fg/16452523375_08591714cf_o.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shirhe-Lyh/deep_image_matting_pytorch/643e52975f153919eaed62de6dd87d31a65d6a0e/data_dim/test/fg/16452523375_08591714cf_o.png -------------------------------------------------------------------------------- /data_dim/test/fg/girl-1467820_1280.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shirhe-Lyh/deep_image_matting_pytorch/643e52975f153919eaed62de6dd87d31a65d6a0e/data_dim/test/fg/girl-1467820_1280.png -------------------------------------------------------------------------------- /data_dim/test/trimap/16452523375_08591714cf_o_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shirhe-Lyh/deep_image_matting_pytorch/643e52975f153919eaed62de6dd87d31a65d6a0e/data_dim/test/trimap/16452523375_08591714cf_o_0.png -------------------------------------------------------------------------------- /data_dim/test/trimap/girl-1467820_1280_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shirhe-Lyh/deep_image_matting_pytorch/643e52975f153919eaed62de6dd87d31a65d6a0e/data_dim/test/trimap/girl-1467820_1280_0.png -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Jul 11 20:46:51 2019 4 | 5 | @author: shirhe-lyh 6 | """ 7 | 8 | import cv2 9 | import numpy as np 10 | import os 11 | import PIL 12 | import torch 13 | import torchvision as tv 14 | 15 | from data import utils 16 | 17 | 18 | def random_dilate(alpha, low=1, high=5, mode='constant'): 19 | """Dilation.""" 20 | iterations = np.random.randint(1, 20) 21 | erode_ksize = np.random.randint(low=low, high=high) 22 | dilate_ksize = np.random.randint(low=low, high=high) 23 | erode_kernel = cv2.getStructuringElement( 24 | cv2.MORPH_ELLIPSE, (erode_ksize, erode_ksize)) 25 | dilate_kernel = cv2.getStructuringElement( 26 | cv2.MORPH_ELLIPSE, (dilate_ksize, dilate_ksize)) 27 | alpha_eroded = cv2.erode(alpha, erode_kernel, iterations=iterations) 28 | alpha_dilated = cv2.dilate(alpha, dilate_kernel, iterations=iterations) 29 | if mode == 'constant': 30 | alpha_noise = 128 * np.ones_like(alpha) 31 | alpha_noise[alpha_eroded >= 255] = 255 32 | alpha_noise[alpha_dilated <= 0] = 0 33 | else: 34 | value = np.random.randint(low=100, high=255) 35 | alpha_noise = value * ((alpha_dilated - alpha_eroded) / 255.) 36 | alpha_noise += alpha_eroded 37 | return alpha_noise 38 | 39 | 40 | def crop_offset(trimap, crop_size=320): 41 | """Generate top-left corner to crop.""" 42 | trimap_ = np.where(trimap == 255, 0, trimap) 43 | y_indices, x_indices = np.where(trimap_ > 0) 44 | num_unknowns = len(y_indices) 45 | y, x = 0, 0 46 | if num_unknowns > 0: 47 | index = np.random.randint(low=0, high=num_unknowns) 48 | xc = x_indices[index] 49 | yc = y_indices[index] 50 | y = max(0, yc - crop_size // 2) 51 | x = max(0, xc - crop_size // 2) 52 | return y, x 53 | 54 | 55 | class MattingDataset(torch.utils.data.Dataset): 56 | """Read dataset for Matting.""" 57 | 58 | def __init__(self, annotation_path, root_dir=None, transforms=None, 59 | output_size=320, dilation_mode='constant'): 60 | self._transforms = transforms 61 | self._output_size = output_size 62 | self._dilation_mode = dilation_mode 63 | 64 | # Transform 65 | if transforms is None: 66 | channel_means = [0.485, 0.456, 0.406] 67 | channel_std = [0.229, 0.224, 0.225] 68 | self._transforms = tv.transforms.Compose([ 69 | tv.transforms.ColorJitter(brightness=32/255., contrast=0.5, 70 | saturation=0.5, hue=0.2), 71 | tv.transforms.ToTensor(), 72 | tv.transforms.Normalize(mean=channel_means, std=channel_std)]) 73 | 74 | # Format [[image_path, alpha_path], ...] 75 | self._image_alpha_paths = self.get_image_mask_paths(annotation_path, 76 | root_dir=root_dir) 77 | self._remove_invalid_data() 78 | 79 | def __getitem__(self, index): 80 | image_path, alpha_path = self._image_alpha_paths[index] 81 | image = PIL.Image.open(image_path) 82 | alpha = PIL.Image.open(alpha_path) 83 | 84 | # Rotate 85 | # degree = np.random.randint(low=-30, high=30) 86 | # image = image.rotate(degree) 87 | # alpha = alpha.rotate(degree) 88 | 89 | # Crop 90 | width, height = alpha.size 91 | min_size = np.min((width, height)) 92 | #crop_sizes = [320, 480, 600, 800] # For Matting_Human_Half 93 | crop_sizes = [320, 480, 640] # For deep image matting dataset 94 | crop_size = np.random.choice(crop_sizes) 95 | if min_size >= crop_size: 96 | alpha_noise = random_dilate(alpha=np.array(alpha), 97 | mode=self._dilation_mode) 98 | height_offset, width_offset = crop_offset(alpha_noise, crop_size) 99 | box = (width_offset, height_offset, width_offset+crop_size, 100 | height_offset+crop_size) 101 | image = image.crop(box=box) 102 | alpha = alpha.crop(box=box) 103 | 104 | # Resize 105 | if crop_size > self._output_size or min_size < crop_size: 106 | image = image.resize((self._output_size, self._output_size), 107 | PIL.Image.ANTIALIAS) 108 | alpha = alpha.resize((self._output_size, self._output_size), 109 | PIL.Image.NEAREST) 110 | 111 | # Flip 112 | prob = np.random.uniform() 113 | if prob > 0.5: 114 | image = image.transpose(PIL.Image.FLIP_LEFT_RIGHT) 115 | alpha = alpha.transpose(PIL.Image.FLIP_LEFT_RIGHT) 116 | 117 | # Dilate, Erode 118 | alpha = np.array(alpha) 119 | alpha_noise = random_dilate(alpha, mode=self._dilation_mode) 120 | mask = np.ones_like(alpha_noise) 121 | if self._dilation_mode == 'constant': 122 | mask = np.equal(alpha_noise, 128).astype(np.float32) 123 | 124 | alpha = torch.Tensor(alpha / 255.) 125 | alpha_noise = torch.Tensor(alpha_noise / 255.) 126 | mask = torch.Tensor(mask) 127 | # Transform 128 | image_preprocessed = self._transforms(image) 129 | alpha_u = torch.unsqueeze(alpha, dim=0) 130 | mask_u = torch.unsqueeze(mask, dim=0) 131 | alpha_noise_u = torch.unsqueeze(alpha_noise, dim=0) 132 | image_concated = torch.cat([image_preprocessed, alpha_noise_u], dim=0) 133 | return image_concated, alpha_u, alpha_noise_u, mask_u 134 | 135 | def __len__(self): 136 | return len(self._image_alpha_paths) 137 | 138 | def get_image_mask_paths(self, annotation_path, root_dir=None): 139 | """Get the paths of images and masks. 140 | 141 | Args: 142 | annotation_path: A file contains the paths of images and masks. 143 | 144 | Returns: 145 | A list [[image_path, mask_path], [image_path, mask_path], ...]. 146 | 147 | Raises: 148 | ValueError: If annotation_file does not exist. 149 | """ 150 | # Format: [[image_path, matting_path, alpha_path, mask_path], ...] 151 | image_matting_alpha_mask_paths = utils.provide(annotation_path) 152 | # Remove matting_paths, mask_paths 153 | image_alpha_paths = [] 154 | for image_path, _, alpha_path, _ in image_matting_alpha_mask_paths: 155 | if root_dir is not None: 156 | if not image_path.startswith(root_dir): 157 | image_path = os.path.join(root_dir, image_path) 158 | alpha_path = os.path.join(root_dir, alpha_path) 159 | image_path = image_path.replace('\\', '/') 160 | alpha_path = alpha_path.replace('\\', '/') 161 | image_alpha_paths.append([image_path, alpha_path]) 162 | return image_alpha_paths 163 | 164 | def _remove_invalid_data(self): 165 | valid_data = [] 166 | for image_path, alpha_path in self._image_alpha_paths: 167 | if not os.path.exists(image_path): 168 | continue 169 | if not os.path.exists(alpha_path): 170 | continue 171 | valid_data.append([image_path, alpha_path]) 172 | self._image_alpha_paths = valid_data 173 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Jul 21 07:08:58 2019 4 | 5 | @author: shirhe-lyh 6 | 7 | Implementation of paper: 8 | Deep Image Matting, Ning Xu, eta., arxiv:1703.03872 9 | """ 10 | 11 | import torch 12 | import torchvision as tv 13 | 14 | VGG16_BN_MODEL_URL = 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth' 15 | 16 | VGG16_BN_CONFIGS = { 17 | '13conv': 18 | [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 19 | 'M', 512, 512, 512], 20 | '10conv': 21 | [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512] 22 | } 23 | 24 | 25 | def make_layers(cfg, batch_norm=False): 26 | """Copy from: torchvision/models/vgg. 27 | 28 | Changs retrue_indices in MaxPool2d from False to True. 29 | """ 30 | layers = [] 31 | in_channels=3 32 | for v in cfg: 33 | if v == 'M': 34 | layers += [torch.nn.MaxPool2d(kernel_size=2, stride=2, 35 | return_indices=True)] 36 | else: 37 | conv2d = torch.nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 38 | if batch_norm: 39 | layers += [conv2d, torch.nn.BatchNorm2d(v), 40 | torch.nn.ReLU(inplace=True)] 41 | else: 42 | layers += [conv2d, torch.nn.ReLU(inplace=True)] 43 | in_channels = v 44 | return torch.nn.Sequential(*layers) 45 | 46 | 47 | class VGGFeatureExtractor(torch.nn.Module): 48 | """Feature extractor by VGG network.""" 49 | 50 | def __init__(self, config=None, batch_norm=True): 51 | """Constructor. 52 | 53 | Args: 54 | config: The convolutional architecture of VGG network. 55 | batch_norm: A boolean indicating whether the architecture 56 | include Batch Normalization layers or not. 57 | """ 58 | super(VGGFeatureExtractor, self).__init__() 59 | self._config = config 60 | if self._config is None: 61 | self._config = VGG16_BN_CONFIGS.get('10conv') 62 | self.features = make_layers(self._config, batch_norm=batch_norm) 63 | self._indices = None 64 | self._pre_pool_shapes = None 65 | 66 | def forward(self, x): 67 | self._indices = [] 68 | self._pre_pool_shapes = [] 69 | for layer in self.features: 70 | if isinstance(layer, torch.nn.modules.pooling.MaxPool2d): 71 | self._pre_pool_shapes.append(x.size()) 72 | x, indices = layer(x) 73 | self._indices.append(indices) 74 | else: 75 | x = layer(x) 76 | return x 77 | 78 | 79 | def vgg16_bn_feature_extractor(config=None, pretrained=True, progress=True): 80 | model = VGGFeatureExtractor(config, batch_norm=True) 81 | if pretrained: 82 | state_dict = tv.models.utils.load_state_dict_from_url( 83 | VGG16_BN_MODEL_URL, progress=progress) 84 | model.load_state_dict(state_dict, strict=False) 85 | return model 86 | 87 | 88 | class DIM(torch.nn.Module): 89 | """Deep Image Matting.""" 90 | 91 | def __init__(self, feature_extractor): 92 | """Constructor. 93 | 94 | Args: 95 | feature_extractor: Feature extractor, such as VGGFeatureExtractor. 96 | """ 97 | super(DIM, self).__init__() 98 | # Head convolution layer, number of channels: 4 -> 3 99 | self._head_conv = torch.nn.Conv2d(in_channels=4, out_channels=3, 100 | kernel_size=5, padding=2) 101 | # Encoder 102 | self._feature_extractor = feature_extractor 103 | self._feature_extract_config = self._feature_extractor._config 104 | # Decoder 105 | self._decode_layers = self.decode_layers() 106 | # Prediction 107 | self._final_conv = torch.nn.Conv2d(self._feature_extract_config[0], 1, 108 | kernel_size=5, padding=2) 109 | self._sigmoid = torch.nn.Sigmoid() 110 | 111 | def forward(self, x): 112 | x = self._head_conv(x) 113 | x = self._feature_extractor(x) 114 | indices = self._feature_extractor._indices[::-1] 115 | unpool_shapes = self._feature_extractor._pre_pool_shapes[::-1] 116 | index = 0 117 | for layer in self._decode_layers: 118 | if isinstance(layer, torch.nn.modules.pooling.MaxUnpool2d): 119 | x = layer(x, indices[index], output_size=unpool_shapes[index]) 120 | index += 1 121 | else: 122 | x = layer(x) 123 | x = self._final_conv(x) 124 | x = self._sigmoid(x) 125 | return x 126 | 127 | def decode_layers(self): 128 | layers = [] 129 | strides = [1] 130 | channels = [] 131 | config_reversed = self._feature_extract_config[::-1] 132 | for i, v in enumerate(config_reversed): 133 | if v == 'M': 134 | strides.append(2) 135 | channels.append(config_reversed[i+1]) 136 | channels.append(channels[-1]) 137 | in_channels = self._feature_extract_config[-1] 138 | for stride, out_channels in zip(strides, channels): 139 | if stride == 2: 140 | layers += [torch.nn.MaxUnpool2d(kernel_size=2, stride=2)] 141 | layers += [torch.nn.Conv2d(in_channels, out_channels, 142 | kernel_size=5, padding=2), 143 | torch.nn.BatchNorm2d(num_features=out_channels), 144 | torch.nn.ReLU(inplace=True)] 145 | in_channels = out_channels 146 | return torch.nn.Sequential(*layers) 147 | 148 | 149 | def loss(alphas_pred, alphas_gt, masks, images=None, epsilon=1e-12): 150 | diff = alphas_pred - alphas_gt 151 | diff = diff * masks 152 | num_unkowns = torch.sum(masks) + epsilon 153 | losses = torch.sqrt(torch.mul(diff, diff) + epsilon) 154 | loss = torch.sum(losses) / num_unkowns 155 | if images is not None: 156 | images_fg_gt = torch.mul(images, alphas_gt) 157 | images_fg_pred = torch.mul(images, alphas_pred) 158 | images_fg_diff = images_fg_pred - images_fg_gt 159 | images_fg_diff = images_fg_diff * masks 160 | losses_image = torch.sqrt( 161 | torch.mul(images_fg_diff, images_fg_diff) + epsilon) 162 | loss += torch.sum(losses_image) / num_unkowns 163 | return loss -------------------------------------------------------------------------------- /model_deconv.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Jul 9 15:40:13 2019 4 | 5 | @author: shirhe-lyh 6 | """ 7 | 8 | import torch 9 | import torchvision as tv 10 | 11 | 12 | VGG16_BN_URL = 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth' 13 | 14 | CFGS = {'all': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 15 | 512, 512, 512, 'M', 512, 512, 512, 'M'], 16 | '13conv': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 17 | 512, 512, 512, 'M', 512, 512, 512], 18 | '10conv': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 19 | 512, 512, 512]} 20 | 21 | 22 | class VGGFeatureExtractor(torch.nn.Module): 23 | """Extract features by VGG networks.""" 24 | 25 | def __init__(self, cfg=None, init_weights=True): 26 | """Constructor. 27 | 28 | Args: 29 | cfg: Neural architecture (exclude FC layers) of VGG net. 30 | init_weights: A boolean indicating whether to initialize the 31 | weights or not. 32 | """ 33 | super(VGGFeatureExtractor, self).__init__() 34 | self._cfg = cfg 35 | self.features = self._vgg16_bn_features() 36 | if init_weights: 37 | self._initialize_weights() 38 | 39 | def forward(self, x): 40 | """Forward computation. 41 | 42 | Args: 43 | x: A float32 tensor with shape [batch_size, channels, height, width] 44 | """ 45 | x = self._features(x) 46 | return x 47 | 48 | def _vgg16_bn_features(self): 49 | if self._cfg is None: 50 | self._cfg = CFGS['all'] 51 | return tv.models.vgg.make_layers(self._cfg, batch_norm=True) 52 | 53 | def _initialize_weights(self): 54 | for m in self.modules(): 55 | if isinstance(m, torch.nn.Conv2d): 56 | torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', 57 | nonlinearity='relu') 58 | if m.bias is not None: 59 | torch.nn.init.constant_(m.bias, 0) 60 | elif isinstance(m, torch.nn.BatchNorm2d): 61 | torch.nn.init.constant_(m.weight, 1) 62 | torch.nn.init.constant_(m.bias, 0) 63 | elif isinstance(m, torch.nn.Linear): 64 | torch.nn.init.normal_(m.weight, 0, 0.01) 65 | torch.nn.init.constant_(m.bias, 0) 66 | 67 | 68 | def vgg16_bn_feature_extractor(pretrained=False, cfg=None, progress=True, 69 | **kwargs): 70 | if pretrained: 71 | kwargs['init_weights'] = False 72 | model = VGGFeatureExtractor(cfg, **kwargs) 73 | if pretrained: 74 | state_dict = tv.models.utils.load_state_dict_from_url( 75 | VGG16_BN_URL, progress=progress) 76 | model.load_state_dict(state_dict, strict=False) 77 | return model 78 | 79 | 80 | class DIMDecoder(torch.nn.Module): 81 | """Decoder of Deep Image Matting.""" 82 | 83 | def __init__(self, cfg=None, init_weights=True): 84 | """Constructor. 85 | Args: 86 | cfg: Neural architecture (exclude FC layers) of VGG net. 87 | init_weights: A boolean indicating whether to initialize the 88 | weights or not. 89 | """ 90 | super(DIMDecoder, self).__init__() 91 | self._cfg = cfg 92 | self._decoder = self._dim_decoder() 93 | if init_weights: 94 | self._init_weights() 95 | 96 | def forward(self, x): 97 | return self._decoder(x) 98 | 99 | def _dim_decoder(self): 100 | if self._cfg is None: 101 | self._cfg = CFGS['all'] 102 | deconv_strides = [1] 103 | deconv_channels = [512] 104 | cfg_reversed = self._cfg[::-1] 105 | for i, e in enumerate(cfg_reversed): 106 | if e == 'M': 107 | deconv_strides.append(2) 108 | deconv_channels.append(cfg_reversed[i+1]) 109 | layers = [] 110 | in_channels = deconv_channels[0] 111 | for stride, num_channels in zip(deconv_strides, deconv_channels): 112 | output_padding = 0 if stride == 1 else 1 113 | layers += [torch.nn.ConvTranspose2d( 114 | in_channels=in_channels, out_channels=num_channels, 115 | kernel_size=5, padding=2, output_padding=output_padding, 116 | stride=stride)] 117 | layers += [torch.nn.BatchNorm2d(num_features=num_channels), 118 | torch.nn.ReLU(inplace=True)] 119 | in_channels = num_channels 120 | return torch.nn.Sequential(*layers) 121 | 122 | def _init_weights(self): 123 | for m in self.modules(): 124 | if isinstance(m, torch.nn.ConvTranspose2d): 125 | torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', 126 | nonlinearity='relu') 127 | elif isinstance(m, torch.nn.BatchNorm2d): 128 | torch.nn.init.constant_(m.weight, 1) 129 | 130 | 131 | class DIM(torch.nn.Module): 132 | """Variant of Deep Image Matting.""" 133 | 134 | def __init__(self, cfg=None): 135 | """Constructor. 136 | 137 | Args: 138 | num_classes: Number of classes. 139 | """ 140 | super(DIM, self).__init__() 141 | if cfg is None: 142 | cfg = CFGS.get('10conv') 143 | self._head_conv = torch.nn.Conv2d(in_channels=4, 144 | out_channels=3, 145 | kernel_size=5, 146 | padding=2) 147 | self._head_batchnorm = torch.nn.BatchNorm2d(num_features=3) 148 | self._head_relu = torch.nn.ReLU(inplace=True) 149 | self._feature_extractor = vgg16_bn_feature_extractor( 150 | pretrained=True, cfg=cfg) 151 | self._decoder = DIMDecoder(cfg) 152 | self._alpha_conv = torch.nn.Conv2d(in_channels=64, 153 | out_channels=1, 154 | kernel_size=5, 155 | padding=2) 156 | self._sigmoid = torch.nn.Sigmoid() 157 | # Random initialization 158 | torch.nn.init.kaiming_normal_(self._head_conv.weight, mode='fan_out', 159 | nonlinearity='relu') 160 | torch.nn.init.constant_(self._head_batchnorm.weight, 1) 161 | torch.nn.init.kaiming_normal_(self._alpha_conv.weight, mode='fan_out', 162 | nonlinearity='sigmoid') 163 | 164 | def forward(self, x): 165 | x = self._head_conv(x) 166 | x = self._head_batchnorm(x) 167 | x = self._head_relu(x) 168 | x = self._feature_extractor(x) 169 | x = self._decoder(x) 170 | x = self._alpha_conv(x) 171 | x = self._sigmoid(x) 172 | return x 173 | 174 | 175 | def loss(alphas_pred, alphas_gt, masks, images=None, epsilon=1e-12): 176 | diff = alphas_pred - alphas_gt 177 | diff = diff * masks 178 | num_unkowns = torch.sum(masks) + epsilon 179 | losses = torch.sqrt(torch.mul(diff, diff) + epsilon) 180 | loss = torch.sum(losses) / num_unkowns 181 | if images is not None: 182 | images_fg_gt = torch.mul(images, alphas_gt) 183 | images_fg_pred = torch.mul(images, alphas_pred) 184 | images_fg_diff = images_fg_pred - images_fg_gt 185 | images_fg_diff = images_fg_diff * masks 186 | losses_image = torch.sqrt( 187 | torch.mul(images_fg_diff, images_fg_diff) + epsilon) 188 | loss += torch.sum(losses_image) / num_unkowns 189 | return loss 190 | -------------------------------------------------------------------------------- /model_paper.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Jul 21 07:08:58 2019 4 | 5 | @author: shirhe-lyh 6 | 7 | Implementation of paper: 8 | Deep Image Matting, Ning Xu, eta., arxiv:1703.03872 9 | """ 10 | 11 | import torch 12 | import torchvision as tv 13 | 14 | VGG16_BN_MODEL_URL = 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth' 15 | 16 | VGG16_BN_CONFIGS = { 17 | '13conv': 18 | [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 19 | 'M', 512, 512, 512], 20 | '10conv': 21 | [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512] 22 | } 23 | 24 | 25 | def make_layers(cfg, in_channels=3, batch_norm=False): 26 | """Copy from: torchvision/models/vgg. 27 | 28 | Changs retrue_indices in MaxPool2d from False to True. 29 | """ 30 | layers = [] 31 | for v in cfg: 32 | if v == 'M': 33 | layers += [torch.nn.MaxPool2d(kernel_size=2, stride=2, 34 | return_indices=True)] 35 | else: 36 | conv2d = torch.nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 37 | if batch_norm: 38 | layers += [conv2d, torch.nn.BatchNorm2d(v), 39 | torch.nn.ReLU(inplace=True)] 40 | else: 41 | layers += [conv2d, torch.nn.ReLU(inplace=True)] 42 | in_channels = v 43 | return torch.nn.Sequential(*layers) 44 | 45 | 46 | class VGGFeatureExtractor(torch.nn.Module): 47 | """Feature extractor by VGG network.""" 48 | 49 | def __init__(self, config=None, batch_norm=True): 50 | """Constructor. 51 | 52 | Args: 53 | config: The convolutional architecture of VGG network. 54 | batch_norm: A boolean indicating whether the architecture 55 | include Batch Normalization layers or not. 56 | """ 57 | super(VGGFeatureExtractor, self).__init__() 58 | self._config = config 59 | if self._config is None: 60 | self._config = VGG16_BN_CONFIGS.get('10conv') 61 | self.features = make_layers(self._config, in_channels=4, 62 | batch_norm=batch_norm) 63 | self._indices = None 64 | self._pre_pool_shapes = None 65 | 66 | def forward(self, x): 67 | self._indices = [] 68 | self._pre_pool_shapes = [] 69 | for layer in self.features: 70 | if isinstance(layer, torch.nn.modules.pooling.MaxPool2d): 71 | self._pre_pool_shapes.append(x.size()) 72 | x, indices = layer(x) 73 | self._indices.append(indices) 74 | else: 75 | x = layer(x) 76 | return x 77 | 78 | 79 | def vgg16_bn_feature_extractor(config=None, pretrained=True, progress=True): 80 | model = VGGFeatureExtractor(config, batch_norm=True) 81 | if pretrained: 82 | state_dict = tv.models.utils.load_state_dict_from_url( 83 | VGG16_BN_MODEL_URL, progress=progress) 84 | conv1_weight_name = 'features.0.weight' 85 | conv1_weight = model.state_dict()[conv1_weight_name] 86 | conv1_weight[:, :3, :, :] = state_dict[conv1_weight_name] 87 | conv1_weight[:, 3, :, :] = torch.tensor(0) 88 | state_dict[conv1_weight_name] = conv1_weight 89 | model.load_state_dict(state_dict, strict=False) 90 | return model 91 | 92 | 93 | class DIM(torch.nn.Module): 94 | """Deep Image Matting.""" 95 | 96 | def __init__(self, feature_extractor): 97 | """Constructor. 98 | 99 | Args: 100 | feature_extractor: Feature extractor, such as VGGFeatureExtractor. 101 | """ 102 | super(DIM, self).__init__() 103 | # Encoder 104 | self._feature_extractor = feature_extractor 105 | self._feature_extract_config = self._feature_extractor._config 106 | # Decoder 107 | self._decode_layers = self.decode_layers() 108 | # Prediction 109 | self._final_conv = torch.nn.Conv2d(self._feature_extract_config[0], 1, 110 | kernel_size=5, padding=2) 111 | self._sigmoid = torch.nn.Sigmoid() 112 | # Initialization 113 | self._init_weights([self._final_conv]) 114 | self._init_weights(self._decode_layers) 115 | 116 | def forward(self, x): 117 | x = self._feature_extractor(x) 118 | indices = self._feature_extractor._indices[::-1] 119 | unpool_shapes = self._feature_extractor._pre_pool_shapes[::-1] 120 | index = 0 121 | for layer in self._decode_layers: 122 | if isinstance(layer, torch.nn.modules.pooling.MaxUnpool2d): 123 | x = layer(x, indices[index], output_size=unpool_shapes[index]) 124 | index += 1 125 | else: 126 | x = layer(x) 127 | x = self._final_conv(x) 128 | x = self._sigmoid(x) 129 | return x 130 | 131 | def decode_layers(self): 132 | layers = [] 133 | strides = [1] 134 | channels = [] 135 | config_reversed = self._feature_extract_config[::-1] 136 | for i, v in enumerate(config_reversed): 137 | if v == 'M': 138 | strides.append(2) 139 | channels.append(config_reversed[i+1]) 140 | channels.append(channels[-1]) 141 | in_channels = self._feature_extract_config[-1] 142 | for stride, out_channels in zip(strides, channels): 143 | if stride == 2: 144 | layers += [torch.nn.MaxUnpool2d(kernel_size=2, stride=2)] 145 | layers += [torch.nn.Conv2d(in_channels, out_channels, 146 | kernel_size=5, padding=2), 147 | torch.nn.BatchNorm2d(num_features=out_channels), 148 | torch.nn.ReLU(inplace=True)] 149 | in_channels = out_channels 150 | return torch.nn.Sequential(*layers) 151 | 152 | def _init_weights(self, layers): 153 | for layer in layers: 154 | if isinstance(layer, torch.nn.Conv2d): 155 | torch.nn.init.kaiming_normal_(layer.weight, mode='fan_out', 156 | nonlinearity='relu') 157 | if layer.bias is not None: 158 | torch.nn.init.constant_(layer.bias, 0) 159 | elif isinstance(layer, torch.nn.BatchNorm2d): 160 | torch.nn.init.constant_(layer.weight, 1) 161 | torch.nn.init.constant_(layer.bias, 0) 162 | 163 | 164 | def loss(alphas_pred, alphas_gt, masks, images=None, epsilon=1e-12): 165 | diff = alphas_pred - alphas_gt 166 | diff = diff * masks 167 | num_unkowns = torch.sum(masks) + epsilon 168 | losses = torch.sqrt(torch.mul(diff, diff) + epsilon) 169 | loss = torch.sum(losses) / num_unkowns 170 | if images is not None: 171 | images_fg_gt = torch.mul(images, alphas_gt) 172 | images_fg_pred = torch.mul(images, alphas_pred) 173 | images_fg_diff = images_fg_pred - images_fg_gt 174 | images_fg_diff = images_fg_diff * masks 175 | losses_image = torch.sqrt( 176 | torch.mul(images_fg_diff, images_fg_diff) + epsilon) 177 | loss += torch.sum(losses_image) / num_unkowns 178 | return loss -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Jun 20 17:50:00 2019 4 | 5 | @author: shirhe-lyh 6 | """ 7 | 8 | import cv2 9 | import glob 10 | import numpy as np 11 | import os 12 | import torch 13 | import torchvision as tv 14 | 15 | import dataset 16 | import model 17 | 18 | # Device configuration 19 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 20 | 21 | 22 | def composite(image, alpha): 23 | alpha_exp = np.expand_dims(alpha, axis=2) 24 | image_ = np.concatenate([image, alpha_exp], axis=2) 25 | return image_ 26 | 27 | 28 | if __name__ == '__main__': 29 | ckpt_path = './models/model.ckpt' 30 | test_images_dir = './test/' 31 | output_masks_dir = './test/pred_alphas' 32 | test_images_paths = glob.glob(os.path.join(test_images_dir, '*.png')) 33 | 34 | if not os.path.exists(ckpt_path): 35 | raise ValueError('`ckpt_path` does not exist.') 36 | if not os.path.exists(output_masks_dir): 37 | os.makedirs(output_masks_dir) 38 | 39 | feature_extractor = model.vgg16_bn_feature_extractor( 40 | model.VGG16_BN_CONFIGS.get('13conv')).to(device) 41 | dim = model.DIM(feature_extractor).to(device) 42 | #dim.load_state_dict(torch.load(ckpt_path)) 43 | dim_pretrained_params = torch.load(ckpt_path).items() 44 | dim_state_dict = {k.replace('module.', ''): v for k, v in 45 | dim_pretrained_params} 46 | dim.load_state_dict(dim_state_dict) 47 | print('Load DIM pretrained parameters, Done') 48 | 49 | # Transform 50 | channel_means = [0.485, 0.456, 0.406] 51 | channel_std = [0.229, 0.224, 0.225] 52 | transforms = tv.transforms.Compose([ 53 | tv.transforms.ToTensor(), 54 | tv.transforms.Normalize(mean=channel_means, std=channel_std)]) 55 | 56 | dim.eval() 57 | with torch.no_grad(): 58 | for image_path in test_images_paths: 59 | image = cv2.imread(image_path, -1) 60 | image_fg = cv2.imread(image_path.replace('.png', '.jpg')) 61 | alpha = image[:, :, 3] 62 | image_rgb = cv2.cvtColor(image_fg, cv2.COLOR_BGR2RGB) 63 | image_processed = transforms(image_rgb).to(device) 64 | 65 | alpha_noise = dataset.random_dilate(alpha) 66 | alpha_noise_exp = np.expand_dims(alpha_noise / 255., axis=0) 67 | alpha_noise_exp = torch.Tensor(alpha_noise_exp).to(device) 68 | images = torch.cat([image_processed, alpha_noise_exp], dim=0) 69 | images = torch.unsqueeze(images, dim=0) 70 | 71 | outputs = dim(images) 72 | alpha_pred = outputs.data.cpu().numpy()[0][0] 73 | alpha_pred = 255 * alpha_pred 74 | alpha_pred = alpha_pred.astype(np.uint8) 75 | 76 | image_name = image_path.split('/')[-1] 77 | output_path = os.path.join(output_masks_dir, image_name) 78 | cv2.imwrite(output_path, composite(image_fg, alpha_pred)) 79 | output_path = os.path.join(output_masks_dir, 80 | image_name.replace('.png', '_alpha.png')) 81 | cv2.imwrite(output_path, alpha) 82 | output_path = os.path.join(output_masks_dir, 83 | image_name.replace('.png', '_matte.png')) 84 | cv2.imwrite(output_path, alpha_pred) 85 | output_path = os.path.join(output_masks_dir, 86 | image_name.replace('.png', '_noise.png')) 87 | cv2.imwrite(output_path, alpha_noise) 88 | -------------------------------------------------------------------------------- /predict_trimap.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Nov 10 23:01:35 2019 4 | 5 | @author: john 6 | """ 7 | 8 | import cv2 9 | import glob 10 | import numpy as np 11 | import os 12 | import torch 13 | import torchvision as tv 14 | 15 | import model 16 | 17 | # Device configuration 18 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 19 | 20 | 21 | def matte(image, alpha): 22 | alpha_exp = np.expand_dims(alpha, axis=2) 23 | image_ = np.concatenate([image, alpha_exp], axis=2) 24 | return image_ 25 | 26 | 27 | def compose(fg, bg, alpha): 28 | if fg is None or bg is None or alpha is None: 29 | return None 30 | 31 | height, width, _ = fg.shape 32 | height_bg, width_bg, _ = bg.shape 33 | alpha_exp = np.expand_dims(alpha, axis=2) / 255. 34 | if min(height_bg, width_bg) >= max(height, width): 35 | bg_resized = bg[:height, :width] 36 | else: 37 | bg_resized = cv2.resize(bg, (width, height)) 38 | image = alpha_exp * fg + (1 - alpha_exp) * bg_resized 39 | return image.astype(np.uint8) 40 | 41 | 42 | if __name__ == '__main__': 43 | ckpt_path = './models/model.ckpt' 44 | test_fg_dir = './data_dim/test/fg' 45 | test_bg_dir = './data_dim/test/bg' 46 | test_alpha_dir = './data_dim/test/alpha' 47 | test_trimap_dir = './data_dim/test/trimap' 48 | output_dir = './data_dim/test/preds' 49 | test_fg_paths = glob.glob(os.path.join(test_fg_dir, '*.*')) 50 | test_bg_paths = glob.glob(os.path.join(test_bg_dir, '*.*')) 51 | 52 | if not os.path.exists(ckpt_path): 53 | raise ValueError('`ckpt_path` does not exist.') 54 | if not os.path.exists(output_dir): 55 | os.makedirs(output_dir) 56 | 57 | feature_extractor = model.vgg16_bn_feature_extractor( 58 | model.VGG16_BN_CONFIGS.get('13conv'), pretrained=False).to(device) 59 | dim = model.DIM(feature_extractor).to(device) 60 | #dim.load_state_dict(torch.load(ckpt_path)) 61 | dim_pretrained_params = torch.load(ckpt_path).items() 62 | dim_state_dict = {k.replace('module.', ''): v for k, v in 63 | dim_pretrained_params} 64 | dim.load_state_dict(dim_state_dict) 65 | print('Load DIM pretrained parameters, Done') 66 | 67 | # Transform 68 | channel_means = [0.485, 0.456, 0.406] 69 | channel_std = [0.229, 0.224, 0.225] 70 | transforms = tv.transforms.Compose([ 71 | tv.transforms.ToTensor(), 72 | tv.transforms.Normalize(mean=channel_means, std=channel_std)]) 73 | 74 | dim.eval() 75 | with torch.no_grad(): 76 | for fg_path, bg_path in zip(test_fg_paths, test_bg_paths): 77 | fg = cv2.imread(fg_path) 78 | bg = cv2.imread(bg_path) 79 | image_name = fg_path.replace('\\', '/').split('/')[-1] 80 | alpha_path = os.path.join(test_alpha_dir, image_name) 81 | alpha = cv2.imread(alpha_path, 0) 82 | trimap_name = image_name.replace('.png', '_0.png') 83 | trimap_path = os.path.join(test_trimap_dir, trimap_name) 84 | trimap = cv2.imread(trimap_path, 0) 85 | image = compose(fg, bg, alpha) 86 | image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 87 | image_size = image_rgb.shape[:2] 88 | 89 | image_processed = transforms(image).to(device) 90 | trimap_exp = np.expand_dims(trimap / 255., axis=0) 91 | trimap_exp = torch.Tensor(trimap_exp).to(device) 92 | images = torch.cat([image_processed, trimap_exp], dim=0) 93 | images = torch.unsqueeze(images, dim=0) 94 | 95 | alphas_pred = dim(images) 96 | 97 | alpha_pred_ = alphas_pred.data.cpu().numpy()[0][0] 98 | alpha_pred_ = 255 * alpha_pred_ 99 | alpha_pred = alpha_pred_.astype(np.uint8) 100 | alpha_pred = np.where(trimap == 128, alpha_pred, trimap) 101 | 102 | output_path = os.path.join(output_dir, image_name) 103 | cv2.imwrite(output_path, image) 104 | output_path = os.path.join(output_dir, 105 | image_name.replace('.png', '_matte.png')) 106 | cv2.imwrite(output_path, matte(image, alpha_pred)) 107 | output_path = os.path.join(output_dir, 108 | image_name.replace('.png', '_alpha.png')) 109 | cv2.imwrite(output_path, alpha_pred) 110 | -------------------------------------------------------------------------------- /test/1803151818-00000003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shirhe-Lyh/deep_image_matting_pytorch/643e52975f153919eaed62de6dd87d31a65d6a0e/test/1803151818-00000003.jpg -------------------------------------------------------------------------------- /test/1803151818-00000003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shirhe-Lyh/deep_image_matting_pytorch/643e52975f153919eaed62de6dd87d31a65d6a0e/test/1803151818-00000003.png -------------------------------------------------------------------------------- /test/1803151818-00000004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shirhe-Lyh/deep_image_matting_pytorch/643e52975f153919eaed62de6dd87d31a65d6a0e/test/1803151818-00000004.jpg -------------------------------------------------------------------------------- /test/1803151818-00000004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shirhe-Lyh/deep_image_matting_pytorch/643e52975f153919eaed62de6dd87d31a65d6a0e/test/1803151818-00000004.png -------------------------------------------------------------------------------- /test/1803250719-00000103.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shirhe-Lyh/deep_image_matting_pytorch/643e52975f153919eaed62de6dd87d31a65d6a0e/test/1803250719-00000103.jpg -------------------------------------------------------------------------------- /test/1803250719-00000103.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shirhe-Lyh/deep_image_matting_pytorch/643e52975f153919eaed62de6dd87d31a65d6a0e/test/1803250719-00000103.png -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Jun 17 19:19:14 2019 4 | 5 | @author: shirhe-lyh 6 | """ 7 | 8 | import argparse 9 | import json 10 | import os 11 | import torch 12 | 13 | from torch.utils.tensorboard import SummaryWriter 14 | 15 | import dataset 16 | import model 17 | 18 | parser = argparse.ArgumentParser(description='Train deep image matting model.') 19 | 20 | parser.add_argument('--gpu_indices', default=[0, 1, 2, 3], type=int, nargs='+', 21 | help='The indices of gpus to be used.') 22 | parser.add_argument('--num_epochs', default=300, type=int, 23 | help='Number of epochs') 24 | parser.add_argument('--batch_size_per_gpu', default=16, type=int, 25 | help='Batch size of one gpu.') 26 | parser.add_argument('--learning_rate', default=1e-4, type=float, 27 | help='Initial learning rate.') 28 | parser.add_argument('--end_learning_rate', default=1e-6, type=float, 29 | help='End learning rate.') 30 | parser.add_argument('--decay_epochs', default=20, type=int, 31 | help='Decay learning rate every decay_step.') 32 | parser.add_argument('--lr_decay_factor', default=0.9, type=float, 33 | help='Learning rate decay factor.') 34 | parser.add_argument('--annotation_path', default='./data/train.txt', type=str, 35 | help='Path to the annotation file.') 36 | parser.add_argument('--root_dir', default=None, type=str, 37 | help='Path to the images folder: xxx/Matting_Human_Half.') 38 | parser.add_argument('--model_dir', default='./models', type=str, 39 | help='Where the trained model file is stored.') 40 | 41 | FLAGS = parser.parse_args() 42 | 43 | 44 | def config_learning_rate(optimizer, decay=0.9): 45 | lr = FLAGS.learning_rate * decay 46 | if lr < FLAGS.end_learning_rate: 47 | return FLAGS.end_learning_rate 48 | 49 | for param_group in optimizer.param_groups: 50 | param_group['lr'] = lr 51 | return lr 52 | 53 | 54 | def train(): 55 | gpu_indices = FLAGS.gpu_indices 56 | num_epochs = FLAGS.num_epochs 57 | learning_rate = FLAGS.learning_rate 58 | lr_decay = FLAGS.lr_decay_factor 59 | batch_size = FLAGS.batch_size_per_gpu * len(gpu_indices) 60 | num_steps_to_save_checkpoint = 128000 // batch_size 61 | annotation_path = FLAGS.annotation_path 62 | root_dir = FLAGS.root_dir 63 | model_dir = FLAGS.model_dir 64 | 65 | gpu_ids_str = ','.join([str(index) for index in gpu_indices]) 66 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_ids_str 67 | 68 | # Device configuration 69 | cuda_ = 'cuda:{}'.format(gpu_indices[0]) 70 | device = torch.device(cuda_ if torch.cuda.is_available() else 'cpu') 71 | 72 | matting_dataset = dataset.MattingDataset(annotation_path=annotation_path, 73 | root_dir=root_dir) 74 | train_loader = torch.utils.data.DataLoader(matting_dataset, 75 | batch_size=batch_size, 76 | num_workers=32, 77 | shuffle=True, 78 | drop_last=True) 79 | 80 | feature_extractor = model.vgg16_bn_feature_extractor( 81 | model.VGG16_BN_CONFIGS.get('13conv')).to(device) 82 | dim = model.DIM(feature_extractor).to(device) 83 | 84 | # Load pretrained parameters 85 | start_epoch, start_step = 0, 0 86 | last_dim_checkpoint_path = None 87 | json_path = os.path.join(model_dir, 'checkpoint.json') 88 | if not os.path.exists(model_dir): 89 | os.makedirs(model_dir) 90 | else: 91 | if os.path.exists(json_path): 92 | with open(json_path, 'r') as reader: 93 | ckpt_dict = json.load(reader) 94 | start_epoch = ckpt_dict.get('epoch', 0) + 1 95 | start_step = ckpt_dict.get('step', 0) + 1 96 | dim_name = 'model-{}-{}.ckpt'.format(start_epoch, start_step) 97 | if os.path.exists(os.path.join(model_dir, dim_name)): 98 | last_dim_checkpoint_path = os.path.join(model_dir, dim_name) 99 | if os.path.exists(os.path.join(model_dir, 'model.ckpt')): 100 | last_dim_checkpoint_path = os.path.join(model_dir, 'model.ckpt') 101 | if last_dim_checkpoint_path and os.path.exists(last_dim_checkpoint_path): 102 | #dim.load_state_dict(torch.load(last_dim_checkpoint_path)) 103 | dim_pretrained_params = torch.load(last_dim_checkpoint_path).items() 104 | dim_state_dict = {k.replace('module.', ''): v for k, v in 105 | dim_pretrained_params} 106 | dim.load_state_dict(dim_state_dict) 107 | print('Load DIM pretrained parameters, Done') 108 | 109 | # Multiple GPUs 110 | dim = torch.nn.DataParallel(dim, device_ids=gpu_indices) 111 | 112 | optimizer = torch.optim.Adam(dim.parameters(), lr=learning_rate) 113 | 114 | # Tensorboard 115 | log_dir = os.path.join(model_dir, 'logs') 116 | log = SummaryWriter(log_dir=log_dir) 117 | 118 | total_step = len(train_loader) 119 | for epoch in range(start_epoch, num_epochs): 120 | for i, (images, alphas, alphas_noise, masks) in enumerate(train_loader): 121 | images = images.to(device) 122 | alphas = alphas.to(device) 123 | alphas_noise = alphas_noise.to(device) 124 | masks = masks.to(device) 125 | 126 | # Forward pass 127 | outputs = dim(images) 128 | loss = model.loss(outputs, alphas, masks=masks) 129 | 130 | # Backward and optimize 131 | optimizer.zero_grad() 132 | loss.backward() 133 | optimizer.step() 134 | 135 | # Log 136 | step = i + epoch * total_step 137 | if (i+1) % 50 == 0: 138 | # print('Epoch {}/{}, Step: {}/{}, Loss: {:.4f}'.format( 139 | # epoch+1, num_epochs, i+1, total_step, loss.item())) 140 | 141 | # Log scalar values 142 | log.add_scalar('loss', loss.item(), step+1) 143 | 144 | # Log training images 145 | info = {'alphas': 146 | alphas.cpu().numpy()[:2], 147 | 'alphas_noise': 148 | alphas_noise.cpu().numpy()[:2], 149 | 'alphas_pred': 150 | outputs.data.cpu().numpy()[:2]} 151 | for tag, imgs in info.items(): 152 | log.add_images(tag, imgs, step+1, dataformats='NCHW') 153 | 154 | # Save model 155 | if (step + 1) % num_steps_to_save_checkpoint == 0: 156 | print('Save Model: Epoch {}/{}, Step: {}/{}'.format( 157 | epoch+1, num_epochs, i+1, total_step)) 158 | model_name = 'model-{}-{}.ckpt'.format(epoch+1, i+1) 159 | model_path = os.path.join(model_dir, model_name) 160 | torch.save(dim.state_dict(), model_path) 161 | ckpt_dict = {'epoch': epoch, 'step': i, 'global_step': step} 162 | with open (json_path, 'w') as writer: 163 | json.dump(ckpt_dict, writer) 164 | 165 | # Decay learning rate 166 | if epoch % FLAGS.decay_epochs == 0: 167 | num_decays = epoch // FLAGS.decay_epochs 168 | lr = config_learning_rate(optimizer, decay=lr_decay ** num_decays) 169 | log.add_scalar('learning_rate', lr, step+1) 170 | log.close() 171 | 172 | # Final save 173 | model_path = os.path.join(model_dir, 'model.ckpt') 174 | torch.save(dim.state_dict(), model_path) 175 | ckpt_dict = {'epoch': num_epochs-1, 'step': total_step-1, 176 | 'global_step': num_epochs * total_step - 1} 177 | with open (json_path, 'w') as writer: 178 | json.dump(ckpt_dict, writer) 179 | 180 | 181 | if __name__ == '__main__': 182 | train() 183 | 184 | --------------------------------------------------------------------------------