├── .gitignore ├── 1703.03872.pdf ├── Combined_Dataset ├── Test_set │ └── Composition_code_revised.py └── Training_set │ └── Composition_code_revised.py ├── LICENSE ├── README.md ├── clean.sh ├── config.py ├── custom_layers ├── __init__.py ├── scale_layer.py └── unpooling_layer.py ├── data_generator.py ├── demo.py ├── encoder_decoder.svg ├── history ├── 2018-05-17 07-52-29.png └── 2018-05-21 21-44-41.png ├── images ├── 0_alpha.png ├── 0_compose.png ├── 0_image.png ├── 0_new_bg.png ├── 0_out.png ├── 0_trimap.png ├── 1_alpha.png ├── 1_compose.png ├── 1_image.png ├── 1_new_bg.png ├── 1_out.png ├── 1_trimap.png ├── 2_alpha.png ├── 2_compose.png ├── 2_image.png ├── 2_new_bg.png ├── 2_out.png ├── 2_trimap.png ├── 3_alpha.png ├── 3_compose.png ├── 3_image.png ├── 3_new_bg.png ├── 3_out.png ├── 3_trimap.png ├── 4_alpha.png ├── 4_compose.png ├── 4_image.png ├── 4_new_bg.png ├── 4_out.png ├── 4_trimap.png ├── 5_alpha.png ├── 5_compose.png ├── 5_image.png ├── 5_new_bg.png ├── 5_out.png ├── 5_trimap.png ├── 6_alpha.png ├── 6_compose.png ├── 6_image.png ├── 6_new_bg.png ├── 6_out.png ├── 6_trimap.png ├── 7_alpha.png ├── 7_compose.png ├── 7_image.png ├── 7_new_bg.png ├── 7_out.png ├── 7_trimap.png ├── 8_alpha.png ├── 8_compose.png ├── 8_image.png ├── 8_new_bg.png ├── 8_out.png ├── 8_trimap.png ├── 9_alpha.png ├── 9_compose.png ├── 9_image.png ├── 9_new_bg.png ├── 9_out.png └── 9_trimap.png ├── migrate.py ├── model.py ├── model.svg ├── parallel_model.svg ├── plot_model.py ├── pre_process.py ├── predit_single.py ├── refinement.svg ├── segnet.py ├── test.py ├── test_alphamatting.py ├── train.py ├── train_encoder_decoder.py ├── train_final.py ├── train_names.txt ├── train_refinement.py ├── unit_tests.py ├── utils.py ├── valid_names.txt └── vgg16.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | Adobe_Deep_Matting_Dataset.zip 3 | __pycache__/ 4 | custom_layers/__pycache__/ 5 | models/ 6 | train2014/ 7 | logs/ 8 | bg/ 9 | fg/ 10 | mask/ 11 | merged/ 12 | bg_test/ 13 | fg_test/ 14 | mask_test/ 15 | merged_test/ 16 | temp/ 17 | .cache/ 18 | VOC2008test.tar 19 | VOCdevkit/ 20 | VOCtest_06-Nov-2007.tar 21 | VOCtrainval_06-Nov-2007.tar 22 | VOCtrainval_14-Jul-2008.tar 23 | train2014.zip 24 | Deep-Image-Matting.7z 25 | matting_evaluation.zip 26 | training.txt 27 | Combined_Dataset/Adobe Deep Image Mattng Dataset License Agreement.pdf 28 | Combined_Dataset/README.txt 29 | Combined_Dataset/Test_set/Adobe-licensed images/ 30 | Combined_Dataset/Test_set/Composition_code.py 31 | Combined_Dataset/Test_set/test_bg_names.txt 32 | Combined_Dataset/Test_set/test_fg_names.txt 33 | Combined_Dataset/Training_set/Adobe-licensed images/ 34 | Combined_Dataset/Training_set/Composition_code.py 35 | Combined_Dataset/Training_set/Other/ 36 | Combined_Dataset/Training_set/training_bg_names.txt 37 | Combined_Dataset/Training_set/training_fg_names.txt -------------------------------------------------------------------------------- /1703.03872.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/1703.03872.pdf -------------------------------------------------------------------------------- /Combined_Dataset/Test_set/Composition_code_revised.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | import cv2 as cv 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | ############################################################## 8 | # Set your paths here 9 | 10 | # path to provided foreground images 11 | fg_path = 'data/fg_test/' 12 | 13 | # path to provided alpha mattes 14 | a_path = 'data/mask_test/' 15 | 16 | # Path to background images (MSCOCO) 17 | bg_path = 'data/bg_test/' 18 | 19 | # Path to folder where you want the composited images to go 20 | out_path = 'data/merged_test/' 21 | 22 | 23 | ############################################################## 24 | 25 | def composite4(fg, bg, a, w, h): 26 | fg = np.array(fg, np.float32) 27 | bg = np.array(bg[0:h, 0:w], np.float32) 28 | alpha = np.zeros((h, w, 1), np.float32) 29 | alpha[:, :, 0] = a / 255. 30 | comp = alpha * fg + (1 - alpha) * bg 31 | comp = comp.astype(np.uint8) 32 | return comp 33 | 34 | 35 | def process(im_name, bg_name, fcount, bcount): 36 | im = cv.imread(fg_path + im_name) 37 | a = cv.imread(a_path + im_name, 0) 38 | h, w = im.shape[:2] 39 | bg = cv.imread(bg_path + bg_name) 40 | bh, bw = bg.shape[:2] 41 | wratio = w / bw 42 | hratio = h / bh 43 | ratio = wratio if wratio > hratio else hratio 44 | if ratio > 1: 45 | bg = cv.resize(src=bg, dsize=(math.ceil(bw * ratio), math.ceil(bh * ratio)), interpolation=cv.INTER_CUBIC) 46 | 47 | out = composite4(im, bg, a, w, h) 48 | filename = out_path + str(fcount) + '_' + str(bcount) + '.png' 49 | cv.imwrite(filename, out) 50 | 51 | 52 | def do_composite_test(): 53 | num_bgs = 20 54 | 55 | with open('data/Combined_Dataset/Test_set/test_bg_names.txt') as f: 56 | bg_files = f.read().splitlines() 57 | with open('data/Combined_Dataset/Test_set/test_fg_names.txt') as f: 58 | fg_files = f.read().splitlines() 59 | 60 | # a_files = os.listdir(a_path) 61 | num_samples = len(fg_files) * num_bgs 62 | 63 | # pb = ProgressBar(total=100, prefix='Compose test images', suffix='', decimals=3, length=50, fill='=') 64 | start = time.time() 65 | bcount = 0 66 | for fcount in tqdm(range(len(fg_files))): 67 | im_name = fg_files[fcount] 68 | 69 | for i in range(num_bgs): 70 | bg_name = bg_files[bcount] 71 | process(im_name, bg_name, fcount, bcount) 72 | bcount += 1 73 | 74 | # pb.print_progress_bar(bcount * 100.0 / num_samples) 75 | 76 | end = time.time() 77 | elapsed = end - start 78 | print('elapsed: {} seconds'.format(elapsed)) 79 | -------------------------------------------------------------------------------- /Combined_Dataset/Training_set/Composition_code_revised.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | import cv2 as cv 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | ############################################################## 8 | # Set your paths here 9 | 10 | # path to provided foreground images 11 | fg_path = 'data/fg/' 12 | 13 | # path to provided alpha mattes 14 | a_path = 'data/mask/' 15 | 16 | # Path to background images (MSCOCO) 17 | bg_path = 'data/bg/' 18 | 19 | # Path to folder where you want the composited images to go 20 | out_path = 'data/merged/' 21 | 22 | 23 | ############################################################## 24 | 25 | def composite4(fg, bg, a, w, h): 26 | fg = np.array(fg, np.float32) 27 | bg = np.array(bg[0:h, 0:w], np.float32) 28 | alpha = np.zeros((h, w, 1), np.float32) 29 | alpha[:, :, 0] = a / 255. 30 | comp = alpha * fg + (1 - alpha) * bg 31 | comp = comp.astype(np.uint8) 32 | return comp 33 | 34 | 35 | def process(im_name, bg_name, fcount, bcount): 36 | im = cv.imread(fg_path + im_name) 37 | a = cv.imread(a_path + im_name, 0) 38 | h, w = im.shape[:2] 39 | bg = cv.imread(bg_path + bg_name) 40 | bh, bw = bg.shape[:2] 41 | wratio = w / bw 42 | hratio = h / bh 43 | ratio = wratio if wratio > hratio else hratio 44 | if ratio > 1: 45 | bg = cv.resize(src=bg, dsize=(math.ceil(bw * ratio), math.ceil(bh * ratio)), interpolation=cv.INTER_CUBIC) 46 | 47 | out = composite4(im, bg, a, w, h) 48 | filename = out_path + str(fcount) + '_' + str(bcount) + '.png' 49 | cv.imwrite(filename, out) 50 | 51 | 52 | def do_composite(): 53 | num_bgs = 100 54 | 55 | with open('data/Combined_Dataset/Training_set/training_bg_names.txt') as f: 56 | bg_files = f.read().splitlines() 57 | with open('data/Combined_Dataset/Training_set/training_fg_names.txt') as f: 58 | fg_files = f.read().splitlines() 59 | 60 | # a_files = os.listdir(a_path) 61 | num_samples = len(fg_files) * num_bgs 62 | 63 | # pb = ProgressBar(total=100, prefix='Compose train images', suffix='', decimals=3, length=50, fill='=') 64 | start = time.time() 65 | bcount = 0 66 | for fcount in tqdm(range(len(fg_files))): 67 | im_name = fg_files[fcount] 68 | 69 | for i in range(num_bgs): 70 | bg_name = bg_files[bcount] 71 | process(im_name, bg_name, fcount, bcount) 72 | bcount += 1 73 | 74 | # pb.print_progress_bar(bcount * 100.0 / num_samples) 75 | 76 | end = time.time() 77 | elapsed = end - start 78 | print('elapsed: {} seconds'.format(elapsed)) 79 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Yang Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Just in case you are interested, [Deep Image Matting v2](https://github.com/foamliu/Deep-Image-Matting-v2) is an upgraded version of this. 2 | 3 | # Deep Image Matting 4 | This repository is to reproduce Deep Image Matting. 5 | 6 | ## Dependencies 7 | - [NumPy](http://docs.scipy.org/doc/numpy-1.10.1/user/install.html) 8 | - [Tensorflow 1.9.0](https://www.tensorflow.org/) 9 | - [Keras 2.1.6](https://keras.io/#installation) 10 | - [OpenCV](https://opencv-python-tutroals.readthedocs.io/en/latest/) 11 | 12 | ## Dataset 13 | ### Adobe Deep Image Matting Dataset 14 | Follow the [instruction](https://sites.google.com/view/deepimagematting) to contact author for the dataset. 15 | 16 | ### MSCOCO 17 | Go to [MSCOCO](http://cocodataset.org/#download) to download: 18 | * [2014 Train images](http://images.cocodataset.org/zips/train2014.zip) 19 | 20 | 21 | ### PASCAL VOC 22 | Go to [PASCAL VOC](http://host.robots.ox.ac.uk/pascal/VOC/) to download: 23 | * VOC challenge 2008 [training/validation data](http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar) 24 | * The test data for the [VOC2008 challenge](http://host.robots.ox.ac.uk/pascal/VOC/voc2008/index.html#testdata) 25 | 26 | ## ImageNet Pretrained Models 27 | Download [VGG16](https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels.h5) into "models" folder. 28 | 29 | 30 | ## Usage 31 | ### Data Pre-processing 32 | Extract training images: 33 | ```bash 34 | $ python pre_process.py 35 | ``` 36 | 37 | ### Train 38 | ```bash 39 | $ python train.py 40 | ``` 41 | 42 | If you want to visualize during training, run in your terminal: 43 | ```bash 44 | $ tensorboard --logdir path_to_current_dir/logs 45 | ``` 46 | 47 | ### Demo 48 | Download pre-trained Deep Image Matting [Model](https://github.com/foamliu/Deep-Image-Matting/releases/download/v1.0/final.42-0.0398.hdf5) to "models" folder then run: 49 | ```bash 50 | $ python demo.py 51 | ``` 52 | 53 | Image/Trimap | Output/GT | New BG/Compose | 54 | |---|---|---| 55 | |![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/0_image.png) | ![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/0_out.png) | ![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/0_new_bg.png) | 56 | |![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/0_trimap.png) | ![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/0_alpha.png) | ![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/0_compose.png)| 57 | |![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/1_image.png) | ![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/1_out.png) | ![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/1_new_bg.png) | 58 | |![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/1_trimap.png) | ![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/1_alpha.png) | ![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/1_compose.png)| 59 | |![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/2_image.png) | ![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/2_out.png) | ![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/2_new_bg.png) | 60 | |![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/2_trimap.png) | ![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/2_alpha.png) | ![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/2_compose.png)| 61 | |![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/3_image.png) | ![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/3_out.png) | ![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/3_new_bg.png) | 62 | |![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/3_trimap.png) | ![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/3_alpha.png) | ![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/3_compose.png)| 63 | |![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/4_image.png) | ![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/4_out.png) | ![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/4_new_bg.png) | 64 | |![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/4_trimap.png) | ![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/4_alpha.png) | ![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/4_compose.png)| 65 | |![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/5_image.png) | ![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/5_out.png) | ![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/5_new_bg.png) | 66 | |![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/5_trimap.png) | ![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/5_alpha.png) | ![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/5_compose.png)| 67 | |![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/6_image.png) | ![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/6_out.png) | ![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/6_new_bg.png) | 68 | |![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/6_trimap.png) | ![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/6_alpha.png) | ![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/6_compose.png)| 69 | |![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/7_image.png) | ![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/7_out.png) | ![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/7_new_bg.png) | 70 | |![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/7_trimap.png) | ![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/7_alpha.png) | ![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/7_compose.png)| 71 | |![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/8_image.png) | ![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/8_out.png) | ![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/8_new_bg.png) | 72 | |![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/8_trimap.png) | ![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/8_alpha.png) | ![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/8_compose.png)| 73 | |![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/9_image.png) | ![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/9_out.png) | ![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/9_new_bg.png) | 74 | |![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/9_trimap.png) | ![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/9_alpha.png) | ![image](https://github.com/foamliu/Deep-Image-Matting/raw/master/images/9_compose.png)| 75 | 76 | -------------------------------------------------------------------------------- /clean.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | rm models/model.* 3 | rm logs -r 4 | rm training.txt -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | img_rows, img_cols = 320, 320 2 | # img_rows_half, img_cols_half = 160, 160 3 | channel = 4 4 | batch_size = 16 5 | epochs = 1000 6 | patience = 50 7 | num_samples = 43100 8 | num_train_samples = 34480 9 | # num_samples - num_train_samples 10 | num_valid_samples = 8620 11 | unknown_code = 128 12 | epsilon = 1e-6 13 | epsilon_sqr = epsilon ** 2 14 | 15 | ############################################################## 16 | # Set your paths here 17 | 18 | # path to provided foreground images 19 | fg_path = 'data/fg/' 20 | 21 | # path to provided alpha mattes 22 | a_path = 'data/mask/' 23 | 24 | # Path to background images (MSCOCO) 25 | bg_path = 'data/bg/' 26 | 27 | # Path to folder where you want the composited images to go 28 | out_path = 'data/merged/' 29 | 30 | ############################################################## 31 | -------------------------------------------------------------------------------- /custom_layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Python Package 2 | -------------------------------------------------------------------------------- /custom_layers/scale_layer.py: -------------------------------------------------------------------------------- 1 | from keras.layers.core import Layer 2 | from keras.engine import InputSpec 3 | from keras import backend as K 4 | try: 5 | from keras import initializations 6 | except ImportError: 7 | from keras import initializers as initializations 8 | 9 | class Scale(Layer): 10 | '''Learns a set of weights and biases used for scaling the input data. 11 | the output consists simply in an element-wise multiplication of the input 12 | and a sum of a set of constants: 13 | 14 | out = in * gamma + beta, 15 | 16 | where 'gamma' and 'beta' are the weights and biases larned. 17 | 18 | # Arguments 19 | axis: integer, axis along which to normalize in mode 0. For instance, 20 | if your input tensor has shape (samples, channels, rows, cols), 21 | set axis to 1 to normalize per feature map (channels axis). 22 | momentum: momentum in the computation of the 23 | exponential average of the mean and standard deviation 24 | of the data, for feature-wise normalization. 25 | weights: Initialization weights. 26 | List of 2 Numpy arrays, with shapes: 27 | `[(input_shape,), (input_shape,)]` 28 | beta_init: name of initialization function for shift parameter 29 | (see [initializations](../initializations.md)), or alternatively, 30 | Theano/TensorFlow function to use for weights initialization. 31 | This parameter is only relevant if you don't pass a `weights` argument. 32 | gamma_init: name of initialization function for scale parameter (see 33 | [initializations](../initializations.md)), or alternatively, 34 | Theano/TensorFlow function to use for weights initialization. 35 | This parameter is only relevant if you don't pass a `weights` argument. 36 | ''' 37 | def __init__(self, weights=None, axis=-1, momentum = 0.9, beta_init='zero', gamma_init='one', **kwargs): 38 | self.momentum = momentum 39 | self.axis = axis 40 | self.beta_init = initializations.get(beta_init) 41 | self.gamma_init = initializations.get(gamma_init) 42 | self.initial_weights = weights 43 | super(Scale, self).__init__(**kwargs) 44 | 45 | def build(self, input_shape): 46 | self.input_spec = [InputSpec(shape=input_shape)] 47 | shape = (int(input_shape[self.axis]),) 48 | 49 | # Compatibility with TensorFlow >= 1.0.0 50 | self.gamma = K.variable(self.gamma_init(shape), name='{}_gamma'.format(self.name)) 51 | self.beta = K.variable(self.beta_init(shape), name='{}_beta'.format(self.name)) 52 | #self.gamma = self.gamma_init(shape, name='{}_gamma'.format(self.name)) 53 | #self.beta = self.beta_init(shape, name='{}_beta'.format(self.name)) 54 | self.trainable_weights = [self.gamma, self.beta] 55 | 56 | if self.initial_weights is not None: 57 | self.set_weights(self.initial_weights) 58 | del self.initial_weights 59 | 60 | def call(self, x, mask=None): 61 | input_shape = self.input_spec[0].shape 62 | broadcast_shape = [1] * len(input_shape) 63 | broadcast_shape[self.axis] = input_shape[self.axis] 64 | 65 | out = K.reshape(self.gamma, broadcast_shape) * x + K.reshape(self.beta, broadcast_shape) 66 | return out 67 | 68 | def get_config(self): 69 | config = {"momentum": self.momentum, "axis": self.axis} 70 | base_config = super(Scale, self).get_config() 71 | return dict(list(base_config.items()) + list(config.items())) 72 | -------------------------------------------------------------------------------- /custom_layers/unpooling_layer.py: -------------------------------------------------------------------------------- 1 | from keras import backend as K 2 | from keras.engine.topology import Layer 3 | from keras.layers import Lambda, Multiply 4 | 5 | 6 | class Unpooling(Layer): 7 | 8 | def __init__(self, **kwargs): 9 | super(Unpooling, self).__init__(**kwargs) 10 | 11 | def build(self, input_shape): 12 | super(Unpooling, self).build(input_shape) 13 | 14 | def call(self, inputs, **kwargs): 15 | x = inputs[:, 1] 16 | # print('x.shape: ' + str(K.int_shape(x))) 17 | bool_mask = Lambda(lambda t: K.greater_equal(t[:, 0], t[:, 1]), 18 | output_shape=K.int_shape(x)[1:])(inputs) 19 | # print('bool_mask.shape: ' + str(K.int_shape(bool_mask))) 20 | mask = Lambda(lambda t: K.cast(t, dtype='float32'))(bool_mask) 21 | # print('mask.shape: ' + str(K.int_shape(mask))) 22 | x = Multiply()([mask, x]) 23 | # print('x.shape: ' + str(K.int_shape(x))) 24 | return x 25 | 26 | def compute_output_shape(self, input_shape): 27 | return input_shape[0], input_shape[2], input_shape[3], input_shape[4] 28 | -------------------------------------------------------------------------------- /data_generator.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | from random import shuffle 5 | 6 | import cv2 as cv 7 | import numpy as np 8 | from keras.utils import Sequence 9 | 10 | from config import batch_size 11 | from config import fg_path, bg_path, a_path 12 | from config import img_cols, img_rows 13 | from config import unknown_code 14 | from utils import safe_crop 15 | 16 | kernel = cv.getStructuringElement(cv.MORPH_ELLIPSE, (3, 3)) 17 | with open('data/Combined_Dataset/Training_set/training_fg_names.txt') as f: 18 | fg_files = f.read().splitlines() 19 | with open('data/Combined_Dataset/Test_set/test_fg_names.txt') as f: 20 | fg_test_files = f.read().splitlines() 21 | with open('data/Combined_Dataset/Training_set/training_bg_names.txt') as f: 22 | bg_files = f.read().splitlines() 23 | with open('data/Combined_Dataset/Test_set/test_bg_names.txt') as f: 24 | bg_test_files = f.read().splitlines() 25 | 26 | 27 | def get_alpha(name): 28 | fg_i = int(name.split("_")[0]) 29 | name = fg_files[fg_i] 30 | filename = os.path.join('data/mask', name) 31 | alpha = cv.imread(filename, 0) 32 | return alpha 33 | 34 | 35 | def get_alpha_test(name): 36 | fg_i = int(name.split("_")[0]) 37 | name = fg_test_files[fg_i] 38 | filename = os.path.join('data/mask_test', name) 39 | alpha = cv.imread(filename, 0) 40 | return alpha 41 | 42 | 43 | def composite4(fg, bg, a, w, h): 44 | fg = np.array(fg, np.float32) 45 | bg_h, bg_w = bg.shape[:2] 46 | x = 0 47 | if bg_w > w: 48 | x = np.random.randint(0, bg_w - w) 49 | y = 0 50 | if bg_h > h: 51 | y = np.random.randint(0, bg_h - h) 52 | bg = np.array(bg[y:y + h, x:x + w], np.float32) 53 | alpha = np.zeros((h, w, 1), np.float32) 54 | alpha[:, :, 0] = a / 255. 55 | im = alpha * fg + (1 - alpha) * bg 56 | im = im.astype(np.uint8) 57 | return im, a, fg, bg 58 | 59 | 60 | def process(im_name, bg_name): 61 | im = cv.imread(fg_path + im_name) 62 | a = cv.imread(a_path + im_name, 0) 63 | h, w = im.shape[:2] 64 | bg = cv.imread(bg_path + bg_name) 65 | bh, bw = bg.shape[:2] 66 | wratio = w / bw 67 | hratio = h / bh 68 | ratio = wratio if wratio > hratio else hratio 69 | if ratio > 1: 70 | bg = cv.resize(src=bg, dsize=(math.ceil(bw * ratio), math.ceil(bh * ratio)), interpolation=cv.INTER_CUBIC) 71 | 72 | return composite4(im, bg, a, w, h) 73 | 74 | 75 | def generate_trimap(alpha): 76 | fg = np.array(np.equal(alpha, 255).astype(np.float32)) 77 | # fg = cv.erode(fg, kernel, iterations=np.random.randint(1, 3)) 78 | unknown = np.array(np.not_equal(alpha, 0).astype(np.float32)) 79 | unknown = cv.dilate(unknown, kernel, iterations=np.random.randint(1, 20)) 80 | trimap = fg * 255 + (unknown - fg) * 128 81 | return trimap.astype(np.uint8) 82 | 83 | 84 | # Randomly crop (image, trimap) pairs centered on pixels in the unknown regions. 85 | def random_choice(trimap, crop_size=(320, 320)): 86 | crop_height, crop_width = crop_size 87 | y_indices, x_indices = np.where(trimap == unknown_code) 88 | num_unknowns = len(y_indices) 89 | x, y = 0, 0 90 | if num_unknowns > 0: 91 | ix = np.random.choice(range(num_unknowns)) 92 | center_x = x_indices[ix] 93 | center_y = y_indices[ix] 94 | x = max(0, center_x - int(crop_width / 2)) 95 | y = max(0, center_y - int(crop_height / 2)) 96 | return x, y 97 | 98 | 99 | class DataGenSequence(Sequence): 100 | def __init__(self, usage): 101 | self.usage = usage 102 | 103 | filename = '{}_names.txt'.format(usage) 104 | with open(filename, 'r') as f: 105 | self.names = f.read().splitlines() 106 | 107 | np.random.shuffle(self.names) 108 | 109 | def __len__(self): 110 | return int(np.ceil(len(self.names) / float(batch_size))) 111 | 112 | def __getitem__(self, idx): 113 | i = idx * batch_size 114 | 115 | length = min(batch_size, (len(self.names) - i)) 116 | batch_x = np.empty((length, img_rows, img_cols, 4), dtype=np.float32) 117 | batch_y = np.empty((length, img_rows, img_cols, 2), dtype=np.float32) 118 | 119 | for i_batch in range(length): 120 | name = self.names[i] 121 | fcount = int(name.split('.')[0].split('_')[0]) 122 | bcount = int(name.split('.')[0].split('_')[1]) 123 | im_name = fg_files[fcount] 124 | bg_name = bg_files[bcount] 125 | image, alpha, fg, bg = process(im_name, bg_name) 126 | 127 | # crop size 320:640:480 = 1:1:1 128 | different_sizes = [(320, 320), (480, 480), (640, 640)] 129 | crop_size = random.choice(different_sizes) 130 | 131 | trimap = generate_trimap(alpha) 132 | x, y = random_choice(trimap, crop_size) 133 | image = safe_crop(image, x, y, crop_size) 134 | alpha = safe_crop(alpha, x, y, crop_size) 135 | 136 | trimap = generate_trimap(alpha) 137 | 138 | # Flip array left to right randomly (prob=1:1) 139 | if np.random.random_sample() > 0.5: 140 | image = np.fliplr(image) 141 | trimap = np.fliplr(trimap) 142 | alpha = np.fliplr(alpha) 143 | 144 | batch_x[i_batch, :, :, 0:3] = image / 255. 145 | batch_x[i_batch, :, :, 3] = trimap / 255. 146 | 147 | mask = np.equal(trimap, 128).astype(np.float32) 148 | batch_y[i_batch, :, :, 0] = alpha / 255. 149 | batch_y[i_batch, :, :, 1] = mask 150 | 151 | i += 1 152 | 153 | return batch_x, batch_y 154 | 155 | def on_epoch_end(self): 156 | np.random.shuffle(self.names) 157 | 158 | 159 | def train_gen(): 160 | return DataGenSequence('train') 161 | 162 | 163 | def valid_gen(): 164 | return DataGenSequence('valid') 165 | 166 | 167 | def shuffle_data(): 168 | num_fgs = 431 169 | num_bgs = 43100 170 | num_bgs_per_fg = 100 171 | num_valid_samples = 8620 172 | names = [] 173 | bcount = 0 174 | for fcount in range(num_fgs): 175 | for i in range(num_bgs_per_fg): 176 | names.append(str(fcount) + '_' + str(bcount) + '.png') 177 | bcount += 1 178 | 179 | from config import num_valid_samples 180 | valid_names = random.sample(names, num_valid_samples) 181 | train_names = [n for n in names if n not in valid_names] 182 | shuffle(valid_names) 183 | shuffle(train_names) 184 | 185 | with open('valid_names.txt', 'w') as file: 186 | file.write('\n'.join(valid_names)) 187 | 188 | with open('train_names.txt', 'w') as file: 189 | file.write('\n'.join(train_names)) 190 | 191 | 192 | if __name__ == '__main__': 193 | filename = 'merged/357_35748.png' 194 | bgr_img = cv.imread(filename) 195 | bg_h, bg_w = bgr_img.shape[:2] 196 | print(bg_w, bg_h) 197 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | 5 | import cv2 as cv 6 | import keras.backend as K 7 | import numpy as np 8 | 9 | from data_generator import generate_trimap, random_choice, get_alpha_test 10 | from model import build_encoder_decoder, build_refinement 11 | from utils import compute_mse_loss, compute_sad_loss 12 | from utils import get_final_output, safe_crop, draw_str 13 | 14 | 15 | def composite4(fg, bg, a, w, h): 16 | fg = np.array(fg, np.float32) 17 | bg_h, bg_w = bg.shape[:2] 18 | x = 0 19 | if bg_w > w: 20 | x = np.random.randint(0, bg_w - w) 21 | y = 0 22 | if bg_h > h: 23 | y = np.random.randint(0, bg_h - h) 24 | bg = np.array(bg[y:y + h, x:x + w], np.float32) 25 | alpha = np.zeros((h, w, 1), np.float32) 26 | alpha[:, :, 0] = a / 255. 27 | im = alpha * fg + (1 - alpha) * bg 28 | im = im.astype(np.uint8) 29 | return im, bg 30 | 31 | 32 | if __name__ == '__main__': 33 | img_rows, img_cols = 320, 320 34 | channel = 4 35 | 36 | pretrained_path = 'models/final.42-0.0398.hdf5' 37 | encoder_decoder = build_encoder_decoder() 38 | final = build_refinement(encoder_decoder) 39 | final.load_weights(pretrained_path) 40 | print(final.summary()) 41 | 42 | out_test_path = 'data/merged_test/' 43 | test_images = [f for f in os.listdir(out_test_path) if 44 | os.path.isfile(os.path.join(out_test_path, f)) and f.endswith('.png')] 45 | samples = random.sample(test_images, 10) 46 | 47 | bg_test = 'data/bg_test/' 48 | test_bgs = [f for f in os.listdir(bg_test) if 49 | os.path.isfile(os.path.join(bg_test, f)) and f.endswith('.jpg')] 50 | sample_bgs = random.sample(test_bgs, 10) 51 | 52 | total_loss = 0.0 53 | for i in range(len(samples)): 54 | filename = samples[i] 55 | image_name = filename.split('.')[0] 56 | 57 | print('\nStart processing image: {}'.format(filename)) 58 | 59 | bgr_img = cv.imread(os.path.join(out_test_path, filename)) 60 | bg_h, bg_w = bgr_img.shape[:2] 61 | print('bg_h, bg_w: ' + str((bg_h, bg_w))) 62 | 63 | a = get_alpha_test(image_name) 64 | a_h, a_w = a.shape[:2] 65 | print('a_h, a_w: ' + str((a_h, a_w))) 66 | 67 | alpha = np.zeros((bg_h, bg_w), np.float32) 68 | alpha[0:a_h, 0:a_w] = a 69 | trimap = generate_trimap(alpha) 70 | different_sizes = [(320, 320), (320, 320), (320, 320), (480, 480), (640, 640)] 71 | crop_size = random.choice(different_sizes) 72 | x, y = random_choice(trimap, crop_size) 73 | print('x, y: ' + str((x, y))) 74 | 75 | bgr_img = safe_crop(bgr_img, x, y, crop_size) 76 | alpha = safe_crop(alpha, x, y, crop_size) 77 | trimap = safe_crop(trimap, x, y, crop_size) 78 | cv.imwrite('images/{}_image.png'.format(i), np.array(bgr_img).astype(np.uint8)) 79 | cv.imwrite('images/{}_trimap.png'.format(i), np.array(trimap).astype(np.uint8)) 80 | cv.imwrite('images/{}_alpha.png'.format(i), np.array(alpha).astype(np.uint8)) 81 | 82 | x_test = np.empty((1, img_rows, img_cols, 4), dtype=np.float32) 83 | x_test[0, :, :, 0:3] = bgr_img / 255. 84 | x_test[0, :, :, 3] = trimap / 255. 85 | 86 | y_true = np.empty((1, img_rows, img_cols, 2), dtype=np.float32) 87 | y_true[0, :, :, 0] = alpha / 255. 88 | y_true[0, :, :, 1] = trimap / 255. 89 | 90 | y_pred = final.predict(x_test) 91 | # print('y_pred.shape: ' + str(y_pred.shape)) 92 | 93 | y_pred = np.reshape(y_pred, (img_rows, img_cols)) 94 | print(y_pred.shape) 95 | y_pred = y_pred * 255.0 96 | y_pred = get_final_output(y_pred, trimap) 97 | y_pred = y_pred.astype(np.uint8) 98 | 99 | sad_loss = compute_sad_loss(y_pred, alpha, trimap) 100 | mse_loss = compute_mse_loss(y_pred, alpha, trimap) 101 | str_msg = 'sad_loss: %.4f, mse_loss: %.4f, crop_size: %s' % (sad_loss, mse_loss, str(crop_size)) 102 | print(str_msg) 103 | 104 | out = y_pred.copy() 105 | draw_str(out, (10, 20), str_msg) 106 | cv.imwrite('images/{}_out.png'.format(i), out) 107 | 108 | sample_bg = sample_bgs[i] 109 | bg = cv.imread(os.path.join(bg_test, sample_bg)) 110 | bh, bw = bg.shape[:2] 111 | wratio = img_cols / bw 112 | hratio = img_rows / bh 113 | ratio = wratio if wratio > hratio else hratio 114 | if ratio > 1: 115 | bg = cv.resize(src=bg, dsize=(math.ceil(bw * ratio), math.ceil(bh * ratio)), interpolation=cv.INTER_CUBIC) 116 | im, bg = composite4(bgr_img, bg, y_pred, img_cols, img_rows) 117 | cv.imwrite('images/{}_compose.png'.format(i), im) 118 | cv.imwrite('images/{}_new_bg.png'.format(i), bg) 119 | 120 | K.clear_session() 121 | -------------------------------------------------------------------------------- /history/2018-05-17 07-52-29.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/history/2018-05-17 07-52-29.png -------------------------------------------------------------------------------- /history/2018-05-21 21-44-41.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/history/2018-05-21 21-44-41.png -------------------------------------------------------------------------------- /images/0_alpha.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/0_alpha.png -------------------------------------------------------------------------------- /images/0_compose.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/0_compose.png -------------------------------------------------------------------------------- /images/0_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/0_image.png -------------------------------------------------------------------------------- /images/0_new_bg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/0_new_bg.png -------------------------------------------------------------------------------- /images/0_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/0_out.png -------------------------------------------------------------------------------- /images/0_trimap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/0_trimap.png -------------------------------------------------------------------------------- /images/1_alpha.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/1_alpha.png -------------------------------------------------------------------------------- /images/1_compose.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/1_compose.png -------------------------------------------------------------------------------- /images/1_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/1_image.png -------------------------------------------------------------------------------- /images/1_new_bg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/1_new_bg.png -------------------------------------------------------------------------------- /images/1_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/1_out.png -------------------------------------------------------------------------------- /images/1_trimap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/1_trimap.png -------------------------------------------------------------------------------- /images/2_alpha.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/2_alpha.png -------------------------------------------------------------------------------- /images/2_compose.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/2_compose.png -------------------------------------------------------------------------------- /images/2_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/2_image.png -------------------------------------------------------------------------------- /images/2_new_bg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/2_new_bg.png -------------------------------------------------------------------------------- /images/2_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/2_out.png -------------------------------------------------------------------------------- /images/2_trimap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/2_trimap.png -------------------------------------------------------------------------------- /images/3_alpha.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/3_alpha.png -------------------------------------------------------------------------------- /images/3_compose.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/3_compose.png -------------------------------------------------------------------------------- /images/3_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/3_image.png -------------------------------------------------------------------------------- /images/3_new_bg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/3_new_bg.png -------------------------------------------------------------------------------- /images/3_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/3_out.png -------------------------------------------------------------------------------- /images/3_trimap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/3_trimap.png -------------------------------------------------------------------------------- /images/4_alpha.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/4_alpha.png -------------------------------------------------------------------------------- /images/4_compose.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/4_compose.png -------------------------------------------------------------------------------- /images/4_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/4_image.png -------------------------------------------------------------------------------- /images/4_new_bg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/4_new_bg.png -------------------------------------------------------------------------------- /images/4_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/4_out.png -------------------------------------------------------------------------------- /images/4_trimap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/4_trimap.png -------------------------------------------------------------------------------- /images/5_alpha.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/5_alpha.png -------------------------------------------------------------------------------- /images/5_compose.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/5_compose.png -------------------------------------------------------------------------------- /images/5_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/5_image.png -------------------------------------------------------------------------------- /images/5_new_bg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/5_new_bg.png -------------------------------------------------------------------------------- /images/5_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/5_out.png -------------------------------------------------------------------------------- /images/5_trimap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/5_trimap.png -------------------------------------------------------------------------------- /images/6_alpha.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/6_alpha.png -------------------------------------------------------------------------------- /images/6_compose.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/6_compose.png -------------------------------------------------------------------------------- /images/6_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/6_image.png -------------------------------------------------------------------------------- /images/6_new_bg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/6_new_bg.png -------------------------------------------------------------------------------- /images/6_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/6_out.png -------------------------------------------------------------------------------- /images/6_trimap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/6_trimap.png -------------------------------------------------------------------------------- /images/7_alpha.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/7_alpha.png -------------------------------------------------------------------------------- /images/7_compose.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/7_compose.png -------------------------------------------------------------------------------- /images/7_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/7_image.png -------------------------------------------------------------------------------- /images/7_new_bg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/7_new_bg.png -------------------------------------------------------------------------------- /images/7_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/7_out.png -------------------------------------------------------------------------------- /images/7_trimap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/7_trimap.png -------------------------------------------------------------------------------- /images/8_alpha.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/8_alpha.png -------------------------------------------------------------------------------- /images/8_compose.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/8_compose.png -------------------------------------------------------------------------------- /images/8_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/8_image.png -------------------------------------------------------------------------------- /images/8_new_bg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/8_new_bg.png -------------------------------------------------------------------------------- /images/8_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/8_out.png -------------------------------------------------------------------------------- /images/8_trimap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/8_trimap.png -------------------------------------------------------------------------------- /images/9_alpha.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/9_alpha.png -------------------------------------------------------------------------------- /images/9_compose.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/9_compose.png -------------------------------------------------------------------------------- /images/9_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/9_image.png -------------------------------------------------------------------------------- /images/9_new_bg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/9_new_bg.png -------------------------------------------------------------------------------- /images/9_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/9_out.png -------------------------------------------------------------------------------- /images/9_trimap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/9_trimap.png -------------------------------------------------------------------------------- /migrate.py: -------------------------------------------------------------------------------- 1 | import keras.backend as K 2 | import numpy as np 3 | 4 | from config import channel 5 | from model import build_encoder_decoder 6 | from vgg16 import vgg16_model 7 | 8 | 9 | def migrate_model(new_model): 10 | old_model = vgg16_model(224, 224, 3) 11 | # print(old_model.summary()) 12 | old_layers = [l for l in old_model.layers] 13 | new_layers = [l for l in new_model.layers] 14 | 15 | old_conv1_1 = old_model.get_layer('conv1_1') 16 | old_weights = old_conv1_1.get_weights()[0] 17 | old_biases = old_conv1_1.get_weights()[1] 18 | new_weights = np.zeros((3, 3, channel, 64), dtype=np.float32) 19 | new_weights[:, :, 0:3, :] = old_weights 20 | new_weights[:, :, 3:channel, :] = 0.0 21 | new_conv1_1 = new_model.get_layer('conv1_1') 22 | new_conv1_1.set_weights([new_weights, old_biases]) 23 | 24 | for i in range(2, 31): 25 | old_layer = old_layers[i] 26 | new_layer = new_layers[i + 1] 27 | new_layer.set_weights(old_layer.get_weights()) 28 | 29 | # flatten = old_model.get_layer('flatten') 30 | # f_dim = flatten.input_shape 31 | # print('f_dim: ' + str(f_dim)) 32 | # old_dense1 = old_model.get_layer('dense1') 33 | # input_shape = old_dense1.input_shape 34 | # output_dim = old_dense1.get_weights()[1].shape[0] 35 | # print('output_dim: ' + str(output_dim)) 36 | # W, b = old_dense1.get_weights() 37 | # shape = (7, 7, 512, output_dim) 38 | # new_W = W.reshape(shape) 39 | # new_conv6 = new_model.get_layer('conv6') 40 | # new_conv6.set_weights([new_W, b]) 41 | 42 | del old_model 43 | 44 | 45 | if __name__ == '__main__': 46 | model = build_encoder_decoder() 47 | migrate_model(model) 48 | print(model.summary()) 49 | model.save_weights('models/model_weights.h5') 50 | 51 | K.clear_session() 52 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import keras.backend as K 2 | import tensorflow as tf 3 | from keras.layers import Input, Conv2D, UpSampling2D, BatchNormalization, ZeroPadding2D, MaxPooling2D, Concatenate, \ 4 | Reshape, Lambda 5 | from keras.models import Model 6 | from keras.utils import multi_gpu_model 7 | from keras.utils import plot_model 8 | 9 | from custom_layers.unpooling_layer import Unpooling 10 | 11 | 12 | def build_encoder_decoder(): 13 | # Encoder 14 | input_tensor = Input(shape=(320, 320, 4)) 15 | x = ZeroPadding2D((1, 1))(input_tensor) 16 | x = Conv2D(64, (3, 3), activation='relu', name='conv1_1')(x) 17 | x = ZeroPadding2D((1, 1))(x) 18 | x = Conv2D(64, (3, 3), activation='relu', name='conv1_2')(x) 19 | orig_1 = x 20 | x = MaxPooling2D((2, 2), strides=(2, 2))(x) 21 | 22 | x = ZeroPadding2D((1, 1))(x) 23 | x = Conv2D(128, (3, 3), activation='relu', name='conv2_1')(x) 24 | x = ZeroPadding2D((1, 1))(x) 25 | x = Conv2D(128, (3, 3), activation='relu', name='conv2_2')(x) 26 | orig_2 = x 27 | x = MaxPooling2D((2, 2), strides=(2, 2))(x) 28 | 29 | x = ZeroPadding2D((1, 1))(x) 30 | x = Conv2D(256, (3, 3), activation='relu', name='conv3_1')(x) 31 | x = ZeroPadding2D((1, 1))(x) 32 | x = Conv2D(256, (3, 3), activation='relu', name='conv3_2')(x) 33 | x = ZeroPadding2D((1, 1))(x) 34 | x = Conv2D(256, (3, 3), activation='relu', name='conv3_3')(x) 35 | orig_3 = x 36 | x = MaxPooling2D((2, 2), strides=(2, 2))(x) 37 | 38 | x = ZeroPadding2D((1, 1))(x) 39 | x = Conv2D(512, (3, 3), activation='relu', name='conv4_1')(x) 40 | x = ZeroPadding2D((1, 1))(x) 41 | x = Conv2D(512, (3, 3), activation='relu', name='conv4_2')(x) 42 | x = ZeroPadding2D((1, 1))(x) 43 | x = Conv2D(512, (3, 3), activation='relu', name='conv4_3')(x) 44 | orig_4 = x 45 | x = MaxPooling2D((2, 2), strides=(2, 2))(x) 46 | 47 | x = ZeroPadding2D((1, 1))(x) 48 | x = Conv2D(512, (3, 3), activation='relu', name='conv5_1')(x) 49 | x = ZeroPadding2D((1, 1))(x) 50 | x = Conv2D(512, (3, 3), activation='relu', name='conv5_2')(x) 51 | x = ZeroPadding2D((1, 1))(x) 52 | x = Conv2D(512, (3, 3), activation='relu', name='conv5_3')(x) 53 | orig_5 = x 54 | x = MaxPooling2D((2, 2), strides=(2, 2))(x) 55 | 56 | # Decoder 57 | # x = Conv2D(4096, (7, 7), activation='relu', padding='valid', name='conv6')(x) 58 | # x = BatchNormalization()(x) 59 | # x = UpSampling2D(size=(7, 7))(x) 60 | 61 | x = Conv2D(512, (1, 1), activation='relu', padding='same', name='deconv6', kernel_initializer='he_normal', 62 | bias_initializer='zeros')(x) 63 | x = BatchNormalization()(x) 64 | x = UpSampling2D(size=(2, 2))(x) 65 | the_shape = K.int_shape(orig_5) 66 | shape = (1, the_shape[1], the_shape[2], the_shape[3]) 67 | origReshaped = Reshape(shape)(orig_5) 68 | # print('origReshaped.shape: ' + str(K.int_shape(origReshaped))) 69 | xReshaped = Reshape(shape)(x) 70 | # print('xReshaped.shape: ' + str(K.int_shape(xReshaped))) 71 | together = Concatenate(axis=1)([origReshaped, xReshaped]) 72 | # print('together.shape: ' + str(K.int_shape(together))) 73 | x = Unpooling()(together) 74 | 75 | x = Conv2D(512, (5, 5), activation='relu', padding='same', name='deconv5', kernel_initializer='he_normal', 76 | bias_initializer='zeros')(x) 77 | x = BatchNormalization()(x) 78 | x = UpSampling2D(size=(2, 2))(x) 79 | the_shape = K.int_shape(orig_4) 80 | shape = (1, the_shape[1], the_shape[2], the_shape[3]) 81 | origReshaped = Reshape(shape)(orig_4) 82 | xReshaped = Reshape(shape)(x) 83 | together = Concatenate(axis=1)([origReshaped, xReshaped]) 84 | x = Unpooling()(together) 85 | 86 | x = Conv2D(256, (5, 5), activation='relu', padding='same', name='deconv4', kernel_initializer='he_normal', 87 | bias_initializer='zeros')(x) 88 | x = BatchNormalization()(x) 89 | x = UpSampling2D(size=(2, 2))(x) 90 | the_shape = K.int_shape(orig_3) 91 | shape = (1, the_shape[1], the_shape[2], the_shape[3]) 92 | origReshaped = Reshape(shape)(orig_3) 93 | xReshaped = Reshape(shape)(x) 94 | together = Concatenate(axis=1)([origReshaped, xReshaped]) 95 | x = Unpooling()(together) 96 | 97 | x = Conv2D(128, (5, 5), activation='relu', padding='same', name='deconv3', kernel_initializer='he_normal', 98 | bias_initializer='zeros')(x) 99 | x = BatchNormalization()(x) 100 | x = UpSampling2D(size=(2, 2))(x) 101 | the_shape = K.int_shape(orig_2) 102 | shape = (1, the_shape[1], the_shape[2], the_shape[3]) 103 | origReshaped = Reshape(shape)(orig_2) 104 | xReshaped = Reshape(shape)(x) 105 | together = Concatenate(axis=1)([origReshaped, xReshaped]) 106 | x = Unpooling()(together) 107 | 108 | x = Conv2D(64, (5, 5), activation='relu', padding='same', name='deconv2', kernel_initializer='he_normal', 109 | bias_initializer='zeros')(x) 110 | x = BatchNormalization()(x) 111 | x = UpSampling2D(size=(2, 2))(x) 112 | the_shape = K.int_shape(orig_1) 113 | shape = (1, the_shape[1], the_shape[2], the_shape[3]) 114 | origReshaped = Reshape(shape)(orig_1) 115 | xReshaped = Reshape(shape)(x) 116 | together = Concatenate(axis=1)([origReshaped, xReshaped]) 117 | x = Unpooling()(together) 118 | x = Conv2D(64, (5, 5), activation='relu', padding='same', name='deconv1', kernel_initializer='he_normal', 119 | bias_initializer='zeros')(x) 120 | x = BatchNormalization()(x) 121 | 122 | x = Conv2D(1, (5, 5), activation='sigmoid', padding='same', name='pred', kernel_initializer='he_normal', 123 | bias_initializer='zeros')(x) 124 | 125 | model = Model(inputs=input_tensor, outputs=x) 126 | return model 127 | 128 | 129 | def build_refinement(encoder_decoder): 130 | input_tensor = encoder_decoder.input 131 | 132 | input = Lambda(lambda i: i[:, :, :, 0:3])(input_tensor) 133 | 134 | x = Concatenate(axis=3)([input, encoder_decoder.output]) 135 | x = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', 136 | bias_initializer='zeros')(x) 137 | x = BatchNormalization()(x) 138 | x = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', 139 | bias_initializer='zeros')(x) 140 | x = BatchNormalization()(x) 141 | x = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', 142 | bias_initializer='zeros')(x) 143 | x = BatchNormalization()(x) 144 | x = Conv2D(1, (3, 3), activation='sigmoid', padding='same', name='refinement_pred', kernel_initializer='he_normal', 145 | bias_initializer='zeros')(x) 146 | 147 | model = Model(inputs=input_tensor, outputs=x) 148 | return model 149 | 150 | 151 | if __name__ == '__main__': 152 | with tf.device("/cpu:0"): 153 | encoder_decoder = build_encoder_decoder() 154 | print(encoder_decoder.summary()) 155 | plot_model(encoder_decoder, to_file='encoder_decoder.svg', show_layer_names=True, show_shapes=True) 156 | 157 | with tf.device("/cpu:0"): 158 | refinement = build_refinement(encoder_decoder) 159 | print(refinement.summary()) 160 | plot_model(refinement, to_file='refinement.svg', show_layer_names=True, show_shapes=True) 161 | 162 | parallel_model = multi_gpu_model(refinement, gpus=None) 163 | print(parallel_model.summary()) 164 | plot_model(parallel_model, to_file='parallel_model.svg', show_layer_names=True, show_shapes=True) 165 | 166 | K.clear_session() 167 | -------------------------------------------------------------------------------- /model.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 9 | 10 | G 11 | 12 | 13 | 140442454693144 14 | 15 | input: InputLayer 16 | 17 | input: 18 | 19 | output: 20 | 21 | (None, 320, 320, 4) 22 | 23 | (None, 320, 320, 4) 24 | 25 | 26 | 140439928904728 27 | 28 | zero_padding2d_1: ZeroPadding2D 29 | 30 | input: 31 | 32 | output: 33 | 34 | (None, 320, 320, 4) 35 | 36 | (None, 322, 322, 4) 37 | 38 | 39 | 140442454693144->140439928904728 40 | 41 | 42 | 43 | 44 | 140442454531040 45 | 46 | conv1_1: Conv2D 47 | 48 | input: 49 | 50 | output: 51 | 52 | (None, 322, 322, 4) 53 | 54 | (None, 320, 320, 64) 55 | 56 | 57 | 140439928904728->140442454531040 58 | 59 | 60 | 61 | 62 | 140439928532048 63 | 64 | zero_padding2d_2: ZeroPadding2D 65 | 66 | input: 67 | 68 | output: 69 | 70 | (None, 320, 320, 64) 71 | 72 | (None, 322, 322, 64) 73 | 74 | 75 | 140442454531040->140439928532048 76 | 77 | 78 | 79 | 80 | 140439928693592 81 | 82 | conv1_2: Conv2D 83 | 84 | input: 85 | 86 | output: 87 | 88 | (None, 322, 322, 64) 89 | 90 | (None, 320, 320, 64) 91 | 92 | 93 | 140439928532048->140439928693592 94 | 95 | 96 | 97 | 98 | 140439928694712 99 | 100 | max_pooling2d_1: MaxPooling2D 101 | 102 | input: 103 | 104 | output: 105 | 106 | (None, 320, 320, 64) 107 | 108 | (None, 160, 160, 64) 109 | 110 | 111 | 140439928693592->140439928694712 112 | 113 | 114 | 115 | 116 | 140439928296896 117 | 118 | zero_padding2d_3: ZeroPadding2D 119 | 120 | input: 121 | 122 | output: 123 | 124 | (None, 160, 160, 64) 125 | 126 | (None, 162, 162, 64) 127 | 128 | 129 | 140439928694712->140439928296896 130 | 131 | 132 | 133 | 134 | 140439928298296 135 | 136 | conv2_1: Conv2D 137 | 138 | input: 139 | 140 | output: 141 | 142 | (None, 162, 162, 64) 143 | 144 | (None, 160, 160, 128) 145 | 146 | 147 | 140439928296896->140439928298296 148 | 149 | 150 | 151 | 152 | 140439928297120 153 | 154 | zero_padding2d_4: ZeroPadding2D 155 | 156 | input: 157 | 158 | output: 159 | 160 | (None, 160, 160, 128) 161 | 162 | (None, 162, 162, 128) 163 | 164 | 165 | 140439928298296->140439928297120 166 | 167 | 168 | 169 | 170 | 140439928386392 171 | 172 | conv2_2: Conv2D 173 | 174 | input: 175 | 176 | output: 177 | 178 | (None, 162, 162, 128) 179 | 180 | (None, 160, 160, 128) 181 | 182 | 183 | 140439928297120->140439928386392 184 | 185 | 186 | 187 | 188 | 140439928389080 189 | 190 | max_pooling2d_2: MaxPooling2D 191 | 192 | input: 193 | 194 | output: 195 | 196 | (None, 160, 160, 128) 197 | 198 | (None, 80, 80, 128) 199 | 200 | 201 | 140439928386392->140439928389080 202 | 203 | 204 | 205 | 206 | 140439927937624 207 | 208 | zero_padding2d_5: ZeroPadding2D 209 | 210 | input: 211 | 212 | output: 213 | 214 | (None, 80, 80, 128) 215 | 216 | (None, 82, 82, 128) 217 | 218 | 219 | 140439928389080->140439927937624 220 | 221 | 222 | 223 | 224 | 140439927938800 225 | 226 | conv3_1: Conv2D 227 | 228 | input: 229 | 230 | output: 231 | 232 | (None, 82, 82, 128) 233 | 234 | (None, 80, 80, 256) 235 | 236 | 237 | 140439927937624->140439927938800 238 | 239 | 240 | 241 | 242 | 140439928293920 243 | 244 | zero_padding2d_6: ZeroPadding2D 245 | 246 | input: 247 | 248 | output: 249 | 250 | (None, 80, 80, 256) 251 | 252 | (None, 82, 82, 256) 253 | 254 | 255 | 140439927938800->140439928293920 256 | 257 | 258 | 259 | 260 | 140439928015336 261 | 262 | conv3_2: Conv2D 263 | 264 | input: 265 | 266 | output: 267 | 268 | (None, 82, 82, 256) 269 | 270 | (None, 80, 80, 256) 271 | 272 | 273 | 140439928293920->140439928015336 274 | 275 | 276 | 277 | 278 | 140439928060728 279 | 280 | zero_padding2d_7: ZeroPadding2D 281 | 282 | input: 283 | 284 | output: 285 | 286 | (None, 80, 80, 256) 287 | 288 | (None, 82, 82, 256) 289 | 290 | 291 | 140439928015336->140439928060728 292 | 293 | 294 | 295 | 296 | 140439928098544 297 | 298 | conv3_3: Conv2D 299 | 300 | input: 301 | 302 | output: 303 | 304 | (None, 82, 82, 256) 305 | 306 | (None, 80, 80, 256) 307 | 308 | 309 | 140439928060728->140439928098544 310 | 311 | 312 | 313 | 314 | 140439928142704 315 | 316 | max_pooling2d_3: MaxPooling2D 317 | 318 | input: 319 | 320 | output: 321 | 322 | (None, 80, 80, 256) 323 | 324 | (None, 40, 40, 256) 325 | 326 | 327 | 140439928098544->140439928142704 328 | 329 | 330 | 331 | 332 | 140439928184560 333 | 334 | zero_padding2d_8: ZeroPadding2D 335 | 336 | input: 337 | 338 | output: 339 | 340 | (None, 40, 40, 256) 341 | 342 | (None, 42, 42, 256) 343 | 344 | 345 | 140439928142704->140439928184560 346 | 347 | 348 | 349 | 350 | 140439927704432 351 | 352 | conv4_1: Conv2D 353 | 354 | input: 355 | 356 | output: 357 | 358 | (None, 42, 42, 256) 359 | 360 | (None, 40, 40, 512) 361 | 362 | 363 | 140439928184560->140439927704432 364 | 365 | 366 | 367 | 368 | 140439927702528 369 | 370 | zero_padding2d_9: ZeroPadding2D 371 | 372 | input: 373 | 374 | output: 375 | 376 | (None, 40, 40, 512) 377 | 378 | (None, 42, 42, 512) 379 | 380 | 381 | 140439927704432->140439927702528 382 | 383 | 384 | 385 | 386 | 140439927780016 387 | 388 | conv4_2: Conv2D 389 | 390 | input: 391 | 392 | output: 393 | 394 | (None, 42, 42, 512) 395 | 396 | (None, 40, 40, 512) 397 | 398 | 399 | 140439927702528->140439927780016 400 | 401 | 402 | 403 | 404 | 140439927783376 405 | 406 | zero_padding2d_10: ZeroPadding2D 407 | 408 | input: 409 | 410 | output: 411 | 412 | (None, 40, 40, 512) 413 | 414 | (None, 42, 42, 512) 415 | 416 | 417 | 140439927780016->140439927783376 418 | 419 | 420 | 421 | 422 | 140439927866088 423 | 424 | conv4_3: Conv2D 425 | 426 | input: 427 | 428 | output: 429 | 430 | (None, 42, 42, 512) 431 | 432 | (None, 40, 40, 512) 433 | 434 | 435 | 140439927783376->140439927866088 436 | 437 | 438 | 439 | 440 | 140439927869336 441 | 442 | max_pooling2d_4: MaxPooling2D 443 | 444 | input: 445 | 446 | output: 447 | 448 | (None, 40, 40, 512) 449 | 450 | (None, 20, 20, 512) 451 | 452 | 453 | 140439927866088->140439927869336 454 | 455 | 456 | 457 | 458 | 140439925318376 459 | 460 | zero_padding2d_11: ZeroPadding2D 461 | 462 | input: 463 | 464 | output: 465 | 466 | (None, 20, 20, 512) 467 | 468 | (None, 22, 22, 512) 469 | 470 | 471 | 140439927869336->140439925318376 472 | 473 | 474 | 475 | 476 | 140439925321624 477 | 478 | conv5_1: Conv2D 479 | 480 | input: 481 | 482 | output: 483 | 484 | (None, 22, 22, 512) 485 | 486 | (None, 20, 20, 512) 487 | 488 | 489 | 140439925318376->140439925321624 490 | 491 | 492 | 493 | 494 | 140439925318600 495 | 496 | zero_padding2d_12: ZeroPadding2D 497 | 498 | input: 499 | 500 | output: 501 | 502 | (None, 20, 20, 512) 503 | 504 | (None, 22, 22, 512) 505 | 506 | 507 | 140439925321624->140439925318600 508 | 509 | 510 | 511 | 512 | 140439925399680 513 | 514 | conv5_2: Conv2D 515 | 516 | input: 517 | 518 | output: 519 | 520 | (None, 22, 22, 512) 521 | 522 | (None, 20, 20, 512) 523 | 524 | 525 | 140439925318600->140439925399680 526 | 527 | 528 | 529 | 530 | 140439925402368 531 | 532 | zero_padding2d_13: ZeroPadding2D 533 | 534 | input: 535 | 536 | output: 537 | 538 | (None, 20, 20, 512) 539 | 540 | (None, 22, 22, 512) 541 | 542 | 543 | 140439925399680->140439925402368 544 | 545 | 546 | 547 | 548 | 140439925487376 549 | 550 | conv5_3: Conv2D 551 | 552 | input: 553 | 554 | output: 555 | 556 | (None, 22, 22, 512) 557 | 558 | (None, 20, 20, 512) 559 | 560 | 561 | 140439925402368->140439925487376 562 | 563 | 564 | 565 | 566 | 140439925488440 567 | 568 | max_pooling2d_5: MaxPooling2D 569 | 570 | input: 571 | 572 | output: 573 | 574 | (None, 20, 20, 512) 575 | 576 | (None, 10, 10, 512) 577 | 578 | 579 | 140439925487376->140439925488440 580 | 581 | 582 | 583 | 584 | 140439925036872 585 | 586 | deconv6: Conv2D 587 | 588 | input: 589 | 590 | output: 591 | 592 | (None, 10, 10, 512) 593 | 594 | (None, 10, 10, 512) 595 | 596 | 597 | 140439925488440->140439925036872 598 | 599 | 600 | 601 | 602 | 140439925038440 603 | 604 | batch_normalization_1: BatchNormalization 605 | 606 | input: 607 | 608 | output: 609 | 610 | (None, 10, 10, 512) 611 | 612 | (None, 10, 10, 512) 613 | 614 | 615 | 140439925036872->140439925038440 616 | 617 | 618 | 619 | 620 | 140439925124344 621 | 622 | up_sampling2d_1: UpSampling2D 623 | 624 | input: 625 | 626 | output: 627 | 628 | (None, 10, 10, 512) 629 | 630 | (None, 20, 20, 512) 631 | 632 | 633 | 140439925038440->140439925124344 634 | 635 | 636 | 637 | 638 | 140439925122552 639 | 640 | unpooling_1: Unpooling 641 | 642 | input: 643 | 644 | output: 645 | 646 | (None, 20, 20, 512) 647 | 648 | (None, 20, 20, 512) 649 | 650 | 651 | 140439925124344->140439925122552 652 | 653 | 654 | 655 | 656 | 140438117455296 657 | 658 | deconv5: Conv2D 659 | 660 | input: 661 | 662 | output: 663 | 664 | (None, 20, 20, 512) 665 | 666 | (None, 20, 20, 512) 667 | 668 | 669 | 140439925122552->140438117455296 670 | 671 | 672 | 673 | 674 | 140438787056864 675 | 676 | batch_normalization_2: BatchNormalization 677 | 678 | input: 679 | 680 | output: 681 | 682 | (None, 20, 20, 512) 683 | 684 | (None, 20, 20, 512) 685 | 686 | 687 | 140438117455296->140438787056864 688 | 689 | 690 | 691 | 692 | 140438787109944 693 | 694 | up_sampling2d_2: UpSampling2D 695 | 696 | input: 697 | 698 | output: 699 | 700 | (None, 20, 20, 512) 701 | 702 | (None, 40, 40, 512) 703 | 704 | 705 | 140438787056864->140438787109944 706 | 707 | 708 | 709 | 710 | 140438787081552 711 | 712 | unpooling_2: Unpooling 713 | 714 | input: 715 | 716 | output: 717 | 718 | (None, 40, 40, 512) 719 | 720 | (None, 40, 40, 512) 721 | 722 | 723 | 140438787109944->140438787081552 724 | 725 | 726 | 727 | 728 | 140438787083008 729 | 730 | deconv4: Conv2D 731 | 732 | input: 733 | 734 | output: 735 | 736 | (None, 40, 40, 512) 737 | 738 | (None, 40, 40, 256) 739 | 740 | 741 | 140438787081552->140438787083008 742 | 743 | 744 | 745 | 746 | 140438787321136 747 | 748 | batch_normalization_3: BatchNormalization 749 | 750 | input: 751 | 752 | output: 753 | 754 | (None, 40, 40, 256) 755 | 756 | (None, 40, 40, 256) 757 | 758 | 759 | 140438787083008->140438787321136 760 | 761 | 762 | 763 | 764 | 140438787380224 765 | 766 | up_sampling2d_3: UpSampling2D 767 | 768 | input: 769 | 770 | output: 771 | 772 | (None, 40, 40, 256) 773 | 774 | (None, 80, 80, 256) 775 | 776 | 777 | 140438787321136->140438787380224 778 | 779 | 780 | 781 | 782 | 140438787292800 783 | 784 | unpooling_3: Unpooling 785 | 786 | input: 787 | 788 | output: 789 | 790 | (None, 80, 80, 256) 791 | 792 | (None, 80, 80, 256) 793 | 794 | 795 | 140438787380224->140438787292800 796 | 797 | 798 | 799 | 800 | 140438787289832 801 | 802 | deconv3: Conv2D 803 | 804 | input: 805 | 806 | output: 807 | 808 | (None, 80, 80, 256) 809 | 810 | (None, 80, 80, 128) 811 | 812 | 813 | 140438787292800->140438787289832 814 | 815 | 816 | 817 | 818 | 140437516990504 819 | 820 | batch_normalization_4: BatchNormalization 821 | 822 | input: 823 | 824 | output: 825 | 826 | (None, 80, 80, 128) 827 | 828 | (None, 80, 80, 128) 829 | 830 | 831 | 140438787289832->140437516990504 832 | 833 | 834 | 835 | 836 | 140437516521200 837 | 838 | up_sampling2d_4: UpSampling2D 839 | 840 | input: 841 | 842 | output: 843 | 844 | (None, 80, 80, 128) 845 | 846 | (None, 160, 160, 128) 847 | 848 | 849 | 140437516990504->140437516521200 850 | 851 | 852 | 853 | 854 | 140437516616536 855 | 856 | unpooling_4: Unpooling 857 | 858 | input: 859 | 860 | output: 861 | 862 | (None, 160, 160, 128) 863 | 864 | (None, 160, 160, 128) 865 | 866 | 867 | 140437516521200->140437516616536 868 | 869 | 870 | 871 | 872 | 140437516618832 873 | 874 | deconv2: Conv2D 875 | 876 | input: 877 | 878 | output: 879 | 880 | (None, 160, 160, 128) 881 | 882 | (None, 160, 160, 64) 883 | 884 | 885 | 140437516616536->140437516618832 886 | 887 | 888 | 889 | 890 | 140437516220008 891 | 892 | batch_normalization_5: BatchNormalization 893 | 894 | input: 895 | 896 | output: 897 | 898 | (None, 160, 160, 64) 899 | 900 | (None, 160, 160, 64) 901 | 902 | 903 | 140437516618832->140437516220008 904 | 905 | 906 | 907 | 908 | 140436915486000 909 | 910 | up_sampling2d_5: UpSampling2D 911 | 912 | input: 913 | 914 | output: 915 | 916 | (None, 160, 160, 64) 917 | 918 | (None, 320, 320, 64) 919 | 920 | 921 | 140437516220008->140436915486000 922 | 923 | 924 | 925 | 926 | 140436915530216 927 | 928 | unpooling_5: Unpooling 929 | 930 | input: 931 | 932 | output: 933 | 934 | (None, 320, 320, 64) 935 | 936 | (None, 320, 320, 64) 937 | 938 | 939 | 140436915486000->140436915530216 940 | 941 | 942 | 943 | 944 | 140436915530160 945 | 946 | deconv1: Conv2D 947 | 948 | input: 949 | 950 | output: 951 | 952 | (None, 320, 320, 64) 953 | 954 | (None, 320, 320, 64) 955 | 956 | 957 | 140436915530216->140436915530160 958 | 959 | 960 | 961 | 962 | 140436914654288 963 | 964 | batch_normalization_6: BatchNormalization 965 | 966 | input: 967 | 968 | output: 969 | 970 | (None, 320, 320, 64) 971 | 972 | (None, 320, 320, 64) 973 | 974 | 975 | 140436915530160->140436914654288 976 | 977 | 978 | 979 | 980 | 140436914715056 981 | 982 | pred: Conv2D 983 | 984 | input: 985 | 986 | output: 987 | 988 | (None, 320, 320, 64) 989 | 990 | (None, 320, 320, 1) 991 | 992 | 993 | 140436914654288->140436914715056 994 | 995 | 996 | 997 | 998 | 999 | -------------------------------------------------------------------------------- /parallel_model.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 9 | 10 | G 11 | 12 | 13 | 139941103165960 14 | 15 | input_1: InputLayer 16 | 17 | input: 18 | 19 | output: 20 | 21 | (None, 320, 320, 4) 22 | 23 | (None, 320, 320, 4) 24 | 25 | 26 | 139939966963272 27 | 28 | lambda_12: Lambda 29 | 30 | input: 31 | 32 | output: 33 | 34 | (None, 320, 320, 4) 35 | 36 | (None, 320, 320, 4) 37 | 38 | 39 | 139941103165960->139939966963272 40 | 41 | 42 | 43 | 44 | 139940906542864 45 | 46 | lambda_23: Lambda 47 | 48 | input: 49 | 50 | output: 51 | 52 | (None, 320, 320, 4) 53 | 54 | (None, 320, 320, 4) 55 | 56 | 57 | 139941103165960->139940906542864 58 | 59 | 60 | 61 | 62 | 139939971312440 63 | 64 | model_2: Model 65 | 66 | input: 67 | 68 | output: 69 | 70 | (None, 320, 320, 4) 71 | 72 | (None, 320, 320, 1) 73 | 74 | 75 | 139939966963272->139939971312440 76 | 77 | 78 | 79 | 80 | 139940906542864->139939971312440 81 | 82 | 83 | 84 | 85 | 139939974831184 86 | 87 | refinement_pred: Concatenate 88 | 89 | input: 90 | 91 | output: 92 | 93 | [(None, 320, 320, 1), (None, 320, 320, 1)] 94 | 95 | (None, 320, 320, 1) 96 | 97 | 98 | 139939971312440->139939974831184 99 | 100 | 101 | 102 | 103 | 104 | -------------------------------------------------------------------------------- /plot_model.py: -------------------------------------------------------------------------------- 1 | # dependency: pip install pydot & brew install graphviz 2 | from model import build_encoder_decoder 3 | from keras.utils import plot_model 4 | 5 | if __name__ == '__main__': 6 | img_rows, img_cols = 320, 320 7 | channel = 3 8 | model = build_encoder_decoder(img_rows, img_cols, channel) 9 | plot_model(model, to_file='model.svg', show_layer_names=True, show_shapes=True) 10 | -------------------------------------------------------------------------------- /pre_process.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import shutil 5 | import zipfile 6 | import tarfile 7 | 8 | from Combined_Dataset.Training_set.Composition_code_revised import do_composite 9 | from Combined_Dataset.Test_set.Composition_code_revised import do_composite_test 10 | 11 | if __name__ == '__main__': 12 | # path to provided foreground images 13 | fg_path = 'data/fg/' 14 | # path to provided alpha mattes 15 | a_path = 'data/mask/' 16 | # Path to background images (MSCOCO) 17 | bg_path = 'data/bg/' 18 | # Path to folder where you want the composited images to go 19 | out_path = 'data/merged/' 20 | 21 | train_folder = 'data/Combined_Dataset/Training_set/' 22 | 23 | # if not os.path.exists('Combined_Dataset'): 24 | zip_file = 'data/Adobe_Deep_Matting_Dataset.zip' 25 | print('Extracting {}...'.format(zip_file)) 26 | 27 | zip_ref = zipfile.ZipFile(zip_file, 'r') 28 | zip_ref.extractall('data') 29 | zip_ref.close() 30 | 31 | if not os.path.exists(bg_path): 32 | zip_file = 'data/train2014.zip' 33 | print('Extracting {}...'.format(zip_file)) 34 | 35 | zip_ref = zipfile.ZipFile(zip_file, 'r') 36 | zip_ref.extractall('data') 37 | zip_ref.close() 38 | 39 | with open(os.path.join(train_folder, 'training_bg_names.txt')) as f: 40 | training_bg_names = f.read().splitlines() 41 | 42 | os.makedirs(bg_path) 43 | for bg_name in training_bg_names: 44 | src_path = os.path.join('data/train2014', bg_name) 45 | dest_path = os.path.join(bg_path, bg_name) 46 | shutil.move(src_path, dest_path) 47 | 48 | if not os.path.exists(fg_path): 49 | os.makedirs(fg_path) 50 | 51 | for old_folder in [train_folder + 'Adobe-licensed images/fg', train_folder + 'Other/fg']: 52 | fg_files = os.listdir(old_folder) 53 | for fg_file in fg_files: 54 | src_path = os.path.join(old_folder, fg_file) 55 | dest_path = os.path.join(fg_path, fg_file) 56 | shutil.move(src_path, dest_path) 57 | 58 | if not os.path.exists(a_path): 59 | os.makedirs(a_path) 60 | 61 | for old_folder in [train_folder + 'Adobe-licensed images/alpha', train_folder + 'Other/alpha']: 62 | a_files = os.listdir(old_folder) 63 | for a_file in a_files: 64 | src_path = os.path.join(old_folder, a_file) 65 | dest_path = os.path.join(a_path, a_file) 66 | shutil.move(src_path, dest_path) 67 | 68 | if not os.path.exists(out_path): 69 | os.makedirs(out_path) 70 | # do_composite() 71 | 72 | # path to provided foreground images 73 | fg_test_path = 'data/fg_test/' 74 | # path to provided alpha mattes 75 | a_test_path = 'data/mask_test/' 76 | # Path to background images (PASCAL VOC) 77 | bg_test_path = 'data/bg_test/' 78 | # Path to folder where you want the composited images to go 79 | out_test_path = 'data/merged_test/' 80 | 81 | # test data gen 82 | test_folder = 'data/Combined_Dataset/Test_set/' 83 | 84 | if not os.path.exists(bg_test_path): 85 | os.makedirs(bg_test_path) 86 | 87 | tar_file = 'data/VOCtrainval_14-Jul-2008.tar' 88 | print('Extracting {}...'.format(tar_file)) 89 | 90 | tar = tarfile.open(tar_file) 91 | tar.extractall('data') 92 | tar.close() 93 | 94 | tar_file = 'data/VOC2008test.tar' 95 | print('Extracting {}...'.format(tar_file)) 96 | 97 | tar = tarfile.open(tar_file) 98 | tar.extractall('data') 99 | tar.close() 100 | 101 | with open(os.path.join(test_folder, 'test_bg_names.txt')) as f: 102 | test_bg_names = f.read().splitlines() 103 | 104 | for bg_name in test_bg_names: 105 | tokens = bg_name.split('_') 106 | src_path = os.path.join('data/VOCdevkit/VOC2008/JPEGImages', bg_name) 107 | dest_path = os.path.join(bg_test_path, bg_name) 108 | shutil.move(src_path, dest_path) 109 | 110 | if not os.path.exists(fg_test_path): 111 | os.makedirs(fg_test_path) 112 | 113 | for old_folder in [test_folder + 'Adobe-licensed images/fg']: 114 | fg_files = os.listdir(old_folder) 115 | for fg_file in fg_files: 116 | src_path = os.path.join(old_folder, fg_file) 117 | dest_path = os.path.join(fg_test_path, fg_file) 118 | shutil.move(src_path, dest_path) 119 | 120 | if not os.path.exists(a_test_path): 121 | os.makedirs(a_test_path) 122 | 123 | for old_folder in [test_folder + 'Adobe-licensed images/alpha']: 124 | a_files = os.listdir(old_folder) 125 | for a_file in a_files: 126 | src_path = os.path.join(old_folder, a_file) 127 | dest_path = os.path.join(a_test_path, a_file) 128 | shutil.move(src_path, dest_path) 129 | 130 | if not os.path.exists(out_test_path): 131 | os.makedirs(out_test_path) 132 | 133 | do_composite_test() 134 | -------------------------------------------------------------------------------- /predit_single.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import cv2 as cv 3 | import keras.backend as K 4 | import numpy as np 5 | 6 | from model import build_encoder_decoder, build_refinement 7 | from utils import get_final_output, create_patches, patch_dims, assemble_patches 8 | import tensorflow as tf 9 | import time 10 | 11 | config = tf.ConfigProto(device_count = {"GPU": 1, "CPU": 1}) 12 | sess = tf.Session(config=config) 13 | K.set_session(sess) 14 | 15 | if __name__ == '__main__': 16 | # load network 17 | PATCH_SIZE = 320 18 | PRETRAINED_PATH = 'models/final.42-0.0398.hdf5' 19 | TRIMAP_PATH = "images/trimap2.png" 20 | IMG_PATH = "images/frame2.png" 21 | 22 | encoder_decoder = build_encoder_decoder() 23 | final = build_refinement(encoder_decoder) 24 | final.load_weights(PRETRAINED_PATH) 25 | print(final.summary()) 26 | 27 | # loading input files 28 | trimap = cv.imread(TRIMAP_PATH, cv.IMREAD_GRAYSCALE) 29 | img = cv.imread(IMG_PATH) 30 | result = np.zeros(trimap.shape, dtype=np.uint8) 31 | 32 | img_size = np.array(trimap.shape) 33 | 34 | # create patches 35 | x = np.dstack((img, np.expand_dims(trimap, axis=2))) / 255. 36 | patches = create_patches(x, PATCH_SIZE) 37 | 38 | # create mat for patches predictions 39 | patches_count = np.product( 40 | patch_dims(mat_size=trimap.shape, patch_size=PATCH_SIZE) 41 | ) 42 | patches_predictions = np.zeros(shape=(patches_count, PATCH_SIZE, PATCH_SIZE)) 43 | 44 | # predicting 45 | for i in range(patches.shape[0]): 46 | print("Predicting patches {}/{}".format(i + 1, patches_count)) 47 | 48 | patch_prediction = final.predict(np.expand_dims(patches[i, :, :, :], axis=0)) 49 | patches_predictions[i] = np.reshape(patch_prediction, (PATCH_SIZE, PATCH_SIZE)) * 255. 50 | 51 | # assemble 52 | result = assemble_patches(patches_predictions, trimap.shape, PATCH_SIZE) 53 | result = result[:img_size[0], :img_size[1]] 54 | 55 | prediction = get_final_output(result, trimap).astype(np.uint8) 56 | 57 | # save into files 58 | cv.imshow("result", prediction) 59 | cv.imshow("image", img) 60 | cv.waitKey(0) 61 | 62 | K.clear_session() 63 | 64 | -------------------------------------------------------------------------------- /segnet.py: -------------------------------------------------------------------------------- 1 | import keras.backend as K 2 | import tensorflow as tf 3 | from keras.layers import Input, Conv2D, UpSampling2D, BatchNormalization, ZeroPadding2D, MaxPooling2D, Reshape, \ 4 | Concatenate, Lambda 5 | from keras.models import Model 6 | from keras.utils import multi_gpu_model 7 | from keras.utils import plot_model 8 | 9 | from custom_layers.unpooling_layer import Unpooling 10 | 11 | 12 | def build_encoder_decoder(): 13 | kernel = 3 14 | 15 | # Encoder 16 | # 17 | input_tensor = Input(shape=(320, 320, 4)) 18 | x = ZeroPadding2D((1, 1))(input_tensor) 19 | x = Conv2D(64, (kernel, kernel), activation='relu', name='conv1_1')(x) 20 | x = ZeroPadding2D((1, 1))(x) 21 | x = Conv2D(64, (kernel, kernel), activation='relu', name='conv1_2')(x) 22 | orig_1 = x 23 | x = MaxPooling2D((2, 2), strides=(2, 2))(x) 24 | 25 | x = ZeroPadding2D((1, 1))(x) 26 | x = Conv2D(128, (kernel, kernel), activation='relu', name='conv2_1')(x) 27 | x = ZeroPadding2D((1, 1))(x) 28 | x = Conv2D(128, (kernel, kernel), activation='relu', name='conv2_2')(x) 29 | orig_2 = x 30 | x = MaxPooling2D((2, 2), strides=(2, 2))(x) 31 | 32 | x = ZeroPadding2D((1, 1))(x) 33 | x = Conv2D(256, (kernel, kernel), activation='relu', name='conv3_1')(x) 34 | x = ZeroPadding2D((1, 1))(x) 35 | x = Conv2D(256, (kernel, kernel), activation='relu', name='conv3_2')(x) 36 | x = ZeroPadding2D((1, 1))(x) 37 | x = Conv2D(256, (kernel, kernel), activation='relu', name='conv3_3')(x) 38 | orig_3 = x 39 | x = MaxPooling2D((2, 2), strides=(2, 2))(x) 40 | 41 | x = ZeroPadding2D((1, 1))(x) 42 | x = Conv2D(512, (kernel, kernel), activation='relu', name='conv4_1')(x) 43 | x = ZeroPadding2D((1, 1))(x) 44 | x = Conv2D(512, (kernel, kernel), activation='relu', name='conv4_2')(x) 45 | x = ZeroPadding2D((1, 1))(x) 46 | x = Conv2D(512, (kernel, kernel), activation='relu', name='conv4_3')(x) 47 | orig_4 = x 48 | x = MaxPooling2D((2, 2), strides=(2, 2))(x) 49 | 50 | x = ZeroPadding2D((1, 1))(x) 51 | x = Conv2D(512, (kernel, kernel), activation='relu', name='conv5_1')(x) 52 | x = ZeroPadding2D((1, 1))(x) 53 | x = Conv2D(512, (kernel, kernel), activation='relu', name='conv5_2')(x) 54 | x = ZeroPadding2D((1, 1))(x) 55 | x = Conv2D(512, (kernel, kernel), activation='relu', name='conv5_3')(x) 56 | orig_5 = x 57 | x = MaxPooling2D((2, 2), strides=(2, 2))(x) 58 | 59 | # Decoder 60 | # 61 | x = UpSampling2D(size=(2, 2))(x) 62 | the_shape = K.int_shape(orig_5) 63 | shape = (1, the_shape[1], the_shape[2], the_shape[3]) 64 | origReshaped = Reshape(shape)(orig_5) 65 | xReshaped = Reshape(shape)(x) 66 | together = Concatenate(axis=1)([origReshaped, xReshaped]) 67 | x = Unpooling()(together) 68 | x = Conv2D(512, (kernel, kernel), activation='relu', padding='same', name='deconv5_1', 69 | kernel_initializer='he_normal', 70 | bias_initializer='zeros')(x) 71 | x = BatchNormalization()(x) 72 | x = Conv2D(512, (kernel, kernel), activation='relu', padding='same', name='deconv5_2', 73 | kernel_initializer='he_normal', 74 | bias_initializer='zeros')(x) 75 | x = BatchNormalization()(x) 76 | x = Conv2D(512, (kernel, kernel), activation='relu', padding='same', name='deconv5_3', 77 | kernel_initializer='he_normal', 78 | bias_initializer='zeros')(x) 79 | x = BatchNormalization()(x) 80 | 81 | x = UpSampling2D(size=(2, 2))(x) 82 | the_shape = K.int_shape(orig_4) 83 | shape = (1, the_shape[1], the_shape[2], the_shape[3]) 84 | origReshaped = Reshape(shape)(orig_4) 85 | xReshaped = Reshape(shape)(x) 86 | together = Concatenate(axis=1)([origReshaped, xReshaped]) 87 | x = Unpooling()(together) 88 | x = Conv2D(256, (kernel, kernel), activation='relu', padding='same', name='deconv4_1', 89 | kernel_initializer='he_normal', 90 | bias_initializer='zeros')(x) 91 | x = BatchNormalization()(x) 92 | x = Conv2D(256, (kernel, kernel), activation='relu', padding='same', name='deconv4_2', 93 | kernel_initializer='he_normal', 94 | bias_initializer='zeros')(x) 95 | x = BatchNormalization()(x) 96 | x = Conv2D(256, (kernel, kernel), activation='relu', padding='same', name='deconv4_3', 97 | kernel_initializer='he_normal', 98 | bias_initializer='zeros')(x) 99 | x = BatchNormalization()(x) 100 | 101 | x = UpSampling2D(size=(2, 2))(x) 102 | the_shape = K.int_shape(orig_3) 103 | shape = (1, the_shape[1], the_shape[2], the_shape[3]) 104 | origReshaped = Reshape(shape)(orig_3) 105 | xReshaped = Reshape(shape)(x) 106 | together = Concatenate(axis=1)([origReshaped, xReshaped]) 107 | x = Unpooling()(together) 108 | x = Conv2D(128, (kernel, kernel), activation='relu', padding='same', name='deconv3_1', 109 | kernel_initializer='he_normal', 110 | bias_initializer='zeros')(x) 111 | x = BatchNormalization()(x) 112 | x = Conv2D(128, (kernel, kernel), activation='relu', padding='same', name='deconv3_2', 113 | kernel_initializer='he_normal', 114 | bias_initializer='zeros')(x) 115 | x = BatchNormalization()(x) 116 | x = Conv2D(128, (kernel, kernel), activation='relu', padding='same', name='deconv3_3', 117 | kernel_initializer='he_normal', 118 | bias_initializer='zeros')(x) 119 | x = BatchNormalization()(x) 120 | 121 | x = UpSampling2D(size=(2, 2))(x) 122 | the_shape = K.int_shape(orig_2) 123 | shape = (1, the_shape[1], the_shape[2], the_shape[3]) 124 | origReshaped = Reshape(shape)(orig_2) 125 | xReshaped = Reshape(shape)(x) 126 | together = Concatenate(axis=1)([origReshaped, xReshaped]) 127 | x = Unpooling()(together) 128 | x = Conv2D(64, (kernel, kernel), activation='relu', padding='same', name='deconv2_1', 129 | kernel_initializer='he_normal', 130 | bias_initializer='zeros')(x) 131 | x = BatchNormalization()(x) 132 | x = Conv2D(64, (kernel, kernel), activation='relu', padding='same', name='deconv2_2', 133 | kernel_initializer='he_normal', 134 | bias_initializer='zeros')(x) 135 | x = BatchNormalization()(x) 136 | 137 | x = UpSampling2D(size=(2, 2))(x) 138 | the_shape = K.int_shape(orig_1) 139 | shape = (1, the_shape[1], the_shape[2], the_shape[3]) 140 | origReshaped = Reshape(shape)(orig_1) 141 | xReshaped = Reshape(shape)(x) 142 | together = Concatenate(axis=1)([origReshaped, xReshaped]) 143 | x = Unpooling()(together) 144 | x = Conv2D(64, (kernel, kernel), activation='relu', padding='same', name='deconv1_1', 145 | kernel_initializer='he_normal', 146 | bias_initializer='zeros')(x) 147 | x = BatchNormalization()(x) 148 | x = Conv2D(64, (kernel, kernel), activation='relu', padding='same', name='deconv1_2', 149 | kernel_initializer='he_normal', 150 | bias_initializer='zeros')(x) 151 | x = BatchNormalization()(x) 152 | 153 | x = Conv2D(1, (kernel, kernel), activation='sigmoid', padding='same', name='pred', kernel_initializer='he_normal', 154 | bias_initializer='zeros')(x) 155 | 156 | model = Model(inputs=input_tensor, outputs=x) 157 | return model 158 | 159 | 160 | def build_refinement(encoder_decoder): 161 | input_tensor = encoder_decoder.input 162 | 163 | input = Lambda(lambda i: i[:, :, :, 0:3])(input_tensor) 164 | 165 | x = Concatenate(axis=3)([input, encoder_decoder.output]) 166 | x = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', 167 | bias_initializer='zeros')(x) 168 | x = BatchNormalization()(x) 169 | x = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', 170 | bias_initializer='zeros')(x) 171 | x = BatchNormalization()(x) 172 | x = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', 173 | bias_initializer='zeros')(x) 174 | x = BatchNormalization()(x) 175 | x = Conv2D(1, (3, 3), activation='sigmoid', padding='same', name='refinement_pred', kernel_initializer='he_normal', 176 | bias_initializer='zeros')(x) 177 | 178 | model = Model(inputs=input_tensor, outputs=x) 179 | return model 180 | 181 | 182 | if __name__ == '__main__': 183 | with tf.device("/cpu:0"): 184 | encoder_decoder = build_encoder_decoder() 185 | print(encoder_decoder.summary()) 186 | plot_model(encoder_decoder, to_file='encoder_decoder.svg', show_layer_names=True, show_shapes=True) 187 | 188 | with tf.device("/cpu:0"): 189 | refinement = build_refinement(encoder_decoder) 190 | print(refinement.summary()) 191 | plot_model(refinement, to_file='refinement.svg', show_layer_names=True, show_shapes=True) 192 | 193 | parallel_model = multi_gpu_model(refinement, gpus=None) 194 | print(parallel_model.summary()) 195 | plot_model(parallel_model, to_file='parallel_model.svg', show_layer_names=True, show_shapes=True) 196 | 197 | K.clear_session() 198 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import cv2 as cv 4 | import numpy as np 5 | 6 | from model import build_encoder_decoder, build_refinement 7 | from utils import get_final_output 8 | 9 | # python test.py -i "images/image.png" -t "images/trimap.png" 10 | if __name__ == '__main__': 11 | img_rows, img_cols = 320, 320 12 | channel = 4 13 | 14 | model_weights_path = 'models/final.42-0.0398.hdf5' 15 | encoder_decoder = build_encoder_decoder() 16 | final = build_refinement(encoder_decoder) 17 | final.load_weights(model_weights_path) 18 | print(final.summary()) 19 | 20 | ap = argparse.ArgumentParser() 21 | ap.add_argument("-i", "--image", help="path to the image file") 22 | ap.add_argument("-t", "--trimap", help="path to the trimap file") 23 | args = vars(ap.parse_args()) 24 | image_path = args["image"] 25 | trimap_path = args["trimap"] 26 | 27 | if image_path is None: 28 | image_path = 'images/image.jpg' 29 | if trimap_path is None: 30 | trimap_path = 'images/trimap.jpg' 31 | 32 | print('Start processing image: {}'.format(image_path)) 33 | 34 | x_test = np.empty((1, img_rows, img_cols, 4), dtype=np.float32) 35 | bgr_img = cv.imread(image_path) 36 | trimap = cv.imread(trimap_path, 0) 37 | 38 | x_test[0, :, :, 0:3] = bgr_img / 255. 39 | x_test[0, :, :, 3] = trimap / 255. 40 | 41 | out = final.predict(x_test) 42 | out = np.reshape(out, (img_rows, img_cols)) 43 | print(out.shape) 44 | out = out * 255.0 45 | out = get_final_output(out, trimap) 46 | out = out.astype(np.uint8) 47 | cv.imshow('out', out) 48 | cv.imwrite('images/out.png', out) 49 | cv.waitKey(0) 50 | cv.destroyAllWindows() 51 | -------------------------------------------------------------------------------- /test_alphamatting.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 as cv 4 | import numpy as np 5 | 6 | from model import build_encoder_decoder, build_refinement 7 | 8 | if __name__ == '__main__': 9 | pretrained_path = 'models/final.42-0.0398.hdf5' 10 | encoder_decoder = build_encoder_decoder() 11 | final = build_refinement(encoder_decoder) 12 | final.load_weights(pretrained_path) 13 | print(final.summary()) 14 | 15 | images = [f for f in os.listdir('alphamatting/input_lowres') if f.endswith('.png')] 16 | 17 | for image_name in images: 18 | filename = os.path.join('alphamatting/input_lowres', image_name) 19 | im = cv.imread(filename) 20 | im_h, im_w = im.shape[:2] 21 | 22 | for id in [1, 2, 3]: 23 | trimap_name = os.path.join('alphamatting/trimap_lowres/Trimap{}'.format(id), image_name) 24 | trimap = cv.imread(trimap_name, 0) 25 | 26 | for i in range(0, np.ceil(im_h / 320)): 27 | for j in range(0, np.ceil(im_w / 320)): 28 | x = j * 320 29 | y = i * 320 30 | w = min(320, im_w - x) 31 | h = min(320, im_h - y) 32 | im_crop = im[y:y + h, x:x + w] 33 | tri_crop = trimap[y:y + h, x:x + w] 34 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import keras 4 | import tensorflow as tf 5 | from keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau 6 | from keras.utils import multi_gpu_model 7 | 8 | from config import patience, batch_size, epochs, num_train_samples, num_valid_samples 9 | from data_generator import train_gen, valid_gen 10 | from migrate import migrate_model 11 | from segnet import build_encoder_decoder, build_refinement 12 | from utils import overall_loss, get_available_cpus, get_available_gpus 13 | 14 | if __name__ == '__main__': 15 | checkpoint_models_path = 'models/' 16 | # Parse arguments 17 | ap = argparse.ArgumentParser() 18 | ap.add_argument("-p", "--pretrained", help="path to save pretrained model files") 19 | args = vars(ap.parse_args()) 20 | pretrained_path = args["pretrained"] 21 | 22 | # Callbacks 23 | tensor_board = keras.callbacks.TensorBoard(log_dir='./logs', histogram_freq=0, write_graph=True, write_images=True) 24 | model_names = checkpoint_models_path + 'final.{epoch:02d}-{val_loss:.4f}.hdf5' 25 | model_checkpoint = ModelCheckpoint(model_names, monitor='val_loss', verbose=1, save_best_only=True) 26 | early_stop = EarlyStopping('val_loss', patience=patience) 27 | reduce_lr = ReduceLROnPlateau('val_loss', factor=0.1, patience=int(patience / 4), verbose=1) 28 | 29 | 30 | class MyCbk(keras.callbacks.Callback): 31 | def __init__(self, model): 32 | keras.callbacks.Callback.__init__(self) 33 | self.model_to_save = model 34 | 35 | def on_epoch_end(self, epoch, logs=None): 36 | fmt = checkpoint_models_path + 'final.%02d-%.4f.hdf5' 37 | self.model_to_save.save(fmt % (epoch, logs['val_loss'])) 38 | 39 | 40 | # Load our model, added support for Multi-GPUs 41 | num_gpu = len(get_available_gpus()) 42 | if num_gpu >= 2: 43 | with tf.device("/cpu:0"): 44 | model = build_encoder_decoder() 45 | model = build_refinement(model) 46 | if pretrained_path is not None: 47 | model.load_weights(pretrained_path) 48 | else: 49 | migrate_model(model) 50 | 51 | final = multi_gpu_model(model, gpus=num_gpu) 52 | # rewrite the callback: saving through the original model and not the multi-gpu model. 53 | model_checkpoint = MyCbk(model) 54 | else: 55 | model = build_encoder_decoder() 56 | final = build_refinement(model) 57 | if pretrained_path is not None: 58 | final.load_weights(pretrained_path) 59 | else: 60 | migrate_model(final) 61 | 62 | decoder_target = tf.placeholder(dtype='float32', shape=(None, None, None, None)) 63 | final.compile(optimizer='nadam', loss=overall_loss, target_tensors=[decoder_target]) 64 | 65 | print(final.summary()) 66 | 67 | # Final callbacks 68 | callbacks = [tensor_board, model_checkpoint, early_stop, reduce_lr] 69 | 70 | # Start Fine-tuning 71 | final.fit_generator(train_gen(), 72 | steps_per_epoch=num_train_samples // batch_size, 73 | validation_data=valid_gen(), 74 | validation_steps=num_valid_samples // batch_size, 75 | epochs=epochs, 76 | verbose=1, 77 | callbacks=callbacks, 78 | use_multiprocessing=True, 79 | workers=2 80 | ) 81 | -------------------------------------------------------------------------------- /train_encoder_decoder.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import keras 4 | import tensorflow as tf 5 | from keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau 6 | from keras.utils import multi_gpu_model 7 | 8 | import migrate 9 | from config import patience, batch_size, epochs, num_train_samples, num_valid_samples 10 | from data_generator import train_gen, valid_gen 11 | from model import build_encoder_decoder 12 | from utils import overall_loss, get_available_cpus, get_available_gpus 13 | 14 | if __name__ == '__main__': 15 | # Parse arguments 16 | ap = argparse.ArgumentParser() 17 | ap.add_argument("-c", "--checkpoint", help="path to save checkpoint model files") 18 | ap.add_argument("-p", "--pretrained", help="path to save pretrained model files") 19 | args = vars(ap.parse_args()) 20 | checkpoint_path = args["checkpoint"] 21 | pretrained_path = args["pretrained"] 22 | if checkpoint_path is None: 23 | checkpoint_models_path = 'models/' 24 | else: 25 | # python train_encoder_decoder.py -c /mnt/Deep-Image-Matting/models/ 26 | checkpoint_models_path = '{}/'.format(checkpoint_path) 27 | 28 | # Callbacks 29 | tensor_board = keras.callbacks.TensorBoard(log_dir='./logs', histogram_freq=0, write_graph=True, write_images=True) 30 | model_names = checkpoint_models_path + 'model.{epoch:02d}-{val_loss:.4f}.hdf5' 31 | model_checkpoint = ModelCheckpoint(model_names, monitor='val_loss', verbose=1, save_best_only=True) 32 | early_stop = EarlyStopping('val_loss', patience=patience) 33 | reduce_lr = ReduceLROnPlateau('val_loss', factor=0.1, patience=int(patience / 4), verbose=1) 34 | 35 | class MyCbk(keras.callbacks.Callback): 36 | def __init__(self, model): 37 | keras.callbacks.Callback.__init__(self) 38 | self.model_to_save = model 39 | 40 | def on_epoch_end(self, epoch, logs=None): 41 | fmt = checkpoint_models_path + 'model.%02d-%.4f.hdf5' 42 | self.model_to_save.save(fmt % (epoch, logs['val_loss'])) 43 | 44 | # Load our model, added support for Multi-GPUs 45 | num_gpu = len(get_available_gpus()) 46 | if num_gpu >= 2: 47 | with tf.device("/cpu:0"): 48 | if pretrained_path is not None: 49 | model = build_encoder_decoder() 50 | model.load_weights(pretrained_path) 51 | else: 52 | model = build_encoder_decoder() 53 | migrate.migrate_model(model) 54 | 55 | new_model = multi_gpu_model(model, gpus=num_gpu) 56 | # rewrite the callback: saving through the original model and not the multi-gpu model. 57 | model_checkpoint = MyCbk(model) 58 | else: 59 | if pretrained_path is not None: 60 | new_model = build_encoder_decoder() 61 | new_model.load_weights(pretrained_path) 62 | else: 63 | new_model = build_encoder_decoder() 64 | migrate.migrate_model(new_model) 65 | 66 | # sgd = SGD(lr=1e-3, decay=1e-6, momentum=0.9, nesterov=True) 67 | new_model.compile(optimizer='nadam', loss=overall_loss) 68 | 69 | print(new_model.summary()) 70 | 71 | # Summarize then go! 72 | num_cpu = get_available_cpus() 73 | workers = int(round(num_cpu / 2)) 74 | print('num_gpu={}\nnum_cpu={}\nworkers={}\ntrained_models_path={}.'.format(num_gpu, num_cpu, workers, 75 | checkpoint_models_path)) 76 | 77 | # Final callbacks 78 | callbacks = [tensor_board, model_checkpoint, early_stop, reduce_lr] 79 | 80 | # Start Fine-tuning 81 | new_model.fit_generator(train_gen(), 82 | steps_per_epoch=num_train_samples // batch_size, 83 | validation_data=valid_gen(), 84 | validation_steps=num_valid_samples // batch_size, 85 | epochs=epochs, 86 | verbose=1, 87 | callbacks=callbacks, 88 | use_multiprocessing=True, 89 | workers=workers 90 | ) 91 | -------------------------------------------------------------------------------- /train_final.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import keras 4 | import tensorflow as tf 5 | from keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau 6 | from keras.utils import multi_gpu_model 7 | 8 | from config import patience, batch_size, epochs, num_train_samples, num_valid_samples 9 | from data_generator import train_gen, valid_gen 10 | from model import build_encoder_decoder, build_refinement 11 | from utils import alpha_prediction_loss, get_available_cpus, get_available_gpus 12 | 13 | if __name__ == '__main__': 14 | checkpoint_models_path = 'models/' 15 | # Parse arguments 16 | ap = argparse.ArgumentParser() 17 | ap.add_argument("-p", "--pretrained", help="path to save pretrained model files") 18 | args = vars(ap.parse_args()) 19 | pretrained_path = args["pretrained"] 20 | 21 | # Callbacks 22 | tensor_board = keras.callbacks.TensorBoard(log_dir='./logs', histogram_freq=0, write_graph=True, write_images=True) 23 | model_names = checkpoint_models_path + 'final.{epoch:02d}-{val_loss:.4f}.hdf5' 24 | model_checkpoint = ModelCheckpoint(model_names, monitor='val_loss', verbose=1, save_best_only=True) 25 | early_stop = EarlyStopping('val_loss', patience=patience) 26 | reduce_lr = ReduceLROnPlateau('val_loss', factor=0.1, patience=int(patience / 4), verbose=1) 27 | 28 | 29 | class MyCbk(keras.callbacks.Callback): 30 | def __init__(self, model): 31 | keras.callbacks.Callback.__init__(self) 32 | self.model_to_save = model 33 | 34 | def on_epoch_end(self, epoch, logs=None): 35 | fmt = checkpoint_models_path + 'model.%02d-%.4f.hdf5' 36 | self.model_to_save.save(fmt % (epoch, logs['val_loss'])) 37 | 38 | num_gpu = len(get_available_gpus()) 39 | if num_gpu >= 2: 40 | with tf.device("/cpu:0"): 41 | # Load our model, added support for Multi-GPUs 42 | model = build_encoder_decoder() 43 | model = build_refinement(model) 44 | model.load_weights(pretrained_path) 45 | 46 | final = multi_gpu_model(model, gpus=num_gpu) 47 | # rewrite the callback: saving through the original model and not the multi-gpu model. 48 | model_checkpoint = MyCbk(model) 49 | else: 50 | model = build_encoder_decoder() 51 | final = build_refinement(model) 52 | final.load_weights(pretrained_path) 53 | 54 | # finetune the whole network together. 55 | for layer in final.layers: 56 | layer.trainable = True 57 | 58 | sgd = keras.optimizers.SGD(lr=1e-5, decay=1e-6, momentum=0.9, nesterov=True) 59 | # nadam = keras.optimizers.Nadam(lr=2e-5) 60 | decoder_target = tf.placeholder(dtype='float32', shape=(None, None, None, None)) 61 | final.compile(optimizer=sgd, loss=alpha_prediction_loss, target_tensors=[decoder_target]) 62 | 63 | print(final.summary()) 64 | 65 | # Summarize then go! 66 | num_cpu = get_available_cpus() 67 | workers = int(round(num_cpu / 2)) 68 | 69 | # Final callbacks 70 | callbacks = [tensor_board, model_checkpoint, early_stop, reduce_lr] 71 | 72 | # Start Fine-tuning 73 | final.fit_generator(train_gen(), 74 | steps_per_epoch=num_train_samples // batch_size, 75 | validation_data=valid_gen(), 76 | validation_steps=num_valid_samples // batch_size, 77 | epochs=epochs, 78 | verbose=1, 79 | callbacks=callbacks, 80 | use_multiprocessing=True, 81 | workers=workers 82 | ) 83 | -------------------------------------------------------------------------------- /train_refinement.py: -------------------------------------------------------------------------------- 1 | import keras 2 | from keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau 3 | from keras.optimizers import SGD 4 | 5 | from config import patience, batch_size, epochs, num_train_samples, num_valid_samples 6 | from data_generator import train_gen, valid_gen 7 | from model import build_encoder_decoder, build_refinement 8 | from utils import custom_loss_wrapper, get_available_cpus 9 | 10 | if __name__ == '__main__': 11 | checkpoint_models_path = 'models/' 12 | 13 | # Callbacks 14 | tensor_board = keras.callbacks.TensorBoard(log_dir='./logs', histogram_freq=0, write_graph=True, write_images=True) 15 | model_names = checkpoint_models_path + 'refinement.{epoch:02d}-{val_loss:.4f}.hdf5' 16 | model_checkpoint = ModelCheckpoint(model_names, monitor='val_loss', verbose=1, save_best_only=True) 17 | early_stop = EarlyStopping('val_loss', patience=patience) 18 | reduce_lr = ReduceLROnPlateau('val_loss', factor=0.1, patience=int(patience / 4), verbose=1) 19 | 20 | pretrained_path = 'models/model.98-0.0459.hdf5' 21 | encoder_decoder = build_encoder_decoder() 22 | encoder_decoder.load_weights(pretrained_path) 23 | # fix encoder-decoder part parameters and then update the refinement part. 24 | for layer in encoder_decoder.layers: 25 | layer.trainable = False 26 | 27 | refinement = build_refinement(encoder_decoder) 28 | 29 | # sgd = SGD(lr=1e-3, decay=1e-6, momentum=0.9, nesterov=True) 30 | refinement.compile(optimizer='nadam', loss=custom_loss_wrapper(refinement.input)) 31 | 32 | print(refinement.summary()) 33 | 34 | # Summarize then go! 35 | num_cpu = get_available_cpus() 36 | workers = int(round(num_cpu / 2)) 37 | 38 | # Final callbacks 39 | callbacks = [tensor_board, model_checkpoint, early_stop, reduce_lr] 40 | 41 | # Start Fine-tuning 42 | refinement.fit_generator(train_gen(), 43 | steps_per_epoch=num_train_samples // batch_size, 44 | validation_data=valid_gen(), 45 | validation_steps=num_valid_samples // batch_size, 46 | epochs=epochs, 47 | verbose=1, 48 | callbacks=callbacks, 49 | use_multiprocessing=True, 50 | workers=workers 51 | ) 52 | -------------------------------------------------------------------------------- /unit_tests.py: -------------------------------------------------------------------------------- 1 | import random 2 | import unittest 3 | 4 | import cv2 as cv 5 | import numpy as np 6 | import os 7 | from config import unknown_code 8 | from data_generator import generate_trimap 9 | from data_generator import get_alpha 10 | from data_generator import random_choice 11 | from utils import safe_crop 12 | 13 | 14 | class TestStringMethods(unittest.TestCase): 15 | 16 | def test_generate_trimap(self): 17 | image = cv.imread('fg/1-1252426161dfXY.jpg') 18 | alpha = cv.imread('mask/1-1252426161dfXY.jpg', 0) 19 | trimap = generate_trimap(alpha) 20 | self.assertEqual(trimap.shape, (615, 410)) 21 | 22 | # ensure np.where works as expected. 23 | count = 0 24 | h, w = trimap.shape[:2] 25 | for i in range(h): 26 | for j in range(w): 27 | if trimap[i, j] == unknown_code: 28 | count += 1 29 | x_indices, y_indices = np.where(trimap == unknown_code) 30 | num_unknowns = len(x_indices) 31 | self.assertEqual(count, num_unknowns) 32 | 33 | # ensure an unknown pixel is chosen 34 | ix = random.choice(range(num_unknowns)) 35 | center_x = x_indices[ix] 36 | center_y = y_indices[ix] 37 | 38 | self.assertEqual(trimap[center_x, center_y], unknown_code) 39 | 40 | x, y = random_choice(trimap) 41 | # print(x, y) 42 | image = safe_crop(image, x, y) 43 | trimap = safe_crop(trimap, x, y) 44 | alpha = safe_crop(alpha, x, y) 45 | cv.imwrite('temp/test_generate_trimap_image.png', image) 46 | cv.imwrite('temp/test_generate_trimap_trimap.png', trimap) 47 | cv.imwrite('temp/test_generate_trimap_alpha.png', alpha) 48 | 49 | def test_flip(self): 50 | image = cv.imread('fg/1-1252426161dfXY.jpg') 51 | # print(image.shape) 52 | alpha = cv.imread('mask/1-1252426161dfXY.jpg', 0) 53 | trimap = generate_trimap(alpha) 54 | x, y = random_choice(trimap) 55 | image = safe_crop(image, x, y) 56 | trimap = safe_crop(trimap, x, y) 57 | alpha = safe_crop(alpha, x, y) 58 | image = np.fliplr(image) 59 | trimap = np.fliplr(trimap) 60 | alpha = np.fliplr(alpha) 61 | cv.imwrite('temp/test_flip_image.png', image) 62 | cv.imwrite('temp/test_flip_trimap.png', trimap) 63 | cv.imwrite('temp/test_flip_alpha.png', alpha) 64 | 65 | def test_different_sizes(self): 66 | different_sizes = [(320, 320), (320, 320), (320, 320), (480, 480), (640, 640)] 67 | crop_size = random.choice(different_sizes) 68 | # print('crop_size=' + str(crop_size)) 69 | 70 | def test_resize(self): 71 | name = '0_0.png' 72 | filename = os.path.join('merged', name) 73 | image = cv.imread(filename) 74 | bg_h, bg_w = image.shape[:2] 75 | a = get_alpha(name) 76 | a_h, a_w = a.shape[:2] 77 | alpha = np.zeros((bg_h, bg_w), np.float32) 78 | alpha[0:a_h, 0:a_w] = a 79 | trimap = generate_trimap(alpha) 80 | # 剪切尺寸 320:640:480 = 3:1:1 81 | crop_size = (480, 480) 82 | x, y = random_choice(trimap, crop_size) 83 | image = safe_crop(image, x, y, crop_size) 84 | trimap = safe_crop(trimap, x, y, crop_size) 85 | alpha = safe_crop(alpha, x, y, crop_size) 86 | cv.imwrite('temp/test_resize_image.png', image) 87 | cv.imwrite('temp/test_resize_trimap.png', trimap) 88 | cv.imwrite('temp/test_resize_alpha.png', alpha) 89 | 90 | 91 | if __name__ == '__main__': 92 | unittest.main() 93 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import math 3 | import cv2 as cv 4 | import keras.backend as K 5 | import numpy as np 6 | from tensorflow.python.client import device_lib 7 | 8 | from config import epsilon, epsilon_sqr 9 | from config import img_cols 10 | from config import img_rows 11 | from config import unknown_code 12 | 13 | 14 | # overall loss: weighted summation of the two individual losses. 15 | # 16 | def overall_loss(y_true, y_pred): 17 | w_l = 0.5 18 | return w_l * alpha_prediction_loss(y_true, y_pred) + (1 - w_l) * compositional_loss(y_true, y_pred) 19 | 20 | 21 | # alpha prediction loss: the abosolute difference between the ground truth alpha values and the 22 | # predicted alpha values at each pixel. However, due to the non-differentiable property of 23 | # absolute values, we use the following loss function to approximate it. 24 | def alpha_prediction_loss(y_true, y_pred): 25 | mask = y_true[:, :, :, 1] 26 | diff = y_pred[:, :, :, 0] - y_true[:, :, :, 0] 27 | diff = diff * mask 28 | num_pixels = K.sum(mask) 29 | return K.sum(K.sqrt(K.square(diff) + epsilon_sqr)) / (num_pixels + epsilon) 30 | 31 | 32 | # compositional loss: the aboslute difference between the ground truth RGB colors and the predicted 33 | # RGB colors composited by the ground truth foreground, the ground truth background and the predicted 34 | # alpha mattes. 35 | def compositional_loss(y_true, y_pred): 36 | mask = y_true[:, :, :, 1] 37 | mask = K.reshape(mask, (-1, img_rows, img_cols, 1)) 38 | image = y_true[:, :, :, 2:5] 39 | fg = y_true[:, :, :, 5:8] 40 | bg = y_true[:, :, :, 8:11] 41 | c_g = image 42 | c_p = y_pred * fg + (1.0 - y_pred) * bg 43 | diff = c_p - c_g 44 | diff = diff * mask 45 | num_pixels = K.sum(mask) 46 | return K.sum(K.sqrt(K.square(diff) + epsilon_sqr)) / (num_pixels + epsilon) 47 | 48 | 49 | # compute the MSE error given a prediction, a ground truth and a trimap. 50 | # pred: the predicted alpha matte 51 | # target: the ground truth alpha matte 52 | # trimap: the given trimap 53 | # 54 | def compute_mse_loss(pred, target, trimap): 55 | error_map = (pred - target) / 255. 56 | mask = np.equal(trimap, unknown_code).astype(np.float32) 57 | # print('unknown: ' + str(unknown)) 58 | loss = np.sum(np.square(error_map) * mask) / np.sum(mask) 59 | # print('mse_loss: ' + str(loss)) 60 | return loss 61 | 62 | 63 | # compute the SAD error given a prediction, a ground truth and a trimap. 64 | # 65 | def compute_sad_loss(pred, target, trimap): 66 | error_map = np.abs(pred - target) / 255. 67 | mask = np.equal(trimap, unknown_code).astype(np.float32) 68 | loss = np.sum(error_map * mask) 69 | 70 | # the loss is scaled by 1000 due to the large images used in our experiment. 71 | loss = loss / 1000 72 | # print('sad_loss: ' + str(loss)) 73 | return loss 74 | 75 | 76 | # getting the number of GPUs 77 | def get_available_gpus(): 78 | local_device_protos = device_lib.list_local_devices() 79 | return [x.name for x in local_device_protos if x.device_type == 'GPU'] 80 | 81 | 82 | # getting the number of CPUs 83 | def get_available_cpus(): 84 | return multiprocessing.cpu_count() 85 | 86 | 87 | def get_final_output(out, trimap): 88 | mask = np.equal(trimap, unknown_code).astype(np.float32) 89 | return (1 - mask) * trimap + mask * out 90 | 91 | def patch_dims(mat_size, patch_size): 92 | return np.ceil(np.array(mat_size) / patch_size).astype(int) 93 | 94 | def create_patches(mat, patch_size): 95 | mat_size = mat.shape 96 | assert len(mat_size) == 3, "Input mat need to have 4 channels (R, G, B, trimap)" 97 | assert mat_size[-1] == 4 , "Input mat need to have 4 channels (R, G, B, trimap)" 98 | 99 | patches_dim = patch_dims(mat_size=mat_size[:2], patch_size=patch_size) 100 | patches_count = np.product(patches_dim) 101 | 102 | patches = np.zeros(shape=(patches_count, patch_size, patch_size, 4), dtype=np.float32) 103 | for y in range(patches_dim[0]): 104 | y_start = y * patch_size 105 | for x in range(patches_dim[1]): 106 | x_start = x * patch_size 107 | 108 | # extract patch from input mat 109 | single_patch = mat[y_start: y_start + patch_size, x_start: x_start + patch_size, :] 110 | 111 | # zero pad patch in bottom and right side if real patch size is smaller than patch size 112 | real_patch_h, real_patch_w = single_patch.shape[:2] 113 | patch_id = y + x * patches_dim[0] 114 | patches[patch_id, :real_patch_h, :real_patch_w, :] = single_patch 115 | 116 | return patches 117 | 118 | def assemble_patches(pred_patches, mat_size, patch_size): 119 | patch_dim_h, patch_dim_w = patch_dims(mat_size=mat_size, patch_size=patch_size) 120 | result = np.zeros(shape=(patch_size * patch_dim_h, patch_size * patch_dim_w), dtype=np.uint8) 121 | patches_count = pred_patches.shape[0] 122 | 123 | for i in range(patches_count): 124 | y = (i % patch_dim_h) * patch_size 125 | x = int(math.floor(i / patch_dim_h)) * patch_size 126 | 127 | result[y:y+patch_size, x:x+patch_size] = pred_patches[i] 128 | 129 | return result 130 | 131 | def safe_crop(mat, x, y, crop_size=(img_rows, img_cols)): 132 | crop_height, crop_width = crop_size 133 | if len(mat.shape) == 2: 134 | ret = np.zeros((crop_height, crop_width), np.float32) 135 | else: 136 | ret = np.zeros((crop_height, crop_width, 3), np.float32) 137 | crop = mat[y:y + crop_height, x:x + crop_width] 138 | h, w = crop.shape[:2] 139 | ret[0:h, 0:w] = crop 140 | if crop_size != (img_rows, img_cols): 141 | ret = cv.resize(ret, dsize=(img_rows, img_cols), interpolation=cv.INTER_NEAREST) 142 | return ret 143 | 144 | 145 | def draw_str(dst, target, s): 146 | x, y = target 147 | cv.putText(dst, s, (x + 1, y + 1), cv.FONT_HERSHEY_PLAIN, 1.0, (0, 0, 0), thickness=2, lineType=cv.LINE_AA) 148 | cv.putText(dst, s, (x, y), cv.FONT_HERSHEY_PLAIN, 1.0, (255, 255, 255), lineType=cv.LINE_AA) 149 | -------------------------------------------------------------------------------- /vgg16.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import keras.backend as K 4 | from keras.layers import Conv2D, ZeroPadding2D, MaxPooling2D 5 | from keras.layers import Dense, Dropout, Flatten 6 | from keras.models import Sequential 7 | 8 | 9 | def vgg16_model(img_rows, img_cols, channel=3): 10 | model = Sequential() 11 | # Encoder 12 | model.add(ZeroPadding2D((1, 1), input_shape=(img_rows, img_cols, channel), name='input')) 13 | model.add(Conv2D(64, (3, 3), activation='relu', name='conv1_1')) 14 | model.add(ZeroPadding2D((1, 1))) 15 | model.add(Conv2D(64, (3, 3), activation='relu', name='conv1_2')) 16 | model.add(MaxPooling2D((2, 2), strides=(2, 2))) 17 | 18 | model.add(ZeroPadding2D((1, 1))) 19 | model.add(Conv2D(128, (3, 3), activation='relu', name='conv2_1')) 20 | model.add(ZeroPadding2D((1, 1))) 21 | model.add(Conv2D(128, (3, 3), activation='relu', name='conv2_2')) 22 | model.add(MaxPooling2D((2, 2), strides=(2, 2))) 23 | 24 | model.add(ZeroPadding2D((1, 1))) 25 | model.add(Conv2D(256, (3, 3), activation='relu', name='conv3_1')) 26 | model.add(ZeroPadding2D((1, 1))) 27 | model.add(Conv2D(256, (3, 3), activation='relu', name='conv3_2')) 28 | model.add(ZeroPadding2D((1, 1))) 29 | model.add(Conv2D(256, (3, 3), activation='relu', name='conv3_3')) 30 | model.add(MaxPooling2D((2, 2), strides=(2, 2))) 31 | 32 | model.add(ZeroPadding2D((1, 1))) 33 | model.add(Conv2D(512, (3, 3), activation='relu', name='conv4_1')) 34 | model.add(ZeroPadding2D((1, 1))) 35 | model.add(Conv2D(512, (3, 3), activation='relu', name='conv4_2')) 36 | model.add(ZeroPadding2D((1, 1))) 37 | model.add(Conv2D(512, (3, 3), activation='relu', name='conv4_3')) 38 | model.add(MaxPooling2D((2, 2), strides=(2, 2))) 39 | 40 | model.add(ZeroPadding2D((1, 1))) 41 | model.add(Conv2D(512, (3, 3), activation='relu', name='conv5_1')) 42 | model.add(ZeroPadding2D((1, 1))) 43 | model.add(Conv2D(512, (3, 3), activation='relu', name='conv5_2')) 44 | model.add(ZeroPadding2D((1, 1))) 45 | model.add(Conv2D(512, (3, 3), activation='relu', name='conv5_3')) 46 | model.add(MaxPooling2D((2, 2), strides=(2, 2))) 47 | 48 | # Add Fully Connected Layer 49 | model.add(Flatten(name='flatten')) 50 | model.add(Dense(4096, activation='relu', name='dense1')) 51 | model.add(Dropout(0.5)) 52 | model.add(Dense(4096, activation='relu', name='dense2')) 53 | model.add(Dropout(0.5)) 54 | model.add(Dense(1000, activation='softmax', name='softmax')) 55 | 56 | # Loads ImageNet pre-trained data 57 | weights_path = 'models/vgg16_weights_tf_dim_ordering_tf_kernels.h5' 58 | model.load_weights(weights_path) 59 | 60 | return model 61 | 62 | 63 | if __name__ == '__main__': 64 | model = vgg16_model(224, 224, 3) 65 | # input_layer = model.get_layer('input') 66 | print(model.summary()) 67 | 68 | K.clear_session() 69 | --------------------------------------------------------------------------------